openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

167 lines
7.4 KiB

from tinygrad import Tensor, Device, Context, GlobalCounters, dtypes
from tinygrad.helpers import prod, unwrap
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.opt.kernel import AxisType
from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
from tinygrad.uop.ops import graph_rewrite, PatternMatcher, UPat, Ops, UOp, GroupOp
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape
from tinygrad.kernelize.kernelize import merge_views
from tinygrad.shape.view import View
N = 4096
run_count = 5
# change reduceop axes and input ShapeTrackers, view gets replaced with a reshape.
# src->r->view --> src->view->r
def swizzle_reduceop(src:UOp, r:UOp, view:UOp):
if r.tag is not None: return None
# confirm the input is in order
# TODO: replace this with a UOp that allows for nothing else then remove this
permute = tuple(i for i in range(len(src.shape)) if i not in r.axis_arg)+r.axis_arg
assert permute == tuple(range(len(permute))), f"reduce axis must already be in order, {permute} isn't"
# append the reduce shape to each of the views
reduce_count = len(r.axis_arg)
prshape = prod(rshape:=src.shape[-reduce_count:])
rstrides = strides_for_shape(rshape)
nv = [View.create(v.shape[:-reduce_count]+rshape, tuple(x*prshape for x in v.strides[:-reduce_count])+rstrides, v.offset*prshape,
v.mask[:-reduce_count]+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in unwrap(view.st).views]
# no reshape required with shrinking REDUCE_AXIS
return UOp(Ops.REDUCE_AXIS, r.dtype, (src.view(ShapeTracker(tuple(nv))),),
(r.arg[0], tuple(range(len(view.shape)-reduce_count, len(view.shape)))))
early_view_left = merge_views+PatternMatcher([
# view before elementwise and buffer ops
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.BIND, Ops.VALID, Ops.STORE, Ops.LOAD}, name="e"),), name="view"),
lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src)) if e.tag is None else None),
# push a non contiguous ShapeTracker through reduceop
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop),
])
def hand_spec():
# Block Tile size . 128x128
# Thread Tile size . 4x4
# Wave Tile size . 128x32
# A wave is . 8x4
# ────── problem size and tiling params (mirror the C kernel) ───────────────────
BK = 8 # depth of K-tile
BN = BM = 128 # block-tile (output) sizes
# the real thread is 16x8 = 128 regs
TM = 4
nbIterWaveM = 2
TN = 4
nbIterWaveN = 4
# ────── shared-memory tile sizes (unchanged) ───────────────────────────────────
LDS_A_SZ = BK * BM # 1024 floats
LDS_B_SZ = BK * BN # 1024 floats
bC = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0) # output C
bA = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1) # input A
bB = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2) # input B
# TODO: this should not be a string, just a number
lAs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(LDS_A_SZ, local=True), arg="As")
lBs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(LDS_B_SZ, local=True), arg="Bs")
s0 = ShapeTracker.from_shape((N, N, N), (N, 0, 1))
s1 = ShapeTracker.from_shape((N, N, N), (0, 1, N))
s2 = ShapeTracker.from_shape((N, N, 1), (N, 1, 0))
ls0 = ShapeTracker.from_shape((BM, BK))
ls1 = ShapeTracker.from_shape((BN, BK))
buf_at = [AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.LOCAL, AxisType.LOCAL, AxisType.LOCAL, AxisType.UPCAST, AxisType.UPCAST]
buf_bt = [AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.LOCAL, AxisType.LOCAL, AxisType.LOCAL, AxisType.UPCAST, AxisType.UPCAST]
axis_types = buf_at + buf_bt + [AxisType.REDUCE, AxisType.UNROLL, AxisType.UNROLL, AxisType.UNROLL]
# 128 x 128 x 8
full_shape = (N//BM, 2, 2, 2, 2, 2, 2, 2, N//BN, 2, 2, 2, 2, 2, 2, 2, N//BK, 2, 2, 2)
s0 = s0.reshape(full_shape)
s1 = s1.reshape(full_shape)
s2 = s2.reshape(full_shape[:-4] + (1,)*4)
ls0 = ls0.reshape((1, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2)).expand(s0.shape)
ls1 = ls1.reshape((1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2)).expand(s1.shape)
assert ls0.real_size() == LDS_A_SZ
assert ls1.real_size() == LDS_B_SZ
# BK is a loop of 8
# each loop reads 8 in A, 16 in B
print(ls0)
print(ls1)
permaxis = []
for axis_order in [AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP, AxisType.UPCAST, AxisType.GROUP_REDUCE, AxisType.REDUCE, AxisType.UNROLL]:
permaxis += [i for i,a in enumerate(axis_types) if a == axis_order]
axis_types = [axis_types[x] for x in permaxis]
s0, s1, s2, ls0, ls1 = [x.permute(tuple(permaxis)) for x in [s0, s1, s2, ls0, ls1]]
print(axis_types)
lw0, lr0 = ls0, ls0
lw1, lr1 = ls1, ls1
# first round of permutes
permaxis = (0, 1, 19, 18, 17, 12, 11, 10, 5, 4, 3, 2, 6, 7, 8, 9, 16, 13, 14, 15)
s0 = s0.permute(permaxis)
lw0 = lw0.permute(permaxis)
permaxis = (0, 1, 15, 14, 9, 8, 7, 6, 13, 19, 18, 17, 5, 4, 3, 2, 16, 12, 11, 10)
s1 = s1.permute(permaxis)
lw1 = lw1.permute(permaxis)
# second round of permutes
#permaxis = (0, 1, 12, 11, 5, 4, 3, 2, 10, 6, 7, 8, 9, 13, 14, 15, 16, 17, 18, 19)
#lw0 = lw0.permute(permaxis)
#lr0 = lr0.permute(permaxis)
from tinygrad.opt.kernel import axis_colors, colored
print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(s0.shape, s0.views[0].strides, axis_types)]))
print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(s1.shape, s1.views[0].strides, axis_types)]))
print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(s2.shape, s2.views[0].strides, axis_types)]))
print("lw")
print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(lw0.shape, lw0.views[0].strides, axis_types)]))
print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(lw1.shape, lw1.views[0].strides, axis_types)]))
print("lr")
print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(lr0.shape, lr0.views[0].strides, axis_types)]))
print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(lr1.shape, lr1.views[0].strides, axis_types)]))
# loads and stores
bs0 = bA.view(s0).load()
bs1 = bB.view(s1).load()
bs0 = lAs.view(lr0).load(lAs.view(lw0).store(bs0))
bs1 = lBs.view(lr1).load(lBs.view(lw1).store(bs1))
mat = (bs0 * bs1).r(Ops.ADD, tuple([i for i,a in enumerate(axis_types) if a in (AxisType.REDUCE, AxisType.UNROLL)]), permute=False)
st = bC.view(s2).store(mat)
ast = st.sink(arg=KernelInfo(axis_types=tuple(axis_types), name="tinygemm"))
ast = graph_rewrite(ast, merge_views)
prg = get_program(ast, Device.default.renderer)
print(prg.src)
return prg
if __name__ == "__main__":
hprg = hand_spec()
hrunner = CompiledRunner(hprg)
a = Tensor.randn(N, N).realize()
b = Tensor.randn(N, N).realize()
hc = Tensor.zeros(N, N).contiguous().realize()
GlobalCounters.reset()
with Context(DEBUG=2, BEAM=4):
for _ in range(run_count): tc = (a@b).realize()
GlobalCounters.reset()
ei = ExecItem(hrunner, [hc.uop.buffer, a.uop.buffer, b.uop.buffer])
with Context(DEBUG=2):
for _ in range(run_count): ei.run(wait=True)
err = (hc-tc).square().mean().item()
print(f"hrunner {err}")
assert err < 1e-06