openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

680 lines
28 KiB

#!/usr/bin/env python
import unittest
import numpy as np
import torch
from tinygrad import Tensor, Device, TinyJit
from tinygrad.ops import Ops
from tinygrad.helpers import CI, Context, OSX
from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, Embedding
from tinygrad.nn import BatchNorm, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm, RMSNorm, LSTMCell
from tinygrad.nn.state import load_state_dict
from tinygrad.engine.realize import run_schedule
@unittest.skipIf(CI and Device.DEFAULT in {"CUDA", "NV"}, "slow")
class TestNN(unittest.TestCase):
def test_sparse_cat_cross_entropy(self):
# create in tinygrad
input_tensor = Tensor.randn(6, 5) # not square to test that mean scaling uses the correct dimension
target = Tensor([0, 0, 0, 1, 2, 3]) # torch doesn't support target=-1
torch_input = torch.tensor(input_tensor.numpy())
torch_target = torch.tensor(target.numpy(), dtype=torch.long)
for smoothing in [0.0, 0.1, 0.5, 1.0]:
for ignore_index in [-1, 0, 2]:
for reduction in ["none", "sum", "mean"]:
loss = input_tensor.sparse_categorical_crossentropy(target, label_smoothing=smoothing, ignore_index=ignore_index, reduction=reduction)
torch_loss = torch.nn.CrossEntropyLoss(reduction=reduction, label_smoothing=smoothing, ignore_index=ignore_index)(torch_input, torch_target)
np.testing.assert_allclose(loss.numpy(), torch_loss.detach().numpy(), atol=1e-5, rtol=1e-6)
# also test with a batch dimension (of size 1)
loss = input_tensor.unsqueeze(0).sparse_categorical_crossentropy(
target.unsqueeze(0), label_smoothing=smoothing, ignore_index=ignore_index, reduction=reduction
)
torch_loss = torch.nn.CrossEntropyLoss(reduction=reduction, label_smoothing=smoothing, ignore_index=ignore_index)(
torch_input.unsqueeze(0).permute(0,2,1), torch_target.unsqueeze(0)
)
np.testing.assert_allclose(loss.numpy(), torch_loss.detach().numpy(), atol=1e-5, rtol=1e-6)
def test_batchnorm2d(self, training=False, threed=False, track_running_stats=True):
with Tensor.train(training):
szs = [4, 8, 16, 32]
for sz in szs:
# create in tinygrad
bn = BatchNorm(sz, eps=1e-5, track_running_stats=track_running_stats)
bn.weight = Tensor.randn(sz)
bn.bias = Tensor.randn(sz)
if track_running_stats:
bn.running_mean = Tensor.randn(sz)
bn.running_var = Tensor.randn(sz)
bn.running_var.numpy()[bn.running_var.numpy() < 0] = 0
# create in torch
with torch.no_grad():
if threed:
tbn = torch.nn.BatchNorm3d(sz, track_running_stats=track_running_stats).eval()
else:
tbn = torch.nn.BatchNorm2d(sz, track_running_stats=track_running_stats).eval()
tbn.training = training
tbn.weight[:] = torch.tensor(bn.weight.numpy())
tbn.bias[:] = torch.tensor(bn.bias.numpy())
if track_running_stats:
tbn.running_mean[:] = torch.tensor(bn.running_mean.numpy())
tbn.running_var[:] = torch.tensor(bn.running_var.numpy())
if track_running_stats:
np.testing.assert_allclose(bn.running_mean.numpy(), tbn.running_mean.detach().numpy(), rtol=1e-5, atol=1e-6)
np.testing.assert_allclose(bn.running_var.numpy(), tbn.running_var.detach().numpy(), rtol=1e-5, atol=1e-6)
# trial
if threed:
inn = Tensor.randn(2, sz, 3, 3, 3)
else:
inn = Tensor.randn(2, sz, 3, 3)
# in tinygrad
outt = bn(inn)
# in torch
toutt = tbn(torch.tensor(inn.numpy()))
# close
np.testing.assert_allclose(outt.numpy(), toutt.detach().numpy(), rtol=5e-4, atol=1e-6)
if track_running_stats:
np.testing.assert_allclose(bn.running_mean.numpy(), tbn.running_mean.detach().numpy(), rtol=1e-5, atol=1e-6)
np.testing.assert_allclose(bn.running_var.numpy(), tbn.running_var.detach().numpy(), rtol=1e-5, atol=1e-6)
def test_batchnorm2d_training(self): self.test_batchnorm2d(True, False, True)
def test_batchnorm2d_no_running_stats(self): self.test_batchnorm2d(False, False, False)
def test_batchnorm2d_training_no_running_stats(self): self.test_batchnorm2d(True, False, False)
def test_batchnorm3d(self): self.test_batchnorm2d(False, True, True)
def test_batchnorm3d_training(self): self.test_batchnorm2d(True, True, True)
def test_batchnorm3d_no_running_stats(self): self.test_batchnorm2d(False, True, False)
def test_batchnorm3d_training_no_running_stats(self): self.test_batchnorm2d(True, True, False)
def test_batchnorm_axis(self):
sz = (2, 4, 3, 2, 2)
x = Tensor.randn(sz)
weight = Tensor.randn(2, 3)
bias = Tensor.randn(2, 3)
mean = Tensor.randn(2, 3)
invstd = Tensor.randn(2, 3)
a = (x.batchnorm(weight, bias, mean, invstd, axis=(0, 2))
.permute(1, 0, 2, 3, 4).reshape(4, 6, 2, 2))
b = (x.permute(1, 0, 2, 3, 4).reshape(4, 6, 2, 2)
.batchnorm(weight.flatten(), bias.flatten(), mean.flatten(), invstd.flatten()))
t_x = torch.tensor(x.permute(1, 0, 2, 3, 4).reshape(4, 6, 2, 2).numpy())
t_weight, t_bias = torch.tensor(weight.flatten().numpy()), torch.tensor(bias.flatten().numpy())
t_mean, t_invstd = torch.tensor(mean.flatten().numpy()), torch.tensor(invstd.flatten().numpy())
torch.nn.functional.batch_norm(t_x, t_mean, 1.0 / t_invstd**2, t_weight, t_bias)
np.testing.assert_allclose(a.numpy(), b.numpy())
def test_linear(self):
def _test_linear(x, in_dim, out_dim):
# create in tinygrad
model = Linear(in_dim, out_dim)
z = model(x)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.Linear(in_dim, out_dim).eval()
torch_layer.weight[:] = torch.tensor(model.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(model.bias.numpy(), dtype=torch.float32)
torch_x = torch.tensor(x.numpy(), dtype=torch.float32)
torch_z = torch_layer(torch_x)
# test
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
BS, T, in_dim, out_dim = 4, 2, 8, 16
_test_linear(Tensor.randn(BS, in_dim), in_dim, out_dim)
_test_linear(Tensor.randn(BS, T, in_dim), in_dim, out_dim) # test with more dims
def test_conv1d(self):
BS, C1, W = 4, 16, 224//4
C2, K, S, P = 64, 7, 2, 1
# create in tinygrad
layer = Conv1d(C1, C2, kernel_size=K, stride=S, padding=P)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.Conv1d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.uniform(BS, C1, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
def test_conv2d(self):
BS, C1, H, W = 4, 16, 224//4, 224//4
C2, K, S, P = 64, 7, 2, 1
# create in tinygrad
layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.Conv2d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.uniform(BS, C1, H, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
def test_conv1d_same_padding(self):
BS, C1, W = 8, 3, 32
C2, K, S, P = 16, 3, 1, 'same'
# create in tinygrad
layer = Conv1d(C1, C2, kernel_size=K, stride=S, padding=P)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.Conv1d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.uniform(BS, C1, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
def _run_conv2d_same_padding_test(self, BS, C1, C2, H, W, K, S, padding='same', D=1):
# create in tinygrad
layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=padding, dilation=D)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.Conv2d(C1, C2, kernel_size=K, stride=S, padding=padding, dilation=D).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.uniform(BS, C1, H, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
def test_conv2d_same_padding_odd_input(self):
BS, C1, H, W = 16, 16, 29, 31
C2, K, S, P = 32, 5, 1, 'same'
self._run_conv2d_same_padding_test(BS, C1, C2, H, W, K, S, P)
def test_conv2d_same_padding_large_kernel(self):
BS, C1, H, W = 16, 16, 28, 33
C2, K, S, P = 32, 9, 1, 'same'
self._run_conv2d_same_padding_test(BS, C1, C2, H, W, K, S, P)
def test_conv2d_same_padding_with_dilation(self):
BS, C1, H, W = 16, 3, 28, 28
C2, K, S, P, D = 32, 3, 1, 'same', 3
self._run_conv2d_same_padding_test(BS, C1, C2, H, W, K, S, P, D)
def test_conv2d_same_padding_invalid_stride(self):
C1, C2, K, S, P = 16, 32, 2, 2, 'same'
self.assertRaises(ValueError, Conv2d, C1, C2, kernel_size=K, stride=S, padding=P)
def test_conv2d_same_padding_invalid_padding_str(self):
C1, C2, K, S, P = 16, 32, 2, 1, 'not_same'
self.assertRaises(ValueError, Conv2d, C1, C2, kernel_size=K, stride=S, padding=P)
@unittest.skip("Takes too long to compile for Compiled backends")
def test_conv2d_winograd(self):
BS, C1, H, W = 2, 8, 16, 16
C2, K, S, P = 8, 3, 1, 1
# create in tinygrad
layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
layer.weight.requires_grad = True
layer.bias.requires_grad = True
# create in torch
torch_layer = torch.nn.Conv2d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
torch_layer.weight = torch.nn.Parameter(torch.tensor(layer.weight.numpy(), dtype=torch.float32))
torch_layer.bias = torch.nn.Parameter(torch.tensor(layer.bias.numpy(), dtype=torch.float32))
# test
x = Tensor.uniform(BS, C1, H, W, requires_grad=True)
with Context(WINO=1):
z = layer(x)
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
m = z.mean()
m.backward()
gw = layer.weight.grad.realize()
gb = layer.bias.grad.realize()
gx = x.grad.realize()
torch_z.mean().backward()
np.testing.assert_allclose(gw.numpy(), torch_layer.weight.grad.numpy(), atol=5e-4, rtol=1e-5)
np.testing.assert_allclose(gb.numpy(), torch_layer.bias.grad.numpy(), atol=5e-4, rtol=1e-5)
np.testing.assert_allclose(gx.numpy(), torch_x.grad.numpy(), atol=5e-4, rtol=1e-5)
def test_conv_transpose1d(self):
BS, C1, W = 4, 16, 224//4
C2, K, S, P = 64, 7, 2, 1
# create in tinygrad
layer = ConvTranspose1d(C1, C2, kernel_size=K, stride=S, padding=P)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.ConvTranspose1d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.uniform(BS, C1, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
def test_conv_transpose2d(self):
BS, C1, H, W = 4, 16, 224//4, 224//4
C2, K, S, P = 64, 7, 2, 1
# create in tinygrad
layer = ConvTranspose2d(C1, C2, kernel_size=K, stride=S, padding=P)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.ConvTranspose2d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.uniform(BS, C1, H, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and not OSX, "WEBGPU Vulkan can only run kernels with up to 10 buffers")
def test_groupnorm(self):
BS, H, W, C, G = 20, 10, 10, 6, 3
# create in torch
torch_layer = torch.nn.GroupNorm(G, C).eval()
# create in tinygrad
layer = GroupNorm(G, C)
layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True)
for _ in range(10):
# forward
x = Tensor.randn(BS, C, H, W, requires_grad=True)
z = layer(x)
z.sum().backward()
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x)
torch_z.sum().backward()
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and not OSX, "WEBGPU Vulkan can only run kernels with up to 10 buffers")
def test_layernorm(self):
N, C, H, W = 20, 5, 10, 10
# create in torch
torch_layer = torch.nn.LayerNorm([H, W]).eval()
# create in tinygrad
layer = LayerNorm([H, W])
layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True)
for _ in range(10):
# forward
x = Tensor.randn(N, C, H, W, requires_grad=True)
z = layer(x)
z.sum().backward()
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x)
torch_z.sum().backward()
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and not OSX, "WEBGPU Vulkan can only run kernels with up to 10 buffers")
def test_layernorm_2d(self):
N, C, H, W = 20, 5, 10, 10
# create in torch
torch_layer = torch.nn.LayerNorm([C]).eval()
# create in tinygrad
layer = LayerNorm2d(C)
layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True)
for _ in range(10):
# forward
x = Tensor.randn(N, C, H, W, requires_grad=True)
z = layer(x)
z.sum().backward()
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2)
torch_z.sum().backward()
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and not OSX, "WEBGPU Vulkan can only run kernels with up to 10 buffers")
def test_instancenorm_2d(self):
N, C, H, W = 20, 10, 10, 10
# create in torch
torch_layer = torch.nn.InstanceNorm2d(C, affine=True).eval()
# create in tinygrad
layer = InstanceNorm(C)
layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True)
for _ in range(10):
# forward
x = Tensor.randn(N, C, H, W, requires_grad=True)
z = layer(x)
z.sum().backward()
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x)
torch_z.sum().backward()
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and not OSX, "WEBGPU Vulkan can only run kernels with up to 10 buffers")
def test_instancenorm_3d(self):
N, C, D, H, W = 20, 10, 10, 10, 10
# create in torch
torch_layer = torch.nn.InstanceNorm3d(C, affine=True).eval()
# create in tinygrad
layer = InstanceNorm(C)
layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True)
for _ in range(10):
# forward
x = Tensor.randn(N, C, D, H, W, requires_grad=True)
z = layer(x)
z.sum().backward()
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x)
torch_z.sum().backward()
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=2e-3, rtol=1e-3)
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and not OSX, "WEBGPU Vulkan can only run kernels with up to 10 buffers")
def test_rmsnorm(self):
class TorchRMSNorm(torch.nn.Module):
# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L34C1-L77C36
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
B, T, embed_size = 4, 10, 20
torch_layer = TorchRMSNorm(embed_size)
layer = RMSNorm(embed_size)
layer.weight.requires_grad = True
for _ in range(10):
# forward
x = Tensor.randn(B, T, embed_size, requires_grad=True)
z = layer(x)
z.sum().backward()
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x)
torch_z.sum().backward()
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=2e-3, rtol=1e-3)
def test_embedding(self):
B, T, embed_size, vocab_size = 4, 10, 20, 28
# create in tinygrad
layer = Embedding(vocab_size, embed_size)
with torch.no_grad():
torch_layer = torch.nn.Embedding(vocab_size, embed_size).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
# test
x = Tensor(np.random.randint(0, vocab_size, (B, T), dtype=np.int32))
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8)
# test with empty input length
x = Tensor(np.random.randint(0, vocab_size, (B, 0), dtype=np.int32))
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8)
# test with jit enabled
@TinyJit
def layer_jit(x):
return layer(x).realize()
for _ in range(3):
x = Tensor(np.random.randint(0, vocab_size, (B, T), dtype=np.int32))
z = layer_jit(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8)
def test_embedding_one_kernel(self):
layer = Embedding(20, 30)
layer.weight = Tensor.zeros_like(layer.weight).contiguous()
a = Tensor([[1, 5, 9, 11],
[12, 19, 8, 1]])
result = layer(a)
schedule = result.schedule()
self.assertEqual(3, len([item for item in schedule if item.ast.op is Ops.SINK]), "first run realizes arange, weight, and embedding")
run_schedule(schedule)
b = Tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
result = layer(b)
schedule = result.schedule()
self.assertEqual(1, len([item for item in schedule if item.ast.op is Ops.SINK]), "second run realizes embedding only")
run_schedule(schedule)
def test_embedding_shape(self):
vocab_size, embed_size = 10, 16
layer = Embedding(vocab_size, embed_size)
for rank in range(5):
shp = (1,) * rank
a = Tensor([3]).reshape(shp)
result = layer(a)
self.assertEqual(result.shape, shp + (embed_size,))
def test_load_state_dict(self):
layer = Conv2d(3, 5, kernel_size=3)
state_dict = {
'weight': Tensor.randn(5, 3, 3, 3),
'bias': Tensor.randn(5),
}
load_state_dict(layer, state_dict)
np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy())
np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy())
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
def test_load_state_dict_sharded_model(self):
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3")
layer = Conv2d(3, 5, kernel_size=3)
layer.weight.shard_(devices, 3)
layer.bias.shard_(devices, None)
state_dict = {
'weight': Tensor.randn(5, 3, 3, 3).realize(),
'bias': Tensor.randn(5).realize(),
}
load_state_dict(layer, state_dict)
# sharded model shards the state_dict
self.assertEqual(layer.weight.device, devices)
self.assertEqual(layer.weight.lazydata.axis, 3)
self.assertEqual(layer.bias.device, devices)
self.assertEqual(layer.bias.lazydata.axis, None)
np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy())
np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy())
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
def test_load_state_dict_sharded_dict(self):
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3")
layer = Conv2d(3, 5, kernel_size=3)
state_dict = {
'weight': Tensor.randn(5, 3, 3, 3).shard(devices, 3),
'bias': Tensor.randn(5).shard(devices, None),
}
load_state_dict(layer, state_dict)
# NOTE: model is not sharded, still not sharded after load_state_dict
self.assertEqual(layer.weight.device, Device.DEFAULT)
self.assertEqual(layer.bias.device, Device.DEFAULT)
np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy())
np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy())
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
def test_load_state_dict_sharded_model_dict_same_axis(self):
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3")
layer = Conv2d(3, 5, kernel_size=3)
layer.weight.shard_(devices, 3)
layer.bias.shard_(devices, None)
state_dict = {
'weight': Tensor.randn(5, 3, 3, 3).shard(devices, 3),
'bias': Tensor.randn(5).shard(devices, None),
}
load_state_dict(layer, state_dict)
self.assertEqual(layer.weight.device, devices)
self.assertEqual(layer.weight.lazydata.axis, 3)
self.assertEqual(layer.bias.device, devices)
self.assertEqual(layer.bias.lazydata.axis, None)
np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy())
np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy())
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
def test_load_state_dict_sharded_model_dict_different_axis(self):
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3")
devices5 = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3", f"{Device.DEFAULT}:4", f"{Device.DEFAULT}:5")
layer = Conv2d(3, 5, kernel_size=3)
layer.weight.shard_(devices, 3)
layer.bias.shard_(devices, None)
# different shard axis
state_dict = {
'weight': Tensor.randn(5, 3, 3, 3).shard(devices, None),
'bias': Tensor.randn(5).shard(devices5, 0),
}
load_state_dict(layer, state_dict)
# NOTE: model and state_dict shard differently, use the state_dict sharding # TODO: revisit this?
self.assertEqual(layer.weight.device, devices)
self.assertEqual(layer.weight.lazydata.axis, None)
self.assertEqual(layer.bias.device, devices5)
self.assertEqual(layer.bias.lazydata.axis, 0)
np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy())
np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy())
def test_load_state_dict_shape_mismatch(self):
d1, d2 = 2, 4
layer = Linear(d1, d1, bias=False)
state_dict = {'weight': Tensor.randn(d2, d2)}
with self.assertRaisesRegex(ValueError, r'Shape mismatch in layer `weight`: Expected shape \(2, 2\), but found \(4, 4\) in state dict.'):
load_state_dict(layer, state_dict)
def test_lstm_cell(self):
layer = LSTMCell(32, 16)
with torch.no_grad():
torch_layer = torch.nn.LSTMCell(32, 16)
layer.weight_hh.assign(torch_layer.weight_hh.numpy())
layer.weight_ih.assign(torch_layer.weight_ih.numpy())
layer.bias_hh.assign(torch_layer.bias_hh.numpy())
layer.bias_ih.assign(torch_layer.bias_ih.numpy())
inp = Tensor.randn(1, 32)
out_h, out_c = layer(inp)
torch_out_h, torch_out_c = torch_layer(torch.tensor(inp.numpy()))
np.testing.assert_allclose(out_h.numpy(), torch_out_h.numpy(), atol=1e-6)
np.testing.assert_allclose(out_c.numpy(), torch_out_c.numpy(), atol=1e-6)
out_h, out_c = layer(inp, (out_h, out_c))
torch_out_h, torch_out_c = torch_layer(torch.tensor(inp.numpy()), (torch_out_h, torch_out_c))
np.testing.assert_allclose(out_h.numpy(), torch_out_h.numpy(), atol=1e-6)
np.testing.assert_allclose(out_c.numpy(), torch_out_c.numpy(), atol=1e-6)
def test_lstm_cell_no_bias(self):
layer = LSTMCell(32, 16, bias=False)
inp = Tensor.randn(1, 32)
out_h, out_c = layer(inp)
out_h.realize()
out_c.realize()
h = Tensor.randn(1, 16)
c = Tensor.randn(1, 16)
out_h, out_c = layer(inp, (h, c))
out_h.realize()
out_c.realize()
assert layer.bias_hh is None
assert layer.bias_ih is None
if __name__ == '__main__':
unittest.main()