import unittest
from tinygrad . helpers import prod
from tinygrad . shape . view import View
from tinygrad . shape . shapetracker import ShapeTracker
from tinygrad import Variable
from test . unit . test_shapetracker import shapetracker_getitem
class MultiShapeTracker :
def __init__ ( self , sts : list [ ShapeTracker ] ) : self . sts = sts
@property
def shape ( self ) : return self . sts [ 0 ] . shape
def reshape ( self , arg ) : self . sts = [ x . reshape ( arg ) for x in self . sts ]
def permute ( self , arg ) : self . sts = [ x . permute ( arg ) for x in self . sts ]
def expand ( self , arg ) : self . sts = [ x . expand ( arg ) for x in self . sts ]
def shrink ( self , arg ) : self . sts = [ x . shrink ( arg ) for x in self . sts ]
def flip ( self , arg ) : self . sts = [ x . flip ( arg ) for x in self . sts ]
def pad ( self , arg ) : self . sts = [ x . pad ( arg ) for x in self . sts ]
def st_equal ( st1 : ShapeTracker , st2 : ShapeTracker ) - > bool :
if st1 . shape != st2 . shape : return False
if st1 == st2 : return True
for i in range ( 0 , prod ( st1 . shape ) ) :
st1_off , st1_v = shapetracker_getitem ( st1 , i )
st2_off , st2_v = shapetracker_getitem ( st2 , i )
if st1_v != st2_v or ( st1_off != st2_off and st1_v ) :
print ( f " ST MISMATCH @ { i } , { st1_v =} != { st2_v =} , { st1_off =} != { st2_off =} " )
print ( st1 )
print ( st2 )
return False
return True
class TestShapeTrackerBasics ( unittest . TestCase ) :
def test_pad_shrink_removes_mask ( self ) :
a = ShapeTracker . from_shape ( ( 10 , 10 ) )
a = a . pad ( ( ( 0 , 2 ) , ( 0 , 2 ) ) )
a = a . shrink ( ( ( 0 , 10 ) , ( 0 , 10 ) ) )
assert len ( a . views ) == 1 and a . views [ - 1 ] . mask is None
def test_pad_shrink_leaves_mask ( self ) :
a = ShapeTracker . from_shape ( ( 10 , 10 ) )
a = a . pad ( ( ( 0 , 2 ) , ( 0 , 2 ) ) )
a = a . shrink ( ( ( 0 , 10 ) , ( 0 , 11 ) ) )
assert len ( a . views ) == 1 and a . views [ - 1 ] . mask is not None
def test_reshape_makes_same ( self ) :
a = ShapeTracker . from_shape ( ( 2 , 5 ) )
x = a . pad ( ( ( 2 , 0 ) , ( 0 , 0 ) ) )
x = x . reshape ( ( 2 , 2 , 5 ) )
x1 = x . reshape ( ( 4 , 5 ) )
x1 = x1 . reshape ( ( 2 , 2 , 5 ) )
assert x == x1 . simplify ( )
def test_simplify_is_correct ( self ) :
multiv = ShapeTracker ( views = ( View ( shape = ( 15 , 3 ) , strides = ( 9 , 1 ) , offset = 6 , mask = None , contiguous = False ) ,
View ( shape = ( 4 , 3 ) , strides = ( 12 , 4 ) , offset = 0 , mask = None , contiguous = False ) ) )
assert st_equal ( multiv , multiv . simplify ( ) )
class TestShapeTrackerAdd ( unittest . TestCase ) :
def test_simple_add_reshape ( self ) :
a = ShapeTracker . from_shape ( ( 10 , 10 ) )
a = a . reshape ( ( 100 , ) )
b = ShapeTracker . from_shape ( ( 100 , ) )
assert a + b == b
def test_simple_add_permute ( self ) :
a = ShapeTracker . from_shape ( ( 10 , 10 ) )
a = a . permute ( ( 1 , 0 ) )
b = ShapeTracker . from_shape ( ( 10 , 10 ) )
b = b . permute ( ( 1 , 0 ) )
assert a + b == ShapeTracker . from_shape ( ( 10 , 10 ) )
def test_plus_real1 ( self ) :
st = MultiShapeTracker ( [ ShapeTracker . from_shape ( ( 15 , 9 ) ) ] )
st . shrink ( ( ( 0 , 15 ) , ( 6 , 9 ) ) )
backup = st . sts [ 0 ]
st . sts . append ( ShapeTracker . from_shape ( backup . shape ) )
st . reshape ( ( 45 , ) )
st . flip ( ( True , ) )
st . reshape ( ( 15 , 3 ) )
assert st_equal ( backup + st . sts [ 1 ] , st . sts [ 0 ] )
def test_off_by_one ( self ) :
st1 = ShapeTracker ( views = ( View ( shape = ( 5 , ) , strides = ( 1 , ) , offset = 0 , mask = None , contiguous = True ) ,
View ( shape = ( 5 , ) , strides = ( 1 , ) , offset = 0 , mask = None , contiguous = True ) ) )
st2 = ShapeTracker ( views = ( View ( shape = ( 4 , ) , strides = ( 1 , ) , offset = 0 , mask = None , contiguous = True ) ,
View ( shape = ( 5 , ) , strides = ( 1 , ) , offset = 0 , mask = None , contiguous = True ) ) )
assert not ( st_equal ( st1 , st2 ) )
class TestShapeTrackerAddVariable ( unittest . TestCase ) :
def test_self_add ( self ) :
j = Variable ( " j " , 0 , 20 ) . bind ( 10 )
a = ShapeTracker . from_shape ( ( 10 , 10 ) )
x = a . reshape ( ( 10 , j ) )
out = x + x
assert out == x
def test_self_add_reshape ( self ) :
j = Variable ( " j " , 0 , 20 ) . bind ( 10 )
a = ShapeTracker . from_shape ( ( 10 , 10 ) )
x = a . reshape ( ( 10 , j ) )
out = x . reshape ( ( 5 , 2 , j ) ) + x
assert out == x
def test_merge_symbolic_views ( self ) :
var_i = Variable ( ' i ' , 1 , 10 )
var_j = Variable ( ' i ' , 1 , 10 )
vm1 = View ( shape = ( var_i , var_j , 3 ) , strides = ( 3 , 0 , 1 ) , offset = 0 , mask = None , contiguous = False )
vm2 = View ( shape = ( var_i , var_j , 3 ) , strides = ( var_j * 3 , 3 , 1 ) , offset = 0 , mask = None , contiguous = True )
ShapeTracker ( ( vm1 , ) ) + ShapeTracker ( ( vm2 , ) )
def test_merge_symbolic_views_2 ( self ) :
var_i = Variable ( ' i ' , 1 , 10 )
var_j = Variable ( ' j ' , 1 , 10 )
vm1 = View ( shape = ( var_i , var_j ) , strides = ( 0 , 0 ) , offset = 0 , mask = None , contiguous = False )
vm2 = View ( shape = ( var_i , var_j ) , strides = ( var_j , 1 ) , offset = 0 , mask = None , contiguous = True )
ret = ( ShapeTracker ( ( vm1 , ) ) + ShapeTracker ( ( vm2 , ) ) ) . reshape ( ( var_i , var_j , 1 ) )
ret_2 = ShapeTracker ( ( vm1 , ) ) + ShapeTracker ( ( vm2 , ) ) . reshape ( ( var_i , var_j , 1 ) )
assert ret == ret_2
class TestShapeTrackerInvert ( unittest . TestCase ) :
def test_invert_reshape ( self ) :
a = ShapeTracker . from_shape ( ( 10 , 10 ) )
x = a . reshape ( ( 5 , 20 ) )
ap = ShapeTracker . from_shape ( x . shape ) + x . invert ( a . shape )
assert ap == a , f " { ap } != { a } "
def test_invert_permute ( self ) :
a = ShapeTracker . from_shape ( ( 5 , 20 ) )
x = a . permute ( ( 1 , 0 ) )
ap = x + x . invert ( a . shape )
assert ap == a , f " { ap } != { a } "
def test_invert_permute_3 ( self ) :
a = ShapeTracker . from_shape ( ( 8 , 4 , 5 ) )
x = a . permute ( ( 1 , 2 , 0 ) )
ap = x + x . invert ( a . shape )
assert ap == a , f " { ap } != { a } "
def test_invert_real1 ( self ) :
a = ShapeTracker . from_shape ( ( 3 , 6 , 10 ) )
x = a . reshape ( ( 3 , 3 , 2 , 10 ) )
x = x . permute ( ( 2 , 1 , 3 , 0 ) )
ap = x + x . invert ( a . shape )
assert ap == a , f " { ap } != { a } "
def test_cant_invert_expand ( self ) :
a = ShapeTracker . from_shape ( ( 10 , 1 ) )
x = a . expand ( ( 10 , 10 ) )
assert x . invert ( a . shape ) is None
def test_cant_invert_shrink ( self ) :
a = ShapeTracker . from_shape ( ( 10 , 10 ) )
x = a . shrink ( ( ( 0 , 10 ) , ( 2 , 8 ) ) )
assert x . invert ( a . shape ) is None
def test_can_invert_flip ( self ) :
a = ShapeTracker . from_shape ( ( 20 , 10 ) )
x = a . flip ( ( True , False ) )
ap = x + x . invert ( a . shape )
assert st_equal ( ap , a )
def test_can_invert_flip_permute ( self ) :
a = ShapeTracker . from_shape ( ( 20 , 10 ) )
x = a . permute ( ( 1 , 0 ) )
x = x . flip ( ( True , False ) )
ap = x + x . invert ( a . shape )
assert st_equal ( ap , a )
def test_invert_failure ( self ) :
a = ShapeTracker . from_shape ( ( 2 , 5 ) )
x = a . pad ( ( ( 2 , 0 ) , ( 0 , 0 ) ) )
x = x . reshape ( ( 2 , 2 , 5 ) )
x = x . reshape ( ( 4 , 5 ) )
ap = x + x . invert ( a . shape )
assert st_equal ( ap , a )
if __name__ == ' __main__ ' :
unittest . main ( )