import unittest , contextlib
import numpy as np
from tinygrad import Tensor , GlobalCounters , dtypes , nn , Device
from tinygrad . helpers import CI , Context , getenv
from tinygrad . engine . realize import run_schedule
from tinygrad . codegen . kernel import Opt , OptOps , Kernel , KernelOptError
from tinygrad . engine . realize import CompiledRunner , ExecItem
from tinygrad . engine . search import get_kernel_actions
from tinygrad . ops import Ops
class TestArange ( unittest . TestCase ) :
def _get_flops ( self , N , opts = None ) :
GlobalCounters . reset ( )
tt = Tensor . arange ( N )
sched = tt . schedule ( )
self . assertEqual ( len ( sched ) , 1 )
k = Kernel ( sched [ - 1 ] . ast )
if opts is not None :
for o in opts : k . apply_opt ( o )
p = k . to_program ( )
print ( p . name )
#print(p.src)
ExecItem ( CompiledRunner ( p ) , [ tt . lazydata . buffer ] ) . run ( )
np . testing . assert_equal ( tt . numpy ( ) , np . arange ( N ) )
return p . estimates . ops
def test_complexity ( self , opts = None , limit = None ) :
f1 = self . _get_flops ( 256 , opts )
f2 = self . _get_flops ( 2560 , opts )
print ( f " { f1 =} , { f2 =} " )
# add 1 to avoid divide by 0. arange is 0 flops now!
assert ( f1 < 6000 and f2 < 6000 ) or ( ( f2 + 1 ) / ( f1 + 1 ) < 16 ) , f " bad complexity, flops { ( f2 + 1 ) / ( f1 + 1 ) : .1f } X while inputs 10X "
if limit is not None and not getenv ( " PTX " ) :
# PTX counts index ALU in flops
assert f1 < = limit , f " { f1 =} , { limit =} "
def test_complexity_w_upcast ( self ) : return self . test_complexity ( [ Opt ( OptOps . UPCAST , 0 , 4 ) ] , limit = 0 )
def test_complexity_w_unroll2 ( self ) : return self . test_complexity ( [ Opt ( OptOps . UNROLL , 0 , 2 ) ] , limit = 0 )
def test_complexity_w_unroll4 ( self ) : return self . test_complexity ( [ Opt ( OptOps . UNROLL , 0 , 4 ) ] , limit = 0 )
def test_complexity_w_unroll8 ( self ) : return self . test_complexity ( [ Opt ( OptOps . UNROLL , 0 , 8 ) ] , limit = 0 )
def test_complexity_w_upcast_and_unroll ( self ) : return self . test_complexity ( [ Opt ( OptOps . UPCAST , 0 , 4 ) , Opt ( OptOps . UNROLL , 0 , 4 ) ] , limit = 0 )
if Device . default . renderer . has_local :
# TODO: fix limit
def test_complexity_w_group ( self ) : return self . test_complexity ( [ Opt ( OptOps . GROUP , 0 , 16 ) ] , limit = 81920 )
def test_complexity_w_group_top ( self ) : return self . test_complexity ( [ Opt ( OptOps . GROUPTOP , 0 , 16 ) ] , limit = 106496 )
def test_complexity_w_local ( self ) : return self . test_complexity ( [ Opt ( OptOps . LOCAL , 0 , 16 ) ] , limit = 0 )
@unittest . skip ( " doesn ' t work yet " )
def test_complexity_w_local_and_padto ( self ) : return self . test_complexity ( [ Opt ( OptOps . LOCAL , 0 , 16 ) , Opt ( OptOps . PADTO , axis = 1 , arg = 32 ) ] )
def test_all_opts ( self , opts = None , exclude = None ) :
k = Kernel ( Tensor . arange ( 256 ) . schedule ( ) [ - 1 ] . ast )
if opts is not None :
for o in opts : k . apply_opt ( o )
all_opts_256 = [ kk . applied_opts for kk in get_kernel_actions ( k , include_0 = False ) . values ( ) ]
k = Kernel ( Tensor . arange ( 2560 ) . schedule ( ) [ - 1 ] . ast )
if opts is not None :
for o in opts : k . apply_opt ( o )
all_opts_2560 = [ kk . applied_opts for kk in get_kernel_actions ( k , include_0 = False ) . values ( ) ]
all_opts = [ x for x in all_opts_256 if x in all_opts_2560 ]
for opts in all_opts :
if exclude is not None and opts [ - 1 ] in exclude : continue
print ( opts )
self . test_complexity ( opts )
def test_all_opts_w_local ( self ) :
with contextlib . suppress ( KernelOptError ) :
return self . test_all_opts ( [ Opt ( OptOps . LOCAL , 0 , 16 ) ] , [ Opt ( op = OptOps . PADTO , axis = 1 , arg = 32 ) ] )
def test_all_opts_w_upcast ( self ) : return self . test_all_opts ( [ Opt ( OptOps . UPCAST , 0 , 4 ) ] )
def test_all_opts_w_unroll ( self ) : return self . test_all_opts ( [ Opt ( OptOps . UNROLL , 0 , 4 ) ] , [ Opt ( op = OptOps . GROUP , axis = 0 , arg = 0 ) ] )
def test_all_opts_w_upcast_and_unroll ( self ) :
return self . test_all_opts ( [ Opt ( OptOps . UPCAST , 0 , 4 ) , Opt ( OptOps . UNROLL , 0 , 4 ) ] , [ Opt ( op = OptOps . GROUP , axis = 0 , arg = 0 ) ] )
class TestIndexing ( unittest . TestCase ) :
def test_arange_2_reduce ( self ) :
needle = Tensor . zeros ( 16384 , dtype = dtypes . int ) . contiguous ( )
needle [ 1337 ] = 1
needle . realize ( )
with Context ( NOOPT = 1 , FUSE_ARANGE = 1 ) :
GlobalCounters . reset ( )
out = ( ( Tensor . arange ( 1 , 16385 ) - 1 ) * needle ) . sum ( )
sched = out . schedule ( )
self . assertEqual ( len ( sched ) , 1 )
run_schedule ( sched )
self . assertEqual ( out . item ( ) , 1337 )
@unittest . skipIf ( getenv ( " PTX " ) , " broken on ptx for some reason " )
def test_manual_index ( self ) :
dataset = Tensor . rand ( 16384 , 256 ) . realize ( )
idxs = Tensor ( [ 0 , 3 , 5 , 6 ] ) . realize ( )
real_index = dataset . numpy ( ) [ idxs . numpy ( ) ]
print ( " *** indexing *** " )
with Context ( NOOPT = 1 , FUSE_ARANGE = 1 ) :
GlobalCounters . reset ( )
rng = Tensor . ones ( 4 , 256 , 16384 , dtype = dtypes . int ) . _cumalu ( axis = - 1 , op = Ops . ADD , _include_initial = True ) . reshape ( 4 , 256 , 16384 , 1 )
idxs = idxs . reshape ( 4 , 1 , 1 , 1 ) . expand ( 4 , 256 , 16384 , 1 )
reshape_dataset = dataset . T . reshape ( 1 , 256 , 16384 , 1 ) . expand ( 4 , 256 , 16384 , 1 )
full = ( rng == idxs ) . where ( reshape_dataset , Tensor . zeros ( 4 , 256 , 16384 , 1 ) )
X = full . sum ( axis = ( 2 , 3 ) )
sched = X . schedule ( )
self . assertEqual ( len ( sched ) , 1 )
run_schedule ( sched )
assert GlobalCounters . global_ops < 4 * 16384 , f " too many ops { GlobalCounters . global_ops } "
np . testing . assert_allclose ( real_index , X . numpy ( ) )
def test_index ( self ) :
dataset = Tensor . rand ( 16384 , 256 ) . realize ( )
idxs = Tensor ( [ 0 , 3 , 5 , 6 ] ) . realize ( )
real_index = dataset . numpy ( ) [ idxs . numpy ( ) ]
print ( " *** indexing *** " )
with Context ( NOOPT = 1 ) :
GlobalCounters . reset ( )
X = dataset [ idxs ]
assert X . shape == ( 4 , 256 )
sched = X . schedule ( )
# TODO: enable these asserts when the scheduler can handle this
#self.assertEqual(len(sched), 1)
run_schedule ( sched )
#assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops}"
np . testing . assert_allclose ( real_index , X . numpy ( ) )
def test_index_fused ( self , noopt = 1 ) :
dataset = Tensor . rand ( 16384 , 256 ) . realize ( )
idxs = Tensor ( [ 0 , 3 , 5 , 6 ] ) . realize ( )
real_index = dataset . numpy ( ) [ idxs . numpy ( ) ]
print ( " *** indexing *** " )
with Context ( NOOPT = noopt , FUSE_ARANGE = 1 ) :
GlobalCounters . reset ( )
X = dataset [ idxs ]
assert X . shape == ( 4 , 256 )
sched = X . schedule ( )
self . assertEqual ( len ( sched ) , 2 )
run_schedule ( sched )
assert GlobalCounters . global_ops < 4 * 16384 , f " too many ops { GlobalCounters . global_ops } != { 4 * 16384 } "
np . testing . assert_allclose ( real_index , X . numpy ( ) )
@unittest . skip ( " not ready " )
def test_index_fused_opt ( self ) : self . test_index_fused ( 0 )
def test_index_fused_out_of_bounds ( self ) :
dataset = Tensor . rand ( 256 , 256 ) . realize ( )
idxs = Tensor ( [ - 19238 , - 257 , 256 , 495 , 10982377 ] ) . realize ( )
with Context ( NOOPT = 1 , FUSE_ARANGE = 1 ) :
X = dataset [ idxs ]
np . testing . assert_equal ( X . numpy ( ) , 0 )
@unittest . skipIf ( getenv ( " PTX " ) , " broken on ptx for some reason " )
def test_index_mnist ( self , noopt = 1 , op_limit = 512 * 784 * 13 ) :
from tinygrad . nn . datasets import mnist
X_train , Y_train , _ , _ = mnist ( )
with Context ( NOOPT = noopt , FUSE_ARANGE = 1 , SPLIT_REDUCEOP = 0 ) :
samples = Tensor . randint ( getenv ( " BS " , 512 ) , high = X_train . shape [ 0 ] ) . realize ( )
GlobalCounters . reset ( )
x = X_train [ samples ] . numpy ( )
y = Y_train [ samples ] . numpy ( )
assert GlobalCounters . global_ops < op_limit , f " too many ops { GlobalCounters . global_ops } != { op_limit } "
np . testing . assert_allclose ( X_train . numpy ( ) [ samples . numpy ( ) ] , x )
np . testing . assert_allclose ( Y_train . numpy ( ) [ samples . numpy ( ) ] , y )
@unittest . skip ( " not ready " )
def test_index_mnist_opt ( self ) : self . test_index_mnist ( 0 )
@unittest . skipIf ( getenv ( " PTX " ) , " broken on ptx for some reason " )
def test_llama_embedding ( self , noopt = 1 , op_limit = 65536 ) :
# llama3 is 128256
vocab_size , embed_size = ( 10 , 3 ) if CI else ( 32000 , 4096 )
emb = nn . Embedding ( vocab_size , embed_size )
# TODO: why is a new realize needed here
emb_w = emb . weight . realize ( ) . numpy ( )
x = Tensor ( [ 1 , 2 , 3 , 4 ] )
with Context ( NOOPT = noopt , FUSE_ARANGE = 1 ) :
GlobalCounters . reset ( )
z = emb ( x ) . realize ( )
self . assertLessEqual ( GlobalCounters . global_ops , op_limit )
self . assertEqual ( GlobalCounters . kernel_count , 2 )
if getenv ( " CHECK " , 1 ) :
import torch
with torch . no_grad ( ) :
torch_emb = torch . nn . Embedding ( vocab_size , embed_size ) . eval ( )
torch_emb . weight [ : ] = torch . tensor ( emb_w , dtype = torch . float32 )
torch_z = torch_emb ( torch . tensor ( x . numpy ( ) ) )
# TODO: reshape to match torch, should we do this in nn?
np . testing . assert_allclose ( z . numpy ( ) . reshape ( 4 , embed_size ) , torch_z . detach ( ) . numpy ( ) , atol = 1e-8 , rtol = 1e-8 )
# at least the arange is being fused
def test_llama_embedding_opt ( self ) : self . test_llama_embedding ( 0 , 1_736_704_000 if CI else 5_898_240_000 )
if __name__ == " __main__ " :
unittest . main ( )