#!/usr/bin/env python3 import sys import pathlib import onnx import codecs import pickle from typing import Any def get_name_and_shape(value_info:onnx.ValueInfoProto) -> tuple[str, tuple[int,...]]: shape = tuple([int(dim.dim_value) for dim in value_info.type.tensor_type.shape.dim]) name = value_info.name return name, shape def get_metadata_value_by_name(model:onnx.ModelProto, name:str) -> str | Any: for prop in model.metadata_props: if prop.key == name: return prop.value return None if __name__ == "__main__": model_path = pathlib.Path(sys.argv[1]) model = onnx.load(str(model_path)) output_slices = get_metadata_value_by_name(model, 'output_slices') assert output_slices is not None, 'output_slices not found in metadata' metadata = { 'policy_model': get_metadata_value_by_name(model, 'policy_model'), 'vision_model': get_metadata_value_by_name(model, 'vision_model'), 'output_slices': pickle.loads(codecs.decode(output_slices.encode(), "base64")), 'input_shapes': dict([get_name_and_shape(x) for x in model.graph.input]), 'output_shapes': dict([get_name_and_shape(x) for x in model.graph.output]) } metadata_path = model_path.parent / (model_path.stem + '_metadata.pkl') with open(metadata_path, 'wb') as f: pickle.dump(metadata, f) print(f'saved metadata to {metadata_path}')