# this will be the new test_ops for the next level # schedule confirms the right things are capable of fusing # NOTE: this has overlap with external_test_opt.py import unittest, functools import numpy as np from typing import cast from hypothesis import assume, given, settings, strategies as strat from tinygrad import nn, dtypes, Device, Tensor, Variable from tinygrad.device import is_dtype_supported from tinygrad.dtype import DType, ImageDType from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp from tinygrad.schedule.rangeify import get_rangeify_map, Kernel from tinygrad.engine.schedule import create_schedule_with_vars from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule class KernelCountException(Exception): pass def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Tensor]|None=None, filter_sink=True): if to_prerealize: with Context(DEBUG=0, TRACK_MATCH_STATS=0): Tensor.realize(*to_prerealize) if isinstance(t, Tensor): sched = t.schedule() elif isinstance(t, list) and isinstance(t[0], Tensor): sched = Tensor.schedule(*t) else: assert isinstance(t, UOp), f"can't schedule {t}" sink = UOp.sink(t) if t.op is not Ops.SINK else t becomes_map = get_rangeify_map(sink) sched, _ = create_schedule_with_vars(sink.substitute(becomes_map)) # test lowering all the ScheduleItems to ExecItems kernel_cnt = len([si for si,ei in lower_schedule(sched.copy()) if isinstance(ei.prg, CompiledRunner) or not filter_sink]) if kernel_cnt != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}") if DEBUG >= 3: for i,s in enumerate(sched): print("kernel", i+1) print(s.ast) raise KernelCountException(f"{kernel_cnt} != {allowed}") return sched def _realize_weights(m): for p in nn.state.get_parameters(m): p.realize() def _test_conv2d(allowed:int, dtype:DType=dtypes.float): old_default_float, dtypes.default_float = dtypes.default_float, dtype dtypes.default_float = dtype Tensor.manual_seed(0) BS, CIN = 2, 3 img = Tensor.randn(BS, CIN, 64, 64, requires_grad=True).realize() w = Tensor.uniform(16, CIN, 3, 3, requires_grad=True).realize() ret = Tensor.conv2d(img, w).relu().mean().backward() dtypes.default_float = old_default_float s = Tensor.schedule(ret, img.grad, w.grad) run_schedule(s.copy()) cnt = len([si for si in s if si.ast.op is Ops.SINK]) assert cnt == allowed, f"expected {allowed} kernels, got {cnt}" if getenv("CHECK", 1): import torch ref_img = torch.tensor(img.numpy(), requires_grad=True) ref_w = torch.tensor(w.numpy(), requires_grad=True) torch.nn.functional.conv2d(ref_img, ref_w).relu().mean().backward() assert ref_img.grad is not None and ref_w.grad is not None and img.grad is not None and w.grad is not None np.testing.assert_allclose(img.grad.numpy(), ref_img.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2) np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2) class TestSchedule(unittest.TestCase): def test_arange_avgpool2d(self, kcount=1): x = Tensor.arange(25).reshape(1,1,5,5).cast(dtypes.float32) t = x.avg_pool2d(padding=1) sched = t.schedule() self.assertEqual(len(sched), kcount) run_schedule(sched) import torch torch_out = torch.nn.functional.avg_pool2d(torch.arange(25).reshape(1,1,5,5).float(), kernel_size=(2,2), padding=1).numpy() np.testing.assert_allclose(t.numpy(), torch_out) def test_arange_avgpool2d_fused_noopt(self): with Context(NOOPT=1): self.test_arange_avgpool2d(kcount=1) # linearizer error @unittest.skip("recursion error no longer raised") @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "needs supports_float4 to fail") def test_arange_avgpool2d_fused(self): with self.assertRaises(RecursionError): with Context(NOOPT=0): self.test_arange_avgpool2d(kcount=1) # when we're fusing a reduce, all ReduceOps must have the same N in the dimensions # all permutes, reshapes, expands and shrinks push through the reduce def test_arange_sum(self): a = Tensor.arange(6).reshape(3, 2).sum(axis=1) run_schedule(check_schedule(a, 1)) self.assertListEqual(a.tolist(), [1, 5, 9]) def test_arange_sum_alt(self): a = (Tensor.arange(5).reshape(1,5).expand(6,5)*Tensor(2)).reshape(1,6,5).sum(axis=2) run_schedule(check_schedule(a, 1)) np.testing.assert_equal(a.numpy(), 20) def test_permute_arange(self): a = Tensor.arange(6).reshape(6, 1, 1).permute(2, 0, 1).sum(axis=1) run_schedule(check_schedule(a, 1)) self.assertListEqual(a.tolist(), [[15]]) @unittest.skipIf(Device.DEFAULT == "CPU", "devices must mismatch") def test_error_on_device_mismatch(self): a = Tensor.empty(10) b = Tensor.empty(10, device="CPU") c = a+b with self.assertRaisesRegex(RuntimeError, "all buffers must be on the same device"): check_schedule(c, 1) @unittest.skipIf(Device.DEFAULT == "CPU", "devices must mismatch") def test_error_on_device_mismatch_alt(self): a = Tensor.empty(10) b = Tensor.empty((1,), device="CPU").expand(10).contiguous() c = a+b with self.assertRaisesRegex(RuntimeError, "all buffers must be on the same device"): check_schedule(c, 2) @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") def test_expand_buffer_before_cast(self): a = Tensor.randn(4, 2, 1).realize().permute((1, 0, 2)) b = a.cast(dtypes.half).expand((2, 4, 4))+2 run_schedule(check_schedule(b, 1)) np.testing.assert_allclose(b.numpy(), np.broadcast_to(a.numpy().astype(np.float16), (2, 4, 4))+2) def test_indexing_scalars_simple(self): X = Tensor.randn(2, 2).realize() xt = X[Tensor(1)][Tensor(0)] run_schedule(check_schedule(xt, 1)) np.testing.assert_equal(xt.numpy(), X.numpy()[1][0]) @unittest.skipIf(CI and Device.DEFAULT == "NV", "crashes on NV CI") def test_add_chain_buffers(self): N = 31 with Context(TRACK_MATCH_STATS=0, DEBUG=0): bufs = [Tensor(i).reshape((1,)).contiguous().realize() for i in range(N)] for X in range(1,N): root = bufs[0] for i in range(1,N,X): root = root + functools.reduce(lambda a,b:a+b, bufs[i:i+X]) self.assertEqual(root.item(), sum(range(N))) @given(strat.sampled_from(range(2,4)), strat.sampled_from(range(2,4)), strat.sampled_from(range(0,4)), strat.sampled_from(range(0,4))) @settings(deadline=None) def test_indexing_scalars(self, x, y, a, b): assume(a1 children but should still fuse # run_schedule(check_schedule(out, 1)) run_schedule(check_schedule(out, 2)) np.testing.assert_allclose(out.numpy(), \ (c.numpy()*a.numpy().sum(axis=-1,keepdims=True)).sum(-1) + (b.numpy()*a.numpy().sum(axis=-1,keepdims=True)).sum(-1), atol=1e-4, rtol=1e-4) def test_reduce_expand_reduce_fusion(self): Tensor.manual_seed(0) a = Tensor.randn(4, 32).realize() out = (a+a.sum(-1, keepdim=True)).sum(-1) # run_schedule(check_schedule(out, 1)) run_schedule(check_schedule(out, 2)) np.testing.assert_allclose(out.numpy(), (a.numpy()+a.numpy().sum(axis=-1,keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4) def test_reduce_expand_reduce_expand_fusion(self): Tensor.manual_seed(0) a = Tensor.randn(4, 32).realize() out = a+(a+a.sum(-1,keepdim=True)).sum(-1, keepdim=True) # run_schedule(check_schedule(out, 2)) run_schedule(check_schedule(out, 3)) np.testing.assert_allclose(out.numpy(), \ a.numpy()+(a.numpy()+a.numpy().sum(axis=-1,keepdims=True)).sum(axis=-1,keepdims=True), atol=1e-4, rtol=1e-4) def test_branching_reduces_and_expands_fusion(self): Tensor.manual_seed(0) a = Tensor.randn(4, 32).realize() out0 = a+a.sum(-1, keepdim=True) out1 = out0.sum(-1) # run_schedule(check_schedule(out, 2)) run_schedule(check_schedule([out0, out1], 3)) np.testing.assert_allclose(out0.numpy(), a.numpy()+a.numpy().sum(axis=-1,keepdims=True), atol=1e-4, rtol=1e-4) np.testing.assert_allclose(out1.numpy(), (a.numpy()+a.numpy().sum(axis=-1,keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4) def test_multireduce_fusion_simple_sequential(self): Tensor.manual_seed(0) x = Tensor.randn(4, 32).realize() y = Tensor.randn(4, 32).realize() out = (y + x.sum(axis=-1, keepdim=True)).sum(axis=-1) # run_schedule(check_schedule(out, 1)) run_schedule(check_schedule(out, 2)) np.testing.assert_allclose(out.numpy(), (y.numpy() + x.numpy().sum(axis=-1, keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4) def test_multireduce_fusion_simple_parallel(self): Tensor.manual_seed(0) x = Tensor.randn(4, 32).realize() y = Tensor.randn(4, 32).realize() out = y.sum(axis=-1) + x.sum(axis=-1) run_schedule(check_schedule(out, 1)) np.testing.assert_allclose(out.numpy(), y.numpy().sum(axis=-1) + x.numpy().sum(axis=-1), atol=1e-4, rtol=1e-4) def test_multireduce_fusion_sequential(self): Tensor.manual_seed(0) x = Tensor.randn(4, 32).realize() out = x.std(-1) # run_schedule(check_schedule(out, 1)) run_schedule(check_schedule(out, 2)) np.testing.assert_allclose(out.numpy(), x.numpy().std(axis=-1, ddof=1), atol=1e-4, rtol=1e-4) def test_multireduce_fusion_parallel(self): Tensor.manual_seed(0) x = Tensor.randn(4, 32).realize() y = Tensor.randn(4, 32).realize() out = x.std(-1) + y.std(-1) # run_schedule(check_schedule(out, 1)) run_schedule(check_schedule(out, 3)) np.testing.assert_allclose(out.numpy(), x.numpy().std(axis=-1, ddof=1) + y.numpy().std(axis=-1, ddof=1), atol=1e-4, rtol=1e-4) def test_multireduce_diffops_sequential(self): Tensor.manual_seed(0) x = Tensor.randn(4, 32).realize() out = (x - x.max(-1, keepdim=True)).sum(-1) # run_schedule(check_schedule(out, 1)) run_schedule(check_schedule(out, 2)) np.testing.assert_allclose(out.numpy(), (x.numpy() - x.numpy().max(axis=-1, keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4) def test_multireduce_fusion_diffops_parallel(self): Tensor.manual_seed(0) x = Tensor.randn(4, 32).realize() y = Tensor.randn(4, 32).realize() out = x.sum(-1) + y.max(-1) run_schedule(check_schedule(out, 1)) np.testing.assert_allclose(out.numpy(), x.numpy().sum(axis=-1) + y.numpy().max(axis=-1), atol=1e-4, rtol=1e-4) def test_multireduce_fusion_sequential_and_parallel(self): Tensor.manual_seed(0) x = Tensor.randn(4, 32).realize() y = Tensor.randn(4, 32).realize() mu = (x - x.max(axis=-1, keepdim=True)).mean(axis=-1, keepdim=True) + (y - y.max(axis=-1, keepdim=True)).mean(axis=-1, keepdim=True) out = [((x - mu).square().sum(-1)/x.shape[-1]).sqrt(), ((y - mu).square().sum(-1)/y.shape[-1]).sqrt()] np_mu = (x.numpy() - x.numpy().max(axis=-1, keepdims=True)).mean(axis=-1, keepdims=True) + \ (y.numpy() - y.numpy().max(axis=-1, keepdims=True)).mean(axis=-1, keepdims=True) # run_schedule(check_schedule(out, 1)) run_schedule(check_schedule(out, 5)) np.testing.assert_allclose(out[0].numpy(), np.sqrt(np.square(x.numpy() - np_mu).sum(-1)/x.shape[-1]), atol=1e-4, rtol=1e-4) np.testing.assert_allclose(out[1].numpy(), np.sqrt(np.square(y.numpy() - np_mu).sum(-1)/y.shape[-1]), atol=1e-4, rtol=1e-4) def test_multimatmul_fusion(self): Tensor.manual_seed(0) a,b = Tensor.randn(4, 64).realize(), Tensor.rand(64,8).realize() c,d = Tensor.randn(4, 64).realize(), Tensor.rand(64,8).realize() out = a@b + c@d run_schedule(check_schedule(out, 1)) np.testing.assert_allclose(out.numpy(), a.numpy()@b.numpy() + c.numpy()@d.numpy(), atol=1e-4, rtol=1e-4) def test_softmax_fusion(self): Tensor.manual_seed(0) x = Tensor.randn(4, 12, 64, 64).realize() out = x.softmax() run_schedule(check_schedule(out, 3)) expected = (x_exp:=np.exp(x.numpy()-x.numpy().max(-1, keepdims=True)))/x_exp.sum(-1, keepdims=True) np.testing.assert_allclose(out.numpy(), expected, atol=1e-4, rtol=1e-4) @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") def test_softmax_upcast(self): # input half, softmax in float Tensor.manual_seed(0) x = Tensor.randn(4, 12, 64, 64, dtype=dtypes.half).realize() out = x.softmax(dtype=dtypes.float) sched = out.schedule() self.assertEqual(len(sched), 3) self.assertEqual(sched[0].bufs[0].dtype, dtypes.float) # input float, softmax in float Tensor.manual_seed(0) x = Tensor.randn(4, 12, 64, 64, dtype=dtypes.float).realize() out = x.softmax(dtype=dtypes.float) sched = out.schedule() self.assertEqual(len(sched), 3) self.assertEqual(sched[0].bufs[0].dtype, dtypes.float) def test_softmax_backward(self): Tensor.manual_seed(0) x = Tensor.randn(4, 12, 64, 64, requires_grad=True).realize() x.softmax().sum().backward() run_schedule(check_schedule(x.grad, 4)) def test_layernorm_onelayer_fusion(self): Tensor.manual_seed(0) layer = nn.LayerNorm([10, 10]) layer.weight = Tensor.randn(10,10).realize() layer.bias = Tensor.randn(10,10).realize() x = Tensor.randn(20, 5, 10, 10).realize() out = layer(x) # run_schedule(check_schedule(out, 2)) run_schedule(check_schedule(out, 3)) y = (x.numpy() - x.numpy().mean(layer.axis, keepdims=True)) expected = y / np.sqrt((y*y).mean(layer.axis, keepdims=True) + layer.eps) np.testing.assert_allclose(out.numpy(), expected * layer.weight.numpy() + layer.bias.numpy(), atol=1e-4, rtol=1e-4) def test_scaled_dot_product_attention_fusion(self): x, y, z, m = (Tensor.empty(32, 8, 16, 16) for _ in range(4)) out = Tensor.scaled_dot_product_attention(x, y, z, attn_mask=m) check_schedule(out, 4) def test_scaled_dot_product_attention_causal_fusion(self): x, y, z = (Tensor.empty(32, 8, 16, 16) for _ in range(3)) out = Tensor.scaled_dot_product_attention(x, y, z, is_causal=True) check_schedule(out, 4) def test_adam_step_fusion(self): with Tensor.train(): x = Tensor.empty(4, 64, 32) layer = nn.Linear(32, 32*4) _realize_weights(layer) opt = nn.optim.Adam(nn.state.get_parameters(layer), lr=1e-4) layer(x).relu().sum().backward() check_schedule(opt.schedule_step(), 19) def test_adam_conv_fuse(self): with Tensor.train(): img = Tensor.empty(2,3,4,4) c1 = nn.Conv2d(3,32,3) _realize_weights(c1) opt = nn.optim.Adam(nn.state.get_parameters(c1), lr=1e-4) opt.zero_grad() c1(img).relu().sum().backward() check_schedule(opt.schedule_step(), 19) def test_adam_2convs_fuse(self): with Tensor.train(): img = Tensor.empty(2,3,4,4) c1 = nn.Conv2d(3,16,3,bias=False) c2 = nn.Conv2d(16,32,2,bias=False) _realize_weights([c1, c2]) opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4) opt.zero_grad() c2(c1(img).relu()).relu().sum().backward() check_schedule(opt.schedule_step(), 21) def test_sgd_conv_fuse(self): with Tensor.train(): img = Tensor.empty(2,3,4,4) c1 = nn.Conv2d(3,32,3) _realize_weights(c1) opt = nn.optim.SGD(nn.state.get_parameters(c1)) opt.zero_grad() c1(img).relu().sum().backward() check_schedule(opt.schedule_step(), 5) # TODO: 3? def test_sgd_2convs_fuse(self): with Tensor.train(): img = Tensor.empty(2,3,4,4) c1 = nn.Conv2d(3,16,3,bias=False) c2 = nn.Conv2d(16,32,2,bias=False) _realize_weights([c1, c2]) opt = nn.optim.SGD(nn.state.get_parameters([c1, c2])) opt.zero_grad() c2(c1(img).relu()).relu().sum().backward() check_schedule(opt.schedule_step(), 7) @unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong") def test_fold_2convs_sgd_nesterov_momentum_wd(self): with Tensor.train(): img = Tensor.empty(2,3,4,4) c1 = nn.Conv2d(3,16,3,bias=False) c2 = nn.Conv2d(16,32,2,bias=False) _realize_weights([c1, c2]) opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]), nesterov=True, momentum=0.9, weight_decay=0.1) opt.zero_grad() c2(c1(img).relu()).relu().sum().backward() check_schedule(opt.schedule_step(), 13) def test_sgd_4convs_fuse(self): with Tensor.train(): img = Tensor.empty(2,3,16,16) c1 = nn.Conv2d(3,4,3,bias=False) c2 = nn.Conv2d(4,8,3,bias=False) c3 = nn.Conv2d(8,16,3,bias=False) c4 = nn.Conv2d(16,32,3,bias=False) _realize_weights([c1, c2, c3, c4]) opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4])) opt.zero_grad() c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward() check_schedule(opt.schedule_step(), 15) def test_sgd_4convs_fuse_conv_bw(self): with Tensor.train(): img = Tensor.empty(2,3,16,16) c1 = nn.Conv2d(3,4,3,bias=False) c2 = nn.Conv2d(4,8,3,bias=False) c3 = nn.Conv2d(8,16,3,bias=False) c4 = nn.Conv2d(16,32,3,bias=False) _realize_weights([c1, c2, c3, c4]) opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4])) opt.zero_grad() c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward() check_schedule(opt.schedule_step(), 15) def test_reduce_simple_chase(self): a = Tensor.empty(4, 4, 4) r = a.sum(0) + 6 b = r.sum(0) * 4 c = r.sum(1) * 2 check_schedule([b, c], 3) def test_multireduce_simple_chase(self): Tensor.manual_seed(0) a = Tensor.randn(4, 4, 4).realize() r = (a + (a.sum(0, keepdim=True) + 6)).sum(0) * 2 b = r.sum(0) + 8 c = r.sum(1) + 12 np_r = (a.numpy() + (a.numpy().sum(0) + 6)).sum(0) * 2 # schedule = check_schedule([b,c], 3) # self.assertIs(schedule[0].ast[0].src[0].arg, Ops.MUL) schedule = check_schedule([b,c], 4) run_schedule(schedule) np.testing.assert_allclose(b.numpy(), np_r.sum(0) + 8, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(c.numpy(), np_r.sum(1) + 12, atol=1e-4, rtol=1e-4) def test_push_permute_chase(self): a = Tensor.empty(4, 4, 4) b = Tensor.empty(4, 4) r = a.sum(2) + b d = r.T * 4 e = r * d check_schedule([d, e], 3) def test_multireduce_push_permute_chase(self): Tensor.manual_seed(0) a = Tensor.randn(4, 4, 4).realize() b = Tensor.randn(4, 4).realize() r = a.sum(2) + b d = r.T * 4 e = r * (d + a).sum(2) schedule = check_schedule([d, e], 3) # make sure it doesn't fuse run_schedule(schedule) np.testing.assert_allclose(d.numpy(), (a.numpy().sum(2) + b.numpy()).T * 4, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(e.numpy(), (a.numpy().sum(2) + b.numpy()) * (d.numpy() + a.numpy()).sum(2), atol=1e-4, rtol=1e-4) def test_push_shrink_chase(self): a = Tensor.empty(16, 16) b = Tensor.empty(4) c = Tensor.empty(16, ) r = a.sum(1) + c d = r[:4] * b check_schedule(d, 1) def test_multireduce_push_shrink_chase(self): Tensor.manual_seed(0) a = Tensor.randn(16, 16).realize() b = Tensor.randn(4).realize() c = Tensor.randn(16, ).realize() d = Tensor.randn(16, 16).realize() r = a.sum(1) + c out = r[:4] * b + d.sum(1)[:4] schedule = check_schedule(out, 1) run_schedule(schedule) np.testing.assert_allclose(out.numpy(), (a.numpy().sum(1) + c.numpy())[:4] * b.numpy() + d.numpy().sum(1)[:4], atol=1e-4, rtol=1e-4) def test_midreduce_nochase(self): a = Tensor.empty(16, 16) b = (a.sum(0) + a.max(1)) + 2 check_schedule(b, 1) def test_multireduce_midreduce_nochase(self): Tensor.manual_seed(0) a = Tensor.randn(16, 16).realize() b = (a.sum(0)+a.max(0) + a.max(1)+a.sum(1)) + 2 schedule = check_schedule(b, 1) run_schedule(schedule) np.testing.assert_allclose(b.numpy(), a.numpy().sum(0)+a.numpy().max(0) + a.numpy().max(1)+a.numpy().sum(1)+2, atol=1e-4, rtol=1e-4) # pattern in test_transformer def test_partial_fuse1(self): Tensor.manual_seed(0) a = Tensor.randn(16, 16).realize() b = Tensor.randn(16, 16).realize() c = a.sum() + 2 d = (a.sum() - b.sum()) * 4 # run_schedule(check_schedule([c, d], 1)) run_schedule(check_schedule([c, d], 2)) np.testing.assert_allclose(c.numpy(), a.numpy().sum()+2, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(d.numpy(), (a.numpy().sum() - b.numpy().sum()) * 4, atol=1e-4, rtol=1e-4) # pattern in conv def test_partial_fuse2(self): Tensor.manual_seed(0) a = Tensor.randn(16, 16).realize() b = Tensor.randn(16, 16).realize() c = a.sum() + 2 d = b.sum() - c # run_schedule(check_schedule([c, d], 1)) run_schedule(check_schedule([c, d], 2)) np.testing.assert_allclose(c.numpy(), a.numpy().sum()+2, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(d.numpy(), b.numpy().sum()-(a.numpy().sum()+2), atol=1e-4, rtol=1e-4) # pattern in adam def test_partial_fuse3(self): Tensor.manual_seed(0) a = Tensor.randn(16, 16).realize() b = Tensor.randn(16, 16).realize() c = a.sum() + 2 d = a.sum() * 2 e = c * d f = b.sum() - e # run_schedule(check_schedule([c, d, e, f], 1)) run_schedule(check_schedule([c, d, e, f], 4)) np.testing.assert_allclose(c.numpy(), c_np:=a.numpy().sum()+2, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(d.numpy(), d_np:=a.numpy().sum()*2, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(e.numpy(), e_np:=c_np*d_np, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(f.numpy(), b.numpy().sum() - e_np, atol=1e-4, rtol=1e-4) def test_partial_fuse4(self): Tensor.manual_seed(0) a = Tensor.randn(16, 16).realize() b = Tensor.randn(16, 16).realize() c = a.sum() + 2 d = a.sum() * 2 e = c * d f = (b - d).sum() - e # run_schedule(check_schedule([c, d, e, f], 1)) run_schedule(check_schedule([c, d, e, f], 4)) np.testing.assert_allclose(c.numpy(), c_np:=a.numpy().sum()+2, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(d.numpy(), d_np:=a.numpy().sum()*2, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(e.numpy(), e_np:=c_np*d_np, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(f.numpy(), (b.numpy()-d_np).sum()-e_np, atol=1e-4, rtol=1e-4) def test_pad_reduce_safe(self): Tensor.manual_seed(0) a = Tensor.rand(3, 4, 5).realize() b = Tensor.rand(3, 4, 5).realize() out = (a + b).pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum().contiguous() run_schedule(check_schedule(out, 1)) np.testing.assert_allclose(out.numpy(), np.pad(a.numpy()+b.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=1e-5, rtol=1e-6) def test_multireduce_pad_reduce_safe(self): Tensor.manual_seed(0) a = Tensor.randn(3, 4, 5).realize() b = Tensor.randn(3, 4, 5).realize() out = (a.pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum(keepdim=True)+b.pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum()).contiguous() run_schedule(check_schedule(out, 1)) np.testing.assert_allclose(out.numpy(), np.pad(a.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(keepdims=True) + \ np.pad(b.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=1e-4, rtol=1e-4) def test_pad_reduce_unsafe(self): Tensor.manual_seed(0) a = Tensor.rand(3, 4, 5).realize() out = a.log2().pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum().contiguous() run_schedule(check_schedule(out, 1)) np.testing.assert_allclose(out.numpy(), np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=1e-5, rtol=1e-6) def test_multireduce_pad_reduce_unsafe(self): Tensor.manual_seed(0) a = Tensor.randn(3, 4, 5).abs().realize() b = Tensor.randn(3, 4, 5).abs().realize() out = (a.log2().pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum()+b).abs().log2().pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum().contiguous() # run_schedule(check_schedule(out, 1)) run_schedule(check_schedule(out, 2)) np.testing.assert_allclose(out.numpy(), np.pad(np.log2(np.abs(np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum() + \ b.numpy())), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=3e-4, rtol=1e-5) def test_shrink_pad_safe(self): a = Tensor.ones((3, )).contiguous().realize() b = Tensor.ones((3, )).contiguous().realize() out = (a + b).shrink(((0, 1),)).pad(((0, 1),)).contiguous() run_schedule(check_schedule(out, 1)) np.testing.assert_equal(out.numpy(), [2, 0]) def test_shrink_pad_unsafe(self): a = Tensor.ones((3, )).contiguous().realize() out = a.exp2().shrink(((0, 1),)).pad(((0, 1),)).contiguous() run_schedule(check_schedule(out, 1)) np.testing.assert_equal(out.numpy(), [2, 0]) def test_base_change_shrink_pad(self): a = Tensor.ones(3, 3).contiguous().realize() b = a.exp2() c = b[:-1, :-1] d = c.pad(((0, 1), (0, 1))) * 2 run_schedule(check_schedule(d, 1)) np.testing.assert_equal(d.numpy(), np.pad(np.exp2(a.numpy())[:-1, :-1], ((0, 1), (0, 1)))*2) def test_base_change_expand_pad(self): a = Tensor.ones(3, 3).contiguous().realize() b = a.exp2() c = b[:, None, :] d = c.pad(((0, 0), (1, 1), (0, 0))) * 2 run_schedule(check_schedule(d, 1)) np.testing.assert_equal(d.numpy(), np.pad(np.exp2(a.numpy())[:, None, :], ((0, 0), (1, 1), (0, 0)))*2) def test_fuse_arange_pad_replicate_mode(self): x = Tensor.empty(3,3,3,3, requires_grad=True) y = x.pad((-1,2,2,-1), mode="replicate") dx = y.sum().gradient(x)[0] sched = check_schedule(dx, 1) run_schedule(sched) np.testing.assert_allclose(dx.numpy(), [[[[0.,3.,9.],[0,1.,3.],[0.,0.,0.]]]*3]*3) def test_fuse_arange_avg_pool2d_ceil_mode(self): x = Tensor.avg_pool2d(Tensor.empty(1,1,6,6), kernel_size=(3,3), padding=1, stride=3, ceil_mode=True) sched = check_schedule(x, 1) self.assertEqual(len([x for x in sched[0].ast.backward_slice_with_self if x.op is Ops.REDUCE]), 1) def test_fuse_arange_pad_circular_mode_bw(self): x = Tensor.empty(1,1,5,5,5) out = x.pad((1,2,3,5,1,2), mode="circular") g = out.sum().gradient(x)[0] sched = check_schedule(g, 1) self.assertEqual(len([x for x in sched[0].ast.backward_slice_with_self if x.op is Ops.REDUCE]), 0) # TODO like openpilot with imagef @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") def test_base_change_expand_expand(self): a = Tensor.ones(4, 4).contiguous().realize() b = a.cast(dtypes.half).expand(2, 4, 4) c = b.cast(dtypes.int).expand(2, 2, 4, 4) run_schedule(check_schedule(c, 1)) np.testing.assert_equal(c.numpy(), np.ones(((2, 2, 4, 4)), dtype=np.int32)) def test_base_change_pad_expand(self): a = Tensor.full((4, 4), 1.).contiguous().realize() b = Tensor.full((4, 4), 2.).contiguous().realize() c = (a + b).pad(((1, 1), (1, 1))) d = c.cast(dtypes.int).expand((2, 6, 6)) * 4 run_schedule(check_schedule(d, 1)) c_np = np.pad((np.full((4, 4), 2., dtype=np.float32) + np.full((4, 4), 1., dtype=np.float32)), ((1, 1), (1, 1)), constant_values=0.0) np.testing.assert_equal(d.numpy(), np.broadcast_to(c_np.astype(np.half), (2, *c_np.shape)) * 4) def test_pad_reduce_unsafe_multiview_st(self): P = Tensor.ones(3, 3).contiguous() sums = P.sum(axis=1, keepdim=True) P /= sums p = P[0] p = p.pad(((1, 0), )) p = p.repeat([2]) run_schedule(check_schedule(p, 3)) tiny_ret = p.numpy() P = np.ones((3, 3), dtype=np.float32) sums = P.sum(axis=1, keepdims=True) P /= sums p = P[0] p = np.pad(p, (1, 0), 'constant') p = np.tile(p, 2) np.testing.assert_allclose(tiny_ret, p) def test_bitcast_fuses(self): x = Tensor.empty(1, dtype=dtypes.float32) a = x.exp2().bitcast(dtypes.int32) b = x.bitcast(dtypes.int32) check_schedule(a+b, 1) # this should fuse when it makes sense @unittest.skip("disabling subbuffer manually isn't supported anymore") def test_bitcast_disable_subbufer(self): x = cast(UOp, Tensor.empty(1, dtype=dtypes.float32).realize().uop) a = x.alu(Ops.EXP2).cast(dtypes.int32, True, allow_buffer_view=False) b = x.cast(dtypes.int32, True, allow_buffer_view=False) b = a.alu(Ops.ADD, b) check_schedule(b, 1) def test_reduceop_reshape_dont_push(self): Tensor.manual_seed(0) x = Tensor.randn(10, 20).realize() out = x.argmax(1) run_schedule(check_schedule(out, 2)) def test_conv2d(self): _test_conv2d(5 if SPLIT_REDUCEOP else 4) def test_conv2d_fused(self): _test_conv2d(5 if SPLIT_REDUCEOP else 4) def test_resnet_conv2d(self): x = Tensor.empty(1, 8, 32, 32) w1 = Tensor.empty(8, 8, 3, 3) w2 = Tensor.empty(8, 8, 1, 1) out = x.conv2d(w1).conv2d(w2) check_schedule(out, 2) @unittest.skipUnless(is_dtype_supported(dtypes.half) and is_dtype_supported(dtypes.ulong), "need half and ulong") def test_conv2d_half(self): _test_conv2d(5 if SPLIT_REDUCEOP else 4, dtype=dtypes.half) @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") @unittest.skipIf(Device.DEFAULT == "WEBGPU", "Causes other tests to fail") def test_conv2d_fused_half(self): _test_conv2d(5 if SPLIT_REDUCEOP else 4, dtype=dtypes.half) def test_schedule_mem_used(self): base = GlobalCounters.mem_used Tensor.ones(256).contiguous().realize() Tensor.ones(5, 5).contiguous().schedule() self.assertEqual(GlobalCounters.mem_used-base, 0) @unittest.skip("TODO: this is consistently creating non reproducible failures") def test_schedule_mem_used_with_inputs(self): base = GlobalCounters.mem_used x = Tensor.ones(256).contiguous().realize() (x+Tensor.ones(256).contiguous()).schedule() self.assertEqual(GlobalCounters.mem_used-base, 1024) def test_const_schedule(self): constv = Tensor.empty(2, 2).uop.const_like(10) check_schedule(constv, 0) def test_const_schedule_contig(self): constv = Tensor.empty(2, 2).uop.const_like(10).contiguous() check_schedule(constv, 1) @unittest.skipIf(Device.DEFAULT != "CL", "image only supported on CL") def test_image_matmul(self): with Context(IMAGE=2): x = Tensor.randn((9, 9)).realize() y = Tensor.randn((9, 9)).realize() out = x@y run_schedule(check_schedule(out, 3)) np.testing.assert_allclose(out.numpy(), x.numpy()@y.numpy(), atol=1e-4, rtol=1e-4) self.assertIsInstance(out.dtype, ImageDType) self.assertIsNotNone(out.uop.base.realized) self.assertIsInstance(out.uop.base.realized.dtype, ImageDType) def _test_fusion(self, shapes, f, cnt): with Context(DEBUG=0, TRACK_MATCH_STATS=0): args = [Tensor.randn(s).realize() for s in shapes] run_schedule(check_schedule(compare:=f(*args), cnt)) if getenv("COMPARE", 1): import torch good = f(*[torch.tensor(x.numpy()) for x in args]) np.testing.assert_allclose(compare.numpy(), good.numpy(), atol=1e-4, rtol=1e-4) def test_late_fusion_simple(self): self._test_fusion([(4, 4), (4, 1)], lambda a,b:a.sum(1, keepdim=True)+b, 1) def test_late_fusion_post_reshape(self): self._test_fusion([(4, 4), (1, 4)], lambda a,b:a.sum(1).reshape(b.shape)+b, 1) def test_late_fusion_post_permute(self): self._test_fusion([(4, 6, 4), (4, 4, 1)], lambda a,b:a.sum(1, keepdim=True).permute((2, 0, 1))+b, 1) def test_late_fusion_double_transpose(self): self._test_fusion([(32, 16, 1)], lambda a:(a.expand(32, 16, 16).sum((2,), keepdim=True).permute((1, 0, 2))+2).permute((1, 0, 2)).contiguous(), 1) def test_late_fusion_post_expand(self): self._test_fusion([(32, 32)], lambda a:a-a.sum(1), 2) def test_cast_padded_view(self): a = Tensor.arange(4).reshape(1, 4) casted_view = a.pad(((0, 1), (0, 0))).cast(dtypes.float) casted_view.realize() self.assertEqual(casted_view.uop.base.realized.size, 8) contig = casted_view.contiguous().realize() self.assertEqual(contig.uop.base.realized.size, 8) self.assertListEqual(contig.tolist(), [[0.0, 1.0, 2.0, 3.0], [0.0, 0.0, 0.0, 0.0]]) # NOTE: we only reorder CAST if it's an EXPAND def test_cast_after_shrink(self): a = Tensor.arange(4).reshape(1, 4) casted_view = a.shrink(((0, 1), (0, 2))).cast(dtypes.float) casted_view.realize() self.assertEqual(casted_view.uop.base.realized.size, 2) realized_view = casted_view.contiguous().realize() self.assertEqual(realized_view.uop.base.realized.size, 2) self.assertListEqual(realized_view.tolist(), [[0, 1]]) def test_cast_const_view(self): a = Tensor.ones((4, 4), dtype=dtypes.float32) casted_view = a.cast(dtypes.int32) run_schedule(check_schedule(casted_view, 0)) self.assertIsNone(casted_view.uop.base.realized) realized_const_view = casted_view.contiguous() run_schedule(check_schedule(realized_const_view, 1)) self.assertListEqual(realized_const_view.tolist(), [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]) @given(strat.sampled_from(dtypes.all), strat.sampled_from(dtypes.all)) @unittest.skip("kernel count depends on input") def test_cast_padded_const(self, dt1, dt2): assume(is_dtype_supported(dt1) and is_dtype_supported(dt2)) a = Tensor(1, dtype=dt1).reshape(1, 1).pad(((1, 1), None)) casted_view = a.cast(dt2) run_schedule(check_schedule(casted_view, 0)) realized_const_view = casted_view.contiguous() run_schedule(check_schedule(realized_const_view, 1)) np.testing.assert_equal(realized_const_view.numpy(), [[0], [1], [0]]) def test_simple_indexing(self): X = Tensor.randn(10, 10).realize() idxs = Tensor([0, 2]).realize() xt = X[idxs] run_schedule(check_schedule(xt, 1)) np.testing.assert_equal(xt.numpy(), X.numpy()[idxs.numpy()]) def test_simple_indexing_alt(self): X = Tensor.arange(16).reshape(4, 4) xt = X[[1, 2], [-1, 2]] run_schedule(check_schedule(xt, 1)) np.testing.assert_equal(xt.numpy(), (np.arange(16).reshape(4, 4))[[1, 2], [-1, 2]]) def test_advanced_indexing(self): X = Tensor.arange(10)+1 xt = X[[0, -1]] run_schedule(check_schedule(xt, 1)) np.testing.assert_equal(xt.numpy(), (np.arange(10)+1)[[0, -1]]) def test_advanced_indexing_alt(self): X = Tensor.arange(6).reshape(3, 2)+1 xt = X[[Tensor([2]), Tensor([1])]] run_schedule(check_schedule(xt, 1)) np.testing.assert_equal(xt.numpy(), 6) def test_advanced_simple_indexing_combined(self): X = Tensor.arange(16).reshape(4, 4) xt = X[1:2, [-1, 2]] run_schedule(check_schedule(xt, 1)) def test_push_through_reshape(self): Tensor.manual_seed(0) x = Tensor.randn(10, 20).realize() out = x.argmax(1) run_schedule(check_schedule(out, 2)) np.testing.assert_allclose(out.numpy(), np.argmax(x.numpy(), 1)) def test_arange_push_through_expand(self): Tensor.manual_seed(0) a = Tensor.arange(4,) b = Tensor.randn(4, 4).realize() out = (a+b).sum() run_schedule(check_schedule(out, 1)) np.testing.assert_allclose(out.numpy(), (np.arange(4)+b.numpy()).sum(), atol=1e-5) def test_argmin(self): Tensor.manual_seed(0) x = Tensor.randn(4, 32).realize() out = x.argmin(-1) run_schedule(check_schedule(out, 2)) np.testing.assert_equal(out.numpy(), x.numpy().argmin(axis=-1)) def test_argmax(self): Tensor.manual_seed(0) x = Tensor.randn(4, 32).realize() out = x.argmax(-1) run_schedule(check_schedule(out, 2)) np.testing.assert_equal(out.numpy(), x.numpy().argmax(axis=-1)) def test_arange_transposed(self): Tensor.manual_seed(0) x = Tensor.randint(4, 1).realize() a = ((Tensor.arange(4,)*x).T).sum() run_schedule(check_schedule(a, 1)) np.testing.assert_equal(a.numpy(), (np.arange(4)*x.numpy()).T.sum()) def test_div_padded_arange(self): x = Tensor.full((2,2), 16) y = x.idiv(Tensor.linspace(2, 8, steps=4, dtype=dtypes.int).reshape(2,2)).pad(((1,1), (1,1))) out = y.sum(axis=1) run_schedule(check_schedule(out, 1)) self.assertListEqual(out.tolist(), [0, 12, 4, 0]) def test_arange_transposed_descendants(self): Tensor.manual_seed(0) x = Tensor.randint(4, 1).realize() a = (Tensor.arange(4,)*x).T b = Tensor.randint(4, 4).realize() out = (a+b).sum() run_schedule(check_schedule(out, 1)) np.testing.assert_equal(out.numpy(), ((np.arange(4)*x.numpy()).T+b.numpy()).sum()) def test_arange_index(self): Tensor.manual_seed(0) x = Tensor.randn(5, 2).realize() a = Tensor.arange(10) out = (x + a[2]).sum() run_schedule(check_schedule(out, 1)) np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6) def test_arange_index_shrink(self): Tensor.manual_seed(0) with Context(TRACK_MATCH_STATS=0): x = Tensor.randn(11).realize() a = Tensor.arange(22) out = (x + a[:11]).sum() check_schedule(out, 1) def test_arange_index_contiguous(self): Tensor.manual_seed(0) x = Tensor.randn(5, 2).realize() a = Tensor.arange(10).contiguous() out = (x + a[2]).sum() run_schedule(check_schedule(out, 2)) np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6) def test_arange_index_child(self): Tensor.manual_seed(0) x = Tensor.randn(5, 2).realize() a = Tensor.arange(10)+1 out = (x + a[2]).sum() run_schedule(check_schedule(out, 1)) np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum(), atol=1e-5, rtol=1e-6) def test_user_contiguous(self): Tensor.manual_seed(0) x = Tensor.randn(5, 2).realize() a = (Tensor.arange(10)+1).contiguous() out = (x + a[2]).sum() run_schedule(check_schedule(out, 2)) np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum(), atol=1e-5, rtol=1e-6) @unittest.skip("BUFFER_VIEW no longer supported on non-disk devices") def test_arange_view_op(self): a = Tensor.arange(12).reshape(4, 3).shrink(((1, 2), (1, 3))).contiguous() sched = run_schedule(check_schedule(a, 1)) self.assertIs(sched[1].ast.op, Ops.BUFFER_VIEW) np.testing.assert_equal(a.numpy(), [[4, 5]]) @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") def test_precompute_freqs_cis(self): from extra.models.llama import precompute_freqs_cis args = {"dim":32 if CI else 128, "end":2048 if CI else 8192, "theta":10000} fused = precompute_freqs_cis(**args) run_schedule(check_schedule(fused, 1)) if getenv("CHECK", 1): ref = precompute_freqs_cis(**args) run_schedule(check_schedule(ref, 1)) np.testing.assert_equal(fused.numpy(), ref.numpy()) def test_fuse_assign_contiguous(self): x = Tensor.zeros(4, 4, dtype=dtypes.int).contiguous().realize() a = Tensor.arange(8).reshape(4, 2) run_schedule(check_schedule(x.shrink((None, (0, 2))).assign(a.contiguous()), 2)) np.testing.assert_equal(x.numpy(), [[0, 1, 0, 0], [2, 3, 0, 0], [4, 5, 0, 0], [6, 7, 0, 0]]) def test_assign_non_contiguous_alt(self): self.test_assign_non_contiguous(alt=True) def test_assign_non_contiguous(self, alt=False): x = (Tensor.arange(16)-100).reshape(4,4).contiguous().realize() xref = x.numpy() if alt: y = Tensor.randint(2, 4).contiguous().realize() a = Tensor.arange(8).reshape(2, 4)+y tst = x.shrink(((0, 2), None)).assign(a).realize() xref[:2, :] = np.arange(8).reshape(2, 4)+y.numpy() else: y = Tensor.randint(4, 2).contiguous().realize() a = Tensor.arange(8).reshape(4, 2)+y tst = x.shrink((None, (0, 2))).assign(a).realize() xref[:, :2] = np.arange(8).reshape(4, 2)+y.numpy() np.testing.assert_equal(x.numpy(), xref) np.testing.assert_equal(tst.numpy(), a.numpy()) def test_setitem_sched(self, mop=lambda x:x, expected_kcount=1): a = Tensor.arange(16, device="CPU").reshape(4, 4).contiguous().realize() a2 = mop(a) expected = (a+a2).tolist() a.assign(a+a2) kcount = len(sched:=a.schedule()) run_schedule(sched) self.assertListEqual(a.tolist(), expected) self.assertEqual(kcount, expected_kcount) def test_setitem_permuted_sched(self): self.test_setitem_sched(lambda x: x.T, 2) def test_setitem_paddded_sched(self): self.test_setitem_sched(lambda x: x.shrink_to(4, 1).pad_to(4, 4), 1) def test_sparse_categorical_crossentropy_simple(self): X = Tensor([[0, 2, 3], [1, 2, 3]]).realize() Y = Tensor([1, 2]).realize() loss = X.sparse_categorical_crossentropy(Y) run_schedule(check_schedule(loss, 3)) np.testing.assert_allclose(loss.item(), 0.878309, atol=1e-5, rtol=1e-6) def test_const_folding_alt(self): t = Tensor.full((2,), 1.) lt = (t < 0.) a = Tensor.empty(2).assign(t*lt.where(-1., 0.)) b = Tensor.empty(2, dtype=dtypes.bool).assign(lt) Tensor.realize(a, b) self.assertEqual(a.tolist(), [0., 0.]) self.assertEqual(b.tolist(), [False, False]) @unittest.skipIf(Device.DEFAULT == "WEBGPU", "Validation error on WebGPU") def test_mnist_val(self): from tinygrad.nn.datasets import mnist import torch _, Y_train, _, _ = mnist() samples = Tensor.randint(BS:=getenv("BS", 512), high=cast(int,Y_train.shape[-1])).realize() yt = Tensor.randn(BS, 10).realize() with Context(SPLIT_REDUCEOP=0): loss = yt.sparse_categorical_crossentropy(Y_train[samples]) run_schedule(check_schedule(loss, 4)) loss_fused = loss.numpy() loss_ref = torch.nn.CrossEntropyLoss()(torch.tensor(yt.numpy()), torch.tensor(Y_train.numpy())[torch.tensor(samples.numpy())]) np.testing.assert_allclose(loss_fused, loss_ref.numpy(), atol=1e-6, rtol=1e-6) def test_arange_fuse_grouped_children(self): X = Tensor.randn(4, 4).realize() r = (X+Tensor.arange(16).reshape(4, 4)).sum() out0 = r+2 out1 = r+3 run_schedule(check_schedule([out0, out1], 2)) # TODO: 1? r_ref = (X.numpy()+np.arange(16).reshape(4, 4)).sum() np.testing.assert_allclose(out0.numpy(), r_ref+2, rtol=2e-7) np.testing.assert_allclose(out1.numpy(), r_ref+3, rtol=2e-7) def test_recursive_swizzle(self): a = Tensor([1,2,3,4]).realize() for _ in range(24): a = a + a new_uop = a.reshape(4,1).realize().uop assert new_uop.base.op is Ops.BUFFER @unittest.skipIf(CI and Device.DEFAULT == "NV", "crashes on NV CI") def test_limit_bufs_with_var(self): N = 31 with Context(TRACK_MATCH_STATS=0, DEBUG=0): bufs = [Tensor([1]*10).contiguous().realize() for i in range(N)] vi = Variable("i", 0, 9).bind(1) vj = Variable("j", 0, 9).bind(2) root = bufs[0][vi] + bufs[0][vj] for X in range(1,N): root = root + bufs[X][vi] + bufs[X][vj] self.assertEqual(root.item(), N * 2) def test_limit_bufs_kernelize(self): N = 31 with Context(TRACK_MATCH_STATS=0, DEBUG=0): bufs = [Tensor(i).contiguous().realize() for i in range(N)] x = bufs[0] for y in bufs[1:]: x = x+y x.kernelize() kcount = len([s for s in x.uop.toposort() if s.op is Ops.KERNEL]) z = x+Tensor.empty(1) # z only loads 2 buffers sched = z.schedule() self.assertEqual(len(sched), kcount+1) class TestSwizzle(unittest.TestCase): def test_swizzle_simple(self): Tensor.manual_seed(0) with Context(DEBUG=0, TRACK_MATCH_STATS=0): a = Tensor.randint(32, 32).realize() r = (a+a).sum(1).sum(0) # double reduce collapses to a single reduce run_schedule(check_schedule(r, 1)) self.assertEqual(r.numpy(), (a.numpy()+a.numpy()).sum(1).sum(0)) def test_single_swizzle(self): Tensor.manual_seed(0) with Context(DEBUG=0, TRACK_MATCH_STATS=0): a = Tensor.randint(4, 1).realize() b = Tensor.ones((1, 1), dtype=a.dtype).contiguous().realize() # ADD(REDUCE(RESHAPE(LOAD)), LOAD) to ADD(REDUCE(RESHAPE(LOAD))), RESHAPE(LOAD) r = a.sum(0)+b run_schedule(check_schedule(r, 1)) self.assertEqual(r.numpy(), a.numpy().sum(0)+1) def test_double_swizzle_possible(self): Tensor.manual_seed(0) with Context(DEBUG=0, TRACK_MATCH_STATS=0): a = Tensor.randint(4,).realize() b = Tensor.randint(4,).realize() # parallel reduce! add = a.sum(0)+b.sum(0) run_schedule(check_schedule(add, 1)) self.assertEqual(add.numpy(), a.numpy().sum(0)+b.numpy().sum(0)) def test_softmax_one_kernel(self): Tensor.manual_seed(0) with Context(DEBUG=0, TRACK_MATCH_STATS=0): a = Tensor.randn(32, 32).realize() t = a.softmax() check_schedule(t, 3) # TODO: 1? def test_argmax_one_kernel(self): Tensor.manual_seed(0) with Context(DEBUG=0, TRACK_MATCH_STATS=0): a = Tensor.randn(10, 20).realize() t = a.argmax(0) check_schedule(t, 2) # TODO: 1? def test_swizzle_reduceop(self): Tensor.manual_seed(0) x = Tensor.randn(4,4).realize() y = Tensor.randn(4,4,4).realize() out = x.reshape(4,4,1).expand(4,4,4).sum(axis=(1,))+y run_schedule(check_schedule(out, 2)) # TODO: 1? np.testing.assert_allclose(out.numpy(), np.tile(x.numpy().reshape(4,4,1), (1,1,4)).sum(axis=1)+y.numpy()) def test_permute_rewrite(self): x = Tensor.randn(4, 4, 16).realize() y = Tensor.randn(4, 1, 16).realize() z = Tensor.randn(4, 4, 1).realize() t = (x*y).sum(axis=(0, 2)).reshape(1, 4, 1).permute(0, 2, 1)+z run_schedule(check_schedule(t, 2)) # TODO: 1? t_np = (x.numpy()*y.numpy()).sum(axis=(0, 2)).reshape(1, 4, 1).transpose(0, 2, 1)+z.numpy() np.testing.assert_allclose(t.numpy(), t_np, atol=1e-6, rtol=1e-3) @unittest.skip("TODO: this swizzle isn't resolvable when there's a mask") def test_swizzle_failure_permute(self): a = Tensor.empty(45,65).T.reshape(65,1,45).pad((None,None,(0,45))).expand(65,45,90) b = Tensor.empty(45,65) a_reduce = a.sum(axis=(2,), keepdim=True).sum(axis=(1,)) b_reduce = b.sum(axis=(0,)) t = a_reduce+b_reduce run_schedule(check_schedule(t, 1)) def test_parallel_reduce_possible(self): Tensor.manual_seed(0) x = Tensor.randn(4, 2, 2).realize() y = Tensor.randn(4, 2, 2).realize() t = x.sum(axis=1)+y.sum(axis=1) run_schedule(check_schedule(t, 1)) np.testing.assert_allclose(t.numpy(), x.numpy().sum(axis=1)+y.numpy().sum(axis=1), atol=1e-6, rtol=1e-3) # kernels can only have 1 or n in each dim def test_dont_parallelize_different_n(self): Tensor.manual_seed(0) x = Tensor.randn(4, 2, 2).realize() y = Tensor.randn(4, 3, 2).realize() t = x.sum(axis=1)+y.sum(axis=1) run_schedule(check_schedule(t, 1)) np.testing.assert_allclose(t.numpy(), x.numpy().sum(axis=1)+y.numpy().sum(axis=1), atol=1e-6, rtol=1e-3) def test_unsafe_pad(self): x = Tensor.full((2,2), 1.0).contiguous() y = x*x.sum((1,)).reciprocal() t = y.pad(((0,1),None)) run_schedule(check_schedule(t, 3)) np.testing.assert_equal(t.numpy(), [[0.5, 0.5], [0.5, 0.5], [0., 0.]]) zero_pm = UPat(Ops.CONST, arg=0) class TestView(unittest.TestCase): def test_all_masked_out(self): # start with non CONST Ops a = Tensor.rand(10, 10).realize() # all masked out, degrades to const 0 b = a.pad(((0, 10), None))[10:] sched = check_schedule(b.contiguous(), 1) run_schedule(sched) np.testing.assert_equal(b.numpy(), 0) def test_mask_dim_1(self): # mask out dim = 1 works too a = Tensor.rand(10, 10).realize() b = a.pad((None, (0, 10)))[:, 10:] assert b.shape == (10, 10) sched = check_schedule(b.contiguous(), 1) run_schedule(sched) np.testing.assert_equal(b.numpy(), 0) def test_zero_size_alt(self): a = Tensor.empty(135, 0, 9) b = a.pad(((0, 0), (0, 0), (18, 0))) check_schedule(b, 0) def test_partial_mask(self): # partial masked out does not degrade into CONST a = Tensor.rand(10, 10).realize() b = a.pad(((0, 5), None))[5:] assert b.shape == (10, 10) sched = check_schedule(b.contiguous(), 1) run_schedule(sched) np.testing.assert_allclose(b.numpy(), np.pad(a.numpy(), ((0, 5), (0, 0)))[5:]) # a*VIEW(x), where VIEW(x) = 0 # x collapses along with its children def test_parent_view_collapses(self): a = Tensor([1, 2]) b = Tensor.arange(3).contiguous() bv = b.pad(((0, 2),))[-2:] # this becomes a late a*0 late_mul = a*bv check_schedule(late_mul, 0) # the arange doesn't realize self.assertIsNone(b.uop.base.realized) # mul doesn't realize self.assertIsNone(late_mul.uop.base.realized) self.assertEqual(late_mul.tolist(), [0, 0]) # SINK has two branches: # a*VIEW(x), where VIEW(x) = 0 # x+2 # as long as one child realizes, x does not collapse def test_parent_multiple_children_no_collapse(self): a = Tensor([1, 2]) b = Tensor.arange(3).contiguous() bv = b.pad(((0, 2),))[-2:] late_mul = a*bv other_child = b+2 s = check_schedule([late_mul, other_child], 2) # the arange becomes a BUFFER self.assertIs(b.uop.base.op, Ops.BUFFER) # mul still collapses self.assertIs(late_mul.uop.base.op, Ops.CONST) run_schedule(s) self.assertEqual(other_child.tolist(), [2, 3, 4]) @unittest.skipIf(Device.DEFAULT == "CPU", "tests copy from another device to cpu") class TestCopyFolding(unittest.TestCase): def test_const_copy_is_free(self): b = Tensor(1).to("CPU") check_schedule(b, 0, filter_sink=False) assert b.item() == 1 def test_one_hot_with_copy(self): y = Tensor([1, 2, 3]).to("CPU") x = y.one_hot(10) check_schedule(x, 3, filter_sink=False) def test_const_copy_multi(self): x = Tensor.ones(1, device="CPU").to_(["CPU", "CPU:1"]) check_schedule(x, 0, filter_sink=False) self.assertEqual(x.item(), 1) def test_late_const_copy_folding(self): a = Tensor.arange(3).realize() zeros = Tensor.zeros(3).realize() b = (a*zeros).to("CPU") run_schedule(check_schedule(b, 0, filter_sink=False)) self.assertListEqual(b.tolist(), [0, 0, 0]) self.assertEqual(b.device, "CPU") def test_alu_after_copy(self): a = Tensor.ones((4,)).to("CPU") b = Tensor.empty(4, device="CPU") add = a+b assert all_same([x.device for x in add.uop.src]), f"ALU has different devices! {[x.device for x in add.src]}" add.kernelize() def test_alu_before_copy(self): buf = Tensor.ones(1).contiguous().realize() a = buf+1 b = a.to("CPU") self.assertListEqual(b.tolist(), [2.]) def test_copy_to_same_device(self): a = Tensor.empty(4).uop b = a.copy_to_device(a.device) check_schedule(b, 1, filter_sink=False) # TODO: 0? def test_copy_to_same_device_alt(self): a = Tensor.empty(4, 4).uop b = a.copy_to_device(a.device) check_schedule(b, 1, filter_sink=False) # TODO: 0? def test_copy_to_same_device_sched(self): a = Tensor.ones(4).contiguous().realize().uop.as_buf() t = Tensor(a.copy_to_device(a.device)) sched = t.schedule() assert len([s for s in sched if s.ast.op is Ops.COPY]) == 0 run_schedule(sched) assert t.uop.is_realized, f"didn't realize Tensor {t}" self.assertListEqual(t.tolist(), [1.,1.,1.,1.]) def test_clone(self): a = Tensor.empty(4) check_schedule(a.clone(), 1, filter_sink=False) def test_shrink_copy(self): a = Tensor.arange(4) view = a.shrink(((0, 2),)) b = view.clone() run_schedule(check_schedule(b, 1, filter_sink=False)) self.assertEqual(b.uop.base.buffer.size, 2) self.assertEqual(b.uop.size, 2) self.assertListEqual(b.tolist(), [0, 1]) def test_expanded_copy(self): a = Tensor.arange(2) view = a.reshape(2, 1).expand(2, 2) b = view.clone() run_schedule(check_schedule(b, 1, filter_sink=False)) self.assertEqual(b.uop.base.buffer.size, 4) self.assertEqual(b.uop.size, 4) self.assertListEqual(b.tolist(), [[0, 0], [1, 1]]) def test_permuted_copy(self): a = Tensor.arange(4) b = a.reshape(2, 2).permute(1, 0) b.realize() self.assertListEqual(b.tolist(), [[0, 2], [1, 3]]) def test_permute_on_disk(self): with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_buffer()) a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}") b = a.reshape(2, 2).permute(1, 0).to("CPU") b.realize() self.assertListEqual(b.tolist(), [[0, 2], [1, 3]]) def test_permute_on_disk_contiguous(self): with open(temp('dt_arange_4_permute_contig'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_buffer()) a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute_contig')}") b = a.reshape(2, 2).permute(1, 0).contiguous().to("CPU") b.realize() self.assertListEqual(b.tolist(), [[0, 2], [1, 3]]) def test_permute_after_shrink(self): a = Tensor.arange(5) b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CPU") b.realize() self.assertListEqual(b.tolist(), [[0, 2], [1, 3]]) # NOTE: disk permute must come after COPY def test_permute_after_shrink_on_disk(self): with open(temp('dt_arange_5_permute'), "wb") as f: f.write(Tensor.arange(5).realize().uop.base.buffer.as_buffer()) a = Tensor.empty(5, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_5_permute')}") b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CPU") b.realize() self.assertListEqual(b.tolist(), [[0, 2], [1, 3]]) class TestBufferUOp(unittest.TestCase): # BUFFER has a ShapeTracker of shape=(n,) and stride=(1,) def test_buffer_has_buffer(self): buf = Tensor.empty(10) self.assertIsNotNone(buf.uop.buffer) self.assertEqual(buf.uop.shape, (10,)) # the device Buffer remains unallocated until it's we run the schedule self.assertFalse(buf.uop.buffer.is_allocated()) add = buf+1 sched = add.schedule() self.assertFalse(buf.uop.buffer.is_allocated()) run_schedule(sched) self.assertTrue(buf.uop.buffer.is_allocated()) def test_buffer_has_unique_buffer(self): buf = Tensor.empty(10) buf1 = buf.uop.buffer buf2 = buf.uop.buffer self.assertIs(buf1, buf2) # we also allow VIEW(BUFFER) to access the underlying device Buffer, as long as it's contiguous def test_buffer_view_allowed(self): add = Tensor.empty(1, 1)+Tensor.empty(1, 1) add.realize() self.assertIsNotNone(add.uop.buffer) self.assertEqual(add.uop.shape, (1, 1)) def test_buffer_view_not_allowed(self): permuted_view = Tensor.empty(1, 2, 3).permute(0, 2, 1) with self.assertRaisesRegex(AssertionError, "can only be RESHAPE"): permuted_view.uop.buffer # cannot access Buffer of a non contiguous VIEW def test_buffer_only_after_realize(self): a = Tensor([1])+Tensor([2]) # accessing realized will return None self.assertIsNone(a.uop.realized) # accessing Buffer will assert with self.assertRaisesRegex(AssertionError, "must be BUFFER"): a.uop.buffer # there is no BUFFER on an unrealized ADD # Buffer only exists once we realize it a.realize() self.assertIsNotNone(a.uop.buffer) def test_const_does_not_realize(self): a = Tensor(1)+Tensor(2) run_schedule(check_schedule(a, 0)) self.assertIsNone(a.uop.base.realized) def test_var_does_not_realize(self): a = Tensor(UOp.variable("a", 0, 10).bind(1)) run_schedule(check_schedule(a, 0)) self.assertIsNone(a.uop.base.realized) def test_view_does_not_realize(self): a = Tensor.randn(1, 4).expand(4, 4) a.realize() self.assertEqual(a.uop.base.realized.size, 4) a2 = a.contiguous().realize() self.assertEqual(a2.uop.base.realized.size, 16) class TestContiguous(unittest.TestCase): def test_contiguous_buffer(self): a = Tensor.empty(4) b = a.contiguous() check_schedule(b, 0) def test_contiguous_buffer_view(self): a = Tensor.empty(4) b = a.reshape((2, 2)).contiguous() check_schedule(b, 0) def test_non_contiguous_buffer_view(self): a = Tensor.empty(4, 1) b = a.expand((4, 4)).contiguous() check_schedule(b, 1) def test_size_change_buffer_view(self): a = Tensor.empty(4) b = a.reshape((1, 1, 4)).shrink(((0, 1), (0, 1), (0, 3))).contiguous() check_schedule(b, 1) def test_double_contiguous_realizes_once(self): a = Tensor.empty(4, 1) b = a.expand((4, 4)).contiguous().contiguous() check_schedule(b, 1) def test_view_does_not_realize(self): a = Tensor.empty(4) b = a.expand((4, 4)) check_schedule(b, 0) self.assertEqual(b.uop.base.buffer.size, 4) def test_contiguous_view_realizes(self): a = Tensor.empty(4) b = a.expand((4, 4)).contiguous() check_schedule(b, 1) self.assertEqual(b.uop.base.buffer.size, 16) class TestUOpBecome(unittest.TestCase): # the simplest case, if we create a new BUFFER for this tensor UOp def test_new_buffer(self): a = Tensor.empty(4, 4) b = Tensor.empty(4, 4) add = a+b check_schedule(add, 1) # NOTE: realized base is always a flat buffer assert UPat(Ops.BUFFER).match(add.uop.base, {}) # the Tensor UOp can optionally stack a VIEW on top of the BUFFER, in this case to preserve the (4, 4) shape of the tensor assert add.uop is not add.uop.base self.assertEqual(add.uop.size, 16) self.assertEqual(add.uop.shape, (4, 4)) def test_new_buffer_view(self): a = Tensor.empty(4, 4) b = Tensor.empty(4, 4) add = (a+b).reshape(8, 2) check_schedule(add, 1) assert UPat(Ops.BUFFER).match(add.uop.base, {}) # the shape is preserverd in the becomes_map. self.assertEqual(add.uop.shape, (8, 2)) assert add.uop is not add.uop.base def test_new_flat_buffer(self): a = Tensor.empty(4,) b = Tensor.empty(4,) add = a+b check_schedule(add, 1) # BUFFER already has a shape (4,), this tensor just becomes a contiguous BUFFER assert UPat(Ops.BUFFER).match(add.uop.base, {}) # sometimes we prefer to perform an op before movement ops, in this case we should stack the mops on top of the new buffer def test_reorder_expand(self): a = Tensor.empty(4, 1) b = a.expand(4, 4).reciprocal() check_schedule(b, 1) self.assertEqual(b.uop.base.buffer.size, 4) self.assertEqual(b.uop.shape, (4, 4)) def test_reorder_expand_alt(self): x = Tensor.empty(4, 1) y = Tensor.empty(4, 1) img = Tensor.empty(4, 4) z = (img*x) / y check_schedule(z, 1) # TODO: rangeify doesn't yet cleanup this kind of re-indexing @unittest.expectedFailure def test_become_existing_buffer(self): a = Tensor.empty(4, 4) b = a*1 assert UPat(Ops.MUL).match(b.uop, {}) # before scheduling it's a mul check_schedule(b, 0) self.assertIs(a.uop.base.buffer, b.uop.base.buffer) def test_become_buf_with_mops(self): a = Tensor.empty(2, 4, 2) noop = a.shrink(((1, 2), (0, 4), (0, 2))).reshape(4, 2)*1+0 # before realizing, this tensor is base assert noop.uop is noop.uop.base noop.realize() # it becomes a realized view after realize assert noop.uop is not noop.uop.base assert noop.uop.base.op is Ops.BUFFER late_add = noop+2 late_add.realize() def test_become_const_in_base(self): a = Tensor.empty(4) b = a*0 assert UPat(Ops.MUL).match(b.uop, {}) # before scheduling it's a mul check_schedule(b, 0) assert UPat(Ops.CONST, arg=0).match(b.uop.base, {}) # scheduling replaces the tensor uop with a VIEW(BUFFER) def test_become_const_from_const(self): const_add = Tensor(1)+Tensor(2) assert UPat(Ops.ADD).match(const_add.uop, {}) check_schedule(const_add, 0) assert UPat(Ops.CONST, arg=3).match(const_add.uop.base, {}) # tensors can become another realized tensor source @unittest.expectedFailure def test_become_existing_buf_simple(self): a = Tensor.empty(4, 4) b = a+0 check_schedule(b, 0) assert b.uop.base.op is Ops.BUFFER self.assertIs(a.uop, b.uop) # they can also chain other movement ops on top of the tensor source @unittest.expectedFailure def test_become_existing_buf_view(self): a = Tensor.empty(4, 4) b = a.permute((1, 0))+0 check_schedule(b, 0) self.assertEqual(b.uop.st, a.uop.permute((1, 0)).st) @unittest.expectedFailure def test_become_existing_buf_view_alt(self): a = Tensor.empty(4, 4) b = a.permute((1, 0)).reshape((8, 2))+0 check_schedule(b, 0) self.assertEqual(b.uop.st, a.uop.permute((1, 0)).reshape((8, 2)).st) # they can also have other base parents that simplified, in that case we just backtrack to the chained mops @unittest.expectedFailure def test_become_existing_buf_complex(self): a = Tensor.empty(4, 4) b = (a.permute((1, 0))+0).reshape((8, 2))+0 check_schedule(b, 0) self.assertEqual(b.uop.st, a.uop.permute((1, 0)).reshape((8, 2)).st) assert b.uop.base.op is Ops.BUFFER @unittest.expectedFailure def test_become_multiple_choices(self): a = Tensor.empty(16) b = (a.reshape(1, 1, 4, 1, 4)+0).reshape(1, 1, 4, 4).shrink(((0, 1), (0, 1), (0, 3), (0, 3)))+0 c = (a.reshape(1, 1, 4, 4)+0).shrink(((0, 1), (0, 1), (0, 3), (0, 3)))+0 check_schedule([b, c], 0) assert all_same([x.uop.base.realized for x in [a,b,c]]) def test_setitem_becomes_subbuffer(self): a = Tensor.full((4,), 2.).contiguous().realize() b = a.shrink(((0, 2),)).assign(Tensor.full((2,), 1.0)) b.realize() assert a.uop.is_realized assert a.uop.buffer._base is None assert b.uop.op_in_backward_slice_with_self(Ops.SHRINK) assert b.uop.base is a.uop.base def test_setitem_offset(self): a = Tensor.full((16,), 0.).contiguous().realize() b = Tensor.full((16,), 1.).contiguous().realize() a_view = a[4:].reshape(3, 4).shrink(((0,2),(0,2))).reshape((4,)) b.shrink(((0,4),)).assign(a_view).realize() self.assertListEqual(b.tolist(), [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) if __name__ == '__main__': unittest.main(verbosity=2)