import unittest , random
from tinygrad . dtype import dtypes
from tinygrad . uop . ops import print_uops , UOp , Ops
from tinygrad . codegen . linearize import block_reorder
from tinygrad . renderer . cstyle import OpenCLRenderer
def is_toposorted ( lst : list [ UOp ] ) :
seen = set ( )
for u in lst :
if any ( p not in seen for p in u . src ) : return False
seen . add ( u )
return True
class TestBlockReorder ( unittest . TestCase ) :
def _test_randomize ( self , golden : list [ UOp ] ) :
# test random order is always same
for _ in range ( 50 ) :
# shuffle and form a valid toposort
lst = golden [ : ]
random . shuffle ( lst )
topolst = [ ]
for u in lst :
for p in u . toposort ( ) :
if p not in topolst : topolst . append ( p )
assert is_toposorted ( topolst )
for x , y in zip ( golden , this_order := block_reorder ( topolst ) ) :
if x is not y :
print_uops ( golden )
print_uops ( this_order )
self . assertIs ( x , y )
def _test_render ( self , golden : list [ UOp ] ) :
return OpenCLRenderer ( ) . render ( golden )
def test_loads ( self ) :
a = UOp ( Ops . DEFINE_GLOBAL , dtype = dtypes . float . ptr ( ) , arg = 0 )
b = UOp ( Ops . DEFINE_GLOBAL , dtype = dtypes . float . ptr ( ) , arg = 1 )
c = UOp ( Ops . DEFINE_GLOBAL , dtype = dtypes . float . ptr ( ) , arg = 2 )
v1 = UOp ( Ops . SPECIAL , dtype = dtypes . int , arg = ( " gidx0 " , 4 ) )
v2 = UOp ( Ops . SPECIAL , dtype = dtypes . int , arg = ( " gidx1 " , 4 ) )
v1 = v1 * 27
v2 = v2 * 4
loads = [
a . index ( v1 ) . load ( dtype = dtypes . float ) ,
a . index ( v1 + 1 ) . load ( dtype = dtypes . float ) ,
a . index ( v1 + 2 ) . load ( dtype = dtypes . float ) ,
a . index ( v1 + 3 ) . load ( dtype = dtypes . float ) ,
b . index ( v2 ) . load ( dtype = dtypes . float ) ,
b . index ( v2 + 1 ) . load ( dtype = dtypes . float ) ,
b . index ( v2 + 2 ) . load ( dtype = dtypes . float ) ,
b . index ( v2 + 3 ) . load ( dtype = dtypes . float ) ]
#random.shuffle(loads)
sink = c . store ( sum ( loads ) ) . sink ( )
# determine golden order
golden = block_reorder ( list ( sink . toposort ( ) ) )
# render for test
print ( self . _test_render ( golden ) )
#print_uops(golden)
# assert the loads are in this order
self . assertListEqual ( [ g . src [ 0 ] . src [ 1 ] . render ( ) for g in golden if g . op is Ops . LOAD ] ,
[ ' (gidx1*4) ' , ' ((gidx1*4)+1) ' , ' ((gidx1*4)+2) ' , ' ((gidx1*4)+3) ' ,
' (gidx0*27) ' , ' ((gidx0*27)+1) ' , ' ((gidx0*27)+2) ' , ' ((gidx0*27)+3) ' ] )
# assert math is after loads
first_math = [ i for i , g in enumerate ( golden ) if g . op is Ops . ADD and g . dtype == dtypes . float ] [ 0 ]
assert not any ( x . op is Ops . LOAD for x in golden [ first_math : ] )
# confirm the sort is stable
self . _test_randomize ( golden )
if __name__ == ' __main__ ' :
unittest . main ( )