dragonpilot - 基於 openpilot 的開源駕駛輔助系統
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

837 lines
28 KiB

#!/usr/bin/env python
import unittest, functools
import numpy as np
from hypothesis import given, settings, strategies as strat
from test.helpers import assert_jit_cache_len, not_support_multi_device, REAL_DEV
from tinygrad.tensor import Tensor
from tinygrad.engine.jit import TinyJit, GraphRunner, MultiGraphRunner, graph_class
from tinygrad.engine.realize import CompiledRunner, BufferCopy, BufferXfer
from tinygrad.device import Device
from tinygrad.helpers import Context, JIT, GlobalCounters, getenv
from tinygrad.dtype import dtypes
from extra.models.unet import ResBlock
def _simple_test(add, extract=lambda x: x, N=10):
for _ in range(5):
a = Tensor.randn(N, N)
b = Tensor.randn(N, N)
c = add(a, b)
np.testing.assert_allclose(extract(c).numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
assert_jit_cache_len(add, 1)
class TestJit(unittest.TestCase):
@settings(deadline=2e4)
@unittest.skipUnless(REAL_DEV in ["LLVM", "CPU"], f"no support on {REAL_DEV}")
@given(strat.sampled_from([Tensor.exp2, Tensor.log2, Tensor.sin]))
def test_approx_jit_timeout(self, op):
with Context(TRANSCENDENTAL=2):
model = [ResBlock(16, 24, 16) for _ in range(4)]
@TinyJit
def fw_approx(t, t2):
for l in model: t = l(t, t2)
return op(t).realize()
fw_approx(Tensor.empty(4, 16, 8, 8), Tensor.empty(1, 24))
def test_simple_jit(self):
@TinyJit
def add(a, b): return (a+b).realize()
_simple_test(add)
def test_simple_jit_reset(self):
@TinyJit
def add(a, b): return (a+b).realize()
_simple_test(add)
add.reset()
_simple_test(add, N=20)
def test_simple_jit_norealize(self):
@TinyJit
def add(a, b): return (a+b)
_simple_test(add)
def test_simple_jit_norealize_list(self):
@TinyJit
def add(a, b): return [a+b]
_simple_test(add, extract=lambda x: x[0])
def test_simple_jit_norealize_dict(self):
@TinyJit
def add(a, b): return {"billy": a+b}
_simple_test(add, extract=lambda x: x["billy"])
def test_jit_multiple_outputs(self):
@TinyJit
def f(a, b): return (a+b).realize(), (a-b).realize(), (a*b).realize()
for _ in range(5):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
c, d, e = f(a, b)
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
np.testing.assert_allclose(d.numpy(), a.numpy()-b.numpy(), atol=1e-4, rtol=1e-5)
np.testing.assert_allclose(e.numpy(), a.numpy()*b.numpy(), atol=1e-4, rtol=1e-5)
assert_jit_cache_len(f, 3)
def test_nothing_jitted(self):
@TinyJit
def add(a, b): return None
with self.assertRaises(AssertionError):
for _ in range(5):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
add(a, b)
def test_jit_zero_does_not_jit(self):
@TinyJit
def add(a, b): return (a+b).realize()
with Context(JIT=0):
for i in range(5):
a = Tensor([i])
b = Tensor([i])
c = add(a, b)
np.testing.assert_allclose(c.numpy(), 2*i)
assert_jit_cache_len(add, 0)
def test_jit_not_capturing(self):
@TinyJit
def add(a, b):
Tensor.zeros(4, 4).contiguous().realize() # no-op kernel is captured
return (a+b).realize()
for i in range(5):
a = Tensor([i])
b = Tensor([i])
c = add(a, b)
np.testing.assert_allclose(c.numpy(), 2*i)
assert_jit_cache_len(add, 2)
@TinyJit
def add2(a, b):
with Context(CAPTURING=0): # not captured
Tensor.zeros(4, 4).contiguous().realize()
return (a+b).realize()
for i in range(5):
a = Tensor([i])
b = Tensor([i])
c = add2(a, b)
np.testing.assert_allclose(c.numpy(), 2*i)
assert_jit_cache_len(add2, 1)
def test_jit_shape_mismatch(self):
@TinyJit
def add(a, b): return (a+b).realize()
for _ in range(5):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
add(a, b)
bad = Tensor.randn(20, 20)
with self.assertRaises(AssertionError):
add(a, bad)
def test_jit_shape_views_mismatch(self):
@TinyJit
def add(a): return (a+1).realize()
with self.assertRaises(AssertionError):
for i in range(1,5):
# a has an offset that the kernel doesn't know about
a = Tensor.randn(10, 10).realize()[:, i:i+2]
add(a)
def test_jit_duplicate_fail(self):
# the jit doesn't support duplicate arguments
@TinyJit
def add(a, b): return (a+b).realize()
a = Tensor.randn(10, 10)
with self.assertRaises(AssertionError):
add(a, a)
def test_jit_assign(self, dtype=dtypes.float32):
@TinyJit
def add(a):
a += 1
a.realize()
a = Tensor.zeros(1, dtype=dtype).contiguous().realize()
for _ in range(5): add(a)
self.assertEqual(a.item(), 5)
def test_jit_assign_int8(self): self.test_jit_assign(dtypes.int8)
def test_kwargs_jit(self):
@TinyJit
def add_kwargs(first, second): return (first+second).realize()
for _ in range(5):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
c = add_kwargs(first=a, second=b)
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
assert_jit_cache_len(add_kwargs, 1)
def test_reorder_kwargs_jit(self):
@TinyJit
def add_kwargs(first, second): return (first/second).realize()
for _ in range(2):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
c = add_kwargs(second=b, first=a)
np.testing.assert_allclose(c.numpy(), a.numpy()/b.numpy(), atol=1e-4, rtol=1e-5)
for _ in range(2):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
c = add_kwargs(first=a, second=b)
np.testing.assert_allclose(c.numpy(), a.numpy()/b.numpy(), atol=1e-4, rtol=1e-5)
assert_jit_cache_len(add_kwargs, 1)
def test_array_jit(self):
@TinyJit
def add_array(a, arr): return (a+arr[0]).realize()
for i in range(5):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
a.realize(), b.realize()
c = add_array(a, [b])
if i >= 2:
# should fail once jitted since jit can't handle arrays
np.testing.assert_allclose(np.any(np.not_equal(c.numpy(),a.numpy()+b.numpy())), True, atol=1e-4, rtol=1e-5)
else:
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
assert_jit_cache_len(add_array, 1)
def test_jit_copyin(self):
@TinyJit
def f(a):
return a + Tensor([1,2,3])
for _ in range(5):
b = Tensor.randn(3)
c = f(b)
np.testing.assert_allclose(c.numpy(), b.numpy()+[1,2,3], atol=1e-4, rtol=1e-5)
def test_method_jit(self):
class Fun:
def __init__(self):
self.a = Tensor.randn(10, 10)
@TinyJit
def __call__(self, b:Tensor) -> Tensor:
return (self.a+b).realize()
fun = Fun()
for _ in range(5):
b = Tensor.randn(10, 10)
c = fun(b)
np.testing.assert_allclose(c.numpy(), fun.a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
assert_jit_cache_len(fun.__call__.func.__self__, 1)
def test_jit_size1_input(self):
@TinyJit
def f(a, b): return (a+b).realize()
a = Tensor([1, 2, 3])
for i in range(5):
np.testing.assert_allclose(f(a, Tensor([i])).numpy(), (a+i).numpy(), atol=1e-4, rtol=1e-5)
assert_jit_cache_len(f, 1)
def test_jit_output_non_tensor_fail(self):
@TinyJit
def f(a, b, i): return (a+b).realize(), i
output1, output2 = [], []
expect1, expect2 = [], []
for i in range(5):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
o1, o2 = f(a, b, i)
output1.append(o1.numpy().copy())
output2.append(o2)
expect1.append(a.numpy().copy()+b.numpy().copy())
expect2.append(i)
np.testing.assert_allclose(output1, expect1, atol=1e-4, rtol=1e-5)
# the jit only works with Tensor outputs
assert output2 != expect2
assert_jit_cache_len(f, 1)
def test_jit_random_regen(self):
def f(a, b):
rn = Tensor.randn(*a.shape)
return ((a+b)*rn).realize()
a = Tensor.randn(10, 10).realize() # realize these before resetting the random seed
b = Tensor.randn(10, 10).realize()
Tensor.manual_seed(1234)
jf = TinyJit(f)
res = set()
for _ in range(5):
o1 = jf(a, b)
res.add(o1.numpy()[0][0])
assert len(res) == 5, "All values should be different, rand works in jit."
Tensor.manual_seed(1234)
jf2 = TinyJit(f)
res2 = set()
for _ in range(5):
o1 = jf2(a, b)
res2.add(o1.numpy()[0][0])
assert len(res2) == 5, "All values should be different, rand works in jit."
assert res == res2, "Jit rand is not reproducible with the same seed"
Tensor.manual_seed(3421)
jf3 = TinyJit(f)
res3 = set()
for _ in range(5):
o1 = jf3(a, b)
res3.add(o1.numpy()[0][0])
assert len(res3) == 5, "All values should be different, rand works in jit."
assert res3 != res2, "Jit rand is diff with diff seeds"
@unittest.expectedFailure # TODO: fix
def test_jit_v_nojit_random_regen(self):
def f(a, b):
rn = Tensor.randn(*a.shape)
rn = rn * a
rn2 = Tensor.randn(*a.shape)
rn2 = rn2 * b
rn = rn + rn2
rn2 = rn2 + Tensor.randn(*a.shape)
return ((a+b)*rn).realize(), ((a+b)*rn2).realize()
Tensor.manual_seed(0)
a = Tensor.randn(10, 10).realize() # realize these before resetting the random seed
b = Tensor.randn(10, 10).realize()
Tensor.manual_seed(1234)
without_jit = set()
for _ in range(5):
o1, o2 = f(a, b)
without_jit.add(o1.numpy()[0][0])
without_jit.add(o2.numpy()[0][0])
assert len(without_jit) == 10, "All values should be different."
Tensor.manual_seed(1234)
jf = TinyJit(f)
with_jit = set()
for _ in range(5):
o1, o2 = jf(a, b)
with_jit.add(o1.numpy()[0][0])
with_jit.add(o2.numpy()[0][0])
assert len(with_jit) == 10, "All values should be different."
assert with_jit == without_jit, "Jit rand produced different values from no jit."
def test_jit_multiple_random_regen(self):
def f(a, b):
rn = Tensor.randn(*a.shape)
rn = rn * a
rn2 = Tensor.randn(*a.shape)
rn2 = rn2 * b
rn = rn + rn2
rn2 = rn2 + Tensor.randn(*a.shape)
return ((a+b)*rn).realize(), ((a+b)*rn2).realize()
a = Tensor.randn(10, 10).realize() # realize these before resetting the random seed
b = Tensor.randn(10, 10).realize()
Tensor.manual_seed(1234)
jf = TinyJit(f)
res = set()
for _ in range(5):
o1, o2 = jf(a, b)
res.add(o1.numpy()[0][0])
res.add(o2.numpy()[0][0])
assert len(res) == 10, "All values should be different, rand works in jit."
Tensor.manual_seed(1234)
jf2 = TinyJit(f)
res2 = set()
for _ in range(5):
o1, o2 = jf2(a, b)
res2.add(o1.numpy()[0][0])
res2.add(o2.numpy()[0][0])
assert len(res2) == 10, "All values should be different, rand works in jit."
assert res == res2, "Jit rand is not reproducible with the same seed"
Tensor.manual_seed(3421)
jf3 = TinyJit(f)
res3 = set()
for _ in range(5):
o1, o2 = jf3(a, b)
res3.add(o1.numpy()[0][0])
res3.add(o2.numpy()[0][0])
assert len(res3) == 10, "All values should be different, rand works in jit."
assert res3 != res2, "Jit rand is diff with diff seeds"
#@unittest.expectedFailure # requires contiguous folding
def test_jit_random_after_unrealized_random(self):
@TinyJit
def f(): return Tensor.rand()
Tensor.manual_seed(1234)
Tensor.rand()
res = [f().numpy() for _ in range(3)]
assert res[1] != res[2]
def test_jit_realization_and_sampling(self):
w = Tensor.eye(5)
@TinyJit
def foo (x): return w.dot(x).realize()
arg = [
Tensor([1,2,3,4,5]),
Tensor([1,3,3,4,6]),
Tensor([1,2,5,4,7]),
Tensor([0,2,3,1,0]),
]
Y = [foo(e).numpy() for e in arg]
foo(Tensor([7,7,7,7,7]))
want = [[1., 2., 3., 4., 5.],
[1., 3., 3., 4., 6.],
[1., 2., 5., 4., 7.],
[0., 2., 3., 1., 0.]]
np.testing.assert_allclose(want, Y)
def test_jit_buffer_behavior(self):
@TinyJit
def foo(x) -> Tensor: return x.sum().realize()
result_1 = foo(Tensor([1] * 2))
result_2 = foo(Tensor([2] * 2))
result_3 = foo(Tensor([3] * 2))
# expect the buffer to share underlying buffer
np.testing.assert_allclose(result_1.numpy(), [2], atol=1e-4, rtol=1e-5)
np.testing.assert_allclose(result_2.numpy(), [6], atol=1e-4, rtol=1e-5)
np.testing.assert_allclose(result_3.numpy(), [6], atol=1e-4, rtol=1e-5)
def test_jit_batch_split(self):
if Device[Device.DEFAULT].graph is None or JIT >= 2: raise unittest.SkipTest("only test graphs")
# Create long jit with 83 kernels.
def f(a, b, c, d, e):
for _ in range(80):
a = (a+b).realize()
y = (a*c).realize()
z = (y*d).realize()
w = (z*e)
return w.realize()
a = Tensor.randn(10, 10).realize()
b = Tensor.randn(10, 10).realize()
c = Tensor.randn(10, 10).realize()
d = Tensor.randn(10, 10).realize()
e = Tensor.randn(10, 10).realize()
jf = TinyJit(f)
prev = None
for _ in range(5):
o = jf(a, b, c, d, e).numpy()
if prev is not None: np.testing.assert_allclose(o, prev, atol=1e-4, rtol=1e-5)
prev = o
graph_t = Device[Device.DEFAULT].graph.func if isinstance(Device[Device.DEFAULT].graph, functools.partial) else Device[Device.DEFAULT].graph
# Checking that 2 graphs are inited.
assert isinstance(jf.jit_cache[0].prg, graph_t)
assert isinstance(jf.jit_cache[1].prg, graph_t)
def test_jit_const_inputs(self):
@TinyJit
def g(x,y,z): return (x+y+z).realize()
for i in range(5):
np.testing.assert_equal(g(Tensor([i]*3), Tensor.ones(3), Tensor.zeros(3)).numpy(), np.array([i+1]*3))
def test_jitted_clone(self):
def f(a): return a.clone().realize()
jf = TinyJit(f)
for _ in range(5):
a = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
ja = jf(a)
np.testing.assert_allclose(a.numpy(), ja.numpy(), atol=1e-4, rtol=1e-5)
@unittest.skipIf(not_support_multi_device(), "no multi")
def test_jitted_transfers(self):
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
def f(a, b):
x = a.to(d1)
y = b.to(d1)
return x.realize(), y.realize()
jf = TinyJit(f)
for _ in range(5):
a = Tensor.randn(10, 10, device=d0).realize()
b = Tensor.randn(10, 10, device=d0).realize()
xc, yc = jf(a, b)
np.testing.assert_allclose(a.numpy(), xc.numpy(), atol=1e-4, rtol=1e-5)
np.testing.assert_allclose(b.numpy(), yc.numpy(), atol=1e-4, rtol=1e-5)
def test_jit_several_devs(self):
d0, d1 = f"{Device.DEFAULT}:0", "CPU"
def f(a, b):
x = a.to(d0).realize()
y = b.to(d0).realize()
return x+y.realize(), x*y.realize()
jf = TinyJit(f)
for _ in range(5):
a = Tensor.randn(10, 10, device=d1).realize()
b = Tensor.randn(10, 10, device=d1).realize()
zc, wc = jf(a, b)
np.testing.assert_allclose((a.numpy()+b.numpy()), zc.numpy(), atol=1e-4, rtol=1e-5)
np.testing.assert_allclose((a.numpy()*b.numpy()), wc.numpy(), atol=1e-4, rtol=1e-5)
@unittest.skipIf(not_support_multi_device(), "no multi")
def test_jitted_view(self):
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
def f(a):
x1 = a.sum(axis=(1,))
x = (x1 + 5).bitcast(dtypes.int32)
y = x.to(d1)
return y.realize()
jf = TinyJit(f)
for _ in range(5):
a = Tensor.randn(10, 1000, device=d0).realize()
xc = jf(a)
np.testing.assert_allclose((a.numpy().sum(axis=(1,)) + 5).view(np.int32), xc.numpy(), atol=1e-4, rtol=5e-5)
def test_jit_output_clone(self):
@TinyJit
def f(x:Tensor) -> Tensor: return (x + 1).realize()
f(Tensor([0.0]))
f(Tensor([0.0]))
a = f(Tensor([1.0])).clone().realize()
b = f(Tensor([2.0]))
assert abs((a - b).item()) > 0.5
@unittest.skip("Pending multioutput implementation #3607")
class TestMultioutputJit(unittest.TestCase):
def _test(self, f):
for _ in range(5):
a, b = Tensor.randn(10, 10), Tensor.randn(10, 10)
out0, out1, out2 = f(a, b)
np.testing.assert_allclose(out0.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
np.testing.assert_allclose(out1.numpy(), a.numpy()-b.numpy(), atol=1e-4, rtol=1e-5)
np.testing.assert_allclose(out2.numpy(), a.numpy()*b.numpy(), atol=1e-4, rtol=1e-5)
def test_jit_multioutput_realize(self):
@TinyJit
def fxn(a, b): return (a+b).realize(), (a-b).realize(), (a*b).realize()
self._test(fxn)
assert_jit_cache_len(fxn, 3)
def test_jit_multioutput_norealize(self):
@TinyJit
def fxn(a, b): return a+b, a-b, a*b
self._test(fxn)
assert_jit_cache_len(fxn, 1)
def test_jit_multioutput_mix(self):
@TinyJit
def fxn(a, b): return a+b, a-b, (a*b).realize()
self._test(fxn)
assert_jit_cache_len(fxn, 2)
class TestJitInsideJit(unittest.TestCase):
def test_jit_jit_error(self):
@TinyJit
def f(t): return t + 1
@TinyJit
def g(t): return f(t) * 3
# NOTE: first does not raise
g(Tensor([1])).realize()
with self.assertRaisesRegex(RuntimeError, "having TinyJit inside another TinyJit is not supported"):
g(Tensor([1])).realize()
class TestCopyInsideJit(unittest.TestCase):
def test_copy_inside_jit(self):
@TinyJit
def add(x,y) -> Tensor: return x.to(Device.DEFAULT)+y
for _ in range(5):
# create a Tensor on CPU
a = Tensor.rand(16,16,device="CPU").realize()
b = Tensor.rand(16,16).realize()
out = add(a,b)
np.testing.assert_allclose(out.flatten().tolist(), [x+y for x,y in zip(a.flatten().tolist(), b.flatten().tolist())])
class TestJitPrune(unittest.TestCase):
def test_simple_prune(self):
weights = Tensor.rand(16).realize()
def w2(x) -> Tensor: return (weights*2).contiguous() + x
w2_noprune = TinyJit(w2)
w2_prune = TinyJit(w2, prune=True)
for _ in range(3):
a = Tensor.rand(16).realize()
out = w2_noprune(a)
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
assert len(w2_noprune.captured.jit_cache) == 2
for _ in range(3):
a = Tensor.rand(16).realize()
out = w2_prune(a)
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
assert len(w2_prune.captured.jit_cache) == 1
def test_prune_w_copy_correct(self):
weights = Tensor.rand(16).realize()
def w2(x) -> Tensor: return (weights*2).contiguous() + x.to(Device.DEFAULT)
w2_noprune = TinyJit(w2)
w2_prune = TinyJit(w2, prune=True)
for _ in range(3):
a = Tensor.rand(16, device="CPU").realize()
out = w2_noprune(a)
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
for _ in range(3):
a = Tensor.rand(16, device="CPU").realize()
out = w2_prune(a)
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
def test_prune_w_independent_copy_correct(self):
weights = Tensor.rand(16, device="CPU").realize()
def w2(x) -> Tensor: return (weights*2).contiguous().to(Device.DEFAULT) + x
w2_noprune = TinyJit(w2)
w2_prune = TinyJit(w2, prune=True)
for _ in range(3):
a = Tensor.rand(16).realize()
out = w2_noprune(a)
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
for _ in range(3):
a = Tensor.rand(16).realize()
out = w2_prune(a)
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
assert len(w2_prune.captured.jit_cache) == 1, "prune should have removed the copy"
class TestJitFree(unittest.TestCase):
def test_free_intermediates(self):
ext_tensor = Tensor([1,24,23,45,1])
@TinyJit
def fxn(x:Tensor):
out = (x*2+ext_tensor).reshape(5,1).expand(5, 100).contiguous()
return out.sum()
for i in range(5):
out = fxn(Tensor([i,1,2,3,4]))
self.assertEqual(out.item(), 11400+200*i)
pre_free = GlobalCounters.mem_used
fxn.captured.free_intermediates()
savings_after_free = pre_free - GlobalCounters.mem_used
# Different allocator implementations have different savings.
expected_savings = 8196 if hasattr(Device[Device.DEFAULT].allocator, '_offset') else 2024
self.assertEqual(savings_after_free, expected_savings)
out = fxn(Tensor([11,1,2,3,4]))
self.assertEqual(out.item(), 13600)
# Try one more time...
pre_free = GlobalCounters.mem_used
fxn.captured.free_intermediates()
fxn.captured.free_intermediates() # 2nd time to validate
savings_after_free = pre_free - GlobalCounters.mem_used
self.assertEqual(savings_after_free, expected_savings)
out = fxn(Tensor([11,1,2,3,4]))
self.assertEqual(out.item(), 13600)
def test_updated_not_freed(self):
x = Tensor([1]).realize()
@TinyJit
def fxn(y):
nonlocal x
x += y
return x
for _ in range(5): fxn(Tensor([1]))
self.assertEqual(x.item(), 6)
pre_free = GlobalCounters.mem_used
fxn.captured.free_intermediates()
savings_after_free = pre_free - GlobalCounters.mem_used
self.assertEqual(savings_after_free, 0)
fxn(Tensor([2]))
self.assertEqual(x.item(), 8)
def test_replan_buffers_memory_layout(self):
if not hasattr(Device[Device.DEFAULT].allocator, '_offset'): raise unittest.SkipTest("replan_buffers_memory_layout useless")
ext_tensor = Tensor([1,24,23,45,1])
ext_tensor_2 = Tensor([2,2,2,2,2])
@TinyJit
def fxn(x:Tensor):
out = (x*ext_tensor_2+ext_tensor).reshape(5,1).expand(5, 100).contiguous()
return out.sum()
for i in range(5):
out = fxn(Tensor([i,1,2,3,4]))
self.assertEqual(out.item(), 11400+200*i)
assert len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])) == 4
fxn.captured.replan_buffers_memory_layout()
assert len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])) == 2
out = fxn(Tensor([11,1,2,3,4]))
self.assertEqual(out.item(), 13600)
class TestJitGraphSplit(unittest.TestCase):
def compute(self, device, inp):
assert inp.device == device, f"Input device {inp.device} does not match expected {device}"
return (inp + 1.0).contiguous().realize()
def copy(self, device, to_device, inp):
assert inp.device == device, f"Input device {inp.device} does not match expected {device}"
return inp.to(to_device).realize()
def expect(self, f, *args, graph=None, multigraph=None, hcqgraph=None):
def _numpies(tpl): return tpl.numpy() if tpl.__class__ is Tensor else tuple([t.numpy() for t in tpl])
expected = _numpies(f(*args))
for i in range(4):
res = _numpies(f(*args))
np.testing.assert_allclose(res, expected, atol=1e-4, rtol=1e-5)
dev = Device[Device.DEFAULT]
graph_t = graph_class(dev)
if graph_t is None: return
got = f.jit_cache
from tinygrad.runtime.graph.hcq import HCQGraph
if graph_t is HCQGraph:
validate = hcqgraph
elif issubclass(graph_t, MultiGraphRunner):
validate = multigraph
else:
validate = graph
assert len(got) == len(validate), f"Expected {len(validate)} operations, got {len(got)}"
for expected, got in zip(validate, got):
if expected["type"] == "graph":
assert isinstance(got.prg, GraphRunner), f"Expected GraphRunner, got {type(got.prg)}"
assert len(got.prg.jit_cache) == expected["cnt"], f"Expected {expected['cnt']} operations in graph, got {len(got.prg.jit_cache)}"
elif expected["type"] == "comp":
assert isinstance(got.prg, CompiledRunner), f"Expected CompiledRunner, got {type(got.prg)}"
elif expected["type"] == "copy":
assert isinstance(got.prg, BufferCopy), f"Expected BufferCopy, got {type(got.prg)}"
elif expected["type"] == "xfer":
assert isinstance(got.prg, BufferXfer), f"Expected BufferXfer, got {type(got.prg)}"
def ji_graph(self, cnt): return {"type": "graph", "cnt": cnt}
def ji_comp(self): return {"type": "comp"}
def ji_copy(self): return {"type": "copy"}
def ji_xfer(self): return {"type": "xfer"}
def test_jit_split_simple(self):
@TinyJit
def f(inp):
op0 = self.compute(Device.DEFAULT, inp)
op1 = self.compute(Device.DEFAULT, op0)
op2 = self.compute(Device.DEFAULT, op1)
return op2
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
self.expect(f, inp,
graph=[self.ji_graph(3)],
multigraph=[self.ji_graph(3)],
hcqgraph=[self.ji_graph(3)])
def test_jit_cpu_simple(self):
if Device.DEFAULT == "CPU": raise unittest.SkipTest("CPU is not a valid default device for this test")
@TinyJit
def f(inp, inp_cpu):
op0 = self.compute(Device.DEFAULT, inp)
op1 = self.compute(Device.DEFAULT, op0)
op2 = self.compute("CPU", inp_cpu)
op3 = self.compute(Device.DEFAULT, op1)
return op2, op3
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
inp_cpu = Tensor.randn(10, 10, device="CPU").realize()
self.expect(f, inp, inp_cpu,
graph=[self.ji_graph(2), self.ji_comp(), self.ji_comp()],
multigraph=[self.ji_graph(2), self.ji_comp(), self.ji_comp()],
hcqgraph=[self.ji_graph(4)])
def test_jit_cpu_several(self):
if Device.DEFAULT == "CPU": raise unittest.SkipTest("CPU is not a valid default device for this test")
@TinyJit
def f(inp, inp_cpu):
op0 = self.compute(Device.DEFAULT, inp)
op1 = self.compute(Device.DEFAULT, op0)
op2 = self.compute("CPU", inp_cpu)
op3 = self.compute("CPU", op2)
op4 = self.compute(Device.DEFAULT, op1)
return op3, op4
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
inp_cpu = Tensor.randn(10, 10, device="CPU").realize()
self.expect(f, inp, inp_cpu,
graph=[self.ji_graph(2), self.ji_graph(2), self.ji_comp()],
multigraph=[self.ji_graph(2), self.ji_graph(2), self.ji_comp()],
hcqgraph=[self.ji_graph(5)])
def test_jit_multidev(self):
if Device.DEFAULT == "CPU": raise unittest.SkipTest("CPU is not a valid default device for this test")
try: Device[f"{Device.DEFAULT}:1"]
except Exception: raise unittest.SkipTest("no multidevice")
@TinyJit
def f(inp, inp_d1):
op0 = self.compute(Device.DEFAULT, inp)
op1 = self.compute(Device.DEFAULT, op0)
op2 = self.compute(f"{Device.DEFAULT}:1", inp_d1)
op3 = self.compute(f"{Device.DEFAULT}:1", op2)
op4 = self.compute(Device.DEFAULT, op1)
return op3, op4
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
inp_d1 = Tensor.randn(10, 10, device=f"{Device.DEFAULT}:1").realize()
self.expect(f, inp, inp_d1,
graph=[self.ji_graph(2), self.ji_graph(2), self.ji_comp()],
multigraph=[self.ji_graph(5)],
hcqgraph=[self.ji_graph(5)])
def test_jit_multidev_xfer(self):
if Device.DEFAULT in {"CPU", "LLVM"}: raise unittest.SkipTest("CPU/LLVM is not a valid default device for this test (zero-copies)")
if Device.DEFAULT == "METAL" or REAL_DEV == "METAL": raise unittest.SkipTest("Metal is flaky, with multidevice (same as metal llama 4gpu?)")
try: Device[f"{Device.DEFAULT}:1"]
except Exception: raise unittest.SkipTest("no multidevice")
@TinyJit
def f(inp, inp_d1):
op0 = self.compute(Device.DEFAULT, inp)
op1 = self.compute(Device.DEFAULT, op0)
op2 = self.compute(f"{Device.DEFAULT}:1", inp_d1)
op3 = self.copy(f"{Device.DEFAULT}:1", Device.DEFAULT, op2)
op4 = self.compute(f"{Device.DEFAULT}:1", op2)
op5 = self.compute(Device.DEFAULT, op3)
return op1, op4, op5
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
inp_d1 = Tensor.randn(10, 10, device=f"{Device.DEFAULT}:1").realize()
self.expect(f, inp, inp_d1,
graph=[self.ji_graph(2), self.ji_comp(), self.ji_xfer(), self.ji_comp(), self.ji_comp()],
multigraph=[self.ji_graph(6)],
hcqgraph=[self.ji_graph(6)])
@unittest.skipIf(getenv("MOCKGPU"), "MockGPU does not support parallel copies")
def test_jit_multidev_copy(self):
if Device.DEFAULT in {"CPU", "LLVM"}: raise unittest.SkipTest("CPU/LLVM is not a valid default device for this test (zero-copies)")
@TinyJit
def f(inp):
op0 = self.compute(Device.DEFAULT, inp)
op1 = self.compute(Device.DEFAULT, op0)
op2 = self.copy(Device.DEFAULT, "CPU", op1)
op3 = self.compute("CPU", op2)
return op3
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
self.expect(f, inp,
graph=[self.ji_graph(2), self.ji_copy(), self.ji_comp()],
multigraph=[self.ji_graph(2), self.ji_copy(), self.ji_comp()],
hcqgraph=[self.ji_graph(4)])
if __name__ == '__main__':
unittest.main()