openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

90 lines
4.5 KiB

# the job of the lowerer is to do indexing
from dataclasses import dataclass
from typing import cast
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint_to_uop
from tinygrad.helpers import prod, partition, flatten
# ***** indexing *****
@dataclass
class IndexContext:
idxs: list[UOp]
ridxs: list[UOp]
def get_index(ast:UOp) -> IndexContext:
ki = ast.arg if isinstance(ast.arg, KernelInfo) else KernelInfo()
# NOTE: assumes the shape is <global dims> <local dims> <group_for_reduces> <reduces> <upcasts/unrolls>
full_shape = ast.full_shape
first_upcasted = len(full_shape)-ki.upcasted
# all loops are RANGES
idxs = [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(g),), i) for i,g in enumerate(full_shape[:first_upcasted])]
# upcast loops
for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
assert isinstance(g, int), "needs to be int to upcast/unroll"
idxs.append(UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(g), tuple(range(g))),), ((i,g),)))
# late indexes (group for reduce)
# if there's no reduce, this is first_upcasted. assumes reduces are at the end
first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.toposort() if x.op is Ops.REDUCE_AXIS))
local_loads = [x for x in ast.toposort() if x.op is Ops.LOAD and x.src[0].base.op is Ops.DEFINE_LOCAL]
# NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces
group_for_reduces = sum([any(l.st_arg.shape[i]!=ast.src[0].st_arg.shape[i] for l in local_loads) for i in range(first_reduce,first_upcasted)])
ridxs = idxs[:]
for a in range(first_reduce, first_reduce+group_for_reduces):
ridxs[a] = UOp(Ops.RANGE, dtypes.int, (sint_to_uop(full_shape[a]),), 1000+a)
return IndexContext(idxs, ridxs)
# ***** lowering (given index) *****
def lower_reduce_axis(ctx: IndexContext, x: UOp):
# NOTE: always using ridxs is fine here
reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is Ops.RANGE)
assert all(x.op is Ops.UNROLL for x in reduce_expand), f"not all UNROLLS in {reduce_expand} for {x.axis_arg}"
alu_op: Ops = x.arg[0]
ret = x.src[0]
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
# REDUCE supports both "horizontal" reduction and range reduction. the horizontal elements are taken in the nearest group
return UOp(Ops.REDUCE, x.dtype, (ret,)+tuple(reduce_range), alu_op)
def lower_load(ctx: IndexContext, x: UOp, buf: UOp):
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if buf.op is Ops.DEFINE_LOCAL else ctx.idxs)
barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[1],)),) if buf.op is Ops.DEFINE_LOCAL else ()
return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid),) + barrier)
def lower_store(ctx: IndexContext, x: UOp, buf: UOp):
idx, valid = x.st_arg.to_indexed_uops(ctx.idxs)
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
if cast(PtrDType, buf.dtype).local and x.src[1].op is Ops.REDUCE:
reduce_input = x.src[1].src[0]
store_back = reduce_input.op is Ops.LOAD and cast(PtrDType, reduce_input.src[0].dtype).local
else: store_back = False
if (not cast(PtrDType, buf.dtype).local) or store_back:
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
if oidx is not ridx: valid = valid * oidx.eq(0)
return UOp(Ops.STORE, dtypes.void, (buf.index(idx, valid), x.src[1]))
def lower_const(ctx:IndexContext, view:UOp, c:UOp):
if all(x.mask is None for x in view.arg.views): return c
_, valid = view.arg.to_indexed_uops(ctx.idxs)
return valid.where(c, c.const_like(0))
pm_lowerer = PatternMatcher([
# TODO: remove these hacks
# hack for old style CONST(VIEW) (now it's just VIEW(CONST))
(UPat((Ops.DEFINE_VAR, Ops.CONST), src=(UPat(Ops.VIEW, name="v"),), name="c"), lambda c,v: c.replace(src=()).view(v.arg)),
# hack for old style VALID (now it's just VIEW(CONST))
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c"), UPat(Ops.CONST, arg=0)), lambda c,v: c.replace(src=()).view(v.arg)),
# reduce/view_const
(UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"),), name="view"), lower_const),
# rewrite LOAD/STORE VIEW to LOAD/STORE with indexed
(UPat(Ops.LOAD, src=(UPat.var("buf").view(),), allow_any_len=True, name="x"), lower_load),
(UPat(Ops.STORE, src=(UPat.var("buf").view(),), allow_any_len=True, name="x"), lower_store),
])