openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

329 lines
12 KiB

1 month ago
import unittest, itertools
from typing import Tuple
from tinygrad.codegen.uopgraph import full_graph_rewrite, is_increasing
from tinygrad.dtype import dtypes
from tinygrad.ops import UOp, Ops, simplify_valid
def get_gated_load_uop(valid:UOp, idx:UOp):
return UOp(Ops.LOAD, dtypes.float, (
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0).index(idx, valid),
UOp.const(dtypes.float, 0.0)
))
def get_load_image_uop(image_shape:Tuple[int, ...], valid:UOp, idx:Tuple[UOp, UOp]):
return UOp(Ops.LOAD, dtypes.float.vec(4), (
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef(image_shape), arg=0).index(UOp(Ops.VECTORIZE, dtypes.int.vec(2), idx), valid),
UOp(Ops.VECTORIZE, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4)
))
def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.int, (), (expr, nmax))
def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax)
def Range(n, nmax): return UOp(Ops.RANGE, dtypes.int, arg=n, src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, nmax),))
class TestHelpers(unittest.TestCase):
def test_is_increasing(self):
idx1 = Special("idx1", 32)
idx2 = Special("idx2", 64)
ridx0 = Variable("ridx0", 0, 5)
ridx1 = Variable("ridx1", 0, 2)
ridx2 = Variable("ridx2", 0, 2)
# (ridx0+(idx1*48)+(ridx2*6)+(-6)),((idx2*2)+ridx1+(-1)))
f0 = ((idx1*24)+(ridx2*3)+ridx0+765)%768
f1 = ridx0+(idx1*48)+(ridx2*6)+(-6)
f2 = (idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)
f3 = (idx2*2)+ridx1+(-1)
self.assertFalse(is_increasing(f0))
self.assertTrue(is_increasing(f1))
self.assertTrue(is_increasing(f2))
self.assertTrue(is_increasing(f3))
rng = UOp(Ops.RANGE, dtypes.int, arg=(2, True), src=(UOp(Ops.CONST, dtypes.int, arg=0, src=()), UOp(Ops.CONST, dtypes.int, arg=5, src=()),))
self.assertTrue(is_increasing(rng))
self.assertTrue(is_increasing(rng+2))
class TestValidIdxSimplification(unittest.TestCase):
def check(self, load, sidx, svalid):
load = full_graph_rewrite(load.sink()).src[0]
idx, valid = load.src[0].src[1], load.src[2]
self.assertEqual(idx.render(simplify=False), sidx)
self.assertEqual(valid.render(simplify=False), svalid)
def test_cumsum(self):
gidx0 = Special("gidx0", 5)
lidx0 = Special("lidx0", 4)
gate = (gidx0*4+lidx0<19).ne(True)
idx = gidx0*4+lidx0-19
load = get_gated_load_uop(gate, idx)
self.check(load,
"0",
"(((lidx0+(gidx0*4))<19)!=True)")
def test_simplify_within_valid1(self):
ridx0 = Range(0, 4)
ridx1 = Range(1, 4)
ridx2 = Range(2, 4)
ridx3 = Range(3, 4)
valid = ((ridx0*3+ridx1)<8) & ((((ridx0*3+ridx1)//8+ridx2*3+ridx3)%4)<2)
idx = ridx0+ridx1+ridx2+ridx3
load = get_gated_load_uop(valid, idx)
self.check(load,
"(((ridx0+ridx1)+ridx2)+ridx3)",
"((((ridx0*3)+ridx1)<8)&((((ridx2*3)+ridx3)%4)<2))")
def test_simplify_within_valid2(self):
gidx0 = Special("gidx0", 56)
ridx0 = Range(0, 3)
alu0 = gidx0+ridx0
valid = (alu0 < 57) & (alu0 >= 1)
self.assertIsNone(simplify_valid(valid))
def test_valid_order_matters1(self):
ridx0 = Range(0, 2)
v0 = ridx0<1
v1 = ((ridx0*5+1)%6)<5
self.assertEqual(simplify_valid(v0&v1).render(), "(ridx0<1)")
self.assertEqual(simplify_valid(v1&v0).render(), "(ridx0<1)")
def test_valid_order_matters2(self):
gidx0 = Special("gidx0", 13)
gidx1 = Special("gidx1", 13)
ridx0 = Range(0, 4)
alu0 = (gidx1+(ridx0*13))
v0 = (gidx0+11)%14<11
v1 = (alu0+((gidx0+39)//42))%14<11
v2 = gidx0<3
v3 = alu0<42
for v in itertools.permutations([v0,v1,v2,v3]):
self.assertEqual(simplify_valid(v[0]&v[1]&v[2]&v[3]).render(), "False")
@unittest.expectedFailure # TODO: fix
def test_from_merge_views(self):
# taken from test_merges_from_fuzzer1
# generated by
# v0 = View(shape=(2, 4), strides=(2, 1), offset=-2, mask=((0, 2), (2, 4)), contiguous=False)
# v1 = View(shape=(2, 4, 2, 2), strides=(4, 0, -2, -1), offset=3, mask=None, contiguous=False)
# s = ShapeTracker((v0, v1))
# idx, valid = s.to_indexed_uops()
# print(f"{idx.render()=}")
# print(f"{valid.render()=}")
# s = ShapeTracker((View(shape=(2, 4, 2, 2), strides=(2, 0, 0, -1), offset=1, mask=((0, 2), (0, 4), (0, 1), (0, 2)), contiguous=False),))
# idx, valid = s.to_indexed_uops()
# print(f"{idx.render()=}")
# print(f"{valid.render()=}")
ridx0 = Range(0, 2)
ridx2 = Range(2, 2)
ridx3 = Range(3, 2)
idx = (((ridx0*2)+((((ridx2*2)+(ridx3*3))+3)%4))+-2)
valid = ((((((ridx2*2)+(ridx3*3))+3)%4)<2)!=True) # noqa: E712
load = get_gated_load_uop(valid, idx)
self.check(load,
"(((ridx0*2)+(ridx3*-1))+1)",
"(ridx2<1)")
class TestImageSimplification(unittest.TestCase):
def check(self, load, svalid, sidx0, sidx1):
load = full_graph_rewrite(load.sink()).src[0]
idx = load.src[0].src[1]
self.assertEqual(idx.op, Ops.VECTORIZE)
self.assertEqual(len(idx.src), 2)
idx0, idx1 = idx.src[0], idx.src[1]
self.assertEqual(idx0.render(simplify=False), sidx0)
self.assertEqual(idx1.render(simplify=False), sidx1)
if svalid is not None: self.assertEqual(load.src[2].render(simplify=False), svalid)
def test_idx_gt_c(self):
# (idx1 < c+1).ne(True) ? (..., idx1-1+c) : 0 can drop the valid
# (idx1 < c+1).ne(True) -> idx > c
gidx0 = Special("gidx0", 32)
gidx1 = Special("gidx1", 32)
shape = (10, 10, 4)
load = get_load_image_uop(shape, (gidx1<1).ne(True), (gidx0, gidx1-1))
self.check(load, None, "gidx0", "(gidx1+-1)")
load = get_load_image_uop(shape, (gidx1<1).ne(True), (gidx0, gidx1-2))
self.check(load, None, "gidx0", "(gidx1+-2)")
# should match any one of the AND clause and drop the matched statement from valid
valid = (gidx0<1).ne(True) & (gidx1<1).ne(True)
load = get_load_image_uop(shape, valid, (gidx0+1, gidx1-1))
self.check(load, "((gidx0<1)!=True)", "(gidx0+1)", "(gidx1+-1)")
valid = (gidx0<1).ne(True) & (gidx1<1).ne(True)
load = get_load_image_uop(shape, valid, (gidx0, gidx1-1))
self.check(load, None, "gidx0", "(gidx1+-1)")
def test_idx_lt_bound(self):
# (idx1 < image_bound) ? (..., idx1) : 0 can drop the valid
gidx0 = Special("gidx0", 32)
gidx1 = Special("gidx1", 32)
load = get_load_image_uop((10, 10, 4), gidx1<10, (gidx0, gidx1))
self.check(load, None, "gidx0", "gidx1")
# same thing, valid has a div
load = get_load_image_uop((10, 10, 4), gidx1//2<5, (gidx0, gidx1))
self.check(load, None, "gidx0", "gidx1")
# 10x20 image, not out of bound
load = get_load_image_uop((20, 10, 4), gidx1<10, (gidx0, gidx1))
self.check(load, "(gidx1<10)", "gidx0", "gidx1")
def test_generic_idx_lt_bound(self):
# (idx1 < image_bound - c) ? (..., idx1 + c) : 0 can drop the valid
gidx0 = Special("gidx0", 32)
gidx1 = Special("gidx1", 32)
shape = (10, 10, 4)
load = get_load_image_uop(shape, (gidx1<8), (gidx0, gidx1+2))
self.check(load, None, "gidx0", "(gidx1+2)")
load = get_load_image_uop(shape, (gidx1<5), (gidx0, gidx1+5))
self.check(load, None, "gidx0", "(gidx1+5)")
def test_valid_empty_set(self):
gidx0 = Special("gidx0", 32)
gidx1 = Special("gidx1", 32)
shape = (32, 32, 4)
idx = (gidx0%2, gidx1+2)
# not empty
load = get_load_image_uop(shape, gidx0<8, idx)
self.check(load, "(gidx0<8)", "(gidx0%2)", "(gidx1+2)")
# empty -> invalid
load = get_load_image_uop(shape, (gidx0<8) & (gidx0<8).ne(True), idx)
load = full_graph_rewrite(load.sink()).src[0]
self.assertEqual(load.op, Ops.VECTORIZE)
self.assertEqual(load.dtype.count, 4)
def test_openpilot_conv1(self):
# first conv in openpilot
# kernel in tinygrad ae5d1407ee844a97a52ad3756835d38e7e2b9e1b https://gist.github.com/chenyuxyz/39c2d4e9a076b46731c67d345ff066b6
idx1 = Special("idx1", 32)
idx2 = Special("idx2", 64)
# ridx0 = Variable("ridx0", 0, 5)
# ridx1 = Variable("ridx1", 0, 2)
# ridx2 = Variable("ridx2", 0, 2)
ridx0 = Range(0, 6)
ridx1 = Range(1, 3)
ridx2 = Range(2, 3)
alu1 = ((idx2*2)+ridx1)
alu4 = ((idx1*48)+(ridx2*6)+ridx0)
valid = ((((idx2*2)+(ridx1))<1).ne(True))&((((idx1*8)+(ridx2))<1).ne(True))
shape = (128, 1536, 4)
idx = ((alu4+1530)%1536, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2))
load = get_load_image_uop(shape, valid, idx)
self.check(load, None, "((((idx1*48)+(ridx2*6))+ridx0)+-6)", "(((idx2*2)+ridx1)+-1)")
def test_openpilot_conv2(self):
# conv in test/external/external_test_valid_remove.py
idx1 = Special("idx1", 32)
idx2 = Special("idx2", 64)
# ridx0 = Variable("ridx0", 0, 2)
# ridx1 = Variable("ridx1", 0, 2)
# ridx2 = Variable("ridx2", 0, 2)
ridx0 = Range(0, 3)
ridx1 = Range(1, 3)
ridx2 = Range(2, 3)
alu1 = ((idx2*2)+ridx1)
alu3 = ((idx1*24)+(ridx2*3)+ridx0)
valid = ((((idx2*2)+ridx1)<1).ne(True))&((((idx1*8)+ridx2)<1).ne(True))
shape = (128, 768, 4)
idx = ((alu3+765)%768, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2))
load = get_load_image_uop(shape, valid, idx)
self.check(load, None, "((((idx1*24)+(ridx2*3))+ridx0)+-3)", "(((idx2*2)+ridx1)+-1)")
def test_openpilot_conv3(self):
# in openpilot 0.9.7
idx0 = Special("idx0", 64)
idx1 = Special("idx1", 2)
idx2 = Special("idx2", 4)
ridx0 = Range(0, 7)
ridx1 = Range(1, 7)
alu2 = ((idx2*2)+ridx0)
alu4 = ((idx1*8)+ridx1)
alu6 = ((idx1*512)+(ridx1*64)+idx0)
valid = (alu2<11)&(alu4<3).ne(True)
shape = (8, 1024, 4)
idx = (((alu6+832)%1024),(alu2+((idx1+((ridx1+5)//8)+1)//2)+(-4)))
load = get_load_image_uop(shape, valid, idx)
self.check(load,
"((((idx2*2)+ridx0)<11)&((((idx1*8)+ridx1)<3)!=True))",
"(((idx0+((idx1*512)+(ridx1*64)))+832)%1024)",
"((((idx2*2)+ridx0)+(((idx1+((ridx1+5)//8))+1)//2))+-4)")
def test_simplify1(self):
# idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1)
gidx = Special("gidx", 512)
valid = (gidx<488) & (gidx<480).ne(True)
idx = ((gidx*3+18)%26, (gidx*3+18)//26-56)
load = get_load_image_uop((1, 26, 4), valid, idx)
self.check(load, None, "((gidx*3)+-1438)", "0")
def test_simplify2(self):
# from GPU=1 DEBUG=4 FORWARD_ONLY=1 IMAGE=2 python3 test/test_ops.py TestOps.test_simple_padding_conv2d
lidx = Special("lidx", 4)
valid = (lidx<3) & (lidx<1).ne(True)
idx = ((lidx+1)%2, (lidx+1)//2-1)
load = get_load_image_uop((1, 2, 4), valid, idx)
self.check(load, None, "(lidx+-1)", "0")
def test_simplify3(self):
# from openpilot
idx0 = Special("idx0", 265)
valid = (idx0<201).ne(True)
idx = ((idx0+55)%64, (idx0+55)//64-4)
load = get_load_image_uop((1, 64, 4), valid, idx)
self.check(load, None, "(idx0+-201)", "0")
def test_simplify4(self):
idx0 = Special("idx0", 512)
shape = (4, 64, 4)
alu2 = ((idx0*4+1)%32)
alu3 = ((idx0*4+2)%32)
alu4 = ((idx0*4+3)%32)
alu5 = (idx0*4%32)
alu8 = (idx0//8%32//4)
alu9 = idx0<256
# TODO: can this be simplified further?
load = get_load_image_uop(shape, alu9, (((alu8+(alu2*8))%64),(alu2//8)))
self.check(load, "(idx0<256)", "(((((idx0%8)*32)+(idx0//32))+8)%64)", "((idx0%8)//2)")
load = get_load_image_uop(shape, alu9, (((alu8+(alu3*8))%64),(alu3//8)))
self.check(load, "(idx0<256)", "(((((idx0%8)*32)+(idx0//32))+16)%64)", "((idx0%8)//2)")
load = get_load_image_uop(shape, alu9, (((alu8+(alu4*8))%64),(alu4//8)))
self.check(load, "(idx0<256)", "(((((idx0%8)*32)+(idx0//32))+24)%64)", "((idx0%8)//2)")
load = get_load_image_uop(shape, alu9, (((alu8+(alu5*8))%64),(alu5//8)))
self.check(load, "(idx0<256)", "((((idx0%8)*32)+(idx0//32))%64)", "((idx0%8)//2)")
def test_simplify5(self):
# openpilot 0.9.7, chunk replacement to simplify
shape = (10, 384, 4)
idx0 = Special("idx0", 16)
idx1 = Special("idx1", 24)
alu0 = idx0*4
alu1 = (idx1*256)+alu0
alu2 = idx1//3
alu3 = ((alu1+1)%768)
idx = ((idx0+((((alu3//640)+alu2)%8)*16)+128),((alu3//64)%10))
valid = alu3<640
load = get_load_image_uop(shape, valid, idx)
self.check(load, "(((idx0+(idx1*64))%192)<160)", "((idx0+((idx1//3)*16))+128)", "(((idx0+(idx1*64))%192)//16)")
if __name__ == '__main__':
unittest.main()