import numpy as np import unittest from dataclasses import replace from tinygrad.codegen.opt import Opt, OptOps from tinygrad.codegen.gpudims import get_grouped_dims from tinygrad.uop.ops import UOp, Ops, GroupOp from tinygrad.device import Device, Buffer, is_dtype_supported from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner, get_program from tinygrad.helpers import Context, flatten, dedup, TC_SELECT, TC_OPT from tinygrad.dtype import DType, dtypes, PtrDType, AddrSpace from tinygrad.codegen import apply_rewrites, rewrites_for_views from tinygrad.renderer.ptx import PTXRenderer class TestLinearizer(unittest.TestCase): def test_arg_dedup(self): # NOTE: this realize exists because Tensor.numpy calls .contiguous() internally # without contiguous folding, rand.to("CPU") and rand.contiguous().to("CPU") are different UOps. # this test asserts they are the identical Buffer # having different buffers is fine for correctness, because the outputs match. a, b = Tensor.randn(4).realize(), Tensor.randn(4).realize() np_a, np_b = a.numpy(), b.numpy() c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),)))) lowered = [x[1] for x in lower_schedule(c.schedule())] for ei in lowered: ei.run() rawbufs = lowered[-1].bufs assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.uop.base.realized, b.uop.base.realized} np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:]) np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4) def test_load_removed(self): a = Tensor.rand(1).realize() b = Tensor.rand(1).realize() ta = Tensor.where(Tensor(True), a, b).numpy() tb = Tensor.where(Tensor(False), a, b).numpy() np.testing.assert_equal(a.numpy(), ta) np.testing.assert_equal(b.numpy(), tb) def test_multioutput(self): dtype, st = dtypes.int, ShapeTracker.from_shape((8,)) g0, g1, g2, g3 = [UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), arg=i) for i in range(4)] a = UOp(Ops.LOAD, dtype, src=(g2.view(st),)) b = UOp(Ops.LOAD, dtype, src=(g3.view(st),)) out0 = UOp(Ops.STORE, dtypes.void, src=(g0.view(st), a + b)) out1 = UOp(Ops.STORE, dtypes.void, src=(g1.view(st), a * b)) sink = UOp(Ops.SINK, src=(out0, out1)) a_t = Tensor.full(st.shape, 2).contiguous().realize() b_t = Tensor.full(st.shape, 3).contiguous().realize() helper_linearizer_ast(sink, [a_t, b_t], wanna_output=[a_t.numpy()+b_t.numpy(), a_t.numpy()*b_t.numpy()]) uops = get_program(sink, opts=[]).uops stores = [u for u in uops if u.op is Ops.STORE] mutable_bufs = dedup(flatten([[x for x in u.src[0].toposort() if x.op is Ops.DEFINE_GLOBAL] for u in stores])) assert len(mutable_bufs) == len(stores) == 2 self.assertSetEqual(set([u.arg for u in mutable_bufs]), set([0,1])) def _test_no_nested_ranges(self, lins, skip=None): for l in lins: range_in_acc = flatten([[x for x in u.src if x.op is Ops.RANGE] for u in l.uops if u.op is Ops.DEFINE_REG]) ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u in range_in_acc) or (u.op is Ops.ENDRANGE and u.src[0] in range_in_acc)] for i,u in enumerate(ranges): if skip and i in skip: continue assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}" def test_two_nested_range(self): a = Tensor.randn(2, ).realize() out = a.reshape(2, 1).expand(2, 3).sum() ast = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)).sum()]) uops = get_program(ast, opts=[]).uops ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE] assert len(ranges) == 1 # NOTE: it collapses now def test_three_nested_range(self): a = Tensor.randn(2, ).realize() out = a.reshape(2, 1).expand(2, 3).expand(2, 2, 3).sum() ast = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)), (2, 2, 3)).sum()]) uops = get_program(ast, opts=[]).uops ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE] assert len(ranges) == 1 # NOTE: it collapses now def test_two_nested_range_alt_indexing(self): a = Tensor([2, 2]).realize() out = a.reshape(2, 1).pad(((1, 1), (1, 1)), value=2).sum() ast = helper_linearizer_opt(out, wanna_output=[24]) uops = get_program(ast, opts=[]).uops ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE] # RANGE -> ALU -> RANGE -> ALU + LOAD -> STORE assert any(x.op in GroupOp.ALU for x in uops[ranges[0]:ranges[1]]) assert not any(x.op is Ops.LOAD for x in uops[ranges[0]:ranges[1]]) assert any(x.op in {*GroupOp.ALU, Ops.LOAD} for x in uops[ranges[1]:]) def test_range_outer_op_before_phi(self): a = Tensor.randn(4, 1).realize() b = Tensor.randn(1, 1).realize() out = (a + b[0]).sum() + b[0] ast = helper_linearizer_opt(out, wanna_output=[(a.numpy()+b.numpy()[0]).sum()+b.numpy()]) uops = get_program(ast, opts=[]).uops ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE] # LOAD -> RANGE -> LOAD -> STORE assert len([x for x in uops[:ranges[0]] if x.op is Ops.LOAD]) == 1 def test_range_outer_op_before_phi_nested_range(self): a = Tensor.randn(2, ).realize() b = Tensor.randn(1, 1).realize() out = (a.reshape(2, 1).expand(2, 3) + b[0]).sum() + b[0] ast = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)) + b.numpy()[0]).sum() + b.numpy()]) uops = get_program(ast, opts=[]).uops ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE] assert len(ranges) == 1 # NOTE: it collapses now def test_load_dedup(self): # for different leaves in the AST, the same loads may occur. a = Tensor.randn(4).realize() # these are of size 3 to avoid float4 coalesce r = a[:-1] + a[1:] uops = get_program(r.schedule()[-1].ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0)]).uops num_loads = len([uop for uop in uops if uop.op is Ops.LOAD]) assert num_loads <= 4, "more load uops than needed" assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?" def test_upcast_cse(self): # when upcasting, within a subtree, there may be common expressions. a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize() r = a.expand([2]) + b.expand([2]) uops = get_program(r.schedule()[-1].ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0)]).uops num_ops = len([uop for uop in uops if uop.op in GroupOp.ALU]) assert num_ops <= 1, "more alu uops than needed" @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") def test_reduce_upcast(self): x, w = Tensor.randn((1,1,3)).realize(), Tensor.randn((1,1,2)).realize() r = Tensor.conv2d(x,w,padding=1).relu() uops = get_program(r.schedule()[-1].ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]).uops accs = [u for u in uops if u.op is Ops.DEFINE_REG] stores = [u for u in uops if u.op is Ops.STORE] assert len(accs) == 0 # it's removed now assert len(stores) == 1 assert stores[0].src[1].dtype == dtypes.float.vec(4) # NOTE: can reenable, it does work. it just makes BEAM slow @unittest.expectedFailure @unittest.skipUnless(Device.DEFAULT == "CPU", "test only for CPU") def test_upcast_with_locals_cpu(self): out = Tensor.ones(64,64).contiguous() @ Tensor.ones(64,64).contiguous() prg = get_program(out.schedule()[-1].ast, opts=[Opt(OptOps.LOCAL, axis=0, arg=4)]).uops self.assertEqual(len(prg.src.split("for")), 5) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "broken on ptx for some reason") def test_upcast_with_locals(self): x, y = Tensor.rand(1,128), Tensor.rand(128, 128) r = (x@y).relu() opts_to_apply = [Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4)] program = get_program(r.schedule()[-1].ast, opts=opts_to_apply) stores = [u for u in program.uops if u.op is Ops.STORE and u.src[0].dtype.addrspace != AddrSpace.REG] # the first store is to lds and can be upcasted assert stores[0].src[1].dtype == dtypes.float.vec(4) assert any(x.op is Ops.DEFINE_LOCAL for x in stores[0].toposort()) # the second store is to gds with no upcasts assert stores[1].src[1].dtype == dtypes.float assert any(x.op is Ops.DEFINE_GLOBAL for x in stores[1].toposort()) def test_zero_fold(self): a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize() r = Tensor.stack(a, b) uops = get_program(r.schedule()[-1].ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0)]).uops num_ops = len([uop for uop in uops if uop.op in GroupOp.ALU]) assert num_ops == 0, "more alu uops than needed" def test_sum_acc_dtype(self): for tensor_dtype, acc_dtype in ( (dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float)): if is_dtype_supported(tensor_dtype) and is_dtype_supported(acc_dtype): a = Tensor([1, 2, 3], dtype=tensor_dtype).sum() realized_ast = a.schedule()[-1].ast program = get_program(realized_ast, opts=[]) local = [uop for uop in program.uops if uop.op is Ops.DEFINE_REG] assert local[0].dtype.base == acc_dtype def test_arg_acc_dtype(self): def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType): realized_ast = c.schedule()[-1].ast program = get_program(realized_ast, opts=[]) local = [uop for uop in program.uops if uop.op is Ops.DEFINE_REG] self.assertEqual(local[0].dtype.base, expected_dtype) tests = ( (dtypes.float16, None, dtypes.float), (dtypes.bfloat16, None, dtypes.float), (dtypes.float, None, dtypes.float), (dtypes.float16, dtypes.float16, dtypes.float16), (dtypes.bfloat16, dtypes.bfloat16, dtypes.bfloat16), (dtypes.float, dtypes.float16, dtypes.float16), ) for tensor_dtype, acc_dtype, expected_dtype in tests: if is_dtype_supported(tensor_dtype) and is_dtype_supported(acc_dtype) and is_dtype_supported(expected_dtype): a, b = Tensor.rand(8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, dtype=tensor_dtype) helper_arg_acc_dtype(a.sum(dtype=acc_dtype), expected_dtype) helper_arg_acc_dtype(a.matmul(b, dtype=acc_dtype), expected_dtype) helper_arg_acc_dtype(Tensor.einsum("ki,ij->kj", a, b, dtype=acc_dtype), expected_dtype) d, w = Tensor.rand(4, 8, 8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, 2, 2, dtype=tensor_dtype) helper_arg_acc_dtype(d.conv2d(w, dtype=acc_dtype), expected_dtype) @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") def test_simple_unroll_no_between_phi_dependencies(self): x, y = Tensor.rand(128, 128), Tensor.rand(128, 128) r = (x@y).relu() opt = [Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4)] ast = helper_linearizer_opt(r, [opt]) # the uops graph is DEFINE_REG -> 4x STORE 0.0 -> RANGE -> 4x ALU -> 4x STORE -> ENDRANGE uops = get_program(ast, opts=opt).uops begin_range = [i for i, x in enumerate(uops) if x.op is Ops.RANGE][-1] end_range = [i for i, x in enumerate(uops) if x.op is Ops.ENDRANGE][0] for i,u in enumerate(uops): print(i, u.op, [uops.index(s) for s in u.src], u.arg, u.dtype) for u in uops: if u.op is Ops.STORE and isinstance(dt:=u.src[0].dtype, PtrDType) and dt.addrspace is AddrSpace.REG: if uops.index(u) < begin_range: assert u.src[1].op is Ops.CONST else: assert u.src[1].op in GroupOp.ALU assert begin_range < uops.index(u) < end_range # children of STORE are placed after ENDRANGE if any(x.op is Ops.STORE and x.src[1].op in GroupOp.ALU for x in u.src): assert end_range < uops.index(u) def test_grouped_dims(self): def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes, assert_same_length = True): idxs = get_grouped_dims(prefix, dims, max_sizes, reverse_dims) loop_idxs = dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs])) loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg) sizes = [x.src[0].arg for x in loop_idxs] assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}" if assert_same_length: assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}" assert sizes == expected_sizes, f"expected sizes={expected_sizes}, got {sizes=}" # TODO: add these back after uop symbolic # for i in range(len(dims)): # assert idxs[i].max+1 == dims[i], f"idxs[{i}] should have max {dims[i]-1}" # for i in range(len(loop_idxs)): # assert loop_idxs[i].expr.startswith(prefix), f"loop_idxs[{i}] must start with {prefix}" # assert loop_idxs[i].max+1 == sizes[i], f"loop_idxs[{i}] should have max {sizes[i]-1}" # no-op _assert_grouped_dims("gidx", (2,), (16,16,16), False, [2]) _assert_grouped_dims("gidx", (2,3), (16,16,16), False, [2,3]) # check reverse dims _assert_grouped_dims("gidx", (2,3), (16,16,16), True, [3,2]) _assert_grouped_dims("gidx", (2,3,4), (16,16,16), False, [2,3,4]) # test splitting globals: len(dims) == len(max) _assert_grouped_dims("gidx", (64,3,4), (16,16,16), False, [16,12,4]) _assert_grouped_dims("gidx", (64,3,4), (16,4,16), False, [16,3,16]) _assert_grouped_dims("gidx", (64,3,4), (16,16,16), True, [16,3,16]) _assert_grouped_dims("gidx", (128,3,4), (16,4,256), False, [16,3,32]) _assert_grouped_dims("gidx", (4,4,512), (16,4,256), False, [8,4,256]) # prefer group_dim strategy when possible _assert_grouped_dims("gidx", (512,4,2), (8192,2,2), False, [2048,2]) # test splitting globals: len(dims) < len(max) # len(dim) -> len(limited) # 1 -> 2 _assert_grouped_dims("gidx", (128,), (16,16,256), False, [16,8], False) # 1 -> 3 _assert_grouped_dims("gidx", (65536,), (16,16,256), False, [16,16,256], False) # 2 -> 3 _assert_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False) # test when the only divisor is the square root of dim _assert_grouped_dims("gidx", (121,), (12,12,12), False, [11,11], False) # collapse on onto the left most axis _assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), False, [6,4,5]) _assert_grouped_dims("gidx", (2,3,4,5), (32,16,16), True, [20,3,2]) # _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (32,16,16), True, [20,3,Variable("start_pos",1,2)]) # collapse on left-most available axis (the left most is too small) _assert_grouped_dims("gidx", (2,3,4,5), (4,16,16), False, [2,12,5]) _assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), True, [5,12,2]) # _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (16,16,16), False, [Variable("start_pos",1,2)*3,4,5]) # dim too large and not factorable with self.assertRaises(RuntimeError): get_grouped_dims("gidx", (23,), (16,16,16), False,) with self.assertRaises(RuntimeError): get_grouped_dims("gidx", (128,3,4), (16,2,2), False,) # too large for sizes with self.assertRaises(RuntimeError): get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16)) # # variable too large # with self.assertRaises(AssertionError): # get_grouped_dims("gidx", (Variable("start_pos",0,16),3,4), (16,16,16), False,) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") def test_default_global_reversed(self): # shrink so that the dims do not collapse t = Tensor.ones(5, 6, 7).contiguous().realize().shrink(((0, 4), (0, 5), (0, 6))) ast = helper_linearizer_opt(t+1) uops = get_program(ast, opts=[]).uops idxs = dedup([uop for uop in uops if uop.op is Ops.SPECIAL]) idxs = sorted(idxs, key=lambda uop: uop.arg) assert (idxs[0].arg, idxs[0].src[0].arg) == ('gidx0', 6), idxs[0] assert (idxs[1].arg, idxs[1].src[0].arg) == ('gidx1', 5), idxs[1].arg assert (idxs[2].arg, idxs[2].src[0].arg) == ('gidx2', 4), idxs[2].arg def test_sum_collapse(self): t = Tensor([2]).reshape(1, 1).expand(256, 256).sum() sched = [si for si in t.schedule() if si.ast.op is Ops.SINK] # sum_collapse is a full collapse now assert len(sched) == 1 assert not any(u.op is Ops.REDUCE_AXIS for u in sched[0].ast.toposort()), "found reduce in sum collapse" #lin = Kernel(sched[0].ast) #assert not any(u.op is Ops.RANGE for u in lin.linearize().uops), "found loop in sum collapse" def test_assign_fold(self): a = Tensor.ones(4, 4).contiguous().realize() m = Tensor.ones(4, 4).shrink(((1, 2), None)).pad(((1, 2), None)) a.assign(a+m) a.realize() np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.]) def test_where_fold(self): a = Tensor.ones(4, 4).contiguous().realize() b = a.shrink(((1, 2), None)).pad(((1, 2), None)) a.assign(b.where(2, a)) sched = a.schedule() assert len(sched) == 1 sched_copy = sched[:] run_schedule(sched) np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.]) program = get_program(sched_copy[-1].ast, opts=()) assert not any(u.op == Ops.WHERE for u in program.uops), "found where where where should be folded" def test_phi_simplification(self): def helper(t, max_ops=0): ast = helper_linearizer_opt(t) uops = get_program(ast).uops # ignore kernel optimized IF statements for now if if_op:=next((u for u in uops if u.op is Ops.IF), None): uops = uops[:uops.index(if_op)] assert len(set([u.op for u in uops if u.op in {Ops.RANGE, Ops.SPECIAL}])) == 1, "has either specials or ranges, not both" reg_stores = [u for u in uops if u.op is Ops.STORE and isinstance(dt:=u.src[0].dtype, PtrDType) and dt.addrspace == AddrSpace.REG] assert len(reg_stores) == 0, "STORE to reg should have been simplified" # TODO: once uops track min/max this will be fixed #assert len([u for u in uops if u.op is Ops.MAX]) <= max_ops, "no unnecessary MAX ops" helper(Tensor.arange(5.5, (3.5*300), 3.5), max_ops=2) helper(Tensor.arange(-1, -100, -5), max_ops=2) # NOTE: both of these split the reduce (this just wasn't tracked before) #helper(Tensor.arange(-3.2, 6.7, 0.64), max_ops=2) #helper(Tensor.arange(256), max_ops=2) helper(Tensor.arange(255), max_ops=2) @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "broken on ptx for some reason") def test_grouped_store_phis(self): """ float4 acc0 = float4(0.0,0.0,0.0,0.0); { acc0 = // ... } *((device float4*)(data0+alu2)) = float4(acc0.x,acc0.y,acc0.z,acc0.w); simplifies to: *((device float4*)(data0+alu2)) = acc0; """ x, y = Tensor.randn(64,64), Tensor.randn(64,64) out = x.matmul(y) with Context(TC=0): ast = helper_linearizer_opt(out) uops = get_program(ast).uops # check that the float4 cast collapses store_vals = [u.src[1] for u in uops if u.op is Ops.STORE and u.src[0].dtype.addrspace != AddrSpace.REG] for val in store_vals: assert val.dtype == dtypes.float.vec(4) # and val.op is not Ops.VECTORIZE @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") def test_grouped_store_values(self): x = Tensor.randn((4,3,6,6)).realize() out = x.flip((0,1)).contiguous() ast = helper_linearizer_opt(out) store_val = [u.src[1] for u in get_program(ast).uops if u.op is Ops.STORE][0] assert store_val.dtype == dtypes.float.vec(4) and store_val.op is not Ops.VECTORIZE @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") def test_grouped_store_locals_and_globals(self): x, y = Tensor.rand(128, 128), Tensor.rand(128, 128) out = x@y opt = [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 2)] # upcast accs in both reduces ast = helper_linearizer_opt(out, opts=[opt]) def get_recursive(uop): return set.union(set(uop.src), [uop], *[get_recursive(v) for v in uop.src]) uops = get_program(ast, opts=opt).uops local_stores = [u for u in uops if u.op is Ops.STORE and any(x.op is Ops.DEFINE_LOCAL for x in get_recursive(u.src[0]))] global_stores = [u for u in uops if u.op is Ops.STORE and any(x.op is Ops.DEFINE_GLOBAL for x in get_recursive(u.src[0]))] barrier = [u for u in uops if u.op is Ops.BARRIER][0] # check that the float4 cast collapses for all stores for store in local_stores+global_stores: assert store.src[1].dtype.count > 1 # and store.src[2].op is not Ops.VECTORIZE # # check the children's vins # TODO: src ALU are not the same, should it? # assert barrier.src == tuple(local_stores) assert len([u for u in uops if u.op is Ops.IF and u.src[-1] == barrier]) == 1 @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "broken on ptx for some reason") def test_grouped_store_local_only(self): x, y = Tensor.rand(1,128), Tensor.rand(128, 128) r = (x@y).relu() ast = helper_linearizer_opt(r) uops = get_program(ast).uops stores = [u for u in uops if u.op is Ops.STORE and u.src[0].dtype.addrspace != AddrSpace.REG] # the float4 value stores directly in lds and we skip upcast self.assertEqual(stores[0].src[1].dtype, dtypes.float.vec(4)) #assert stores[0].src[-1].op is not Ops.VECTORIZE # the global store doesn't change assert stores[1].src[1].dtype == dtypes.float @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") def test_skip_unmatching_upcasts(self): Tensor.manual_seed(0) c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9600), arg=0, src=()) c1 = c0.view(ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),))) c2 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9600), arg=1, src=()) c3 = c2.view(ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),))) c4 = c3.load() c5 = c1.store(c4) ast = c5.sink() opt = [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=16), Opt(op=OptOps.LOCAL, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=3, arg=2)] helper_linearizer_ast(ast, [Tensor.randn(240*40).realize()], opts=[opt]) out = [u for u in get_program(ast, opts=opt).uops if u.op is Ops.STORE][0] assert out.src[1].op is Ops.VECTORIZE and out.src[1].dtype == dtypes.float.vec(4) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") def test_skip_unmatching_upcasts_with_gep(self): Tensor.manual_seed(0) c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(256), arg=0, src=()) c1 = c0.view(ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),))) c2 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(256), arg=1, src=()) c3 = c2.view(ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),))) c4 = c3.load() c5 = c1.store(c4) ast = c5.sink() opt = [Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=2, arg=2), Opt(op=OptOps.LOCAL, axis=1, arg=8), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=0, arg=2)] helper_linearizer_ast(ast, [Tensor.randn(8*32).realize()], opts=[opt]) out = [u for u in get_program(ast).uops if u.op is Ops.STORE][0] assert out.src[1].op is Ops.VECTORIZE and out.src[1].dtype.count != 1 # *** helpers *** def push_views(ast): return apply_rewrites(ast, rewrites_for_views) def helper_realized_ast(r:Tensor|list[Tensor]) -> tuple[UOp, list[Buffer]]: if isinstance(r, Tensor): r = [r] s = Tensor.schedule(*r) run_schedule(s[:-1]) # run all kernels except the last one assert s[-1].ast.op is Ops.SINK, f"helper_realized_ast expects a SINK {s[-1]}" # now all input buffers in s[-1] should be realized # create fresh buffers for the outputs bufs = [Buffer(x.device, x.size, x.dtype).allocate() if i < len(s[-1].ast.src) else x for i,x in enumerate(s[-1].bufs)] return push_views(s[-1].ast), bufs def helper_linearizer_ast(ast:UOp, inputs:list[Tensor], *args, **kwargs): assert isinstance(ast, UOp), "ast must be UOp" inbufs = [x.uop.base.buffer for x in inputs] outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.st_arg.size, out.src[1].dtype).allocate() for out in ast.src] _helper_linearizer_opt_ast(ast, outbufs+inbufs, *args, **kwargs) def helper_linearizer_opt(r:Tensor|list[Tensor], *args, **kwargs): realized_ast, real_bufs = helper_realized_ast(r) _helper_linearizer_opt_ast(realized_ast, real_bufs, *args, **kwargs) return realized_ast def copyout_outputs(outbufs:list[Buffer]) -> list[np.ndarray]: return [np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in outbufs] def reset_bufs(bufs:list[Buffer]): for buf in bufs: buf.copyin(np.zeros((buf.size, ), dtype=_to_np_dtype(buf.dtype)).data) # Zero to check that all values are filled def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[], apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[], wanna_output=[]): outbufs = real_bufs[:len(realized_ast.src)] device = real_bufs[0].device wanna_output = [np.array(x).flatten() for x in wanna_output] def get_prg(opts): return CompiledRunner(replace(get_program(realized_ast, opts=opts), device=device)) def check_opt(opts): prg = get_prg(opts=opts) reset_bufs(outbufs) prg.exec(real_bufs) for x,want in zip(copyout_outputs(outbufs), wanna_output): np.testing.assert_allclose(x, want, atol=atol, rtol=rtol) # Get baseline if it is not provided, which is not optimized at all. prg = get_prg(opts=()) prg.exec(real_bufs) if len(wanna_output) == 0: wanna_output = copyout_outputs(outbufs) else: for buf,want in zip(copyout_outputs(outbufs), wanna_output): np.testing.assert_allclose(buf, want, atol=atol, rtol=rtol) # Check correctness of handcoded optimiztions. prg = get_prg(opts=None) reset_bufs(outbufs) prg.exec(real_bufs) for buf,want in zip(copyout_outputs(outbufs), wanna_output): np.testing.assert_allclose(buf, want, atol=atol, rtol=rtol) for x in opts: # Check custom transformations if any. check_opt(([Opt(OptOps.TC, 0, (TC_SELECT.value, TC_OPT.value, 1))] if apply_tc else [])+x) if __name__ == '__main__': unittest.main()