import unittest , pickle , types
import numpy as np
from tinygrad import Tensor , TinyJit , Variable , dtypes
from tinygrad . helpers import GlobalCounters , ContextVar , Context
from tinygrad . ops import PatternMatcher , UPat , UOp , Ops
class TestPickle ( unittest . TestCase ) :
def test_pickle_code_object ( self ) :
y = lambda x : x * 2 # noqa: E731
code_str = pickle . dumps ( y . __code__ )
fxn = types . FunctionType ( pickle . loads ( code_str ) , globals ( ) )
self . assertEqual ( fxn ( 2 ) , 4 )
def test_pickle_pattern_matcher ( self ) :
pm = PatternMatcher ( [ ( UPat . cvar ( ' x ' ) , lambda x : x * 2 ) ] )
sink = UOp . const ( dtypes . int , 2 )
tt = pm . rewrite ( sink )
pm_str = pickle . dumps ( pm )
pm2 = pickle . loads ( pm_str )
self . assertEqual ( pm2 . rewrite ( sink ) . key , tt . key )
def test_pickle_main_pattern_matcher ( self ) :
from tinygrad . codegen . devectorizer import sym
pickle . dumps ( sym )
def test_pickle_realized_tensor ( self ) :
print ( " ** init " )
t = Tensor . rand ( 10 , 10 ) . realize ( )
st = pickle . dumps ( t )
t_values = t . numpy ( )
del t # free buffers
print ( " ** post pickle " )
init = GlobalCounters . kernel_count
t2 : Tensor = pickle . loads ( st )
np . testing . assert_equal ( t_values , t2 . numpy ( ) )
# expect at most one COPY kernel
self . assertLessEqual ( GlobalCounters . kernel_count - init , 1 )
def test_pickle_realized_tensor_alt ( self ) :
print ( " ** init " )
t = Tensor . rand ( 10 , 10 ) . to ( " CPU " ) . realize ( )
st = pickle . dumps ( t )
t_values = t . numpy ( )
del t # free buffers
print ( " ** post pickle " )
init = GlobalCounters . kernel_count
t2 : Tensor = pickle . loads ( st )
np . testing . assert_equal ( t_values , t2 . numpy ( ) )
self . assertEqual ( GlobalCounters . kernel_count - init , 0 )
def test_pickle_realized_tensor_alt2 ( self ) :
print ( " ** init " )
t = Tensor . rand ( 10 , 10 ) . to ( " CPU " ) . realize ( )
tensor_uop = t . lazydata
assert tensor_uop . is_realized , f " expected { tensor_uop } to be realized "
t_values = t . numpy ( )
# pickle
st = pickle . dumps ( t )
# free buffers
del t
del tensor_uop
print ( " ** post pickle " )
t2 : Tensor = pickle . loads ( st )
assert t2 . lazydata . is_realized , f " expected { t2 . lazydata } to be realized "
np . testing . assert_equal ( t_values , t2 . numpy ( ) )
# NOTE: currently Buffer exists on the uop, not tensor
def test_pickle_buffer_uop ( self ) :
t = Tensor . arange ( 4 ) . realize ( )
a = t . lazydata
assert a . op is Ops . BUFFER
self . assertIsNotNone ( buffer := a . realized )
s = pickle . dumps ( a )
# free buffers
del a
del buffer
a2 : UOp = pickle . loads ( s )
self . assertListEqual ( a2 . realized . as_buffer ( ) . cast ( " I " ) . tolist ( ) , [ 0 , 1 , 2 , 3 ] )
def test_pickle_unrealized_tensor ( self ) :
t = Tensor . ones ( 10 , 10 )
st = pickle . dumps ( t )
t2 : Tensor = pickle . loads ( st )
np . testing . assert_equal ( t . numpy ( ) , t2 . numpy ( ) )
def test_pickle_variable ( self ) :
v = Variable ( " i " , 1 , 20 ) . bind ( 10 )
t1 = Tensor . ones ( 10 , v ) . contiguous ( )
t2 = Tensor . ones ( 10 , v ) . contiguous ( )
ret = ( t1 + t2 ) . sum ( 1 )
st = pickle . dumps ( ret )
del ret
vt2 = pickle . loads ( st )
np . testing . assert_equal ( vt2 . numpy ( ) , 20 )
def test_pickle_buffer_view ( self ) :
t = Tensor . arange ( 10 , device = " CPU " ) . contiguous ( ) . realize ( )
vt = t [ 3 : 5 ] . contiguous ( ) . realize ( )
assert hasattr ( vt . lazydata . buffer , ' base ' )
ref_value = vt . tolist ( )
st = pickle . dumps ( vt )
del t , vt
vt2 = pickle . loads ( st )
assert hasattr ( vt2 . lazydata . buffer , ' base ' )
assert ref_value == vt2 . tolist ( )
def test_pickle_numpy ( self ) :
t = Tensor ( np . array ( [ 1 , 2 , 3 , 4. ] ) , dtype = dtypes . float32 )
st = pickle . dumps ( t )
t2 : Tensor = pickle . loads ( st )
np . testing . assert_equal ( t . numpy ( ) , t2 . numpy ( ) )
def test_pickle_jit ( self ) :
@TinyJit
def add ( a , b ) : return a . sum ( ) + b + 1
for _ in range ( 3 ) : add ( Tensor . rand ( 10 , 10 ) , Tensor . rand ( 10 , 10 ) )
st = pickle . dumps ( add )
del add
add_fxn = pickle . loads ( st )
x = Tensor . ones ( 10 , 10 ) . contiguous ( ) . realize ( )
y = Tensor . ones ( 10 , 10 ) . contiguous ( ) . realize ( )
print ( " post jit " )
out = add_fxn ( x , y )
np . testing . assert_equal ( out . numpy ( ) , 102 )
def test_pickle_context_var ( self ) :
v = ContextVar ( " test_var " , 0 )
with Context ( test_var = 1 ) :
vs = pickle . dumps ( v )
v2 = pickle . loads ( vs )
self . assertEqual ( v2 . value , 1 )
def test_pickle_schedule ( self ) :
a = Tensor ( [ 1 , 2 ] )
out = a + 2
sched = out . schedule ( )
pk = pickle . dumps ( sched )
sched_pk = pickle . loads ( pk )
self . assertEqual ( sched_pk [ - 1 ] . ast , sched [ - 1 ] . ast )
def test_pickle_renderer ( self ) :
from tinygrad . device import Device
pk = pickle . dumps ( Device . default . renderer )
pickle . loads ( pk )
class TestPickleJIT ( unittest . TestCase ) :
@classmethod
def setUpClass ( cls ) :
@TinyJit
def add ( a , b ) : return a . sum ( ) + b + 1
for _ in range ( 3 ) : add ( Tensor . rand ( 1000 , 1000 ) , Tensor . rand ( 1000 , 1000 ) )
cls . st = pickle . dumps ( add )
del add
def test_inspect ( self ) :
import io
class FakeClass :
def __init__ ( self , * args , * * kwargs ) :
print ( self . module , self . name )
class InspectUnpickler ( pickle . Unpickler ) :
def find_class ( self , module , name ) : return type ( " SpecializedFakeClass " , ( FakeClass , ) , { " name " : name , " module " : module } )
InspectUnpickler ( io . BytesIO ( self . st ) ) . load ( )
@unittest . skip ( " we are still saving intermediate buffers " )
def test_size ( self ) :
# confirm no intermediate buffers are saved
self . assertLess ( len ( self . st ) , 1_000_000 )
if __name__ == ' __main__ ' :
unittest . main ( )