openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

203 lines
12 KiB

from typing import cast, Optional, Callable
import itertools, functools, random, math, time, multiprocessing, traceback, signal, atexit
from collections import defaultdict
from dataclasses import replace
from tinygrad.ops import UOp, Ops, Variable, sym_infer
from tinygrad.device import Device, Buffer, Compiler
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, time_to_str
from tinygrad.helpers import IGNORE_BEAM_CACHE, TC_SEARCH_OVER_SHAPE
from tinygrad.dtype import ImageDType, PtrDType
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
from tinygrad.tensor import Tensor
from tinygrad.engine.realize import CompiledRunner
from tinygrad.renderer import ProgramSpec
actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(6)]
actions += [Opt(op=OptOps.UNROLL, axis=axis, arg=amt) for amt in [0,4,7] for axis in range(5)]
actions += [Opt(op=OptOps.LOCAL, axis=axis, arg=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)]
actions += [Opt(op=OptOps.GROUPTOP, axis=axis, arg=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)]
actions += [Opt(op=OptOps.GROUP, axis=axis, arg=amt) for amt in [0,4,8,16] for axis in range(3)]
if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, arg=amt) for amt in [32] for axis in range(7)]
actions += [Opt(op=OptOps.LOCAL, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=6, arg=2)]
actions += [Opt(op=OptOps.TC, axis=0, arg=(-1, 0))]
actions += [Opt(op=OptOps.TC, axis=axis, arg=(-1, getenv("TC_OPT", 2))) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce)
actions += [Opt(op=OptOps.SWAP, axis=axis_0, arg=axis_1) for axis_0 in range(5) for axis_1 in range(axis_0+1, 5)]
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
def _get_test_global_size(global_size, max_global_size, var_vals):
test_global_size, factor = [sym_infer(sz, var_vals) for sz in global_size], 1
while prod(test_global_size) > max_global_size:
for j in range(len(global_size)-1,-1,-1):
if test_global_size[j] > 16:
test_global_size[j] //= 2
factor *= 2
break
return test_global_size, factor
def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[Variable, int], rawbufs:list[Buffer], early_stop:Optional[float]=None,
max_global_size:Optional[int]=65536, clear_l2=False, cnt=3, name="test") -> list[float]:
factor = 1
if p.global_size is not None and max_global_size is not None:
global_size, factor = _get_test_global_size(p.global_size, max_global_size, var_vals)
p = replace(p, global_size=global_size)
try: car = CompiledRunner(p, precompiled=lib)
except AssertionError: return [math.inf] * cnt
tms = []
input_bufs = [rawbufs[i] for i in car.p.globals]
for _ in range(cnt):
if clear_l2:
if hasattr(dev:=Device[p.device], 'invalidate_caches'): dev.invalidate_caches()
else:
with Context(DEBUG=0, BEAM=0, CAPTURING=0, TRACK_MATCH_STATS=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor)
if early_stop is not None and early_stop < min(tms): break
return tms
class TimeoutException(Exception): pass
def timeout_handler(signum, frame): raise TimeoutException()
def _try_compile_linearized_w_idx(x:tuple[int,Kernel], compiler:Compiler) -> tuple[int, Optional[tuple[ProgramSpec, bytes, float]]]:
if hasattr(signal, "alarm"):
signal.signal(getattr(signal, 'SIGALRM'), timeout_handler)
# set timeout
signal.alarm(getenv("BEAM_TIMEOUT_SEC", 10))
ret = None
try:
p = x[1].to_program(name_override="test")
assert p.uops is not None, "uop list wasn't generated?"
if len(p.uops) >= (uops_max:=getenv("BEAM_UOPS_MAX", 3000)) > 0:
if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too many uops. {len(p.uops)=}, {uops_max=}")
raise RuntimeError("too many uops")
st = time.perf_counter()
prog = compiler.compile(p.src)
et = time.perf_counter() - st
ret = (p, prog, et)
except RuntimeError:
if DEBUG >= 4: traceback.print_exc()
except Exception as e:
if getenv("BEAM_STRICT_MODE"): raise e
finally:
if hasattr(signal, "alarm"): signal.alarm(0)
return x[0], ret
# workers should ignore ctrl c
def _init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN)
def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_allocated() for buf in bufs]
# *** external API ***
# get (scrap) buffers for timing the linearizer
def bufs_from_lin(lin:Kernel, allocate:bool=True) -> list[Buffer]:
bufsts: defaultdict[int, list[UOp]] = defaultdict(list)
for x in lin.bufs:
if x.src[0].op is Ops.DEFINE_GLOBAL: bufsts[x.src[0].arg].append(x)
rawbufs: list[Optional[Buffer]] = [None]*len(bufsts)
for k,lx in bufsts.items():
buf_size = prod(dtype.shape) if isinstance(dtype:=lx[0].src[0].dtype, ImageDType) else max(y.st_arg.real_size() for y in lx)
assert isinstance(dtype, (PtrDType, ImageDType))
if buf_size == 0: buf_size = 1 # create a size 1 buffer if no cell is accessed in kernel. # TODO: remove from kernel input in this case.
buf_dtype = dtype if isinstance(dtype, ImageDType) else dtype.base
rawbufs[k] = Buffer(lin.opts.device, buf_size, buf_dtype).allocate() if allocate else Buffer(lin.opts.device, buf_size, buf_dtype)
assert all(r is not None for r in rawbufs)
return cast(list[Buffer], rawbufs)
# get dictionary of all possible actions
def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]:
acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
kernel_actions = actions.copy()
if TC_SEARCH_OVER_SHAPE and len(lin.applied_opts) == 0: # tensor core opts must be first
for i, action in enumerate(kernel_actions):
if action.op == OptOps.TC and (tc_arg := cast(tuple, action.arg))[0] == -1:
# replace every tc_action with default tc with one tc_action for each available tc
kernel_actions[i:i+1] = [Opt(op=OptOps.TC, axis=action.axis, arg=(tc_select, tc_arg[1])) for tc_select,_ in enumerate(lin.opts.tensor_cores)]
for i,a in enumerate(kernel_actions):
if a.axis is not None and a.op is not OptOps.TC:
if ((ax:=lin.real_axis(a)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in kernel_actions): continue
lin2 = lin.copy()
try:
lin2.apply_opt(a)
up, lcl, tc_up = 1, 1, prod(tc.dims)//tc.threads if (tc:=lin2.tensor_core) else 1
for s,c in zip(lin2.full_shape, lin2.colors()):
if c in {"magenta", "yellow"}: up *= s
elif c in {"cyan", "green", "white"}: lcl *= s
if up//tc_up > max_up or lcl > max_lcl:
if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too many upcast/local. {up//tc_up=}, {max_up=}, {lcl=}, {max_lcl=}")
continue
acted_lins[i+1] = lin2
except KernelOptError: pass
return acted_lins
beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG")
def beam_search(lin:Kernel, rawbufs:list[Buffer], amt:int, allow_test_size=True, disable_cache=IGNORE_BEAM_CACHE.value) -> Kernel:
global beam_pool
key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
if not disable_cache and CACHELEVEL >= 1 and (val:=diskcache_get("beam_search", key)) is not None:
ret = lin.copy()
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
return ret
beam: list[tuple[Kernel, float]] = [(lin, float("inf"))]
seen_libs = set()
default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV", "METAL", "HIP"} else 0
if beam_pool is None and (workers := getenv("PARALLEL", default_parallel)):
beam_pool = multiprocessing.get_context("spawn").Pool(workers, _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16))
@atexit.register
def close_pool(): beam_pool.close()
min_progress = getenv("BEAM_MIN_PROGRESS", 0.01)/1e6
if BEAM_DEBUG: print(f"BEAM_SEARCH:\n{lin.ast}")
if DEBUG >= 2: print(f" 0.00s: from 1 -> 1 actions {lin.colored_shape()}")
try:
rawbufs = _ensure_buffer_alloc(rawbufs)
var_vals: dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
exiting, st = False, time.perf_counter()
dev = Device[lin.opts.device]
while not exiting:
acted_lins: list[Kernel] = flatten([get_kernel_actions(lin, include_0=False).values() for lin,_ in beam])
timed_lins: list[tuple[Kernel, float]] = []
_compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler)
least_compute_ops = math.inf
for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))):
if proc is None: continue
p, lib, compile_et = proc
if lib in seen_libs: continue
# filter out kernels that use 1000x more compute than the smallest
least_compute_ops = min(this_compute_ops:=sym_infer(p.estimates.ops, var_vals), least_compute_ops)
if least_compute_ops*1000 < this_compute_ops: continue
seen_libs.add(lib)
try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0, clear_l2=hasattr(dev, 'invalidate_caches'))
except RuntimeError: continue # for runtime issues
timed_lins.append((acted_lins[i], min(tms)))
if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(list, p.uops)):5d} uops {time_to_str(compile_et, w=12)} compile/{time_to_str(timed_lins[-1][1], w=12)} run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {time_to_str(timed_lins[-1][1], w=12)} {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
# done
opts = sorted(timed_lins, key=lambda x: x[1])
exiting = len(opts) == 0 or (opts[0][1] < min_progress) or (len(beam) > 0 and ((beam[0][1]-opts[0][1]) < min_progress))
if not exiting: beam = opts[:amt]
elif len(opts) > 0 and opts[0][1] < beam[0][1]: beam = opts[:1]
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(time_to_str(beam[0][1], w=12), "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) # noqa: E501
except KeyboardInterrupt as e:
if beam_pool is not None: beam_pool.terminate()
raise e
if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
if BEAM_DEBUG: print(f"BEAM_SEARCH: final tm={time_to_str(beam[0][1], w=0)}, applied_opts={beam[0][0].applied_opts}")
return beam[0][0]
def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffer]) -> list[int]:
test_rawbuffers = [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype).allocate(), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
MAX_WORKGROUP = 1024
local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
def try_exec(local_size):
try: return _prg(*[x._buf for x in test_rawbuffers], global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True) # noqa: E501
except Exception: return float('inf')
ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])
assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
return ret[1]