You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							29 lines
						
					
					
						
							1.0 KiB
						
					
					
				
			
		
		
	
	
							29 lines
						
					
					
						
							1.0 KiB
						
					
					
				| #!/usr/bin/env python3
 | |
| import sys
 | |
| import pathlib
 | |
| import onnx
 | |
| import codecs
 | |
| import pickle
 | |
| from typing import Tuple
 | |
| 
 | |
| def get_name_and_shape(value_info:onnx.ValueInfoProto) -> Tuple[str, Tuple[int,...]]:
 | |
|   shape = tuple([int(dim.dim_value) for dim in value_info.type.tensor_type.shape.dim])
 | |
|   name = value_info.name
 | |
|   return name, shape
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|   model_path = pathlib.Path(sys.argv[1])
 | |
|   model = onnx.load(str(model_path))
 | |
|   i = [x.key for x in model.metadata_props].index('output_slices')
 | |
|   output_slices = model.metadata_props[i].value
 | |
| 
 | |
|   metadata = {}
 | |
|   metadata['output_slices'] = pickle.loads(codecs.decode(output_slices.encode(), "base64"))
 | |
|   metadata['input_shapes'] = dict([get_name_and_shape(x) for x in model.graph.input])
 | |
|   metadata['output_shapes'] = dict([get_name_and_shape(x) for x in model.graph.output])
 | |
| 
 | |
|   metadata_path = model_path.parent / (model_path.stem + '_metadata.pkl')
 | |
|   with open(metadata_path, 'wb') as f:
 | |
|     pickle.dump(metadata, f)
 | |
| 
 | |
|   print(f'saved metadata to {metadata_path}')
 | |
| 
 |