|
|
|
@ -4,22 +4,32 @@ import pathlib |
|
|
|
|
import onnx |
|
|
|
|
import codecs |
|
|
|
|
import pickle |
|
|
|
|
from typing import Any |
|
|
|
|
|
|
|
|
|
def get_name_and_shape(value_info:onnx.ValueInfoProto) -> tuple[str, tuple[int,...]]: |
|
|
|
|
shape = tuple([int(dim.dim_value) for dim in value_info.type.tensor_type.shape.dim]) |
|
|
|
|
name = value_info.name |
|
|
|
|
return name, shape |
|
|
|
|
|
|
|
|
|
def get_metadata_value_by_name(model:onnx.ModelProto, name:str) -> str | Any: |
|
|
|
|
for prop in model.metadata_props: |
|
|
|
|
if prop.key == name: |
|
|
|
|
return prop.value |
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
model_path = pathlib.Path(sys.argv[1]) |
|
|
|
|
model = onnx.load(str(model_path)) |
|
|
|
|
i = [x.key for x in model.metadata_props].index('output_slices') |
|
|
|
|
output_slices = model.metadata_props[i].value |
|
|
|
|
output_slices = get_metadata_value_by_name(model, 'output_slices') |
|
|
|
|
assert output_slices is not None, 'output_slices not found in metadata' |
|
|
|
|
|
|
|
|
|
metadata = {} |
|
|
|
|
metadata['output_slices'] = pickle.loads(codecs.decode(output_slices.encode(), "base64")) |
|
|
|
|
metadata['input_shapes'] = dict([get_name_and_shape(x) for x in model.graph.input]) |
|
|
|
|
metadata['output_shapes'] = dict([get_name_and_shape(x) for x in model.graph.output]) |
|
|
|
|
metadata = { |
|
|
|
|
'policy_model': get_metadata_value_by_name(model, 'policy_model'), |
|
|
|
|
'vision_model': get_metadata_value_by_name(model, 'vision_model'), |
|
|
|
|
'output_slices': pickle.loads(codecs.decode(output_slices.encode(), "base64")), |
|
|
|
|
'input_shapes': dict([get_name_and_shape(x) for x in model.graph.input]), |
|
|
|
|
'output_shapes': dict([get_name_and_shape(x) for x in model.graph.output]) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
metadata_path = model_path.parent / (model_path.stem + '_metadata.pkl') |
|
|
|
|
with open(metadata_path, 'wb') as f: |
|
|
|
|