import unittest
from tinygrad import Tensor
from tinygrad . helpers import getenv , GlobalCounters
from tinygrad . engine . realize import lower_schedule_item , ProgramSpec
from tinygrad . renderer import Estimates
from tinygrad . codegen . linearize import linearize_uop
from tinygrad . ops import Ops , UOp
from tinygrad . dtype import dtypes
from tinygrad . codegen . kernel import Kernel , Opt , OptOps , KernelOptError
from tinygrad . device import Device
def flops_mem ( uops , ignore_indexing = False ) :
est = Estimates . from_uops ( uops , ignore_indexing )
return est . ops , est . lds
# **************** new FlopCounter ****************
def get_stats ( x : Tensor ) :
si = x . schedule ( ) [ - 1 ]
ei = lower_schedule_item ( si )
return ei . prg . estimates . ops , ei . prg . estimates . mem
class TestMemoryCount ( unittest . TestCase ) :
def test_add ( self ) :
a = Tensor . empty ( 1024 , 1024 , dtype = dtypes . uint8 )
b = Tensor . empty ( 1024 , 1024 , dtype = dtypes . uint8 )
_ , mem = get_stats ( a + b )
self . assertEqual ( mem , 1024 * 1024 * 3 ) # 2 reads + 1 write
def test_add_const ( self ) :
a = Tensor . empty ( 1024 , 1024 , dtype = dtypes . uint8 )
_ , mem = get_stats ( a + 3 )
self . assertEqual ( mem , 1024 * 1024 * 2 ) # 1 read + 1 write
def test_add_slice ( self ) :
a = Tensor . empty ( 1024 , 1024 , dtype = dtypes . uint8 ) [ : 512 ]
_ , mem = get_stats ( a + 3 )
self . assertEqual ( mem , 512 * 1024 * 2 ) # 1 read + 1 write
def test_expanded ( self ) :
a = Tensor . empty ( 1024 , 1 , dtype = dtypes . uint8 ) . expand ( 1024 , 1024 )
b = Tensor . empty ( 1024 , 1024 , dtype = dtypes . uint8 )
_ , mem = get_stats ( a + b )
self . assertEqual ( mem , 1024 * 1024 * 2 + 1024 ) # 1 full read + 1 lil read + 1 write
def test_both_expanded ( self ) :
# TODO: this probably should be a full write
a = Tensor . empty ( 1024 , 1 , dtype = dtypes . uint8 ) . expand ( 1024 , 1024 )
b = Tensor . empty ( 1024 , 1 , dtype = dtypes . uint8 ) . expand ( 1024 , 1024 )
_ , mem = get_stats ( a + b )
self . assertEqual ( mem , 1024 * 1024 + 2 * 1024 ) # 2 lil reads + 1 write
def test_self_add ( self ) :
a = Tensor . empty ( 1024 , 1024 , dtype = dtypes . uint8 )
_ , mem = get_stats ( a + a )
self . assertEqual ( mem , 1024 * 1024 * 2 ) # 1 read + 1 write
def test_self_add_transposed ( self ) :
a = Tensor . empty ( 1024 , 1024 , dtype = dtypes . uint8 )
_ , mem = get_stats ( a + a . T )
self . assertEqual ( mem , 1024 * 1024 * 2 ) # 1 read + 1 write
def test_self_add_assign ( self ) :
a = Tensor . empty ( 1024 , 1024 , dtype = dtypes . uint8 ) . realize ( )
_ , mem = get_stats ( a . assign ( a + a ) )
self . assertEqual ( mem , 1024 * 1024 * 2 ) # 1 read + 1 write
@unittest . skipIf ( Device . DEFAULT == " CPU " , " test copy to CPU from other device " )
def test_copyout ( self ) :
a = Tensor . empty ( 32 , dtype = dtypes . uint8 ) . to ( " CPU " )
_ , mem = get_stats ( a )
self . assertEqual ( mem , 32 * 1 )
a = Tensor . empty ( 32 , dtype = dtypes . uint32 ) . to ( " CPU " )
_ , mem = get_stats ( a )
self . assertEqual ( mem , 32 * 4 )
# NOTE: this still isn't testing unroll using the acc
@unittest . skipUnless ( getenv ( " PYTHON " ) , " only run test on emulated tensor cores " )
class TestUOpsStatsMatmulHalf ( unittest . TestCase ) :
def test_simple_matmul_half ( self , N = 16 ) :
GlobalCounters . reset ( )
a , b = Tensor . empty ( N , N , dtype = dtypes . half ) , Tensor . empty ( N , N , dtype = dtypes . half )
c = a . matmul ( b )
c . realize ( )
expected_ops = N * * 3 * 2
self . assertEqual ( expected_ops , GlobalCounters . global_ops )
def test_bigger_matmul_half ( self ) : self . test_simple_matmul_half ( 64 )
def test_batched_matmul_half ( self , N = 16 ) :
GlobalCounters . reset ( )
a , b = Tensor . empty ( 4 , N , N , dtype = dtypes . half ) , Tensor . empty ( 1 , N , N , dtype = dtypes . half )
c = a . matmul ( b )
c . realize ( )
expected_ops = 4 * N * * 3 * 2
self . assertEqual ( expected_ops , GlobalCounters . global_ops )
class TestUOpsStats ( unittest . TestCase ) :
@unittest . skipIf ( getenv ( " PTX " ) , " wrong in PTX " )
def test_simple_add ( self ) :
a = Tensor . empty ( 100 , 100 )
b = Tensor . empty ( 100 , 100 )
c = a + b
ops , mem = get_stats ( c )
expected_ops = c . numel ( )
expected_mem = a . nbytes ( ) + b . nbytes ( ) + c . nbytes ( )
self . assertEqual ( mem , expected_mem )
# NOTE; ops also include indexing ops
assert expected_ops < = ops and ops < = expected_ops * 2
@unittest . skipIf ( getenv ( " PTX " ) , " wrong in PTX " )
def test_simple_add_sq ( self ) :
a = Tensor . empty ( 100 , 100 )
b = Tensor . empty ( 100 , 100 )
c = ( a + b ) * ( a + b )
ops , mem = get_stats ( c )
expected_ops = c . numel ( ) * 2
expected_mem = a . nbytes ( ) + b . nbytes ( ) + c . nbytes ( )
self . assertEqual ( mem , expected_mem )
# NOTE; ops also include indexing ops
assert expected_ops < = ops and ops < = expected_ops * 2
def test_simple_matmul ( self ) :
a = Tensor . empty ( 1024 , 1024 )
b = Tensor . empty ( 1024 , 1024 )
c = a @b
ops , mem = get_stats ( c )
expected_ops = c . numel ( ) * 1024 * 2
required_mem = a . nbytes ( ) + b . nbytes ( ) + c . nbytes ( )
assert expected_ops < = ops and ops < = expected_ops * 1.2
# NOTE: it's hard to assert on the memory here, all depends on caching
assert required_mem < = mem
#MULACC should have the same stats as MUL + ADD
def test_mulacc ( self ) :
globl = UOp ( Ops . DEFINE_GLOBAL , dtypes . int . ptr ( ) , tuple ( ) )
o1 = UOp ( Ops . CONST , dtypes . int , tuple ( ) , 1 )
o2 = UOp ( Ops . CONST , dtypes . int , tuple ( ) , 2 )
u1 = UOp ( Ops . LOAD , dtypes . int , ( globl . index ( o1 ) , ) )
u2 = UOp ( Ops . LOAD , dtypes . int , ( globl . index ( o2 ) , ) )
u3 = UOp ( Ops . CONST , dtypes . int , tuple ( ) , 3 )
u4 = UOp ( Ops . MUL , dtypes . int , ( u1 , u2 ) )
u5 = UOp ( Ops . ADD , dtypes . int , ( u4 , u3 ) )
uops = linearize_uop ( u5 . sink ( ) )
globl = UOp ( Ops . DEFINE_GLOBAL , dtypes . int . ptr ( ) , tuple ( ) )
o1 = UOp ( Ops . CONST , dtypes . int , tuple ( ) , 1 )
o2 = UOp ( Ops . CONST , dtypes . int , tuple ( ) , 2 )
u1 = UOp ( Ops . LOAD , dtypes . int , ( globl . index ( o1 ) , ) )
u2 = UOp ( Ops . LOAD , dtypes . int , ( globl . index ( o2 ) , ) )
u3 = UOp ( Ops . CONST , dtypes . int , tuple ( ) , 3 )
u4 = UOp ( Ops . MULACC , dtypes . int , ( u1 , u2 , u3 ) )
uops_fma = linearize_uop ( u4 . sink ( ) )
self . assertEqual ( flops_mem ( uops ) , flops_mem ( uops_fma ) )
N = 100
@unittest . skipIf ( getenv ( " PTX " ) , " wrong in PTX " ) # maybe?
class TestStatsOptimized ( unittest . TestCase ) :
@classmethod
def setUpClass ( cls ) :
cls . ast_gemm = ( Tensor . empty ( N , N ) @ Tensor . empty ( N , N ) ) . schedule ( ) [ - 1 ] . ast
cls . ast_reduce = ( Tensor . empty ( N * N ) . sum ( ) ) . schedule ( ) [ - 1 ] . ast
def check_gemm ( self , p : ProgramSpec , extra_flops = 0 ) :
#p.uops.print()
#print(p.src)
print ( p . name , p . estimates . ops , p . estimates . mem , p . estimates . lds )
self . assertEqual ( p . estimates . ops , 2 * N * N * N + extra_flops ) # N**3 mulaccs
self . assertEqual ( p . estimates . mem , 3 * N * N * 4 ) # 3 NxN mats with floats
def test_gemm ( self ) :
p = Kernel ( self . ast_gemm ) . to_program ( )
self . check_gemm ( p )
self . assertEqual ( p . estimates . lds , 2 * N * N * N * 4 + 4 * N * N )
# this is a good lesson about why UPCASTing is a good idea
def test_gemm_one_upcasted ( self ) :
k = Kernel ( self . ast_gemm )
k . apply_opt ( Opt ( OptOps . UPCAST , 0 , 4 ) )
p = k . to_program ( )
self . check_gemm ( p )
self . assertEqual ( p . estimates . lds , N * N * N * 4 + N * N * N * 4 / / 4 + 4 * N * N )
def test_gemm_upcasted ( self ) :
k = Kernel ( self . ast_gemm )
k . apply_opt ( Opt ( OptOps . UPCAST , 0 , 4 ) )
k . apply_opt ( Opt ( OptOps . UPCAST , 1 , 4 ) )
k . apply_opt ( Opt ( OptOps . UNROLL , 0 , 4 ) )
p = k . to_program ( )
self . check_gemm ( p )
self . assertEqual ( p . estimates . lds , 2 * N * N * N * 4 / / 4 + 4 * N * N )
def test_gemm_upcasted_locals ( self ) :
k = Kernel ( self . ast_gemm )
k . apply_opt ( Opt ( OptOps . UPCAST , 0 , 4 ) )
k . apply_opt ( Opt ( OptOps . UPCAST , 1 , 4 ) )
try :
k . apply_opt ( Opt ( OptOps . LOCAL , 0 , 5 ) )
k . apply_opt ( Opt ( OptOps . LOCAL , 1 , 5 ) )
except KernelOptError :
raise unittest . SkipTest ( " no locals " )
p = k . to_program ( )
self . check_gemm ( p )
self . assertEqual ( p . estimates . lds , 2 * N * N * N * 4 / / 4 + 4 * N * N )
def test_gemm_group ( self ) :
k = Kernel ( self . ast_gemm )
try :
k . apply_opt ( Opt ( OptOps . GROUP , 0 , 4 ) )
except KernelOptError :
raise unittest . SkipTest ( " no locals " )
SZ = N * N * 4
p = k . to_program ( )
# NOTE: these are sort of wrong. they aren't honoring the IF statement
self . check_gemm ( p , extra_flops = SZ * 4 )
self . assertEqual ( p . estimates . lds , 2 * N * N * N * 4 + SZ * 4 + ( SZ * 4 + 4 * N * N ) * 4 )
def test_reduce ( self ) :
k = Kernel ( self . ast_reduce )
p = k . to_program ( )
print ( p . name , p . estimates . ops , p . estimates . mem , p . estimates . lds )
self . assertEqual ( p . estimates . ops , N * N )
self . assertEqual ( p . estimates . mem , N * N * 4 + 4 )
def test_reduce_group ( self ) :
k = Kernel ( self . ast_reduce )
try :
k . apply_opt ( Opt ( OptOps . GROUP , 0 , 50 ) )
except KernelOptError :
raise unittest . SkipTest ( " no locals " )
p = k . to_program ( )
# NOTE: these are wrong, they don't respect the if statement
print ( p . name , p . estimates . ops , p . estimates . mem , p . estimates . lds )
if __name__ == ' __main__ ' :
unittest . main ( verbosity = 2 )