You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
76 lines
2.6 KiB
76 lines
2.6 KiB
#!/usr/bin/env python3
|
|
|
|
import os
|
|
import sys
|
|
import numpy as np
|
|
|
|
os.environ["OMP_NUM_THREADS"] = "4"
|
|
os.environ["OMP_WAIT_POLICY"] = "PASSIVE"
|
|
|
|
import onnxruntime as ort # pylint: disable=import-error
|
|
|
|
ORT_TYPES_TO_NP_TYPES = {'tensor(float16)': np.float16, 'tensor(float)': np.float32, 'tensor(uint8)': np.uint8}
|
|
|
|
def read(sz, tf8=False):
|
|
dd = []
|
|
gt = 0
|
|
szof = 1 if tf8 else 4
|
|
while gt < sz * szof:
|
|
st = os.read(0, sz * szof - gt)
|
|
assert(len(st) > 0)
|
|
dd.append(st)
|
|
gt += len(st)
|
|
r = np.frombuffer(b''.join(dd), dtype=np.uint8 if tf8 else np.float32)
|
|
if tf8:
|
|
r = r / 255.
|
|
return r
|
|
|
|
def write(d):
|
|
os.write(1, d.tobytes())
|
|
|
|
def run_loop(m, tf8_input=False):
|
|
ishapes = [[1]+ii.shape[1:] for ii in m.get_inputs()]
|
|
keys = [x.name for x in m.get_inputs()]
|
|
itypes = [ORT_TYPES_TO_NP_TYPES[x.type] for x in m.get_inputs()]
|
|
|
|
# run once to initialize CUDA provider
|
|
if "CUDAExecutionProvider" in m.get_providers():
|
|
m.run(None, dict(zip(keys, [np.zeros(shp, dtype=itp) for shp, itp in zip(ishapes, itypes)])))
|
|
|
|
print("ready to run onnx model", keys, ishapes, file=sys.stderr)
|
|
while 1:
|
|
inputs = []
|
|
for k, shp, itp in zip(keys, ishapes, itypes):
|
|
ts = np.product(shp)
|
|
#print("reshaping %s with offset %d" % (str(shp), offset), file=sys.stderr)
|
|
inputs.append(read(ts, (k=='input_img' and tf8_input)).reshape(shp).astype(itp))
|
|
ret = m.run(None, dict(zip(keys, inputs)))
|
|
#print(ret, file=sys.stderr)
|
|
for r in ret:
|
|
write(r.astype(np.float32))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print(sys.argv, file=sys.stderr)
|
|
print("Onnx available providers: ", ort.get_available_providers(), file=sys.stderr)
|
|
options = ort.SessionOptions()
|
|
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
|
|
if 'OpenVINOExecutionProvider' in ort.get_available_providers() and 'ONNXCPU' not in os.environ:
|
|
provider = 'OpenVINOExecutionProvider'
|
|
elif 'CUDAExecutionProvider' in ort.get_available_providers() and 'ONNXCPU' not in os.environ:
|
|
options.intra_op_num_threads = 2
|
|
provider = 'CUDAExecutionProvider'
|
|
else:
|
|
options.intra_op_num_threads = 2
|
|
options.inter_op_num_threads = 8
|
|
options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
|
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
provider = 'CPUExecutionProvider'
|
|
|
|
try:
|
|
print("Onnx selected provider: ", [provider], file=sys.stderr)
|
|
ort_session = ort.InferenceSession(sys.argv[1], options, providers=[provider])
|
|
print("Onnx using ", ort_session.get_providers(), file=sys.stderr)
|
|
run_loop(ort_session, tf8_input=("--use_tf8" in sys.argv))
|
|
except KeyboardInterrupt:
|
|
pass
|
|
|