import os, sys, pickle, time import numpy as np if "FLOAT16" not in os.environ: os.environ["FLOAT16"] = "1" if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2" if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1" if "JIT_BATCH_SIZE" not in os.environ: os.environ["JIT_BATCH_SIZE"] = "0" from tinygrad import fetch, Tensor, TinyJit, Context, GlobalCounters, Device from tinygrad.helpers import DEBUG, getenv from tinygrad.tensor import _from_np_dtype from tinygrad.engine.realize import CompiledRunner import onnx from onnx.helper import tensor_dtype_to_np_dtype from tinygrad.frontend.onnx import OnnxRunner OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx" OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/openpilot.pkl" def compile(onnx_file): onnx_model = onnx.load(onnx_file) run_onnx = OnnxRunner(onnx_model) print("loaded model") input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input} input_types = {inp.name: tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) for inp in onnx_model.graph.input} # Float inputs and outputs to tinyjits for openpilot are always float32 input_types = {k:(np.float32 if v==np.float16 else v) for k,v in input_types.items()} Tensor.manual_seed(100) new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())} new_inputs_numpy = {k:v.numpy() for k,v in new_inputs.items()} print("created tensors") run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(Device.DEFAULT) for k,v in kwargs.items()}).values())).cast('float32'), prune=True) for i in range(3): GlobalCounters.reset() print(f"run {i}") inputs = {**{k:v.clone() for k,v in new_inputs.items() if 'img' in k}, **{k:Tensor(v, device="NPY").realize() for k,v in new_inputs_numpy.items() if 'img' not in k}} with Context(DEBUG=max(DEBUG.value, 2 if i == 2 else 1)): ret = run_onnx_jit(**inputs).numpy() # copy i == 1 so use of JITBEAM is okay if i == 1: test_val = np.copy(ret) print(f"captured {len(run_onnx_jit.captured.jit_cache)} kernels") np.testing.assert_equal(test_val, ret, "JIT run failed") print("jit run validated") # checks from compile2 kernel_count = 0 read_image_count = 0 gated_read_image_count = 0 for ei in run_onnx_jit.captured.jit_cache: if isinstance(ei.prg, CompiledRunner): kernel_count += 1 read_image_count += ei.prg.p.src.count("read_image") gated_read_image_count += ei.prg.p.src.count("?read_image") print(f"{kernel_count=}, {read_image_count=}, {gated_read_image_count=}") if (allowed_kernel_count:=getenv("ALLOWED_KERNEL_COUNT", -1)) != -1: assert kernel_count <= allowed_kernel_count, f"too many kernels! {kernel_count=}, {allowed_kernel_count=}" if (allowed_read_image:=getenv("ALLOWED_READ_IMAGE", -1)) != -1: assert read_image_count == allowed_read_image, f"different read_image! {read_image_count=}, {allowed_read_image=}" if (allowed_gated_read_image:=getenv("ALLOWED_GATED_READ_IMAGE", -1)) != -1: assert gated_read_image_count <= allowed_gated_read_image, f"too many gated read_image! {gated_read_image_count=}, {allowed_gated_read_image=}" with open(OUTPUT, "wb") as f: pickle.dump(run_onnx_jit, f) mdl_sz = os.path.getsize(onnx_file) pkl_sz = os.path.getsize(OUTPUT) print(f"mdl size is {mdl_sz/1e6:.2f}M") print(f"pkl size is {pkl_sz/1e6:.2f}M") print("**** compile done ****") return test_val def test_vs_compile(run, new_inputs, test_val=None): new_inputs_numpy = {k:v.numpy() for k,v in new_inputs.items()} # create fake "from_blob" tensors for the inputs, and wrapped NPY tensors for the numpy inputs (these have the same underlying memory) inputs = {**{k:v for k,v in new_inputs.items() if 'img' in k}, **{k:Tensor(v, device="NPY").realize() for k,v in new_inputs_numpy.items() if 'img' not in k}} # run 20 times for _ in range(20): st = time.perf_counter() out = run(**inputs) mt = time.perf_counter() val = out.numpy() et = time.perf_counter() print(f"enqueue {(mt-st)*1e3:6.2f} ms -- total run {(et-st)*1e3:6.2f} ms") print(out, val.shape, val.dtype) if test_val is not None: np.testing.assert_equal(test_val, val) print("**** test done ****") # test that changing the numpy changes the model outputs if any([x.device == 'NPY' for x in inputs.values()]): for v in new_inputs_numpy.values(): v *= 2 out = run(**inputs) changed_val = out.numpy() np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, val, changed_val) return val def test_vs_onnx(new_inputs, test_val, onnx_file, ort=False): new_inputs_numpy = {k:v.numpy() for k,v in new_inputs.items()} onnx_model = onnx.load(onnx_file) timings = [] if ort: # test with onnxruntime import onnxruntime as ort onnx_session = ort.InferenceSession(onnx_file) for _ in range(1 if test_val is not None else 5): st = time.perf_counter() onnx_output = onnx_session.run([onnx_model.graph.output[0].name], {k:v.astype(np.float16) for k,v in new_inputs_numpy.items()}) timings.append(time.perf_counter() - st) new_torch_out = onnx_output[0] else: # test with torch import torch from onnx2torch import convert inputs = {k.name:new_inputs_numpy[k.name] for k in onnx_model.graph.input} torch_model = convert(onnx_model).float() with torch.no_grad(): for _ in range(1 if test_val is not None else 5): st = time.perf_counter() torch_out = torch_model(*[torch.tensor(x) for x in inputs.values()]) timings.append(time.perf_counter() - st) new_torch_out = torch_out.numpy() if test_val is not None: np.testing.assert_allclose(new_torch_out.reshape(test_val.shape), test_val, atol=1e-4, rtol=1e-2) print("test vs onnx passed") return timings if __name__ == "__main__": onnx_file = fetch(OPENPILOT_MODEL) test_val = compile(onnx_file) if not getenv("RUN") else None with open(OUTPUT, "rb") as f: pickle_loaded = pickle.load(f) # same randomness as compile Tensor.manual_seed(100) new_inputs = {nm:Tensor.randn(*st.shape, dtype=dtype).mul(8).realize() for nm, (st, _, dtype, _) in sorted(zip(pickle_loaded.captured.expected_names, pickle_loaded.captured.expected_st_vars_dtype_device))} test_val = test_vs_compile(pickle_loaded, new_inputs, test_val) if getenv("BENCHMARK"): for be in ["torch", "ort"]: try: timings = test_vs_onnx(new_inputs, None, onnx_file, be=="ort") print(f"timing {be}: {min(timings)*1000:.2f} ms") except Exception as e: print(f"{be} fail with {e}") if not getenv("FLOAT16"): test_vs_onnx(new_inputs, test_val, onnx_file, getenv("ORT"))