modeld: ort helpers (#34258)
	
		
	
				
					
				
			* ort helpers * import from ort helpers * import that too * linter * linter * linterpull/34260/head
							parent
							
								
									11fb0b95d2
								
							
						
					
					
						commit
						a98210aeec
					
				
				 2 changed files with 39 additions and 30 deletions
			
			
		| @ -0,0 +1,36 @@ | |||||||
|  | import onnx | ||||||
|  | import onnxruntime as ort | ||||||
|  | import numpy as np | ||||||
|  | import itertools | ||||||
|  | 
 | ||||||
|  | ORT_TYPES_TO_NP_TYPES = {'tensor(float16)': np.float16, 'tensor(float)': np.float32, 'tensor(uint8)': np.uint8} | ||||||
|  | 
 | ||||||
|  | def attributeproto_fp16_to_fp32(attr): | ||||||
|  |   float32_list = np.frombuffer(attr.raw_data, dtype=np.float16) | ||||||
|  |   attr.data_type = 1 | ||||||
|  |   attr.raw_data = float32_list.astype(np.float32).tobytes() | ||||||
|  | 
 | ||||||
|  | def convert_fp16_to_fp32(model): | ||||||
|  |   for i in model.graph.initializer: | ||||||
|  |     if i.data_type == 10: | ||||||
|  |       attributeproto_fp16_to_fp32(i) | ||||||
|  |   for i in itertools.chain(model.graph.input, model.graph.output): | ||||||
|  |     if i.type.tensor_type.elem_type == 10: | ||||||
|  |       i.type.tensor_type.elem_type = 1 | ||||||
|  |   for i in model.graph.node: | ||||||
|  |     if i.op_type == 'Cast' and i.attribute[0].i == 10: | ||||||
|  |       i.attribute[0].i = 1 | ||||||
|  |     for a in i.attribute: | ||||||
|  |       if hasattr(a, 't'): | ||||||
|  |         if a.t.data_type == 10: | ||||||
|  |           attributeproto_fp16_to_fp32(a.t) | ||||||
|  |   return model.SerializeToString() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def make_onnx_cpu_runner(model_path): | ||||||
|  |   options = ort.SessionOptions() | ||||||
|  |   options.intra_op_num_threads = 4 | ||||||
|  |   options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL | ||||||
|  |   options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | ||||||
|  |   model_data = convert_fp16_to_fp32(onnx.load(model_path)) | ||||||
|  |   return ort.InferenceSession(model_data,  options, providers=['CPUExecutionProvider']) | ||||||
					Loading…
					
					
				
		Reference in new issue