modeld: add model names to metadata (#34602)

* modeld: add model names to metadata

* lint

* type hint

* oops

* assert

* ok Any
pull/34611/head
YassineYousfi 2 months ago committed by GitHub
parent 07ef523ec1
commit f06701ea24
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 22
      selfdrive/modeld/get_model_metadata.py

@ -4,22 +4,32 @@ import pathlib
import onnx import onnx
import codecs import codecs
import pickle import pickle
from typing import Any
def get_name_and_shape(value_info:onnx.ValueInfoProto) -> tuple[str, tuple[int,...]]: def get_name_and_shape(value_info:onnx.ValueInfoProto) -> tuple[str, tuple[int,...]]:
shape = tuple([int(dim.dim_value) for dim in value_info.type.tensor_type.shape.dim]) shape = tuple([int(dim.dim_value) for dim in value_info.type.tensor_type.shape.dim])
name = value_info.name name = value_info.name
return name, shape return name, shape
def get_metadata_value_by_name(model:onnx.ModelProto, name:str) -> str | Any:
for prop in model.metadata_props:
if prop.key == name:
return prop.value
return None
if __name__ == "__main__": if __name__ == "__main__":
model_path = pathlib.Path(sys.argv[1]) model_path = pathlib.Path(sys.argv[1])
model = onnx.load(str(model_path)) model = onnx.load(str(model_path))
i = [x.key for x in model.metadata_props].index('output_slices') output_slices = get_metadata_value_by_name(model, 'output_slices')
output_slices = model.metadata_props[i].value assert output_slices is not None, 'output_slices not found in metadata'
metadata = {} metadata = {
metadata['output_slices'] = pickle.loads(codecs.decode(output_slices.encode(), "base64")) 'policy_model': get_metadata_value_by_name(model, 'policy_model'),
metadata['input_shapes'] = dict([get_name_and_shape(x) for x in model.graph.input]) 'vision_model': get_metadata_value_by_name(model, 'vision_model'),
metadata['output_shapes'] = dict([get_name_and_shape(x) for x in model.graph.output]) 'output_slices': pickle.loads(codecs.decode(output_slices.encode(), "base64")),
'input_shapes': dict([get_name_and_shape(x) for x in model.graph.input]),
'output_shapes': dict([get_name_and_shape(x) for x in model.graph.output])
}
metadata_path = model_path.parent / (model_path.stem + '_metadata.pkl') metadata_path = model_path.parent / (model_path.stem + '_metadata.pkl')
with open(metadata_path, 'wb') as f: with open(metadata_path, 'wb') as f:

Loading…
Cancel
Save