openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

93 lines
4.6 KiB

import random, operator
import z3
from tinygrad import Variable, dtypes
from tinygrad.uop.ops import UOp, graph_rewrite
from tinygrad.uop.spec import z3_renderer
from tinygrad.helpers import DEBUG, Context
seed = random.randint(0, 100)
print(f"Seed: {seed}")
random.seed(seed)
unary_ops = [lambda a:a+random.randint(-4, 4), lambda a: a*random.randint(-4, 4),
lambda a: a//random.randint(1, 9), lambda a: a%random.randint(1, 9),
lambda a:a.maximum(random.randint(-10, 10)), lambda a:a.minimum(random.randint(-10, 10))]
binary_ops = [lambda a,b: a+b, lambda a,b: a*b, lambda a,b:a.maximum(b), lambda a,b:a.minimum(b)]
comp_ops = [operator.lt, operator.le, operator.gt, operator.ge]
def random_or_sub_expression_int(depth, expr):
sub_expr = random.choice([e for e in expr.toposort() if e.dtype is not dtypes.bool])
return random.choice([random_int_expr(depth-1), sub_expr])
def random_int_expr(depth=10):
if depth <= 0: return random.choice(v)
expr1 = random_int_expr(depth-1)
# we give more weight to arithmatic ops than to minimum and maximum
ops = [
lambda: random.choices(unary_ops, weights=[4, 4, 4, 4, 1, 1])[0](expr1),
# for the second operand its either another random exprssion or some subexpression of the first operand
lambda: random.choices(binary_ops, [8, 1, 1, 1])[0](expr1, random_or_sub_expression_int(depth-1, expr1)),
lambda: random_bool_expr(3, random_or_sub_expression_int(depth-1, expr1)).where(expr1, random_or_sub_expression_int(depth-1, expr1)),
]
# we give weight proportional to the amount of ops in each branch
return random.choices(ops, weights=[6, 4, 1])[0]()
def random_bool_expr(depth=10, expr1=None):
if depth == 0: return True
if expr1 is None: expr1 = random_int_expr(depth-1)
expr2 = random.choice([random_or_sub_expression_int(depth-1, expr1), UOp.const(dtypes.int, random.randint(-10, 10))])
return random.choice(comp_ops)(expr1, expr2)
if __name__ == "__main__":
skipped = 0
for i in range(10000):
if i % 1000 == 0:
print(f"Running test {i}")
upper_bounds = [*list(range(1, 10)), 16, 32, 64, 128, 256]
u1 = Variable("v1", 0, random.choice(upper_bounds))
u2 = Variable("v2", 0, random.choice(upper_bounds))
u3 = Variable("v3", 0, random.choice(upper_bounds))
v = [u1,u2,u3]
expr = random_int_expr(6)
with Context(CORRECT_DIVMOD_FOLDING=1):
simplified_expr = expr.simplify()
solver = z3.Solver()
solver.set(timeout=5000) # some expressions take very long verify, but its very unlikely they actually return sat
z3_sink = graph_rewrite(expr.sink(simplified_expr, u1, u2, u3), z3_renderer, ctx=(solver, {}))
z3_expr, z3_simplified_expr = z3_sink.src[0].arg, z3_sink.src[1].arg
check = solver.check(z3_simplified_expr != z3_expr)
if check == z3.unknown and DEBUG>=1:
skipped += 1
print("Skipped due to timeout or interrupt:\n" +
f"v1=Variable(\"{u1.arg[0]}\", {u1.arg[1]}, {u1.arg[2]})\n" +
f"v2=Variable(\"{u2.arg[0]}\", {u2.arg[1]}, {u2.arg[2]})\n" +
f"v3=Variable(\"{u3.arg[0]}\", {u3.arg[1]}, {u3.arg[2]})\n" +
f"expr = {expr.render(simplify=False)}\n")
elif check == z3.sat:
m = solver.model()
v1, v2, v3 = z3_sink.src[2].arg, z3_sink.src[3].arg, z3_sink.src[4].arg
n1, n2, n3 = m[v1], m[v2], m[v3]
u1_val, u2_val, u3_val = u1.const_like(n1.as_long()), u2.const_like(n2.as_long()), u3.const_like(n3.as_long())
with Context(CORRECT_DIVMOD_FOLDING=1):
num = expr.simplify().substitute({u1:u1_val, u2:u2_val, u3:u3_val}).ssimplify()
rn = expr.substitute({u1:u1_val, u2:u2_val, u3:u3_val}).ssimplify()
if num==rn: print("z3 found a mismatch but the expressions are equal!!")
assert False, f"mismatched {expr.render()} at v1={m[v1]}; v2={m[v2]}; v3={m[v3]} = {num} != {rn}\n" +\
"Reproduce with:\n" +\
f"v1=Variable(\"{u1.arg[0]}\", {u1.arg[1]}, {u1.arg[2]})\n" +\
f"v2=Variable(\"{u2.arg[0]}\", {u2.arg[1]}, {u2.arg[2]})\n" +\
f"v3=Variable(\"{u3.arg[0]}\", {u3.arg[1]}, {u3.arg[2]})\n" +\
f"expr = {expr}\n" +\
f"v1_val, v2_val, v3_val = UOp.const(dtypes.int, {n1.as_long()}), UOp.const(dtypes.int, {n2.as_long()})," +\
f"UOp.const(dtypes.int, {n3.as_long()})\n" +\
"num = expr.simplify().substitute({v1:v1_val, v2:v2_val, v3:v3_val}).ssimplify()\n" +\
"rn = expr.substitute({v1:v1_val, v2:v2_val, v3:v3_val}).ssimplify()\n" +\
"assert num==rn, f\"{num} != {rn}\"\n"
if DEBUG >= 2: print(f"validated {expr.render()}")
print(f"Skipped {skipped} expressions due to timeout")