openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

719 lines
33 KiB

from typing import List
import unittest, time, pytest
from tinygrad import dtypes, Device
from tinygrad.helpers import DEBUG
from tinygrad.ops import Ops, UOp, KernelInfo, UPat, PatternMatcher, track_rewrites
from tinygrad.renderer import Renderer
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
from tinygrad.codegen.devectorizer import full_graph_rewrite, graph_rewrite, sym
from tinygrad.codegen.expander import expander, expand_rewrite
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.shape.shapetracker import ShapeTracker, View
simple_pm = PatternMatcher([
(UPat.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
(UPat.cvar('x') + UPat.cvar('y'), lambda x,y: UOp.const(dtypes.float, x.arg+y.arg)),
(UPat.cvar('x') * UPat.cvar('y') * UPat.cvar('z'), lambda x,y,z: UOp.const(dtypes.float, x.arg*y.arg*z.arg)),
((UPat.var('x') + UPat.cvar('c1')) + UPat.cvar('c2'), lambda x,c1,c2: x + (c1.arg+c2.arg)),
])
def to_uops_list(u:List[UOp]) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u)))
class TestGraphRewriteEfficiency(unittest.TestCase):
def test_create_many_uops(self):
c1 = UOp.const(dtypes.int, 1)
c2 = UOp.const(dtypes.int, 2)
st = time.perf_counter()
uops = [UOp(Ops.ADD, dtypes.int, (c1, c2)) for _ in range(10000)]
et = time.perf_counter() - st
print(f"created {len(uops)} uops in {et*1000:.2f} ms")
def test_expand_rewrite(self):
sink = UOp(Ops.SINK, dtypes.void, arg=KernelInfo(local_dims=2, upcasted=4, dont_use_locals=False), src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 4, 64, 8, 16, 1, 1, 3, 3, 4, 1),
strides=(1179648, 9216, 1, 147456, 576, 0, 0, 64, 192, 36864, 0),
offset=0, mask=None, contiguous=False),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 10)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(
View(shape=(1, 1024, 1, 64, 4, 17, 4, 17), strides=(0, 14400, 0, 225, 0, 15, 0, 1), offset=-16,
mask=((0, 1), (0, 1024), (0, 1), (0, 64), (0, 4), (1, 16), (0, 4), (1, 16)), contiguous=False),
View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(0, 73984, 4734976, 0, 4624, 295936, 68, 18, 1224, 0, 1), offset=0,
mask=None, contiguous=False))), src=()),)),
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(
View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(7200, 0, 230400, 900, 0, 14400, 15, 0, 0, 225, 1), offset=0,
mask=None, contiguous=False),)), src=()),)),)),)),)),)),))
lower_sink = rewrite_shapetracker_with_index(sink, Device[Device.DEFAULT].renderer)
cnt = [0]
old_init = UOp.__init__
def uop_hook(self, *args, **kwargs):
cnt[0] += 1
old_init(self, *args, **kwargs)
UOp.__init__ = uop_hook
st = time.perf_counter()
new_sink = full_graph_rewrite(lower_sink)
et = time.perf_counter() - st
UOp.__init__ = old_init
print(f"rewrote in {et*1000:.2f} ms, from {len(lower_sink.toposort)} -> {len(new_sink.toposort)}, creating {cnt[0]} uops")
class TestGraphRewriteConst(unittest.TestCase):
def test_gep_const(self):
v1 = UOp.const(dtypes.int.vec(3), (0,1,2))
v2 = v1.gep(1)
ret = graph_rewrite(v2, sym)
self.assertEqual(ret.dtype, dtypes.int)
self.assertEqual(ret.arg, 1)
def test_gep_const_single(self):
v1 = UOp.const(dtypes.int.vec(3), 4)
v2 = v1.gep(1)
ret = graph_rewrite(v2, sym)
self.assertEqual(ret.dtype, dtypes.int)
self.assertEqual(ret.arg, 4)
def test_add_const(self):
v1 = UOp.const(dtypes.int.vec(3), (0,1,2))
v2 = UOp.const(dtypes.int.vec(3), (5,6,7))
ret = graph_rewrite(v1+v2, sym)
self.assertEqual(ret.op, Ops.VCONST)
self.assertEqual(ret.dtype, dtypes.int.vec(3))
self.assertEqual(ret.arg, (5,7,9))
def test_add_const_lose_v(self):
v1 = UOp.const(dtypes.int.vec(3), (0,1,2))
v2 = UOp.const(dtypes.int.vec(3), (2,1,0))
ret = graph_rewrite(v1+v2, sym)
self.assertEqual(ret.op, Ops.CONST)
self.assertEqual(ret.dtype, dtypes.int.vec(3))
self.assertEqual(ret.arg, 2)
xfail_broken_const_wraparound = pytest.mark.xfail(reason="const folding does not properly implement modular arithmetic")
class TestModularWraparound(unittest.TestCase):
def _test(self, uop:UOp, expected:int):
results = to_uops_list([uop])
self.assertEqual(len(results), 1)
self.assertEqual(results[0].op, Ops.CONST)
self.assertEqual(results[0].dtype, uop.dtype)
self.assertEqual(results[0].arg, expected)
@xfail_broken_const_wraparound
def test_cast(self):
t = self._test
t(UOp.const(dtypes.uint, 0xABCD17D6).cast(dtypes.uint8), 0xD6)
t(UOp.const(dtypes.uint, 0xABCD17D6).cast(dtypes.uint8).cast(dtypes.uint), 0xD6)
@xfail_broken_const_wraparound
def test_mul(self):
t = self._test
t(UOp.const(dtypes.uint, 0xABCD17D6) * 0xAABBCCDD, 1147018174)
t(UOp.const(dtypes.int, 0xABCD17D6) * 10, -1241321892)
@xfail_broken_const_wraparound
def test_div(self):
t = self._test
t(UOp.const(dtypes.uint, 0xABCD17D6) * 0xAABBCCDD // 11, 104274379)
t(UOp.const(dtypes.int, 0xABCD17D6) * 10 // 11, -112847444)
@xfail_broken_const_wraparound
def test_neg(self):
t = self._test
t(-UOp.const(dtypes.uint8, 1), 0xFF)
t(-UOp.const(dtypes.uint16, 1), 0xFFFF)
t(-UOp.const(dtypes.uint32, 1), 0xFFFFFFFF)
t(-UOp.const(dtypes.uint64, 1), 0xFFFFFFFFFFFFFFFF)
@xfail_broken_const_wraparound
def test_neg_min_int(self):
t = self._test
t(-UOp.const(dtypes.int8, -2**7), -2**7)
t(-UOp.const(dtypes.int16, -2**15), -2**15)
t(-UOp.const(dtypes.int32, -2**31), -2**31)
t(-UOp.const(dtypes.int64, -2**63), -2**63)
@xfail_broken_const_wraparound
def test_payne_hanek_reduction_bug(self):
t = self._test
a = (UOp.const(dtypes.uint, 43748177600).cast(dtypes.uint) | 36).cast(dtypes.ulong)
b = 2536655455 * a + 4294967296 * UOp.const(dtypes.ulong, 25366554550)
c = (b + 2261737165) // 4611686018427387904
t(c, 0)
class TestGraphRewrite(unittest.TestCase):
def test_dedup(self):
v1 = UOp(Ops.DEFINE_VAR, dtypes.float)
v2 = UOp(Ops.DEFINE_VAR, dtypes.float)
nout = graph_rewrite(v1+v2, PatternMatcher([]))
self.assertIs(nout.src[0], nout.src[1])
# NOTE: this shows why we can't have a UOp in arg
@unittest.expectedFailure
def test_no_dedup_args(self):
a1 = UOp(Ops.DEFINE_VAR, dtypes.int, (), ("a1", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 11)))
a2 = UOp(Ops.DEFINE_VAR, dtypes.int, (), ("a2", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 11)))
sink = a1.sink(a2)
define_vars = [x for x in graph_rewrite(sink, PatternMatcher([])).toposort if x.op is Ops.DEFINE_VAR]
self.assertEqual(len(define_vars), 1)
def test_simple(self):
c1 = UOp.const(dtypes.float, 1.0)
c2 = UOp.const(dtypes.float, 2.0)
nout = graph_rewrite(c1+c2, simple_pm)
self.assertEqual(nout.op, Ops.CONST)
self.assertEqual(nout.arg, 3.0)
def test_depth_2_late(self):
c1 = UOp.const(dtypes.float, 1.0)
c2 = UOp.const(dtypes.float, 2.0)
c3 = UOp.const(dtypes.float, 3.0)
nout = graph_rewrite(c1*c2*(c3+c3), simple_pm)
self.assertEqual(nout.op, Ops.CONST)
self.assertEqual(nout.arg, 12.0)
def test_double(self):
c1 = UOp.const(dtypes.float, 1.0)
c2 = UOp.const(dtypes.float, 2.0)
c3 = UOp.const(dtypes.float, 3.0)
nout = graph_rewrite(c1+c2+c3, simple_pm)
self.assertEqual(nout.op, Ops.CONST)
self.assertEqual(nout.arg, 6.0)
def test_triple(self):
c1 = UOp.const(dtypes.float, 1.0)
c2 = UOp.const(dtypes.float, 2.0)
c3 = UOp.const(dtypes.float, 3.0)
c4 = UOp.const(dtypes.float, 4.0)
nout = graph_rewrite(c1+c2+c3+c4, simple_pm)
self.assertEqual(nout.op, Ops.CONST)
self.assertEqual(nout.arg, 10.0)
def test_diamond(self):
c1 = UOp.const(dtypes.float, 1.0)
c2 = UOp.const(dtypes.float, 2.0)
c3 = UOp.const(dtypes.float, 3.0)
nout = graph_rewrite((c1+c2)+(c1+c3), simple_pm)
self.assertEqual(nout.op, Ops.CONST)
self.assertEqual(nout.arg, 7.0)
def test_magic_4(self):
c1 = UOp.const(dtypes.int, 4.0)
nout = graph_rewrite(c1, simple_pm)
self.assertEqual(nout.op, Ops.CONST)
self.assertEqual(nout.arg, 3.0)
def test_depth_2_fold(self):
v = UOp(Ops.DEFINE_VAR, dtypes.float)
c1 = UOp.const(dtypes.float, 1.0)
c2 = UOp.const(dtypes.float, 2.0)
nout = graph_rewrite(v+c1+c2, simple_pm)
self.assertEqual(nout.op, Ops.ADD)
self.assertEqual(nout.src[0].op, Ops.DEFINE_VAR)
self.assertEqual(nout.src[1].op, Ops.CONST)
self.assertEqual(nout.src[1].arg, 3.0)
def test_commutative_work(self):
a = UOp.variable('a', 0, 1)
b = UOp.variable('b', 0, 1)
self.assertIs((a+b).simplify(), (b+a).simplify())
def test_consts_go_last_right_away(self):
a = UOp.variable('a', 0, 1)
tst = (2+a).simplify()
self.assertIs(tst.src[0], a)
self.assertIs(tst.src[1], a.const_like(2))
def test_consts_go_last(self):
a = UOp.variable('a', 0, 1)
b = UOp.variable('b', 0, 1)
c = UOp.variable('c', 0, 1)
d = UOp.variable('d', 0, 1)
outs = [2+a, 2+a+d+3+b+c+4, UOp(Ops.ADD, a.dtype, src=(a.const_like(2), a)), (4+d)+c+(2+a)+b]
for out in outs:
sink = graph_rewrite(out, sym)
print(sink.render())
self.assertEqual(sink.op, Ops.ADD)
self.assertEqual(sink.src[1].op, Ops.CONST)
self.assertEqual(len([x for x in sink.toposort if x.op is Ops.CONST]), 1)
class TestUOpGraph(unittest.TestCase):
def test_add_constant_fold(self):
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
out = UOp(Ops.ADD, dtypes.float, (c1, c2))
uops = to_uops_list([out])
self.assertEqual(len(uops), 1)
out = uops[-1]
self.assertEqual(out.op, Ops.CONST)
self.assertEqual(out.arg, 3.0)
def test_where_same_fold(self):
v = UOp.variable('tmp', 0, 1)
c0 = UOp(Ops.CONST, dtypes.int, arg=0)
vc = UOp(Ops.CMPNE, dtypes.bool, (v, c0))
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
out = UOp(Ops.WHERE, dtypes.float, (vc, c1, c1))
uops = to_uops_list([out])
self.assertEqual(len(uops), 1)
out = uops[-1]
self.assertEqual(out.op, Ops.CONST)
self.assertEqual(out.arg, 1.0)
def test_where_const_fold(self):
bf = UOp(Ops.CONST, dtypes.bool, arg=False)
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
out = UOp(Ops.WHERE, dtypes.float, (bf, c1, c2))
uops = to_uops_list([out])
self.assertEqual(len(uops), 1)
out = uops[-1]
self.assertEqual(out.op, Ops.CONST)
self.assertEqual(out.arg, 2.0)
def test_const_cast(self):
bf = UOp(Ops.CONST, dtypes.bool, arg=False)
out = UOp(Ops.CAST, dtypes.int, (bf,))
uops = to_uops_list([out])
self.assertEqual(len(uops), 1)
out = uops[-1]
self.assertEqual(out.op, Ops.CONST)
self.assertEqual(out.arg, 0)
@unittest.skip("this test isn't valid uops")
def test_noop_vectorize_fold(self):
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0)
idx = UOp.const(dtypes.int, 0)
ld = UOp(Ops.LOAD, dtypes.float.vec(2), (d0, idx))
vec = UOp(Ops.VECTORIZE, dtypes.float.vec(2), (ld,))
x = UOp(Ops.GEP, dtypes.float, (vec, ), arg=0)
alu = UOp(Ops.SQRT, dtypes.float, (x, ))
out = UOp(Ops.STORE, dtypes.void, (d0, idx, alu))
uops = to_uops_list([out])
self.assertEqual(len([x for x in uops if x.op is Ops.VECTORIZE]), 0)
def test_gep_vec_fold(self):
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
d2 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 2)
idx = UOp.const(dtypes.int, 0)
def _test_vec(geps, count=4):
vec = UOp(Ops.VECTORIZE, dtypes.float.vec(count), geps)
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), vec))
uops = to_uops_list([out])
if DEBUG >= 4:
from tinygrad import Device
print(Device[Device.DEFAULT].renderer.render("test", uops))
return uops[-1].src[-1]
# possible
val = UOp(Ops.LOAD, dtypes.float.vec(4), (d1.index(idx),))
xyzw = tuple(UOp(Ops.GEP, dtypes.float, (val,), (i,)) for i in range(4))
self.assertIs(_test_vec(xyzw).op, Ops.LOAD)
# unaligned
val = UOp(Ops.LOAD, dtypes.float.vec(4), (d1.index(idx),))
wzyx = tuple(UOp(Ops.GEP, dtypes.float, (val,), (i,)) for i in reversed(range(4)))
self.assertIs(_test_vec(wzyx).op, Ops.VECTORIZE)
# different_size
val = UOp(Ops.LOAD, dtypes.float.vec(2), (d1.index(idx),))
xy = tuple(UOp(Ops.GEP, dtypes.float, (val, ), (i,)) for i in range(2))
self.assertIs(_test_vec(xy+xy).op, Ops.VECTORIZE)
val = UOp(Ops.LOAD, dtypes.float.vec(4), (d1.index(idx),))
xy = tuple(UOp(Ops.GEP, dtypes.float, (val, ), (i,)) for i in range(2))
self.assertIs(_test_vec(xy, count=2).op, Ops.VECTORIZE)
# different vals
val1 = UOp(Ops.LOAD, dtypes.float.vec(2), (d1.index(idx),))
val2 = UOp(Ops.LOAD, dtypes.float.vec(2), (d2.index(idx),))
xy1 = tuple(UOp(Ops.GEP, dtypes.float, (val1, ), (i,)) for i in range(2))
xy2 = tuple(UOp(Ops.GEP, dtypes.float, (val2, ), (i,)) for i in range(2))
self.assertIs(_test_vec(xy1+xy2).op, Ops.VECTORIZE)
def test_gep_vec_const_fold(self):
for vec_size in [2, 4, 8]:
consts = [UOp.const(dtypes.float, float(i)) for i in range(vec_size)]
vec = UOp(Ops.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts))
uops = to_uops_list([UOp(Ops.GEP, dtypes.float, (vec,), (i,)) for i in range(vec_size)])
for uop, const in zip(uops, consts):
self.assertEqual(uop, const)
def test_wmma_vectorize_fold(self):
for i in [2, 4, 8]:
vec = UOp(Ops.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i))
acc = UOp.variable('acc', 0, 1, dtypes.half.vec(i))
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
self.assertEqual(uops[0], acc)
self.assertEqual(len(uops), 1)
for i in [2, 4, 8]:
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i))
vec = UOp(Ops.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
acc = UOp.variable('acc', 0, 1, dtypes.half.vec(i))
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
self.assertEqual(uops[0], acc)
self.assertEqual(len(uops), 1)
@unittest.skip("wmma is wrong here, it needs an arg")
def test_wmma_vectorize_no_fold(self):
for i in [4, 8]:
vec = UOp(Ops.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
tuple(UOp(Ops.DEFINE_VAR, dtypes.half, arg=(f'tmp{j}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2)))
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
self.assertEqual(uops[-1], wmma)
for i in [4, 8]:
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
vec = UOp(Ops.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
tuple(UOp(Ops.DEFINE_VAR, dtypes.half, arg=(f'tmp{j}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2)))
acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
self.assertEqual(uops[-1], wmma)
for i in [2, 4, 8]:
vec = UOp(Ops.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
self.assertEqual(uops[-1], wmma)
for i in [2, 4, 8]:
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
vec = UOp(Ops.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
self.assertEqual(uops[-1], wmma)
def test_cast_alu_fold(self):
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0)
d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1)
idx = UOp.const(dtypes.int, 0)
ld = UOp(Ops.LOAD, dtypes.int, (d1.index(idx),))
alu = (ld<1).cast(dtypes.bool)
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), alu))
uops = to_uops_list([out])
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 0)
def test_double_cast_fold(self):
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0)
d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1)
idx = UOp.const(dtypes.int, 0)
ld = UOp(Ops.LOAD, dtypes.int, (d1.index(idx),))
alu = ld.cast(dtypes.float).cast(dtypes.float)
out = UOp(Ops.STORE, dtypes.void, (d0.index(idx), alu))
uops = to_uops_list([out])
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 1)
def test_depth_2_const_fold(self):
v = UOp.variable("tmp", 0, 1)
c2 = UOp(Ops.CONST, dtypes.int, arg=2)
c4 = UOp(Ops.CONST, dtypes.int, arg=4)
vc = UOp(Ops.ADD, dtypes.int, (v, c2))
out = UOp(Ops.ADD, dtypes.int, (vc, c4))
uops = to_uops_list([out])
self.assertEqual(len(uops), 3)
out = uops[-1]
self.assertEqual(out.op, Ops.ADD)
self.assertEqual(out.src[1].op, Ops.CONST)
self.assertEqual(out.src[1].arg, 6)
def test_bitcast_to_same_dtype_fold(self):
for dt in dtypes.ints + dtypes.floats + (dtypes.bool,):
d0 = UOp(Ops.DEFINE_GLOBAL, dt.ptr(), arg=0)
v = UOp(Ops.LOAD, dt, (d0.index(UOp.const(dtypes.int, 0)),))
uops = to_uops_list([v.bitcast(dt)])
self.assertEqual(len([x for x in uops if x.op is Ops.BITCAST]), 0, f"dtype = {dt}")
def test_out_of_bounds_access(self):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 42)),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_fold_gated_load(self):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
glbl2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2)
idx = UOp.const(dtypes.int, 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl1.index(idx, UOp.const(dtypes.bool, False)),))
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl2.index(idx, UOp.const(dtypes.bool, True)),))
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))])
ld0 = uops[-1].src[-1]
# the gate and invalid value are deleted from ld1
self.assertEqual(ld0, UOp.load(glbl2.index(idx), dtype=dtypes.int))
def test_fold_gated_load_local(self):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
smem = UOp(Ops.DEFINE_LOCAL, dtypes.int.ptr(size=18, local=True), (), "temp")
lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 16))
st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx), UOp.load(glbl0.index(lidx), dtype=dtypes.int)))
barrier = UOp(Ops.BARRIER, dtypes.void, (st, ))
ld0 = UOp(Ops.LOAD, dtypes.int, (smem.index(lidx+1, UOp.const(dtypes.bool, False)), barrier))
ld1 = UOp(Ops.LOAD, dtypes.int, (smem.index(lidx+2, UOp.const(dtypes.bool, True)), barrier))
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx), ld1+ld0))])
ld0 = uops[-1].src[-1]
# the gate and invalid value are deleted from ld1
self.assertEqual(ld0.src[0], smem.index(lidx+2))
def test_fold_gated_store(self):
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
idx0 = UOp.const(dtypes.int, 0)
idx1 = UOp.const(dtypes.int, 0)
val = UOp.const(dtypes.int, 42)
st0 = UOp(Ops.STORE, dtypes.void, (glbl.index(idx0, UOp.const(dtypes.bool, False)), val))
st1 = UOp(Ops.STORE, dtypes.void, (glbl.index(idx1, UOp.const(dtypes.bool, True)), val))
uops = to_uops_list([st0, st1])
# only the second store happens
self.assertEqual(len(uops), 5)
self.assertEqual(uops[-1], UOp.store(glbl.index(idx1), val))
@unittest.skip("this is a uop type error")
def test_asserts_bad_gate(self):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
idx = UOp.const(dtypes.int, 0)
bad_gate = UOp.const(dtypes.int, 1)
with self.assertRaises(AssertionError): to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))])
def test_switched_range_order(self):
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
c0 = UOp.const(dtypes.int, 0)
c2 = UOp.const(dtypes.int, 2)
cf = UOp.const(dtypes.float, 0.0)
r1 = UOp(Ops.RANGE, dtypes.int, (c0, c2), 0)
r2 = UOp(Ops.RANGE, dtypes.int, (c0, c2), 1)
alu = UOp(Ops.MUL, dtypes.int, (r2, r1))
store = UOp(Ops.STORE, dtypes.void, (glbl.index(alu), cf))
uops = to_uops_list([store])
ranges = [x for x in uops if x.op is Ops.RANGE]
endranges = [x for x in uops if x.op is Ops.ENDRANGE]
# ranges are closed in the right order
self.assertEqual(endranges[-1].src[0], ranges[0])
@track_rewrites()
def expander_rewrite(sink): return graph_rewrite(sink, sym + expander)
@track_rewrites()
def float4_rewrite(sink): return full_graph_rewrite(sink, Renderer())
class TestExpander(unittest.TestCase):
def test_expand_add_broadcast(self):
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
sink = expander_rewrite(e1+3)
assert sink.op is Ops.UNROLL and len(sink.src[0].arg) == 4
self.assertTupleEqual(sink.src[0].arg, (3,4,5,6))
def test_contract_simple(self):
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
con = UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
sink = expander_rewrite(con)
self.assertEqual(sink.op, Ops.VCONST)
self.assertTupleEqual(sink.arg, (0,1,2,3))
def test_contract_axis_1(self):
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4)))
con = UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
sink = expander_rewrite(con)
assert sink.op is Ops.UNROLL and len(sink.src[0].arg) == 16 and sink.arg == ((2,4),)
assert sink.src[0].op is Ops.VCONST
self.assertTupleEqual(sink.src[0].arg[0:4], (0,4,8,12))
self.assertTupleEqual(sink.src[0].arg[12:], (3,7,11,15))
def test_contract_axis_2(self):
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4)))
con = UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((2,4),))
sink = expander_rewrite(con)
assert sink.op is Ops.UNROLL and len(sink.src[0].arg) == 16 and sink.arg == ((1,4),)
assert sink.src[0].op is Ops.VCONST
self.assertTupleEqual(sink.src[0].arg[0:4], (0,1,2,3))
self.assertTupleEqual(sink.src[0].arg[12:], (12,13,14,15))
def test_contract_axis_2_big(self):
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2)))
con = UOp(Ops.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
sink = expander_rewrite(con)
assert sink.op is Ops.UNROLL and sink.arg == ((1, 2), (3, 2), (4, 2))
self.assertTupleEqual(sink.src[0].arg[0:2], (0,4))
self.assertTupleEqual(sink.src[0].arg[12:14], (10,14))
def test_contract_multi_axis(self):
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2)))
sink = expander_rewrite(UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((3, 2), (2, 2))))
assert sink.op is Ops.UNROLL and sink.arg == ((1, 2), (4, 2))
self.assertTupleEqual(sink.src[0].arg[0:4], (0, 4, 2, 6))
sink = expander_rewrite(UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((2, 2), (3, 2))))
assert sink.op is Ops.UNROLL and sink.arg == ((1, 2), (4, 2))
self.assertTupleEqual(sink.src[0].arg[0:4], (0, 2, 4, 6))
def test_contract_mid(self):
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(8), tuple(x for x in range(8))),), ((1,2),(2,2),(3,2)))
con = UOp(Ops.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
sink = expander_rewrite(con)
assert sink.op is Ops.UNROLL and sink.arg == ((1,2),(3,2))
assert sink.src[0].op is Ops.VCONST and len(sink.src[0].arg) == 8
self.assertTupleEqual(sink.src[0].arg, (0,2,1,3,4,6,5,7))
def test_contract_no_expand(self):
e1 = UOp(Ops.DEFINE_VAR, dtypes.int)
con = UOp(Ops.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
sink = expander_rewrite(con)
assert sink.op is Ops.VECTORIZE and len(sink.src) == 2
assert sink.src[0] == sink.src[1]
def test_contract_half_expand(self):
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
con = UOp(Ops.CONTRACT, dtypes.int.vec(8), (e1,), ((1,4), (2,2)))
sink = expander_rewrite(con)
assert sink.op is Ops.VCONST and len(sink.arg) == 8
assert sink.arg[0] == sink.arg[1]
assert sink.arg[0] != sink.arg[2]
assert sink.arg[6] == sink.arg[7]
def test_expand_same_axis(self):
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
e2 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),))
sink = expander_rewrite(e1+e2)
self.assertEqual(sink.op, Ops.UNROLL)
self.assertEqual(sink.src[0].op, Ops.VCONST)
self.assertTupleEqual(sink.src[0].arg, (0,5,10,15))
def test_expand_different_axis(self, flip=False):
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),))
e2 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((2,4),))
sink = expander_rewrite((e2+e1) if flip else (e1+e2))
assert sink.op is Ops.UNROLL and len(sink.src[0].arg) == 16
assert sink.arg == ((1, 4), (2, 4))
self.assertTupleEqual(sink.src[0].arg, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15))
def test_expand_different_axis_flip(self): self.test_expand_different_axis(True)
@unittest.skip("no longer supported")
def test_reduce_known_axis(self):
e1 = UOp(Ops.UNROLL, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
sink = UOp(Ops.REDUCE, dtypes.int, (3*e1,e1), Ops.ADD)
sink = expander_rewrite(sink)
assert sink.op is Ops.CONST
self.assertEqual(sink.arg, 3*(0+1+2+3))
@unittest.skip("no longer supported")
def test_reduce_const(self):
e1 = UOp(Ops.UNROLL, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
sink = UOp(Ops.REDUCE, dtypes.int, (UOp.const(dtypes.int, 3), e1), Ops.ADD)
sink = expander_rewrite(sink)
assert sink.op is Ops.CONST
self.assertEqual(sink.arg, 3*4)
@unittest.skip("no longer supported")
def test_double_expand(self):
e1 = UOp(Ops.UNROLL, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
e2 = UOp(Ops.UNROLL, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((2,4),))
e = UOp(Ops.UNROLL, dtypes.int, (e1, e2), ((1,2),))
sink = expander_rewrite(e)
assert sink.op is Ops.UNROLL and len(sink.src) == 8
assert sink.arg == ((1, 2), (2, 4))
self.assertListEqual([x.arg for x in sink.src], [0,1,2,3,4,5,6,7])
@unittest.skip("no longer supported")
def test_double_expand_reverse(self):
e1 = UOp(Ops.UNROLL, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
e2 = UOp(Ops.UNROLL, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,4),))
e = UOp(Ops.UNROLL, dtypes.int, (e1, e2), ((2,2),))
sink = expander_rewrite(e)
assert sink.op is Ops.UNROLL and len(sink.src) == 8
assert sink.arg == ((1, 4), (2, 2))
self.assertListEqual([x.arg for x in sink.src], [0, 4, 1, 5, 2, 6, 3, 7])
@unittest.skip("no longer supported")
def test_double_expand_middle(self):
e1 = UOp(Ops.UNROLL, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,2),(3,2)))
e2 = UOp(Ops.UNROLL, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,2),(3,2)))
e = UOp(Ops.UNROLL, dtypes.int, (e1, e2), ((2,2),))
sink = expander_rewrite(e)
assert sink.op is Ops.UNROLL and len(sink.src) == 8
assert sink.arg == ((1, 2), (2, 2), (3, 2))
self.assertListEqual([x.arg for x in sink.src], [0, 1, 4, 5, 2, 3, 6, 7])
# does this need to work?
@unittest.expectedFailure
@unittest.skip
def test_reduce_different_axis(self):
e1 = UOp(Ops.UNROLL, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
e2 = UOp(Ops.UNROLL, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
sink = UOp(Ops.REDUCE, dtypes.int, (e1,e2), Ops.ADD)
sink = expander_rewrite(sink)
print(sink)
class TestIFUOps(unittest.TestCase):
def test_create_ifs(self):
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(size=4, local=True), (), "smem")
valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 10))<5
lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 4))
gate = valid&(lidx.ne(2))
idx = UOp.const(dtypes.int, 0)
st = UOp(Ops.STORE, dtypes.void, (sbuf.index(idx), UOp.const(dtypes.float, 42)))
barrier = UOp(Ops.BARRIER, dtypes.void, (st,))
lbuf = UOp(Ops.LOAD, dtypes.float, (sbuf.index(UOp.const(dtypes.int, 0)), barrier))
store = UOp(Ops.STORE, dtypes.void, (gbuf.index(UOp.const(dtypes.int, 0), gate), lbuf))
sink = UOp(Ops.SINK, dtypes.void, (store,))
sink = full_graph_rewrite(expand_rewrite(sink))
if_uops = [u for u in sink.toposort if u.op is Ops.IF]
self.assertEqual(len(if_uops), 1)
self.assertEqual(if_uops[0].src[0], gate)
for st in sink.src:
self.assertEqual(len(st.src), 2)
def test_expand_ifs_one_gate(self):
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(size=16, local=True), (), "smem")
valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 4))<1
lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 16))
gate = valid&(lidx.ne(2))
st = UOp(Ops.STORE, dtypes.void, (sbuf, lidx, UOp.const(dtypes.float, 42)))
barrier = UOp(Ops.BARRIER, dtypes.void, (st,))
lbufs = [UOp(Ops.LOAD, dtypes.float, (sbuf.index(UOp.const(dtypes.int, i)), barrier)) for i in range(4)]
stores = [UOp(Ops.STORE, dtypes.void, (gbuf.index(UOp.const(dtypes.int, i), gate), lbufs[i])) for i in range(4)]
sink = UOp(Ops.SINK, dtypes.void, tuple(stores))
sink = full_graph_rewrite(expand_rewrite(sink))
if_uops = [u for u in sink.toposort if u.op is Ops.IF]
self.assertEqual(len(if_uops), 1)
self.assertEqual(if_uops[0].src[0], gate)
for st in sink.src:
self.assertEqual(len(st.src), 2)
# this will be fixed with the merge gated stores bounty
@unittest.expectedFailure
def test_expand_ifs_dumb(self):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 10))<5
lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 4))
gate = valid&(lidx.ne(2))
stores = [UOp(Ops.STORE, dtypes.void, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
sink = UOp(Ops.SINK, dtypes.void, tuple(stores))
sink = full_graph_rewrite(sink)
if_uops = [u for u in sink.toposort if u.op is Ops.IF]
self.assertEqual(len(if_uops), 1)
self.assertEqual(if_uops[0].src[0], gate)
for st in sink.src:
self.assertEqual(len(st.src), 2)
if __name__ == '__main__':
unittest.main(verbosity=2)