import unittest, functools, random from typing import List from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes, Variable from tinygrad.ops import Ops, UOp from tinygrad.helpers import CI, getenv, prod, Context, OSX from tinygrad.nn.state import get_parameters, get_state_dict from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner, run_schedule from tinygrad.engine.multi import all_reduce import numpy as np from hypothesis import given, strategies as strat, settings from tinygrad.device import is_dtype_supported settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False)) settings.load_profile("my_profile") d0 = f"{Device.DEFAULT}:0" d1 = f"{Device.DEFAULT}:1" d2 = f"{Device.DEFAULT}:2" d3 = f"{Device.DEFAULT}:3" d4 = f"{Device.DEFAULT}:4" d5 = f"{Device.DEFAULT}:5" devices_2 = (d1, d2) devices_3 = (d1, d2, d3) devices_4 = (d1, d2, d3, d4) N = 128 # shard_x is "data parallel" # shard_w is "model parallel" def _test_allreduce(t:Tensor): aa = (t[0:64] + t[64:128] + t[128:192] + t[192:256]).repeat([4,1]).realize() ts = t.shard(devices_4, 0).realize() b = Tensor(UOp.multi(*all_reduce(Ops.ADD, ts.lazydata.src), axis=0)) b.realize() return aa, b @unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL"), "no GPU CI") class TestMultiTensor(unittest.TestCase): def test_to(self): X = Tensor.ones(256).contiguous().realize() X.to_(devices_2) for lb in X.lazydata.src: assert lb.shape == (256,) (X + X).realize() def test_gradient(self): X = Tensor.ones(256).contiguous().realize() X.to_(devices_2) grad = X.sum().gradient(X)[0] grad.realize() def test_shard(self): X = Tensor.ones(256).contiguous().realize() X.shard_(devices_2, 0) for lb in X.lazydata.src: assert lb.shape == (128,) (X + X).realize() def test_shard_not_multiple(self): X = Tensor.ones(256).contiguous().realize() with self.assertRaises(RuntimeError): X.shard_(devices_3, 0) def test_tensor_from_multi(self): X = Tensor([1, 2], dtype=dtypes.int).shard_(devices_2, 0) Y = Tensor(X.lazydata) self.assertEqual(Y.device, Device.DEFAULT) np.testing.assert_equal(X.numpy(), Y.numpy()) with self.assertRaises(AssertionError): _ = Tensor(X.lazydata, dtype=dtypes.float) def test_sharded_arange(self): sharded_arange = Tensor.arange(1000).shard(devices_2, 0) sharded_arange.realize() np.testing.assert_equal(sharded_arange.numpy(), np.arange(1000)) def test_shard_no_recompile(self): X = Tensor.ones(256).contiguous().realize() X.shard_(devices_2, 0) out = (X + X) sched = out.schedule() names = [] for si, ei in lower_schedule(sched): if isinstance(ei.prg, CompiledRunner): names.append(ei.prg.p.name) ei.run() self.assertEqual(len(set(names)), 3), "function was relinearized" @unittest.skip("this doesn't fold because shard_ calls contiguous on all lbs") def test_sharded_memory(self): # Buffer may be stuck in track_cross_buffer for x in (d0, d1, d2, d3, d4): Device[x].synchronize() mem_base = GlobalCounters.mem_used X = Tensor.ones(256).contiguous().realize() assert GlobalCounters.mem_used-mem_base== X.dtype.itemsize * 256, GlobalCounters.mem_used-mem_base X.shard_(devices_4).realize() for x in (d0, d1, d2, d3, d4): Device[x].synchronize() assert GlobalCounters.mem_used-mem_base == X.dtype.itemsize * 256 * 4, GlobalCounters.mem_used-mem_base X = Tensor.ones(256).contiguous().realize() assert GlobalCounters.mem_used-mem_base == X.dtype.itemsize * 256, GlobalCounters.mem_used-mem_base X.shard_(devices_4, axis=0).realize() for x in (d0, d1, d2, d3, d4): Device[x].synchronize() assert GlobalCounters.mem_used-mem_base == X.dtype.itemsize * 256, GlobalCounters.mem_used-mem_base X = Tensor.ones(256).realize() assert GlobalCounters.mem_used-mem_base == 0 X.shard_(devices_4).realize() assert GlobalCounters.mem_used-mem_base == 0 X = Tensor.ones(256).realize() assert GlobalCounters.mem_used-mem_base == 0 X.shard_(devices_4, axis=0).realize() assert GlobalCounters.mem_used-mem_base == 0 def test_shard_same_device(self): X = Tensor.ones(256).contiguous().realize() X.shard_((d1, X.device), 0) (X + X).realize() def test_shard_plus_one_sum(self): X = Tensor.ones(256).contiguous().realize() X.shard_((d1, d2), 0) (X + 1).sum().realize() def test_shard_plus_one_sum_d0(self): X = Tensor.ones(256).contiguous().realize() X.shard_((d0, d2), 0) (X + 1).sum().realize() def test_numpy(self): X = Tensor.ones(256) X.shard_((d1, d2), 0) np.testing.assert_allclose(X.numpy(), 1) def _test_simple_add_axis(self, shard_x, shard_w): X = Tensor.ones(256).contiguous().realize() W = Tensor.ones(256).contiguous().realize() X.shard_((d1, d2), shard_x) W.shard_((d1, d2), shard_w) O = X + W np.testing.assert_allclose(O.numpy(), 2) def test_simple_add(self): return self._test_simple_add_axis(None, None) def test_simple_add_X(self): return self._test_simple_add_axis(0, None) def test_simple_add_W(self): return self._test_simple_add_axis(None, 0) def test_simple_add_XW(self): return self._test_simple_add_axis(0, 0) def test_four_add(self): X = Tensor.ones(256, 256).contiguous().realize() W = Tensor.ones(256, 256).contiguous().realize() X.shard_(devices_4, 1) W.shard_(devices_4, None) O = X + W np.testing.assert_allclose(O.numpy(), 2) def test_elementwise_dtype(self): Tensor.manual_seed(0) X = Tensor.randn(8, 8).realize() W = Tensor.randn(8, 8).realize() X.shard_(devices_4, 0) W.shard_(devices_4, 0) O = X.shrink(((0, 2), None)) * W.shrink(((0, 2), None)) < 2 np.testing.assert_allclose(O.numpy(), X.numpy()[0:2]*W.numpy()[0:2] < 2) @given(strat.sampled_from((4, 5)), strat.sampled_from((devices_2, devices_3)), strat.sampled_from((Ops.ADD, Ops.MUL, Ops.MAX)), strat.sampled_from((None, 0, 1)), strat.sampled_from((None, 0, 1)), strat.sampled_from((1, 0, -1))) def test_simple_reduce(self, N, devices, rop, shard_axis, reduce_axis, sign): N = N * len(devices) X = Tensor.rand(N*N).reshape(N, N).mul(sign) n = X.numpy() X.shard_(devices, shard_axis) f = {Ops.ADD: lambda x: x.sum(reduce_axis), Ops.MUL: lambda x: x.prod(reduce_axis), Ops.MAX: lambda x: x.max(reduce_axis)}[rop] fX = f(X) fn = f(n) np.testing.assert_allclose(fX.numpy(), fn, rtol=1e-6, atol=1e-6) def test_allreduce_naive(self): with Context(RING=0): a,b = _test_allreduce(Tensor.rand(256, 256)) np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5) def test_allreduce_ring(self): with Context(RING=2): a,b = _test_allreduce(Tensor.rand(256, 256)) np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5) def test_copy_jit(self): @TinyJit def copy_tensor(x:Tensor): return (x.to(f"{x.device.split(':')[0]}:1") + 1) for _ in range(5): t = Tensor.rand(256).realize() x = copy_tensor(t) np.testing.assert_equal((t+1).numpy(), x.numpy()) def test_allreduce_naive_jit(self): with Context(RING=0): jit_allreduce = TinyJit(_test_allreduce) for _ in range(5): a,b = jit_allreduce(Tensor.rand(256, 256)) np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5) def test_allreduce_ring_jit(self): with Context(RING=2): jit_allreduce = TinyJit(_test_allreduce) for _ in range(5): a,b = jit_allreduce(Tensor.rand(256, 256)) np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5) @unittest.skip("slow") def test_fuzz_allreduce(self): random.seed(41) for it in range(100): for n in range(2, 4+1): shape = tuple([(n if i == 0 else 1) * random.randint(1, 10) for i in range(random.randint(1, 4))]) t = Tensor.rand(shape).shard_(tuple([d0, d1, d2, d3][:n]), 0) with Context(RING=0): a = Tensor(UOp.multi(*all_reduce(Ops.ADD, t.lazydata.src), axis=0)) with Context(RING=2): b = Tensor(UOp.multi(*all_reduce(Ops.ADD, t.lazydata.src), axis=0)) diff = a - b mean_err = diff.reshape((prod(diff.shape),)).abs().mean().numpy() max_err = diff.reshape((prod(diff.shape),)).abs().max().numpy() assert mean_err < 1e-6, f"big mean error, iteration {it}_{n}" assert max_err < 1e-6, f"big max error, iteration {it}_{n}" def _test_matmul_shard_axis(self, shard_x, shard_w, device): X = Tensor.kaiming_uniform(N, N).realize() W = Tensor.kaiming_uniform(N, N).realize() Xs = X.shard(device, shard_x) Ws = W.shard(device, shard_w) O = (Xs@Ws) np.testing.assert_allclose(X.numpy() @ W.numpy(), O.to(Device.DEFAULT).numpy(), atol=1e-5) def _test_double_matmul_shard_axis(self, shard_x, shard_w, device): X = Tensor.kaiming_uniform(N, N).realize() W1 = Tensor.kaiming_uniform(N, N).realize() W2 = Tensor.kaiming_uniform(N, N).realize() Xs = X.shard(device, shard_x) W1s = W1.shard(device, shard_w) W2s = W2.shard(device, shard_w) O = (Xs@W1s)@W2s np.testing.assert_allclose((X.numpy() @ W1.numpy()) @ W2.numpy(), O.to(Device.DEFAULT).numpy(), atol=1e-5) def test_matmul_shard_none(self): return self._test_matmul_shard_axis(None, None, devices_2) def test_matmul_shard_X_0(self): return self._test_matmul_shard_axis(0, None, devices_2) def test_matmul_shard_X_1(self): return self._test_matmul_shard_axis(1, None, devices_2) def test_matmul_shard_W_0(self): return self._test_matmul_shard_axis(None, 0, devices_2) def test_matmul_shard_W_1(self): return self._test_matmul_shard_axis(None, 1, devices_2) def test_matmul_shard_0_0(self): return self._test_matmul_shard_axis(0, 0, devices_2) def test_matmul_shard_0_1(self): return self._test_matmul_shard_axis(0, 1, devices_2) def test_matmul_shard_1_0(self): return self._test_matmul_shard_axis(1, 0, devices_2) def test_matmul_shard_1_1(self): return self._test_matmul_shard_axis(1, 1, devices_2) def test_double_matmul_shard_X_0(self): return self._test_double_matmul_shard_axis(0, None, devices_2) def test_double_matmul_shard_X_1(self): return self._test_double_matmul_shard_axis(1, None, devices_2) def test_double_matmul_shard_W_0(self): return self._test_double_matmul_shard_axis(None, 0, devices_2) def test_double_matmul_shard_W_1(self): return self._test_double_matmul_shard_axis(None, 1, devices_2) def test_conv_data_shard(self): conv = nn.Conv2d(3, 16, 3, bias=False) for p in get_parameters(conv): p.shard_(devices_2) fake_image = Tensor.rand((2, 3, 32, 32)).shard(devices_2, axis=0) out = conv(fake_image) out.numpy() def test_conv_bias_data_shard(self): conv = nn.Conv2d(3, 16, 3) for p in get_parameters(conv): p.shard_(devices_2) fake_image = Tensor.rand((2, 3, 32, 32)).shard(devices_2, axis=0) out = conv(fake_image) out.numpy() def test_backprop_conv(self): with Tensor.train(): conv = nn.Conv2d(3, 16, 3) for p in get_parameters(conv): p.shard_(devices_2) optim = nn.optim.Adam(get_parameters(conv)) fake_image = Tensor.rand((2, 3, 32, 32)).shard(devices_2, axis=0) out = conv(fake_image) optim.zero_grad() out.mean().backward() #for p in get_parameters(conv): p.grad.realize() optim.step() out.numpy() def test_backprop_conv_wino(self): with Context(WINO=1): self.test_backprop_conv() def test_backward_sum(self): x = Tensor([[1.,2,3,4], [5,6,7,8]]).shard(devices_2, axis=0) w = Tensor([1.,2,3,4], requires_grad=True).shard(devices_2) out = x * w out.mean().backward() tst = w.grad.numpy() np.testing.assert_allclose(tst, [0.75, 1., 1.25, 1.5]) def test_lr_scheduler_OneCycleLR(self): from extra.lr_scheduler import OneCycleLR conv = nn.Conv2d(3, 16, 3) for p in get_parameters(conv): p.shard_(devices_2) optim = nn.optim.SGD(get_parameters(conv)) lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10) lr_sched.step() def test_embedding(self): B, T, embed_size, vocab_size = 4, 10, 20, 28 layer = nn.Embedding(vocab_size, embed_size) x = Tensor(np.random.randint(0, vocab_size, (B, T), dtype=np.int32)) z = layer(x) layer_sharded = nn.Embedding(vocab_size, embed_size) layer_sharded.weight.replace(layer.weight.shard(devices_2, axis=1)).realize() x_sharded = x.shard(devices_2, axis=None) z_shard = layer_sharded(x_sharded) np.testing.assert_allclose(z.numpy(), z_shard.numpy(), atol=1e-6, rtol=1e-6) def test_rmsnorm(self): B, T, embed_size = 4, 10, 20 norm = nn.RMSNorm(embed_size) x = Tensor.rand((B, T, embed_size)).contiguous().realize() y = norm(x) # for norm layers, the correct way to shard weights is duplication norm_sharded = nn.RMSNorm(embed_size) norm_sharded.weight.shard_(devices_2, axis=None).realize() # if x is being sharded, then all-reduce is involved x_sharded = x.shard(devices_2, axis=2).realize() y_shard = norm_sharded(x_sharded).realize() np.testing.assert_allclose(y.numpy(), y_shard.numpy(), atol=1e-6, rtol=1e-6) # if x is being duplicated, then the operations remain inside each GPU # which is the common case x_sharded = x.shard(devices_2, axis=None).realize() y_shard = norm_sharded(x_sharded).realize() np.testing.assert_allclose(y.numpy(), y_shard.numpy(), atol=1e-6, rtol=1e-6) # NOTE: this is failing on LLVM CI, no idea why. Works locally. @unittest.skipIf(CI and Device.DEFAULT in ("CUDA", "NV", "LLVM"), "slow") @unittest.skipIf(Device.DEFAULT == "WEBGPU" and not OSX, "WEBGPU Vulkan can only run kernels with up to 10 buffers") def test_data_parallel_resnet(self): from extra.models.resnet import ResNet18 fake_image = Tensor.rand((2, 3, 224//8, 224//8)) fake_image_sharded = fake_image.shard(devices_2, axis=0) m = ResNet18() m.load_from_pretrained() real_output = m(fake_image).log_softmax().numpy() for p in get_parameters(m): p.shard_(devices_2).realize() GlobalCounters.reset() shard_output = m(fake_image_sharded).log_softmax().realize() assert shard_output.lazydata.src[0].shape == (1, 1000) assert shard_output.lazydata.src[1].shape == (1, 1000) shard_output_np = shard_output.numpy() np.testing.assert_allclose(real_output, shard_output_np, atol=1e-6, rtol=1e-6) @unittest.skipIf(CI and Device.DEFAULT in ("CUDA", "NV", "LLVM"), "slow, and flaky on LLVM") @unittest.skipIf(Device.DEFAULT == "WEBGPU" and not OSX, "WEBGPU Vulkan can only run kernels with up to 10 buffers") def test_data_parallel_resnet_train_step(self): from extra.models.resnet import ResNet18 from tinygrad.nn.optim import LARS fake_image = Tensor.rand((2, 3, 224//8, 224//8)) fake_image_sharded = fake_image.shard(devices_2, axis=0) labels = Tensor.randint(2, low=0, high=1000) labels_sharded = labels.shard(devices_2, axis=0) m = ResNet18() optimizer = LARS(get_parameters(m), 0.1) # set requires_grad for all params optimizer.zero_grad() m.load_from_pretrained() output = m(fake_image).sparse_categorical_crossentropy(labels, label_smoothing=0.1) output.backward() grad = m.conv1.weight.grad.numpy() for p in get_parameters(m): p.shard_(devices_2).realize() GlobalCounters.reset() optimizer.zero_grad() shard_output = m(fake_image_sharded).sparse_categorical_crossentropy(labels_sharded, label_smoothing=0.1) shard_output.backward() shard_grad = m.conv1.weight.grad.numpy() # sometimes there is zeros in these grads... why? np.testing.assert_allclose(grad, shard_grad, atol=1e-5, rtol=1e-5) def test_assign_kv_cache_multi(self): bsz, max_context = 2, 8 class Attn: @TinyJit def __call__(self, xk:Tensor, start_pos:UOp): seqlen = xk.shape[1] if not hasattr(self, "cache_k"): self.cache_k = Tensor.zeros(bsz, max_context, 1, 1).shard(devices_2).contiguous().realize() keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1).contiguous() if start_pos > 0 else xk self.cache_k.assign(keys.pad((None,(0,max_context-start_pos-seqlen),None,None)).contiguous()).realize() attn = Attn() xk = Tensor.ones(bsz, 3, 1, 1).shard(devices_2).contiguous() attn(xk, 0) for i in range(3,6): # copied from LLaMA start_pos = Variable("start_pos", 1, max_context).bind(i) xk = Tensor.ones(bsz, 1, 1, 1).shard(devices_2).contiguous() attn(xk, start_pos) out = attn.cache_k.flatten().numpy() np.testing.assert_allclose(out, [1.,1.,1.,1.,1.,1.,0.,0.,1.,1.,1.,1.,1.,1.,0.,0.]) def test_multi_tensor_jit_param(self): @TinyJit def jf(a, b) -> Tensor: return (a + b).realize() for _ in range(5): a = Tensor.ones(256).contiguous().realize() b = Tensor.ones(256).contiguous().realize() a.shard_(devices_2) b.shard_(devices_2) c = jf(a, b) np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5) assert len(jf.jit_cache) > 0 def test_multi_tensor_jit_body(self): @TinyJit def jf() -> Tensor: a = Tensor.ones(256).contiguous().realize() b = Tensor.ones(256).contiguous().realize() a.shard_(devices_2) b.shard_(devices_2) return (a + b).realize() for _ in range(5): r = jf() np.testing.assert_allclose(r.numpy(), np.ones(256)+np.ones(256), atol=1e-4, rtol=1e-5) assert len(jf.jit_cache) > 0 #@unittest.skipIf(CI and Device.DEFAULT=="METAL", "no ICB in CI, creation of graph fails") @unittest.skip("test broken") def test_multi_device_jit_graph(self): if Device[d0].graph is None or Device[d1].graph is None: raise unittest.SkipTest("only test graphs") @TinyJit def jf(a: Tensor, b: Tensor, c: Tensor, d:Tensor): # Create 80 entries on device 0: 2 batches. for _ in range(40): a = ((a + b).realize() + (a * b).realize()).realize() # Create 80 entries on device 1: 2 batches. for _ in range(40): c = ((c + d).realize() + (c * d).realize()).realize() # Create a copy from device 0 to 1: 1 entry. a = a.to(d1).realize() # Creates one last entry on device 1: 1 batch. return (a + c).realize() a = Tensor.randn(10, 10, device=d0).realize() b = Tensor.randn(10, 10, device=d0).realize() c = Tensor.randn(10, 10, device=d1).realize() d = Tensor.randn(10, 10, device=d1).realize() ref = jf(a, b, c, d).numpy() for _ in range(5): o = jf(a, b, c, d).numpy() np.testing.assert_allclose(ref, o, atol=1e-4, rtol=1e-5) graph_d0 = Device[d0].graph.func if isinstance(Device[d0].graph, functools.partial) else Device[d0].graph graph_d1 = Device[d1].graph.func if isinstance(Device[d1].graph, functools.partial) else Device[d1].graph # Checking that 2 graphs per device, 1 copy and 1 last graph on device 1 are created. assert isinstance(jf.jit_cache[0].prg, graph_d0) assert isinstance(jf.jit_cache[1].prg, graph_d0) assert isinstance(jf.jit_cache[2].prg, graph_d1) assert isinstance(jf.jit_cache[3].prg, graph_d1) assert isinstance(jf.jit_cache[4].prg, BufferCopy) assert isinstance(jf.jit_cache[5].prg, graph_d1) @unittest.skip("no longer supports uneven shard") def test_uneven_shard(self): for N in range(1, 6): X = Tensor.rand(4, 1, 257).contiguous().realize() n = X.numpy() devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N)) X.shard_(devices, 2) np.testing.assert_equal(X.numpy(), n) np.testing.assert_equal(X.reshape(2, 2, 257).numpy(), n.reshape((2, 2, 257))) np.testing.assert_equal(X.shrink(((0,2), (0, 1), (0,257))).numpy(), n[0:2, 0:1, 0:257]) np.testing.assert_equal(X.expand((4, 4, 257)).numpy(), np.tile(n, (1, 4, 1))) np.testing.assert_equal(X.permute((0, 2, 1)).numpy(), np.transpose(n, (0, 2, 1))) @unittest.skip("no longer supports uneven shard") def test_uneven_multiple_zeros(self): for data in ([1, 2, 3, 4], [1, 2, 3], [1, 2], [1], []): for N in (1, 2, 3, 4): devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N)) # make sure something is computed on each device X = ((Tensor(data).shard(devices, axis=0) + 1).realize() - 1).realize() np.testing.assert_equal(X.numpy(), data) @unittest.skip("no longer supports uneven shard") def test_uneven_shard_with_empty(self): N = 4 X = Tensor.rand(16, 1, 3).contiguous().realize() np_x = X.numpy() devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N)) # test empty shard np.testing.assert_equal(X.shard(devices, 0).numpy(), np_x) # test reshape with empty shard np.testing.assert_equal(X.shard(devices, 0).reshape(8, 1, 6).numpy(), np_x.reshape(8, 1, 6)) @unittest.skip("no longer supports uneven shard") def test_multiple_uneven_shard(self): N = 4 X = Tensor.rand(4, 1, 257).contiguous().realize() Y = Tensor.rand(4, 1, 257).contiguous().realize() np_x, np_y = X.numpy(), Y.numpy() devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N)) X.shard_(devices, 2) Y.shard_(devices, 2) np.testing.assert_equal(X.numpy(), np_x) np.testing.assert_equal(Y.numpy(), np_y) np.testing.assert_equal((X + Y).numpy(), np_x + np_y) def test_bn_ast_on_devices(self): t = Tensor.empty((16, 64, 112, 112)).shard(devices_4, axis=0) bn = nn.BatchNorm2d(64) for p in get_parameters(bn): p.shard_(devices_4).realize() out = bn(t) scheds = [sched for sched in out.schedule() if sched.bufs[0].device in devices_4 and sched.ast.op is not Ops.COPY] assert set(sched.bufs[0].device for sched in scheds) == set(devices_4), "should have ast on each shard device" asts = [sched.ast for sched in scheds] assert len(asts) # test case to show that ast can be different on devices # TODO: make ast identical on devices #assert len(set(asts)) == 4, len(asts) # for i, ast in enumerate(asts): # print(f"{i} {ast}") def test_reshape_on_axis(self): t0 = Tensor.rand((26, 15, 7)).shard(devices_3, axis=1) # test split and rejoin to the right t1 = t0.reshape((26, 3, 5, 7)) t2 = t0.reshape((26, 3, 35)) t3 = t1.reshape((26, 15, 7)) t4 = t2.reshape((26, 105,)) for t in [t0, t1, t2, t3, t4]: assert t.lazydata.axis == 1 np.testing.assert_allclose(t.numpy().flatten(), t0.numpy().flatten()) # test shape-one axis t5 = t4.reshape((26, 1, 105)) assert t5.lazydata.axis == 2 np.testing.assert_allclose(t.numpy().flatten(), t5.numpy().flatten()) # test split and rejoin to the right and reshape to the left t5 = t0.reshape((2, 13, 3, 5, 7)) t6 = t0.reshape((13, 2, 3, 7, 5)) t7 = t0.reshape((1, 13, 2, 3, 1, 7, 5)) assert t5.lazydata.axis == 2 assert t6.lazydata.axis == 2 assert t7.lazydata.axis == 3 np.testing.assert_allclose(t5.numpy().flatten(), t0.numpy().flatten()) np.testing.assert_allclose(t6.numpy().flatten(), t0.numpy().flatten()) np.testing.assert_allclose(t7.numpy().flatten(), t0.numpy().flatten()) # test no left join with self.assertRaises((AssertionError, ValueError)): t0.reshape((26*15,7)).schedule() @unittest.skip("no longer supports uneven shard") def test_reshape_on_axis_uneven(self): def reshape_helper(t0, t, t_axis): assert t.lazydata.axis == t_axis np.testing.assert_allclose(t0.reshape(t.shape).numpy(), t.numpy()) t0 = Tensor.rand((4, 42, 15)).shard(devices_3, axis=1, splits=[14, 7, 21]) # ok to reshape as long as elements remain on same device reshape_helper(t0, t0.reshape(2, 2, 42, 3, 5), 2) # split to the right reshape_helper(t0, t0.reshape(2, 2, 6, 7, 15), 2) # split off and merge to the right reshape_helper(t0, t0.reshape(4, 6, 105), 1) # really blend the axes together reshape_helper(t0, t0.reshape(4, 30, 21), 1) # split off 1-shape reshape_helper(t0, t0.reshape(4, 1, 42, 15), 2) reshape_helper(t0, t0.reshape(4, 6, 1, 7, 15), 1) # assert if cannot maintain shard axis without moving items between devices with self.assertRaises(AssertionError): t0.reshape(4, 7, 6, 15) # assert for degenerate reshape with self.assertRaises(AssertionError): t0.reshape(4, 5, 7, 15) # assert for cannot maintain axis with self.assertRaises(AssertionError): t0.reshape(4, 3, 2, 7, 15) def test_mlb_assign_change_axis(self): t_none = Tensor.zeros((16, 16)).shard(devices_2).contiguous().realize() t_zero = Tensor.ones((16, 16)).shard(devices_2, axis=0) with self.assertRaises(AssertionError): # don't allow assigns that change axes t_none.assign(t_zero) t_none.schedule() def test_init_rand_with_multiple_devices_fail(self): # init rand with multi device is not allowed with self.assertRaises(ValueError): Tensor.rand(256, device=devices_2) def test_rand_on_multiple_devices(self): # different devices generate different rand d0_rand = Tensor.rand(256, device=d0).realize() d1_rand = Tensor.rand(256, device=d1).realize() assert not np.allclose(d0_rand.numpy(), d1_rand.numpy()) def test_rand_on_multiple_devices_manual_seed(self): Tensor.manual_seed(123) d0_rand = Tensor.rand(2, device=d0).tolist() d1_rand = Tensor.rand(2, device=d1).tolist() # manual_seed again gives the same values Tensor.manual_seed(123) d0_rand2 = Tensor.rand(2, device=d0).tolist() d1_rand2 = Tensor.rand(2, device=d1).tolist() self.assertEqual(d0_rand, d0_rand2) self.assertEqual(d1_rand, d1_rand2) # device seed is only determined by init order, so flipping init order flips rands Tensor.manual_seed(123) d1_rand_flip = Tensor.rand(2, device=d1).tolist() d0_rand_flip = Tensor.rand(2, device=d0).tolist() self.assertEqual(d0_rand, d1_rand_flip) self.assertEqual(d1_rand, d0_rand_flip) def test_rand_like_on_shard(self): t = Tensor.empty((16, 16)).shard(devices_2) t2 = Tensor.rand_like(t) self.assertEqual(t.shape, t2.shape) self.assertEqual(t.device, t2.device) self.assertEqual(t.dtype, t2.dtype) self.assertEqual(t.lazydata.axis, t2.lazydata.axis) def test_rand_like_from_alu(self): a = Tensor.ones(4, 4).shard(devices_4, axis=0) aa = a + a self.assertEqual(aa.device, devices_4) self.assertEqual(aa.lazydata.axis, 0) raa = aa.rand_like() self.assertEqual(raa.device, devices_4) self.assertEqual(raa.lazydata.axis, 0) b = Tensor.empty(4, 4).shard(devices_4, axis=None) ab = a + b self.assertEqual(ab.device, devices_4) self.assertEqual(ab.lazydata.axis, 0) rab = ab.rand_like() self.assertEqual(rab.device, devices_4) self.assertEqual(rab.lazydata.axis, 0) @unittest.skip("no longer supports uneven shard") def test_rand_like_uneven_shard(self): t = Tensor.empty((4, 42, 15)).shard(devices_3, axis=1) t2 = Tensor.rand_like(t) self.assertEqual(t.shape, t2.shape) self.assertEqual(t.device, t2.device) self.assertEqual(t.dtype, t2.dtype) self.assertEqual(t.lazydata.axis, t2.lazydata.axis) assert all(tlb.shape == t2lb.shape for tlb, t2lb in zip(t.lazydata.src, t2.lazydata.src)) def test_rand_like_none_shard(self): t = Tensor.empty((16, 16)).shard(devices_2) t2 = Tensor.rand_like(t) self.assertEqual(t.shape, t2.shape) self.assertEqual(t.device, t2.device) self.assertEqual(t.dtype, t2.dtype) self.assertEqual(t.lazydata.axis, t2.lazydata.axis) def test_rand_like_arg_dtype(self): t = Tensor.empty((16, 16), dtype=dtypes.int32).shard(devices_2, axis=1) t2 = Tensor.rand_like(t, dtype=dtypes.float32) self.assertEqual(t.dtype, dtypes.int32) self.assertEqual(t2.dtype, dtypes.float32) def test_rand_like_arg_device(self): # axis=None t = Tensor.empty((16, 16)).shard((d1, d2), axis=None) with self.assertRaises(RuntimeError): Tensor.rand_like(t, device=(d3, d4)) # axis=1 t = Tensor.empty((16, 16)).shard((d1, d2), axis=1) with self.assertRaises(RuntimeError): Tensor.rand_like(t, device=(d3, d4)) def test_dropout_on_shard(self): with Tensor.train(): X = Tensor.ones(256).to(devices_2) output = X.dropout(0.5).numpy() unique, counts = np.unique(output, return_counts=True) assert set(unique) == {0, 2}, unique assert 100 < counts[0] < 156, counts[0] def test_dropout_on_shard_axis(self): with Tensor.train(): X = Tensor.ones(512).shard(devices_2, axis=0) output = X.dropout(0.5).numpy() unique, counts = np.unique(output, return_counts=True) assert set(unique) == {0, 2}, unique assert 200 < counts[0] < 312, counts[0] @unittest.skip("no longer supports uneven shard") def test_dropout_on_uneven_shard_axis(self): with Tensor.train(): X = Tensor.ones(256).shard(devices_3, axis=0) output = X.dropout(0.5).numpy() unique, counts = np.unique(output, return_counts=True) assert set(unique) == {0, 2}, unique assert 100 < counts[0] < 156, counts[0] @unittest.skip("test depends on UOp order. TODO: fix it") def test_broadcast_const(self): for axis in (None, 0, 1): t = Tensor.zeros(16, 16).contiguous().shard(devices_4, axis).realize() t = t + 1 for si in t.schedule(): ast = si.ast.src[0] assert ast.op is Ops.STORE assert ast.src[2].op is Ops.ADD assert ast.src[2].src[0].op is Ops.LOAD assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 1 t = 2 * t for si in t.schedule(): ast = si.ast.src[0] assert ast.op is Ops.STORE assert ast.src[2].op is Ops.MUL assert ast.src[2].src[0].src[1].op is Ops.CONST and ast.src[2].src[0].src[1].arg == 2 assert ast.src[2].src[1].op is Ops.LOAD t = t + t.full_like(3) for si in t.schedule(): ast = si.ast.src[0] assert ast.op is Ops.STORE assert ast.src[2].op is Ops.ADD assert ast.src[2].src[0].op is Ops.LOAD assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 3 @unittest.skip("TODO: this requires forced_realize to be deleted.") def test_shard_memory(self): devices = (d0, d1, d2, d3) t = Tensor.zeros(16, 16).contiguous() t.shard_(devices, axis=0).realize() assert all([lb is lb.base and lb.realized.base.size == 4 * 16 for lb in t.lazydata.src]) @unittest.skip("this is unreliable on OSX") def test_clone(self): t = Tensor.rand(16, 16).shard(devices_2, axis=None) np.testing.assert_allclose(t.numpy(), t.clone().numpy()) t = Tensor.rand(16, 16).shard(devices_2, axis=0) np.testing.assert_allclose(t.numpy(), t.clone().numpy()) def test_multi_const_folding(self): with Context(TRACK_MATCH_STATS=0): a = Tensor.arange(3).realize() zeros = Tensor.zeros(3).realize() b = a.to(devices_2)*zeros.to(devices_2) sched = b.schedule() self.assertEqual(len(sched), 6) # notably, only two copies (for the arange) - vs 4 copies if we didn't fold the const copy self.assertEqual(len([x for x in sched if any(u.op is Ops.COPY for u in x.ast.toposort)]), 2) run_schedule(sched) self.assertListEqual(b.tolist(), [0, 0, 0]) @unittest.expectedFailure def test_dont_realize_intermediate_expand(self): a = Tensor.empty(16, 1).shard_(devices_2, axis=0) b = Tensor.empty(16, 16).to_(devices_2) c = Tensor.empty(16, 16).shard_(devices_2, axis=1) d = a+b (d*c).realize() assert not d.lazydata.is_realized @unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL"), "no GPU CI") class TestHandleData(unittest.TestCase): def test_copied_to_device(self): device = (d0, d1, d2, d3) t = Tensor([1, 2, 3, 4]).shard(device).realize() not_covered = t.to(d5) sched = not_covered.schedule() assert len(sched) == 1 # setup again because create_schedule has side effect t = Tensor([1, 2, 3, 4]).shard(device).realize() not_covered = t.to(d5) assert not_covered.realize().tolist() == [1, 2, 3, 4] for d in device: t = Tensor([1, 2, 3, 4]).shard(device).realize() covered = t.to(d) sched = covered.schedule() assert len(sched) == 0 # setup again because create_schedule has side effect t = Tensor([1, 2, 3, 4]).shard(device).realize() covered = t.to(d) assert covered.realize().tolist() == [1, 2, 3, 4] @unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL"), "no GPU CI") class TestShrinkMultiTensorShardedAxis(unittest.TestCase): # shrink a multitensor on sharded axis def test_shrink_bad_args(self): t = Tensor.arange(64).reshape(8, 8).contiguous().realize() t.shard_([f"{Device.DEFAULT}:{i}" for i in range(4)], axis=0) with self.assertRaises(AssertionError): # sharded axis shrink on non-device boundry is not allowed a = t.shrink(((0, 3), (0, 8))) a.schedule() with self.assertRaises(AssertionError): # cannot shrink sharded and non-sharded axis at the same time a = t.shrink(((0, 2), (2, 4))) a.schedule() a = t.shrink(((0, 2), (0, 8))) a.schedule() assert a.shape == (2, 8) assert a.lazydata.real == (True, False, False, False) with self.assertRaises(AssertionError): # cannot pad sharded and non-sharded axis at the same time p = a.pad(((0, 6), (0, 1))) p.schedule() with self.assertRaises(AssertionError): # can only pad to whole axis p = a.pad(((1, 5), (0, 0))) p.schedule() p = a.pad(((0, 6), (0, 0))) p.schedule() assert p.shape == (8, 8) assert p.lazydata.real == (True, True, True, True) @given(strat.sampled_from([dtypes.float, dtypes.int, dtypes.int64, dtypes.int16])) def test_ops(self, dtype): if not is_dtype_supported(dtype): return t = Tensor.arange(64).reshape(8, 8).contiguous().realize() t.shard_([f"{Device.DEFAULT}:{i}" for i in range(4)], axis=0) for i in range(4): print(f"{i=}") a = t.shrink(((0+2*i,2+2*i),None)) b = Tensor(t.numpy()[0+2*i:2+2*i]) assert a.shape == b.shape == (2, 8) np.testing.assert_allclose(a.numpy(), b.numpy()) assert a.lazydata.real == tuple(i==j for j in range(4)) # cast np.testing.assert_allclose(a.float().numpy(), b.float().numpy()) # elementwise np.testing.assert_allclose(a.exp().numpy(), b.exp().numpy(), rtol=1e-7, atol=1e-3) np.testing.assert_allclose(a.reciprocal().numpy(), b.reciprocal().numpy(), rtol=1e-7, atol=1e-3) np.testing.assert_allclose(a.pow(-0.5).numpy(), b.pow(-0.5).numpy(), rtol=1e-7, atol=1e-3) np.testing.assert_allclose((a+a).numpy(), (b+b).numpy(), rtol=1e-7, atol=1e-3) np.testing.assert_equal((a+1).numpy(), (b+1).numpy()) np.testing.assert_equal((1+a).numpy(), (1+b).numpy()) np.testing.assert_allclose((a.where(a+a, a)).numpy(), (b.where(b+b, b)).numpy(), rtol=1e-7, atol=1e-3) np.testing.assert_allclose((a.where(1, 0)).numpy(), (b.where(1, 0)).numpy(), rtol=1e-7, atol=1e-3) # reduce np.testing.assert_allclose(a.max().numpy(), b.max().numpy(), rtol=1e-7, atol=1e-3) np.testing.assert_allclose(a.sum().numpy(), b.sum().numpy(), rtol=1e-7, atol=1e-3) np.testing.assert_allclose(a.mean().numpy(), b.mean().numpy(), rtol=1e-7, atol=1e-3) np.testing.assert_allclose(a.max(0).numpy(), b.max(0).numpy(), rtol=1e-7, atol=1e-3) np.testing.assert_allclose(a.sum(0).numpy(), b.sum(0).numpy(), rtol=1e-7, atol=1e-3) np.testing.assert_allclose(a.mean(0).numpy(), b.mean(0).numpy(), rtol=1e-7, atol=1e-3) np.testing.assert_allclose(a.max(1).numpy(), b.max(1).numpy(), rtol=1e-7, atol=1e-3) np.testing.assert_allclose(a.sum(1).numpy(), b.sum(1).numpy(), rtol=1e-7, atol=1e-3) np.testing.assert_allclose(a.mean(1).numpy(), b.mean(1).numpy(), rtol=1e-7, atol=1e-3) # pad it back np.testing.assert_allclose(a.pad(((2*i, 2*(4-i-1)), None)).numpy(), b.pad(((2*i, 2*(4-i-1)), None)).numpy(), rtol=1e-7, atol=1e-3) # other movement np.testing.assert_allclose(a.pad((None, (1, 1))).numpy(), b.pad((None, (1, 1))).numpy(), rtol=1e-7, atol=1e-3) np.testing.assert_allclose(a.shrink((None, (1, 3))).numpy(), b.shrink((None, (1, 3))).numpy(), rtol=1e-7, atol=1e-3) np.testing.assert_allclose(a.permute((1, 0)).numpy(), b.permute((1, 0)).numpy(), rtol=1e-7, atol=1e-3) np.testing.assert_allclose(a.reshape((2, 2, 4)).numpy(), b.reshape((2, 2, 4)).numpy(), rtol=1e-7, atol=1e-3) np.testing.assert_allclose(a.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), b.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), rtol=1e-7, atol=1e-3) np.testing.assert_allclose(a.flip(-1).numpy(), b.flip(-1).numpy(), rtol=1e-7, atol=1e-3) @unittest.skip("no longer supports uneven shard") def test_uneven(self): t = Tensor.arange(24).reshape(3, 8).contiguous().realize() t.shard_([f"{Device.DEFAULT}:{i}" for i in range(2)], axis=0) a = t.shrink(((0, 2), None)) b = t.shrink(((2, 3), None)) na = t.numpy()[0:2] nb = t.numpy()[2:3] np.testing.assert_equal(a.numpy(), na) np.testing.assert_equal(b.numpy(), nb) np.testing.assert_equal((a+1).numpy(), na+1) np.testing.assert_equal((b+1).numpy(), nb+1) np.testing.assert_equal((1+a).numpy(), 1+na) np.testing.assert_equal((1+b).numpy(), 1+nb) np.testing.assert_equal((a+a).numpy(), na+na) np.testing.assert_equal((b+b).numpy(), nb+nb) @unittest.skip("why didn't this work?") def test_add_two_partitions(self): t = Tensor.arange(64).reshape(8, 8).contiguous().realize() t.shard_([f"{Device.DEFAULT}:{i}" for i in range(4)], axis=0) a = t.shrink(((2, 4), None)) b = t.shrink(((6, 8), None)) na = t.numpy()[2:4] nb = t.numpy()[6:8] np.testing.assert_equal(a.numpy(), na) np.testing.assert_equal(b.numpy(), nb) self.assertEqual(a.lazydata.real, (False, True, False, False)) self.assertEqual(b.lazydata.real, (False, False, False, True)) with self.assertRaises(AssertionError): # cannot add directly c = a + b c.schedule() c = a.pad(((2, 4), None)) + b.pad(((6, 0), None)) c.realize() self.assertEqual(c.lazydata.real, (True, True, True, True)) expected = np.concatenate([np.zeros_like(t.numpy()[0:2]), na, np.zeros_like(t.numpy()[4:6]), nb]) np.testing.assert_equal(c.numpy(), expected) def test_add_different_tensors(self): devices = [f"{Device.DEFAULT}:{i}" for i in range(4)] x = Tensor.arange(64).reshape(8, 8).contiguous().realize().shard(devices, axis=0) to_add = [] for i in range(len(devices)): to_add.append((Tensor.ones(2, 8) * i).shard(devices)) added:List[Tensor] = [] for bound, a in zip(x.lazydata.bounds, to_add): added.append(x[bound[0]:bound[1]] + a) output = added[0].cat(*added[1:]) expected = np.arange(64).reshape((8,8)) + np.array([[0,0,1,1,2,2,3,3] for _ in range(8)]).T np.testing.assert_allclose(output.numpy(), expected) @unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL"), "no GPU CI") @unittest.skipIf(Device.DEFAULT == "WEBGPU" and not OSX, "WEBGPU Vulkan can only run kernels with up to 10 buffers") class TestBatchNorm(unittest.TestCase): def test_unsynced_backprop_conv_bn(self): with Tensor.train(): from extra.lr_scheduler import OneCycleLR convs = [nn.Conv2d(3, 16, 3), nn.Conv2d(3, 16, 3)] bns = [nn.BatchNorm2d(16), nn.BatchNorm2d(16)] for p in get_parameters(convs + bns): p.shard_((d1, d2)) optim = nn.optim.Adam(get_parameters(convs + bns)) lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10) lr_sched.step() fake_image = Tensor.rand((8, 3, 32, 32)).shard((d1, d2), axis=0) f1 = fake_image.shrink(((0, 4), None, None, None)) f2 = fake_image.shrink(((4, 8), None, None, None)) out1 = bns[0](convs[0](f1)) out2 = bns[1](convs[1](f2)) out = out1.cat(out2) optim.zero_grad() out.mean().backward() optim.step() out.numpy() @unittest.skipIf(Device.DEFAULT == "WEBGPU" and not OSX, "WEBGPU Vulkan can only run kernels with up to 10 buffers") def test_unsynced_backprop_standalone_bn(self): from extra.lr_scheduler import OneCycleLR GPUS = (d1, d2) class BatchNorm: def __init__(self, num_features): self.bns:List[nn.BatchNorm2d] = [] for _ in GPUS: bn = nn.BatchNorm2d(num_features, track_running_stats=False, eps=1e-12, momentum=0.85, affine=True) self.bns.append(bn) def __call__(self, x:Tensor): bn_ts = [] each = x.shape[0]//len(self.bns) for i, bn in enumerate(self.bns): xi = x.shrink(((each*(i), each*(i+1)), None, None, None)) bni = bn(xi) bn_ts.append(bni) return bn_ts[0].cat(*bn_ts[1:]) with Tensor.train(): conv = nn.Conv2d(3, 16, 3) bn = BatchNorm(16) for p in get_parameters([conv, bn]): p.shard_(GPUS) optim = nn.optim.Adam(get_parameters([conv, bn])) lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10) lr_sched.step() fake_image = Tensor.rand((8, 3, 32, 32)).shard(GPUS, axis=0) out = bn(conv(fake_image)) optim.zero_grad() out.mean().backward() optim.step() def test_unsynced_backprop_sync_weights(self): from extra.lr_scheduler import OneCycleLR from examples.hlb_cifar10 import UnsyncedBatchNorm GPUS = (d1, d2) with Tensor.train(): conv = nn.Conv2d(3, 16, 3) bn = UnsyncedBatchNorm(16, num_devices=len(GPUS)) for k, p in get_state_dict([conv, bn]).items(): if 'running_mean' in k or 'running_var' in k: p.shard_(GPUS, axis=0) else: p.to_(GPUS) optim = nn.optim.Adam(get_parameters([conv, bn])) lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10) lr_sched.step() fake_image = Tensor.rand((8, 3, 32, 32)).shard(GPUS, axis=0) out = bn(conv(fake_image)) optim.zero_grad() out.mean().backward() optim.step() @given(strat.sampled_from((False, True))) def test_batchnorm(self, is_training): devices = [f"{Device.DEFAULT}:{i}" for i in range(4)] x = Tensor.arange(4096).reshape(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0) with Tensor.train(is_training): bns = [] for _ in range(len(devices)): bn = nn.BatchNorm2d(8) for p in get_parameters(bn): p.shard_(devices) bn.weight.requires_grad = True bn.bias.requires_grad = True bns.append(bn) bn_ts = [] for bound, bn in zip(x.lazydata.bounds, bns): bni = bn(x[bound[0]:bound[1]]) bn_ts.append(bni) bn_ts[0].cat(*bn_ts[1:]).numpy() def test_synced_vs_unsynced_bn(self): from examples.hlb_cifar10 import UnsyncedBatchNorm from tinygrad.nn import BatchNorm2d devices = [f"{Device.DEFAULT}:{i}" for i in range(4)] x = Tensor.ones(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0) with Tensor.train(): synced_bn = BatchNorm2d(8) unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices)) for p in get_parameters(synced_bn): p.shard_(devices) for k, p in get_state_dict(unsynced_bn).items(): if 'running_mean' in k or 'running_var' in k: p.shard_(devices, axis=0) else: p.to_(devices) synced_out = synced_bn(x) synced_si = list(synced_out.schedule()) unsynced_out = unsynced_bn(x) unsynced_si = list(unsynced_out.schedule()) # TODO: test synced / unsynced batchnorm cross device kernel and copies assert synced_si assert unsynced_si def helper_test_shard_op(shps, fxn, atol=1e-6, rtol=1e-3): for shp in shps: single_in = Tensor.randn(shp) multi_in = single_in.shard(devices_2, axis=0) single_out = fxn(single_in).numpy() multi_out = fxn(multi_in).numpy() try: assert single_out.shape == multi_out.shape, f"shape mismatch: single={single_out.shape} | multi={multi_out.shape}" assert single_out.dtype == multi_out.dtype, f"dtype mismatch: single={single_out.dtype} | multi={multi_out.dtype}" np.testing.assert_allclose(single_out, multi_out, atol=atol, rtol=rtol) except Exception as e: raise Exception(f"Failed shape {single_out.shape}: {e}") @unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL"), "no GPU CI") class TestTensorOps(unittest.TestCase): def test_interpolate(self): helper_test_shard_op([(4,16,16),(4,24,24)], lambda x: Tensor.interpolate(x, (19,19))) def test_bitcast(self): helper_test_shard_op([(256,), (256,)], lambda x: x.bitcast(dtypes.int)) if __name__ == '__main__': unittest.main()