You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
56 lines
3.4 KiB
56 lines
3.4 KiB
# ruff: noqa: E501
|
|
from tinygrad import dtypes
|
|
from tinygrad.helpers import Timing, getenv
|
|
from tinygrad.codegen.opt.kernel import Opt, OptOps
|
|
from tinygrad.engine.realize import get_program, CompiledRunner
|
|
from tinygrad.uop.ops import UOp, Ops, AxisType
|
|
|
|
if __name__ == "__main__":
|
|
if getenv("TC", 0) == 0:
|
|
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=0, src=())
|
|
c1 = UOp.range(UOp.const(dtypes.int, 512), 0, AxisType.GLOBAL)
|
|
c2 = UOp.range(UOp.const(dtypes.int, 64), 1, AxisType.GLOBAL)
|
|
c3 = UOp.range(UOp.const(dtypes.int, 6), 2, AxisType.GLOBAL)
|
|
c4 = UOp.range(UOp.const(dtypes.int, 6), 3, AxisType.GLOBAL)
|
|
c5 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(2097152), arg=1, src=())
|
|
c6 = UOp.range(UOp.const(dtypes.int, 64), 1004, AxisType.REDUCE)
|
|
c7 = UOp.range(UOp.const(dtypes.int, 3), 1005, AxisType.REDUCE)
|
|
c8 = UOp.range(UOp.const(dtypes.int, 3), 1006, AxisType.REDUCE)
|
|
c9 = c5.index(((((((c1*UOp.const(dtypes.int, 4096))+(c3*UOp.const(dtypes.int, 8)))+c4)+(c6*UOp.const(dtypes.int, 64)))+(c7*UOp.const(dtypes.int, 8)))+c8), UOp.const(dtypes.bool, True)).load()
|
|
c10 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(36864), arg=2, src=())
|
|
c11 = c10.index(((((c2*UOp.const(dtypes.int, 576))+(c6*UOp.const(dtypes.int, 9)))+(c7*UOp.const(dtypes.int, 3)))+c8), UOp.const(dtypes.bool, True)).load()
|
|
c12 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=3, src=())
|
|
c13 = c12.index(c2, UOp.const(dtypes.bool, True)).load()
|
|
c14 = ((c9*c11).reduce(c6, c7, c8, arg=Ops.ADD)+c13)
|
|
c15 = c0.index(((((c1*UOp.const(dtypes.int, 2304))+(c2*UOp.const(dtypes.int, 36)))+(c3*UOp.const(dtypes.int, 6)))+c4), UOp.const(dtypes.bool, True)).store(c14, c1, c2, c3, c4)
|
|
ast = c15.sink()
|
|
|
|
# this does have tons of locals
|
|
opts = [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.UPCAST, axis=3, arg=0),
|
|
Opt(op=OptOps.LOCAL, axis=0, arg=16), Opt(op=OptOps.UPCAST, axis=3, arg=2),
|
|
Opt(op=OptOps.GROUPTOP, axis=0, arg=16)]
|
|
else:
|
|
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(10616832), arg=0, src=())
|
|
c1 = UOp.range(UOp.const(dtypes.int, 512), 0, AxisType.GLOBAL)
|
|
c2 = UOp.range(UOp.const(dtypes.int, 64), 1, AxisType.GLOBAL)
|
|
c3 = UOp.range(UOp.const(dtypes.int, 36), 2, AxisType.GLOBAL)
|
|
c4 = UOp.range(UOp.const(dtypes.int, 9), 3, AxisType.GLOBAL)
|
|
c5 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(36864), arg=1, src=())
|
|
c6 = UOp.range(UOp.const(dtypes.int, 64), 1004, AxisType.REDUCE)
|
|
c7 = c5.index((((c2*UOp.const(dtypes.int, 9))+c4)+(c6*UOp.const(dtypes.int, 576))), UOp.const(dtypes.bool, True)).load()
|
|
c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=2, src=())
|
|
c9 = c8.index((((c1*UOp.const(dtypes.int, 2304))+c3)+(c6*UOp.const(dtypes.int, 36))), UOp.const(dtypes.bool, True)).load()
|
|
c10 = (c7*c9).reduce(c6, arg=Ops.ADD)
|
|
c11 = c0.index(((((c1*UOp.const(dtypes.int, 20736))+(c2*UOp.const(dtypes.int, 324)))+(c3*UOp.const(dtypes.int, 9)))+c4), UOp.const(dtypes.bool, True)).store(c10, c1, c2, c3, c4)
|
|
ast = c11.sink()
|
|
|
|
opts = [Opt(op=OptOps.TC, axis=0, arg=(0, 0, 1)), Opt(op=OptOps.UPCAST, axis=2, arg=4),
|
|
Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=0)]
|
|
|
|
prg = get_program(ast, opts=opts)
|
|
print(prg.src)
|
|
for i in range(10):
|
|
with Timing(f"try {i}: "):
|
|
# NOTE: this doesn't even run the kernel
|
|
try: CompiledRunner(prg)
|
|
except RuntimeError: pass
|
|
|