You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
96 lines
4.0 KiB
96 lines
4.0 KiB
# kernel8_batched_gmem.s from https://seb-v.github.io/optimization/update/2025/01/20/Fast-GPU-Matrix-multiplication.html
|
|
import pathlib
|
|
import numpy as np
|
|
from dataclasses import replace
|
|
from tinygrad import Tensor, Device, Context
|
|
from tinygrad.helpers import getenv
|
|
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
|
|
from tinygrad.engine.realize import CompiledRunner, ExecItem
|
|
from tinygrad.ops import graph_rewrite, PatternMatcher, UPat, Ops, UOp
|
|
|
|
# TODO: on METAL for `DEBUG=4 python3 extra/gemm/amd_matmul.py`
|
|
# * fix load grouping (like float4). idk why it's not working, need new devectorizer (this is a Monday project)
|
|
# * DONE - remove extra barrier
|
|
# * DONE (moved Ops.ADD) - fix load order to be in order (the +0 one is last!)
|
|
# * explore async (fast) global load -> local store
|
|
# * why is TC=3 broken for 4096x4096?
|
|
# * write syntactic sugar for these local additions + use it in tensor core kernel.py
|
|
|
|
N = 4096
|
|
LN = 16
|
|
run_count = 5
|
|
|
|
from tinygrad.shape.shapetracker import ShapeTracker, View
|
|
def transform_load(ctx:tuple[Kernel, set[UOp]], x:UOp):
|
|
if x.src[0].op is not Ops.DEFINE_GLOBAL: return None
|
|
if x in ctx[1]: return None
|
|
print(ctx[0].colored_shape())
|
|
ctx[1].add(x)
|
|
input_st: ShapeTracker = x.src[1].arg
|
|
#strides = input_st.real_strides()
|
|
#strides = (0,0)+strides[2:]
|
|
if input_st.real_strides()[2] == 0:
|
|
perm = (0,1,5,3,4,2)
|
|
strides = (0,0,LN*4,4,0,0,1,0)
|
|
elif input_st.real_strides()[3] == 0:
|
|
perm = (0,1,2,5,4,3)
|
|
strides = (0,0,LN*4,4,0,0,0,1)
|
|
else:
|
|
return None
|
|
if len(input_st.shape) == 8:
|
|
local_st = ShapeTracker(views=(View.create((1,1,LN,LN,1,1,4,4), strides),))
|
|
perm = perm + (6,7)
|
|
else:
|
|
local_st = ShapeTracker(views=(View.create((1,1,LN,LN,1,1)),))
|
|
#local_st = ShapeTracker(views=(View.create((1,1,LN,LN,1,1)),))
|
|
load_st = local_st.permute(perm)
|
|
input_st = input_st.permute(perm)
|
|
lcl = UOp(Ops.DEFINE_LOCAL, x.dtype.ptr(local_st.real_size(), local=True), (), f"temp{x.src[0].arg}")
|
|
global_load = x.replace(src=(x.src[0], input_st.to_uop()))
|
|
ret = UOp(Ops.STORE, src=(lcl, local_st.to_uop(), global_load))
|
|
return UOp(Ops.LOAD, x.dtype, src=(lcl, load_st.to_uop(), ret))
|
|
|
|
local_loads_pm = PatternMatcher([
|
|
(UPat(Ops.LOAD, name="x"), transform_load),
|
|
])
|
|
|
|
def ast_transform(k, ast):
|
|
#return ast
|
|
ast = graph_rewrite(ast, local_loads_pm, ctx=(k, set()))
|
|
#ast = ast.replace(arg=replace(ast.arg, upcasted=0))
|
|
print(ast)
|
|
return ast
|
|
|
|
if __name__ == "__main__":
|
|
rng = np.random.default_rng()
|
|
a = Tensor(na:=rng.random((4096, 4096), dtype=np.float32)).realize()
|
|
b = Tensor(nb:=rng.random((4096, 4096), dtype=np.float32)).realize()
|
|
c = a @ b
|
|
si = c.schedule()[-1]
|
|
k = Kernel(si.ast, opts=Device[Device.DEFAULT].renderer)
|
|
#opts = [Opt(op=OptOps.LOCAL, axis=1, arg=16),
|
|
# Opt(op=OptOps.LOCAL, axis=0, arg=8),
|
|
# Opt(op=OptOps.UPCAST, axis=2, arg=4),
|
|
# Opt(op=OptOps.UPCAST, axis=1, arg=4),
|
|
# Opt(op=OptOps.UPCAST, axis=0, arg=2)]
|
|
#opts = [Opt(op=OptOps.UPCAST, axis=1, arg=4),
|
|
# Opt(op=OptOps.UPCAST, axis=0, arg=4),
|
|
# Opt(op=OptOps.LOCAL, axis=1, arg=8),
|
|
# Opt(op=OptOps.LOCAL, axis=0, arg=4)]
|
|
opts = [Opt(op=OptOps.UNROLL, axis=0, arg=LN),
|
|
#Opt(op=OptOps.UPCAST, axis=0, arg=4),
|
|
#Opt(op=OptOps.UPCAST, axis=1, arg=4),
|
|
Opt(op=OptOps.LOCAL, axis=1, arg=LN),
|
|
Opt(op=OptOps.LOCAL, axis=0, arg=LN)]
|
|
for opt in opts: k.apply_opt(opt)
|
|
prg = k.to_program(ast_transform=ast_transform)
|
|
if getenv("FAST", 1) and Device.DEFAULT == "AMD":
|
|
#src = (pathlib.Path(__file__).parent / "fp32_sgemm_amd" / "src" / "kernel8_batched_gmem.s").read_text()
|
|
src = (pathlib.Path(__file__).parent / "kernel8_batched_gmem.s").read_text()
|
|
prg = replace(prg, src=src, global_size=[N//128, N//128, 1], local_size=[128, 1, 1])
|
|
print(prg.global_size, prg.local_size)
|
|
ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)
|
|
with Context(DEBUG=2):
|
|
for _ in range(run_count): ei.run(wait=True)
|
|
nc = c.numpy()
|
|
np.testing.assert_allclose(na@nb, nc, rtol=1e-5)
|
|
|