#!/usr/bin/env python3 import os import sys import numpy as np os.environ["OMP_NUM_THREADS"] = "4" os.environ["OMP_WAIT_POLICY"] = "PASSIVE" import onnxruntime as ort # pylint: disable=import-error def read(sz, tf8=False): dd = [] gt = 0 szof = 1 if tf8 else 4 while gt < sz * szof: st = os.read(0, sz * szof - gt) assert(len(st) > 0) dd.append(st) gt += len(st) r = np.frombuffer(b''.join(dd), dtype=np.uint8 if tf8 else np.float32).astype(np.float32) if tf8: r = r / 255. return r def write(d): os.write(1, d.tobytes()) def run_loop(m, tf8_input=False): ishapes = [[1]+ii.shape[1:] for ii in m.get_inputs()] keys = [x.name for x in m.get_inputs()] # run once to initialize CUDA provider if "CUDAExecutionProvider" in m.get_providers(): m.run(None, dict(zip(keys, [np.zeros(shp, dtype=np.float32) for shp in ishapes]))) print("ready to run onnx model", keys, ishapes, file=sys.stderr) while 1: inputs = [] for k, shp in zip(keys, ishapes): ts = np.product(shp) #print("reshaping %s with offset %d" % (str(shp), offset), file=sys.stderr) inputs.append(read(ts, (k=='input_img' and tf8_input)).reshape(shp)) ret = m.run(None, dict(zip(keys, inputs))) #print(ret, file=sys.stderr) for r in ret: write(r) if __name__ == "__main__": print(sys.argv, file=sys.stderr) print("Onnx available providers: ", ort.get_available_providers(), file=sys.stderr) options = ort.SessionOptions() options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL if 'OpenVINOExecutionProvider' in ort.get_available_providers() and 'ONNXCPU' not in os.environ: provider = 'OpenVINOExecutionProvider' elif 'CUDAExecutionProvider' in ort.get_available_providers() and 'ONNXCPU' not in os.environ: options.intra_op_num_threads = 2 provider = 'CUDAExecutionProvider' else: options.intra_op_num_threads = 2 options.inter_op_num_threads = 8 options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL provider = 'CPUExecutionProvider' try: print("Onnx selected provider: ", [provider], file=sys.stderr) ort_session = ort.InferenceSession(sys.argv[1], options, providers=[provider]) print("Onnx using ", ort_session.get_providers(), file=sys.stderr) run_loop(ort_session, tf8_input=("--use_tf8" in sys.argv)) except KeyboardInterrupt: pass