import random , operator
import z3
from tinygrad import Variable , dtypes
from tinygrad . uop . ops import UOp , graph_rewrite
from tinygrad . uop . spec import z3_renderer
from tinygrad . helpers import DEBUG , Context
seed = random . randint ( 0 , 100 )
print ( f " Seed: { seed } " )
random . seed ( seed )
unary_ops = [ lambda a : a + random . randint ( - 4 , 4 ) , lambda a : a * random . randint ( - 4 , 4 ) ,
lambda a : a / / random . randint ( 1 , 9 ) , lambda a : a % random . randint ( 1 , 9 ) ,
lambda a : a . maximum ( random . randint ( - 10 , 10 ) ) , lambda a : a . minimum ( random . randint ( - 10 , 10 ) ) ]
binary_ops = [ lambda a , b : a + b , lambda a , b : a * b , lambda a , b : a . maximum ( b ) , lambda a , b : a . minimum ( b ) ]
comp_ops = [ operator . lt , operator . le , operator . gt , operator . ge ]
def random_or_sub_expression_int ( depth , expr ) :
sub_expr = random . choice ( [ e for e in expr . toposort ( ) if e . dtype is not dtypes . bool ] )
return random . choice ( [ random_int_expr ( depth - 1 ) , sub_expr ] )
def random_int_expr ( depth = 10 ) :
if depth < = 0 : return random . choice ( v )
expr1 = random_int_expr ( depth - 1 )
# we give more weight to arithmatic ops than to minimum and maximum
ops = [
lambda : random . choices ( unary_ops , weights = [ 4 , 4 , 4 , 4 , 1 , 1 ] ) [ 0 ] ( expr1 ) ,
# for the second operand its either another random exprssion or some subexpression of the first operand
lambda : random . choices ( binary_ops , [ 8 , 1 , 1 , 1 ] ) [ 0 ] ( expr1 , random_or_sub_expression_int ( depth - 1 , expr1 ) ) ,
lambda : random_bool_expr ( 3 , random_or_sub_expression_int ( depth - 1 , expr1 ) ) . where ( expr1 , random_or_sub_expression_int ( depth - 1 , expr1 ) ) ,
]
# we give weight proportional to the amount of ops in each branch
return random . choices ( ops , weights = [ 6 , 4 , 1 ] ) [ 0 ] ( )
def random_bool_expr ( depth = 10 , expr1 = None ) :
if depth == 0 : return True
if expr1 is None : expr1 = random_int_expr ( depth - 1 )
expr2 = random . choice ( [ random_or_sub_expression_int ( depth - 1 , expr1 ) , UOp . const ( dtypes . int , random . randint ( - 10 , 10 ) ) ] )
return random . choice ( comp_ops ) ( expr1 , expr2 )
if __name__ == " __main__ " :
skipped = 0
for i in range ( 10000 ) :
if i % 1000 == 0 :
print ( f " Running test { i } " )
upper_bounds = [ * list ( range ( 1 , 10 ) ) , 16 , 32 , 64 , 128 , 256 ]
u1 = Variable ( " v1 " , 0 , random . choice ( upper_bounds ) )
u2 = Variable ( " v2 " , 0 , random . choice ( upper_bounds ) )
u3 = Variable ( " v3 " , 0 , random . choice ( upper_bounds ) )
v = [ u1 , u2 , u3 ]
expr = random_int_expr ( 6 )
with Context ( CORRECT_DIVMOD_FOLDING = 1 ) :
simplified_expr = expr . simplify ( )
solver = z3 . Solver ( )
solver . set ( timeout = 5000 ) # some expressions take very long verify, but its very unlikely they actually return sat
z3_sink = graph_rewrite ( expr . sink ( simplified_expr , u1 , u2 , u3 ) , z3_renderer , ctx = ( solver , { } ) )
z3_expr , z3_simplified_expr = z3_sink . src [ 0 ] . arg , z3_sink . src [ 1 ] . arg
check = solver . check ( z3_simplified_expr != z3_expr )
if check == z3 . unknown and DEBUG > = 1 :
skipped + = 1
print ( " Skipped due to timeout or interrupt: \n " +
f " v1=Variable( \" { u1 . arg [ 0 ] } \" , { u1 . arg [ 1 ] } , { u1 . arg [ 2 ] } ) \n " +
f " v2=Variable( \" { u2 . arg [ 0 ] } \" , { u2 . arg [ 1 ] } , { u2 . arg [ 2 ] } ) \n " +
f " v3=Variable( \" { u3 . arg [ 0 ] } \" , { u3 . arg [ 1 ] } , { u3 . arg [ 2 ] } ) \n " +
f " expr = { expr . render ( simplify = False ) } \n " )
elif check == z3 . sat :
m = solver . model ( )
v1 , v2 , v3 = z3_sink . src [ 2 ] . arg , z3_sink . src [ 3 ] . arg , z3_sink . src [ 4 ] . arg
n1 , n2 , n3 = m [ v1 ] , m [ v2 ] , m [ v3 ]
u1_val , u2_val , u3_val = u1 . const_like ( n1 . as_long ( ) ) , u2 . const_like ( n2 . as_long ( ) ) , u3 . const_like ( n3 . as_long ( ) )
with Context ( CORRECT_DIVMOD_FOLDING = 1 ) :
num = expr . simplify ( ) . substitute ( { u1 : u1_val , u2 : u2_val , u3 : u3_val } ) . ssimplify ( )
rn = expr . substitute ( { u1 : u1_val , u2 : u2_val , u3 : u3_val } ) . ssimplify ( )
if num == rn : print ( " z3 found a mismatch but the expressions are equal!! " )
assert False , f " mismatched { expr . render ( ) } at v1= { m [ v1 ] } ; v2= { m [ v2 ] } ; v3= { m [ v3 ] } = { num } != { rn } \n " + \
" Reproduce with: \n " + \
f " v1=Variable( \" { u1 . arg [ 0 ] } \" , { u1 . arg [ 1 ] } , { u1 . arg [ 2 ] } ) \n " + \
f " v2=Variable( \" { u2 . arg [ 0 ] } \" , { u2 . arg [ 1 ] } , { u2 . arg [ 2 ] } ) \n " + \
f " v3=Variable( \" { u3 . arg [ 0 ] } \" , { u3 . arg [ 1 ] } , { u3 . arg [ 2 ] } ) \n " + \
f " expr = { expr } \n " + \
f " v1_val, v2_val, v3_val = UOp.const(dtypes.int, { n1 . as_long ( ) } ), UOp.const(dtypes.int, { n2 . as_long ( ) } ), " + \
f " UOp.const(dtypes.int, { n3 . as_long ( ) } ) \n " + \
" num = expr.simplify().substitute( { v1:v1_val, v2:v2_val, v3:v3_val}).ssimplify() \n " + \
" rn = expr.substitute( { v1:v1_val, v2:v2_val, v3:v3_val}).ssimplify() \n " + \
" assert num==rn, f \" {num} != {rn} \" \n "
if DEBUG > = 2 : print ( f " validated { expr . render ( ) } " )
print ( f " Skipped { skipped } expressions due to timeout " )