onnxmodel: fp16_to_fp32 (#30080)

* onnxmodel: force fp32

* rename

* rename this too
old-commit-hash: 04e239f7ed
test-msgs
YassineYousfi 2 years ago committed by GitHub
parent eb0530f6dd
commit a3f3e0c122
  1. 29
      selfdrive/modeld/runners/onnxmodel.py

@ -1,3 +1,5 @@
import onnx
import itertools
import os import os
import sys import sys
import numpy as np import numpy as np
@ -7,7 +9,27 @@ from openpilot.selfdrive.modeld.runners.runmodel_pyx import RunModel
ORT_TYPES_TO_NP_TYPES = {'tensor(float16)': np.float16, 'tensor(float)': np.float32, 'tensor(uint8)': np.uint8} ORT_TYPES_TO_NP_TYPES = {'tensor(float16)': np.float16, 'tensor(float)': np.float32, 'tensor(uint8)': np.uint8}
def create_ort_session(path): def attributeproto_fp16_to_fp32(attr):
float32_list = np.frombuffer(attr.raw_data, dtype=np.float16)
attr.data_type = 1
attr.raw_data = float32_list.astype(np.float32).tobytes()
def convert_fp16_to_fp32(path):
model = onnx.load(path)
for i in model.graph.initializer:
if i.data_type == 10:
attributeproto_fp16_to_fp32(i)
for i in itertools.chain(model.graph.input, model.graph.output):
if i.type.tensor_type.elem_type == 10:
i.type.tensor_type.elem_type = 1
for i in model.graph.node:
for a in i.attribute:
if hasattr(a, 't'):
if a.t.data_type == 10:
attributeproto_fp16_to_fp32(a.t)
return model.SerializeToString()
def create_ort_session(path, fp16_to_fp32):
os.environ["OMP_NUM_THREADS"] = "4" os.environ["OMP_NUM_THREADS"] = "4"
os.environ["OMP_WAIT_POLICY"] = "PASSIVE" os.environ["OMP_WAIT_POLICY"] = "PASSIVE"
@ -28,8 +50,9 @@ def create_ort_session(path):
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
provider = 'CPUExecutionProvider' provider = 'CPUExecutionProvider'
model_data = convert_fp16_to_fp32(path) if fp16_to_fp32 else path
print("Onnx selected provider: ", [provider], file=sys.stderr) print("Onnx selected provider: ", [provider], file=sys.stderr)
ort_session = ort.InferenceSession(path, options, providers=[provider]) ort_session = ort.InferenceSession(model_data, options, providers=[provider])
print("Onnx using ", ort_session.get_providers(), file=sys.stderr) print("Onnx using ", ort_session.get_providers(), file=sys.stderr)
return ort_session return ort_session
@ -40,7 +63,7 @@ class ONNXModel(RunModel):
self.output = output self.output = output
self.use_tf8 = use_tf8 self.use_tf8 = use_tf8
self.session = create_ort_session(path) self.session = create_ort_session(path, fp16_to_fp32=True)
self.input_names = [x.name for x in self.session.get_inputs()] self.input_names = [x.name for x in self.session.get_inputs()]
self.input_shapes = {x.name: [1, *x.shape[1:]] for x in self.session.get_inputs()} self.input_shapes = {x.name: [1, *x.shape[1:]] for x in self.session.get_inputs()}
self.input_dtypes = {x.name: ORT_TYPES_TO_NP_TYPES[x.type] for x in self.session.get_inputs()} self.input_dtypes = {x.name: ORT_TYPES_TO_NP_TYPES[x.type] for x in self.session.get_inputs()}

Loading…
Cancel
Save