You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
55 lines
2.3 KiB
55 lines
2.3 KiB
import unittest
|
|
from tinygrad import Tensor, Device
|
|
from tinygrad.codegen.kernel import Kernel
|
|
from tinygrad.device import Buffer
|
|
from tinygrad.engine.search import get_test_global_size, bufs_from_lin
|
|
from tinygrad.helpers import GlobalCounters
|
|
from extra.optimization.helpers import time_linearizer
|
|
|
|
class TestSearchUtil(unittest.TestCase):
|
|
def test_get_test_global_size(self):
|
|
self.assertEqual(get_test_global_size([256, 256, 256], 65536, {}), ([256, 16, 16], 256.0))
|
|
self.assertEqual(get_test_global_size([65536, 1, 1], 256, {}), ([256, 1, 1], 256.0))
|
|
self.assertEqual(get_test_global_size([77, 1, 1], 16, {}), ([9, 1, 1], 77/9))
|
|
|
|
def test_bufs_from_lin(self):
|
|
a = Tensor([1,2,3,4]).realize()
|
|
si = (a+1).schedule()[0]
|
|
rawbufs = bufs_from_lin(lin:=Kernel(si.ast))
|
|
assert len(rawbufs) == len(lin.membufs) == 2
|
|
assert all(r is not None for r in rawbufs)
|
|
assert all(isinstance(r, Buffer) for r in rawbufs)
|
|
assert all(r.size > 0 for r in rawbufs)
|
|
|
|
def test_bufs_from_lin_alt(self):
|
|
a = Tensor.randn(4, 4).realize()
|
|
b = a+a[0]
|
|
si = b.schedule()[0]
|
|
rawbufs = bufs_from_lin(k:=Kernel(si.ast))
|
|
assert len(rawbufs) == len(k.membufs) == 2
|
|
assert all(r is not None for r in rawbufs)
|
|
assert all(isinstance(r, Buffer) for r in rawbufs)
|
|
assert all(r.size > 0 for r in rawbufs)
|
|
|
|
class TestTimeLinearizer(unittest.TestCase):
|
|
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WebGPU timestamps are low precision, tm is 0")
|
|
def test_reasonable_time(self):
|
|
a = Tensor([1,2,3,4]).realize()
|
|
si = (a+1).schedule()[0]
|
|
# create fresh empty buffers
|
|
rawbufs = [Buffer(b.device, b.size, b.dtype).allocate() for b in si.bufs]
|
|
tm = time_linearizer(Kernel(si.ast), rawbufs, allow_test_size=False, cnt=10, disable_cache=True)
|
|
assert tm > 0 and tm != float('inf')
|
|
|
|
# Ensure that the kernel count is not incremented by time_linearizer when clearing l2
|
|
def test_kernel_count(self):
|
|
ast = Tensor.zeros(16).contiguous().kernelize().uop.src[1].arg.ast
|
|
lin = Kernel(ast)
|
|
bufs = bufs_from_lin(lin)
|
|
|
|
kernel_count = GlobalCounters.kernel_count
|
|
time_linearizer(lin, bufs, allow_test_size=False, cnt=2, disable_cache=True, clear_l2=True)
|
|
assert GlobalCounters.kernel_count == kernel_count, "kernel count was incremented by time_linearizer"
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main() |