import unittest , itertools
from tinygrad . codegen . devectorizer import full_graph_rewrite
from tinygrad . dtype import dtypes
from tinygrad . ops import UOp , Ops
from tinygrad . codegen . symbolic import simplify_valid
def get_gated_load_uop ( valid : UOp , idx : UOp ) :
return UOp ( Ops . LOAD , dtypes . float , (
UOp ( Ops . DEFINE_GLOBAL , dtypes . float . ptr ( ) , arg = 0 ) . index ( idx , valid ) ,
UOp . const ( dtypes . float , 0.0 )
) )
def get_load_image_uop ( image_shape : tuple [ int , . . . ] , valid : UOp , idx : tuple [ UOp , UOp ] ) :
return UOp ( Ops . LOAD , dtypes . float . vec ( 4 ) , (
UOp ( Ops . DEFINE_GLOBAL , dtypes . imagef ( image_shape ) , arg = 0 ) . index ( UOp ( Ops . VECTORIZE , dtypes . int . vec ( 2 ) , idx ) , valid ) ,
UOp ( Ops . VECTORIZE , dtypes . float . vec ( 4 ) , src = ( UOp . const ( dtypes . float , 0.0 ) , ) * 4 )
) )
def Special ( expr , nmax ) : return UOp ( Ops . SPECIAL , dtypes . int , ( ) , ( expr , nmax ) )
def Variable ( expr , nmin , nmax ) : return UOp . variable ( expr , nmin , nmax )
def Range ( n , nmax ) : return UOp ( Ops . RANGE , dtypes . int , arg = n , src = ( UOp . const ( dtypes . int , 0 ) , UOp . const ( dtypes . int , nmax ) , ) )
class TestHelpers ( unittest . TestCase ) :
def test_is_increasing ( self ) :
idx1 = Special ( " idx1 " , 32 )
idx2 = Special ( " idx2 " , 64 )
ridx0 = Variable ( " ridx0 " , 0 , 5 )
ridx1 = Variable ( " ridx1 " , 0 , 2 )
ridx2 = Variable ( " ridx2 " , 0 , 2 )
# (ridx0+(idx1*48)+(ridx2*6)+(-6)),((idx2*2)+ridx1+(-1)))
f0 = ( ( idx1 * 24 ) + ( ridx2 * 3 ) + ridx0 + 765 ) % 768
f1 = ridx0 + ( idx1 * 48 ) + ( ridx2 * 6 ) + ( - 6 )
f2 = ( idx2 * 2 ) + ridx1 + ( ( idx1 + ( ( ridx2 + 7 ) / / 8 ) + 31 ) / / 32 ) + ( - 2 )
f3 = ( idx2 * 2 ) + ridx1 + ( - 1 )
self . assertFalse ( f0 . is_increasing ( ) )
self . assertTrue ( f1 . is_increasing ( ) )
self . assertTrue ( f2 . is_increasing ( ) )
self . assertTrue ( f3 . is_increasing ( ) )
rng = UOp ( Ops . RANGE , dtypes . int , arg = ( 2 , True ) , src = ( UOp ( Ops . CONST , dtypes . int , arg = 0 , src = ( ) ) , UOp ( Ops . CONST , dtypes . int , arg = 5 , src = ( ) ) , ) )
self . assertTrue ( rng . is_increasing ( ) )
self . assertTrue ( ( rng + 2 ) . is_increasing ( ) )
class TestValidIdxSimplification ( unittest . TestCase ) :
def check ( self , load , sidx , svalid ) :
load = full_graph_rewrite ( load . sink ( ) ) . src [ 0 ]
idx , valid = load . src [ 0 ] . src [ 1 ] , load . src [ 0 ] . src [ 2 ]
self . assertEqual ( idx . render ( simplify = False ) , sidx )
self . assertEqual ( valid . render ( simplify = False ) , svalid )
def test_cumsum ( self ) :
gidx0 = Special ( " gidx0 " , 5 )
lidx0 = Special ( " lidx0 " , 4 )
gate = ( gidx0 * 4 + lidx0 < 19 ) . ne ( True )
idx = gidx0 * 4 + lidx0 - 19
load = get_gated_load_uop ( gate , idx )
self . check ( load ,
" 0 " ,
" (((lidx0+(gidx0*4))<19)!=True) " )
def test_simplify_within_valid1 ( self ) :
ridx0 = Range ( 0 , 4 )
ridx1 = Range ( 1 , 4 )
ridx2 = Range ( 2 , 4 )
ridx3 = Range ( 3 , 4 )
valid = ( ( ridx0 * 3 + ridx1 ) < 8 ) & ( ( ( ( ridx0 * 3 + ridx1 ) / / 8 + ridx2 * 3 + ridx3 ) % 4 ) < 2 )
idx = ridx0 + ridx1 + ridx2 + ridx3
load = get_gated_load_uop ( valid , idx )
self . check ( load ,
" (((ridx0+ridx1)+ridx2)+ridx3) " ,
" ((((ridx0*3)+ridx1)<8)&((((ridx2*3)+ridx3) % 4)<2)) " )
def test_simplify_within_valid2 ( self ) :
gidx0 = Special ( " gidx0 " , 56 )
ridx0 = Range ( 0 , 3 )
alu0 = gidx0 + ridx0
valid = ( alu0 < 57 ) & ( alu0 > = 1 )
self . assertIsNone ( simplify_valid ( valid ) )
def test_valid_order_matters1 ( self ) :
ridx0 = Range ( 0 , 2 )
v0 = ridx0 < 1
v1 = ( ( ridx0 * 5 + 1 ) % 6 ) < 5
self . assertEqual ( simplify_valid ( v0 & v1 ) . render ( ) , " (ridx0<1) " )
self . assertEqual ( simplify_valid ( v1 & v0 ) . render ( ) , " (ridx0<1) " )
def test_valid_order_matters2 ( self ) :
gidx0 = Special ( " gidx0 " , 13 )
gidx1 = Special ( " gidx1 " , 13 )
ridx0 = Range ( 0 , 4 )
alu0 = ( gidx1 + ( ridx0 * 13 ) )
v0 = ( gidx0 + 11 ) % 14 < 11
v1 = ( alu0 + ( ( gidx0 + 39 ) / / 42 ) ) % 14 < 11
v2 = gidx0 < 3
v3 = alu0 < 42
for v in itertools . permutations ( [ v0 , v1 , v2 , v3 ] ) :
self . assertEqual ( simplify_valid ( v [ 0 ] & v [ 1 ] & v [ 2 ] & v [ 3 ] ) . render ( ) , " False " )
@unittest . expectedFailure # TODO: fix
def test_from_merge_views ( self ) :
# taken from test_merges_from_fuzzer1
# generated by
# v0 = View(shape=(2, 4), strides=(2, 1), offset=-2, mask=((0, 2), (2, 4)), contiguous=False)
# v1 = View(shape=(2, 4, 2, 2), strides=(4, 0, -2, -1), offset=3, mask=None, contiguous=False)
# s = ShapeTracker((v0, v1))
# idx, valid = s.to_indexed_uops()
# print(f"{idx.render()=}")
# print(f"{valid.render()=}")
# s = ShapeTracker((View(shape=(2, 4, 2, 2), strides=(2, 0, 0, -1), offset=1, mask=((0, 2), (0, 4), (0, 1), (0, 2)), contiguous=False),))
# idx, valid = s.to_indexed_uops()
# print(f"{idx.render()=}")
# print(f"{valid.render()=}")
ridx0 = Range ( 0 , 2 )
ridx2 = Range ( 2 , 2 )
ridx3 = Range ( 3 , 2 )
idx = ( ( ( ridx0 * 2 ) + ( ( ( ( ridx2 * 2 ) + ( ridx3 * 3 ) ) + 3 ) % 4 ) ) + - 2 )
valid = ( ( ( ( ( ( ridx2 * 2 ) + ( ridx3 * 3 ) ) + 3 ) % 4 ) < 2 ) != True ) # noqa: E712
load = get_gated_load_uop ( valid , idx )
self . check ( load ,
" (((ridx0*2)+(ridx3*-1))+1) " ,
" (ridx2<1) " )
def test_valid_becomes_const1 ( self ) :
# from DSP mobilenetv2
ridx0 = Range ( 0 , 30 )
ridx1 = Range ( 1 , 7 )
ridx2 = Range ( 2 , 2 )
alu11 = ( ridx1 + ridx2 )
alu15 = ( ( alu11 + 1 ) / / 7 )
idx = ( alu15 * - 31 ) + ( ( ( ( ( alu11 + 218 ) / / 224 ) + ridx0 ) % 30 ) * 1568 )
valid = ( ridx2 < 1 ) & ( ridx1 < 6 )
load = get_gated_load_uop ( valid , idx )
self . check ( load ,
" (ridx0*1568) " ,
" ((ridx2<1)&(ridx1<6)) " )
class TestImageSimplification ( unittest . TestCase ) :
def check ( self , load , svalid , sidx0 , sidx1 ) :
load = full_graph_rewrite ( load . sink ( ) ) . src [ 0 ]
idx = load . src [ 0 ] . src [ 1 ]
self . assertEqual ( idx . op , Ops . VECTORIZE )
self . assertEqual ( len ( idx . src ) , 2 )
idx0 , idx1 = idx . src [ 0 ] , idx . src [ 1 ]
self . assertEqual ( idx0 . render ( simplify = False ) , sidx0 )
self . assertEqual ( idx1 . render ( simplify = False ) , sidx1 )
if svalid is not None : self . assertEqual ( load . src [ 0 ] . src [ 2 ] . render ( simplify = False ) , svalid )
def test_idx_gt_c ( self ) :
# (idx1 < c+1).ne(True) ? (..., idx1-1+c) : 0 can drop the valid
# (idx1 < c+1).ne(True) -> idx > c
gidx0 = Special ( " gidx0 " , 32 )
gidx1 = Special ( " gidx1 " , 32 )
shape = ( 10 , 10 , 4 )
load = get_load_image_uop ( shape , ( gidx1 < 1 ) . ne ( True ) , ( gidx0 , gidx1 - 1 ) )
self . check ( load , None , " gidx0 " , " (gidx1+-1) " )
load = get_load_image_uop ( shape , ( gidx1 < 1 ) . ne ( True ) , ( gidx0 , gidx1 - 2 ) )
self . check ( load , None , " gidx0 " , " (gidx1+-2) " )
# should match any one of the AND clause and drop the matched statement from valid
valid = ( gidx0 < 1 ) . ne ( True ) & ( gidx1 < 1 ) . ne ( True )
load = get_load_image_uop ( shape , valid , ( gidx0 + 1 , gidx1 - 1 ) )
self . check ( load , " ((gidx0<1)!=True) " , " (gidx0+1) " , " (gidx1+-1) " )
valid = ( gidx0 < 1 ) . ne ( True ) & ( gidx1 < 1 ) . ne ( True )
load = get_load_image_uop ( shape , valid , ( gidx0 , gidx1 - 1 ) )
self . check ( load , None , " gidx0 " , " (gidx1+-1) " )
def test_idx_lt_bound ( self ) :
# (idx1 < image_bound) ? (..., idx1) : 0 can drop the valid
gidx0 = Special ( " gidx0 " , 32 )
gidx1 = Special ( " gidx1 " , 32 )
load = get_load_image_uop ( ( 10 , 10 , 4 ) , gidx1 < 10 , ( gidx0 , gidx1 ) )
self . check ( load , None , " gidx0 " , " gidx1 " )
# same thing, valid has a div
load = get_load_image_uop ( ( 10 , 10 , 4 ) , gidx1 / / 2 < 5 , ( gidx0 , gidx1 ) )
self . check ( load , None , " gidx0 " , " gidx1 " )
# 10x20 image, not out of bound
load = get_load_image_uop ( ( 20 , 10 , 4 ) , gidx1 < 10 , ( gidx0 , gidx1 ) )
self . check ( load , " (gidx1<10) " , " gidx0 " , " gidx1 " )
def test_generic_idx_lt_bound ( self ) :
# (idx1 < image_bound - c) ? (..., idx1 + c) : 0 can drop the valid
gidx0 = Special ( " gidx0 " , 32 )
gidx1 = Special ( " gidx1 " , 32 )
shape = ( 10 , 10 , 4 )
load = get_load_image_uop ( shape , ( gidx1 < 8 ) , ( gidx0 , gidx1 + 2 ) )
self . check ( load , None , " gidx0 " , " (gidx1+2) " )
load = get_load_image_uop ( shape , ( gidx1 < 5 ) , ( gidx0 , gidx1 + 5 ) )
self . check ( load , None , " gidx0 " , " (gidx1+5) " )
def test_valid_empty_set ( self ) :
gidx0 = Special ( " gidx0 " , 32 )
gidx1 = Special ( " gidx1 " , 32 )
shape = ( 32 , 32 , 4 )
idx = ( gidx0 % 2 , gidx1 + 2 )
# not empty
load = get_load_image_uop ( shape , gidx0 < 8 , idx )
self . check ( load , " (gidx0<8) " , " (gidx0 % 2) " , " (gidx1+2) " )
# empty -> invalid
load = get_load_image_uop ( shape , ( gidx0 < 8 ) & ( gidx0 < 8 ) . ne ( True ) , idx )
load = full_graph_rewrite ( load . sink ( ) ) . src [ 0 ]
self . assertEqual ( load . op , Ops . VECTORIZE )
self . assertEqual ( load . dtype . count , 4 )
def test_openpilot_conv1 ( self ) :
# first conv in openpilot
# kernel in tinygrad ae5d1407ee844a97a52ad3756835d38e7e2b9e1b https://gist.github.com/chenyuxyz/39c2d4e9a076b46731c67d345ff066b6
idx1 = Special ( " idx1 " , 32 )
idx2 = Special ( " idx2 " , 64 )
# ridx0 = Variable("ridx0", 0, 5)
# ridx1 = Variable("ridx1", 0, 2)
# ridx2 = Variable("ridx2", 0, 2)
ridx0 = Range ( 0 , 6 )
ridx1 = Range ( 1 , 3 )
ridx2 = Range ( 2 , 3 )
alu1 = ( ( idx2 * 2 ) + ridx1 )
alu4 = ( ( idx1 * 48 ) + ( ridx2 * 6 ) + ridx0 )
valid = ( ( ( ( idx2 * 2 ) + ( ridx1 ) ) < 1 ) . ne ( True ) ) & ( ( ( ( idx1 * 8 ) + ( ridx2 ) ) < 1 ) . ne ( True ) )
shape = ( 128 , 1536 , 4 )
idx = ( ( alu4 + 1530 ) % 1536 , alu1 + ( ( idx1 + ( ( ridx2 + 7 ) / / 8 ) + 31 ) / / 32 ) + ( - 2 ) )
load = get_load_image_uop ( shape , valid , idx )
self . check ( load , None , " ((((idx1*48)+(ridx2*6))+ridx0)+-6) " , " (((idx2*2)+ridx1)+-1) " )
def test_openpilot_conv2 ( self ) :
# conv in test/external/external_test_valid_remove.py
idx1 = Special ( " idx1 " , 32 )
idx2 = Special ( " idx2 " , 64 )
# ridx0 = Variable("ridx0", 0, 2)
# ridx1 = Variable("ridx1", 0, 2)
# ridx2 = Variable("ridx2", 0, 2)
ridx0 = Range ( 0 , 3 )
ridx1 = Range ( 1 , 3 )
ridx2 = Range ( 2 , 3 )
alu1 = ( ( idx2 * 2 ) + ridx1 )
alu3 = ( ( idx1 * 24 ) + ( ridx2 * 3 ) + ridx0 )
valid = ( ( ( ( idx2 * 2 ) + ridx1 ) < 1 ) . ne ( True ) ) & ( ( ( ( idx1 * 8 ) + ridx2 ) < 1 ) . ne ( True ) )
shape = ( 128 , 768 , 4 )
idx = ( ( alu3 + 765 ) % 768 , alu1 + ( ( idx1 + ( ( ridx2 + 7 ) / / 8 ) + 31 ) / / 32 ) + ( - 2 ) )
load = get_load_image_uop ( shape , valid , idx )
self . check ( load , None , " ((((idx1*24)+(ridx2*3))+ridx0)+-3) " , " (((idx2*2)+ridx1)+-1) " )
def test_openpilot_conv3 ( self ) :
# in openpilot 0.9.7
idx0 = Special ( " idx0 " , 64 )
idx1 = Special ( " idx1 " , 2 )
idx2 = Special ( " idx2 " , 4 )
ridx0 = Range ( 0 , 7 )
ridx1 = Range ( 1 , 7 )
alu2 = ( ( idx2 * 2 ) + ridx0 )
alu4 = ( ( idx1 * 8 ) + ridx1 )
alu6 = ( ( idx1 * 512 ) + ( ridx1 * 64 ) + idx0 )
valid = ( alu2 < 11 ) & ( alu4 < 3 ) . ne ( True )
shape = ( 8 , 1024 , 4 )
idx = ( ( ( alu6 + 832 ) % 1024 ) , ( alu2 + ( ( idx1 + ( ( ridx1 + 5 ) / / 8 ) + 1 ) / / 2 ) + ( - 4 ) ) )
load = get_load_image_uop ( shape , valid , idx )
self . check ( load ,
" ((((idx2*2)+ridx0)<11)&((((idx1*8)+ridx1)<3)!=True)) " ,
" (((idx0+((idx1*512)+(ridx1*64)))+832) % 1024) " ,
" (((((idx1+((ridx1+5)//8))+1)//2)+((idx2*2)+ridx0))+-4) " )
def test_simplify1 ( self ) :
# idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1)
gidx = Special ( " gidx " , 512 )
valid = ( gidx < 488 ) & ( gidx < 480 ) . ne ( True )
idx = ( ( gidx * 3 + 18 ) % 26 , ( gidx * 3 + 18 ) / / 26 - 56 )
load = get_load_image_uop ( ( 1 , 26 , 4 ) , valid , idx )
self . check ( load , None , " ((gidx*3)+-1438) " , " 0 " )
def test_simplify2 ( self ) :
# from GPU=1 DEBUG=4 FORWARD_ONLY=1 IMAGE=2 python3 test/test_ops.py TestOps.test_simple_padding_conv2d
lidx = Special ( " lidx " , 4 )
valid = ( lidx < 3 ) & ( lidx < 1 ) . ne ( True )
idx = ( ( lidx + 1 ) % 2 , ( lidx + 1 ) / / 2 - 1 )
load = get_load_image_uop ( ( 1 , 2 , 4 ) , valid , idx )
self . check ( load , None , " (lidx+-1) " , " 0 " )
def test_simplify3 ( self ) :
# from openpilot
idx0 = Special ( " idx0 " , 265 )
valid = ( idx0 < 201 ) . ne ( True )
idx = ( ( idx0 + 55 ) % 64 , ( idx0 + 55 ) / / 64 - 4 )
load = get_load_image_uop ( ( 1 , 64 , 4 ) , valid , idx )
self . check ( load , None , " (idx0+-201) " , " 0 " )
def test_simplify4 ( self ) :
idx0 = Special ( " idx0 " , 512 )
shape = ( 4 , 64 , 4 )
alu2 = ( ( idx0 * 4 + 1 ) % 32 )
alu3 = ( ( idx0 * 4 + 2 ) % 32 )
alu4 = ( ( idx0 * 4 + 3 ) % 32 )
alu5 = ( idx0 * 4 % 32 )
alu8 = ( idx0 / / 8 % 32 / / 4 )
alu9 = idx0 < 256
# TODO: can this be simplified further?
load = get_load_image_uop ( shape , alu9 , ( ( ( alu8 + ( alu2 * 8 ) ) % 64 ) , ( alu2 / / 8 ) ) )
self . check ( load , " (idx0<256) " , " (((((idx0 % 8)*32)+(idx0//32))+8) % 64) " , " ((idx0 % 8)//2) " )
load = get_load_image_uop ( shape , alu9 , ( ( ( alu8 + ( alu3 * 8 ) ) % 64 ) , ( alu3 / / 8 ) ) )
self . check ( load , " (idx0<256) " , " (((((idx0 % 8)*32)+(idx0//32))+16) % 64) " , " ((idx0 % 8)//2) " )
load = get_load_image_uop ( shape , alu9 , ( ( ( alu8 + ( alu4 * 8 ) ) % 64 ) , ( alu4 / / 8 ) ) )
self . check ( load , " (idx0<256) " , " (((((idx0 % 8)*32)+(idx0//32))+24) % 64) " , " ((idx0 % 8)//2) " )
load = get_load_image_uop ( shape , alu9 , ( ( ( alu8 + ( alu5 * 8 ) ) % 64 ) , ( alu5 / / 8 ) ) )
self . check ( load , " (idx0<256) " , " ((((idx0 % 8)*32)+(idx0//32)) % 64) " , " ((idx0 % 8)//2) " )
def test_simplify5 ( self ) :
# openpilot 0.9.7, chunk replacement to simplify
shape = ( 10 , 384 , 4 )
idx0 = Special ( " idx0 " , 16 )
idx1 = Special ( " idx1 " , 24 )
alu0 = idx0 * 4
alu1 = ( idx1 * 256 ) + alu0
alu2 = idx1 / / 3
alu3 = ( ( alu1 + 1 ) % 768 )
idx = ( ( idx0 + ( ( ( ( alu3 / / 640 ) + alu2 ) % 8 ) * 16 ) + 128 ) , ( ( alu3 / / 64 ) % 10 ) )
valid = alu3 < 640
load = get_load_image_uop ( shape , valid , idx )
self . check ( load , " (((idx0+(idx1*64)) % 192)<160) " , " ((idx0+((idx1//3)*16))+128) " , " (((idx0+(idx1*64)) % 192)//16) " )
if __name__ == ' __main__ ' :
unittest . main ( )