openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

206 lines
8.5 KiB

import unittest
from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad import Variable
from tinygrad.tensor import Tensor
class TestSymbolic(unittest.TestCase):
def assert_tuple_equal(self, x, y):
for a,b in zip(x,y): self.assertFalse(a != b)
def test_symbolic_st(self):
x = Variable("x", 1, 100)
st = ShapeTracker.from_shape((x, 3))
assert st.shape == (x, 3)
assert st.real_strides() == (3, 1)
@unittest.expectedFailure
def test_real_strides_0(self):
st = ShapeTracker(views=(View(shape=(2, (Variable('start_pos', 1, 8)+1), 1, 1), strides=(8, 1, 0, 0), offset=0, mask=((0, 2), (0, Variable('start_pos', 1, 8)), (0, 1), (0, 1)), contiguous=False), View(shape=(2, (Variable('start_pos', 1, 8)+1)), strides=((Variable('start_pos', 1, 8)+1), 1), offset=0, mask=None, contiguous=True))) # noqa: E501
self.assertEqual(st.real_strides(), (8, None))
@unittest.expectedFailure
def test_real_strides_1(self):
st = ShapeTracker(views=(View(shape=(3, (Variable('i', 1, 10)+2)), strides=(Variable('i', 1, 10), 1), offset=0, mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501
self.assertEqual(st.real_strides(), (Variable('i', 1, 10), None))
@unittest.expectedFailure
def test_real_strides_2(self):
st = ShapeTracker(views=(View(shape=(3, (Variable('i', 1, 10)+Variable('j', 1, 10))), strides=(Variable('i', 1, 10), 1), offset=0, mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501
self.assertEqual(st.real_strides(), (Variable('i', 1, 10), None))
def test_merge_view_recursion_err(self):
vm2 = View(shape=(Variable('j', 1, 10),), strides=(0,), offset=0, mask=None, contiguous=False)
vm1 = View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True)
self.assertEqual(vm2+vm1, vm1)
def test_merge_view_recursion_err2(self):
vm2 = View(shape=(Variable('a', 1, 10).bind(4),), strides=(0,), offset=0, mask=None, contiguous=False)
# NOTE: vm1 is different from what create function would give, and this test vm2+vm1 halts
vm1 = View(shape=(Variable('a', 1, 10).bind(4),), strides=(1,), offset=0, mask=((0, Variable('a', 1, 10).bind(4)),), contiguous=False)
self.assertEqual(vm2+vm1, None)
vm3 = View.create(shape=(Variable('a', 1, 10).bind(4),))
self.assertEqual(vm3.shape, vm1.shape)
self.assertEqual(vm3.strides, vm1.strides)
self.assertEqual(vm2+vm3, vm2)
def test_cat_dim0_strides(self):
i = Variable("i", 1, 5).bind(3)
j = Variable("j", 1, 5).bind(3)
k = Variable("k", 1, 5).bind(3)
t = Tensor.rand(5, 4)[:i].cat(Tensor.rand(5, 4)[:j], dim=0).cat(Tensor.rand(5, 4)[:k], dim=0)
st = t.uop.st
self.assert_tuple_equal(st.shape, (i+j+k, 4))
assert st.real_strides() == (4, 1)
t = Tensor.rand(5, 3)[:i].cat(Tensor.rand(5, 3)[:i], dim=0).cat(Tensor.rand(3, 3), dim=0)
st = t.uop.st
self.assert_tuple_equal(st.shape, (2*i+3, 3))
assert st.real_strides() == (3, 1)
def test_cat_dim1_strides(self):
i = Variable("i", 1, 5).bind(4)
j = Variable("j", 1, 5).bind(4)
k = Variable("k", 1, 5).bind(4)
t = Tensor.rand(3, 5)[:, :i].cat(Tensor.rand(3, 5)[:, :j], dim=1).cat(Tensor.rand(3, 5)[:, :k], dim=1)
st = t.uop.st
self.assert_tuple_equal(st.shape, (3, i+j+k))
self.assert_tuple_equal(st.real_strides(), (i+j+k, 1))
class TestSymbolicVarVals(unittest.TestCase):
def assert_equal(self, x, y): self.assertFalse(x != y)
def test_var_vals_empty(self):
assert ShapeTracker.from_shape((3, 4, 5)).var_vals == {}
def test_var_vals_shape(self):
x = Variable("x", 1, 100).bind(3)
assert ShapeTracker.from_shape((x, 3)).var_vals == {Variable("x", 1, 100): 3}
def test_var_vals_offset(self):
x = Variable("x", 1, 100).bind(3)
st = ShapeTracker.from_shape((4, 3)).shrink(((x, x+1), (0, 3)))
self.assert_equal(st.views[-1].offset, x * 3)
assert st.var_vals == {Variable("x", 1, 100): 3}
def test_var_vals_mask(self):
x = Variable("x", 1, 100).bind(3)
view = View.create(shape=(3,4), strides=(4,1), offset=0, mask=((0, x), (0, 4)))
st = ShapeTracker(views=(view,))
assert st.var_vals == {Variable("x", 1, 100): 3}
def test_var_vals_complex(self):
x = Variable("x", 1, 100).bind(3)
y = Variable("y", 1, 100).bind(4)
z = Variable("z", 1, 100).bind(5)
st = ShapeTracker.from_shape((x, 5, y)).shrink(((0, x), (z, z+1), (0, 3)))
self.assert_equal(st.views[-1].offset, y * z)
assert st.var_vals == {Variable("x", 1, 100): 3, Variable("y", 1, 100):4, Variable("z", 1, 100): 5}
def test_shrink_reshape(self):
x = Variable("x", 1, 100).bind(3)
st = ShapeTracker.from_shape((10, 10, 10)).shrink(((x, x+3), (3, 7), (2, 5)))
st = st.reshape((3*4*3,))
assert st.var_vals == {Variable("x", 1, 100): 3}
class TestShapeTrackerUnbind(unittest.TestCase):
def test_view_unbind(self):
v = Variable("v", 1, 100)
bv = Variable("v", 1, 100).bind(3)
unbound_view, var_val = View.create(shape=(bv, 4)).unbind()
assert unbound_view == View.create(shape=(v, 4))
assert var_val == {v: 3}
def test_shrink_unbind(self):
v = Variable("v", 1, 100)
bv = Variable("v", 1, 100).bind(2)
t = Tensor.rand(3, 4).shrink(((0,bv),(0,4)))
unbound_st, var_val = t.uop.st.unbind()
assert unbound_st == ShapeTracker((View.create(shape=(v, 4)),))
assert var_val == {v: 2}
t = Tensor.rand(3, 4).shrink(((bv, bv+1), (0, 4)))
unbound_st, var_val = t.uop.st.unbind()
assert unbound_st == ShapeTracker((View.create(shape=(1, 4), offset=4*v),))
assert var_val == {v: 2}
class TestSymbolicReshape(unittest.TestCase):
def test_reshape(self):
a = Tensor.rand(5, 4)
b = Tensor.rand(5, 6)
for i in range(1, 6):
vi = Variable("i", 1, 5).bind(i)
ret = a[:vi]
ret = ret.reshape((vi, 4))
assert ret.shape == (vi, 4)
ret = b[:vi]
ret = ret.reshape((vi, 2, 3))
assert ret.shape == (vi, 2, 3)
def test_two_symbol_reshape(self):
t = Tensor.rand(5, 5)
for i in range(1, 6):
for j in range(1, 6):
vi = Variable("i", 1, 5).bind(i)
vj = Variable("j", 1, 5).bind(j)
ret = t[:vi, :vj]
ret = ret.reshape(vj, vi)
assert ret.shape == (vj, vi)
ret = ret.reshape(vi, vj)
assert ret.shape == (vi, vj)
ret = ret.reshape(1, vi*vj)
assert ret.shape == (1, vi*vj)
def test_symbolic_mask(self):
# taken from gpt2 single kvcache
# these two caused problems in gpt2 if reshape merged views
view = View(shape=(1, (Variable('start_pos', 1, 128).bind(2)+1), 16, 64), strides=(0, 0, 64, 1), offset=1024, mask=((0, 1), (Variable('start_pos', 1, 128).bind(2), (Variable('start_pos', 1, 128).bind(2)+1)), (0, 16), (0, 64)), contiguous=False) # noqa: E501
new_shape = (1, 1, (Variable('start_pos', 1, 128).bind(2)+1), 16, 64)
assert view.reshape(new_shape) is None
view = View(shape=(2, 1, (Variable('start_pos', 1, 128)+1), 16, 64), strides=(0, 0, 1024, 64, 1), offset=131072, mask=((1, 2), (0, 1), (0, (Variable('start_pos', 1, 128)+1)), (0, 16), (0, 64)), contiguous=False) # noqa: E501
new_shape = (2, (Variable('start_pos', 1, 128)+1), 16, 64)
assert view.reshape(new_shape) is None
class TestSymbolicExpand(unittest.TestCase):
def test_expand_into_symbols(self):
vi = Variable("i", 1, 5).bind(3)
vj = Variable("j", 1, 5).bind(3)
a = Tensor([[1], [2], [3]]).expand((3, vi))
assert a.shape == (3, vi)
a = a.reshape(3, vi, 1).expand((3, vi, vj))
assert a.shape == (3, vi, vj)
def test_plus_expands_constant(self):
a = Tensor.rand(3, 5)
for i in range(1, 6):
vi = Variable("i", 1, 5).bind(i)
ret = a[:, :vi]
ret = ret + 1
self.assertTupleEqual(ret.shape, (3, vi))
def test_pad_then_expand_into_symbols(self):
vi = Variable("i", 1, 10).bind(3)
a = Tensor(1).unsqueeze(0).pad((0, 24)).unsqueeze(0).expand((vi, 25))
self.assertEqual(a.shape, (vi, 25))
self.assertEqual(a.reshape(25*vi).shape, (vi*25,))
self.assertEqual(a.reshape(vi*25).shape, (vi*25,))
class TestSymbolicShrink(unittest.TestCase):
def test_shrink_symbols_simple(self):
vi = Variable("i", 1, 5)
t = Tensor.rand(5, 5).shrink(((0, 5),(0,vi)))
assert t.shape == (5, vi)
def test_shrink_symbols(self):
vi = Variable("i", 1, 5)
t = Tensor.rand(3, 5).shrink(((0, 2), (vi, vi+1)))
assert t.shape == (2, 1)
class TestSymbolicPad(unittest.TestCase):
def test_pad(self):
v = Variable("v", 1, 100).bind(5)
t = Tensor.ones(100)[:v].pad(((4, 0),))
t = t.reshape(9)
assert t.tolist() == [0,0,0,0,1,1,1,1,1]
if __name__ == '__main__':
unittest.main()