diff --git a/selfdrive/modeld/runners/onnxmodel.py b/selfdrive/modeld/runners/onnxmodel.py index 3c20a39760..f86bee3878 100644 --- a/selfdrive/modeld/runners/onnxmodel.py +++ b/selfdrive/modeld/runners/onnxmodel.py @@ -1,3 +1,5 @@ +import onnx +import itertools import os import sys import numpy as np @@ -7,7 +9,27 @@ from openpilot.selfdrive.modeld.runners.runmodel_pyx import RunModel ORT_TYPES_TO_NP_TYPES = {'tensor(float16)': np.float16, 'tensor(float)': np.float32, 'tensor(uint8)': np.uint8} -def create_ort_session(path): +def attributeproto_fp16_to_fp32(attr): + float32_list = np.frombuffer(attr.raw_data, dtype=np.float16) + attr.data_type = 1 + attr.raw_data = float32_list.astype(np.float32).tobytes() + +def convert_fp16_to_fp32(path): + model = onnx.load(path) + for i in model.graph.initializer: + if i.data_type == 10: + attributeproto_fp16_to_fp32(i) + for i in itertools.chain(model.graph.input, model.graph.output): + if i.type.tensor_type.elem_type == 10: + i.type.tensor_type.elem_type = 1 + for i in model.graph.node: + for a in i.attribute: + if hasattr(a, 't'): + if a.t.data_type == 10: + attributeproto_fp16_to_fp32(a.t) + return model.SerializeToString() + +def create_ort_session(path, fp16_to_fp32): os.environ["OMP_NUM_THREADS"] = "4" os.environ["OMP_WAIT_POLICY"] = "PASSIVE" @@ -28,8 +50,9 @@ def create_ort_session(path): options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL provider = 'CPUExecutionProvider' + model_data = convert_fp16_to_fp32(path) if fp16_to_fp32 else path print("Onnx selected provider: ", [provider], file=sys.stderr) - ort_session = ort.InferenceSession(path, options, providers=[provider]) + ort_session = ort.InferenceSession(model_data, options, providers=[provider]) print("Onnx using ", ort_session.get_providers(), file=sys.stderr) return ort_session @@ -40,7 +63,7 @@ class ONNXModel(RunModel): self.output = output self.use_tf8 = use_tf8 - self.session = create_ort_session(path) + self.session = create_ort_session(path, fp16_to_fp32=True) self.input_names = [x.name for x in self.session.get_inputs()] self.input_shapes = {x.name: [1, *x.shape[1:]] for x in self.session.get_inputs()} self.input_dtypes = {x.name: ORT_TYPES_TO_NP_TYPES[x.type] for x in self.session.get_inputs()}