openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

772 lines
34 KiB

#!/usr/bin/env python
import unittest, pickle
from tinygrad.dtype import dtypes, ConstType
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.codegen.devectorizer import full_graph_rewrite, sym
from tinygrad.ops import UOp, Ops, graph_rewrite, sym_infer
from tinygrad import Variable
import functools
def render(self) -> tuple[str, ConstType, ConstType]:
# NOTE: we need STORE so the ALU op has children
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0)
uops = linearize_uop(full_graph_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), self)).sink()))
rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[-1]
return rewritten_uop.render(simplify=False), rewritten_uop.vmin, rewritten_uop.vmax
def uconst(val): return UOp.const(dtypes.int, val)
def usum(ops): return functools.reduce(lambda x,y: x+y, ops)
def uand(ops): return functools.reduce(lambda x,y: x*y, ops)
# *** leave tests the same
class TestSymbolicPickle(unittest.TestCase):
def _test_pickle_unpickle(self, x): self.assertEqual(x, pickle.loads(pickle.dumps(x)))
def test_pickle_variable(self): self._test_pickle_unpickle(Variable("a", 3, 8))
def test_pickle_variable_times_2(self): self._test_pickle_unpickle(Variable("a", 3, 8)*2)
class TestSymbolic(unittest.TestCase):
def helper_test_variable(self, v, n, m, s):
rendered, nmin, nmax = render(v)
if isinstance(s, tuple): self.assertIn(rendered, s)
else: self.assertEqual(rendered, s)
self.assertEqual(nmin, n)
self.assertEqual(nmax, m)
def test_cmp_simple(self):
self.helper_test_variable(Variable("a", 3, 8) < 4, 0, 1, "(a<4)")
self.helper_test_variable(Variable("a", 3, 8) >= 8, 0, 1, "((a<8)!=True)")
def test_ge(self):
self.helper_test_variable(Variable("a", 3, 8) >= 77, 0, 0, "False")
self.helper_test_variable(Variable("a", 3, 8) >= 9, 0, 0, "False")
self.helper_test_variable(Variable("a", 3, 8) >= 8, 0, 1, "((a<8)!=True)")
self.helper_test_variable(Variable("a", 3, 8) >= 4, 0, 1, "((a<4)!=True)")
self.helper_test_variable(Variable("a", 3, 8) >= 3, 1, 1, "True")
self.helper_test_variable(Variable("a", 3, 8) >= 2, 1, 1, "True")
def test_lt(self):
self.helper_test_variable(Variable("a", 3, 8) < 77, 1, 1, "True")
self.helper_test_variable(Variable("a", 3, 8) < 9, 1, 1, "True")
self.helper_test_variable(Variable("a", 3, 8) < 8, 0, 1, "(a<8)")
self.helper_test_variable(Variable("a", 3, 8) < 4, 0, 1, "(a<4)")
self.helper_test_variable(Variable("a", 3, 8) < 3, 0, 0, "False")
self.helper_test_variable(Variable("a", 3, 8) < 2, 0, 0, "False")
self.helper_test_variable(Variable("a", 3, 4) < Variable("b", 5, 6), 1, 1, "True")
self.helper_test_variable(Variable("a", 3, 5) < Variable("b", 5, 6), 0, 1, "(a<b)")
self.helper_test_variable(Variable("a", 5, 6) < Variable("b", 3, 5), 0, 0, "False")
self.helper_test_variable(Variable("a", 3, 4) < Variable("a", 3, 4), 0, 0, "False")
def test_lt_divides(self):
expr = (Variable("idx", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512
self.helper_test_variable(expr, 0, 1, "(idx<128)")
def test_lt_divides_and(self):
expr = uand([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512,
(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512])
self.helper_test_variable(expr, 0, 1, "((idx1<128)&(idx2<128))")
def test_lt_factors(self):
expr = (Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256)) < 512
self.helper_test_variable(expr, 0, 1, ("(((idx1*4)+FLOAT4_INDEX)<512)", "((FLOAT4_INDEX+(idx1*4))<512)"))
def test_div_reduction(self):
self.helper_test_variable(Variable("a", 2, 3)//2, 1, 1, "1")
def test_equality(self):
idx1 = Variable("idx1", 0, 3)
idx2 = Variable("idx2", 0, 3)
assert idx1 is idx1
assert idx1 is not idx2
assert idx1*4 is idx1*4
assert idx1*4 is not idx1*3
assert idx1*4 is not idx1+4
assert idx1*4 is not idx2*4
assert idx1+idx2 is idx1+idx2
# assert idx1+idx2 is idx2+idx1
assert idx1+idx2 is not idx2
# assert idx1*idx2 is idx2*idx1
def test_factorize(self):
a = Variable("a", 0, 8)
b = Variable("b", 0, 8)
self.helper_test_variable(a*2+a*3, 0, 8*5, "(a*5)")
self.helper_test_variable(b+a*2+a*3, 0, 8*6, "(b+(a*5))")
def test_factorize_no_mul(self):
a = Variable("a", 0, 8)
b = Variable("b", 0, 8)
self.helper_test_variable(a+a*3, 0, 8*4, "(a*4)")
self.helper_test_variable((a+b)+a*3, 0, 8*5, "(b+(a*4))")
self.helper_test_variable((a*3+b)+b*3, 0, 8*7, "((a*3)+(b*4))")
def test_neg(self):
self.helper_test_variable(-Variable("a", 0, 8), -8, 0, "(a*-1)")
def test_add_1(self):
self.helper_test_variable(Variable("a", 0, 8)+1, 1, 9, "(a+1)")
def test_sub_1(self):
self.helper_test_variable(Variable("a", 0, 8)-1, -1, 7, "(a+-1)")
def test_add_self(self):
a = Variable("a", 0, 8)
b = Variable("b", 0, 8)
self.helper_test_variable(a+a, 0, 16, "(a*2)")
self.helper_test_variable((a+b)+b, 0, 24, "(a+(b*2))")
def test_sub_self(self):
a = Variable("a", 0, 8)
self.helper_test_variable(a-a, 0, 0, "0")
self.helper_test_variable(a*3-a, 0, 16, "(a*2)")
def test_mul_0(self):
self.helper_test_variable(Variable("a", 0, 8)*0, 0, 0, "0")
def test_mul_1(self):
self.helper_test_variable(Variable("a", 0, 8)*1, 0, 8, "a")
@unittest.expectedFailure
def test_mul_neg_1(self):
self.helper_test_variable((Variable("a", 0, 2)*-1)//3, -1, 0, "((((a*-1)+3)//3)+-1)")
def test_mul_2(self):
self.helper_test_variable(Variable("a", 0, 8)*2, 0, 16, "(a*2)")
def test_div_1(self):
self.helper_test_variable(Variable("a", 0, 8)//1, 0, 8, "a")
def test_mod_1(self):
self.helper_test_variable(Variable("a", 0, 8)%1, 0, 0, "0")
def test_max_folds(self):
self.helper_test_variable(Variable("a", 0, 20).maximum(10).maximum(11), 11, 20, "max(a, 11)")
def test_add_min_max(self):
self.helper_test_variable(Variable("a", 0, 8) * 2 + 12, 12, 16+12, "((a*2)+12)")
def test_div_remove(self):
self.helper_test_variable(Variable("a", 0, 7) // 20, 0, 0, "0")
def test_div_min_max(self):
self.helper_test_variable(Variable("a", 1, 7) // 2, 0, 3, "(a//2)")
self.helper_test_variable(Variable("a", 0, 6) // 2, 0, 3, "(a//2)")
def test_div_neg_min_max(self):
self.helper_test_variable(Variable("a", 1, 7) // -2, -3, 0, "(a//-2)")
self.helper_test_variable(Variable("a", 0, 6) // -2, -3, 0, "(a//-2)")
def test_sum_div_remove(self):
self.helper_test_variable(usum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 20, 0, 0, "0")
def test_sum_div_min_max(self):
self.helper_test_variable(usum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)")
def test_sum_div_mod_factor(self):
self.helper_test_variable(usum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) // 2, 0, 20, "((a*2)+(b*2))")
self.helper_test_variable(usum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) % 2, 0, 0, "0")
def test_sum_div_some_factor(self):
self.helper_test_variable(usum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, ("(((a*5)//2)+(b*2))", "((b*2)+((a*5)//2))"))
def test_sum_div_trim_const(self):
self.helper_test_variable((Variable("a", 0, 7)*4 + Variable("b", 0, 3)*4 + 7) // 16, 0, 2, "(((a+b)+1)//4)")
def test_sum_div_some_partial_factor(self):
self.helper_test_variable(usum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)")
self.helper_test_variable(usum([uconst(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)")
self.helper_test_variable((Variable("a", 0, 7)*30+20)//20, 1, 11, "(((a*3)//2)+1)")
def test_sum_div_no_factor(self):
self.helper_test_variable(usum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)")
def test_mod_factor(self):
self.helper_test_variable(usum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 50, "((b%2)*50)")
def test_mod_to_sub(self):
self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, "(a+-1)")
def test_mod_congruence(self):
self.helper_test_variable((3+3*Variable("a",0,3))%4, 0, 3, "((a*-1)+3)")
self.helper_test_variable((17+13*Variable("a",0,3))%18, 2, 17, "((a*-5)+17)")
self.helper_test_variable((2+9*Variable("a",0,3))%18, 2, 11, "(((a%2)*9)+2)")
def test_mod_congruence_mul_add(self):
self.helper_test_variable((6*(Variable("a", 0, 2)+1))%9, 0, 6, "((a*-3)+6)")
def test_mod_congruence_multiple_vars(self):
self.helper_test_variable((9+9*Variable("x",0,3)+9*Variable("y",0,3))%10, 3, 9, "(((x*-1)+(y*-1))+9)")
self.helper_test_variable((7+9*Variable("x",0,2)+9*Variable("y",0,2)+Variable("z",0,2))%10, 3, 9,
("(((z+(x*-1))+(y*-1))+7)", "(((y*-1)+(z+(x*-1)))+7)"))
self.helper_test_variable((10+12*Variable("x",0,2)+Variable("y", 0, 4)%3)%13, 8, 12, "(((x*-1)+(y%3))+10)")
def test_div_congruence(self):
self.helper_test_variable((3+3*Variable("a",0,3))//4, 0, 3, "a")
self.helper_test_variable((18+17*Variable("a",0,2)+17)//18, 1, 3, "(a+1)")
def test_div_congruence_multiple_vars(self):
self.helper_test_variable((9+(9+10)*Variable("x",0,3)+(8+10)*Variable("y",0,2))//10, 0, 10, "((x*2)+(y*2))")
def test_mod_binary_expression(self):
self.helper_test_variable((3+Variable("a",0,1))%4, 0, 3, "((a*-3)+3)")
self.helper_test_variable((3+Variable("a",4,5))%4, 0, 3, "((a*-3)+15)")
def test_sum_div_const(self):
self.helper_test_variable(usum([Variable("a", 0, 7)*4, uconst(3)]) // 4, 0, 7, "a")
def test_sum_div_const_big(self):
self.helper_test_variable(usum([Variable("a", 0, 7)*4, uconst(3)]) // 16, 0, 1, "(a//4)")
def test_sum_lt_fold(self):
self.helper_test_variable(usum([Variable("a", 0, 7) * 4, Variable("b", 0, 3)]) < 16, 0, 1, "(a<4)")
self.helper_test_variable(usum([Variable("a", 0, 7) * 4, Variable("b", 0, 4)]) < 16, 0, 1,
("(((a*4)+b)<16)", "((b+(a*4))<16)"))
self.helper_test_variable(usum([Variable("uidx", 0, 3), Variable("a", 0, 1529) * 12]) < (4 * 67), 0, 1, "(a<23)")
def test_mul_mod_large(self):
self.helper_test_variable((Variable("a", 0, 20)*10)%9, 0, 8, "(a%9)")
def test_mul_mod_small(self):
self.helper_test_variable((Variable("a", 0, 5)*10)%9, 0, 5, "a")
def test_mod_mod(self):
self.helper_test_variable((Variable("a", 0, 31)%12)%4, 0, 3, "(a%4)")
self.helper_test_variable(((4*Variable("a", 0, 31)) % 12) % 4, 0, 0, "0")
self.helper_test_variable(((5*Variable("a", 0, 31)) % 12) % 5, 0, 4, "(((a*5)%12)%5)")
self.helper_test_variable((Variable("a", 0, 31) % 4) % 12, 0, 3, "(a%4)")
def test_mul_mul(self):
self.helper_test_variable((Variable("a", 0, 5)*10)*9, 0, 5*10*9, "(a*90)")
def test_mul_lt(self):
self.helper_test_variable(Variable("a", 0, 5)*4 < 13, 0, 1, "(a<4)")
self.helper_test_variable(Variable("a", 0, 5)*4 < 16, 0, 1, "(a<4)")
self.helper_test_variable(Variable("a", 0, 5)*(-2) < 0, 0, 1, "((a*-1)<0)")
self.helper_test_variable(Variable("a", 0, 5)*4 >= 12, 0, 1, "((a<3)!=True)")
self.helper_test_variable(Variable("a", 0, 5)*4 >= 13, 0, 1, "((a<4)!=True)")
def test_div_div(self):
self.helper_test_variable((Variable("a", 0, 1800)//10)//9, 0, 20, "(a//90)")
def test_div_const_div(self):
a = Variable("a", 0, 124)
self.helper_test_variable((a//2+1)//2, 0, 31, "((a+2)//4)")
def test_distribute_mul(self):
self.helper_test_variable(usum([Variable("a", 0, 3), Variable("b", 0, 5)])*3, 0, 24, "((a*3)+(b*3))")
self.helper_test_variable((1+Variable("a", 0, 3))*(-2)+12, 4, 10, "((a*-2)+10)")
def test_mod_mul_sum(self):
self.helper_test_variable(usum([Variable("b", 0, 2), Variable("a", 0, 5)*10])%9, 0, 7, ("(b+a)", "(a+b)"))
def test_sum_0(self):
self.helper_test_variable(usum([Variable("a", 0, 7)]), 0, 7, "a")
def test_mod_remove(self):
self.helper_test_variable(Variable("a", 0, 6)%100, 0, 6, "a")
def test_big_mod(self):
# NOTE: we no longer support negative variables
#self.helper_test_variable(Variable("a", -20, 20)%10, -9, 9, "(a%10)")
#self.helper_test_variable(Variable("a", -20, 0)%10, -9, 0, "(a%10)")
#self.helper_test_variable(Variable("a", -20, 1)%10, -9, 1, "(a%10)")
self.helper_test_variable(Variable("a", 0, 20)%10, 0, 9, "(a%10)")
#self.helper_test_variable(Variable("a", -1, 20)%10, -1, 9, "(a%10)")
def test_ge_remove(self):
self.helper_test_variable(Variable("a", 0, 6) >= 25, 0, 0, "False")
def test_lt_remove(self):
self.helper_test_variable(Variable("a", 0, 6) < -3, 0, 0, "False")
self.helper_test_variable(Variable("a", 0, 6) < 3, 0, 1, "(a<3)")
self.helper_test_variable(Variable("a", 0, 6) < 8, 1, 1, "True")
def test_lt_sum_remove(self):
self.helper_test_variable(Variable("a", 0, 6) + 2 < 3, 0, 1, "(a<1)")
def test_lt_simple_factor(self):
self.helper_test_variable((Variable("a", 0, 6)*6+Variable("b", 0, 6)*6) < 8, 0, 1, "(((a*3)+(b*3))<4)")
def test_lt_sum_factor_rhs_partial(self):
self.helper_test_variable((Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8) < 4, 0, 1,
("((((a*3)+(b*2))+(c*4))<2)", "(((b*2)+((a*3)+(c*4)))<2)"))
def test_lt_sum_factor_rhs_all(self):
self.helper_test_variable((Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8) < 2, 0, 1,
("((((a*3)+(b*2))+(c*4))<1)", "(((b*2)+((a*3)+(c*4)))<1)"))
def test_and_fold(self):
self.helper_test_variable(uand([uconst(0), Variable("a", 0, 1)]), 0, 0, "0")
def test_and_remove(self):
self.helper_test_variable(uand([uconst(1), Variable("a", 0, 1)]), 0, 1, "a")
def test_mod_factor_negative(self):
self.helper_test_variable(usum([uconst(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, 0, 27, "((a+27)%28)")
self.helper_test_variable(usum([uconst(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, 0, 27, "((a+27)%28)")
def test_sum_combine_num(self):
self.helper_test_variable(usum([uconst(29), Variable("a", 0, 10), uconst(-23)]), 6, 16, "(a+6)")
def test_sum_num_hoisted_and_factors_cancel_out(self):
self.helper_test_variable(usum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1")
def test_div_cancel(self):
self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40])//40, -1, 9, "(b+-1)")
def test_mod_cancel(self):
self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) % 40, 0, 20, "(a*2)")
def test_mul_div(self):
self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a")
def test_add_div(self):
# careful about the lower bounds and upper bounds
self.helper_test_variable((Variable("a", 0, 5)-2)//4, -1, 0, "(((a+2)//4)+-1)")
self.helper_test_variable((Variable("a", 0, 5)-1)//4, -1, 1, "(((a+3)//4)+-1)")
self.helper_test_variable((Variable("a", 0, 5))//4, 0, 1, "(a//4)")
self.helper_test_variable((Variable("a", 0, 5)+1)//4, 0, 1, "((a+1)//4)")
self.helper_test_variable((Variable("a", 0, 5)+2)//4, 0, 1, "((a+2)//4)")
self.helper_test_variable((Variable("a", 0, 5)+3)//4, 0, 2, "((a+3)//4)")
self.helper_test_variable((Variable("a", 0, 5)+4)//4, 1, 2, "((a//4)+1)")
self.helper_test_variable((Variable("a", 0, 5)+5)//4, 1, 2, "(((a+1)//4)+1)")
def test_mul_div_factor_mul(self):
self.helper_test_variable((Variable("a", 0, 10)*8)//4, 0, 20, "(a*2)")
def test_mul_div_factor_div(self):
self.helper_test_variable((Variable("a", 0, 10)*4)//8, 0, 5, "(a//2)")
def test_sum_div_partial_remove(self):
self.helper_test_variable(usum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0")
@unittest.expectedFailure
def test_div_numerator_negative(self):
self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -9, 0, "((((idx*-10)+99)//11)+-9)")
def test_div_into_mod(self):
self.helper_test_variable((Variable("idx", 0, 16)*4)%8//4, 0, 1, "(idx%2)")
# TODO: simplify the expression
def test_div_neg_cancel(self):
self.helper_test_variable((-Variable("idx", 0, 100)+199)//-4 + 50, 1, 26, "((((idx*-1)+199)//-4)+50)")
self.helper_test_variable((-Variable("idx", 0, 100)+200)//-4 + 50, 0, 25, "((((idx*-1)+200)//-4)+50)")
self.helper_test_variable((-Variable("idx", 0, 100)+201)//-4 + 50, 0, 25, "((((idx*-1)+201)//-4)+50)")
def test_sum_div_big_const(self):
gidx0 = Variable("gidx0", 0, 24)
self.helper_test_variable((gidx0+19)//20, 0, 2, "((gidx0+19)//20)")
self.helper_test_variable((gidx0+20)//20, 1, 2, "((gidx0//20)+1)")
self.helper_test_variable((gidx0+21)//20, 1, 2, "(((gidx0+1)//20)+1)")
def test_sum_div_complex1(self):
gidx0 = Variable("gidx0", 0, 24)
gidx1 = Variable("gidx1", 0, 1)
gidx2 = Variable("gidx2", 0, 255)
lidx0 = Variable("lidx0", 0, 1)
lidx1 = Variable("lidx1", 0, 15)
lidx2 = Variable("lidx2", 0, 3)
alu0 = gidx2*640+gidx1*160+(gidx0//5)*2+lidx0*320+lidx1*10
self.helper_test_variable((alu0+lidx2*2+1)//20, 0, 8192,
("((((((gidx0//5)+lidx2)//5)+lidx1)//2)+(((gidx2*32)+(gidx1*8))+(lidx0*16)))",
"(((lidx1+((lidx2+(gidx0//5))//5))//2)+((gidx2*32)+((gidx1*8)+(lidx0*16))))",
"((((gidx1*8)+(gidx2*32))+(lidx0*16))+((lidx1+((lidx2+(gidx0//5))//5))//2))"))
def test_sum_div_complex2(self):
gidx0 = Variable("gidx0", 0, 7)
lidx2 = Variable("lidx2", 0, 1)
lidx3 = Variable("lidx3", 0, 1)
self.helper_test_variable((gidx0*4+lidx2*2+1)//10, 0, 3, ("(((gidx0*2)+lidx2)//5)", "((lidx2+(gidx0*2))//5)"))
self.helper_test_variable((gidx0*4+lidx2*2+lidx3)//10, 0, 3, ("(((gidx0*2)+lidx2)//5)", "((lidx2+(gidx0*2))//5)"))
self.helper_test_variable((gidx0*2+lidx2)//10, 0, 1, "(gidx0//5)")
def test_sum_div_complex3(self):
gidx0 = Variable("gidx0", 0, 7)
lidx2 = Variable("lidx2", 0, 12)
lidx3 = Variable("lidx3", 0, 1)
self.helper_test_variable((gidx0*4+lidx2*2+lidx3)//12, 0, 4, ("(((lidx2//2)+gidx0)//3)", "((gidx0+(lidx2//2))//3)"))
self.helper_test_variable((lidx2*2+gidx0*4+lidx3)//12, 0, 4, ("(((lidx2//2)+gidx0)//3)", "((gidx0+(lidx2//2))//3)"))
def test_sum_mul_distribute(self):
gidx0 = Variable("gidx0", 0, 7)
lidx2 = Variable("lidx2", 0, 12)
lidx3 = Variable("lidx3", 0, 1)
self.helper_test_variable((gidx0+lidx2+lidx3)*4, 0, 80,
("(((gidx0*4)+(lidx2*4))+(lidx3*4))","((lidx3*4)+((gidx0*4)+(lidx2*4)))"))
@unittest.expectedFailure
def test_variable_divmod(self):
start_pos = Variable("start_pos", 0, 127)
v = start_pos + 1
idx0 = Variable("idx0", 0, 2)
idx1 = Variable("idx1", 0, start_pos)
self.helper_test_variable((idx0*v+idx1)//v, 0, 2, "(idx0)")
self.helper_test_variable((idx0*v+idx1)%v, 0, start_pos, "idx1")
def test_divmod_variable_denom_fold_to_const(self):
x = Variable("x", 20, 23)
y = Variable("y", 8, 10)
self.helper_test_variable(x//y, 2, 2, "2")
self.helper_test_variable(x%y, 0, 7, "(x+(y*-2))")
# ensure all 4 corners are checked
x = Variable("x", -10, 10)
y = Variable("y", -8, 9)
self.helper_test_variable(x//y, -2147483648, 2147483647, "(x//y)")
self.helper_test_variable(x%y, -2147483648, 2147483647, "(x%y)")
# TODO: simplify the expression
def test_div_neg_all_range(self):
gidx = Variable("gidx", 0, 124)
lidx = Variable("lidx", 0, 7)
self.helper_test_variable((-gidx*8-lidx+999)//-4 + 250, 1, 250, "(((((gidx*-8)+(lidx*-1))+999)//-4)+250)")
self.helper_test_variable((-gidx*8-lidx+1000)//-4 + 250, 0, 250, "(((((gidx*-8)+(lidx*-1))+1000)//-4)+250)")
self.helper_test_variable((-gidx*8-lidx+1001)//-4 + 250, 0, 250, "(((((gidx*-8)+(lidx*-1))+1001)//-4)+250)")
self.helper_test_variable((-gidx*8-lidx+1002)//-4 + 250, 0, 250, "(((((gidx*-8)+(lidx*-1))+1002)//-4)+250)")
# NOTE: tests are not correct in symbolic
def test_div_neg_then_neg(self):
# taken from arange opts
lidx0 = Variable("lidx0", 0, 7)
lidx1 = Variable("lidx1", 0, 7)
alu2 = -lidx0-lidx1
self.helper_test_variable((((alu2+14)//(-32))+4), 4, 4, "4")
self.helper_test_variable(-(((alu2+14)//(-32))+4), -4, -4, "-4")
self.helper_test_variable((((alu2+134)//(-32))+4), 0, 1, "(((((lidx0*-1)+(lidx1*-1))+134)//-32)+4)")
self.helper_test_variable((((alu2+142)//(-32))+4), 0, 0, "0")
self.helper_test_variable((((alu2+150)//(-32))+4), 0, 0, "0")
self.helper_test_variable((((alu2+158)//(-32))+4), 0, 0, "0")
def test_div_mod_recombine(self):
gidx = Variable("gidx", 0, 124)
self.helper_test_variable(gidx%4+(gidx//4)*4, 0, 124, "gidx")
self.helper_test_variable((gidx//4)*4+gidx%4, 0, 124, "gidx")
def test_div_mod_recombine_folded_mod(self):
a = Variable("a", 0, 2)
b = Variable("b", 0, 100)
self.helper_test_variable((31 * a + 1) % 30 + ((31 * a + 1) // 30) * 30, 1, 63, "((a*31)+1)")
with self.assertRaises(AssertionError):
self.helper_test_variable((31 * b + 1) % 18 + ((31 * b + 1) // 18) * 18, 1, 3101, "((b*31)+1)")
def test_div_mod_recombine_with_gcd(self):
b = Variable("b", 0, 100)
exp = (16 * b + 2) % 18 + ((16 * b + 2) // 18) * 18
self.helper_test_variable(exp, 2, 1602, "((b*16)+2)")
with self.assertRaises(AssertionError):
self.helper_test_variable((30 * b + 1) % 18 + ((30 * b + 1) // 18) * 18, 1, 3001, "((b*30)+1)")
def test_arange_unrolled4(self):
gidx = Variable("gidx", 0, 2559)
unrolled_div = (gidx+2561)//4+(gidx+2562)//4+(gidx+2560)//4+(gidx+2559)//4
self.helper_test_variable(unrolled_div, 2559, 5118, "(gidx+2559)")
def test_arange_unrolled4_small(self):
gidx = Variable("gidx", 0, 3)
unrolled_div = (gidx)//4+(gidx+2)//4+(gidx+3)//4+(gidx+1)//4
self.helper_test_variable(unrolled_div, 0, 3, "gidx")
gidx = Variable("gidx", 0, 2)
unrolled_div = (gidx)//4+(gidx+2)//4+(gidx+3)//4+(gidx+1)//4
self.helper_test_variable(unrolled_div, 0, 2, "gidx")
gidx = Variable("gidx", 0, 1)
unrolled_div = (gidx)//4+(gidx+2)//4+(gidx+3)//4+(gidx+1)//4
self.helper_test_variable(unrolled_div, 0, 1, "gidx")
def test_arange_unrolled2(self):
gidx = Variable("gidx", 0, 2559)
unrolled_div = (gidx+2559)//2+(gidx+2560)//2+3
self.helper_test_variable(unrolled_div, 2562, 5121, "(gidx+2562)")
def test_gated_load(self):
idx = Variable("idx", 0, 24)
self.helper_test_variable(idx//4, 0, 6, "(idx//4)")
# TODO: simplify the true branch
self.helper_test_variable((idx<4).where(idx//4, idx.const_like(-1)), -1, 6, "((idx//4) if (idx<4) else -1)")
def test_idiv_lt(self):
idx = Variable("idx", 0, 24)
self.helper_test_variable((idx//4<3), 0, 1, "(idx<12)")
self.helper_test_variable((idx//-4<-3), 0, 1, "((idx//-4)<-3)")
def test_simplex_lt(self):
a = Variable("a", 0, 3)
b = Variable("b", 0, 3)
c = Variable("c", 0, 3)
d = Variable("d", -3, 3)
self.helper_test_variable((a<1).ne(True), 0, 1, "((a<1)!=True)")
self.helper_test_variable((a+b<1).ne(True), 0, 1, "(((a+b)<1)!=True)")
self.helper_test_variable((a*3+b*4<1).ne(True), 0, 1, "(((a+b)<1)!=True)")
self.helper_test_variable((a*(-3)+b*4<1).ne(True), 0, 1, "((((a*-3)+(b*4))<1)!=True)") # negative coeff, should not be simplified
self.helper_test_variable((a*3+d*4<1).ne(True), 0, 1, "((((a*3)+(d*4))<1)!=True)") # var can be negative, should not be simplified
self.helper_test_variable((a+b+c*2<1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)", '(((b+(a+c))<1)!=True)'))
self.helper_test_variable((a+b*2+c*4<1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)", '(((b+(a+c))<1)!=True)'))
def test_where_removal(self):
cond = Variable("a", 0, 3) < 2
u1, u0 = cond.ufix(1), cond.ufix(0)
self.helper_test_variable(cond, 0, 1, "(a<2)")
self.helper_test_variable(cond.where(u1, u0), 0, 1, "(a<2)")
self.helper_test_variable(cond.where(u1, u0).where(u1, u0), 0, 1, "(a<2)")
def test_where_combine(self):
cond = Variable("x", 0, 3) < 2
a = Variable("a", 0, 3)
b = Variable("b", 0, 3)
aa = cond.where(a, a.ufix(0))
bb = cond.where(b, b.ufix(1))
self.helper_test_variable(aa, 0, 3, "(a if (x<2) else 0)")
self.helper_test_variable(bb, 0, 3, "(b if (x<2) else 1)")
self.helper_test_variable(aa+bb, 0, 6, "((a+b) if (x<2) else 1)")
self.helper_test_variable(aa.maximum(bb), 0, 3, "(max(a, b) if (x<2) else 1)")
# not combining because it increased total ALU
c = Variable("c", 0, 3)
cc = cond.where(c, c+1)
self.helper_test_variable(bb+cc, 0, 7, "((b if (x<2) else 1)+(c if (x<2) else (c+1)))")
# not combining # TODO: can combine if it can further simplify?
ab = cond.where(a, b)
ba = cond.where(b, a)
self.helper_test_variable(ab+ba, 0, 6, "((a if (x<2) else b)+(b if (x<2) else a))")
# not combining # TODO: can combine if one is identity element const
self.helper_test_variable(aa+ab, 0, 6, "((a if (x<2) else b)+(a if (x<2) else 0))")
def test_where_cast(self):
s = Variable("s", 0, 3)
cond = s < 2
a = Variable("a", 0, 3)
b = Variable("b", 0, 3)
expr = cond.where(a, b).cast(dtypes.half)
# TODO: copied from render, render does not support cast
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0)
uops = linearize_uop(full_graph_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), expr)).sink()))
rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[-1]
self.assertEqual(rewritten_uop, cond.where(a.cast(dtypes.half), b.cast(dtypes.half)))
def test_symbolic_div(self):
# from symbolic arange
a = Variable("a", 1, 10)
denominator = ((a*-2)+1)
numerator = (((((a*2)+-1)*2)+1)*a)
self.helper_test_variable(denominator, -19, -1, "((a*-2)+1)")
self.helper_test_variable(numerator, 3, 390, "(a*((a*4)+-1))")
self.helper_test_variable((numerator//denominator)<=0, 1, 1, "True")
class TestSymbolicNumeric(unittest.TestCase):
def helper_test_numeric(self, f):
MIN, MAX = 0, 10
# one number
for i in range(MIN, MAX):
v = graph_rewrite(f(uconst(i)), sym)
self.assertEqual(v.vmin, v.vmax)
self.assertEqual(v.vmin, f(i))
for kmin in range(MIN, MAX):
for kmax in range(MIN, MAX):
if kmin > kmax: continue
v = f(Variable("tmp", kmin, kmax))
values = [f(rv) for rv in range(kmin, kmax+1)]
# the min and max may not be exact
self.assertLessEqual(v.vmin, min(values))
self.assertGreaterEqual(v.vmax, max(values))
def test_mod_4(self): self.helper_test_numeric(lambda x: (x%4))
def test_div_4(self): self.helper_test_numeric(lambda x: (x//4))
def test_plus_1_div_2(self): self.helper_test_numeric(lambda x: (x+1)//2)
def test_plus_1_mod_2(self): self.helper_test_numeric(lambda x: (x+1)%2)
def test_times_2(self): self.helper_test_numeric(lambda x: x*2)
def test_times_2_plus_3(self): self.helper_test_numeric(lambda x: x*2 + 3)
def test_times_2_plus_3_mod_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)%4)
def test_times_2_plus_3_div_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)//4)
def test_times_2_plus_3_div_4_mod_4(self): self.helper_test_numeric(lambda x: ((x*2 + 3)//4)%4)
class TestSymbolicVars(unittest.TestCase):
def test_simple(self):
z = uconst(0)
a = Variable("a", 0, 10)
b = Variable("b", 0, 10)
c = Variable("c", 0, 10)
assert z.vars() == z.vars() == set()
print(a.vars())
assert a.vars() == a.vars() == {a}
m = a * 3
assert m.vars() == {a}
s = usum([a, b, c])
assert s.vars() == {a, b, c}
def test_compound(self):
a = Variable("a", 0, 10)
b = Variable("b", 0, 10)
c = Variable("c", 0, 10)
assert (a + b * c).vars() == {a, b, c}
assert (a % 3 + b // 5).vars() == {a, b}
# TODO: fix me
with self.assertRaises(AssertionError):
assert (a + b + c - a).vars() == {b, c}
def test_dedup(self):
a = Variable("a", 0, 10)
assert (a * a).vars() == {a}
assert (a//4 + a//6).vars() == {a}
class TestSymInfer(unittest.TestCase):
def test_sym_infer(self):
a = Variable("a", 0, 10)
b = Variable("b", 0, 10)
c = Variable("c", 0, 10)
var_vals = {a: 2, b: 3, c: 4}
assert sym_infer(5, var_vals) == 5
assert sym_infer(4.2, var_vals) == 4.2
assert sym_infer(a, var_vals) == 2
assert sym_infer(b, var_vals) == 3
assert sym_infer(a+b, var_vals) == 5
assert sym_infer(a-b, var_vals) == -1
assert sym_infer(a+b+c, var_vals) == 9
assert sym_infer(a*b, var_vals) == 6
assert sym_infer(a*b+c, var_vals) == 10
"""
@unittest.skip("not supported on uops yet")
class TestSymbolicSymbolicOps(unittest.TestCase):
def test_node_divmod_node(self):
i = Variable("i", 1, 10)
idx0 = Variable("idx0", 0, i*3-1)
assert uconst(0) // (Variable("i", 1, 10)*128) == 0
assert uconst(0) % (Variable("i", 1, 10)*128) == 0
assert uconst(127) // (Variable("i", 1, 10)*128) == 0
assert uconst(127) % (Variable("i", 1, 10)*128) == 127
assert 127 // (Variable("i", 1, 10)*128) == 0
assert 127 % (Variable("i", 1, 10)*128) == 127
assert uconst(128) // (Variable("i", 1, 10)*128 + 128) == 0
assert uconst(128) % (Variable("i", 1, 10)*128 + 128) == 128
assert 128 // (Variable("i", 1, 10)*128 + 128) == 0
assert 128 % (Variable("i", 1, 10)*128 + 128) == 128
assert 0 // (Variable("i", 1, 10)*128) == 0
assert 0 % (Variable("i", 1, 10)*128) == 0
assert idx0 // (i*3) == 0
assert idx0 % (i*3) == idx0
assert i // i == 1
assert i % i == 0
assert 128 // uconst(4) == 32
assert 128 % uconst(4) == 0
assert uconst(128) // uconst(4) == 32
assert uconst(128) % uconst(4) == 0
def test_mulnode_divmod_node(self):
i = Variable("i", 1, 10)
idx0 = Variable("idx0", 0, 31)
# assert (idx0*(i*4+4)) // (i+1) == (idx0*4)
# assert (idx0*(i*4+4)) % (i+1) == 0
assert (idx0*i) % i == 0
def test_sumnode_divmod_sumnode(self):
i = Variable("i", 1, 10)
# idx0 = Variable("idx0", 0, 7)
# idx1 = Variable("idx1", 0, 3)
# idx2 = Variable("idx2", 0, i)
# assert (idx0*(i*4+4)+idx1*(i+1)+idx2) // (i+1) == idx0*4+idx1
# assert (idx0*(i*4+4)+idx1*(i+1)+idx2) % (i+1) == idx2
assert (i+1) // (i*128+128) == 0
assert (i+1) % (i*128+128) == (i+1)
# assert (i+1+idx2) // (i+1) == 1
# assert (i+1+idx2) % (i+1) == idx2
# assert (idx0*(i*4+4)+i+1+idx2) // (i+1) == idx0*4+1
# assert (idx0*(i*4+4)+i+1+idx2) % (i+1) == idx2
# assert (i*128+128)*2 // (i*128+128) == 2
# assert (i*128+128)*2 % (i*128+128) == 0
def test_sumnode_div_uconst_no_factoring(self):
gid = Variable("gid", 0, 1023)
lid = Variable("lid", 0, 3)
expr_before_div = uconst(-1019)-4*lid-gid
unfactored_expr = Node.__floordiv__(expr_before_div, uconst(-16), False)
factored_expr = Node.__floordiv__(expr_before_div, uconst(-16), True)
self.assertEqual(unfactored_expr.render(), "(((lid*4)+1019+gid)//16)")
self.assertEqual(factored_expr.render(), "(((((3+gid)//4)+2+lid)//4)+63)")
def test_mod_node_max(self):
i = Variable("i", 1, 128)
gidx0 = Variable("gidx0", 0, i)
mod = gidx0 % 8
assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 8
mod = gidx0 % 2
assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 2
gidx0 = Variable("gidx0", 0, i*8+7)
mod = gidx0 % 8
assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 8
mod = gidx0 % 2
assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 2
def test_nested_variable_mod(self):
i = Variable("i", 1, 5)
idx0 = Variable("idx0", 0, i)
with self.assertRaises(AssertionError):
assert idx0 % 2 == idx0
def test_num_node_mul_node(self):
a = Variable("a", 1, 5)
b = uconst(2) * a
assert b == a * 2
assert isinstance(b, MulNode)
b = uconst(1) * a
assert b == a
assert isinstance(b, Variable)
b = uconst(0) * a
assert b == 0
assert isinstance(b, uconst)
def test_substitute(self):
a = Variable("idx0", 1, 3)
b = a + 1
c = b.substitute({a: uconst(1)})
assert c == uconst(2)
"""
class TestSymbolicRealWorld(unittest.TestCase):
def test_resnet_half(self):
gidx0 = Variable("gidx0", 0, 3)
gidx1 = Variable("gidx1", 0, 127)
gidx2 = Variable("gidx2", 0, 7)
lidx3 = Variable("lidx3", 0, 7)
lidx4 = Variable("lidx4", 0, 1)
lidx5 = Variable("lidx5", 0, 15)
idx:UOp = ((((1+lidx5)%16)*49)+(((262145+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+-13151129600+lidx3)
idx = graph_rewrite(idx, sym)
#print(idx.render())
# NOTE: this used to have 13,151,129,600 in the output which is out of int32 range.
self.assertIn(idx.render(),
("((((((((((lidx5+1)//16)*802816)+(((lidx5+1)%16)*49))+(gidx0*3211264))+(gidx1*784))+(gidx2*8))+(lidx4*100352))+lidx3)+2207744)",
'((lidx3+((((((((lidx5+1)//16)*802816)+(((lidx5+1)%16)*49))+(gidx0*3211264))+(gidx1*784))+(gidx2*8))+(lidx4*100352)))+2207744)',
'((lidx3+((lidx4*100352)+((gidx2*8)+((gidx1*784)+((gidx0*3211264)+((((lidx5+1)//16)*802816)+(((lidx5+1)%16)*49)))))))+2207744)',
))
class TestBounds(unittest.TestCase):
def test_unrolled_arange(self):
# #include <metal_stdlib>
# using namespace metal;
# kernel void r_2560_640_4(device int* data0, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
# int gidx0 = gid.x; /* 2560 */
# int alu0 = (gidx0*(-1));
# int alu1 = max((int)((-640)),((((alu0+2559)/(-4))*(-1))+(-640)));
# int alu2 = max((int)((-640)),((((alu0+2560)/(-4))*(-1))+(-640)));
# int alu3 = max((int)((-640)),((((alu0+2561)/(-4))*(-1))+(-640)));
# int alu4 = max((int)((-640)),((((alu0+2562)/(-4))*(-1))+(-640)));
# *(data0+gidx0) = ((alu1*(-1))+(alu2*(-1))+(alu4*(-1))+(alu3*(-1))+(-1));
# }
gidx0 = Variable("gidx0", 0, 2559)
assert gidx0.vmin == 0 and gidx0.vmax == 2559
alu0 = gidx0 * -1
assert alu0.vmin == -2559 and alu0.vmax == 0
assert (alu0+2559).vmin == 0 and (alu0+2559).vmax == 2559
assert ((alu0+2559)//-4).vmin == -639 and ((alu0+2559)//-4).vmax == 0
assert (((alu0+2559)//-4)*(-1)).vmin == 0 and (((alu0+2559)//-4)*(-1)).vmax == 639
if __name__ == '__main__':
unittest.main()