openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

227 lines
13 KiB

# the job of the lowerer is to do indexing
import functools, itertools, operator, math
from dataclasses import dataclass
from typing import cast
from tinygrad.dtype import dtypes, PtrDType, least_upper_dtype
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop
from tinygrad.renderer import Renderer
from tinygrad.helpers import all_int, prod, partition, flatten, unwrap, QUANTIZE
from tinygrad.codegen.expander import expand_rewrite
from tinygrad.codegen.symbolic import symbolic
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> list[list[int]]|None:
acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
except ValueError: return None
return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
# ***** indexing *****
def _group_dims(dims:tuple[sint, ...], max_sizes:tuple[int, ...]):
# TODO: symbolic shape
if not all_int(dims): return dims
while len(dims) > len(max_sizes) or any(d > m for d,m in zip(dims, max_sizes)):
for i,m in enumerate(max_sizes):
if i < (len(dims)-1) and dims[i] * dims[i+1] <= m:
dims = dims[:i] + (dims[i]*dims[i+1],) + dims[i+2:]
break
else: return None
return dims
def _split_dims(dims, max_sizes):
if all(d <= m for d,m in zip(dims, max_sizes)): return dims
_dims = list(dims) + [1]*(3-len(dims))
for i in range(len(_dims)):
while _dims[i] > max_sizes[i]:
div = next((d for d in range(2, math.ceil(math.sqrt(_dims[i])) + 1) if (_dims[i] % d) == 0), 1)
if div == 1: raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
_dims[i], _dims[(i+1)%len(_dims)] = _dims[i]//div, _dims[(i+1)%len(_dims)]*div
return tuple(_dims[:2] if _dims[2] == 1 else _dims[0] if _dims[1:3] == [1,1] else _dims)
def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|None, reverse=False) -> list[UOp]:
if reverse: dims = dims[::-1]
# try to group first: (a, b, c, d) -> (ab, c, d)
limited = (grouped if (grouped := _group_dims(dims, max_sizes)) else dims) if max_sizes is not None else dims
# check if grouping failed
if max_sizes is not None and len(limited) > len(max_sizes): raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
# try to split up dims: (a,) -> (b, c)
if limited == dims: limited = _split_dims(dims, max_sizes) if max_sizes is not None else dims
ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)]
if len(limited) < len(dims):
ret = []
if (contraction:=get_contraction(dims, limited)) is None: raise AssertionError(f"get_contraction should not be None {dims=} {limited=}")
for idx, contraction_group in zip(raw_idxs, contraction):
for c in contraction_group[:-1]:
ret.append(idx % dims[c])
idx //= dims[c]
ret.append(idx)
elif len(limited) > len(dims):
a, b = len(limited), len(dims)
if a == 2 and b == 1: ret = [raw_idxs[0] * limited[1] + raw_idxs[1]]
if a == 3 and b == 1: ret = [raw_idxs[0] * (limited[1] * limited[2]) + raw_idxs[1] * limited[2] + raw_idxs[2]]
if a == 3 and b == 2: ret = [raw_idxs[0] * limited[1] + raw_idxs[1], raw_idxs[2]]
return ret[::-1] if reverse else ret
@dataclass
class IndexContext:
idxs: list[UOp]
ridxs: list[UOp]
acc_num: int = 0
def get_index(ast:UOp, opts:Renderer) -> IndexContext:
ki = ast.arg if isinstance(ast.arg, KernelInfo) else KernelInfo()
# NOTE: assumes the shape is <global dims> <local dims> <group_for_reduces> <reduces> <upcasts/unrolls>
full_shape = ast.full_shape
first_upcasted = len(full_shape)-ki.upcasted
# if there's no reduce, this is first_upcasted. assumes reduces are at the end
first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.toposort if x.op is Ops.REDUCE_AXIS))
local_loads = [x for x in ast.toposort if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL]
# NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces
group_for_reduces = sum([any(l.st_arg.shape[i]!=ast.src[0].st_arg.shape[i] for l in local_loads) for i in range(first_reduce,first_upcasted)])
global_dims = first_reduce-ki.local_dims
if opts.has_local:
if ki.dont_use_locals:
assert ki.local_dims == 0, "can't use locals if there's no local dims"
idxs = get_grouped_dims("idx", full_shape[:global_dims], opts.global_max, reverse=True)
else:
# define indexes for GPU-like execution
idxs = get_grouped_dims("gidx", full_shape[:global_dims], opts.global_max, reverse=True) + \
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
else:
# all loops are RANGES
idxs = [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), i) for i,g in enumerate(full_shape[:first_reduce])]
# reduce loops
idxs += [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), i)
for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)]
# upcast loops
for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
assert isinstance(g, int), "needs to be int to upcast/unroll"
idxs.append(UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(g), tuple(range(g))),), ((i,g),)))
# late indexes (group for reduce)
ridxs = idxs[:]
for a in range(first_reduce, first_reduce+group_for_reduces):
ridxs[a] = UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(full_shape[a])), 1000+a)
return IndexContext(idxs, ridxs)
# ***** lowering (given index) *****
def lower_reduce_axis(ctx: IndexContext, x: UOp):
# NOTE: always using ridxs is fine here
reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is Ops.RANGE)
assert all(x.op is Ops.UNROLL for x in reduce_expand), f"not all UNROLLS in {reduce_expand} for {x.axis_arg}"
alu_op: Ops = x.arg[0]
ret = x.src[0]
# create acc
acc = UOp(Ops.DEFINE_ACC, x.dtype, (x.const_like(identity_element(alu_op, x.dtype.scalar())),) + tuple(reduce_range), (ctx.acc_num,))
ctx.acc_num += 1
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [acc]+[ret.gep(i) for i in range(ret.dtype.count)])
else:
ret = acc.alu(alu_op, ret)
if not len(reduce_range): return ret
# create ACC and assign
return acc.assign(ret)
def lower_load_store(ctx: IndexContext, x: UOp):
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL else ctx.idxs)
buf = x.src[0]
if x.op is Ops.LOAD:
barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is Ops.DEFINE_LOCAL else ()
return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid),) + barrier)
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
if cast(PtrDType, x.src[0].dtype).local and x.src[2].op is Ops.ASSIGN:
reduce_input = x.src[2].src[1].src[1] if x.src[2].src[1].src[1] is not x.src[2].src[0] else x.src[2].src[1].src[0]
store_back = reduce_input.op is Ops.LOAD and cast(PtrDType, reduce_input.src[0].dtype).local
else: store_back = False
# NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes
if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if u in x.src[2].src else u for u in ctx.idxs])
if (not cast(PtrDType, x.src[0].dtype).local) or store_back:
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
if oidx is not ridx: valid = valid * oidx.eq(0)
return UOp(Ops.STORE, dtypes.void, (buf.index(idx, valid), x.src[2]))
def lower_const(x:UOp):
assert all(v.mask is None for v in unwrap(x.st).views), f"VIEW in CONST/DEFINE_VAR source must be unmasked, got {x.st}"
return x.replace(src=())
pm_lowerer = PatternMatcher([
(UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
(UPat((Ops.CONST, Ops.DEFINE_VAR), src=(UPat(Ops.VIEW),), name="x"), lower_const),
(UPat(Ops.VALID, src=(UPat(Ops.VIEW),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]),
# rewrite LOAD/STORE VIEW to LOAD/STORE with indexed
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store),
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)),
(UPat(Ops.IGNORE, name="x"), lambda x: x.src[0]),
])
# **** this is the "quantization preprocessor", it makes ONNX quantized models, and probably also others, actually use ints ****
FP = (1 << 16)
pm_quant = symbolic+PatternMatcher([
# cast after add/mul
(UPat.var("x").cast(dtypes.float32) + UPat.var("y").cast(dtypes.float32),
lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))+y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)),
(UPat.var("x").cast(dtypes.float32) * UPat.var("y").cast(dtypes.float32),
lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))*y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)),
# masked MUL after masked ADD
((UPat.var("x") + UPat.var("v").where(UPat.var('cadd'), UPat(Ops.CONST, arg=0))) * UPat.var("v").where(UPat.var('cmul'), UPat(Ops.CONST, arg=0)),
lambda x,v,cadd,cmul: x*v.where(cmul, 0)+v.where(cadd*cmul, 0)),
# MUL after reduce
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("x") * UPat.cvar("c"),), name="r"), lambda x,c,r: r.replace(src=(x,))*c),
# CAST after reduce (doesn't work if it's a size change)
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, src=(UPat.var("x"),)),), name="r"),
lambda x,r: r.replace(dtype=x.dtype, src=(x,)).cast(r.dtype) if dtypes.is_float(r.dtype) else None),
# x*c1 + y*c2 -> (x+y)*c1 (if c1 and c2 are close floats)
(UPat.var("x")*UPat.cvar("c1", dtype=dtypes.floats) + UPat.var("y")*UPat.cvar("c2", dtype=dtypes.floats),
lambda x,y,c1,c2: (x+y)*c1 if abs(c1.arg-c2.arg) < 1e-9 else None),
# mul 0 * c1 is 0
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
UPat(Ops.LOAD, src=(UPat(), UPat(Ops.VIEW, name="v"))).cast(dtypes.int).cast(dtypes.float).named("ld"), lambda ld,v,c1: ld*c1),
# mul (with plus) 0 * c1 is 0
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
(UPat(Ops.LOAD, src=(UPat(), UPat(Ops.VIEW, name="v"))).cast(dtypes.int) + \
UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar(), UPat(Ops.CONST, arg=0))).cast(dtypes.float).named("ld"),
lambda ld,v,c1: ld*c1),
# fixed point mult, replace (x.float()*c1+c2).int() with an int expression
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("c2")).cast(dtypes.int),
lambda x,c1,c2: (x * (c1 * FP).cast(dtypes.int) + (c2 * FP).cast(dtypes.int)) // FP),
# fixed point mult, replace (x.float()*c1 + y.float()*c2) with an int expression
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("y").cast(dtypes.float)*UPat.var("c2")),
lambda x,y,c1,c2: ((x * (c1 * FP).cast(dtypes.int) + y * (c2 * FP).cast(dtypes.int)) // FP).cast(dtypes.float)),
# where move
(UPat.var("valid").where(UPat.var("yes"), UPat(Ops.CONST, arg=0))*UPat.var("mul"), lambda valid, yes, mul:
(yes*mul*valid.where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))) if yes.op is not Ops.CONST or yes.arg != 1 else None),
((UPat.var("x")*UPat.cvar("c"))*(UPat.var().where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)).named("v")), lambda x,c,v: (x*v)*c),
(UPat.var("x").cast().named('c') * UPat.var('valid').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)), lambda x,c,valid:
(x*valid.where(UOp.const(x.dtype, 1), UOp.const(x.dtype, 0))).cast(c.dtype)),
((UPat.var('x') * UPat.var('v1').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)) *
UPat.var('v2').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0))).named("mul"), lambda x, mul, v1, v2:
x * (v1&v2).where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))),
# where on two adds
(UPat.var("x") + UPat.var("v").where(UPat.var("a0"), UPat.var("a1")) + UPat.var("v").where(UPat.var("b0"), UPat.var("b1")),
lambda x,v,a0,a1,b0,b1: x + v.where(a0+a1, b0+b1)),
# split REDUCE into multiple reduces
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, name="v1")+UPat.var("c1")) * UPat(Ops.CAST, name="v2",), name="r"),
lambda v1,v2,c1,r: r.replace(src=(v1*v2,)) + r.replace(src=(c1*v2,))),
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, name="v1")+UPat.var("c1")) * (UPat(Ops.CAST, name="v2",)+UPat.var("c2")), name="r"),
lambda v1,v2,c1,c2,r: r.replace(src=(v1*v2,)) + r.replace(src=(c2*v1,)) + r.replace(src=(c1*v2,))),
])
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:
if QUANTIZE and opts.device in {"CPU", "DSP"}: ast = graph_rewrite(ast, pm_quant, name="quantize")
sink = graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
# expand_rewrite turns this into a vectorized program
return expand_rewrite(sink)