openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

62 lines
2.9 KiB

import time, sys, hashlib
from pathlib import Path
from tinygrad.frontend.onnx import OnnxRunner
from tinygrad import Tensor, dtypes, TinyJit
from tinygrad.helpers import IMAGE, GlobalCounters, fetch, colored, getenv, trange
import numpy as np
from extra.bench_log import BenchEvent, WallTimeEvent
OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx"
if __name__ == "__main__":
run_onnx = OnnxRunner(fetch(OPENPILOT_MODEL))
Tensor.manual_seed(100)
input_shapes = {name: spec.shape for name, spec in run_onnx.graph_inputs.items()}
input_types = {name: spec.dtype for name, spec in run_onnx.graph_inputs.items()}
new_inputs = {k:Tensor.randn(*shp, dtype=input_types[k]).mul(8).realize() for k,shp in input_shapes.items()}
new_inputs_junk = {k:Tensor.randn(*shp, dtype=input_types[k]).mul(8).realize() for k,shp in input_shapes.items()}
new_inputs_junk_numpy = {k:v.numpy() for k,v in new_inputs_junk.items()}
# benchmark
for _ in range(5):
GlobalCounters.reset()
st = time.perf_counter_ns()
ret = next(iter(run_onnx(new_inputs_junk).values())).cast(dtypes.float32).numpy()
print(f"unjitted: {(time.perf_counter_ns() - st)*1e-6:7.4f} ms")
# NOTE: the inputs to a JIT must be first level arguments
run_onnx_jit = TinyJit(lambda **kwargs: run_onnx(kwargs), prune=True)
for _ in range(20):
GlobalCounters.reset()
st = time.perf_counter_ns()
with WallTimeEvent(BenchEvent.STEP):
# Need to cast non-image inputs from numpy, this is only realistic way to run model
inputs = {**{k:v for k,v in new_inputs_junk.items() if 'img' in k},
**{k:Tensor(v) for k,v in new_inputs_junk_numpy.items() if 'img' not in k}}
ret = next(iter(run_onnx_jit(**inputs).values())).cast(dtypes.float32).numpy()
print(f"jitted: {(time.perf_counter_ns() - st)*1e-6:7.4f} ms")
suffix = ""
if IMAGE.value < 2: suffix += f"_image{IMAGE.value}" # image=2 has no suffix for compatibility
if getenv("FLOAT16") == 1: suffix += "_float16"
path = Path(__file__).parent / "openpilot" / f"{hashlib.md5(OPENPILOT_MODEL.encode()).hexdigest()}{suffix}.npy"
# validate if we have records
tinygrad_out = next(iter(run_onnx_jit(**new_inputs).values())).cast(dtypes.float32).numpy()
if getenv("SAVE_OUTPUT"):
np.save(path, tinygrad_out)
print(f"saved output to {path}!")
elif getenv("FUZZ") and path.exists():
known_good_out = np.load(path)
for _ in trange(1000):
ret = next(iter(run_onnx_jit(**new_inputs).values())).cast(dtypes.float32).numpy()
np.testing.assert_allclose(known_good_out, ret, atol=1e-2, rtol=1e-2)
print(colored("fuzz validated!", "green"))
elif path.exists():
known_good_out = np.load(path)
np.testing.assert_allclose(known_good_out, tinygrad_out, atol=1e-2, rtol=1e-2)
print(colored("outputs validated!", "green"))
else:
print(colored("skipping validation", "yellow"))