You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
68 lines
2.2 KiB
68 lines
2.2 KiB
#!/usr/bin/env python3
|
|
|
|
import os
|
|
import sys
|
|
import numpy as np
|
|
|
|
os.environ["OMP_NUM_THREADS"] = "4"
|
|
os.environ["OMP_WAIT_POLICY"] = "PASSIVE"
|
|
|
|
import onnxruntime as ort # pylint: disable=import-error
|
|
|
|
def read(sz):
|
|
dd = []
|
|
gt = 0
|
|
while gt < sz * 4:
|
|
st = os.read(0, sz * 4 - gt)
|
|
assert(len(st) > 0)
|
|
dd.append(st)
|
|
gt += len(st)
|
|
return np.frombuffer(b''.join(dd), dtype=np.float32)
|
|
|
|
def write(d):
|
|
os.write(1, d.tobytes())
|
|
|
|
def run_loop(m):
|
|
ishapes = [[1]+ii.shape[1:] for ii in m.get_inputs()]
|
|
keys = [x.name for x in m.get_inputs()]
|
|
|
|
# run once to initialize CUDA provider
|
|
if "CUDAExecutionProvider" in m.get_providers():
|
|
m.run(None, dict(zip(keys, [np.zeros(shp, dtype=np.float32) for shp in ishapes])))
|
|
|
|
print("ready to run onnx model", keys, ishapes, file=sys.stderr)
|
|
while 1:
|
|
inputs = []
|
|
for shp in ishapes:
|
|
ts = np.product(shp)
|
|
#print("reshaping %s with offset %d" % (str(shp), offset), file=sys.stderr)
|
|
inputs.append(read(ts).reshape(shp))
|
|
ret = m.run(None, dict(zip(keys, inputs)))
|
|
#print(ret, file=sys.stderr)
|
|
for r in ret:
|
|
write(r)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print("Onnx available providers: ", ort.get_available_providers(), file=sys.stderr)
|
|
options = ort.SessionOptions()
|
|
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
|
|
if 'OpenVINOExecutionProvider' in ort.get_available_providers() and 'ONNXCPU' not in os.environ:
|
|
provider = 'OpenVINOExecutionProvider'
|
|
elif 'CUDAExecutionProvider' in ort.get_available_providers() and 'ONNXCPU' not in os.environ:
|
|
options.intra_op_num_threads = 2
|
|
provider = 'CUDAExecutionProvider'
|
|
else:
|
|
options.intra_op_num_threads = 2
|
|
options.inter_op_num_threads = 8
|
|
options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
|
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
provider = 'CPUExecutionProvider'
|
|
|
|
try:
|
|
print("Onnx selected provider: ", [provider], file=sys.stderr)
|
|
ort_session = ort.InferenceSession(sys.argv[1], options, providers=[provider])
|
|
print("Onnx using ", ort_session.get_providers(), file=sys.stderr)
|
|
run_loop(ort_session)
|
|
except KeyboardInterrupt:
|
|
pass
|
|
|