openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

168 lines
5.5 KiB

import unittest
from tinygrad import Tensor, TinyJit, Variable, dtypes
import numpy as np
class TestSetitem(unittest.TestCase):
def test_simple_setitem(self):
cases = (
((6,6), (slice(2,4), slice(3,5)), Tensor.ones(2,2)),
((6,6), (slice(2,4), slice(3,5)), Tensor([1.,2.])),
((6,6), (slice(2,4), slice(3,5)), 1.0),
((6,6), (3, 4), 1.0),
((6,6), (3, None, 4, None), 1.0),
((4,4,4,4), (Ellipsis, slice(1,3), slice(None)), Tensor(4)),
((4,4,4,4), (Ellipsis, slice(1,3)), 4),
((4,4,4,4), (2, slice(1,3), None, 1), 4),
((4,4,4,4), (slice(1,3), slice(None), slice(0,4,2)), 4),
((4,4,4,4), (slice(1,3), slice(None), slice(None), slice(0,3)), 4),
((6,6), (slice(1,5,2), slice(0,5,3)), 1.0),
((6,6), (slice(5,1,-2), slice(5,0,-3)), 1.0),
)
for shp, slc, val in cases:
t = Tensor.zeros(shp).contiguous()
t[slc] = val
n = np.zeros(shp)
n[slc] = val.numpy() if isinstance(val, Tensor) else val
np.testing.assert_allclose(t.numpy(), n)
def test_setitem_into_unrealized(self):
t = Tensor.arange(4).reshape(2, 2)
t[1] = 5
np.testing.assert_allclose(t.numpy(), [[0, 1], [5, 5]])
def test_setitem_dtype(self):
for dt in (dtypes.int, dtypes.float, dtypes.bool):
for v in (5., 5, True):
t = Tensor.ones(6,6, dtype=dt).contiguous()
t[1] = v
self.assertEqual(t.dtype, dt)
def test_setitem_into_noncontiguous(self):
t = Tensor.ones(4)
self.assertFalse(t.lazydata.st.contiguous)
with self.assertRaises(RuntimeError): t[1] = 5
@unittest.skip("TODO: flaky")
def test_setitem_inplace_operator(self):
t = Tensor.arange(4).reshape(2, 2).contiguous()
t[1] += 2
np.testing.assert_allclose(t.numpy(), [[0, 1], [4, 5]])
t = Tensor.arange(4).reshape(2, 2).contiguous()
t[1] -= 1
np.testing.assert_allclose(t.numpy(), [[0, 1], [1, 2]])
t = Tensor.arange(4).reshape(2, 2).contiguous()
t[1] *= 2
np.testing.assert_allclose(t.numpy(), [[0, 1], [4, 6]])
# NOTE: have to manually cast setitem target to least_upper_float for div
t = Tensor.arange(4, dtype=dtypes.float).reshape(2, 2).contiguous()
t[1] /= 2
np.testing.assert_allclose(t.numpy(), [[0, 1], [1, 1.5]])
t = Tensor.arange(4).reshape(2, 2).contiguous()
t[1] **= 2
np.testing.assert_allclose(t.numpy(), [[0, 1], [4, 9]])
t = Tensor.arange(4).reshape(2, 2).contiguous()
t[1] ^= 5
np.testing.assert_allclose(t.numpy(), [[0, 1], [7, 6]])
#@unittest.expectedFailure
# update: passing after delete_forced_realize
def test_setitem_consecutive_inplace_operator(self):
t = Tensor.arange(4).reshape(2, 2).contiguous()
t[1] += 2
t = t.contiguous()
# TODO: RuntimeError: can't double realize in one schedule
t[1] -= 1
np.testing.assert_allclose(t.numpy(), [[0, 1], [3, 4]])
def test_setitem_overlapping_indices(self):
t = Tensor([1,2,3,4])
# regular overlapping indices
t[[1,1]] = Tensor([5,6])
np.testing.assert_allclose(t.numpy(), [1,6,3,4])
# overlapping indices with zero value overlapped
t[[1,1]] = Tensor([0,1])
np.testing.assert_allclose(t.numpy(), [1,1,3,4])
def test_setitem_overlapping_indices_with_0(self):
t = Tensor([1,2,3,4])
t[[1,1]] = Tensor([1,0])
np.testing.assert_allclose(t.numpy(), [1,0,3,4])
def test_setitem_with_1_in_shape(self):
t = Tensor([[1],[2],[3]])
t[[0,0]] = Tensor([[1],[2]])
np.testing.assert_allclose(t.numpy(), [[2],[2],[3]])
def test_fancy_setitem(self):
t = Tensor.zeros(6,6).contiguous()
t[[1,2], [3,2]] = 3
n = np.zeros((6,6))
n[[1,2], [3,2]] = 3
np.testing.assert_allclose(t.numpy(), n)
def test_simple_jit_setitem(self):
@TinyJit
def f(t:Tensor, a:Tensor):
t[2:4, 3:5] = a
for i in range(1, 6):
t = Tensor.zeros(6, 6).contiguous().realize()
a = Tensor.full((2, 2), fill_value=i, dtype=dtypes.float).contiguous()
f(t, a)
n = np.zeros((6, 6))
n[2:4, 3:5] = np.full((2, 2), i)
np.testing.assert_allclose(t.numpy(), n)
def test_jit_setitem_variable_offset(self):
@TinyJit
def f(t:Tensor, a:Tensor, v:Variable):
t.shrink(((v,v+1), None)).assign(a).realize()
t = Tensor.zeros(6, 6).contiguous().realize()
n = np.zeros((6, 6))
for i in range(6):
v = Variable("v", 0, 6).bind(i)
a = Tensor.full((1, 6), fill_value=i+1, dtype=dtypes.float).contiguous()
n[i, :] = i+1
f(t, a, v)
np.testing.assert_allclose(t.numpy(), n)
np.testing.assert_allclose(t.numpy(), [[1,1,1,1,1,1],[2,2,2,2,2,2],[3,3,3,3,3,3],[4,4,4,4,4,4],[5,5,5,5,5,5],[6,6,6,6,6,6]])
def test_setitem_overlapping_inplace1(self):
t = Tensor([[3.0], [2.0], [1.0]]).contiguous()
t[1:] = t[:-1]
self.assertEqual(t.tolist(), [[3.0], [3.0], [2.0]])
def test_setitem_overlapping_inplace2(self):
t = Tensor([[3.0], [2.0], [1.0]]).contiguous()
t[:-1] = t[1:]
self.assertEqual(t.tolist(), [[2.0], [1.0], [1.0]])
class TestWithGrad(unittest.TestCase):
def test_no_requires_grad_works(self):
z = Tensor.rand(8, 8)
x = Tensor.rand(8)
z[:3] = x
def test_set_into_requires_grad(self):
z = Tensor.rand(8, 8, requires_grad=True)
x = Tensor.rand(8)
with self.assertRaises(NotImplementedError):
z[:3] = x
def test_set_with_requires_grad(self):
z = Tensor.rand(8, 8)
x = Tensor.rand(8, requires_grad=True)
with self.assertRaises(NotImplementedError):
z[:3] = x
if __name__ == '__main__':
unittest.main()