import random, operator import z3 from tinygrad import Variable, dtypes from tinygrad.uop.ops import UOp, graph_rewrite from tinygrad.uop.spec import z3_renderer from tinygrad.helpers import DEBUG, Context seed = random.randint(0, 100) print(f"Seed: {seed}") random.seed(seed) unary_ops = [lambda a:a+random.randint(-4, 4), lambda a: a*random.randint(-4, 4), lambda a: a//random.randint(1, 9), lambda a: a%random.randint(1, 9), lambda a:a.maximum(random.randint(-10, 10)), lambda a:a.minimum(random.randint(-10, 10))] binary_ops = [lambda a,b: a+b, lambda a,b: a*b, lambda a,b:a.maximum(b), lambda a,b:a.minimum(b)] comp_ops = [operator.lt, operator.le, operator.gt, operator.ge] def random_or_sub_expression_int(depth, expr): sub_expr = random.choice([e for e in expr.toposort() if e.dtype is not dtypes.bool]) return random.choice([random_int_expr(depth-1), sub_expr]) def random_int_expr(depth=10): if depth <= 0: return random.choice(v) expr1 = random_int_expr(depth-1) # we give more weight to arithmatic ops than to minimum and maximum ops = [ lambda: random.choices(unary_ops, weights=[4, 4, 4, 4, 1, 1])[0](expr1), # for the second operand its either another random exprssion or some subexpression of the first operand lambda: random.choices(binary_ops, [8, 1, 1, 1])[0](expr1, random_or_sub_expression_int(depth-1, expr1)), lambda: random_bool_expr(3, random_or_sub_expression_int(depth-1, expr1)).where(expr1, random_or_sub_expression_int(depth-1, expr1)), ] # we give weight proportional to the amount of ops in each branch return random.choices(ops, weights=[6, 4, 1])[0]() def random_bool_expr(depth=10, expr1=None): if depth == 0: return True if expr1 is None: expr1 = random_int_expr(depth-1) expr2 = random.choice([random_or_sub_expression_int(depth-1, expr1), UOp.const(dtypes.int, random.randint(-10, 10))]) return random.choice(comp_ops)(expr1, expr2) if __name__ == "__main__": skipped = 0 for i in range(10000): if i % 1000 == 0: print(f"Running test {i}") upper_bounds = [*list(range(1, 10)), 16, 32, 64, 128, 256] u1 = Variable("v1", 0, random.choice(upper_bounds)) u2 = Variable("v2", 0, random.choice(upper_bounds)) u3 = Variable("v3", 0, random.choice(upper_bounds)) v = [u1,u2,u3] expr = random_int_expr(6) with Context(CORRECT_DIVMOD_FOLDING=1): simplified_expr = expr.simplify() solver = z3.Solver() solver.set(timeout=5000) # some expressions take very long verify, but its very unlikely they actually return sat z3_sink = graph_rewrite(expr.sink(simplified_expr, u1, u2, u3), z3_renderer, ctx=(solver, {})) z3_expr, z3_simplified_expr = z3_sink.src[0].arg, z3_sink.src[1].arg check = solver.check(z3_simplified_expr != z3_expr) if check == z3.unknown and DEBUG>=1: skipped += 1 print("Skipped due to timeout or interrupt:\n" + f"v1=Variable(\"{u1.arg[0]}\", {u1.arg[1]}, {u1.arg[2]})\n" + f"v2=Variable(\"{u2.arg[0]}\", {u2.arg[1]}, {u2.arg[2]})\n" + f"v3=Variable(\"{u3.arg[0]}\", {u3.arg[1]}, {u3.arg[2]})\n" + f"expr = {expr.render(simplify=False)}\n") elif check == z3.sat: m = solver.model() v1, v2, v3 = z3_sink.src[2].arg, z3_sink.src[3].arg, z3_sink.src[4].arg n1, n2, n3 = m[v1], m[v2], m[v3] u1_val, u2_val, u3_val = u1.const_like(n1.as_long()), u2.const_like(n2.as_long()), u3.const_like(n3.as_long()) with Context(CORRECT_DIVMOD_FOLDING=1): num = expr.simplify().substitute({u1:u1_val, u2:u2_val, u3:u3_val}).ssimplify() rn = expr.substitute({u1:u1_val, u2:u2_val, u3:u3_val}).ssimplify() if num==rn: print("z3 found a mismatch but the expressions are equal!!") assert False, f"mismatched {expr.render()} at v1={m[v1]}; v2={m[v2]}; v3={m[v3]} = {num} != {rn}\n" +\ "Reproduce with:\n" +\ f"v1=Variable(\"{u1.arg[0]}\", {u1.arg[1]}, {u1.arg[2]})\n" +\ f"v2=Variable(\"{u2.arg[0]}\", {u2.arg[1]}, {u2.arg[2]})\n" +\ f"v3=Variable(\"{u3.arg[0]}\", {u3.arg[1]}, {u3.arg[2]})\n" +\ f"expr = {expr}\n" +\ f"v1_val, v2_val, v3_val = UOp.const(dtypes.int, {n1.as_long()}), UOp.const(dtypes.int, {n2.as_long()})," +\ f"UOp.const(dtypes.int, {n3.as_long()})\n" +\ "num = expr.simplify().substitute({v1:v1_val, v2:v2_val, v3:v3_val}).ssimplify()\n" +\ "rn = expr.substitute({v1:v1_val, v2:v2_val, v3:v3_val}).ssimplify()\n" +\ "assert num==rn, f\"{num} != {rn}\"\n" if DEBUG >= 2: print(f"validated {expr.render()}") print(f"Skipped {skipped} expressions due to timeout")