from __future__ import annotations
import importlib , inspect , functools , pathlib
from enum import Enum , auto
from typing import TYPE_CHECKING , Union , Type , Tuple , Any , List , Optional , Dict , Callable , Mapping
from tinygrad . helpers import ansilen , prod , DEBUG , getenv , GlobalCounters , DType , colored , BEAM , NOOPT
from tinygrad . runtime . lib import RawBuffer
from tinygrad . shape . symbolic import Variable , sym_infer
from dataclasses import dataclass
# these are the llops your accelerator must implement, along with toCpu
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
# NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars
# NOTE: rdna3 only has RECIP and not DIV. DIV and POW are on the chopping block
class UnaryOps ( Enum ) : NOOP = auto ( ) ; EXP2 = auto ( ) ; LOG2 = auto ( ) ; CAST = auto ( ) ; SIN = auto ( ) ; SQRT = auto ( ) ; RECIP = auto ( ) ; NEG = auto ( ) # noqa: E702
class BinaryOps ( Enum ) : ADD = auto ( ) ; SUB = auto ( ) ; MUL = auto ( ) ; DIV = auto ( ) ; MAX = auto ( ) ; MOD = auto ( ) ; CMPLT = auto ( ) # noqa: E702
class TernaryOps ( Enum ) : MULACC = auto ( ) ; WHERE = auto ( ) # noqa: E702
class ReduceOps ( Enum ) : SUM = auto ( ) ; MAX = auto ( ) # noqa: E702
class BufferOps ( Enum ) : MEM = auto ( ) ; CONST = auto ( ) # noqa: E702
# Ops below this line are not allowed in ASTs
class MovementOps ( Enum ) : RESHAPE = auto ( ) ; PERMUTE = auto ( ) ; EXPAND = auto ( ) ; PAD = auto ( ) ; SHRINK = auto ( ) ; STRIDE = auto ( ) ; AS_STRIDED = auto ( ) # noqa: E702
class LoadOps ( Enum ) : EMPTY = auto ( ) ; RAND = auto ( ) ; CONST = auto ( ) ; FROM = auto ( ) ; CONTIGUOUS = auto ( ) ; CUSTOM = auto ( ) # noqa: E702
Op = Union [ UnaryOps , BinaryOps , ReduceOps , MovementOps , LoadOps , TernaryOps , BufferOps ]
OpType = Union [ Type [ UnaryOps ] , Type [ BinaryOps ] , Type [ ReduceOps ] , Type [ MovementOps ] , Type [ LoadOps ] , Type [ TernaryOps ] , Type [ BufferOps ] ]
if TYPE_CHECKING :
from tinygrad . shape . shapetracker import ShapeTracker
from tinygrad . lazy import LazyBuffer
@dataclass ( frozen = True )
class MemBuffer :
idx : int
dtype : DType
st : ShapeTracker
@dataclass ( frozen = True )
class ConstBuffer :
val : Any
dtype : DType
st : ShapeTracker
@dataclass ( frozen = True )
class ScheduleItem :
ast : LazyOp
out : LazyBuffer
inputs : Tuple [ LazyBuffer , . . . ]
var_vals : Dict [ Variable , int ]
@dataclass ( frozen = True )
class LazyOp :
op : Op
src : Tuple [ Union [ LazyOp , LazyBuffer ] , . . . ]
arg : Any = None
def __repr__ ( self ) : return f " LazyOp(op= { self . op } , src= { self . src } , arg= { self . arg } ) "
@property
def buffers ( self ) :
buffers : Tuple [ Union [ LazyOp , LazyBuffer ] , . . . ] = ( )
try : # NOTE: the linearizer's key function maps the buffers to ints, and LOCAL_BUFFER is used. we don't care about buffers in these cases
for x in self . src : buffers + = x . buffers
except AttributeError : buffers = ( )
return buffers
@property
def key ( self ) : return ( self . op , tuple ( map ( lambda x : getattr ( x , " key " , x ) , self . src ) ) , getattr ( self . arg , " key " , self . arg ) )
def map_buffers ( self , real_srcs : Mapping [ Any , Union [ LazyBuffer , LazyOp ] ] ) - > LazyOp : return LazyOp ( self . op , tuple ( [ y . map_buffers ( real_srcs ) if y not in real_srcs else real_srcs [ y ] for y in self . src ] ) , self . arg )
def get_lazyops ( self ) - > List [ LazyOp ] : return [ self ] + [ item for x in self . src for item in x . get_lazyops ( ) ]
def replace_with_movement_ops ( self : LazyOp , ops : List [ Tuple [ MovementOps , Tuple [ Any , . . . ] ] ] ) - > ' LazyBuffer ' :
assert self . op in BinaryOps or self . op in UnaryOps or self . op in TernaryOps
srcs = [ z . replace_with_movement_ops ( ops ) for z in self . src ]
return srcs [ 0 ] . e ( self . op , * srcs [ 1 : ] , arg = self . arg ) # type: ignore
@property
def st ( self ) : raise NotImplementedError
@property
def realized ( self ) : raise NotImplementedError
@property
def children ( self ) : raise NotImplementedError
# movement ops
def reshape ( self , _ ) : raise NotImplementedError
def pad ( self , _ ) : raise NotImplementedError
def expand ( self , _ ) : raise NotImplementedError
def permute ( self , _ ) : raise NotImplementedError
def shrink ( self , _ ) : raise NotImplementedError
def stride ( self , _ ) : raise NotImplementedError
# **************** Device ****************
class _Device :
def __init__ ( self ) - > None : self . _buffers : List [ str ] = [ x . stem [ len ( " ops_ " ) : ] . upper ( ) for x in ( pathlib . Path ( __file__ ) . parent / " runtime " ) . iterdir ( ) if x . stem . startswith ( " ops_ " ) ]
def canonicalize ( self , device : Optional [ str ] ) - > str : return ( device . split ( " : " , 1 ) [ 0 ] . upper ( ) + ( ( " : " + device . split ( " : " , 1 ) [ 1 ] ) if ' : ' in device else ' ' ) ) . replace ( " :0 " , " " ) if device is not None else self . DEFAULT
@functools . lru_cache ( maxsize = None ) # this class is a singleton, pylint: disable=method-cache-max-size-none
def __getitem__ ( self , x : str ) - > Union [ Interpreted , Compiled ] :
x = x . split ( " : " ) [ 0 ] . upper ( )
return [ cls for cname , cls in inspect . getmembers ( importlib . import_module ( f ' tinygrad.runtime.ops_ { x . lower ( ) } ' ) ) if ( cname . lower ( ) == x . lower ( ) + " buffer " ) and x in self . _buffers ] [ 0 ]
@functools . cached_property
def DEFAULT ( self ) - > str :
device_from_env : Optional [ str ] = functools . reduce ( lambda val , ele : ele if getenv ( ele ) == 1 else val , self . _buffers , None )
if device_from_env : return device_from_env
for device in [ " METAL " , " CUDA " , " GPU " ] :
try :
if self [ device ] : return device
except Exception : pass
return " CPU "
Device = _Device ( )
# **************** for Interpreted Buffers ****************
class Interpreted :
def __init__ ( self , buffer , fxn_for_op : Dict [ Op , Callable ] , from_underlying = None ) :
self . buffer , self . fxn_for_op , self . from_underlying = buffer , fxn_for_op , from_underlying
self . synchronize = lambda : None
self . codegen = None
self . method_cache : Dict [ LazyOp , Callable ] = { }
def interpret_ast ( self : Interpreted , ast : LazyOp ) - > Callable :
tglob : Dict [ str , Any ] = { }
lines : List [ str ] = [ ]
f = self . fxn_for_op
@functools . lru_cache ( None )
def gstr ( x : Any , nm = None ) - > str :
ret = str ( nm ) . replace ( " . " , " _ " ) if nm else f " m { len ( tglob ) : 04d } "
tglob [ ret ] = x
return ret
@functools . lru_cache ( None )
def _interpret_ast ( ast : LazyOp ) - > str :
if TernaryOps . MULACC in f and ast . op == ReduceOps . SUM and isinstance ( ast . src [ 0 ] , LazyOp ) and ast . src [ 0 ] . op == BinaryOps . MUL :
ast = LazyOp ( TernaryOps . MULACC , ast . src [ 0 ] . src , ast . arg )
if MovementOps . AS_STRIDED in f and ast . op in BufferOps :
tmp = f " { gstr ( f [ ast . op ] , ast . op ) } ( { gstr ( ast . arg . val ) } , { gstr ( ast . arg . dtype ) } ) " if ast . op == BufferOps . CONST else f " { gstr ( f [ ast . op ] , ast . op ) } (inputs[ { ast . arg . idx - 1 } ]) "
for mop , arg in ast . arg . st . to_movement_ops ( ) : tmp = f " { gstr ( f [ mop ] , mop ) } ( { tmp } , { gstr ( arg ) } ) "
else :
inp = [ _interpret_ast ( src ) for src in ast . src ]
tmp = f " { gstr ( f [ ast . op ] , ast . op ) } ( { ' , ' . join ( inp + ( [ gstr ( ast . arg ) ] if ast . arg else [ ] ) ) } ) "
ret = f " a { len ( lines ) } "
lines . append ( f " { ret } = { tmp } " )
return ret
ret = _interpret_ast ( ast )
src = ' \n ' . join ( [ ' def run(inputs): ' ] + lines + [ f " return { gstr ( self . from_underlying , ' from_underlying ' ) } ( { ret } ) " if self . from_underlying else f " return { ret } " ] )
if DEBUG > = 4 : print ( functools . reduce ( lambda x , y : ( x . replace ( y [ 0 ] , str ( y [ 1 ] ) ) if y [ 0 ] [ 0 : 2 ] == " m0 " else x ) , tglob . items ( ) , src ) )
exec ( compile ( src , " <ast> " , " exec " ) , tglob ) # pylint: disable=exec-used
return tglob [ ' run ' ]
def exec_ast ( self , ast : LazyOp , output = None , inputs = None , var_vals = None , * * kwargs ) :
if ast not in self . method_cache : self . method_cache [ ast ] = self . interpret_ast ( ast )
ret = self . method_cache [ ast ] ( [ x . realized for x in inputs ] if inputs else None )
if output is not None and ret . dtype != output . dtype and UnaryOps . CAST in self . fxn_for_op :
ret = self . from_underlying ( self . fxn_for_op [ UnaryOps . CAST ] ( self . fxn_for_op [ BufferOps . MEM ] ( ret ) , ( output . dtype , False ) ) ) # Do manual casting of ret if it does not match the required output dtype.
# TODO: is this used?
if output is not None and output . output_buffer is not None :
assert output . output_buffer . dtype == ret . dtype
output . output_buffer . _buf = ret . _buf
return output . output_buffer
return ret
@dataclass
class FlopCounter :
shape : Tuple [ int , . . . ]
dtype : DType
flops : int
mem : Dict [ int , int ]
@property
def mem_estimate ( self ) : return sum ( self . mem . values ( ) ) + self . dtype . itemsize * prod ( self . shape )
def consume_flops ( self ) :
self . flops , ret = 0 , self . flops
return ret
InterpretedFlopCounter = Interpreted ( FlopCounter , {
BufferOps . MEM : lambda arg : FlopCounter ( arg . st . shape , arg . dtype , 0 , { arg . idx : arg . dtype . itemsize * arg . st . size ( ) } ) , BufferOps . CONST : lambda arg : FlopCounter ( arg . st . shape , arg . dtype , 0 , { } ) ,
UnaryOps . CAST : lambda self , arg : FlopCounter ( self . shape , arg [ 0 ] , self . consume_flops ( ) , self . mem ) , # cast uses no flops
* * { op : lambda self : FlopCounter ( self . shape , self . dtype , self . consume_flops ( ) + prod ( self . shape ) , self . mem ) for op in UnaryOps if op != UnaryOps . CAST } ,
* * { op : lambda self , y : FlopCounter ( self . shape , max ( self . dtype , y . dtype ) , self . consume_flops ( ) + y . consume_flops ( ) + prod ( self . shape ) , { * * self . mem , * * y . mem } ) for op in BinaryOps } ,
* * { op : lambda self , new_shape : FlopCounter ( new_shape , self . dtype , self . consume_flops ( ) + prod ( self . shape ) , self . mem ) for op in ReduceOps } ,
TernaryOps . WHERE : lambda self , y , z : FlopCounter ( self . shape , y . dtype , self . consume_flops ( ) + y . consume_flops ( ) + z . consume_flops ( ) + prod ( self . shape ) , { * * self . mem , * * y . mem , * * z . mem } ) } )
@functools . lru_cache ( None )
def get_lazyop_info ( ast : LazyOp ) - > FlopCounter : return InterpretedFlopCounter . exec_ast ( ast )
# **************** for Compiled Buffers ****************
class ASTRunner :
def __init__ ( self , name : str , prg : str , global_size : Optional [ List [ int ] ] = None , local_size : Optional [ List [ int ] ] = None , op_estimate = 0 , mem_estimate = 0 , display_name : Optional [ str ] = None , runtime_args : Optional [ dict ] = None ) :
if DEBUG > = 4 : print ( prg )
self . name , self . prg , self . global_size , self . local_size , self . op_estimate , self . mem_estimate , self . display_name , self . runtime_args = name , prg , global_size , local_size , op_estimate , mem_estimate , display_name , runtime_args if runtime_args is not None else { }
def build ( self , compiler , runtime ) :
self . lib = compiler . __wrapped__ ( self . prg ) if getenv ( " DISABLE_COMPILER_CACHE " ) else compiler ( self . prg )
self . clprg = runtime ( self . name , self . lib )
return self
def exec ( self , rawbufs , var_vals : Optional [ Dict [ Variable , int ] ] = None , force_wait = False ) - > Optional [ float ] :
from tinygrad . jit import CacheCollector
CacheCollector . add ( self , rawbufs , var_vals if var_vals is not None else { } )
return self ( rawbufs , var_vals , force_wait = force_wait )
def launch_dims ( self , var_vals ) :
global_size = ( [ sym_infer ( sz , var_vals ) for sz in self . global_size ] + [ 1 ] * ( 3 - len ( self . global_size ) ) ) if self . global_size is not None else self . global_size
local_size = ( [ sym_infer ( sz , var_vals ) for sz in self . local_size ] + [ 1 ] * ( 3 - len ( self . local_size ) ) ) if self . local_size is not None else self . local_size
return global_size , local_size
def __call__ ( self , rawbufs : List [ RawBuffer ] , var_vals : Optional [ Dict [ Variable , int ] ] = None , jit = False , force_wait = False ) - > Optional [ float ] :
if var_vals is None : var_vals = { }
global_size , local_size = self . launch_dims ( var_vals )
if global_size is not None and local_size is None :
# TODO: this is copied from get_program
from tinygrad . features . search import optimize_local_size
local_size = self . local_size = optimize_local_size ( self . clprg , global_size , rawbufs )
global_size = self . global_size = [ g / / l if g % l == 0 else g / l for g , l in zip ( global_size , local_size ) ]
lra = self . runtime_args . copy ( )
if global_size : lra [ ' global_size ' ] = global_size
if local_size and ' local_size ' not in lra : lra [ ' local_size ' ] = local_size
if et := self . clprg ( * rawbufs , * var_vals . values ( ) , * * lra , wait = force_wait or DEBUG > = 2 ) : GlobalCounters . time_sum_s + = et
op_estimate = sym_infer ( self . op_estimate , var_vals )
mem_estimate = sym_infer ( self . mem_estimate , var_vals )
if DEBUG > = 2 :
print ( f " { colored ( f ' *** { GlobalCounters . kernel_count : 4d } ' , ' magenta ' if jit else None ) } { ( self . display_name + ' ' * ( 37 - ansilen ( self . display_name ) ) ) if self . display_name is not None else self . name : 33s } arg { len ( rawbufs ) : 3d } sz { str ( global_size ) : 18s } { str ( local_size ) : 12s } OPs { int ( op_estimate / 1e6 ) : 6d } M/ { GlobalCounters . global_ops / 1e9 : 7.2f } G mem { GlobalCounters . mem_used / 1e9 : 5.2f } GB " +
( str ( ) if et is None else f " tm { et * 1e6 : 9.2f } us/ { GlobalCounters . time_sum_s * 1e3 : 9.2f } ms ( { op_estimate / ( ( et or 1e-20 ) * 1e9 ) : 8.2f } GFLOPS, { mem_estimate / ( ( et or 1e-20 ) * 1e9 ) : 7.2f } GB/s) " ) )
GlobalCounters . kernel_count + = 1
GlobalCounters . global_ops + = op_estimate
GlobalCounters . global_mem + = mem_estimate
return et
class Compiled :
def __init__ ( self , buffer : Type [ RawBuffer ] , linearizer_opts , renderer , compiler , runtime , synchronize = lambda : None ) :
self . buffer , self . linearizer_opts , self . renderer , self . compiler , self . runtime , self . synchronize = buffer , linearizer_opts , renderer , compiler , runtime , synchronize
self . method_cache : Dict [ LazyOp , ASTRunner ] = { }
def to_program ( self , k ) :
k . linearize ( )
src , runtime_args = self . renderer ( k . function_name , k . uops )
return ASTRunner ( k . function_name , src , k . global_size , k . local_size ,
op_estimate = k . info . flops , mem_estimate = k . info . mem_estimate ,
display_name = k . display_name , runtime_args = runtime_args ) . build ( self . compiler , self . runtime )
def exec_ast ( self , ast : LazyOp , output , inputs , var_vals , * * kwargs ) :
# check if we can reuse the output buffer
# if it's aliased, don't use it
# NOTE: this is pretty wrong actually, who knows where else this buffer is used?
output . realized = output . output_buffer
if output . realized :
for i , a in enumerate ( inputs ) :
# TODO: if this is contiguous it's fine
if a . realized == output . realized :
if any ( not x . arg . st . contiguous for x in ast . get_lazyops ( ) if x . op == BufferOps . MEM and x . arg . idx == i + 1 ) :
output . realized = None
break
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
if not output . realized :
output . realized = self . buffer ( prod ( ( s if isinstance ( s , int ) else s . max for s in output . shape ) ) , output . dtype , * * kwargs )
# all the rawbuffers
rawbuffers = [ output . realized ] + [ x . realized for x in inputs ]
# extract real vars used in ast
from tinygrad . lazy import vars_from_ast
ast_vars = vars_from_ast ( ast )
assert all ( v . val is None for v in ast_vars ) , f " ast contains bound Variable { ast_vars } "
# compilation time
def get_program ( ) :
from tinygrad . codegen . linearizer import Linearizer
k = Linearizer ( ast , self . linearizer_opts )
assert k . info . dtype == output . dtype , f " linearizer must match dtype. linearizer wants { k . info . dtype } but buffer is { output . dtype } "
if not NOOPT :
if not ( used_tensor_cores := k . apply_tensor_cores ( getenv ( " TC " , 1 ) ) ) : k . hand_coded_optimizations ( )
if BEAM > = 1 and not vars_from_ast ( ast ) :
lins = [ ( ( " tc " if used_tensor_cores else " hc " ) , k ) ]
# allocate a scratch buffer if output buffer is also input
test_rawbuffers = [ type ( rawbuffers [ 0 ] ) ( rawbuffers [ 0 ] . size , rawbuffers [ 0 ] . dtype ) , * rawbuffers [ 1 : ] ] if rawbuffers [ 0 ] in rawbuffers [ 1 : ] else rawbuffers
kb = Linearizer ( ast , self . linearizer_opts )
kb . required_optimizations ( )
from tinygrad . features . search import beam_search , time_linearizer
lins . append ( ( f " beam { BEAM . value } " , beam_search ( kb , test_rawbuffers , BEAM . value , bool ( getenv ( " BEAM_ESTIMATE " , 1 ) ) ) ) )
if used_tensor_cores :
lins . append ( ( " hc " , Linearizer ( ast , self . linearizer_opts ) ) )
lins [ - 1 ] [ 1 ] . hand_coded_optimizations ( )
timed = sorted ( [ ( nm , tk , time_linearizer ( tk , test_rawbuffers , allow_test_size = False , disable_cache = True , clear_l2 = True ) ) for nm , tk in lins ] , key = lambda x : x [ 2 ] )
if DEBUG > = 1 : print ( " < " . join ( f " { nm : 6s } : { lin . colored_shape ( 30 , dense = True ) } : { tm * 1e6 : 8.2f } us " for nm , lin , tm in timed ) )
k = timed [ 0 ] [ 1 ]
else :
k . required_optimizations ( )
return self . to_program ( k )
if getenv ( " ENABLE_METHOD_CACHE " , 1 ) :
if ast not in self . method_cache : self . method_cache [ ast ] = get_program ( )
prg = self . method_cache [ ast ]
else :
prg = get_program ( )
if prg . name == getenv ( " PRINT_PRG " , ' ' ) : print ( prg . prg )
prg . exec ( rawbuffers , var_vals = { k : var_vals [ k ] for k in ast_vars } )
return output . realized