import unittest
import numpy as np
from tinygrad import Tensor , GlobalCounters , Context , Device
from tinygrad . dtype import DTypeLike , dtypes
from tinygrad . helpers import DEBUG , get_single_element
from tinygrad . engine . realize import lower_schedule_item
from tinygrad . device import is_dtype_supported
def single_kernel_softmax ( x_in : Tensor , axis = - 1 , dtype : DTypeLike | None = None ) - > Tensor :
# only support axis =-1
x = x_in . reshape ( - 1 , x_in . shape [ - 1 ] )
nr_dim , r_dim = x . shape
inp = x . reshape ( nr_dim , 1 , 1 , r_dim ) . expand ( nr_dim , r_dim , 1 , r_dim )
imx = x . reshape ( nr_dim , 1 , r_dim , 1 ) . expand ( nr_dim , r_dim , r_dim , r_dim ) . max ( axis = - 2 , keepdim = True )
m = inp - imx . detach ( )
if dtype is not None : m = m . cast ( dtype )
e = m . exp ( )
ss = e . sum ( axis = - 1 , keepdim = True )
inp = x . reshape ( nr_dim , r_dim , 1 , 1 )
imx = x . reshape ( nr_dim , 1 , r_dim , 1 ) . expand ( nr_dim , r_dim , r_dim , 1 ) . max ( axis = - 2 , keepdim = True )
m = inp - imx . detach ( )
if dtype is not None : m = m . cast ( dtype )
e = m . exp ( )
out = e . div ( ss ) . reshape ( x_in . shape )
return out
def run_one_schedule_item ( out ) : lower_schedule_item ( get_single_element ( out . schedule ( ) ) ) . run ( )
class TestFuse ( unittest . TestCase ) :
def _test_fuse ( self , fxn , * args , atol = 1e-7 , allow_multiple = False , * * kwargs ) :
GlobalCounters . reset ( )
out_single = fxn ( * args , * * kwargs ) . fuse ( )
if not allow_multiple : run_one_schedule_item ( out_single )
np_single = out_single . numpy ( )
GlobalCounters . reset ( )
np_multi = fxn ( * args , * * kwargs ) . numpy ( )
np . testing . assert_allclose ( np_single , np_multi , atol = atol )
def test_fuse_norm ( self ) :
a = Tensor . rand ( 50 , 50 ) . realize ( )
self . _test_fuse ( lambda a : a / a . mean ( axis = 1 ) , a )
def test_fuse_argmax ( self ) :
a = Tensor . rand ( 50 , 50 ) . realize ( )
self . _test_fuse ( lambda a : a . argmax ( axis = - 1 ) , a )
def test_fuse_softmax ( self ) :
a = Tensor . rand ( 50 , 50 ) . realize ( )
self . _test_fuse ( lambda a : a . softmax ( axis = - 1 ) , a )
def test_fuse_gemm_softmax ( self ) :
a = Tensor . rand ( 50 , 50 ) . realize ( )
b = Tensor . rand ( 50 , 50 ) . realize ( )
self . _test_fuse ( lambda a , b : ( ( a @b ) . relu ( ) + a ) . contiguous ( ) . softmax ( axis = - 1 ) , a , b , allow_multiple = True )
@unittest . skipUnless ( is_dtype_supported ( dtypes . float16 , Device . DEFAULT ) , f " no float16 on { Device . DEFAULT } " )
def test_fuse_softmax_dtype ( self ) :
a = Tensor . rand ( 50 , 50 ) . realize ( )
self . _test_fuse ( lambda a : a . softmax ( axis = - 1 , dtype = ' half ' ) , a , atol = 3e-4 )
def test_fuse_arange_eye ( self ) :
self . _test_fuse ( lambda : Tensor . arange ( 10 ) . reshape ( 10 , 1 ) . expand ( 10 , 10 ) == Tensor . arange ( 10 ) . reshape ( 1 , 10 ) . expand ( 10 , 10 ) )
def test_double_gemm ( self ) :
N = 32
with Context ( TRACK_MATCH_STATS = 0 , DEBUG = 0 ) :
a = ( Tensor . rand ( N , N ) - 0.5 ) . realize ( )
b = ( Tensor . rand ( N , N ) - 0.5 ) . realize ( )
c = ( Tensor . rand ( N , N ) - 0.5 ) . realize ( )
self . _test_fuse ( lambda a , b , c : a @b @c , a , b , c , atol = 1e-5 )
def test_embedding ( self ) :
with Context ( TRACK_MATCH_STATS = 0 , DEBUG = 0 ) :
vocab_sz = 123
embed_sz = 16
weight = ( Tensor . rand ( vocab_sz , embed_sz ) - 0.5 ) . realize ( )
a = Tensor ( [ 1 , 1 , 2 , 3 ] ) . realize ( )
def embedding ( idx : Tensor ) :
arange = Tensor . arange ( vocab_sz ) . unsqueeze ( - 1 )
big_shp = idx . shape + ( vocab_sz , embed_sz )
arange , vals = arange . expand ( big_shp ) , weight . expand ( big_shp )
idx = idx . reshape ( idx . shape + ( 1 , 1 ) ) . expand ( big_shp )
return ( arange == idx ) . mul ( vals ) . sum ( - 2 , dtype = vals . dtype )
self . _test_fuse ( embedding , a , atol = 1e-5 )
@unittest . skip ( " still broken " )
def test_flash_attention ( self ) :
BS = 4
HEADS = 2
MATDIM = 16
EMB = 8
with Context ( TRACK_MATCH_STATS = 0 , DEBUG = 0 ) :
q = Tensor . randn ( BS , HEADS , MATDIM , EMB ) . realize ( )
k = Tensor . randn ( BS , HEADS , MATDIM , EMB ) . realize ( )
v = Tensor . randn ( BS , HEADS , MATDIM , EMB ) . realize ( )
# TODO: OPT is breaking things. NOOPT isn't linearizing
with Context ( NOOPT = 1 ) :
self . _test_fuse ( Tensor . scaled_dot_product_attention , q , k , v )
class TestSoftmaxFusion ( unittest . TestCase ) :
@classmethod
def setUpClass ( cls ) :
with Context ( TRACK_MATCH_STATS = 0 ) : cls . test = Tensor . rand ( 32 , 10 ) . contiguous ( ) . realize ( )
def setUp ( self ) :
GlobalCounters . reset ( )
def test_norm ( self ) :
print ( " *** norm *** " )
with Context ( NOOPT = 1 , DEBUG = max ( DEBUG . value , 2 ) ) :
# NOTE: there's an implied expand on the mean here
sout = self . test / self . test . mean ( - 1 , keepdim = True )
sout . realize ( )
print ( " *** single kernel norm *** " )
with Context ( NOOPT = 1 , DEBUG = max ( DEBUG . value , 2 ) ) :
inp = self . test . reshape ( 32 , 10 , 1 )
div = self . test . reshape ( 32 , 1 , 10 ) . expand ( 32 , 10 , 10 ) . mean ( axis = - 1 , keepdim = True )
out = ( inp / div ) . reshape ( 32 , 10 )
out . realize ( )
np . testing . assert_allclose ( sout . numpy ( ) , out . numpy ( ) )
def test_softmax ( self ) :
# this is the softmax from scaled_dot_product_attention
# it becomes 3 kernels
print ( " *** softmax *** " )
with Context ( NOOPT = 1 , DEBUG = max ( DEBUG . value , 2 ) ) :
sout = self . test . softmax ( - 1 )
sout . realize ( )
print ( " *** single kernel softmax *** " )
# NOTE: DONT_GROUP_REDUCES is required here
with Context ( NOOPT = 1 , DEBUG = max ( DEBUG . value , 2 ) , DONT_GROUP_REDUCES = 1 ) :
out = single_kernel_softmax ( self . test )
out . realize ( )
np . testing . assert_allclose ( sout . numpy ( ) , out . numpy ( ) )
def test_auto_softmax ( self ) :
print ( " *** softmax *** " )
with Context ( NOOPT = 1 , DEBUG = max ( DEBUG . value , 2 ) ) :
sout = self . test . softmax ( - 1 )
sout . realize ( )
print ( " *** auto single kernel softmax *** " )
with Context ( NOOPT = 1 , DEBUG = max ( DEBUG . value , 2 ) ) :
out = self . test . contiguous ( ) . softmax ( - 1 ) . fuse ( )
run_one_schedule_item ( out )
np . testing . assert_allclose ( sout . numpy ( ) , out . numpy ( ) )
@unittest . skip ( " recursion error no longer raised " )
def test_softmax_bw ( self ) :
print ( " *** softmax bw *** " )
self . test . requires_grad_ ( )
with Context ( NOOPT = 1 , DEBUG = max ( DEBUG . value , 2 ) ) :
self . test . softmax ( - 1 ) . sum ( ) . backward ( )
sg = self . test . grad . realize ( )
self . test . grad = None
print ( " *** single kernel softmax bw *** " )
# NOTE: DONT_GROUP_REDUCES is required here
# TODO: fix RecursionError with DONT_GROUP_REDUCES
with self . assertRaises ( RecursionError ) :
with Context ( NOOPT = 1 , DEBUG = max ( DEBUG . value , 2 ) , DONT_GROUP_REDUCES = 1 ) :
single_kernel_softmax ( self . test ) . sum ( ) . backward ( )
g = self . test . grad . realize ( )
np . testing . assert_allclose ( sg . numpy ( ) , g . numpy ( ) , atol = 1e-7 )
if __name__ == ' __main__ ' :
unittest . main ( )