openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

124 lines
6.0 KiB

# this converts a lowerer program into a vectorized program
import functools, itertools, operator
from tinygrad.helpers import AMX, dedup, flatten, all_same, prod
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, graph_rewrite
from tinygrad.codegen.symbolic import sym
def _expand_arg_to_idx(args:tuple[tuple[int, int], ...], rpk:dict[int, int]) -> int:
idx, mul = 0, 1
for axis,m in args[::-1]:
idx += rpk[axis] * mul
mul *= m
return idx
def _choices_from_args(args:tuple[tuple[int, int], ...]) -> list[dict[int, int]]:
return [dict(x) for x in itertools.product(*[zip(itertools.repeat(axis), range(m)) for axis,m in args])]
@functools.lru_cache(None)
def _swizzle_args(cargs:tuple[tuple[int, int], ...], eargs:tuple[tuple[int, int], ...], exclude_args:tuple[int, ...]) -> list[int]:
return [_expand_arg_to_idx(eargs, {**rpk, **{x:0 for x in exclude_args}} if exclude_args else rpk) for rpk in _choices_from_args(cargs)]
def do_expand(root:UOp):
expands = [x for x in root.src if x.op is Ops.UNROLL]
if len(expands) == 0: return None
# NOTE: we 0 out the reduce axis for WMMA. in theory they should all be the same, but is this always correct?
exclude_args = tuple(dedup(root.arg[-1] + tuple(y[0] for y in flatten(root.arg[-2])))) if root.op is Ops.WMMA else ()
if all_same(expands_args:=[x.arg for x in expands]) and len(exclude_args) == 0:
# if there's only one expand arg, it's okay to use it (optimization)
expand_args = expands[0].arg
else:
# otherwise, we sort them and GEP
expand_args = tuple(x for x in sorted(dedup(flatten(expands_args))) if x[0] not in exclude_args)
expand_sz = prod([x[1] for x in expand_args])
new_srcs = []
for i,src in enumerate(root.src):
if src.op is Ops.UNROLL:
if root.op is Ops.IF and i == 0:
# IF means OR on first arg to IF
new_srcs.append(functools.reduce(operator.__or__, [src.src[0].gep(i) for i in range(expand_sz)]))
elif expand_args == src.arg:
# just remove the expand
new_srcs.append(src.src[0])
else:
lst = _swizzle_args(expand_args, src.arg, exclude_args)
# if the base dtype is > 1, put those at the end
if src.dtype.count > 1: lst = flatten([[i*src.dtype.count+j for j in range(src.dtype.count)] for i in lst])
new_srcs.append(src.src[0].gep(tuple(lst)))
else:
# non-UNROLL input
if root.op is Ops.IF:
# for the first arg of IF, just pass them through ignoring UNROLLS
new_srcs.append(src)
elif root.op is Ops.REDUCE and src.op is Ops.RANGE:
# for any range args of REDUCE, pass them through
new_srcs.append(src)
elif src.dtype.count > 1:
# put any input dtype > 1 grouped together
new_srcs.append(UOp(Ops.CAT, src.dtype.scalar().vec(expand_sz*src.dtype.count), (src,)*expand_sz))
else:
# repeat the arg
new_srcs.append(src.broadcast(expand_sz))
new_arg = root.arg
if root.op is Ops.GEP:
assert root.dtype.count == 1
# is this right?
new_arg = tuple(range(root.arg[0], new_srcs[0].dtype.count, new_srcs[0].dtype.count // expand_sz))
nsrc = UOp(root.op, root.dtype.scalar().vec(root.dtype.count*expand_sz), tuple(new_srcs), new_arg)
return UOp(Ops.UNROLL, root.dtype, (nsrc,), expand_args)
def do_contract(con:UOp):
ex = con.src[0]
# CONTRACT without UNROLL repeats the element VECTORIZED
if ex.op is not Ops.UNROLL: return UOp(Ops.VECTORIZE, con.dtype, con.src*con.dtype.count)
# CONTRACT may remove several axes from UNROLL
assert con.dtype.count == prod([x[1] for x in con.arg]), "dtype is wrong"
idxs = []
for rpk in _choices_from_args(new_ex_args:=tuple(x for x in ex.arg if x not in con.arg)):
idxs += [_expand_arg_to_idx(ex.arg, {**rpk, **lrpk}) for lrpk in _choices_from_args(con.arg)]
return UOp(Ops.UNROLL, con.dtype, (ex.src[0].gep(tuple(idxs)),), new_ex_args)
expander = PatternMatcher([
# double expand
(UPat(Ops.UNROLL, name="outer", src=(UPat(Ops.UNROLL, name="inner"),)),
lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
# do expansion
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN,
Ops.VECTORIZE, Ops.IF, Ops.REDUCE), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
(UPat(Ops.CONTRACT, name="con"), do_contract),
# vectorize DEFINE_ACC
(UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_ACC, name="acc"), name="v"),
lambda acc,v: acc.replace(dtype=v.dtype, src=(acc.src[0].broadcast(v.dtype.count),)+acc.src[1:])),
# BARRIERs aren't actually expanded
(UPat(Ops.BARRIER, src=(UPat(Ops.UNROLL, name="ex"),)),
lambda ex: UOp(Ops.UNROLL, src=(UOp(Ops.BARRIER, src=ex.src),)*len(ex.src), arg=ex.arg)),
# empty UNROLL is NOOP
(UPat(Ops.UNROLL, src=(UPat.var('x'),), arg=()), lambda x: x),
# UNROLL GEP (needed for WMMA, generalize this) -> vectorized ALU
(UPat(Ops.UNROLL, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))),
lambda ex,x,y: UOp(Ops.UNROLL, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)),
])
def create_gate(root:UOp) -> UOp|None:
@functools.lru_cache(None)
def _gate_srcs(u:UOp, gate:UOp) -> UOp:
if u.op is Ops.BARRIER: return u
if u.op is Ops.LOAD and u.src[-1].op is Ops.BARRIER:
return UOp(u.op, u.dtype, u.src[:-1]+(UOp(Ops.IF, src=(gate, u.src[-1])),), arg=u.arg)
return u if (replace_source:=tuple(_gate_srcs(x, gate) for x in u.src)) == u.src else UOp(u.op, u.dtype, replace_source, u.arg)
idx = root.src[0]
if idx.op is Ops.CAST: idx = idx.src[0]
return None if idx.op is not Ops.INDEX or len(idx.src) == 2 or (ret:=_gate_srcs(root, idx.src[2])) is root else ret
migrate_indexing = PatternMatcher([
# create gate MUST BE BEFORE expander
(UPat(Ops.STORE, name="root"), create_gate),
])
def expand_rewrite(sink:UOp) -> UOp:
# initial symbolic + migrate indexing (remove this)
sink = graph_rewrite(sink, sym+migrate_indexing)
# expand
return graph_rewrite(sink, sym+expander)