parent
6bdf338d75
commit
27263cbfbe
1 changed files with 29 additions and 0 deletions
@ -0,0 +1,29 @@ |
||||
#!/usr/bin/env python3 |
||||
import sys |
||||
import pathlib |
||||
import onnx |
||||
import codecs |
||||
import pickle |
||||
from typing import Tuple |
||||
|
||||
def get_name_and_shape(value_info:onnx.ValueInfoProto) -> Tuple[str, Tuple[int,...]]: |
||||
shape = tuple([int(dim.dim_value) for dim in value_info.type.tensor_type.shape.dim]) |
||||
name = value_info.name |
||||
return name, shape |
||||
|
||||
if __name__ == "__main__": |
||||
model_path = pathlib.Path(sys.argv[1]) |
||||
model = onnx.load(str(model_path)) |
||||
i = [x.key for x in model.metadata_props].index('output_slices') |
||||
output_slices = model.metadata_props[i].value |
||||
|
||||
metadata = {} |
||||
metadata['output_slices'] = pickle.loads(codecs.decode(output_slices.encode(), "base64")) |
||||
metadata['input_shapes'] = dict([get_name_and_shape(x) for x in model.graph.input]) |
||||
metadata['output_shapes'] = dict([get_name_and_shape(x) for x in model.graph.output]) |
||||
|
||||
metadata_path = model_path.parent / (model_path.stem + '_metadata.pkl') |
||||
with open(metadata_path, 'wb') as f: |
||||
pickle.dump(metadata, f) |
||||
|
||||
print(f'saved metadata to {metadata_path}') |
Loading…
Reference in new issue