openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

162 lines
4.3 KiB

import unittest
from tinygrad import Tensor
from tinygrad.helpers import RANGEIFY
N = 256
@unittest.skipIf(RANGEIFY<1, "tests only for RANGEIFY")
class TestRangeify(unittest.TestCase):
def test_expand_children(self):
A = Tensor.empty(N, N).sum(axis=1)
ba = A.expand(N, N)
((ba+1).sum(axis=1) + (ba+2).sum(axis=0)).realize()
def test_double_gemm(self):
A = Tensor.empty(N, N)
B = Tensor.empty(N, N)
C = Tensor.empty(N, N)
(A@B@C).realize()
def test_double_gemm_exp(self):
A = Tensor.empty(N, N)
B = Tensor.empty(N, N)
C = Tensor.empty(N, N)
(((A@B).exp()@C).exp()).realize()
def test_double_gemm_relu(self):
A = Tensor.empty(N, N)
B = Tensor.empty(N, N)
C = Tensor.empty(N, N)
(((A@B).relu()@C).relu()).realize()
def test_double_gemm_relu_half_contig(self):
A = Tensor.empty(N, N)
B = Tensor.empty(N, N)
C = Tensor.empty(N, N)
(((A@B).relu().contiguous(arg=(1,))@C).relu()).realize()
def test_double_gemm_half_contig(self):
A = Tensor.empty(N, N)
B = Tensor.empty(N, N)
C = Tensor.empty(N, N)
((A@B).contiguous(arg=(1,))@C).realize()
def test_double_gemm_contig(self):
A = Tensor.empty(N, N)
B = Tensor.empty(N, N)
C = Tensor.empty(N, N)
((A@B).contiguous()@C).realize()
def test_many_gemm(self):
A = Tensor.empty(N, N)
B = Tensor.empty(N, N)
C = Tensor.empty(N, N)
D = Tensor.empty(N, N)
E = Tensor.empty(N, N)
F = Tensor.empty(N, N)
(A@B@C@D@E@F).realize()
def test_conv2d(self):
x = Tensor.empty(1, 4, 32, 32)
w1 = Tensor.empty(8, 4, 3, 3)
x.conv2d(w1).realize()
def test_conv2d_t(self):
x = Tensor.empty(1, 4, 32, 32)
w1 = Tensor.empty(8, 4, 3, 3)
(x*2).conv2d(w1).realize()
def test_double_conv2d(self):
x = Tensor.empty(1, 4, 32, 32)
w1 = Tensor.empty(8, 4, 3, 3)
w2 = Tensor.empty(12, 8, 3, 3)
x.conv2d(w1).conv2d(w2).realize()
def test_double_conv2d_half_contig(self):
x = Tensor.empty(1, 4, 32, 32)
w1 = Tensor.empty(8, 4, 3, 3)
w2 = Tensor.empty(12, 8, 3, 3)
# NOTE: this contiguous doesn't help
x.conv2d(w1).contiguous(arg=(1,)).conv2d(w2).permute(0,2,3,1).contiguous().realize()
def test_double_conv2d_contig(self):
x = Tensor.empty(1, 4, 32, 32)
w1 = Tensor.empty(8, 4, 3, 3)
w2 = Tensor.empty(12, 8, 3, 3)
x.conv2d(w1).contiguous().conv2d(w2).realize()
def test_transformer_ffn(self):
from tinygrad.apps.llm import TransformerBlock
from tinygrad import nn
blk = TransformerBlock(1024, 4096, 1, 1, 1e-5)
for p in nn.state.get_parameters(blk): p.replace(Tensor.empty(p.shape))
x = Tensor.empty(128, 1024)
out = blk._feed_forward(x)
out.realize()
def test_flash_attention(self):
BS = 4
HEADS = 2
MATDIM = 16
EMB = 8
q = Tensor.empty(BS, HEADS, MATDIM, EMB)
k = Tensor.empty(BS, HEADS, MATDIM, EMB)
v = Tensor.empty(BS, HEADS, MATDIM, EMB)
q.scaled_dot_product_attention(k, v).realize()
from tinygrad import dtypes
from tinygrad.uop.ops import UOp
# contiguous + reduce can support ranges?
@unittest.skipIf(RANGEIFY<1, "tests only for RANGEIFY")
class TestOuterworld(unittest.TestCase):
def test_passthrough_range(self):
t = Tensor.rand(10, 10).realize()
# passthrough ranges
a = UOp.range(dtypes.int, 10, -1)
sel = t[a]
cpy = sel.contiguous(a).realize()
self.assertTrue((t==cpy).all().item())
def test_flip_range(self):
t = Tensor.rand(10, 10).realize()
# passthrough ranges
a = UOp.range(dtypes.int, 10, -1)
sel = t[9-a]
cpy = sel.contiguous(a).realize()
self.assertTrue((t.flip(0)==cpy).all().item())
def test_vmap(self):
def f(x): return x.sum(axis=0)*2
x = Tensor.ones(3, 10, 2).contiguous()
# vmap across axis 0
a = UOp.range(dtypes.int, 3, -1)
out = f(x[a])
out = out.contiguous(a)
# 3x2 grid of 20
out.realize()
print(out.numpy())
def test_triple_gemm(self):
x = Tensor.rand(1, 16).realize()
W = Tensor.rand(3, 16, 16).realize()
manual = (x @ W[0] @ W[1] @ W[2]).contiguous().realize()
a = UOp.range(dtypes.int, 3, -1)
x = x.assign(x @ W[a])
out = x.contiguous(a)[-1].contiguous().realize()
self.assertTrue((manual==out).all().item())
if __name__ == '__main__':
unittest.main()