openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

55 lines
2.3 KiB

import unittest
from tinygrad import Tensor, Device
from tinygrad.codegen.kernel import Kernel
from tinygrad.device import Buffer
from tinygrad.engine.search import get_test_global_size, bufs_from_lin
from tinygrad.helpers import GlobalCounters
from extra.optimization.helpers import time_linearizer
class TestSearchUtil(unittest.TestCase):
def test_get_test_global_size(self):
self.assertEqual(get_test_global_size([256, 256, 256], 65536, {}), ([256, 16, 16], 256.0))
self.assertEqual(get_test_global_size([65536, 1, 1], 256, {}), ([256, 1, 1], 256.0))
self.assertEqual(get_test_global_size([77, 1, 1], 16, {}), ([9, 1, 1], 77/9))
def test_bufs_from_lin(self):
a = Tensor([1,2,3,4]).realize()
si = (a+1).schedule()[0]
rawbufs = bufs_from_lin(lin:=Kernel(si.ast))
assert len(rawbufs) == len(lin.membufs) == 2
assert all(r is not None for r in rawbufs)
assert all(isinstance(r, Buffer) for r in rawbufs)
assert all(r.size > 0 for r in rawbufs)
def test_bufs_from_lin_alt(self):
a = Tensor.randn(4, 4).realize()
b = a+a[0]
si = b.schedule()[0]
rawbufs = bufs_from_lin(k:=Kernel(si.ast))
assert len(rawbufs) == len(k.membufs) == 2
assert all(r is not None for r in rawbufs)
assert all(isinstance(r, Buffer) for r in rawbufs)
assert all(r.size > 0 for r in rawbufs)
class TestTimeLinearizer(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WebGPU timestamps are low precision, tm is 0")
def test_reasonable_time(self):
a = Tensor([1,2,3,4]).realize()
si = (a+1).schedule()[0]
# create fresh empty buffers
rawbufs = [Buffer(b.device, b.size, b.dtype).allocate() for b in si.bufs]
tm = time_linearizer(Kernel(si.ast), rawbufs, allow_test_size=False, cnt=10, disable_cache=True)
assert tm > 0 and tm != float('inf')
# Ensure that the kernel count is not incremented by time_linearizer when clearing l2
def test_kernel_count(self):
ast = Tensor.zeros(16).contiguous().kernelize().uop.src[1].arg.ast
lin = Kernel(ast)
bufs = bufs_from_lin(lin)
kernel_count = GlobalCounters.kernel_count
time_linearizer(lin, bufs, allow_test_size=False, cnt=2, disable_cache=True, clear_l2=True)
assert GlobalCounters.kernel_count == kernel_count, "kernel count was incremented by time_linearizer"
if __name__ == "__main__":
unittest.main()