openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

72 lines
1.9 KiB

import unittest
import torch
import tinygrad.frontend.torch
torch.set_default_device("tiny")
import numpy as np
class TestTorchBackendInplace(unittest.TestCase):
def test_zero(self):
a = torch.ones(4)
a.zero_()
np.testing.assert_equal(a.cpu().numpy(), [0,0,0,0])
def test_view_zero(self):
a = torch.ones(4)
a.view((2, 2)).zero_()
np.testing.assert_equal(a.cpu().numpy(), [0,0,0,0])
def test_slice_zero(self):
a = torch.ones(4)
a[2:].zero_()
np.testing.assert_equal(a.cpu().numpy(), [1,1,0,0])
def test_slice_permute_zero(self):
a = torch.ones((3,2))
a.permute(1,0)[1:].zero_()
np.testing.assert_equal(a.cpu().numpy(), [[1,0],[1,0],[1,0]])
def test_slice_fill(self):
a = torch.zeros(4)
a[2:].fill_(2)
np.testing.assert_equal(a.cpu().numpy(), [0,0,2,2])
def test_slice_mul(self):
a = torch.ones(4)
a[:2] *= 3
a[2:] *= 2
np.testing.assert_equal(a.cpu().numpy(), [3,3,2,2])
def test_stacked_mul(self):
a = torch.ones((3,3))
b = a[1:,1:].permute(1,0)
c = b[1:,:]
b *= 2
c *= 3
np.testing.assert_equal(a.cpu().numpy(), [[1,1,1],[1,2,6],[1,2,6]])
def test_flatten_reshape_add(self):
a = torch.zeros((2,2,12,32))
b = a.flatten()
c = b.reshape((48,32))
a += 1
b += 1
c += 1
np.testing.assert_equal(c.cpu().numpy(), torch.full((48,32),3).cpu().numpy())
def test_noncontig(self):
a = torch.empty_strided((4,4),(1,4), dtype=torch.int64)
# self.assertFalse(a.is_contiguous()) # TODO: we are contiguous when it's not required
a.zero_()
b = a.view((4,4))
b[1:3,:] += 1
np.testing.assert_equal(a.cpu().numpy(), [[0]*4,[1]*4,[1]*4,[0]*4])
def test_detach(self):
a = torch.zeros(4)
d = a.detach()
d += torch.arange(4)
np.testing.assert_array_equal(a.cpu(), torch.arange(4).cpu())
if __name__ == "__main__":
unittest.main()