openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1123 lines
44 KiB

import unittest, functools, random
from typing import List
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes, Variable
from tinygrad.ops import Ops, UOp
from tinygrad.helpers import CI, getenv, prod, Context, OSX
from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner, run_schedule
from tinygrad.engine.multi import all_reduce
import numpy as np
from hypothesis import given, strategies as strat, settings
from tinygrad.device import is_dtype_supported
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
settings.load_profile("my_profile")
d0 = f"{Device.DEFAULT}:0"
d1 = f"{Device.DEFAULT}:1"
d2 = f"{Device.DEFAULT}:2"
d3 = f"{Device.DEFAULT}:3"
d4 = f"{Device.DEFAULT}:4"
d5 = f"{Device.DEFAULT}:5"
devices_2 = (d1, d2)
devices_3 = (d1, d2, d3)
devices_4 = (d1, d2, d3, d4)
N = 128
# shard_x is "data parallel"
# shard_w is "model parallel"
def _test_allreduce(t:Tensor):
aa = (t[0:64] + t[64:128] + t[128:192] + t[192:256]).repeat([4,1]).realize()
ts = t.shard(devices_4, 0).realize()
b = Tensor(UOp.multi(*all_reduce(Ops.ADD, ts.lazydata.src), axis=0))
b.realize()
return aa, b
@unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL"), "no GPU CI")
class TestMultiTensor(unittest.TestCase):
def test_to(self):
X = Tensor.ones(256).contiguous().realize()
X.to_(devices_2)
for lb in X.lazydata.src:
assert lb.shape == (256,)
(X + X).realize()
def test_gradient(self):
X = Tensor.ones(256).contiguous().realize()
X.to_(devices_2)
grad = X.sum().gradient(X)[0]
grad.realize()
def test_shard(self):
X = Tensor.ones(256).contiguous().realize()
X.shard_(devices_2, 0)
for lb in X.lazydata.src:
assert lb.shape == (128,)
(X + X).realize()
def test_shard_not_multiple(self):
X = Tensor.ones(256).contiguous().realize()
with self.assertRaises(RuntimeError):
X.shard_(devices_3, 0)
def test_tensor_from_multi(self):
X = Tensor([1, 2], dtype=dtypes.int).shard_(devices_2, 0)
Y = Tensor(X.lazydata)
self.assertEqual(Y.device, Device.DEFAULT)
np.testing.assert_equal(X.numpy(), Y.numpy())
with self.assertRaises(AssertionError):
_ = Tensor(X.lazydata, dtype=dtypes.float)
def test_sharded_arange(self):
sharded_arange = Tensor.arange(1000).shard(devices_2, 0)
sharded_arange.realize()
np.testing.assert_equal(sharded_arange.numpy(), np.arange(1000))
def test_shard_no_recompile(self):
X = Tensor.ones(256).contiguous().realize()
X.shard_(devices_2, 0)
out = (X + X)
sched = out.schedule()
names = []
for si, ei in lower_schedule(sched):
if isinstance(ei.prg, CompiledRunner): names.append(ei.prg.p.name)
ei.run()
self.assertEqual(len(set(names)), 3), "function was relinearized"
@unittest.skip("this doesn't fold because shard_ calls contiguous on all lbs")
def test_sharded_memory(self):
# Buffer may be stuck in track_cross_buffer
for x in (d0, d1, d2, d3, d4): Device[x].synchronize()
mem_base = GlobalCounters.mem_used
X = Tensor.ones(256).contiguous().realize()
assert GlobalCounters.mem_used-mem_base== X.dtype.itemsize * 256, GlobalCounters.mem_used-mem_base
X.shard_(devices_4).realize()
for x in (d0, d1, d2, d3, d4): Device[x].synchronize()
assert GlobalCounters.mem_used-mem_base == X.dtype.itemsize * 256 * 4, GlobalCounters.mem_used-mem_base
X = Tensor.ones(256).contiguous().realize()
assert GlobalCounters.mem_used-mem_base == X.dtype.itemsize * 256, GlobalCounters.mem_used-mem_base
X.shard_(devices_4, axis=0).realize()
for x in (d0, d1, d2, d3, d4): Device[x].synchronize()
assert GlobalCounters.mem_used-mem_base == X.dtype.itemsize * 256, GlobalCounters.mem_used-mem_base
X = Tensor.ones(256).realize()
assert GlobalCounters.mem_used-mem_base == 0
X.shard_(devices_4).realize()
assert GlobalCounters.mem_used-mem_base == 0
X = Tensor.ones(256).realize()
assert GlobalCounters.mem_used-mem_base == 0
X.shard_(devices_4, axis=0).realize()
assert GlobalCounters.mem_used-mem_base == 0
def test_shard_same_device(self):
X = Tensor.ones(256).contiguous().realize()
X.shard_((d1, X.device), 0)
(X + X).realize()
def test_shard_plus_one_sum(self):
X = Tensor.ones(256).contiguous().realize()
X.shard_((d1, d2), 0)
(X + 1).sum().realize()
def test_shard_plus_one_sum_d0(self):
X = Tensor.ones(256).contiguous().realize()
X.shard_((d0, d2), 0)
(X + 1).sum().realize()
def test_numpy(self):
X = Tensor.ones(256)
X.shard_((d1, d2), 0)
np.testing.assert_allclose(X.numpy(), 1)
def _test_simple_add_axis(self, shard_x, shard_w):
X = Tensor.ones(256).contiguous().realize()
W = Tensor.ones(256).contiguous().realize()
X.shard_((d1, d2), shard_x)
W.shard_((d1, d2), shard_w)
O = X + W
np.testing.assert_allclose(O.numpy(), 2)
def test_simple_add(self): return self._test_simple_add_axis(None, None)
def test_simple_add_X(self): return self._test_simple_add_axis(0, None)
def test_simple_add_W(self): return self._test_simple_add_axis(None, 0)
def test_simple_add_XW(self): return self._test_simple_add_axis(0, 0)
def test_four_add(self):
X = Tensor.ones(256, 256).contiguous().realize()
W = Tensor.ones(256, 256).contiguous().realize()
X.shard_(devices_4, 1)
W.shard_(devices_4, None)
O = X + W
np.testing.assert_allclose(O.numpy(), 2)
def test_elementwise_dtype(self):
Tensor.manual_seed(0)
X = Tensor.randn(8, 8).realize()
W = Tensor.randn(8, 8).realize()
X.shard_(devices_4, 0)
W.shard_(devices_4, 0)
O = X.shrink(((0, 2), None)) * W.shrink(((0, 2), None)) < 2
np.testing.assert_allclose(O.numpy(), X.numpy()[0:2]*W.numpy()[0:2] < 2)
@given(strat.sampled_from((4, 5)), strat.sampled_from((devices_2, devices_3)),
strat.sampled_from((Ops.ADD, Ops.MUL, Ops.MAX)),
strat.sampled_from((None, 0, 1)), strat.sampled_from((None, 0, 1)), strat.sampled_from((1, 0, -1)))
def test_simple_reduce(self, N, devices, rop, shard_axis, reduce_axis, sign):
N = N * len(devices)
X = Tensor.rand(N*N).reshape(N, N).mul(sign)
n = X.numpy()
X.shard_(devices, shard_axis)
f = {Ops.ADD: lambda x: x.sum(reduce_axis), Ops.MUL: lambda x: x.prod(reduce_axis),
Ops.MAX: lambda x: x.max(reduce_axis)}[rop]
fX = f(X)
fn = f(n)
np.testing.assert_allclose(fX.numpy(), fn, rtol=1e-6, atol=1e-6)
def test_allreduce_naive(self):
with Context(RING=0):
a,b = _test_allreduce(Tensor.rand(256, 256))
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
def test_allreduce_ring(self):
with Context(RING=2):
a,b = _test_allreduce(Tensor.rand(256, 256))
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
def test_copy_jit(self):
@TinyJit
def copy_tensor(x:Tensor): return (x.to(f"{x.device.split(':')[0]}:1") + 1)
for _ in range(5):
t = Tensor.rand(256).realize()
x = copy_tensor(t)
np.testing.assert_equal((t+1).numpy(), x.numpy())
def test_allreduce_naive_jit(self):
with Context(RING=0):
jit_allreduce = TinyJit(_test_allreduce)
for _ in range(5):
a,b = jit_allreduce(Tensor.rand(256, 256))
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
def test_allreduce_ring_jit(self):
with Context(RING=2):
jit_allreduce = TinyJit(_test_allreduce)
for _ in range(5):
a,b = jit_allreduce(Tensor.rand(256, 256))
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
@unittest.skip("slow")
def test_fuzz_allreduce(self):
random.seed(41)
for it in range(100):
for n in range(2, 4+1):
shape = tuple([(n if i == 0 else 1) * random.randint(1, 10) for i in range(random.randint(1, 4))])
t = Tensor.rand(shape).shard_(tuple([d0, d1, d2, d3][:n]), 0)
with Context(RING=0):
a = Tensor(UOp.multi(*all_reduce(Ops.ADD, t.lazydata.src), axis=0))
with Context(RING=2):
b = Tensor(UOp.multi(*all_reduce(Ops.ADD, t.lazydata.src), axis=0))
diff = a - b
mean_err = diff.reshape((prod(diff.shape),)).abs().mean().numpy()
max_err = diff.reshape((prod(diff.shape),)).abs().max().numpy()
assert mean_err < 1e-6, f"big mean error, iteration {it}_{n}"
assert max_err < 1e-6, f"big max error, iteration {it}_{n}"
def _test_matmul_shard_axis(self, shard_x, shard_w, device):
X = Tensor.kaiming_uniform(N, N).realize()
W = Tensor.kaiming_uniform(N, N).realize()
Xs = X.shard(device, shard_x)
Ws = W.shard(device, shard_w)
O = (Xs@Ws)
np.testing.assert_allclose(X.numpy() @ W.numpy(), O.to(Device.DEFAULT).numpy(), atol=1e-5)
def _test_double_matmul_shard_axis(self, shard_x, shard_w, device):
X = Tensor.kaiming_uniform(N, N).realize()
W1 = Tensor.kaiming_uniform(N, N).realize()
W2 = Tensor.kaiming_uniform(N, N).realize()
Xs = X.shard(device, shard_x)
W1s = W1.shard(device, shard_w)
W2s = W2.shard(device, shard_w)
O = (Xs@W1s)@W2s
np.testing.assert_allclose((X.numpy() @ W1.numpy()) @ W2.numpy(), O.to(Device.DEFAULT).numpy(), atol=1e-5)
def test_matmul_shard_none(self): return self._test_matmul_shard_axis(None, None, devices_2)
def test_matmul_shard_X_0(self): return self._test_matmul_shard_axis(0, None, devices_2)
def test_matmul_shard_X_1(self): return self._test_matmul_shard_axis(1, None, devices_2)
def test_matmul_shard_W_0(self): return self._test_matmul_shard_axis(None, 0, devices_2)
def test_matmul_shard_W_1(self): return self._test_matmul_shard_axis(None, 1, devices_2)
def test_matmul_shard_0_0(self): return self._test_matmul_shard_axis(0, 0, devices_2)
def test_matmul_shard_0_1(self): return self._test_matmul_shard_axis(0, 1, devices_2)
def test_matmul_shard_1_0(self): return self._test_matmul_shard_axis(1, 0, devices_2)
def test_matmul_shard_1_1(self): return self._test_matmul_shard_axis(1, 1, devices_2)
def test_double_matmul_shard_X_0(self): return self._test_double_matmul_shard_axis(0, None, devices_2)
def test_double_matmul_shard_X_1(self): return self._test_double_matmul_shard_axis(1, None, devices_2)
def test_double_matmul_shard_W_0(self): return self._test_double_matmul_shard_axis(None, 0, devices_2)
def test_double_matmul_shard_W_1(self): return self._test_double_matmul_shard_axis(None, 1, devices_2)
def test_conv_data_shard(self):
conv = nn.Conv2d(3, 16, 3, bias=False)
for p in get_parameters(conv): p.shard_(devices_2)
fake_image = Tensor.rand((2, 3, 32, 32)).shard(devices_2, axis=0)
out = conv(fake_image)
out.numpy()
def test_conv_bias_data_shard(self):
conv = nn.Conv2d(3, 16, 3)
for p in get_parameters(conv): p.shard_(devices_2)
fake_image = Tensor.rand((2, 3, 32, 32)).shard(devices_2, axis=0)
out = conv(fake_image)
out.numpy()
def test_backprop_conv(self):
with Tensor.train():
conv = nn.Conv2d(3, 16, 3)
for p in get_parameters(conv): p.shard_(devices_2)
optim = nn.optim.Adam(get_parameters(conv))
fake_image = Tensor.rand((2, 3, 32, 32)).shard(devices_2, axis=0)
out = conv(fake_image)
optim.zero_grad()
out.mean().backward()
#for p in get_parameters(conv): p.grad.realize()
optim.step()
out.numpy()
def test_backprop_conv_wino(self):
with Context(WINO=1): self.test_backprop_conv()
def test_backward_sum(self):
x = Tensor([[1.,2,3,4], [5,6,7,8]]).shard(devices_2, axis=0)
w = Tensor([1.,2,3,4], requires_grad=True).shard(devices_2)
out = x * w
out.mean().backward()
tst = w.grad.numpy()
np.testing.assert_allclose(tst, [0.75, 1., 1.25, 1.5])
def test_lr_scheduler_OneCycleLR(self):
from extra.lr_scheduler import OneCycleLR
conv = nn.Conv2d(3, 16, 3)
for p in get_parameters(conv): p.shard_(devices_2)
optim = nn.optim.SGD(get_parameters(conv))
lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10)
lr_sched.step()
def test_embedding(self):
B, T, embed_size, vocab_size = 4, 10, 20, 28
layer = nn.Embedding(vocab_size, embed_size)
x = Tensor(np.random.randint(0, vocab_size, (B, T), dtype=np.int32))
z = layer(x)
layer_sharded = nn.Embedding(vocab_size, embed_size)
layer_sharded.weight.replace(layer.weight.shard(devices_2, axis=1)).realize()
x_sharded = x.shard(devices_2, axis=None)
z_shard = layer_sharded(x_sharded)
np.testing.assert_allclose(z.numpy(), z_shard.numpy(), atol=1e-6, rtol=1e-6)
def test_rmsnorm(self):
B, T, embed_size = 4, 10, 20
norm = nn.RMSNorm(embed_size)
x = Tensor.rand((B, T, embed_size)).contiguous().realize()
y = norm(x)
# for norm layers, the correct way to shard weights is duplication
norm_sharded = nn.RMSNorm(embed_size)
norm_sharded.weight.shard_(devices_2, axis=None).realize()
# if x is being sharded, then all-reduce is involved
x_sharded = x.shard(devices_2, axis=2).realize()
y_shard = norm_sharded(x_sharded).realize()
np.testing.assert_allclose(y.numpy(), y_shard.numpy(), atol=1e-6, rtol=1e-6)
# if x is being duplicated, then the operations remain inside each GPU
# which is the common case
x_sharded = x.shard(devices_2, axis=None).realize()
y_shard = norm_sharded(x_sharded).realize()
np.testing.assert_allclose(y.numpy(), y_shard.numpy(), atol=1e-6, rtol=1e-6)
# NOTE: this is failing on LLVM CI, no idea why. Works locally.
@unittest.skipIf(CI and Device.DEFAULT in ("CUDA", "NV", "LLVM"), "slow")
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and not OSX, "WEBGPU Vulkan can only run kernels with up to 10 buffers")
def test_data_parallel_resnet(self):
from extra.models.resnet import ResNet18
fake_image = Tensor.rand((2, 3, 224//8, 224//8))
fake_image_sharded = fake_image.shard(devices_2, axis=0)
m = ResNet18()
m.load_from_pretrained()
real_output = m(fake_image).log_softmax().numpy()
for p in get_parameters(m): p.shard_(devices_2).realize()
GlobalCounters.reset()
shard_output = m(fake_image_sharded).log_softmax().realize()
assert shard_output.lazydata.src[0].shape == (1, 1000)
assert shard_output.lazydata.src[1].shape == (1, 1000)
shard_output_np = shard_output.numpy()
np.testing.assert_allclose(real_output, shard_output_np, atol=1e-6, rtol=1e-6)
@unittest.skipIf(CI and Device.DEFAULT in ("CUDA", "NV", "LLVM"), "slow, and flaky on LLVM")
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and not OSX, "WEBGPU Vulkan can only run kernels with up to 10 buffers")
def test_data_parallel_resnet_train_step(self):
from extra.models.resnet import ResNet18
from tinygrad.nn.optim import LARS
fake_image = Tensor.rand((2, 3, 224//8, 224//8))
fake_image_sharded = fake_image.shard(devices_2, axis=0)
labels = Tensor.randint(2, low=0, high=1000)
labels_sharded = labels.shard(devices_2, axis=0)
m = ResNet18()
optimizer = LARS(get_parameters(m), 0.1) # set requires_grad for all params
optimizer.zero_grad()
m.load_from_pretrained()
output = m(fake_image).sparse_categorical_crossentropy(labels, label_smoothing=0.1)
output.backward()
grad = m.conv1.weight.grad.numpy()
for p in get_parameters(m): p.shard_(devices_2).realize()
GlobalCounters.reset()
optimizer.zero_grad()
shard_output = m(fake_image_sharded).sparse_categorical_crossentropy(labels_sharded, label_smoothing=0.1)
shard_output.backward()
shard_grad = m.conv1.weight.grad.numpy()
# sometimes there is zeros in these grads... why?
np.testing.assert_allclose(grad, shard_grad, atol=1e-5, rtol=1e-5)
def test_assign_kv_cache_multi(self):
bsz, max_context = 2, 8
class Attn:
@TinyJit
def __call__(self, xk:Tensor, start_pos:UOp):
seqlen = xk.shape[1]
if not hasattr(self, "cache_k"):
self.cache_k = Tensor.zeros(bsz, max_context, 1, 1).shard(devices_2).contiguous().realize()
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1).contiguous() if start_pos > 0 else xk
self.cache_k.assign(keys.pad((None,(0,max_context-start_pos-seqlen),None,None)).contiguous()).realize()
attn = Attn()
xk = Tensor.ones(bsz, 3, 1, 1).shard(devices_2).contiguous()
attn(xk, 0)
for i in range(3,6):
# copied from LLaMA
start_pos = Variable("start_pos", 1, max_context).bind(i)
xk = Tensor.ones(bsz, 1, 1, 1).shard(devices_2).contiguous()
attn(xk, start_pos)
out = attn.cache_k.flatten().numpy()
np.testing.assert_allclose(out, [1.,1.,1.,1.,1.,1.,0.,0.,1.,1.,1.,1.,1.,1.,0.,0.])
def test_multi_tensor_jit_param(self):
@TinyJit
def jf(a, b) -> Tensor:
return (a + b).realize()
for _ in range(5):
a = Tensor.ones(256).contiguous().realize()
b = Tensor.ones(256).contiguous().realize()
a.shard_(devices_2)
b.shard_(devices_2)
c = jf(a, b)
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
assert len(jf.jit_cache) > 0
def test_multi_tensor_jit_body(self):
@TinyJit
def jf() -> Tensor:
a = Tensor.ones(256).contiguous().realize()
b = Tensor.ones(256).contiguous().realize()
a.shard_(devices_2)
b.shard_(devices_2)
return (a + b).realize()
for _ in range(5):
r = jf()
np.testing.assert_allclose(r.numpy(), np.ones(256)+np.ones(256), atol=1e-4, rtol=1e-5)
assert len(jf.jit_cache) > 0
#@unittest.skipIf(CI and Device.DEFAULT=="METAL", "no ICB in CI, creation of graph fails")
@unittest.skip("test broken")
def test_multi_device_jit_graph(self):
if Device[d0].graph is None or Device[d1].graph is None: raise unittest.SkipTest("only test graphs")
@TinyJit
def jf(a: Tensor, b: Tensor, c: Tensor, d:Tensor):
# Create 80 entries on device 0: 2 batches.
for _ in range(40):
a = ((a + b).realize() + (a * b).realize()).realize()
# Create 80 entries on device 1: 2 batches.
for _ in range(40):
c = ((c + d).realize() + (c * d).realize()).realize()
# Create a copy from device 0 to 1: 1 entry.
a = a.to(d1).realize()
# Creates one last entry on device 1: 1 batch.
return (a + c).realize()
a = Tensor.randn(10, 10, device=d0).realize()
b = Tensor.randn(10, 10, device=d0).realize()
c = Tensor.randn(10, 10, device=d1).realize()
d = Tensor.randn(10, 10, device=d1).realize()
ref = jf(a, b, c, d).numpy()
for _ in range(5):
o = jf(a, b, c, d).numpy()
np.testing.assert_allclose(ref, o, atol=1e-4, rtol=1e-5)
graph_d0 = Device[d0].graph.func if isinstance(Device[d0].graph, functools.partial) else Device[d0].graph
graph_d1 = Device[d1].graph.func if isinstance(Device[d1].graph, functools.partial) else Device[d1].graph
# Checking that 2 graphs per device, 1 copy and 1 last graph on device 1 are created.
assert isinstance(jf.jit_cache[0].prg, graph_d0)
assert isinstance(jf.jit_cache[1].prg, graph_d0)
assert isinstance(jf.jit_cache[2].prg, graph_d1)
assert isinstance(jf.jit_cache[3].prg, graph_d1)
assert isinstance(jf.jit_cache[4].prg, BufferCopy)
assert isinstance(jf.jit_cache[5].prg, graph_d1)
@unittest.skip("no longer supports uneven shard")
def test_uneven_shard(self):
for N in range(1, 6):
X = Tensor.rand(4, 1, 257).contiguous().realize()
n = X.numpy()
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
X.shard_(devices, 2)
np.testing.assert_equal(X.numpy(), n)
np.testing.assert_equal(X.reshape(2, 2, 257).numpy(), n.reshape((2, 2, 257)))
np.testing.assert_equal(X.shrink(((0,2), (0, 1), (0,257))).numpy(), n[0:2, 0:1, 0:257])
np.testing.assert_equal(X.expand((4, 4, 257)).numpy(), np.tile(n, (1, 4, 1)))
np.testing.assert_equal(X.permute((0, 2, 1)).numpy(), np.transpose(n, (0, 2, 1)))
@unittest.skip("no longer supports uneven shard")
def test_uneven_multiple_zeros(self):
for data in ([1, 2, 3, 4], [1, 2, 3], [1, 2], [1], []):
for N in (1, 2, 3, 4):
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
# make sure something is computed on each device
X = ((Tensor(data).shard(devices, axis=0) + 1).realize() - 1).realize()
np.testing.assert_equal(X.numpy(), data)
@unittest.skip("no longer supports uneven shard")
def test_uneven_shard_with_empty(self):
N = 4
X = Tensor.rand(16, 1, 3).contiguous().realize()
np_x = X.numpy()
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
# test empty shard
np.testing.assert_equal(X.shard(devices, 0).numpy(), np_x)
# test reshape with empty shard
np.testing.assert_equal(X.shard(devices, 0).reshape(8, 1, 6).numpy(), np_x.reshape(8, 1, 6))
@unittest.skip("no longer supports uneven shard")
def test_multiple_uneven_shard(self):
N = 4
X = Tensor.rand(4, 1, 257).contiguous().realize()
Y = Tensor.rand(4, 1, 257).contiguous().realize()
np_x, np_y = X.numpy(), Y.numpy()
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
X.shard_(devices, 2)
Y.shard_(devices, 2)
np.testing.assert_equal(X.numpy(), np_x)
np.testing.assert_equal(Y.numpy(), np_y)
np.testing.assert_equal((X + Y).numpy(), np_x + np_y)
def test_bn_ast_on_devices(self):
t = Tensor.empty((16, 64, 112, 112)).shard(devices_4, axis=0)
bn = nn.BatchNorm2d(64)
for p in get_parameters(bn): p.shard_(devices_4).realize()
out = bn(t)
scheds = [sched for sched in out.schedule() if sched.bufs[0].device in devices_4 and sched.ast.op is not Ops.COPY]
assert set(sched.bufs[0].device for sched in scheds) == set(devices_4), "should have ast on each shard device"
asts = [sched.ast for sched in scheds]
assert len(asts)
# test case to show that ast can be different on devices
# TODO: make ast identical on devices
#assert len(set(asts)) == 4, len(asts)
# for i, ast in enumerate(asts):
# print(f"{i} {ast}")
def test_reshape_on_axis(self):
t0 = Tensor.rand((26, 15, 7)).shard(devices_3, axis=1)
# test split and rejoin to the right
t1 = t0.reshape((26, 3, 5, 7))
t2 = t0.reshape((26, 3, 35))
t3 = t1.reshape((26, 15, 7))
t4 = t2.reshape((26, 105,))
for t in [t0, t1, t2, t3, t4]:
assert t.lazydata.axis == 1
np.testing.assert_allclose(t.numpy().flatten(), t0.numpy().flatten())
# test shape-one axis
t5 = t4.reshape((26, 1, 105))
assert t5.lazydata.axis == 2
np.testing.assert_allclose(t.numpy().flatten(), t5.numpy().flatten())
# test split and rejoin to the right and reshape to the left
t5 = t0.reshape((2, 13, 3, 5, 7))
t6 = t0.reshape((13, 2, 3, 7, 5))
t7 = t0.reshape((1, 13, 2, 3, 1, 7, 5))
assert t5.lazydata.axis == 2
assert t6.lazydata.axis == 2
assert t7.lazydata.axis == 3
np.testing.assert_allclose(t5.numpy().flatten(), t0.numpy().flatten())
np.testing.assert_allclose(t6.numpy().flatten(), t0.numpy().flatten())
np.testing.assert_allclose(t7.numpy().flatten(), t0.numpy().flatten())
# test no left join
with self.assertRaises((AssertionError, ValueError)):
t0.reshape((26*15,7)).schedule()
@unittest.skip("no longer supports uneven shard")
def test_reshape_on_axis_uneven(self):
def reshape_helper(t0, t, t_axis):
assert t.lazydata.axis == t_axis
np.testing.assert_allclose(t0.reshape(t.shape).numpy(), t.numpy())
t0 = Tensor.rand((4, 42, 15)).shard(devices_3, axis=1, splits=[14, 7, 21])
# ok to reshape as long as elements remain on same device
reshape_helper(t0, t0.reshape(2, 2, 42, 3, 5), 2)
# split to the right
reshape_helper(t0, t0.reshape(2, 2, 6, 7, 15), 2)
# split off and merge to the right
reshape_helper(t0, t0.reshape(4, 6, 105), 1)
# really blend the axes together
reshape_helper(t0, t0.reshape(4, 30, 21), 1)
# split off 1-shape
reshape_helper(t0, t0.reshape(4, 1, 42, 15), 2)
reshape_helper(t0, t0.reshape(4, 6, 1, 7, 15), 1)
# assert if cannot maintain shard axis without moving items between devices
with self.assertRaises(AssertionError): t0.reshape(4, 7, 6, 15)
# assert for degenerate reshape
with self.assertRaises(AssertionError): t0.reshape(4, 5, 7, 15)
# assert for cannot maintain axis
with self.assertRaises(AssertionError): t0.reshape(4, 3, 2, 7, 15)
def test_mlb_assign_change_axis(self):
t_none = Tensor.zeros((16, 16)).shard(devices_2).contiguous().realize()
t_zero = Tensor.ones((16, 16)).shard(devices_2, axis=0)
with self.assertRaises(AssertionError):
# don't allow assigns that change axes
t_none.assign(t_zero)
t_none.schedule()
def test_init_rand_with_multiple_devices_fail(self):
# init rand with multi device is not allowed
with self.assertRaises(ValueError):
Tensor.rand(256, device=devices_2)
def test_rand_on_multiple_devices(self):
# different devices generate different rand
d0_rand = Tensor.rand(256, device=d0).realize()
d1_rand = Tensor.rand(256, device=d1).realize()
assert not np.allclose(d0_rand.numpy(), d1_rand.numpy())
def test_rand_on_multiple_devices_manual_seed(self):
Tensor.manual_seed(123)
d0_rand = Tensor.rand(2, device=d0).tolist()
d1_rand = Tensor.rand(2, device=d1).tolist()
# manual_seed again gives the same values
Tensor.manual_seed(123)
d0_rand2 = Tensor.rand(2, device=d0).tolist()
d1_rand2 = Tensor.rand(2, device=d1).tolist()
self.assertEqual(d0_rand, d0_rand2)
self.assertEqual(d1_rand, d1_rand2)
# device seed is only determined by init order, so flipping init order flips rands
Tensor.manual_seed(123)
d1_rand_flip = Tensor.rand(2, device=d1).tolist()
d0_rand_flip = Tensor.rand(2, device=d0).tolist()
self.assertEqual(d0_rand, d1_rand_flip)
self.assertEqual(d1_rand, d0_rand_flip)
def test_rand_like_on_shard(self):
t = Tensor.empty((16, 16)).shard(devices_2)
t2 = Tensor.rand_like(t)
self.assertEqual(t.shape, t2.shape)
self.assertEqual(t.device, t2.device)
self.assertEqual(t.dtype, t2.dtype)
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
def test_rand_like_from_alu(self):
a = Tensor.ones(4, 4).shard(devices_4, axis=0)
aa = a + a
self.assertEqual(aa.device, devices_4)
self.assertEqual(aa.lazydata.axis, 0)
raa = aa.rand_like()
self.assertEqual(raa.device, devices_4)
self.assertEqual(raa.lazydata.axis, 0)
b = Tensor.empty(4, 4).shard(devices_4, axis=None)
ab = a + b
self.assertEqual(ab.device, devices_4)
self.assertEqual(ab.lazydata.axis, 0)
rab = ab.rand_like()
self.assertEqual(rab.device, devices_4)
self.assertEqual(rab.lazydata.axis, 0)
@unittest.skip("no longer supports uneven shard")
def test_rand_like_uneven_shard(self):
t = Tensor.empty((4, 42, 15)).shard(devices_3, axis=1)
t2 = Tensor.rand_like(t)
self.assertEqual(t.shape, t2.shape)
self.assertEqual(t.device, t2.device)
self.assertEqual(t.dtype, t2.dtype)
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
assert all(tlb.shape == t2lb.shape for tlb, t2lb in zip(t.lazydata.src, t2.lazydata.src))
def test_rand_like_none_shard(self):
t = Tensor.empty((16, 16)).shard(devices_2)
t2 = Tensor.rand_like(t)
self.assertEqual(t.shape, t2.shape)
self.assertEqual(t.device, t2.device)
self.assertEqual(t.dtype, t2.dtype)
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
def test_rand_like_arg_dtype(self):
t = Tensor.empty((16, 16), dtype=dtypes.int32).shard(devices_2, axis=1)
t2 = Tensor.rand_like(t, dtype=dtypes.float32)
self.assertEqual(t.dtype, dtypes.int32)
self.assertEqual(t2.dtype, dtypes.float32)
def test_rand_like_arg_device(self):
# axis=None
t = Tensor.empty((16, 16)).shard((d1, d2), axis=None)
with self.assertRaises(RuntimeError):
Tensor.rand_like(t, device=(d3, d4))
# axis=1
t = Tensor.empty((16, 16)).shard((d1, d2), axis=1)
with self.assertRaises(RuntimeError):
Tensor.rand_like(t, device=(d3, d4))
def test_dropout_on_shard(self):
with Tensor.train():
X = Tensor.ones(256).to(devices_2)
output = X.dropout(0.5).numpy()
unique, counts = np.unique(output, return_counts=True)
assert set(unique) == {0, 2}, unique
assert 100 < counts[0] < 156, counts[0]
def test_dropout_on_shard_axis(self):
with Tensor.train():
X = Tensor.ones(512).shard(devices_2, axis=0)
output = X.dropout(0.5).numpy()
unique, counts = np.unique(output, return_counts=True)
assert set(unique) == {0, 2}, unique
assert 200 < counts[0] < 312, counts[0]
@unittest.skip("no longer supports uneven shard")
def test_dropout_on_uneven_shard_axis(self):
with Tensor.train():
X = Tensor.ones(256).shard(devices_3, axis=0)
output = X.dropout(0.5).numpy()
unique, counts = np.unique(output, return_counts=True)
assert set(unique) == {0, 2}, unique
assert 100 < counts[0] < 156, counts[0]
@unittest.skip("test depends on UOp order. TODO: fix it")
def test_broadcast_const(self):
for axis in (None, 0, 1):
t = Tensor.zeros(16, 16).contiguous().shard(devices_4, axis).realize()
t = t + 1
for si in t.schedule():
ast = si.ast.src[0]
assert ast.op is Ops.STORE
assert ast.src[2].op is Ops.ADD
assert ast.src[2].src[0].op is Ops.LOAD
assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 1
t = 2 * t
for si in t.schedule():
ast = si.ast.src[0]
assert ast.op is Ops.STORE
assert ast.src[2].op is Ops.MUL
assert ast.src[2].src[0].src[1].op is Ops.CONST and ast.src[2].src[0].src[1].arg == 2
assert ast.src[2].src[1].op is Ops.LOAD
t = t + t.full_like(3)
for si in t.schedule():
ast = si.ast.src[0]
assert ast.op is Ops.STORE
assert ast.src[2].op is Ops.ADD
assert ast.src[2].src[0].op is Ops.LOAD
assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 3
@unittest.skip("TODO: this requires forced_realize to be deleted.")
def test_shard_memory(self):
devices = (d0, d1, d2, d3)
t = Tensor.zeros(16, 16).contiguous()
t.shard_(devices, axis=0).realize()
assert all([lb is lb.base and lb.realized.base.size == 4 * 16 for lb in t.lazydata.src])
@unittest.skip("this is unreliable on OSX")
def test_clone(self):
t = Tensor.rand(16, 16).shard(devices_2, axis=None)
np.testing.assert_allclose(t.numpy(), t.clone().numpy())
t = Tensor.rand(16, 16).shard(devices_2, axis=0)
np.testing.assert_allclose(t.numpy(), t.clone().numpy())
def test_multi_const_folding(self):
with Context(TRACK_MATCH_STATS=0):
a = Tensor.arange(3).realize()
zeros = Tensor.zeros(3).realize()
b = a.to(devices_2)*zeros.to(devices_2)
sched = b.schedule()
self.assertEqual(len(sched), 6)
# notably, only two copies (for the arange) - vs 4 copies if we didn't fold the const copy
self.assertEqual(len([x for x in sched if any(u.op is Ops.COPY for u in x.ast.toposort)]), 2)
run_schedule(sched)
self.assertListEqual(b.tolist(), [0, 0, 0])
@unittest.expectedFailure
def test_dont_realize_intermediate_expand(self):
a = Tensor.empty(16, 1).shard_(devices_2, axis=0)
b = Tensor.empty(16, 16).to_(devices_2)
c = Tensor.empty(16, 16).shard_(devices_2, axis=1)
d = a+b
(d*c).realize()
assert not d.lazydata.is_realized
@unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL"), "no GPU CI")
class TestHandleData(unittest.TestCase):
def test_copied_to_device(self):
device = (d0, d1, d2, d3)
t = Tensor([1, 2, 3, 4]).shard(device).realize()
not_covered = t.to(d5)
sched = not_covered.schedule()
assert len(sched) == 1
# setup again because create_schedule has side effect
t = Tensor([1, 2, 3, 4]).shard(device).realize()
not_covered = t.to(d5)
assert not_covered.realize().tolist() == [1, 2, 3, 4]
for d in device:
t = Tensor([1, 2, 3, 4]).shard(device).realize()
covered = t.to(d)
sched = covered.schedule()
assert len(sched) == 0
# setup again because create_schedule has side effect
t = Tensor([1, 2, 3, 4]).shard(device).realize()
covered = t.to(d)
assert covered.realize().tolist() == [1, 2, 3, 4]
@unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL"), "no GPU CI")
class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
# shrink a multitensor on sharded axis
def test_shrink_bad_args(self):
t = Tensor.arange(64).reshape(8, 8).contiguous().realize()
t.shard_([f"{Device.DEFAULT}:{i}" for i in range(4)], axis=0)
with self.assertRaises(AssertionError):
# sharded axis shrink on non-device boundry is not allowed
a = t.shrink(((0, 3), (0, 8)))
a.schedule()
with self.assertRaises(AssertionError):
# cannot shrink sharded and non-sharded axis at the same time
a = t.shrink(((0, 2), (2, 4)))
a.schedule()
a = t.shrink(((0, 2), (0, 8)))
a.schedule()
assert a.shape == (2, 8)
assert a.lazydata.real == (True, False, False, False)
with self.assertRaises(AssertionError):
# cannot pad sharded and non-sharded axis at the same time
p = a.pad(((0, 6), (0, 1)))
p.schedule()
with self.assertRaises(AssertionError):
# can only pad to whole axis
p = a.pad(((1, 5), (0, 0)))
p.schedule()
p = a.pad(((0, 6), (0, 0)))
p.schedule()
assert p.shape == (8, 8)
assert p.lazydata.real == (True, True, True, True)
@given(strat.sampled_from([dtypes.float, dtypes.int, dtypes.int64, dtypes.int16]))
def test_ops(self, dtype):
if not is_dtype_supported(dtype): return
t = Tensor.arange(64).reshape(8, 8).contiguous().realize()
t.shard_([f"{Device.DEFAULT}:{i}" for i in range(4)], axis=0)
for i in range(4):
print(f"{i=}")
a = t.shrink(((0+2*i,2+2*i),None))
b = Tensor(t.numpy()[0+2*i:2+2*i])
assert a.shape == b.shape == (2, 8)
np.testing.assert_allclose(a.numpy(), b.numpy())
assert a.lazydata.real == tuple(i==j for j in range(4))
# cast
np.testing.assert_allclose(a.float().numpy(), b.float().numpy())
# elementwise
np.testing.assert_allclose(a.exp().numpy(), b.exp().numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.reciprocal().numpy(), b.reciprocal().numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.pow(-0.5).numpy(), b.pow(-0.5).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose((a+a).numpy(), (b+b).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_equal((a+1).numpy(), (b+1).numpy())
np.testing.assert_equal((1+a).numpy(), (1+b).numpy())
np.testing.assert_allclose((a.where(a+a, a)).numpy(), (b.where(b+b, b)).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose((a.where(1, 0)).numpy(), (b.where(1, 0)).numpy(), rtol=1e-7, atol=1e-3)
# reduce
np.testing.assert_allclose(a.max().numpy(), b.max().numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.sum().numpy(), b.sum().numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.mean().numpy(), b.mean().numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.max(0).numpy(), b.max(0).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.sum(0).numpy(), b.sum(0).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.mean(0).numpy(), b.mean(0).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.max(1).numpy(), b.max(1).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.sum(1).numpy(), b.sum(1).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.mean(1).numpy(), b.mean(1).numpy(), rtol=1e-7, atol=1e-3)
# pad it back
np.testing.assert_allclose(a.pad(((2*i, 2*(4-i-1)), None)).numpy(), b.pad(((2*i, 2*(4-i-1)), None)).numpy(), rtol=1e-7, atol=1e-3)
# other movement
np.testing.assert_allclose(a.pad((None, (1, 1))).numpy(), b.pad((None, (1, 1))).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.shrink((None, (1, 3))).numpy(), b.shrink((None, (1, 3))).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.permute((1, 0)).numpy(), b.permute((1, 0)).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.reshape((2, 2, 4)).numpy(), b.reshape((2, 2, 4)).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), b.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.flip(-1).numpy(), b.flip(-1).numpy(), rtol=1e-7, atol=1e-3)
@unittest.skip("no longer supports uneven shard")
def test_uneven(self):
t = Tensor.arange(24).reshape(3, 8).contiguous().realize()
t.shard_([f"{Device.DEFAULT}:{i}" for i in range(2)], axis=0)
a = t.shrink(((0, 2), None))
b = t.shrink(((2, 3), None))
na = t.numpy()[0:2]
nb = t.numpy()[2:3]
np.testing.assert_equal(a.numpy(), na)
np.testing.assert_equal(b.numpy(), nb)
np.testing.assert_equal((a+1).numpy(), na+1)
np.testing.assert_equal((b+1).numpy(), nb+1)
np.testing.assert_equal((1+a).numpy(), 1+na)
np.testing.assert_equal((1+b).numpy(), 1+nb)
np.testing.assert_equal((a+a).numpy(), na+na)
np.testing.assert_equal((b+b).numpy(), nb+nb)
@unittest.skip("why didn't this work?")
def test_add_two_partitions(self):
t = Tensor.arange(64).reshape(8, 8).contiguous().realize()
t.shard_([f"{Device.DEFAULT}:{i}" for i in range(4)], axis=0)
a = t.shrink(((2, 4), None))
b = t.shrink(((6, 8), None))
na = t.numpy()[2:4]
nb = t.numpy()[6:8]
np.testing.assert_equal(a.numpy(), na)
np.testing.assert_equal(b.numpy(), nb)
self.assertEqual(a.lazydata.real, (False, True, False, False))
self.assertEqual(b.lazydata.real, (False, False, False, True))
with self.assertRaises(AssertionError):
# cannot add directly
c = a + b
c.schedule()
c = a.pad(((2, 4), None)) + b.pad(((6, 0), None))
c.realize()
self.assertEqual(c.lazydata.real, (True, True, True, True))
expected = np.concatenate([np.zeros_like(t.numpy()[0:2]), na, np.zeros_like(t.numpy()[4:6]), nb])
np.testing.assert_equal(c.numpy(), expected)
def test_add_different_tensors(self):
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
x = Tensor.arange(64).reshape(8, 8).contiguous().realize().shard(devices, axis=0)
to_add = []
for i in range(len(devices)):
to_add.append((Tensor.ones(2, 8) * i).shard(devices))
added:List[Tensor] = []
for bound, a in zip(x.lazydata.bounds, to_add):
added.append(x[bound[0]:bound[1]] + a)
output = added[0].cat(*added[1:])
expected = np.arange(64).reshape((8,8)) + np.array([[0,0,1,1,2,2,3,3] for _ in range(8)]).T
np.testing.assert_allclose(output.numpy(), expected)
@unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL"), "no GPU CI")
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and not OSX, "WEBGPU Vulkan can only run kernels with up to 10 buffers")
class TestBatchNorm(unittest.TestCase):
def test_unsynced_backprop_conv_bn(self):
with Tensor.train():
from extra.lr_scheduler import OneCycleLR
convs = [nn.Conv2d(3, 16, 3), nn.Conv2d(3, 16, 3)]
bns = [nn.BatchNorm2d(16), nn.BatchNorm2d(16)]
for p in get_parameters(convs + bns):
p.shard_((d1, d2))
optim = nn.optim.Adam(get_parameters(convs + bns))
lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10)
lr_sched.step()
fake_image = Tensor.rand((8, 3, 32, 32)).shard((d1, d2), axis=0)
f1 = fake_image.shrink(((0, 4), None, None, None))
f2 = fake_image.shrink(((4, 8), None, None, None))
out1 = bns[0](convs[0](f1))
out2 = bns[1](convs[1](f2))
out = out1.cat(out2)
optim.zero_grad()
out.mean().backward()
optim.step()
out.numpy()
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and not OSX, "WEBGPU Vulkan can only run kernels with up to 10 buffers")
def test_unsynced_backprop_standalone_bn(self):
from extra.lr_scheduler import OneCycleLR
GPUS = (d1, d2)
class BatchNorm:
def __init__(self, num_features):
self.bns:List[nn.BatchNorm2d] = []
for _ in GPUS:
bn = nn.BatchNorm2d(num_features, track_running_stats=False, eps=1e-12, momentum=0.85, affine=True)
self.bns.append(bn)
def __call__(self, x:Tensor):
bn_ts = []
each = x.shape[0]//len(self.bns)
for i, bn in enumerate(self.bns):
xi = x.shrink(((each*(i), each*(i+1)), None, None, None))
bni = bn(xi)
bn_ts.append(bni)
return bn_ts[0].cat(*bn_ts[1:])
with Tensor.train():
conv = nn.Conv2d(3, 16, 3)
bn = BatchNorm(16)
for p in get_parameters([conv, bn]):
p.shard_(GPUS)
optim = nn.optim.Adam(get_parameters([conv, bn]))
lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10)
lr_sched.step()
fake_image = Tensor.rand((8, 3, 32, 32)).shard(GPUS, axis=0)
out = bn(conv(fake_image))
optim.zero_grad()
out.mean().backward()
optim.step()
def test_unsynced_backprop_sync_weights(self):
from extra.lr_scheduler import OneCycleLR
from examples.hlb_cifar10 import UnsyncedBatchNorm
GPUS = (d1, d2)
with Tensor.train():
conv = nn.Conv2d(3, 16, 3)
bn = UnsyncedBatchNorm(16, num_devices=len(GPUS))
for k, p in get_state_dict([conv, bn]).items():
if 'running_mean' in k or 'running_var' in k:
p.shard_(GPUS, axis=0)
else:
p.to_(GPUS)
optim = nn.optim.Adam(get_parameters([conv, bn]))
lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10)
lr_sched.step()
fake_image = Tensor.rand((8, 3, 32, 32)).shard(GPUS, axis=0)
out = bn(conv(fake_image))
optim.zero_grad()
out.mean().backward()
optim.step()
@given(strat.sampled_from((False, True)))
def test_batchnorm(self, is_training):
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
x = Tensor.arange(4096).reshape(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0)
with Tensor.train(is_training):
bns = []
for _ in range(len(devices)):
bn = nn.BatchNorm2d(8)
for p in get_parameters(bn):
p.shard_(devices)
bn.weight.requires_grad = True
bn.bias.requires_grad = True
bns.append(bn)
bn_ts = []
for bound, bn in zip(x.lazydata.bounds, bns):
bni = bn(x[bound[0]:bound[1]])
bn_ts.append(bni)
bn_ts[0].cat(*bn_ts[1:]).numpy()
def test_synced_vs_unsynced_bn(self):
from examples.hlb_cifar10 import UnsyncedBatchNorm
from tinygrad.nn import BatchNorm2d
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
x = Tensor.ones(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0)
with Tensor.train():
synced_bn = BatchNorm2d(8)
unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices))
for p in get_parameters(synced_bn):
p.shard_(devices)
for k, p in get_state_dict(unsynced_bn).items():
if 'running_mean' in k or 'running_var' in k:
p.shard_(devices, axis=0)
else:
p.to_(devices)
synced_out = synced_bn(x)
synced_si = list(synced_out.schedule())
unsynced_out = unsynced_bn(x)
unsynced_si = list(unsynced_out.schedule())
# TODO: test synced / unsynced batchnorm cross device kernel and copies
assert synced_si
assert unsynced_si
def helper_test_shard_op(shps, fxn, atol=1e-6, rtol=1e-3):
for shp in shps:
single_in = Tensor.randn(shp)
multi_in = single_in.shard(devices_2, axis=0)
single_out = fxn(single_in).numpy()
multi_out = fxn(multi_in).numpy()
try:
assert single_out.shape == multi_out.shape, f"shape mismatch: single={single_out.shape} | multi={multi_out.shape}"
assert single_out.dtype == multi_out.dtype, f"dtype mismatch: single={single_out.dtype} | multi={multi_out.dtype}"
np.testing.assert_allclose(single_out, multi_out, atol=atol, rtol=rtol)
except Exception as e:
raise Exception(f"Failed shape {single_out.shape}: {e}")
@unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL"), "no GPU CI")
class TestTensorOps(unittest.TestCase):
def test_interpolate(self):
helper_test_shard_op([(4,16,16),(4,24,24)], lambda x: Tensor.interpolate(x, (19,19)))
def test_bitcast(self):
helper_test_shard_op([(256,), (256,)], lambda x: x.bitcast(dtypes.int))
if __name__ == '__main__':
unittest.main()