You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
167 lines
7.4 KiB
167 lines
7.4 KiB
from tinygrad import Tensor, Device, Context, GlobalCounters, dtypes
|
|
from tinygrad.helpers import prod, unwrap
|
|
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
|
from tinygrad.opt.kernel import AxisType
|
|
from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
|
|
from tinygrad.uop.ops import graph_rewrite, PatternMatcher, UPat, Ops, UOp, GroupOp
|
|
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape
|
|
from tinygrad.kernelize.kernelize import merge_views
|
|
from tinygrad.shape.view import View
|
|
|
|
N = 4096
|
|
run_count = 5
|
|
|
|
# change reduceop axes and input ShapeTrackers, view gets replaced with a reshape.
|
|
# src->r->view --> src->view->r
|
|
def swizzle_reduceop(src:UOp, r:UOp, view:UOp):
|
|
if r.tag is not None: return None
|
|
# confirm the input is in order
|
|
# TODO: replace this with a UOp that allows for nothing else then remove this
|
|
permute = tuple(i for i in range(len(src.shape)) if i not in r.axis_arg)+r.axis_arg
|
|
assert permute == tuple(range(len(permute))), f"reduce axis must already be in order, {permute} isn't"
|
|
|
|
# append the reduce shape to each of the views
|
|
reduce_count = len(r.axis_arg)
|
|
prshape = prod(rshape:=src.shape[-reduce_count:])
|
|
rstrides = strides_for_shape(rshape)
|
|
nv = [View.create(v.shape[:-reduce_count]+rshape, tuple(x*prshape for x in v.strides[:-reduce_count])+rstrides, v.offset*prshape,
|
|
v.mask[:-reduce_count]+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in unwrap(view.st).views]
|
|
|
|
# no reshape required with shrinking REDUCE_AXIS
|
|
return UOp(Ops.REDUCE_AXIS, r.dtype, (src.view(ShapeTracker(tuple(nv))),),
|
|
(r.arg[0], tuple(range(len(view.shape)-reduce_count, len(view.shape)))))
|
|
|
|
early_view_left = merge_views+PatternMatcher([
|
|
# view before elementwise and buffer ops
|
|
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.BIND, Ops.VALID, Ops.STORE, Ops.LOAD}, name="e"),), name="view"),
|
|
lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src)) if e.tag is None else None),
|
|
# push a non contiguous ShapeTracker through reduceop
|
|
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop),
|
|
])
|
|
|
|
def hand_spec():
|
|
# Block Tile size . 128x128
|
|
# Thread Tile size . 4x4
|
|
# Wave Tile size . 128x32
|
|
# A wave is . 8x4
|
|
# ────── problem size and tiling params (mirror the C kernel) ───────────────────
|
|
BK = 8 # depth of K-tile
|
|
BN = BM = 128 # block-tile (output) sizes
|
|
# the real thread is 16x8 = 128 regs
|
|
TM = 4
|
|
nbIterWaveM = 2
|
|
TN = 4
|
|
nbIterWaveN = 4
|
|
|
|
# ────── shared-memory tile sizes (unchanged) ───────────────────────────────────
|
|
LDS_A_SZ = BK * BM # 1024 floats
|
|
LDS_B_SZ = BK * BN # 1024 floats
|
|
|
|
bC = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0) # output C
|
|
bA = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1) # input A
|
|
bB = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2) # input B
|
|
|
|
# TODO: this should not be a string, just a number
|
|
lAs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(LDS_A_SZ, local=True), arg="As")
|
|
lBs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(LDS_B_SZ, local=True), arg="Bs")
|
|
|
|
s0 = ShapeTracker.from_shape((N, N, N), (N, 0, 1))
|
|
s1 = ShapeTracker.from_shape((N, N, N), (0, 1, N))
|
|
s2 = ShapeTracker.from_shape((N, N, 1), (N, 1, 0))
|
|
|
|
ls0 = ShapeTracker.from_shape((BM, BK))
|
|
ls1 = ShapeTracker.from_shape((BN, BK))
|
|
|
|
buf_at = [AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.LOCAL, AxisType.LOCAL, AxisType.LOCAL, AxisType.UPCAST, AxisType.UPCAST]
|
|
buf_bt = [AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.LOCAL, AxisType.LOCAL, AxisType.LOCAL, AxisType.UPCAST, AxisType.UPCAST]
|
|
axis_types = buf_at + buf_bt + [AxisType.REDUCE, AxisType.UNROLL, AxisType.UNROLL, AxisType.UNROLL]
|
|
|
|
# 128 x 128 x 8
|
|
full_shape = (N//BM, 2, 2, 2, 2, 2, 2, 2, N//BN, 2, 2, 2, 2, 2, 2, 2, N//BK, 2, 2, 2)
|
|
|
|
s0 = s0.reshape(full_shape)
|
|
s1 = s1.reshape(full_shape)
|
|
s2 = s2.reshape(full_shape[:-4] + (1,)*4)
|
|
|
|
ls0 = ls0.reshape((1, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2)).expand(s0.shape)
|
|
ls1 = ls1.reshape((1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2)).expand(s1.shape)
|
|
assert ls0.real_size() == LDS_A_SZ
|
|
assert ls1.real_size() == LDS_B_SZ
|
|
|
|
# BK is a loop of 8
|
|
# each loop reads 8 in A, 16 in B
|
|
|
|
print(ls0)
|
|
print(ls1)
|
|
|
|
permaxis = []
|
|
for axis_order in [AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP, AxisType.UPCAST, AxisType.GROUP_REDUCE, AxisType.REDUCE, AxisType.UNROLL]:
|
|
permaxis += [i for i,a in enumerate(axis_types) if a == axis_order]
|
|
axis_types = [axis_types[x] for x in permaxis]
|
|
s0, s1, s2, ls0, ls1 = [x.permute(tuple(permaxis)) for x in [s0, s1, s2, ls0, ls1]]
|
|
print(axis_types)
|
|
|
|
lw0, lr0 = ls0, ls0
|
|
lw1, lr1 = ls1, ls1
|
|
|
|
# first round of permutes
|
|
|
|
permaxis = (0, 1, 19, 18, 17, 12, 11, 10, 5, 4, 3, 2, 6, 7, 8, 9, 16, 13, 14, 15)
|
|
s0 = s0.permute(permaxis)
|
|
lw0 = lw0.permute(permaxis)
|
|
|
|
permaxis = (0, 1, 15, 14, 9, 8, 7, 6, 13, 19, 18, 17, 5, 4, 3, 2, 16, 12, 11, 10)
|
|
s1 = s1.permute(permaxis)
|
|
lw1 = lw1.permute(permaxis)
|
|
|
|
# second round of permutes
|
|
#permaxis = (0, 1, 12, 11, 5, 4, 3, 2, 10, 6, 7, 8, 9, 13, 14, 15, 16, 17, 18, 19)
|
|
#lw0 = lw0.permute(permaxis)
|
|
#lr0 = lr0.permute(permaxis)
|
|
|
|
from tinygrad.opt.kernel import axis_colors, colored
|
|
print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(s0.shape, s0.views[0].strides, axis_types)]))
|
|
print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(s1.shape, s1.views[0].strides, axis_types)]))
|
|
print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(s2.shape, s2.views[0].strides, axis_types)]))
|
|
print("lw")
|
|
print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(lw0.shape, lw0.views[0].strides, axis_types)]))
|
|
print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(lw1.shape, lw1.views[0].strides, axis_types)]))
|
|
print("lr")
|
|
print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(lr0.shape, lr0.views[0].strides, axis_types)]))
|
|
print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(lr1.shape, lr1.views[0].strides, axis_types)]))
|
|
|
|
# loads and stores
|
|
bs0 = bA.view(s0).load()
|
|
bs1 = bB.view(s1).load()
|
|
bs0 = lAs.view(lr0).load(lAs.view(lw0).store(bs0))
|
|
bs1 = lBs.view(lr1).load(lBs.view(lw1).store(bs1))
|
|
|
|
mat = (bs0 * bs1).r(Ops.ADD, tuple([i for i,a in enumerate(axis_types) if a in (AxisType.REDUCE, AxisType.UNROLL)]), permute=False)
|
|
st = bC.view(s2).store(mat)
|
|
|
|
ast = st.sink(arg=KernelInfo(axis_types=tuple(axis_types), name="tinygemm"))
|
|
ast = graph_rewrite(ast, merge_views)
|
|
prg = get_program(ast, Device.default.renderer)
|
|
print(prg.src)
|
|
return prg
|
|
|
|
|
|
if __name__ == "__main__":
|
|
hprg = hand_spec()
|
|
hrunner = CompiledRunner(hprg)
|
|
|
|
a = Tensor.randn(N, N).realize()
|
|
b = Tensor.randn(N, N).realize()
|
|
hc = Tensor.zeros(N, N).contiguous().realize()
|
|
|
|
GlobalCounters.reset()
|
|
with Context(DEBUG=2, BEAM=4):
|
|
for _ in range(run_count): tc = (a@b).realize()
|
|
|
|
GlobalCounters.reset()
|
|
ei = ExecItem(hrunner, [hc.uop.buffer, a.uop.buffer, b.uop.buffer])
|
|
with Context(DEBUG=2):
|
|
for _ in range(run_count): ei.run(wait=True)
|
|
err = (hc-tc).square().mean().item()
|
|
print(f"hrunner {err}")
|
|
assert err < 1e-06
|
|
|