From f06701ea24c58acbc3412080de12132dd625a02d Mon Sep 17 00:00:00 2001 From: YassineYousfi Date: Mon, 17 Feb 2025 14:15:34 -0800 Subject: [PATCH] modeld: add model names to metadata (#34602) * modeld: add model names to metadata * lint * type hint * oops * assert * ok Any --- selfdrive/modeld/get_model_metadata.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/selfdrive/modeld/get_model_metadata.py b/selfdrive/modeld/get_model_metadata.py index 144860204f..0f1fd2a98b 100755 --- a/selfdrive/modeld/get_model_metadata.py +++ b/selfdrive/modeld/get_model_metadata.py @@ -4,22 +4,32 @@ import pathlib import onnx import codecs import pickle +from typing import Any def get_name_and_shape(value_info:onnx.ValueInfoProto) -> tuple[str, tuple[int,...]]: shape = tuple([int(dim.dim_value) for dim in value_info.type.tensor_type.shape.dim]) name = value_info.name return name, shape +def get_metadata_value_by_name(model:onnx.ModelProto, name:str) -> str | Any: + for prop in model.metadata_props: + if prop.key == name: + return prop.value + return None + if __name__ == "__main__": model_path = pathlib.Path(sys.argv[1]) model = onnx.load(str(model_path)) - i = [x.key for x in model.metadata_props].index('output_slices') - output_slices = model.metadata_props[i].value + output_slices = get_metadata_value_by_name(model, 'output_slices') + assert output_slices is not None, 'output_slices not found in metadata' - metadata = {} - metadata['output_slices'] = pickle.loads(codecs.decode(output_slices.encode(), "base64")) - metadata['input_shapes'] = dict([get_name_and_shape(x) for x in model.graph.input]) - metadata['output_shapes'] = dict([get_name_and_shape(x) for x in model.graph.output]) + metadata = { + 'policy_model': get_metadata_value_by_name(model, 'policy_model'), + 'vision_model': get_metadata_value_by_name(model, 'vision_model'), + 'output_slices': pickle.loads(codecs.decode(output_slices.encode(), "base64")), + 'input_shapes': dict([get_name_and_shape(x) for x in model.graph.input]), + 'output_shapes': dict([get_name_and_shape(x) for x in model.graph.output]) + } metadata_path = model_path.parent / (model_path.stem + '_metadata.pkl') with open(metadata_path, 'wb') as f: