from typing import TypeVar , Generic , Callable , Union , cast , Optional , Any
import functools , collections
from tinygrad . tensor import Tensor
from tinygrad . helpers import flatten , merge_dicts , DEBUG , Context , BEAM , getenv , colored , JIT , dedup , partition , unwrap
from tinygrad . device import Buffer , Compiled , Device
from tinygrad . dtype import DType
from tinygrad . ops import UOp , Variable , sym_infer , Ops
from tinygrad . shape . shapetracker import ShapeTracker
from tinygrad . engine . realize import ExecItem , capturing , ViewOp , BufferCopy , BufferXfer , CompiledRunner , Runner , Estimates
from tinygrad . engine . memory import _internal_memory_planner
from tinygrad . nn . state import get_parameters
from dataclasses import dataclass
from weakref import WeakKeyDictionary
class GraphException ( Exception ) : pass
def apply_graph_to_jit ( jit_cache : list [ ExecItem ] , input_rawbuffers : list [ Buffer ] , var_vals : dict [ Variable , int ] , max_batch_size = 0 ) - > list [ ExecItem ] :
# Split JIT cache into batches for faster graph execution.
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
graphed_jit_cache : list [ ExecItem ] = [ ]
current_batch : list [ ExecItem ] = [ ]
current_device : Optional [ Compiled ] = None
def flush_batch ( ) :
nonlocal current_batch , current_device , max_batch_size
try :
if current_device is None : raise GraphException ( " no device for graph " )
if len ( current_batch ) < = 1 and not getenv ( " GRAPH_ONE_KERNEL " ) : raise GraphException ( " only one kernel doesn ' t graph " )
graph_runner = current_device . graph ( current_batch , input_rawbuffers , var_vals )
# clear jit inputs to allow their memory to be freed/reused
for ( j , i ) in graph_runner . input_replace . keys ( ) : graph_runner . jit_cache [ j ] . bufs [ i ] = None
graphed_jit_cache . append ( ExecItem ( graph_runner , cast ( list [ Optional [ Buffer ] ] , input_rawbuffers ) ) )
max_batch_size * = 2
if DEBUG > = 2 : print ( f " JIT GRAPHing batch with { len ( current_batch ) } kernels on device { current_device } " )
except GraphException as e :
graphed_jit_cache . extend ( current_batch )
if DEBUG > = 2 : print ( f " JIT GRAPHing failed batch with { len ( current_batch ) } kernels on device { current_device } : { e } " )
current_batch = [ ]
current_device = None
for ji in jit_cache :
if isinstance ( ji . prg , ViewOp ) : continue
ji_graph_dev : Optional [ Compiled ] = None # device on which the ji will be graphed. Not graphed if None.
if isinstance ( ji . prg , CompiledRunner ) : ji_graph_dev = ji . prg . dev
elif isinstance ( ji . prg , BufferXfer ) and ji . bufs [ 0 ] and ji . bufs [ 0 ] . device . split ( " : " , 1 ) [ 0 ] in { " CUDA " , " NV " , " AMD " } :
ji_graph_dev = Device [ ji . bufs [ 0 ] . device ]
graph_class = ( ji_graph_dev . graph . func if isinstance ( ji_graph_dev . graph , functools . partial ) else ji_graph_dev . graph ) if ji_graph_dev else None
can_be_graphed = ji_graph_dev and ji_graph_dev . graph
can_share_graph = ( ji_graph_dev == current_device or ( isinstance ( graph_class , type ) and issubclass ( graph_class , MultiGraphRunner ) ) and
type ( ji_graph_dev ) is type ( current_device ) )
can_extend_graph_batch = can_be_graphed and ( max_batch_size == 0 or len ( current_batch ) < max_batch_size ) and can_share_graph
if not can_extend_graph_batch and len ( current_batch ) > 0 : flush_batch ( )
if can_be_graphed : current_batch . append ( ji )
else : graphed_jit_cache . append ( ji )
current_device = ji_graph_dev
if len ( current_batch ) > 0 : flush_batch ( )
return graphed_jit_cache
def get_input_replace ( jit_cache : list [ ExecItem ] , input_rawbuffers : list [ Buffer ] ) - > dict [ tuple [ int , int ] , int ] :
input_replace : dict [ tuple [ int , int ] , int ] = { }
for j , ji in enumerate ( jit_cache ) :
for i , a in enumerate ( ji . bufs ) :
if a in input_rawbuffers :
input_replace [ ( j , i ) ] = input_rawbuffers . index ( a )
return input_replace
class GraphRunner ( Runner ) :
def __init__ ( self , jit_cache : list [ ExecItem ] , input_rawbuffers : list [ Buffer ] , var_vals : dict [ Variable , int ] ) :
self . jit_cache = jit_cache # NOTE: this is not used, but you have to keep these objects alive for the Graph
self . input_replace : dict [ tuple [ int , int ] , int ] = get_input_replace ( jit_cache , input_rawbuffers )
self . var_vals_replace : dict [ int , list [ int ] ] = { }
self . launch_dims_replace : dict [ int , tuple [ Optional [ int ] , Optional [ int ] ] ] = { }
self . launch_dims_base : dict [ int , tuple [ tuple [ int , . . . ] , tuple [ int , . . . ] ] ] = { }
def is_sym_dim ( dim ) - > bool : return not all ( isinstance ( d , ( int , float ) ) for d in dim )
self . vars = sorted ( var_vals . keys ( ) , key = lambda v : v . expr )
self . symbolic_dims = dedup ( [ tuple ( d ) for ji in jit_cache if isinstance ( ji . prg , CompiledRunner ) and ( d := ji . prg . p . local_size ) and is_sym_dim ( d ) ] +
[ tuple ( d ) for ji in jit_cache if isinstance ( ji . prg , CompiledRunner ) and ( d := ji . prg . p . global_size ) and is_sym_dim ( d ) ] )
def find_symbolic_dim ( dim ) : return self . symbolic_dims . index ( tuple ( dim ) ) if dim is not None and tuple ( dim ) in self . symbolic_dims else None
estimates = Estimates ( )
for j , ji in enumerate ( jit_cache ) :
estimates + = ji . prg . estimates
if isinstance ( ji . prg , CompiledRunner ) :
if ji . prg . p . vars : self . var_vals_replace [ j ] = [ self . vars . index ( v ) for v in ji . prg . p . vars ]
global_dim_idx , local_dim_idx = find_symbolic_dim ( ji . prg . p . global_size ) , find_symbolic_dim ( ji . prg . p . local_size )
if global_dim_idx is not None or local_dim_idx is not None :
self . launch_dims_replace [ j ] = ( global_dim_idx , local_dim_idx )
assert ji . prg . p . global_size is not None and ji . prg . p . local_size is not None
self . launch_dims_base [ j ] = ( tuple ( ji . prg . p . global_size ) , tuple ( ji . prg . p . local_size ) )
# used in MultiGraphRunner. the ints are id() of _bufs
self . w_dependency_map : dict [ int , Any ] = { }
self . r_dependency_map : dict [ int , list [ Any ] ] = collections . defaultdict ( list )
super ( ) . __init__ ( colored ( f " <batched { len ( jit_cache ) } > " , " cyan " ) , jit_cache [ 0 ] . prg . device . split ( " : " ) [ 0 ] , estimates . simplify ( ) )
def updated_vars ( self , var_vals : dict [ Variable , int ] ) :
vals = [ var_vals [ v ] for v in self . vars ]
for j , vidxs in self . var_vals_replace . items ( ) :
for i , v in enumerate ( vidxs ) : yield j , i , vals [ v ]
def updated_launch_dims ( self , var_vals : dict [ Variable , int ] ) :
dims = [ tuple ( sym_infer ( s , var_vals ) for s in dim ) for dim in self . symbolic_dims ]
for j , ( gl , lc ) in self . launch_dims_replace . items ( ) :
yield j , ( dims [ gl ] if gl is not None else self . launch_dims_base [ j ] [ 0 ] ) , ( dims [ lc ] if lc is not None else self . launch_dims_base [ j ] [ 1 ] )
def _access_resources ( self , rawbufs : list [ Buffer ] , write : list [ int ] , new_dependency : Any ) :
# To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource,
# whether for write or read operations. A resource can be accessed by either a single writer or multiple readers.
wait_nodes = [ ]
for i , rawbuf in enumerate ( rawbufs ) :
if id ( rawbuf . base . _buf ) in self . w_dependency_map : wait_nodes . append ( self . w_dependency_map [ id ( rawbuf . base . _buf ) ] )
if i in write :
if id ( rawbuf . base . _buf ) in self . r_dependency_map : wait_nodes . extend ( self . r_dependency_map . pop ( id ( rawbuf . base . _buf ) ) )
for i , rawbuf in enumerate ( rawbufs ) :
if i in write : self . w_dependency_map [ id ( rawbuf . base . _buf ) ] = new_dependency
else : self . r_dependency_map [ id ( rawbuf . base . _buf ) ] . append ( new_dependency )
return list ( { id ( x ) : x for x in wait_nodes } . values ( ) )
# a marker for your graph supporting multiple devices of the same type
class MultiGraphRunner ( GraphRunner ) : pass
def update_depends ( depends : set [ Buffer | None ] , jit_cache : list [ ExecItem ] ) :
for ei in jit_cache :
if any ( b in depends for b in ei . bufs ) :
if isinstance ( ei . prg , CompiledRunner ) :
depends . update ( cast ( Buffer , ei . bufs [ out ] ) for out in ei . prg . p . outs if out not in ei . prg . p . ins )
if isinstance ( ei . prg , ( BufferCopy , BufferXfer ) ) :
depends . add ( cast ( Buffer , ei . bufs [ 0 ] ) )
ReturnType = TypeVar ( ' ReturnType ' )
@dataclass
class CapturedJit ( Generic [ ReturnType ] ) :
ret : Any # includes the Tensors or any other returned object
jit_cache : list [ ExecItem ]
input_replace : dict [ tuple [ int , int ] , int ]
extra_view_inputs : list [ tuple [ int , int , str , int , DType ] ]
expected_names : list [ Union [ int , str ] ]
expected_st_vars_dtype_device : list [ tuple [ ShapeTracker , tuple [ Variable , . . . ] , DType , str ] ]
def __reduce__ ( self ) :
# TODO: free_intermediates here?
self . optimize_weights ( )
return self . __class__ , ( self . ret , self . jit_cache , self . input_replace , self . extra_view_inputs ,
self . expected_names , self . expected_st_vars_dtype_device )
def __post_init__ ( self ) :
self . _jit_cache : list [ ExecItem ] = self . jit_cache
self . _input_replace : dict [ tuple [ int , int ] , int ] = self . input_replace
self . _first_run = True
self . _clear_inputs ( )
def _clear_inputs ( self ) :
for ( j , i ) in self . _input_replace . keys ( ) : self . _jit_cache [ j ] . bufs [ i ] = None
def free_intermediates ( self ) :
depends : set [ Buffer | None ] = set ( [ None ] )
update_depends ( depends , self . jit_cache )
for b in depends :
if b is not None :
b . deallocate ( )
if b . _base is not None and b . _base . allocated_views == 0 : b . _base . deallocate ( )
self . __post_init__ ( ) # reset the graph state
def optimize_weights ( self ) :
blacklist = [ t . lazydata . buffer for t in get_parameters ( self . ret ) ]
asgn = _internal_memory_planner ( [ [ b for item in self . jit_cache for b in item . bufs if b is not None and b not in blacklist ] ] , ignore_checks = True )
self . jit_cache = [ ExecItem ( item . prg , [ asgn . get ( b , b ) if b is not None else None for b in item . bufs ] ) for item in self . jit_cache ]
for old , new in asgn . items ( ) :
if old . is_allocated ( ) : new . ensure_allocated ( ) . copyin ( old . as_buffer ( ) )
self . __post_init__ ( )
# jit exec
def __call__ ( self , input_buffers : list [ Buffer ] , var_vals : dict [ Variable , int ] ) - > ReturnType :
# assign inputs
for idx , offset , device , size , dtype in self . extra_view_inputs :
input_buffers . append ( Buffer ( device , size , dtype , base = input_buffers [ idx ] , offset = offset ) . ensure_allocated ( ) )
for ( j , i ) , input_idx in self . _input_replace . items ( ) : self . _jit_cache [ j ] . bufs [ i ] = input_buffers [ input_idx ]
# Condense the items into a graph executor.
if self . _first_run :
# allocate intermediates if freed
for ji in self . jit_cache :
for b in ji . bufs :
if b is not None : b . ensure_allocated ( )
# create graph if needed
if JIT < 2 :
self . _jit_cache = apply_graph_to_jit ( self . jit_cache , input_buffers , var_vals , max_batch_size = getenv ( " JIT_BATCH_SIZE " , 32 ) )
self . _input_replace = get_input_replace ( self . _jit_cache , input_buffers )
self . _first_run = False
if DEBUG > = 1 and len ( self . _jit_cache ) > = 10 : print ( f " jit execs { len ( self . _jit_cache ) } kernels " )
for ei in self . _jit_cache : ei . run ( var_vals , jit = True )
self . _clear_inputs ( )
return self . ret
def _prepare_jit_inputs ( args , kwargs ) :
input_tensors : list [ tuple [ int | str , Tensor ] ] = [ ( name , t ) for name , t in list ( enumerate ( args ) ) + sorted ( kwargs . items ( ) ) if t . __class__ is Tensor ]
names , tensors = [ name for name , _ in input_tensors ] , [ t for _ , t in input_tensors ]
if len ( unrealized_tensors := [ x for x in tensors if not x . lazydata . is_realized ] ) : Tensor . realize ( * unrealized_tensors )
# TODO: should we be unpacking multi here?
lbs : list [ UOp ] = flatten ( [ t . lazydata . src if t . lazydata . op is Ops . MULTI else [ t . lazydata ] for t in tensors ] )
input_buffers : list [ Buffer ] = [ lb . base . realized for lb in lbs if lb . base . realized is not None ]
assert len ( set ( input_buffers ) ) == len ( input_buffers ) , " duplicate inputs to JIT "
st_varval_dtype_device = [ ( * unwrap ( lb . st ) . unbind ( ) , lb . dtype , lb . device ) for lb in lbs ]
var_vals = merge_dicts ( [ x [ 1 ] for x in st_varval_dtype_device ] + [ dict ( v . unbind ( ) for v in ( args + tuple ( kwargs . values ( ) ) ) if isinstance ( v , UOp ) ) ] )
st_vars_dtype_device = [ ( x [ 0 ] , tuple ( sorted ( x [ 1 ] . keys ( ) , key = lambda v : v . expr ) ) , x [ 2 ] , x [ 3 ] ) for x in st_varval_dtype_device ]
return input_buffers , var_vals , names , st_vars_dtype_device
class TinyJit ( Generic [ ReturnType ] ) :
def __init__ ( self , fxn : Optional [ Callable [ . . . , ReturnType ] ] , captured : Optional [ CapturedJit ] = None , prune = False ) :
assert fxn or captured , " need either a function or a CapturedJit "
self . fxn = fxn
self . captured : Optional [ CapturedJit ] = captured
self . cnt : int = 2 if self . fxn is None else 0
self . prune = prune
def add_buffer ( self , b : Buffer ) - > Buffer :
if found := self . _buffer_replace . get ( b , None ) : return found
if b . is_allocated ( ) or b . lb_refcount > 0 : return b
if b . _base is not None :
self . _buffer_replace [ b ] = ret = Buffer ( b . device , b . size , b . dtype , base = self . add_buffer ( b . _base ) , offset = b . offset )
else :
self . _buffer_replace [ b ] = ret = Buffer ( b . device , b . size , b . dtype , options = b . options )
return ret
def add ( self , ei : ExecItem ) :
self . _jit_cache . append ( ExecItem ( ei . prg , [ self . add_buffer ( buf ) for buf in ei . bufs if buf is not None ] ) )
def reset ( self ) :
assert self . fxn is not None , " can ' t reset without function "
self . cnt = 0
self . captured = None
def __reduce__ ( self ) :
assert self . captured is not None , " can ' t pickle an uncaptured JIT "
return self . __class__ , ( None , self . captured )
# keep legacy code working
@property
def jit_cache ( self ) - > list [ ExecItem ] : return self . captured . _jit_cache if self . captured is not None else [ ]
@property
def input_replace ( self ) - > dict [ tuple [ int , int ] , int ] : return self . captured . _input_replace if self . captured is not None else { }
def __get__ ( self , obj , objtype ) : return functools . partial ( self . __call__ , obj ) # add support for instance methods
def __call__ ( self , * args , * * kwargs ) - > ReturnType :
input_buffers , var_vals , names , st_vars_dtype_device = _prepare_jit_inputs ( args , kwargs )
if not JIT or self . cnt == 0 :
# jit ignore
assert self . fxn is not None
with Context ( BEAM = 0 if getenv ( " IGNORE_JIT_FIRST_BEAM " ) else BEAM . value ) :
ret = self . fxn ( * args , * * kwargs )
if len ( params := get_parameters ( ret ) ) : Tensor . realize ( params [ 0 ] , * params [ 1 : ] )
elif self . cnt == 1 :
# jit capture
assert self . fxn is not None
if capturing : raise RuntimeError ( f " having TinyJit inside another TinyJit is not supported { len ( capturing ) =} { capturing =} " )
self . _jit_cache : list [ ExecItem ] = [ ]
self . _buffer_replace : WeakKeyDictionary [ Buffer , Buffer ] = WeakKeyDictionary ( )
# TODO: should we always disable the memory planner here? it must be off for prune
with Context ( BEAM = getenv ( " JITBEAM " , BEAM . value ) , NO_MEMORY_PLANNER = int ( self . prune ) ) :
capturing . append ( self )
try :
ret = self . fxn ( * args , * * kwargs )
if len ( params := get_parameters ( ret ) ) : Tensor . realize ( params [ 0 ] , * params [ 1 : ] )
except Exception as e : raise e
finally : capturing . clear ( )
jit_cache = self . _jit_cache
del self . _buffer_replace , self . _jit_cache
assert len ( jit_cache ) , " didn ' t JIT anything! "
if DEBUG > = 1 : print ( f " JIT captured { len ( jit_cache ) } kernels with { len ( input_buffers ) } inputs " )
# track inputs that are views of buffers
# TODO: eventually expected_buffers should live in ExecItem
extra_view_inputs : list [ tuple [ int , int , str , int , DType ] ] = [ ]
for item in jit_cache :
for b in item . bufs :
if b is not None and b . _base is not None and b . _base in input_buffers :
input_buffers . append ( b )
extra_view_inputs . append ( ( input_buffers . index ( b . base ) , b . offset , b . device , b . size , b . dtype ) )
# prune independent kernels (optional)
if self . prune :
depends = set ( input_buffers )
update_depends ( depends , jit_cache )
pruned , onetime = partition ( jit_cache ,
lambda ei : not isinstance ( ei . prg , CompiledRunner ) or any ( ei . bufs [ out ] in depends for out in ei . prg . p . outs ) )
if DEBUG > = 1 : print ( f " pruned from { len ( jit_cache ) } -> { len ( pruned ) } kernels " )
# run the onetime kernels here
for ei in onetime :
for b in ei . bufs : cast ( Buffer , b ) . ensure_allocated ( )
ei . run ( var_vals , jit = True )
jit_cache = pruned
# memory planning (optional)
# Exclude buffers involved in transfer ops to preserve parallelism.
noopt_buffers = { b for ji in jit_cache if isinstance ( ji . prg , BufferXfer ) for b in ji . bufs }
assigned = _internal_memory_planner ( [ cast ( list [ Buffer ] , item . bufs ) for item in jit_cache ] , noopt_buffers , debug_prefix = " JIT " )
jit_cache = [ ExecItem ( item . prg , [ assigned . get ( b , b ) . ensure_allocated ( ) for b in item . bufs if b is not None ] ) for item in jit_cache ]
input_replace = get_input_replace ( jit_cache , input_buffers )
if DEBUG > = 1 and len ( set ( input_replace . values ( ) ) ) != len ( input_buffers ) : print ( " WARNING: some input tensors not found " )
# set this for next run
self . captured = CapturedJit ( ret , jit_cache , input_replace , extra_view_inputs , names , st_vars_dtype_device )
elif self . cnt > = 2 :
# jit exec
assert self . captured is not None
assert self . captured . expected_names == names , f " args mismatch in JIT: { self . captured . expected_names =} != { names } "
assert self . captured . expected_st_vars_dtype_device == st_vars_dtype_device , \
f " args mismatch in JIT: { self . captured . expected_st_vars_dtype_device =} != { st_vars_dtype_device =} "
ret = self . captured ( input_buffers , var_vals )
self . cnt + = 1
return ret