openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

75 lines
2.7 KiB

import pickle, sys
from dataclasses import replace
from tinygrad import Device, Context, Tensor, GlobalCounters
from tinygrad.device import Buffer
from tinygrad.helpers import getenv, BEAM
from tinygrad.engine.jit import TinyJit
from tinygrad.engine.realize import CompiledRunner, ExecItem, ScheduleItem, lower_schedule_item
from tinygrad.renderer import ProgramSpec
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
from tinygrad.codegen.heuristic import hand_coded_optimizations
import numpy as np
def move_jit_captured_to_dev(captured, device="DSP"):
captured.expected_st_vars_dtype_device = [x[:3] + (device,) for x in captured.expected_st_vars_dtype_device]
assign = {}
def move_buffer(b):
if b in assign: return assign[b]
if b._base is not None:
newbuf = Buffer(device, b.size, b.dtype, base=move_buffer(b._base), offset=b.offset)
else:
newbuf = Buffer(device, b.size, b.dtype)
if b.is_allocated(): newbuf.ensure_allocated().copyin(b.as_buffer())
assign[b] = newbuf
return assign[b]
for item in captured.jit_cache:
for b in item.bufs:
if b is not None: move_buffer(b)
captured.jit_cache = [ExecItem(item.prg, [assign.get(b,b) for b in item.bufs]) for item in captured.jit_cache]
return captured
if __name__ == "__main__":
with Context(DEBUG=0):
with open(sys.argv[1], "rb") as f:
fxn: TinyJit = pickle.load(f)
print(f"{f.tell()/1e6:.2f}M loaded")
print(type(fxn))
# Move all buffers to DSP device.
fxn.captured = move_jit_captured_to_dev(fxn.captured, "DSP")
new_jit = []
knum = 1
for ei in fxn.captured.jit_cache:
# skip the copy and the first kernel
if isinstance(ei.prg, CompiledRunner) and all(x is not None for x in ei.bufs):
if knum == (pknum:=getenv("KNUM", 0)) or pknum == 0:
p: ProgramSpec = ei.prg.p
k = Kernel(p.ast, Device["DSP"].renderer)
if getenv("VALIDATE"):
with Context(NOOPT=1):
lower_schedule_item(ScheduleItem(p.ast, ei.bufs)).run()
correct = ei.bufs[0].numpy()
ei.bufs[0].copyin(memoryview(bytearray(b'\x00'*ei.bufs[0].nbytes)))
GlobalCounters.kernel_count -= 1
if not getenv("NOOPT"): k.apply_opts(hand_coded_optimizations(k))
p2 = k.to_program()
new_ei = replace(ei, prg=CompiledRunner(p2))
new_ei.run()
new_jit.append(new_ei)
test = ei.bufs[0].numpy()
if getenv("VALIDATE"):
import numpy as np
np.testing.assert_allclose(correct, test, rtol=1e-3, atol=1e-3)
knum += 1
if getenv("RUN_JIT", 0):
fxn.captured.free_intermediates()
fxn.captured.jit_cache = new_jit
fxn(input=Tensor(np.zeros((1, 3, 224, 224), dtype=np.float32), device="DSP"))