|
|
|
@ -14,8 +14,12 @@ def attributeproto_fp16_to_fp32(attr): |
|
|
|
|
attr.data_type = 1 |
|
|
|
|
attr.raw_data = float32_list.astype(np.float32).tobytes() |
|
|
|
|
|
|
|
|
|
def convert_fp16_to_fp32(path): |
|
|
|
|
model = onnx.load(path) |
|
|
|
|
def convert_fp16_to_fp32(onnx_path_or_bytes): |
|
|
|
|
if isinstance(onnx_path_or_bytes, bytes): |
|
|
|
|
model = onnx.load_from_string(onnx_path_or_bytes) |
|
|
|
|
elif isinstance(onnx_path_or_bytes, str): |
|
|
|
|
model = onnx.load(onnx_path_or_bytes) |
|
|
|
|
|
|
|
|
|
for i in model.graph.initializer: |
|
|
|
|
if i.data_type == 10: |
|
|
|
|
attributeproto_fp16_to_fp32(i) |
|
|
|
@ -23,6 +27,8 @@ def convert_fp16_to_fp32(path): |
|
|
|
|
if i.type.tensor_type.elem_type == 10: |
|
|
|
|
i.type.tensor_type.elem_type = 1 |
|
|
|
|
for i in model.graph.node: |
|
|
|
|
if i.op_type == 'Cast' and i.attribute[0].i == 10: |
|
|
|
|
i.attribute[0].i = 1 |
|
|
|
|
for a in i.attribute: |
|
|
|
|
if hasattr(a, 't'): |
|
|
|
|
if a.t.data_type == 10: |
|
|
|
|