openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

334 lines
19 KiB

from __future__ import annotations
import math, itertools
from collections import defaultdict
from typing import cast, Final
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, can_pad, GroupOp
from tinygrad.device import Buffer
from tinygrad.dtype import AddrSpace, dtypes, ImageDType
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod
from tinygrad.codegen.opt import axis_colors, Opt, OptOps, KernelOptError, check, axis_letters
from tinygrad.codegen.simplify import pm_flatten_range
from tinygrad.renderer import Renderer
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
class Scheduler:
def __init__(self, ast:UOp, opts:Renderer):
self.ast, self.opts = ast, opts
self.dont_use_locals = self.ast.arg.dont_use_locals if self.ast.arg is not None else False
self.applied_opts = list(self.ast.arg.applied_opts) if self.ast.arg is not None else []
@property
def rngs(self):
# always in order by axistype
return sorted([u for u in self.ast.parents if u.op is Ops.RANGE and u.vmax > 0], key=lambda x: (axis_to_pos[x.arg[-1]],) + x.arg[0:-1])
@property
def shape_len(self): return len(self.rngs)
@property
def full_shape(self): return [ssimplify(x.src[0]) for x in self.rngs]
@property
def axis_types(self): return [x.arg[-1] for x in self.rngs]
@property
def maxarg(self): return max([x.arg[0] for x in self.rngs], default=0)
# strings like ['g0', 'g1', 'l0', 'l1', 'l2', 'l3', 'l4', 'l5', 'R0', 'r0', 'r1', 'r2', 'u0', 'u1', 'u2']
def shape_str(self) -> list[str]:
ret: list[str] = []
cnt: dict[AxisType, int] = {}
for x in self.axis_types:
cnt[x] = (cnt[x] + 1) if x in cnt else 0
ret.append(f"{axis_letters[x]}{cnt[x]}")
return ret
def shape_str_to_axis(self, nms:list[str]) -> tuple[int, ...]: return tuple([self.shape_str().index(x) for x in nms])
def copy(self):
ret = Scheduler(self.ast, self.opts)
ret.dont_use_locals = self.dont_use_locals
ret.applied_opts = self.applied_opts[:]
return ret
kernel_cnt: Final[defaultdict[str, int]] = defaultdict(int)
def get_optimized_ast(self, name_override:str|None=None):
if name_override is not None: name = name_override
else:
kernel_type = "r" if self.reduceop is not None else "E"
name = kernel_type + colored('_', 'BLACK').join(['']+[colored(x.src[0].render(), color) for x,color in zip(self.rngs, self.colors())])
Scheduler.kernel_cnt[(function_name := to_function_name(name))] += 1
num = f"n{Scheduler.kernel_cnt[function_name]-1}" if Scheduler.kernel_cnt[function_name] > 1 else ""
name += colored(num, 'BLACK')
self.ast = graph_rewrite(self.ast, pm_flatten_range, name="flatten range")
return self.ast.replace(arg=KernelInfo(name=name, applied_opts=tuple(self.applied_opts), dont_use_locals=self.dont_use_locals), tag=1)
def _globalizable_rngs(self) -> list[UOp]:
store_rngs = self.ast.src[0].src[2:]
# filter any not in local stores
local_store_rngs = [x.ranges for x in self.ast.toposort() if (x.op is Ops.STORE and x.src[0].ptrdtype.addrspace == AddrSpace.LOCAL) \
or (x.op is Ops.BUFFERIZE and x.arg == AddrSpace.LOCAL)]
for ls in local_store_rngs: store_rngs = tuple([x for x in store_rngs if x in ls])
return [x for x in UOp.sink(*store_rngs).toposort() if x.op is Ops.RANGE and x.arg[1] == AxisType.LOOP] if store_rngs else []
def convert_loop_to_global(self):
if not self.opts.has_local: return None
globalizible_rngs = self._globalizable_rngs()
rng = [x.replace(arg=(x.arg[0], AxisType.GLOBAL)) if x in globalizible_rngs else x for x in self.rngs]
self.ast = self.ast.substitute(dict(zip(self.rngs, rng)))
def colors(self) -> list[str]: return [axis_colors[x] if not self.dont_use_locals or not x == AxisType.GLOBAL else "BLUE" for x in self.axis_types]
def colored_shape(self) -> str: return ' '.join([colored(f'{x.src[0].render():>4s}', color) for x,color in zip(self.rngs, self.colors())])
def shift_to(self, rng:UOp, amount:int, new_type:AxisType, top:bool=False, input_new_rng=None):
if (old_sz:=rng.src[0].divides(amount)) is None:
raise KernelOptError(f"{amount} can't divide {rng.src[0]} in {self.colored_shape()}")
new_rng = UOp.range(amount, self.maxarg+1, new_type) if input_new_rng is None else input_new_rng
replaced_rng = rng.replace(src=(UOp.const(dtypes.int, old_sz),))
sub_axis = (new_rng * old_sz + replaced_rng) if top else (replaced_rng * amount + new_rng)
self.ast = self.ast.substitute({rng:sub_axis}, name=f"shift {rng.arg[0]} {amount} {str(new_type).split('.')[1].lower()}")
return replaced_rng, new_rng
def ranges_of(self, *axis_type:AxisType) -> list[UOp]: return [r for r in self.rngs if r.arg[-1] in axis_type]
def axes_of(self, *axis_type:AxisType) -> list[int]: return [i for i,t in enumerate(self.axis_types) if t in axis_type]
# copied from kernel.py
@property
def upcastable_dims(self) -> list[int]: return [i for i in self.axes_of(AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP) \
if isinstance(s:=self.full_shape[i], int) and s > 1]
@property
def unrollable_dims(self) -> list[int]: return [i for i in self.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE) \
if isinstance(s:=self.full_shape[i], int) and s > 1]
def real_axis(self, op:OptOps, axis:int|None):
try:
if axis is None or op is OptOps.TC: return -1
if op is OptOps.UNROLL: return self.unrollable_dims[axis]
if op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.axes_of(AxisType.REDUCE)[axis]
check(axis < self.shape_len, f"invalid axis on {axis=} {op=} {self.shape_len=}")
return axis
except IndexError as e: raise KernelOptError from e
def apply_opt(self, opt:Opt, append_opt:bool=True):
if opt.op is OptOps.NOLOCALS:
check(all(x not in {AxisType.WARP, AxisType.LOCAL, AxisType.GROUP_REDUCE} for x in self.axis_types), "no locals can't have locals")
if append_opt: self.applied_opts.append(opt)
self.dont_use_locals = True
return
if opt.op in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP}:
check(self.opts.has_local, "locals needed for opt")
rng = self.rngs[real_axis] if (real_axis:=self.real_axis(opt.op, opt.axis)) >= 0 else UOp(Ops.NOOP)
opt_to_at = {
OptOps.LOCAL: AxisType.LOCAL, OptOps.UPCAST: AxisType.UPCAST,
OptOps.UNROLL: AxisType.UNROLL, OptOps.GROUP: AxisType.GROUP_REDUCE,
OptOps.GROUPTOP: AxisType.GROUP_REDUCE, OptOps.THREAD: AxisType.THREAD}
ret = None
if opt.op in opt_to_at:
amt:int = int(rng.vmax+1) if opt.arg == 0 else cast(int, opt.arg)
# copied from kernel.py. prevents METAL compiler hangs
if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \
(self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
upcast_local_sz = prod([self.full_shape[a] for a in self.axes_of(AxisType.UPCAST, AxisType.WARP, AxisType.LOCAL, AxisType.GROUP_REDUCE)])
smem_sz = amt*upcast_local_sz*self.reduceop.dtype.itemsize
check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")
if opt.op is OptOps.UNROLL:
check(amt <= 32, "don't unroll more than 32")
check(rng.arg[-1] in {AxisType.GROUP_REDUCE, AxisType.REDUCE}, "unroll is for GROUP_REDUCE/REDUCE")
if opt.op is OptOps.UPCAST:
check((self.opts is not None and self.opts.device == "DSP") or amt <= 16, "don't upcast more than 16")
check(rng.arg[-1] in {AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP}, f"upcast is for GLOBAL/LOCAL/LOOP, not {rng.arg[-1]}")
if opt.op is OptOps.LOCAL:
check(not self.dont_use_locals, "can't use locals")
check(rng.arg[-1] in {AxisType.GLOBAL, AxisType.LOOP}, "local is for globals")
if opt.op is OptOps.THREAD:
check(self.opts is not None and self.opts.has_threads, "target does not support threads")
check(self.opts is not None and self.opts.global_max is not None and amt <= self.opts.global_max[0], "too many threads")
check(all(x is not AxisType.THREAD for x in self.axis_types), "already threaded")
check(rng in self._globalizable_rngs(), "can't apply range to this dim")
if opt.op in {OptOps.GROUP, OptOps.GROUPTOP}:
check(all(x.op is not OptOps.TC for x in self.applied_opts), "no grouping with tensor cores") # TODO: why is this wrong?
check(not self.dont_use_locals, "can't use locals")
check(rng.arg[-1] == AxisType.REDUCE, "group is for reduce")
ret = self.shift_to(rng, amt, opt_to_at[opt.op], top=opt.op in {OptOps.GROUPTOP, OptOps.THREAD})
elif opt.op is OptOps.TC:
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: remove the need for this by having warps
check(opt.axis is not None, "tensor core opts must have an axis")
check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 3, "tensor core opts must have valid arg")
check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.opts.tensor_cores), "tensor core opts must have valid tc_select")
check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt")
check(0 < (use_tensor_cores:=cast(tuple, opt.arg)[2]) <= 2, "use_tensor_cores value is not valid")
try: ret = self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt)
except ValueError as e: raise KernelOptError(str(e))
check(ret is not None, "no tensor core available")
elif opt.op is OptOps.PADTO:
check(rng.src[0].op is Ops.CONST, "only pad const axes")
check(rng.arg[-1] not in {AxisType.UPCAST, AxisType.UNROLL}, "cannot pad upcasted") # TODO: why is this wrong?
check(rng.arg[-1] is not AxisType.THREAD, "cannot pad thread")
# ok to pad SUM if all parent ALU ops have f(0) = 0
if (r:=self.reduceop) is not None and rng.arg[-1] in (AxisType.GROUP_REDUCE, AxisType.REDUCE):
check(r.arg[0] is Ops.ADD and can_pad(r, {}), f"cannot pad {r}")
new_sz = round_up(int(rng.vmax+1), cast(int, opt.arg))
check(rng.vmax+1 > new_sz//4, "pad adds more than quadruple the work")
replaced_rng = UOp.range(new_sz, *rng.arg)
replaces = {rng:replaced_rng}
valid = replaced_rng < rng.vmax+1
for b in self.bufs:
if rng in (i:=b.src[1].get_idx()).sparents:
replaces[b] = b.replace(src=(b.src[0],(valid&b.src[1].get_valid()).where(i, UOp.invalid())))
self.ast = self.ast.substitute(replaces, f"padto {rng.arg[:-1]} {opt.arg}")
elif opt.op is OptOps.SWAP:
try:
altrng = self.rngs[opt.arg]
except IndexError:
raise KernelOptError
check(rng.arg[-1] == AxisType.GLOBAL and altrng.arg[-1] == AxisType.GLOBAL, "swap only for globals")
self.ast = self.ast.substitute({rng:rng.replace(arg=(*altrng.arg[0:-1], rng.arg[-1]), tag=1),
altrng:altrng.replace(arg=(*rng.arg[0:-1], altrng.arg[-1]), tag=1)})
self.ast = graph_rewrite(self.ast, remove_tags)
else:
raise KernelOptError(f"unsupported opt {opt.op}")
if append_opt: self.applied_opts.append(opt)
return ret
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int) -> None|list[UOp]:
reduceops = [x for x in self.ast.toposort() if x.op is Ops.REDUCE]
if not len(reduceops): raise KernelOptError("no reduce ops for TensorCore")
reduceop = reduceops[0]
if use_tensor_cores and reduceop is not None and reduceop.arg is Ops.ADD:
mul = reduceop.src[0] if reduceop.src[0].op is not Ops.CAST else reduceop.src[0].src[0]
if mul.op is not Ops.MUL: return None
in0, in1 = mul.src
try:
tensor_cores = self.opts.tensor_cores if tc_select == -1 else [self.opts.tensor_cores[tc_select]]
except IndexError:
raise KernelOptError(f"invalid tensor core choice {tc_select}")
for tc in tensor_cores:
if tc.dtype_in == in0.dtype.scalar() and tc.dtype_in == in1.dtype.scalar() and tc.dtype_out == reduceop.dtype.scalar():
# tensor cores have three ranges. X, Y, and REDUCE
in0_ranges = sorted([u for u in in0.ranges if u not in in1.ranges], key=lambda x: -x.arg[0])
in1_ranges = sorted([u for u in in1.ranges if u not in in0.ranges], key=lambda x: -x.arg[0])
red_ranges = sorted(reduceop.src[1:], key=lambda x: -x.arg[0])
if DEBUG >= 3:
print(f"TC({axis}): {[(x.arg[0],x.vmax+1) for x in in0_ranges]}",
f"{[(x.arg[0],x.vmax+1) for x in in1_ranges]} {[(x.arg[0],x.vmax+1) for x in red_ranges]}")
if not len(in0_ranges) or not len(in1_ranges) or not len(red_ranges): continue
# pick ranges
# NOTE: why are in1 and in0 switched?
axis_choices = list(itertools.product(in1_ranges, in0_ranges, red_ranges))
if not (axis < len(axis_choices)): continue
axes = list(axis_choices[axis])
# do optimizations and save the ranges
try:
for i,a in enumerate(axes):
idx = self.rngs.index(a)
if (a.vmax+1) % tc.dims[i] != 0:
if opt_level < 2: raise KernelOptError("tc padding requires opt_level >= 2")
# apply_opt should return the updated range?
self.apply_opt(Opt(OptOps.PADTO, idx, tc.dims[i]), append_opt=False) # PADTO might fail
axes[i] = self.rngs[idx]
except KernelOptError: continue
# we create the warp as a whole thing, in case some of these ranges are moved/removed later
warp = UOp.range(tc.threads, -1, AxisType.WARP)
ne: list[UOp] = []
for opt in tc.opts:
if opt[0] == "l":
axes[int(opt[1])], new_range = self.shift_to(axes[int(opt[1])], 2, AxisType.LOCAL, input_new_rng=warp%2)
warp //= 2
elif opt[0] == "u":
axes[int(opt[1])], new_range = self.shift_to(axes[int(opt[1])], 2, AxisType.UPCAST)
else: raise RuntimeError(f"unsupported opt {opt[0]} in tensor cores")
ne.append(new_range)
for _, amt in tc.get_reduce_axes():
axes[2], new_range = self.shift_to(axes[2], amt, AxisType.UNROLL)
ne.append(new_range)
if use_tensor_cores != 2:
# fix the srcs
reduceop = [x for x in self.ast.toposort() if x.op is Ops.REDUCE][0]
tne = [x.replace(tag=1) for x in ne]
ret = reduceop.substitute(dict(zip(ne, tne)))
srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src)
srcs = [x.substitute(dict(zip(tne, [ne[i] for i in argsort(p)]))) for x,p in zip(srcs, tc.permutes_for_shape_str(tc.base_shape_str()))]
# get reduce/upcast axes for the tensor cores
tc_reduce_axes = self.shape_str_to_axis([f"r{i}" for i in range(len(tc.get_reduce_axes()))])
base_upcast_axes = tuple([(s,2) for s in self.shape_str_to_axis(tc.base_upcast_axes())])
tc_upcast_axes = tuple([base_upcast_axes[:int(math.log2(tc.elements_per_thread[i]))] for i in range(3)])
# axes to range number (was done in lowerer)
tc_upcast_axes = tuple([tuple([(self.rngs[a].arg[0], sz) for a,sz in v]) for v in tc_upcast_axes])
tc_reduce_axes = tuple([self.rngs[a].arg[0] for a in tc_reduce_axes])
# construct the op
# TODO: remove tc_upcast_axes from the arg
# do the reduce_axes always disappear? i think they don't
# they need to be moved into the WMMA srcs
wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, tc.threads, tc_upcast_axes, ()) #, tc_reduce_axes)
wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(tc.elements_per_thread[2]), src=(
UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0], tag=1),
UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(tc.elements_per_thread[1]), src=(srcs[1],), arg=tc_upcast_axes[1], tag=1),
UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg, tag=1)
tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2], tag=1)
# preserve extra reduces
reduce_ranges = [x for x in UOp.sink(*reduceop.src[1:]).toposort() if x.op is Ops.RANGE and x.arg[0] not in tc_reduce_axes]
if len(reduce_ranges): tc_uop = UOp(Ops.REDUCE, tc_uop.dtype, (tc_uop,)+tuple(reduce_ranges), Ops.ADD)
self.ast = self.ast.substitute({reduceop: tc_uop})
return axes
return None
# helpers for hand_coded_optimizations
@property
def reduceop(self) -> UOp|None:
red = [x for x in self.ast.parents if x.op is Ops.REDUCE]
if not len(red): return None
return UOp(Ops.REDUCE_AXIS, red[0].dtype, red[0].src, (red[0].arg, ()))
@property
def bufs(self) -> list[UOp]: return [x for x in self.ast.toposort() if x.op is Ops.INDEX][::-1]
@property
def output_shape(self):
return [s if at not in {AxisType.REDUCE, AxisType.UNROLL, AxisType.GROUP_REDUCE} else 1 for s,at in zip(self.full_shape, self.axis_types)]
@property
def upcasted(self) -> int: return len(self.axes_of(AxisType.UPCAST, AxisType.UNROLL))
@property
def group_for_reduces(self) -> int: return len(self.axes_of(AxisType.GROUP_REDUCE))
def bufs_from_ast(ast:UOp, dname:str) -> list[Buffer]:
glbls = sorted([x for x in ast.parents if x.op is Ops.DEFINE_GLOBAL], key=lambda x: x.arg)
return [Buffer(dname, x.ptrdtype.size, x.dtype.base if not isinstance(x.dtype, ImageDType) else x.dtype) for x in glbls]
def apply_opts(ctx:Renderer, ast:UOp):
if ast.tag is not None: return None
k = Scheduler(ast, ctx)
k.convert_loop_to_global()
if ast.arg is not None and ast.arg.opts_to_apply is not None:
for opt in ast.arg.opts_to_apply: k.apply_opt(opt)
elif BEAM >= 1:
from tinygrad.codegen.opt.search import beam_search
rawbufs = bufs_from_ast(ast, ctx.device)
k = beam_search(k, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
elif not NOOPT and (ast.arg is None or ast.arg.applied_opts == ()):
from tinygrad.codegen.opt.heuristic import hand_coded_optimizations
# NOTE: hand_coded_optimizations doesn't support multiblock opts yet
if all(len(u.src) == 1 for u in ast.parents if u.op is Ops.LOAD):
k = hand_coded_optimizations(k)
return k.get_optimized_ast(name_override=ast.arg.name if ast.arg is not None and ast.arg.name != "test" else None)
pm_postrange_opt = PatternMatcher([
(UPat(Ops.SINK, name="ast"), apply_opts),
])