import unittest
from test . helpers import ast_const
from tinygrad . codegen . kernel import Opt , OptOps
from tinygrad . codegen . kernel import Kernel
from tinygrad . ops import UOp , Ops
from tinygrad . engine . search import bufs_from_lin , actions , beam_search
from tinygrad . device import Device , Buffer
from tinygrad . tensor import Tensor
from tinygrad . dtype import dtypes
from tinygrad . helpers import Context , GlobalCounters
from tinygrad . engine . realize import capturing
from tinygrad . shape . shapetracker import ShapeTracker
from tinygrad . shape . view import View
from extra . optimization . helpers import time_linearizer
class TestTimeLinearizer ( unittest . TestCase ) :
@unittest . skipIf ( Device . DEFAULT == " WEBGPU " , " WebGPU timestamps are low precision, tm is 0 " )
def test_reasonable_time ( self ) :
a = Tensor ( [ 1 , 2 , 3 , 4 ] ) . realize ( )
si = ( a + 1 ) . schedule ( ) [ 0 ]
# create fresh empty buffers
rawbufs = [ Buffer ( b . device , b . size , b . dtype ) . allocate ( ) for b in si . bufs ]
tm = time_linearizer ( Kernel ( si . ast ) , rawbufs , allow_test_size = False , cnt = 10 , disable_cache = True )
assert tm > 0 and tm != float ( ' inf ' )
def test_bufs_from_lin ( self ) :
a = Tensor ( [ 1 , 2 , 3 , 4 ] ) . realize ( )
si = ( a + 1 ) . schedule ( ) [ 0 ]
rawbufs = bufs_from_lin ( lin := Kernel ( si . ast ) )
assert len ( rawbufs ) == len ( lin . membufs ) == 2
assert all ( r is not None for r in rawbufs )
assert all ( isinstance ( r , Buffer ) for r in rawbufs )
assert all ( r . size > 0 for r in rawbufs )
def test_bufs_from_lin_alt ( self ) :
a = Tensor . randn ( 4 , 4 ) . realize ( )
b = a + a [ 0 ]
si = b . schedule ( ) [ 0 ]
rawbufs = bufs_from_lin ( k := Kernel ( si . ast ) )
assert len ( rawbufs ) == len ( k . membufs ) == 2
assert all ( r is not None for r in rawbufs )
assert all ( isinstance ( r , Buffer ) for r in rawbufs )
assert all ( r . size > 0 for r in rawbufs )
def test_kernel_count ( self ) :
"""
Ensure that the kernel count is not incremented by time_linearizer when clearing l2
"""
# ast of Tensor.zeros(16).contiguous().realize()
ast = UOp ( Ops . SINK , src = (
UOp ( Ops . STORE , src = (
UOp ( Ops . DEFINE_GLOBAL , dtypes . float . ptr ( ) , arg = 0 , src = ( ) ) ,
UOp ( Ops . VIEW , arg = ShapeTracker ( views = ( View ( shape = ( 16 , ) , strides = ( 1 , ) , offset = 0 , mask = None , contiguous = True ) , ) ) ) ,
ast_const ( dtypes . float , 0.0 , st_src = (
UOp ( Ops . VIEW , arg = ShapeTracker ( views = ( View ( shape = ( 16 , ) , strides = ( 0 , ) , offset = 0 , mask = None , contiguous = False ) , ) ) ) , ) ) , ) ) , ) )
lin = Kernel ( ast )
bufs = bufs_from_lin ( lin )
kernel_count = GlobalCounters . kernel_count
time_linearizer ( lin , bufs , allow_test_size = False , cnt = 2 , disable_cache = True , clear_l2 = True )
assert GlobalCounters . kernel_count == kernel_count , " kernel count was incremented by time_linearizer "
class TestBEAM ( unittest . TestCase ) :
def test_dynamic_beam ( self ) :
# TODO: make this infra globally usable
class Capture :
def __init__ ( self ) : self . captured = [ ]
def add ( self , x ) : self . captured . append ( x )
capturing . append ( Capture ( ) )
kernel_count = GlobalCounters . kernel_count
with Context ( BEAM = 1 ) : Tensor . zeros ( 16 ) . contiguous ( ) . realize ( )
assert GlobalCounters . kernel_count == kernel_count + 1
k_beam_1 = capturing [ 0 ] . captured
capturing . clear ( )
capturing . append ( Capture ( ) )
kernel_count = GlobalCounters . kernel_count
with Context ( BEAM = 0 ) : Tensor . zeros ( 16 ) . contiguous ( ) . realize ( )
assert GlobalCounters . kernel_count == kernel_count + 1
k_beam_0 = capturing [ 0 ] . captured
capturing . clear ( )
self . assertNotEqual ( k_beam_0 [ - 1 ] . prg . p . src , k_beam_1 [ - 1 ] . prg . p . src )
def test_get_kernel_actions ( self ) :
from test . test_linearizer import helper_realized_ast
a = Tensor . rand ( 4 , 3 )
b = Tensor . rand ( 3 )
realized_ast , _ = helper_realized_ast ( a @ b )
from tinygrad . engine . search import get_kernel_actions
lins = get_kernel_actions ( Kernel ( realized_ast ) , False ) . values ( )
# ensure amt=0 are not duplicated
if Opt ( OptOps . UPCAST , 0 , 0 ) in actions :
assert len ( [ x for x in lins if x . applied_opts [ 0 ] == Opt ( OptOps . UPCAST , axis = 0 , arg = 4 ) ] ) == 0 , " did not de-dup UPCAST "
if Opt ( OptOps . LOCAL , 0 , 0 ) in actions :
assert len ( [ x for x in lins if x . applied_opts [ 0 ] == Opt ( OptOps . LOCAL , axis = 0 , arg = 4 ) ] ) == 0 , " did not de-dup LOCAL "
if Opt ( OptOps . UNROLL , 0 , 0 ) in actions :
assert len ( [ x for x in lins if x . applied_opts [ 0 ] == Opt ( OptOps . UNROLL , axis = 0 , arg = 3 ) ] ) == 0 , " did not de-dup UNROLL "
if Opt ( OptOps . GROUP , 0 , 0 ) in actions :
assert len ( [ x for x in lins if x . applied_opts [ 0 ] == Opt ( OptOps . GROUP , axis = 0 , arg = 3 ) ] ) == 0 , " did not de-dup GROUP "
if Opt ( OptOps . GROUPTOP , 0 , 0 ) in actions :
assert len ( [ x for x in lins if x . applied_opts [ 0 ] == Opt ( OptOps . GROUPTOP , axis = 0 , arg = 3 ) ] ) == 0 , " did not de-dup GROUPTOP "
@unittest . skipUnless ( Device [ Device . DEFAULT ] . renderer . tensor_cores , " test requires tensor cores " )
def test_search_over_shape ( self ) :
from test . test_linearizer import helper_realized_ast
from tinygrad . engine . search import get_kernel_actions
dtype_pairs = [ ( tc . dtype_in , tc . dtype_out ) for tc in Device [ Device . DEFAULT ] . renderer . tensor_cores ]
multi_shape_dtype_pairs = [ dts for dts in dtype_pairs if dtype_pairs . count ( dts ) > 1 ]
if len ( multi_shape_dtype_pairs ) == 0 : raise unittest . SkipTest ( " only one tc available per dtype pair to search over " )
for ( dtype_in , dtype_out ) in multi_shape_dtype_pairs :
a = Tensor . rand ( 16 , 16 , dtype = dtype_in )
b = Tensor . rand ( 16 , 16 , dtype = dtype_in )
realized_ast , _ = helper_realized_ast ( a . matmul ( b , dtype = dtype_out ) )
lins = get_kernel_actions ( Kernel ( realized_ast ) ) . values ( )
assert len ( set ( lin . tensor_core . dims for lin in lins if lin . tensor_core is not None ) ) > 1
def test_get_kernel_actions_preserves_actions_state ( self ) :
from test . test_linearizer import helper_realized_ast
from tinygrad . engine . search import get_kernel_actions
a = Tensor . rand ( 16 , 16 )
b = Tensor . rand ( 16 , 16 )
realized_ast , _ = helper_realized_ast ( a @ b )
actions_before = actions . copy ( )
get_kernel_actions ( Kernel ( realized_ast ) )
actions_after = actions . copy ( )
assert actions_after == actions_before , " actions state was not preserved "
def test_filter_global_buffer ( self ) :
# taken from https://github.com/tinygrad/tinygrad/issues/4612
ast = UOp ( Ops . SINK , dtypes . void , arg = None , src = (
UOp ( Ops . STORE , dtypes . void , arg = None , src = (
UOp ( Ops . DEFINE_GLOBAL , dtypes . float . ptr ( ) , arg = 0 , src = ( ) ) ,
UOp ( Ops . VIEW , dtypes . void , arg = ShapeTracker ( views = ( View ( shape = ( 1 , 1 , 256 ) , strides = ( 0 , 0 , 1 ) , offset = 0 , mask = None , contiguous = True ) , ) ) , src = ( ) ) , # noqa: E501
UOp ( Ops . REDUCE_AXIS , dtypes . float , arg = ( Ops . MAX , ( 1 , ) ) , src = (
UOp ( Ops . MUL , dtypes . float , arg = None , src = (
UOp ( Ops . ADD , dtypes . float , arg = None , src = (
UOp ( Ops . ADD , dtypes . float , arg = None , src = (
UOp ( Ops . ADD , dtypes . float , arg = None , src = (
UOp ( Ops . ADD , dtypes . float , arg = None , src = (
UOp ( Ops . ADD , dtypes . float , arg = None , src = (
UOp ( Ops . LOAD , dtypes . float , arg = None , src = (
UOp ( Ops . DEFINE_GLOBAL , dtypes . float . ptr ( ) , arg = 1 , src = ( ) ) ,
UOp ( Ops . VIEW , dtypes . void , arg = ShapeTracker ( views = ( View ( shape = ( 384768 , ) , strides = ( 1 , ) , offset = 0 , mask = ( ( 0 , 64128 ) , ) , contiguous = False ) , View ( shape = ( 1 , 501 , 256 ) , strides = ( 0 , 1 , 501 ) , offset = 256512 , mask = None , contiguous = False ) ) ) , src = ( ) ) , ) ) , # noqa: E501
UOp ( Ops . LOAD , dtypes . float , arg = None , src = (
UOp ( Ops . DEFINE_GLOBAL , dtypes . float . ptr ( ) , arg = 2 , src = ( ) ) ,
UOp ( Ops . VIEW , dtypes . void , arg = ShapeTracker ( views = ( View ( shape = ( 384768 , ) , strides = ( 1 , ) , offset = - 64128 , mask = ( ( 64128 , 128256 ) , ) , contiguous = False ) , View ( shape = ( 1 , 501 , 256 ) , strides = ( 0 , 1 , 501 ) , offset = 256512 , mask = None , contiguous = False ) ) ) , src = ( ) ) , ) ) , ) ) , # noqa: E501
UOp ( Ops . LOAD , dtypes . float , arg = None , src = (
UOp ( Ops . DEFINE_GLOBAL , dtypes . float . ptr ( ) , arg = 3 , src = ( ) ) ,
UOp ( Ops . VIEW , dtypes . void , arg = ShapeTracker ( views = ( View ( shape = ( 384768 , ) , strides = ( 1 , ) , offset = - 128256 , mask = ( ( 128256 , 192384 ) , ) , contiguous = False ) , View ( shape = ( 1 , 501 , 256 ) , strides = ( 0 , 1 , 501 ) , offset = 256512 , mask = None , contiguous = False ) ) ) , src = ( ) ) , ) ) , ) ) , # noqa: E501
UOp ( Ops . LOAD , dtypes . float , arg = None , src = (
UOp ( Ops . DEFINE_GLOBAL , dtypes . float . ptr ( ) , arg = 4 , src = ( ) ) ,
UOp ( Ops . VIEW , dtypes . void , arg = ShapeTracker ( views = ( View ( shape = ( 384768 , ) , strides = ( 1 , ) , offset = - 192384 , mask = ( ( 192384 , 256512 ) , ) , contiguous = False ) , View ( shape = ( 1 , 501 , 256 ) , strides = ( 0 , 1 , 501 ) , offset = 256512 , mask = None , contiguous = False ) ) ) , src = ( ) ) , ) ) , ) ) , # noqa: E501
UOp ( Ops . LOAD , dtypes . float , arg = None , src = (
UOp ( Ops . DEFINE_GLOBAL , dtypes . float . ptr ( ) , arg = 5 , src = ( ) ) ,
UOp ( Ops . VIEW , dtypes . void , arg = ShapeTracker ( views = ( View ( shape = ( 384768 , ) , strides = ( 1 , ) , offset = - 256512 , mask = ( ( 256512 , 320640 ) , ) , contiguous = False ) , View ( shape = ( 1 , 501 , 256 ) , strides = ( 0 , 1 , 501 ) , offset = 256512 , mask = None , contiguous = False ) ) ) , src = ( ) ) , ) ) , ) ) , # noqa: E501
UOp ( Ops . LOAD , dtypes . float , arg = None , src = (
UOp ( Ops . DEFINE_GLOBAL , dtypes . float . ptr ( ) , arg = 6 , src = ( ) ) ,
UOp ( Ops . VIEW , dtypes . void , arg = ShapeTracker ( views = ( View ( shape = ( 384768 , ) , strides = ( 1 , ) , offset = - 320640 , mask = ( ( 320640 , 384768 ) , ) , contiguous = False ) , View ( shape = ( 1 , 501 , 256 ) , strides = ( 0 , 1 , 501 ) , offset = 256512 , mask = None , contiguous = False ) ) ) , src = ( ) ) , ) ) , ) ) , # noqa: E501
ast_const ( dtypes . float , 1.4285714285714286 , st_src = (
UOp ( Ops . VIEW , dtypes . void , arg = ShapeTracker ( views = ( View ( shape = ( 1 , 501 , 256 ) , strides = ( 0 , 0 , 0 ) , offset = 0 , mask = None , contiguous = False ) , ) ) , src = ( ) ) , ) ) , ) ) , ) ) , ) ) , ) ) # noqa: E501
lin = Kernel ( ast )
bufs = bufs_from_lin ( lin )
best_lin = beam_search ( lin , bufs , 2 )
assert best_lin
# need disable_cache to trigger.
tm = time_linearizer ( best_lin , bufs , allow_test_size = False , cnt = 2 , disable_cache = True )
assert tm
def test_beam_unnamed_kernels ( self ) :
a = Tensor . rand ( 100 )
b = Tensor . rand ( 100 )
si = ( a + b ) . schedule ( ) [ - 1 ]
lin = Kernel ( si . ast )
bufs = bufs_from_lin ( lin )
# TODO: beam should have better instrumentation so we don't have to check this indirect thing
kcount = len ( Kernel . kernel_cnt )
beam_search ( lin , bufs , 3 , disable_cache = True )
self . assertEqual ( kcount , len ( Kernel . kernel_cnt ) )
if __name__ == ' __main__ ' :
unittest . main ( )