#!/usr/bin/env python import unittest, pickle from typing import Tuple from tinygrad.dtype import dtypes, ConstType from tinygrad.codegen.linearize import linearize_uop from tinygrad.codegen.uopgraph import full_graph_rewrite, sym from tinygrad.ops import UOp, Ops, graph_rewrite, sym_infer from tinygrad import Variable import functools def render(self) -> Tuple[str, ConstType, ConstType]: # NOTE: we need STORE so the ALU op has children glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0) uops = linearize_uop(full_graph_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), self)).sink())) rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[-1] return rewritten_uop.render(simplify=False), rewritten_uop.vmin, rewritten_uop.vmax def uconst(val): return UOp.const(dtypes.int, val) def usum(ops): return functools.reduce(lambda x,y: x+y, ops) def uand(ops): return functools.reduce(lambda x,y: x*y, ops) # *** leave tests the same class TestSymbolicPickle(unittest.TestCase): def _test_pickle_unpickle(self, x): self.assertEqual(x, pickle.loads(pickle.dumps(x))) def test_pickle_variable(self): self._test_pickle_unpickle(Variable("a", 3, 8)) def test_pickle_variable_times_2(self): self._test_pickle_unpickle(Variable("a", 3, 8)*2) class TestSymbolic(unittest.TestCase): def helper_test_variable(self, v, n, m, s): rendered, nmin, nmax = render(v) if isinstance(s, tuple): self.assertIn(rendered, s) else: self.assertEqual(rendered, s) self.assertEqual(nmin, n) self.assertEqual(nmax, m) def test_cmp_simple(self): self.helper_test_variable(Variable("a", 3, 8) < 4, 0, 1, "(a<4)") self.helper_test_variable(Variable("a", 3, 8) >= 8, 0, 1, "((a<8)!=True)") def test_ge(self): self.helper_test_variable(Variable("a", 3, 8) >= 77, 0, 0, "False") self.helper_test_variable(Variable("a", 3, 8) >= 9, 0, 0, "False") self.helper_test_variable(Variable("a", 3, 8) >= 8, 0, 1, "((a<8)!=True)") self.helper_test_variable(Variable("a", 3, 8) >= 4, 0, 1, "((a<4)!=True)") self.helper_test_variable(Variable("a", 3, 8) >= 3, 1, 1, "True") self.helper_test_variable(Variable("a", 3, 8) >= 2, 1, 1, "True") def test_lt(self): self.helper_test_variable(Variable("a", 3, 8) < 77, 1, 1, "True") self.helper_test_variable(Variable("a", 3, 8) < 9, 1, 1, "True") self.helper_test_variable(Variable("a", 3, 8) < 8, 0, 1, "(a<8)") self.helper_test_variable(Variable("a", 3, 8) < 4, 0, 1, "(a<4)") self.helper_test_variable(Variable("a", 3, 8) < 3, 0, 0, "False") self.helper_test_variable(Variable("a", 3, 8) < 2, 0, 0, "False") self.helper_test_variable(Variable("a", 3, 4) < Variable("b", 5, 6), 1, 1, "True") self.helper_test_variable(Variable("a", 3, 5) < Variable("b", 5, 6), 0, 1, "(a= 12, 0, 1, "((a<3)!=True)") self.helper_test_variable(Variable("a", 0, 5)*4 >= 13, 0, 1, "((a<4)!=True)") def test_div_div(self): self.helper_test_variable((Variable("a", 0, 1800)//10)//9, 0, 20, "(a//90)") def test_div_const_div(self): a = Variable("a", 0, 124) self.helper_test_variable((a//2+1)//2, 0, 31, "((a+2)//4)") def test_distribute_mul(self): self.helper_test_variable(usum([Variable("a", 0, 3), Variable("b", 0, 5)])*3, 0, 24, "((a*3)+(b*3))") self.helper_test_variable((1+Variable("a", 0, 3))*(-2)+12, 4, 10, "((a*-2)+10)") def test_mod_mul_sum(self): self.helper_test_variable(usum([Variable("b", 0, 2), Variable("a", 0, 5)*10])%9, 0, 7, ("(b+a)", "(a+b)")) def test_sum_0(self): self.helper_test_variable(usum([Variable("a", 0, 7)]), 0, 7, "a") def test_mod_remove(self): self.helper_test_variable(Variable("a", 0, 6)%100, 0, 6, "a") def test_big_mod(self): # NOTE: we no longer support negative variables #self.helper_test_variable(Variable("a", -20, 20)%10, -9, 9, "(a%10)") #self.helper_test_variable(Variable("a", -20, 0)%10, -9, 0, "(a%10)") #self.helper_test_variable(Variable("a", -20, 1)%10, -9, 1, "(a%10)") self.helper_test_variable(Variable("a", 0, 20)%10, 0, 9, "(a%10)") #self.helper_test_variable(Variable("a", -1, 20)%10, -1, 9, "(a%10)") def test_ge_remove(self): self.helper_test_variable(Variable("a", 0, 6) >= 25, 0, 0, "False") def test_lt_remove(self): self.helper_test_variable(Variable("a", 0, 6) < -3, 0, 0, "False") self.helper_test_variable(Variable("a", 0, 6) < 3, 0, 1, "(a<3)") self.helper_test_variable(Variable("a", 0, 6) < 8, 1, 1, "True") def test_lt_sum_remove(self): self.helper_test_variable(Variable("a", 0, 6) + 2 < 3, 0, 1, "(a<1)") def test_lt_simple_factor(self): self.helper_test_variable((Variable("a", 0, 6)*6+Variable("b", 0, 6)*6) < 8, 0, 1, "(((a*3)+(b*3))<4)") def test_lt_sum_factor_rhs_partial(self): self.helper_test_variable((Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8) < 4, 0, 1, "((((a*3)+(b*2))+(c*4))<2)") def test_lt_sum_factor_rhs_all(self): self.helper_test_variable((Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8) < 2, 0, 1, "((((a*3)+(b*2))+(c*4))<1)") def test_and_fold(self): self.helper_test_variable(uand([uconst(0), Variable("a", 0, 1)]), 0, 0, "0") def test_and_remove(self): self.helper_test_variable(uand([uconst(1), Variable("a", 0, 1)]), 0, 1, "a") def test_mod_factor_negative(self): self.helper_test_variable(usum([uconst(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, 0, 27, "((a+27)%28)") self.helper_test_variable(usum([uconst(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, 0, 27, "((a+27)%28)") def test_sum_combine_num(self): self.helper_test_variable(usum([uconst(29), Variable("a", 0, 10), uconst(-23)]), 6, 16, "(a+6)") def test_sum_num_hoisted_and_factors_cancel_out(self): self.helper_test_variable(usum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1") def test_div_cancel(self): self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40])//40, -1, 9, "(b+-1)") def test_mod_cancel(self): self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) % 40, 0, 20, "(a*2)") def test_mul_div(self): self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a") def test_add_div(self): # careful about the lower bounds and upper bounds self.helper_test_variable((Variable("a", 0, 5)-2)//4, -1, 0, "(((a+2)//4)+-1)") self.helper_test_variable((Variable("a", 0, 5)-1)//4, -1, 1, "(((a+3)//4)+-1)") self.helper_test_variable((Variable("a", 0, 5))//4, 0, 1, "(a//4)") self.helper_test_variable((Variable("a", 0, 5)+1)//4, 0, 1, "((a+1)//4)") self.helper_test_variable((Variable("a", 0, 5)+2)//4, 0, 1, "((a+2)//4)") self.helper_test_variable((Variable("a", 0, 5)+3)//4, 0, 2, "((a+3)//4)") self.helper_test_variable((Variable("a", 0, 5)+4)//4, 1, 2, "((a//4)+1)") self.helper_test_variable((Variable("a", 0, 5)+5)//4, 1, 2, "(((a+1)//4)+1)") def test_mul_div_factor_mul(self): self.helper_test_variable((Variable("a", 0, 10)*8)//4, 0, 20, "(a*2)") def test_mul_div_factor_div(self): self.helper_test_variable((Variable("a", 0, 10)*4)//8, 0, 5, "(a//2)") def test_sum_div_partial_remove(self): self.helper_test_variable(usum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0") @unittest.expectedFailure def test_div_numerator_negative(self): self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -9, 0, "((((idx*-10)+99)//11)+-9)") def test_div_into_mod(self): self.helper_test_variable((Variable("idx", 0, 16)*4)%8//4, 0, 1, "(idx%2)") # TODO: simplify the expression def test_div_neg_cancel(self): self.helper_test_variable((-Variable("idx", 0, 100)+199)//-4 + 50, 1, 26, "((((idx*-1)+199)//-4)+50)") self.helper_test_variable((-Variable("idx", 0, 100)+200)//-4 + 50, 0, 25, "((((idx*-1)+200)//-4)+50)") self.helper_test_variable((-Variable("idx", 0, 100)+201)//-4 + 50, 0, 25, "((((idx*-1)+201)//-4)+50)") def test_sum_div_big_const(self): gidx0 = Variable("gidx0", 0, 24) self.helper_test_variable((gidx0+19)//20, 0, 2, "((gidx0+19)//20)") self.helper_test_variable((gidx0+20)//20, 1, 2, "((gidx0//20)+1)") self.helper_test_variable((gidx0+21)//20, 1, 2, "(((gidx0+1)//20)+1)") def test_sum_div_complex1(self): gidx0 = Variable("gidx0", 0, 24) gidx1 = Variable("gidx1", 0, 1) gidx2 = Variable("gidx2", 0, 255) lidx0 = Variable("lidx0", 0, 1) lidx1 = Variable("lidx1", 0, 15) lidx2 = Variable("lidx2", 0, 3) alu0 = gidx2*640+gidx1*160+(gidx0//5)*2+lidx0*320+lidx1*10 self.helper_test_variable((alu0+lidx2*2+1)//20, 0, 8192, ("((((((gidx0//5)+lidx2)//5)+lidx1)//2)+(((gidx2*32)+(gidx1*8))+(lidx0*16)))", "((((gidx1*8)+(gidx2*32))+(lidx0*16))+((lidx1+((lidx2+(gidx0//5))//5))//2))")) def test_sum_div_complex2(self): gidx0 = Variable("gidx0", 0, 7) lidx2 = Variable("lidx2", 0, 1) lidx3 = Variable("lidx3", 0, 1) self.helper_test_variable((gidx0*4+lidx2*2+1)//10, 0, 3, ("(((gidx0*2)+lidx2)//5)", "((lidx2+(gidx0*2))//5)")) self.helper_test_variable((gidx0*4+lidx2*2+lidx3)//10, 0, 3, ("(((gidx0*2)+lidx2)//5)", "((lidx2+(gidx0*2))//5)")) self.helper_test_variable((gidx0*2+lidx2)//10, 0, 1, "(gidx0//5)") def test_sum_div_complex3(self): gidx0 = Variable("gidx0", 0, 7) lidx2 = Variable("lidx2", 0, 12) lidx3 = Variable("lidx3", 0, 1) self.helper_test_variable((gidx0*4+lidx2*2+lidx3)//12, 0, 4, ("(((lidx2//2)+gidx0)//3)", "((gidx0+(lidx2//2))//3)")) self.helper_test_variable((lidx2*2+gidx0*4+lidx3)//12, 0, 4, ("(((lidx2//2)+gidx0)//3)", "((gidx0+(lidx2//2))//3)")) def test_sum_mul_distribute(self): gidx0 = Variable("gidx0", 0, 7) lidx2 = Variable("lidx2", 0, 12) lidx3 = Variable("lidx3", 0, 1) self.helper_test_variable((gidx0+lidx2+lidx3)*4, 0, 80, "(((gidx0*4)+(lidx2*4))+(lidx3*4))") @unittest.expectedFailure def test_variable_divmod(self): start_pos = Variable("start_pos", 0, 127) v = start_pos + 1 idx0 = Variable("idx0", 0, 2) idx1 = Variable("idx1", 0, start_pos) self.helper_test_variable((idx0*v+idx1)//v, 0, 2, "(idx0)") self.helper_test_variable((idx0*v+idx1)%v, 0, start_pos, "idx1") # TODO: simplify the expression def test_div_neg_all_range(self): gidx = Variable("gidx", 0, 124) lidx = Variable("lidx", 0, 7) self.helper_test_variable((-gidx*8-lidx+999)//-4 + 250, 1, 250, "(((((gidx*-8)+(lidx*-1))+999)//-4)+250)") self.helper_test_variable((-gidx*8-lidx+1000)//-4 + 250, 0, 250, "(((((gidx*-8)+(lidx*-1))+1000)//-4)+250)") self.helper_test_variable((-gidx*8-lidx+1001)//-4 + 250, 0, 250, "(((((gidx*-8)+(lidx*-1))+1001)//-4)+250)") self.helper_test_variable((-gidx*8-lidx+1002)//-4 + 250, 0, 250, "(((((gidx*-8)+(lidx*-1))+1002)//-4)+250)") # NOTE: tests are not correct in symbolic def test_div_neg_then_neg(self): # taken from arange opts lidx0 = Variable("lidx0", 0, 7) lidx1 = Variable("lidx1", 0, 7) alu2 = -lidx0-lidx1 self.helper_test_variable((((alu2+14)//(-32))+4), 4, 4, "4") self.helper_test_variable(-(((alu2+14)//(-32))+4), -4, -4, "-4") self.helper_test_variable((((alu2+134)//(-32))+4), 0, 1, "(((((lidx0*-1)+(lidx1*-1))+134)//-32)+4)") self.helper_test_variable((((alu2+142)//(-32))+4), 0, 0, "0") self.helper_test_variable((((alu2+150)//(-32))+4), 0, 0, "0") self.helper_test_variable((((alu2+158)//(-32))+4), 0, 0, "0") def test_div_mod_recombine(self): gidx = Variable("gidx", 0, 124) self.helper_test_variable(gidx%4+(gidx//4)*4, 0, 124, "gidx") self.helper_test_variable((gidx//4)*4+gidx%4, 0, 124, "gidx") def test_div_mod_recombine_folded_mod(self): a = Variable("a", 0, 2) b = Variable("b", 0, 100) self.helper_test_variable((31 * a + 1) % 30 + ((31 * a + 1) // 30) * 30, 1, 63, "((a*31)+1)") with self.assertRaises(AssertionError): self.helper_test_variable((31 * b + 1) % 18 + ((31 * b + 1) // 18) * 18, 1, 3101, "((b*31)+1)") def test_div_mod_recombine_with_gcd(self): b = Variable("b", 0, 100) exp = (16 * b + 2) % 18 + ((16 * b + 2) // 18) * 18 self.helper_test_variable(exp, 2, 1602, "((b*16)+2)") with self.assertRaises(AssertionError): self.helper_test_variable((30 * b + 1) % 18 + ((30 * b + 1) // 18) * 18, 1, 3001, "((b*30)+1)") def test_arange_unrolled4(self): gidx = Variable("gidx", 0, 2559) unrolled_div = (gidx+2561)//4+(gidx+2562)//4+(gidx+2560)//4+(gidx+2559)//4 self.helper_test_variable(unrolled_div, 2559, 5118, "(gidx+2559)") def test_arange_unrolled4_small(self): gidx = Variable("gidx", 0, 3) unrolled_div = (gidx)//4+(gidx+2)//4+(gidx+3)//4+(gidx+1)//4 self.helper_test_variable(unrolled_div, 0, 3, "gidx") gidx = Variable("gidx", 0, 2) unrolled_div = (gidx)//4+(gidx+2)//4+(gidx+3)//4+(gidx+1)//4 self.helper_test_variable(unrolled_div, 0, 2, "gidx") gidx = Variable("gidx", 0, 1) unrolled_div = (gidx)//4+(gidx+2)//4+(gidx+3)//4+(gidx+1)//4 self.helper_test_variable(unrolled_div, 0, 1, "gidx") def test_arange_unrolled2(self): gidx = Variable("gidx", 0, 2559) unrolled_div = (gidx+2559)//2+(gidx+2560)//2+3 self.helper_test_variable(unrolled_div, 2562, 5121, "(gidx+2562)") def test_gated_load(self): idx = Variable("idx", 0, 24) self.helper_test_variable(idx//4, 0, 6, "(idx//4)") # TODO: simplify the true branch self.helper_test_variable((idx<4).where(idx//4, idx.const_like(-1)), -1, 6, "((idx//4) if (idx<4) else -1)") def test_idiv_lt(self): idx = Variable("idx", 0, 24) self.helper_test_variable((idx//4<3), 0, 1, "(idx<12)") self.helper_test_variable((idx//-4<-3), 0, 1, "((idx//-4)<-3)") def test_simplex_lt(self): a = Variable("a", 0, 3) b = Variable("b", 0, 3) c = Variable("c", 0, 3) d = Variable("d", -3, 3) self.helper_test_variable((a<1).ne(True), 0, 1, "((a<1)!=True)") self.helper_test_variable((a+b<1).ne(True), 0, 1, "(((a+b)<1)!=True)") self.helper_test_variable((a*3+b*4<1).ne(True), 0, 1, "(((a+b)<1)!=True)") self.helper_test_variable((a*(-3)+b*4<1).ne(True), 0, 1, "((((a*-3)+(b*4))<1)!=True)") # negative coeff, should not be simplified self.helper_test_variable((a*3+d*4<1).ne(True), 0, 1, "((((a*3)+(d*4))<1)!=True)") # var can be negative, should not be simplified self.helper_test_variable((a+b+c*2<1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)")) self.helper_test_variable((a+b*2+c*4<1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)")) def test_where_removal(self): cond = Variable("a", 0, 3) < 2 u1, u0 = cond.ufix(1), cond.ufix(0) self.helper_test_variable(cond, 0, 1, "(a<2)") self.helper_test_variable(cond.where(u1, u0), 0, 1, "(a<2)") self.helper_test_variable(cond.where(u1, u0).where(u1, u0), 0, 1, "(a<2)") def test_where_combine(self): cond = Variable("x", 0, 3) < 2 a = Variable("a", 0, 3) b = Variable("b", 0, 3) aa = cond.where(a, a.ufix(0)) bb = cond.where(b, b.ufix(1)) self.helper_test_variable(aa, 0, 3, "(a if (x<2) else 0)") self.helper_test_variable(bb, 0, 3, "(b if (x<2) else 1)") self.helper_test_variable(aa+bb, 0, 6, "((a+b) if (x<2) else 1)") self.helper_test_variable(aa.maximum(bb), 0, 3, "(max(a, b) if (x<2) else 1)") # not combining because it increased total ALU c = Variable("c", 0, 3) cc = cond.where(c, c+1) self.helper_test_variable(bb+cc, 0, 7, "((b if (x<2) else 1)+(c if (x<2) else (c+1)))") # not combining # TODO: can combine if it can further simplify? ab = cond.where(a, b) ba = cond.where(b, a) self.helper_test_variable(ab+ba, 0, 6, "((a if (x<2) else b)+(b if (x<2) else a))") # not combining # TODO: can combine if one is identity element const self.helper_test_variable(aa+ab, 0, 6, "((a if (x<2) else b)+(a if (x<2) else 0))") def test_symbolic_div(self): # from symbolic arange a = Variable("a", 1, 10) denominator = ((a*-2)+1) numerator = (((((a*2)+-1)*2)+1)*a) self.helper_test_variable(denominator, -19, -1, "((a*-2)+1)") self.helper_test_variable(numerator, 3, 390, "(a*((a*4)+-1))") self.helper_test_variable((numerator//denominator)<=0, 1, 1, "True") class TestSymbolicNumeric(unittest.TestCase): def helper_test_numeric(self, f): MIN, MAX = 0, 10 # one number for i in range(MIN, MAX): v = graph_rewrite(f(uconst(i)), sym) self.assertEqual(v.vmin, v.vmax) self.assertEqual(v.vmin, f(i)) for kmin in range(MIN, MAX): for kmax in range(MIN, MAX): if kmin > kmax: continue v = f(Variable("tmp", kmin, kmax)) values = [f(rv) for rv in range(kmin, kmax+1)] # the min and max may not be exact self.assertLessEqual(v.vmin, min(values)) self.assertGreaterEqual(v.vmax, max(values)) def test_mod_4(self): self.helper_test_numeric(lambda x: (x%4)) def test_div_4(self): self.helper_test_numeric(lambda x: (x//4)) def test_plus_1_div_2(self): self.helper_test_numeric(lambda x: (x+1)//2) def test_plus_1_mod_2(self): self.helper_test_numeric(lambda x: (x+1)%2) def test_times_2(self): self.helper_test_numeric(lambda x: x*2) def test_times_2_plus_3(self): self.helper_test_numeric(lambda x: x*2 + 3) def test_times_2_plus_3_mod_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)%4) def test_times_2_plus_3_div_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)//4) def test_times_2_plus_3_div_4_mod_4(self): self.helper_test_numeric(lambda x: ((x*2 + 3)//4)%4) class TestSymbolicVars(unittest.TestCase): def test_simple(self): z = uconst(0) a = Variable("a", 0, 10) b = Variable("b", 0, 10) c = Variable("c", 0, 10) assert z.vars() == z.vars() == set() print(a.vars()) assert a.vars() == a.vars() == {a} m = a * 3 assert m.vars() == {a} s = usum([a, b, c]) assert s.vars() == {a, b, c} def test_compound(self): a = Variable("a", 0, 10) b = Variable("b", 0, 10) c = Variable("c", 0, 10) assert (a + b * c).vars() == {a, b, c} assert (a % 3 + b // 5).vars() == {a, b} # TODO: fix me with self.assertRaises(AssertionError): assert (a + b + c - a).vars() == {b, c} def test_dedup(self): a = Variable("a", 0, 10) assert (a * a).vars() == {a} assert (a//4 + a//6).vars() == {a} class TestSymInfer(unittest.TestCase): def test_sym_infer(self): a = Variable("a", 0, 10) b = Variable("b", 0, 10) c = Variable("c", 0, 10) var_vals = {a: 2, b: 3, c: 4} assert sym_infer(5, var_vals) == 5 assert sym_infer(a, var_vals) == 2 assert sym_infer(b, var_vals) == 3 assert sym_infer(a+b, var_vals) == 5 assert sym_infer(a-b, var_vals) == -1 assert sym_infer(a+b+c, var_vals) == 9 assert sym_infer(a*b, var_vals) == 6 assert sym_infer(a*b+c, var_vals) == 10 """ @unittest.skip("not supported on uops yet") class TestSymbolicSymbolicOps(unittest.TestCase): def test_node_divmod_node(self): i = Variable("i", 1, 10) idx0 = Variable("idx0", 0, i*3-1) assert uconst(0) // (Variable("i", 1, 10)*128) == 0 assert uconst(0) % (Variable("i", 1, 10)*128) == 0 assert uconst(127) // (Variable("i", 1, 10)*128) == 0 assert uconst(127) % (Variable("i", 1, 10)*128) == 127 assert 127 // (Variable("i", 1, 10)*128) == 0 assert 127 % (Variable("i", 1, 10)*128) == 127 assert uconst(128) // (Variable("i", 1, 10)*128 + 128) == 0 assert uconst(128) % (Variable("i", 1, 10)*128 + 128) == 128 assert 128 // (Variable("i", 1, 10)*128 + 128) == 0 assert 128 % (Variable("i", 1, 10)*128 + 128) == 128 assert 0 // (Variable("i", 1, 10)*128) == 0 assert 0 % (Variable("i", 1, 10)*128) == 0 assert idx0 // (i*3) == 0 assert idx0 % (i*3) == idx0 assert i // i == 1 assert i % i == 0 assert 128 // uconst(4) == 32 assert 128 % uconst(4) == 0 assert uconst(128) // uconst(4) == 32 assert uconst(128) % uconst(4) == 0 def test_mulnode_divmod_node(self): i = Variable("i", 1, 10) idx0 = Variable("idx0", 0, 31) # assert (idx0*(i*4+4)) // (i+1) == (idx0*4) # assert (idx0*(i*4+4)) % (i+1) == 0 assert (idx0*i) % i == 0 def test_sumnode_divmod_sumnode(self): i = Variable("i", 1, 10) # idx0 = Variable("idx0", 0, 7) # idx1 = Variable("idx1", 0, 3) # idx2 = Variable("idx2", 0, i) # assert (idx0*(i*4+4)+idx1*(i+1)+idx2) // (i+1) == idx0*4+idx1 # assert (idx0*(i*4+4)+idx1*(i+1)+idx2) % (i+1) == idx2 assert (i+1) // (i*128+128) == 0 assert (i+1) % (i*128+128) == (i+1) # assert (i+1+idx2) // (i+1) == 1 # assert (i+1+idx2) % (i+1) == idx2 # assert (idx0*(i*4+4)+i+1+idx2) // (i+1) == idx0*4+1 # assert (idx0*(i*4+4)+i+1+idx2) % (i+1) == idx2 # assert (i*128+128)*2 // (i*128+128) == 2 # assert (i*128+128)*2 % (i*128+128) == 0 def test_sumnode_div_uconst_no_factoring(self): gid = Variable("gid", 0, 1023) lid = Variable("lid", 0, 3) expr_before_div = uconst(-1019)-4*lid-gid unfactored_expr = Node.__floordiv__(expr_before_div, uconst(-16), False) factored_expr = Node.__floordiv__(expr_before_div, uconst(-16), True) self.assertEqual(unfactored_expr.render(), "(((lid*4)+1019+gid)//16)") self.assertEqual(factored_expr.render(), "(((((3+gid)//4)+2+lid)//4)+63)") def test_mod_node_max(self): i = Variable("i", 1, 128) gidx0 = Variable("gidx0", 0, i) mod = gidx0 % 8 assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 8 mod = gidx0 % 2 assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 2 gidx0 = Variable("gidx0", 0, i*8+7) mod = gidx0 % 8 assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 8 mod = gidx0 % 2 assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 2 def test_nested_variable_mod(self): i = Variable("i", 1, 5) idx0 = Variable("idx0", 0, i) with self.assertRaises(AssertionError): assert idx0 % 2 == idx0 def test_num_node_mul_node(self): a = Variable("a", 1, 5) b = uconst(2) * a assert b == a * 2 assert isinstance(b, MulNode) b = uconst(1) * a assert b == a assert isinstance(b, Variable) b = uconst(0) * a assert b == 0 assert isinstance(b, uconst) def test_substitute(self): a = Variable("idx0", 1, 3) b = a + 1 c = b.substitute({a: uconst(1)}) assert c == uconst(2) """ class TestSymbolicRealWorld(unittest.TestCase): def test_resnet_half(self): gidx0 = Variable("gidx0", 0, 3) gidx1 = Variable("gidx1", 0, 127) gidx2 = Variable("gidx2", 0, 7) lidx3 = Variable("lidx3", 0, 7) lidx4 = Variable("lidx4", 0, 1) lidx5 = Variable("lidx5", 0, 15) idx:UOp = ((((1+lidx5)%16)*49)+(((262145+lidx5)//16)*802816)+(gidx0*3211264)+(gidx1*784)+(gidx2*8)+(lidx4*100352)+-13151129600+lidx3) idx = graph_rewrite(idx, sym) #print(idx.render()) # NOTE: this used to have 13,151,129,600 in the output which is out of int32 range. self.assertIn(idx.render(), ("((((((((((lidx5+1)//16)*802816)+(((lidx5+1)%16)*49))+(gidx0*3211264))+(gidx1*784))+(gidx2*8))+(lidx4*100352))+lidx3)+2207744)", '((lidx3+((((((((lidx5+1)//16)*802816)+(((lidx5+1)%16)*49))+(gidx0*3211264))+(gidx1*784))+(gidx2*8))+(lidx4*100352)))+2207744)', )) class TestBounds(unittest.TestCase): def test_unrolled_arange(self): # #include # using namespace metal; # kernel void r_2560_640_4(device int* data0, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { # int gidx0 = gid.x; /* 2560 */ # int alu0 = (gidx0*(-1)); # int alu1 = max((int)((-640)),((((alu0+2559)/(-4))*(-1))+(-640))); # int alu2 = max((int)((-640)),((((alu0+2560)/(-4))*(-1))+(-640))); # int alu3 = max((int)((-640)),((((alu0+2561)/(-4))*(-1))+(-640))); # int alu4 = max((int)((-640)),((((alu0+2562)/(-4))*(-1))+(-640))); # *(data0+gidx0) = ((alu1*(-1))+(alu2*(-1))+(alu4*(-1))+(alu3*(-1))+(-1)); # } gidx0 = Variable("gidx0", 0, 2559) assert gidx0.vmin == 0 and gidx0.vmax == 2559 alu0 = gidx0 * -1 assert alu0.vmin == -2559 and alu0.vmax == 0 assert (alu0+2559).vmin == 0 and (alu0+2559).vmax == 2559 assert ((alu0+2559)//-4).vmin == -639 and ((alu0+2559)//-4).vmax == 0 assert (((alu0+2559)//-4)*(-1)).vmin == 0 and (((alu0+2559)//-4)*(-1)).vmax == 639 if __name__ == '__main__': unittest.main()