openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

434 lines
16 KiB

import unittest, itertools
from tinygrad.codegen import full_rewrite_to_sink
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import UOp, Ops
from tinygrad.uop.symbolic import simplify_valid
from tinygrad.helpers import Context
from test.unit.test_uop_symbolic import check_uop_against_string
def get_gated_load_uop(valid:UOp, idx:UOp):
return UOp(Ops.LOAD, dtypes.float, (
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0).index(idx.valid(valid), ptr=True),
UOp.const(dtypes.float, 0.0)
))
def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]):
return UOp(Ops.LOAD, dtypes.float.vec(4), (
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef(image_shape), arg=0).index(UOp(Ops.VECTORIZE, dtypes.index.vec(2), idx).valid(valid), ptr=True),
UOp(Ops.VECTORIZE, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4)
))
def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, nmax),), expr)
def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax)
def Range(n, nmax): return UOp.range(nmax, n)
class TestHelpers(unittest.TestCase):
def test_is_increasing(self):
idx1 = Special("idx1", 32)
idx2 = Special("idx2", 64)
ridx0 = Variable("ridx0", 0, 5)
ridx1 = Variable("ridx1", 0, 2)
ridx2 = Variable("ridx2", 0, 2)
# (ridx0+(idx1*48)+(ridx2*6)+(-6)),((idx2*2)+ridx1+(-1)))
f0 = ((idx1*24)+(ridx2*3)+ridx0+765)%768
f1 = ridx0+(idx1*48)+(ridx2*6)+(-6)
f2 = (idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)
f3 = (idx2*2)+ridx1+(-1)
self.assertFalse(f0.is_increasing())
self.assertTrue(f1.is_increasing())
self.assertTrue(f2.is_increasing())
self.assertTrue(f3.is_increasing())
rng = UOp.range(5, 2)
self.assertTrue(rng.is_increasing())
self.assertTrue((rng+2).is_increasing())
class TestValidIdxSimplification(unittest.TestCase):
def check(self, load, sidx, svalid):
with Context(NOOPT=1, SPEC=0):
load = full_rewrite_to_sink(load.sink()).src[0]
idx, valid = load.src[0].src[1], load.src[0].src[2]
check_uop_against_string(self, idx, sidx)
check_uop_against_string(self, valid, svalid)
def test_cumsum(self):
gidx0 = Special("gidx0", 5)
lidx0 = Special("lidx0", 4)
gate = (gidx0*4+lidx0<19).ne(True)
idx = gidx0*4+lidx0-19
load = get_gated_load_uop(gate, idx)
self.check(load,
"0",
"(((lidx0+(gidx0*4))<19)!=True)")
def test_simplify_within_valid1(self):
ridx0 = Range(0, 4)
ridx1 = Range(1, 4)
ridx2 = Range(2, 4)
ridx3 = Range(3, 4)
valid = ((ridx0*3+ridx1)<8) & ((((ridx0*3+ridx1)//8+ridx2*3+ridx3)%4)<2)
idx = ridx0+ridx1+ridx2+ridx3
load = get_gated_load_uop(valid, idx)
self.check(load,
"(((r0+r1)+r2)+r3)",
"((((r0*3)+r1)<8)&((((r2*3)+r3)%4)<2))")
def test_simplify_within_valid2(self):
gidx0 = Special("gidx0", 56)
ridx0 = Range(0, 3)
alu0 = gidx0+ridx0
valid = (alu0 < 57) & (alu0 >= 1)
self.assertIsNone(simplify_valid(valid))
def test_valid_order_matters1(self):
ridx0 = Range(0, 2)
v0 = ridx0<1
v1 = ((ridx0*5+1)%6)<5
self.assertEqual(simplify_valid(v0&v1).render(), "(r0<1)")
self.assertEqual(simplify_valid(v1&v0).render(), "(r0<1)")
def test_valid_order_matters2(self):
gidx0 = Special("gidx0", 13)
gidx1 = Special("gidx1", 13)
ridx0 = Range(0, 4)
alu0 = (gidx1+(ridx0*13))
v0 = (gidx0+11)%14<11
v1 = (alu0+((gidx0+39)//42))%14<11
v2 = gidx0<3
v3 = alu0<42
for v in itertools.permutations([v0,v1,v2,v3]):
self.assertEqual(simplify_valid(v[0]&v[1]&v[2]&v[3]).render(), "False")
def test_simplify_valid_from_div(self):
x = Variable("x", -100, 100)
valid = ((x<0)&((100%x).cast(dtypes.bool)))
# NOTE: this simplifies the (100%x) part somehow, still has two clauses
self.assertIsNotNone(simplify_valid(valid))
self.assertEqual(len(list(valid.split_uop(Ops.AND))), 2)
@unittest.expectedFailure # TODO: fix
def test_from_merge_views(self):
# taken from test_merges_from_fuzzer1
# generated by
# v0 = View(shape=(2, 4), strides=(2, 1), offset=-2, mask=((0, 2), (2, 4)), contiguous=False)
# v1 = View(shape=(2, 4, 2, 2), strides=(4, 0, -2, -1), offset=3, mask=None, contiguous=False)
# s = ShapeTracker((v0, v1))
# idx, valid = s.to_indexed_uops()
# print(f"{idx.render()=}")
# print(f"{valid.render()=}")
# s = ShapeTracker((View(shape=(2, 4, 2, 2), strides=(2, 0, 0, -1), offset=1, mask=((0, 2), (0, 4), (0, 1), (0, 2)), contiguous=False),))
# idx, valid = s.to_indexed_uops()
# print(f"{idx.render()=}")
# print(f"{valid.render()=}")
ridx0 = Range(0, 2)
ridx2 = Range(2, 2)
ridx3 = Range(3, 2)
idx = (((ridx0*2)+((((ridx2*2)+(ridx3*3))+3)%4))+-2)
valid = ((((((ridx2*2)+(ridx3*3))+3)%4)<2)!=True) # noqa: E712
load = get_gated_load_uop(valid, idx)
self.check(load,
"(((r0*2)+(r3*-1))+1)",
"(r2<1)")
def test_load_in_valid(self):
# from FUSE_ARANGE=1 python test/test_ops.py TestOps.test_scatter_add
# can lead to OOB
ridx2 = Range(2, 4)
lidx0 = Special("lidx0", 3)
gidx0 = Special("gidx0", 2)
idx=(((lidx0+(gidx0*3))+(ridx2*5))+40)
valid = (lidx0+(gidx0*3)) < 5
val7 = get_gated_load_uop(valid, idx)
valid2 = valid & val7.cast(dtypes.bool).logical_not()
self.assertIsNone(simplify_valid(valid2))
def test_valid_becomes_const1(self):
# from DSP mobilenetv2
ridx0 = Range(0, 30)
ridx1 = Range(1, 7)
ridx2 = Range(2, 2)
alu11 = (ridx1+ridx2)
alu15 = ((alu11+1)//7)
idx = (alu15*-31)+(((((alu11+218)//224)+ridx0)%30)*1568)
valid = (ridx2<1)&(ridx1<6)
load = get_gated_load_uop(valid, idx)
self.check(load,
"(r0*1568)",
"((r2<1)&(r1<6))")
def test_valid_becomes_const1_z3(self):
from z3 import Ints, Solver, And, If, Not, unsat
ridx0, ridx1, ridx2, alu11, alu15 = Ints('ridx0 ridx1 ridx2 alu11 alu15')
alu11 = (ridx1+ridx2)
alu15 = ((alu11+1)/7)
idx = (alu15*-31)+(((((alu11+218)/224)+ridx0)%30)*1568)
valid = (ridx2<1)&(ridx1<6)
load = If(valid, idx, 0)
# correct simplification
s = Solver()
s.add(And(0<=ridx0, ridx0<30, 0<=ridx1, ridx1<7, 0<=ridx2, ridx2<2))
simplifed_idx = (ridx0*1568)
simplifed_load = If(valid, simplifed_idx, 0)
s.add(Not(load == simplifed_load)) # Check if they are NOT equivalent
assert s.check() == unsat, f"The expressions are not equivalent. {s.model()=}"
# new solver for a wrong simplified expression
s = Solver()
s.add(And(0<=ridx0, ridx0<30, 0<=ridx1, ridx1<7, 0<=ridx2, ridx2<2))
wrong_simplifed_idx = (ridx0*1567)+ridx1
wrong_simplifed_load = If(valid, wrong_simplifed_idx, 0)
s.add(Not(load == wrong_simplifed_load)) # Check if they are NOT equivalent
assert s.check() != unsat, "The expressions are equivalent??"
print("The expressions are not equivalent.")
print(s.model())
def test_valid_becomes_const2(self):
ridx0 = Range(0, 4)
ridx1 = Range(1, 4)
ridx2 = Range(2, 4)
ridx3 = Range(3, 4)
# TODO: this should also work without the extra nesting
idx = (((ridx0+ridx1)+(ridx2+ridx3)+28)//30)
valid = ((ridx0+ridx1)<1).ne(True) & ((ridx2+ridx3)<1).ne(True)
load = get_gated_load_uop(valid, idx)
self.check(load,
"1",
"((((r0+r1)<1)!=True)&(((r2+r3)<1)!=True))")
def test_valid_with_non_const_rhs(self):
ridx0 = Range(0, 1024)
ridx1 = Range(1, 4)
ridx2 = Range(2, 4)
valid = (ridx0<(ridx1*4 + ridx2))&(ridx0<-1).ne(True)
idx = ridx0
load = get_gated_load_uop(valid, idx)
self.check(load,
"r0",
"(r0<((r1*4)+r2))")
class TestImageSimplification(unittest.TestCase):
def check(self, load, svalid, sidx0, sidx1):
with Context(NOOPT=1, SPEC=0):
load = full_rewrite_to_sink(load.sink()).src[0]
idx = load.src[0].src[1]
self.assertEqual(idx.op, Ops.VECTORIZE)
self.assertEqual(len(idx.src), 2)
idx0, idx1 = idx.src[0], idx.src[1]
check_uop_against_string(self, idx0, sidx0)
check_uop_against_string(self, idx1, sidx1)
if svalid is not None:
check_uop_against_string(self, load.src[0].src[2], svalid)
else:
self.assertEqual(len(load.src[0].src), 2, "svalid is None but load still has a valid")
def test_idx_gt_c(self):
# (idx1 < c+1).ne(True) ? (..., idx1-1+c) : 0 can drop the valid
# (idx1 < c+1).ne(True) -> idx > c
gidx0 = Special("gidx0", 32)
gidx1 = Special("gidx1", 32)
shape = (10, 10, 4)
load = get_load_image_uop(shape, (gidx1<1).ne(True), (gidx0, gidx1-1))
self.check(load, None, "gidx0", "(gidx1+-1)")
load = get_load_image_uop(shape, (gidx1<1).ne(True), (gidx0, gidx1-2))
self.check(load, None, "gidx0", "(gidx1+-2)")
# should match any one of the AND clause and drop the matched statement from valid
valid = (gidx0<1).ne(True) & (gidx1<1).ne(True)
load = get_load_image_uop(shape, valid, (gidx0+1, gidx1-1))
self.check(load, "((gidx0<1)!=True)", "(gidx0+1)", "(gidx1+-1)")
valid = (gidx0<1).ne(True) & (gidx1<1).ne(True)
load = get_load_image_uop(shape, valid, (gidx0, gidx1-1))
self.check(load, "((gidx0<1)!=True)", "gidx0", "(gidx1+-1)")
def test_idx_lt_bound(self):
# (idx1 < image_bound) ? (..., idx1) : 0 can drop the valid
gidx0 = Special("gidx0", 32)
gidx1 = Special("gidx1", 32)
load = get_load_image_uop((10, 10, 4), gidx1<10, (gidx0, gidx1))
self.check(load, None, "gidx0", "gidx1")
# same thing, valid has a div
load = get_load_image_uop((10, 10, 4), gidx1//2<5, (gidx0, gidx1))
self.check(load, None, "gidx0", "gidx1")
# 10x20 image, not out of bound
load = get_load_image_uop((20, 10, 4), gidx1<10, (gidx0, gidx1))
self.check(load, "(gidx1<10)", "gidx0", "gidx1")
def test_generic_idx_lt_bound(self):
# (idx1 < image_bound - c) ? (..., idx1 + c) : 0 can drop the valid
gidx0 = Special("gidx0", 32)
gidx1 = Special("gidx1", 32)
shape = (10, 10, 4)
load = get_load_image_uop(shape, (gidx1<8), (gidx0, gidx1+2))
self.check(load, None, "gidx0", "(gidx1+2)")
load = get_load_image_uop(shape, (gidx1<5), (gidx0, gidx1+5))
self.check(load, None, "gidx0", "(gidx1+5)")
def test_valid_empty_set(self):
gidx0 = Special("gidx0", 32)
gidx1 = Special("gidx1", 32)
shape = (32, 32, 4)
idx = (gidx0%2, gidx1+2)
# not empty
load = get_load_image_uop(shape, gidx0<8, idx)
self.check(load, "(gidx0<8)", "(gidx0%2)", "(gidx1+2)")
# empty -> invalid
load = get_load_image_uop(shape, (gidx0<8) & (gidx0<8).ne(True), idx)
with Context(NOOPT=1, SPEC=0):
load = full_rewrite_to_sink(load.sink()).src[0]
self.assertEqual(load.op, Ops.VECTORIZE)
self.assertEqual(load.dtype.count, 4)
def test_openpilot_conv1(self):
# first conv in openpilot
# kernel in tinygrad ae5d1407ee844a97a52ad3756835d38e7e2b9e1b https://gist.github.com/chenyuxyz/39c2d4e9a076b46731c67d345ff066b6
idx1 = Special("idx1", 32)
idx2 = Special("idx2", 64)
# ridx0 = Variable("ridx0", 0, 5)
# ridx1 = Variable("ridx1", 0, 2)
# ridx2 = Variable("ridx2", 0, 2)
ridx0 = Range(0, 6)
ridx1 = Range(1, 3)
ridx2 = Range(2, 3)
alu1 = ((idx2*2)+ridx1)
alu4 = ((idx1*48)+(ridx2*6)+ridx0)
valid = ((((idx2*2)+(ridx1))<1).ne(True))&((((idx1*8)+(ridx2))<1).ne(True))
shape = (128, 1536, 4)
idx = ((alu4+1530)%1536, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2))
load = get_load_image_uop(shape, valid, idx)
self.check(load, None, "((((idx1*48)+(r2*6))+r0)+-6)", "(((idx2*2)+r1)+-1)")
def test_openpilot_conv2(self):
# conv in test/external/external_test_valid_remove.py
idx1 = Special("idx1", 32)
idx2 = Special("idx2", 64)
# ridx0 = Variable("ridx0", 0, 2)
# ridx1 = Variable("ridx1", 0, 2)
# ridx2 = Variable("ridx2", 0, 2)
ridx0 = Range(0, 3)
ridx1 = Range(1, 3)
ridx2 = Range(2, 3)
alu1 = ((idx2*2)+ridx1)
alu3 = ((idx1*24)+(ridx2*3)+ridx0)
valid = ((((idx2*2)+ridx1)<1).ne(True))&((((idx1*8)+ridx2)<1).ne(True))
shape = (128, 768, 4)
idx = ((alu3+765)%768, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2))
load = get_load_image_uop(shape, valid, idx)
self.check(load, None, "((((idx1*24)+(r2*3))+r0)+-3)", "(((idx2*2)+r1)+-1)")
def test_openpilot_conv3(self):
# in openpilot 0.9.7
idx0 = Special("idx0", 64)
idx1 = Special("idx1", 2)
idx2 = Special("idx2", 4)
ridx0 = Range(0, 7)
ridx1 = Range(1, 7)
alu2 = ((idx2*2)+ridx0)
alu4 = ((idx1*8)+ridx1)
alu6 = ((idx1*512)+(ridx1*64)+idx0)
valid = (alu2<11)&(alu4<3).ne(True)
shape = (8, 1024, 4)
idx = (((alu6+832)%1024),(alu2+((idx1+((ridx1+5)//8)+1)//2)+(-4)))
load = get_load_image_uop(shape, valid, idx)
self.check(load,
"((((idx2*2)+r0)<11)&((((idx1*8)+r1)<3)!=True))",
"(((idx0+((idx1*512)+(r1*64)))+832)%1024)",
"((((idx2*2)+r0)+(((idx1+((r1+5)//8))+1)//2))+-4)")
def test_simplify1(self):
# idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1)
gidx = Special("gidx", 512)
valid = (gidx<488) & (gidx<480).ne(True)
idx = ((gidx*3+18)%26, (gidx*3+18)//26-56)
load = get_load_image_uop((1, 26, 4), valid, idx)
self.check(load, None, "((gidx*3)+-1438)", "0")
def test_simplify2(self):
# from CL=1 DEBUG=4 FORWARD_ONLY=1 IMAGE=2 python3 test/test_ops.py TestOps.test_simple_padding_conv2d
lidx = Special("lidx", 4)
valid = (lidx<3) & (lidx<1).ne(True)
idx = ((lidx+1)%2, (lidx+1)//2-1)
load = get_load_image_uop((1, 2, 4), valid, idx)
self.check(load, None, "(lidx+-1)", "0")
def test_simplify3(self):
# from openpilot
idx0 = Special("idx0", 265)
valid = (idx0<201).ne(True)
idx = ((idx0+55)%64, (idx0+55)//64-4)
load = get_load_image_uop((1, 64, 4), valid, idx)
self.check(load, None, "(idx0+-201)", "0")
def test_simplify4(self):
idx0 = Special("idx0", 512)
shape = (4, 64, 4)
alu2 = ((idx0*4+1)%32)
alu3 = ((idx0*4+2)%32)
alu4 = ((idx0*4+3)%32)
alu5 = (idx0*4%32)
alu8 = (idx0//8%32//4)
alu9 = idx0<256
# TODO: can this be simplified further?
load = get_load_image_uop(shape, alu9, (((alu8+(alu2*8))%64),(alu2//8)))
self.check(load, "(idx0<256)", "(((((idx0%8)*32)+(idx0//32))+8)%64)", "((idx0%8)//2)")
load = get_load_image_uop(shape, alu9, (((alu8+(alu3*8))%64),(alu3//8)))
self.check(load, "(idx0<256)", "(((((idx0%8)*32)+(idx0//32))+16)%64)", "((idx0%8)//2)")
load = get_load_image_uop(shape, alu9, (((alu8+(alu4*8))%64),(alu4//8)))
self.check(load, "(idx0<256)", "(((((idx0%8)*32)+(idx0//32))+24)%64)", "((idx0%8)//2)")
load = get_load_image_uop(shape, alu9, (((alu8+(alu5*8))%64),(alu5//8)))
self.check(load, "(idx0<256)", "((((idx0%8)*32)+(idx0//32))%64)", "((idx0%8)//2)")
def test_simplify5(self):
# openpilot 0.9.7, chunk replacement to simplify
shape = (10, 384, 4)
idx0 = Special("idx0", 16)
idx1 = Special("idx1", 24)
alu0 = idx0*4
alu1 = (idx1*256)+alu0
alu2 = idx1//3
alu3 = ((alu1+1)%768)
idx = ((idx0+((((alu3//640)+alu2)%8)*16)+128),((alu3//64)%10))
valid = alu3<640
load = get_load_image_uop(shape, valid, idx)
self.check(load, "(((idx0+(idx1*64))%192)<160)", "((idx0+((idx1//3)*16))+128)", "(((idx0+(idx1*64))%192)//16)")
def test_simplify6(self):
# from openpilot
# the valid implies the numerator of the div/mod is positive and can be simplified with floordiv rules
idx1 = Special("idx1", 16)
idx2 = Special("idx2", 64)
ridx3 = Range(3, 3)
ridx4 = Range(4, 3)
ridx5 = Range(5, 3)
alu0 = ((idx2*1536)+(ridx4*768)+ridx3+(idx1*24)+(ridx5*3)+-771)%768
alu1 = ((idx2*1536)+(ridx4*768)+ridx3+(idx1*24)+(ridx5*3)+-771)//768
valid = (((idx2+ridx4)<1)!=1)&(((idx1+ridx5)<1)!=1)
load = get_load_image_uop((128, 768, 4), valid, (alu0, alu1))
self.check(load, None, "((((idx1*24)+r3)+(r5*3))+-3)", "(((idx2*2)+r4)+-1)")
if __name__ == '__main__':
unittest.main()