openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

281 lines
12 KiB

import unittest, math
from tinygrad import dtypes
from tinygrad.helpers import all_same
from tinygrad.ops import GroupOp, UOp, Ops, exec_alu
from tinygrad.codegen.devectorizer import full_graph_rewrite
# Helper function to apply the graph rewrite
def apply_rewrite(expr):
return full_graph_rewrite(expr.sink()).src[0]
def evaluate_uop(uop, variables):
if uop.op == Ops.CONST:
return uop.arg
elif uop.op == Ops.DEFINE_VAR:
var_name = uop.arg[0]
return variables[var_name]
elif uop.op in GroupOp.ALU:
src_values = [evaluate_uop(src, variables) for src in uop.src]
return exec_alu(uop.op, uop.dtype, src_values)
else:
raise NotImplementedError(f"Unsupported UOp {uop.op}")
class TestArithmeticSimplifications(unittest.TestCase):
def test_full_graph_rewrite_division_by_zero(self):
optimized_div_uop = apply_rewrite(UOp.const(dtypes.float32, 10.0) / UOp.const(dtypes.float32, 0.0))
self.assertEqual(optimized_div_uop.op, Ops.CONST)
self.assertTrue(math.isinf(optimized_div_uop.arg) or math.isnan(optimized_div_uop.arg))
def test_full_graph_rewrite_redundant_operations(self):
optimized_uop = apply_rewrite((UOp.const(dtypes.float32, 10.0) + UOp.const(dtypes.float32, 0.0)) * UOp.const(dtypes.float32, 1.0))
self.assertEqual(optimized_uop.op, Ops.CONST)
self.assertEqual(optimized_uop.arg, 10.0)
def test_full_graph_rewrite_large_graph(self):
prev_uop = UOp.const(dtypes.int32, 0)
for i in range(1, 101):
prev_uop += UOp.const(dtypes.int32, i)
optimized_uop = apply_rewrite(prev_uop)
self.assertEqual(optimized_uop.op, Ops.CONST)
self.assertEqual(optimized_uop.arg, sum(range(1, 101)))
def test_full_graph_rewrite_division_by_one(self):
optimized_uop = apply_rewrite(UOp.const(dtypes.float32, 42.0) / UOp.const(dtypes.float32, 1.0))
self.assertEqual(optimized_uop.op, Ops.CONST)
self.assertEqual(optimized_uop.arg, 42.0)
def test_full_graph_rewrite_modulo_by_one(self):
optimized_uop = apply_rewrite(UOp.const(dtypes.int32, 42) % UOp.const(dtypes.int32, 1))
self.assertEqual(optimized_uop.op, Ops.CONST)
self.assertEqual(optimized_uop.arg, 0)
class TestFoldingAndReduction(unittest.TestCase):
@unittest.skip("reduce is removed now")
def test_full_graph_rewrite_constant_reduction_folding(self):
const1 = UOp.const(dtypes.int32, 5)
const2 = UOp.const(dtypes.int32, 10)
const3 = UOp.const(dtypes.int32, 20)
optimized_sink = apply_rewrite((const1 + const2 + const3).reduce(Ops.ADD))
expected_sum = 5 + 10 + 20
self.assertEqual(optimized_sink.arg, expected_sum)
@unittest.skip("reduce is removed now")
def test_full_graph_rewrite_reduction_with_unused_range(self):
const1 = UOp.const(dtypes.int32, 15)
const2 = UOp.const(dtypes.int32, 25)
rng = UOp.range(dtypes.int32, 0, 10, idx=0)
optimized_sink = apply_rewrite((const1 + const2).reduce(Ops.ADD, rng))
expected_sum = 10 * (15 + 25)
self.assertEqual(optimized_sink.arg, expected_sum)
@unittest.skip("currently failing")
def test_full_graph_rewrite_range_reduction(self):
simple_range = UOp.range(dtypes.int32, 0, 5, idx=0)
optimized_sink = apply_rewrite(simple_range.reduce(Ops.ADD, simple_range))
expected_sum = sum(range(5))
self.assertEqual(optimized_sink.arg, expected_sum)
@unittest.skip("currently failing")
def test_full_graph_rewrite_simple_reduction_folding(self):
simple_range = UOp.range(dtypes.int32, 0, 4, idx=0)
add_uop = simple_range + UOp.const(dtypes.int32, 1)
optimized_sink = apply_rewrite(add_uop.reduce(Ops.ADD, simple_range))
expected_sum = sum(i + 1 for i in range(4))
self.assertEqual(optimized_sink.arg, expected_sum)
@unittest.skip("currently failing")
def test_full_graph_rewrite_nested_loop_collapse(self):
outer_range = UOp.range(dtypes.int32, 0, 8, 0)
inner_range = UOp.range(dtypes.int32, 0, 4, 1)
expr = (outer_range * 10) + inner_range
optimized_reduce_uop = apply_rewrite(expr.reduce(Ops.ADD, outer_range, inner_range))
self.assertEqual(optimized_reduce_uop.op, Ops.CONST)
self.assertEqual(optimized_reduce_uop.arg, sum((i * 10) + j for i in range(8) for j in range(4)))
class TestModuloAndDivisionFolding(unittest.TestCase):
def test_full_graph_rewrite_modulo_folding_with_define_var(self):
x_var_uop = UOp.variable('x', 0, 100)
optimized_mod_uop = apply_rewrite(((x_var_uop * 4) + 2) % 4)
self.assertEqual(optimized_mod_uop.op, Ops.CONST)
self.assertEqual(optimized_mod_uop.arg, 2)
def test_full_graph_rewrite_division_folding_with_define_var(self):
n_var_uop = UOp.variable('n', 1, 1000)
optimized_div_uop = apply_rewrite((n_var_uop * 6) // 3)
self.assertEqual(optimized_div_uop.op, Ops.MUL)
self.assertEqual(optimized_div_uop.src[1].arg, 2)
def test_full_graph_rewrite_complex_mod_div_folding(self):
k_var_uop = UOp.variable('k', 0, 50)
optimized_div_uop = apply_rewrite(((k_var_uop * 12 + 8) % 6) // 2)
self.assertEqual(optimized_div_uop.op, Ops.CONST)
self.assertEqual(optimized_div_uop.arg, 1)
def test_graph_rewrite_div_folding_bug(self):
lhs = UOp(Ops.ADD, dtypes.int.vec(4), src=(
UOp(Ops.VECTORIZE, dtypes.int.vec(4), arg=None, src=(UOp(Ops.SPECIAL, dtypes.int, arg=('lidx0', 32), src=()),)*4),
UOp(Ops.VCONST, dtypes.int.vec(4), arg=(0, 256, 512, 768), src=())))
rhs = UOp.const(dtypes.int.vec(4), 2)
unopt = lhs<rhs
opt = apply_rewrite(unopt)
print(unopt)
print(opt)
if opt.op is Ops.VECTORIZE: self.assertFalse(all_same(opt.src))
def test_full_graph_rewrite_modulo_large_divisor(self):
x_var_uop = UOp.variable('x', 1, 5)
self.assertIs(apply_rewrite(x_var_uop % 10), x_var_uop)
def test_full_graph_rewrite_division_with_remainder(self):
x_var_uop = UOp.variable('x', 7, 9)
optimized_sink = apply_rewrite(x_var_uop // 2)
for x_value in range(7, 10):
self.assertEqual(x_value // 2, evaluate_uop(optimized_sink, {'x': x_value}))
def test_full_graph_rewrite_complex_mod_div_expression(self):
x_var_uop = UOp.variable('x', 1, 10)
optimized_sink = apply_rewrite(((x_var_uop * 5) % 3) // 2)
for x_value in range(1, 11):
original_result = ((x_value * 5) % 3) // 2
optimized_result = evaluate_uop(optimized_sink, {'x': x_value})
self.assertEqual(original_result, optimized_result)
class TestEdgeCasesAndSpecialOperations(unittest.TestCase):
def test_full_graph_rewrite_transcendental_edge_cases(self):
optimized_sink = full_graph_rewrite(UOp.const(dtypes.float32, -1.0).log2().sink(UOp.const(dtypes.float32, 0.0).reciprocal()))
optimized_log2_neg, optimized_recip_zero = optimized_sink.src
self.assertTrue(math.isnan(optimized_log2_neg.arg), f"Expected NaN for log2(-1.0), got {optimized_log2_neg.arg}")
self.assertTrue(math.isinf(optimized_recip_zero.arg) and optimized_recip_zero.arg > 0,
f"Expected +inf for reciprocal(0.0), got {optimized_recip_zero.arg}")
@unittest.skip("broken")
def test_full_graph_rewrite_modulo_negative_dividend(self):
x_var_uop = UOp.variable('x', -5, -1)
optimized_sink = full_graph_rewrite((x_var_uop % 3).sink())
for x_value in range(-5, 0):
self.assertEqual(x_value % 3, evaluate_uop(optimized_sink.src[0], {'x': x_value}))
@unittest.skip("broken")
def test_full_graph_rewrite_division_negative_divisor(self):
x_var_uop = UOp.variable('x', 1, 5)
optimized_sink = full_graph_rewrite((x_var_uop // -2).sink())
for x_value in range(1, 6):
self.assertEqual(x_value // -2, evaluate_uop(optimized_sink.src[0], {'x': x_value}))
class TestGEPAndVectorizeRewrite(unittest.TestCase):
def test_gep_single_element_extraction(self):
# GEP on a vector dtype to extract a single element
base_vector = UOp.const(dtypes.float32.vec(4), (1.0, 2.0, 3.0, 4.0))
self.assertEqual(apply_rewrite(base_vector.gep(2)).arg, 3.0)
def test_gep_tuple_extraction(self):
# GEP on a vector dtype to extract multiple elements as a vector
base_vector = UOp.const(dtypes.float32.vec(4), (1.0, 2.0, 3.0, 4.0))
optimized_uop = apply_rewrite(base_vector.gep((2, 3)))
self.assertEqual([sub_uop.arg for sub_uop in optimized_uop.src], [3.0, 4.0])
def test_gep_on_vconst(self):
# GEP on a VCONST to extract a single element
vconst = UOp(Ops.VCONST, dtypes.float32.vec(4), arg=(1.0, 2.0, 3.0, 4.0))
self.assertEqual(apply_rewrite(vconst.gep(2)).arg, 3.0)
def test_gep_tuple_on_vconst(self):
# GEP on a VCONST using a tuple to extract multiple elements
vconst = UOp(Ops.VCONST, dtypes.float32.vec(4), arg=(7.0, 8.0, 9.0, 10.0))
optimized_uop = apply_rewrite(vconst.gep((1, 3)))
self.assertEqual([sub_uop.arg for sub_uop in optimized_uop.src], [8.0, 10.0])
def test_gep_gep_simplification(self):
# Nested GEP simplification on a vector dtype
base_vector = UOp.const(dtypes.float32.vec(4), (10.0, 20.0, 30.0, 40.0))
gep_inner = base_vector.gep(1) # Extract 2nd element (20.0)
self.assertEqual(apply_rewrite(gep_inner.gep(0)).arg, 20.0)
def test_vectorize_multiple_elements(self):
# Vectorizing multiple elements using GEP
base_vector = UOp.const(dtypes.float32.vec(4), (5.0, 10.0, 15.0, 20.0))
vectorized_uop = UOp(Ops.VECTORIZE, dtypes.float32.vec(4), src=(base_vector.gep(0), base_vector.gep(1), base_vector.gep(2), base_vector.gep(3)))
optimized_uop = apply_rewrite(vectorized_uop)
self.assertEqual([sub_uop.arg for sub_uop in optimized_uop.src], [5.0, 10.0, 15.0, 20.0])
import inspect
from tinygrad.ops import graph_rewrite, _substitute, track_rewrites
from tinygrad.codegen.symbolic import symbolic_simple
class TestBottomUpRewrite(unittest.TestCase):
def test_const_folding(self):
a = UOp.const(dtypes.int, 5)
ret = (a*3) + (a*7)
gt = graph_rewrite(ret, symbolic_simple)
ret = graph_rewrite(ret, symbolic_simple, bottom_up=True)
self.assertIs(gt, ret)
# normally .substitute would be fine, but it's not tracked
@track_rewrites()
def named_substitute(name:str, uop:UOp, rel:dict[UOp, UOp]): return graph_rewrite(uop, _substitute, rel, bottom_up=True)
def substitute(uop:UOp, rel:dict[UOp, UOp]): return named_substitute(inspect.stack()[1].function, uop, rel)
class TestSubstitute(unittest.TestCase):
# these work because the substituted things don't have parents
def test_simple(self):
a = UOp.variable('a', 0, 10)
b = UOp.variable('b', 0, 10)
ret = a + 4
ret = substitute(ret, {a:b})
self.assertIs(ret, b+4)
def test_double(self):
a = UOp.variable('a', 0, 10)
b = UOp.variable('b', 0, 10)
c = UOp.variable('c', 0, 10)
ret = (a + 4) + b
ret = substitute(ret, {a:c, b:c})
self.assertIs(ret, (c + 4) + c)
def test_diamond(self):
a = UOp.variable('a', 0, 10)
b = UOp.variable('b', 0, 10)
ret = (a + 4) + (a + 5)
ret = substitute(ret, {a:b})
self.assertIs(ret, (b + 4) + (b + 5))
# this works because there's nothing above the substituted node
def test_sin(self):
a = UOp.variable('a', 0, 10)
b = UOp.variable('b', 0, 10)
ret = a.sin().sin()
ret = substitute(ret, {a.sin():b})
self.assertIs(ret, b.sin())
# broken due to infinite recursion
# NOTE: VIZ hangs and doesn't recover if you click this one
def test_assert_inf_recurse(self):
a = UOp.variable('a', 0, 10)
n1 = a.sin()
ret = n1
with self.assertRaises(RecursionError):
ret = substitute(ret, {n1:n1.sqrt()})
def test_sin_to_sqrt(self):
a = UOp.variable('a', 0, 10)
n1 = a.sin()
ret = n1.sin()
ret = substitute(ret, {a.sin():a.sqrt()})
self.assertIs(ret, a.sqrt().sin())
def test_double_sin_to_sqrt(self):
a = UOp.variable('a', 0, 10)
n1 = a.sin()
ret = n1.sin()
# NOTE: this would work if it had gone in the opposite order
ret = substitute(ret, {a.sin():a.sqrt(), n1.sin():n1.sqrt()})
self.assertIs(ret, a.sqrt().sqrt())
if __name__ == '__main__':
unittest.main()