from __future__ import annotations
from typing import Any , Optional , Union , Callable , cast , TYPE_CHECKING , Type , get_args
import sys , time , functools , itertools , math , operator , hashlib , os , types , pickle , pathlib , inspect , weakref
from enum import auto , IntEnum , Enum
from dataclasses import dataclass , field
from tinygrad . dtype import ConstType , ImageDType , dtypes , DType , truncate
from tinygrad . helpers import ContextVar , all_int , prod , getenv , all_same , Context , partition , temp , unwrap , T , argfix , Metadata , _METADATA , flatten
from tinygrad . helpers import PICKLE_BUFFERS , dedup , cdiv , cmod
if TYPE_CHECKING :
from tinygrad . shape . shapetracker import ShapeTracker
from tinygrad . device import Buffer
# wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses
class FastEnum ( IntEnum ) :
def __str__ ( self ) : return Enum . __str__ ( self )
@staticmethod
def _generate_next_value_ ( _ , __ , ___ , last_values ) : return 1 + max ( [ 0 , * last_values , * [ max ( c ) for c in FastEnum . __subclasses__ ( ) ] ] )
class SimpleMathTrait :
# required to implement
def alu ( self : T , arg : Ops , * src ) - > T : raise NotImplementedError
def const_like ( self : T , b : ConstLike ) - > T : raise NotImplementedError
# great functions you get!
def ufix ( self , x ) : return self . const_like ( x ) if not isinstance ( x , MathTrait ) else x
def _binop ( self , op , x , reverse ) : return self . ufix ( x ) . alu ( op , self ) if reverse else self . alu ( op , self . ufix ( x ) )
def logical_not ( self ) : return self . ne ( True )
def neg ( self ) :
if ( dtype := getattr ( self , ' dtype ' ) ) is None : raise TypeError ( f " MathTraits __neg__ requires a dtype, { self =} " )
return self . logical_not ( ) if dtype . scalar ( ) == dtypes . bool else self * ( - 1 )
def add ( self , x , reverse = False ) : return self . _binop ( Ops . ADD , x , reverse )
def mul ( self , x , reverse = False ) : return self . _binop ( Ops . MUL , x , reverse )
def bitwise_and ( self , x , reverse = False ) : return self . _binop ( Ops . AND , x , reverse )
def bitwise_or ( self , x , reverse = False ) : return self . _binop ( Ops . OR , x , reverse )
def bitwise_xor ( self , x , reverse = False ) : return self . _binop ( Ops . XOR , x , reverse )
def idiv ( self , x , reverse = False ) : return self . _binop ( Ops . IDIV , x , reverse )
def mod ( self , x , reverse = False ) : return self . _binop ( Ops . MOD , x , reverse )
def sub ( self , x , reverse = False ) : return self . ufix ( x ) . alu ( Ops . ADD , - self ) if reverse else self . alu ( Ops . ADD , self . ufix ( - x ) )
def div ( self , x , reverse = False ) : return ( self . ufix ( x ) * self . alu ( Ops . RECIP ) ) if reverse else ( self * self . ufix ( x ) . alu ( Ops . RECIP ) )
def __neg__ ( self ) : return self . neg ( )
def __add__ ( self , x ) : return self . add ( x )
def __sub__ ( self , x ) : return self . sub ( x )
def __mul__ ( self , x ) : return self . mul ( x )
def __truediv__ ( self , x ) : return self . div ( x )
def __floordiv__ ( self , x ) : return self . idiv ( x ) # TODO: idiv is trunc div, not floordiv
def __mod__ ( self , x ) : return self . mod ( x )
def __and__ ( self , x ) : return self . bitwise_and ( x )
def __or__ ( self , x ) : return self . bitwise_or ( x )
def __xor__ ( self , x ) : return self . bitwise_xor ( x )
def __radd__ ( self , x ) : return self . add ( x , True )
def __rsub__ ( self , x ) : return self . sub ( x , True )
def __rmul__ ( self , x ) : return self . mul ( x , True )
def __rtruediv__ ( self , x ) : return self . div ( x , True )
def __rfloordiv__ ( self , x ) : return self . idiv ( x , True )
def __rand__ ( self , x ) : return self . bitwise_and ( x , True )
def __ror__ ( self , x ) : return self . bitwise_or ( x , True )
def __rxor__ ( self , x ) : return self . bitwise_xor ( x , True )
def __rmod__ ( self , x ) : return self . mod ( x , True )
def __lt__ ( self , x ) : return self . alu ( Ops . CMPLT , self . ufix ( x ) )
def __gt__ ( self , x ) : return self . ufix ( x ) . alu ( Ops . CMPLT , self )
def __ge__ ( self , x ) : return ( self < x ) . logical_not ( )
def __le__ ( self , x ) : return ( self > x ) . logical_not ( )
def ne ( self , x ) : return self . alu ( Ops . CMPNE , self . ufix ( x ) )
def eq ( self , x ) : return self . ne ( x ) . logical_not ( )
def __ne__ ( self , x ) : return self . ne ( x )
# NOTE: __eq__ isn't overridden, and means the same thing as is by default
class MathTrait ( SimpleMathTrait ) :
# TODO: move to Tensor when new backward is done
def lshift ( self , x , reverse = False ) : return self . _binop ( Ops . SHL , x , reverse )
def rshift ( self , x , reverse = False ) : return self . _binop ( Ops . SHR , x , reverse )
def __lshift__ ( self , x ) : return self . lshift ( x )
def __rshift__ ( self , x ) : return self . rshift ( x )
def __rlshift__ ( self , x ) : return self . lshift ( x , True )
def __rrshift__ ( self , x ) : return self . rshift ( x , True )
def maximum ( self , x ) : return self . alu ( Ops . MAX , self . ufix ( x ) )
def minimum ( self , x ) : return - ( - self ) . maximum ( - x )
def where ( self , x , y ) : return self . alu ( Ops . WHERE , x , x . ufix ( y ) )
def threefry ( self , seed ) : return self . alu ( Ops . THREEFRY , seed )
def reciprocal ( self ) : return self . alu ( Ops . RECIP )
def sqrt ( self ) : return self . alu ( Ops . SQRT )
def sin ( self ) : return self . alu ( Ops . SIN )
def log2 ( self ) : return self . alu ( Ops . LOG2 )
def exp2 ( self ) : return self . alu ( Ops . EXP2 )
def pow ( self , x ) : return self . alu ( Ops . POW , self . ufix ( x ) )
# the order of these Ops controls the order of the toposort
class Ops ( FastEnum ) :
# uops that aren't rendered
NAME = auto ( ) ; SINK = auto ( ) ; CONTIGUOUS = auto ( ) ; CONTIGUOUS_BACKWARD = auto ( ) ; DETACH = auto ( ) ; KERNEL = auto ( ) ; UNIQUE = auto ( ) # noqa: E702
# TODO: empty continues to exist because of tensor
EMPTY = auto ( )
# MetaOps
COPY = auto ( ) ; BUFFER_VIEW = auto ( ) # noqa: E702
# blocks in linearizer
BLOCK = auto ( ) ; BLOCKSTART = auto ( ) ; BLOCKFORK = auto ( ) ; BLOCKEND = auto ( ) # noqa: E702
# movement ops!
RESHAPE = auto ( ) ; PERMUTE = auto ( ) ; EXPAND = auto ( ) ; PAD = auto ( ) ; SHRINK = auto ( ) ; FLIP = auto ( ) # noqa: E702
# misc ops
UNROLL = auto ( ) ; CONTRACT = auto ( ) # noqa: E702
VIEW = auto ( ) ; DEFINE_GLOBAL = auto ( ) ; BUFFER = auto ( ) # noqa: E702
DEFINE_VAR = auto ( ) ; DEFINE_LOCAL = auto ( ) ; DEFINE_ACC = auto ( ) # noqa: E702
VALID = auto ( ) ; SPECIAL = auto ( ) ; NOOP = auto ( ) # noqa: E702
# reduce
REDUCE_AXIS = auto ( ) ; REDUCE = auto ( ) # noqa: E702
# helper ops
GEP = auto ( ) ; VECTORIZE = auto ( ) ; CAT = auto ( ) ; PTRCAT = auto ( ) # noqa: E702
# UnaryOps
CAST = auto ( ) ; BITCAST = auto ( ) ; EXP2 = auto ( ) ; LOG2 = auto ( ) ; SIN = auto ( ) ; SQRT = auto ( ) ; RECIP = auto ( ) ; NEG = auto ( ) # noqa: E702
# load/store before math
LOAD = auto ( ) ; STORE = auto ( ) # noqa: E702
# early INDEX
INDEX = auto ( )
# math ops
WMMA = auto ( )
# BinaryOps
MUL = auto ( ) ; SHL = auto ( ) ; SHR = auto ( ) ; IDIV = auto ( ) ; ADD = auto ( ) ; MAX = auto ( ) ; MOD = auto ( ) ; CMPLT = auto ( ) ; CMPNE = auto ( ) # noqa: E702
XOR = auto ( ) ; OR = auto ( ) ; AND = auto ( ) ; THREEFRY = auto ( ) ; SUB = auto ( ) ; FDIV = auto ( ) ; POW = auto ( ) # noqa: E702
# TernaryOps
WHERE = auto ( ) ; MULACC = auto ( ) # noqa: E702
# assignment ops
ASSIGN = auto ( )
BIND = auto ( )
# control flow ops
BARRIER = auto ( ) ; RANGE = auto ( ) ; IF = auto ( ) ; ENDRANGE = auto ( ) ; ENDIF = auto ( ) # noqa: E702
# consts last!
VCONST = auto ( ) ; CONST = auto ( ) # noqa: E702
# device
DEVICE = auto ( )
MULTI = auto ( )
# CUSTOMI is inline
CUSTOM = auto ( ) ; CUSTOMI = auto ( ) # noqa: E702
IGNORE = auto ( )
class GroupOp :
Unary = { Ops . EXP2 , Ops . LOG2 , Ops . SIN , Ops . SQRT , Ops . RECIP , Ops . NEG }
Binary = { Ops . ADD , Ops . MUL , Ops . IDIV , Ops . MAX , Ops . MOD , Ops . CMPLT , Ops . CMPNE , Ops . XOR , Ops . SHL , Ops . SHR , Ops . OR , Ops . AND , Ops . THREEFRY ,
Ops . SUB , Ops . FDIV , Ops . POW }
Ternary = { Ops . WHERE , Ops . MULACC }
ALU = set . union ( Unary , Binary , Ternary )
Irreducible = { Ops . CONST , Ops . DEFINE_VAR , Ops . SPECIAL , Ops . RANGE }
Movement = { Ops . RESHAPE , Ops . EXPAND , Ops . PERMUTE , Ops . PAD , Ops . SHRINK , Ops . FLIP }
Buffer = { Ops . LOAD , Ops . STORE , Ops . VALID , Ops . CONST , Ops . DEFINE_VAR }
Block = { Ops . BLOCK , Ops . BLOCKEND , Ops . BLOCKFORK , Ops . BLOCKSTART }
# BinaryOps that can be flipped
Commutative = { Ops . ADD , Ops . MUL , Ops . MAX , Ops . CMPNE , Ops . XOR , Ops . AND , Ops . OR }
# BinaryOps where f(f(a,b),c) = f(a,f(b,c))
Associative = { Ops . ADD , Ops . MUL , Ops . AND , Ops . OR , Ops . MAX }
# BinaryOps that satisfy f(x,x)=x see https://en.wikipedia.org/wiki/Idempotence
Idempotent = { Ops . OR , Ops . AND , Ops . MAX }
# do not preserve f(0) = 0
UnsafePad = { Ops . RECIP , Ops . LOG2 , Ops . EXP2 , Ops . IDIV , Ops . POW }
Meta = { Ops . COPY , Ops . BUFFER_VIEW }
All = set ( Ops )
# some BUFFER ops can be processed with only a view
view_supported_devices = { " LLVM " , " CPU " , " CUDA " , " NV " , " AMD " , " METAL " , " QCOM " , " DSP " , " DISK " }
# https://en.wikipedia.org/wiki/Identity_element
def identity_element ( op : Ops , dt : DType ) - > ConstType : return dtypes . as_const ( { Ops . ADD : 0 , Ops . MUL : 1 , Ops . MAX : dtypes . min ( dt ) } [ op ] , dt )
def can_pad ( u : UOp , edges : dict [ UOp , None ] , cache : dict [ UOp , None ] ) - > bool :
if u . op in GroupOp . UnsafePad : return False
if u in edges or u in cache : return True
cache [ u ] = None
return all ( can_pad ( x . base , edges , cache ) for x in u . src )
# With True as the default, this matches the old symbolic behavior
def resolve ( x : UOp | bool , default : bool = True ) :
if isinstance ( x , bool ) : return x
assert x . dtype == dtypes . bool , " UOp in resolve must be bool "
# NOTE: generating the text for the exception is expensive, so we do this
return bool ( sx . vmin ) if ( sx := x . simplify ( ) ) . vmin == sx . vmax else default
# smax/smin are replacements for max/min that preserve symbolic
def _suop ( lst , uop_fxn , python_fxn ) :
uops , nums = partition ( lst , lambda x : isinstance ( x , UOp ) )
return ssimplify ( functools . reduce ( uop_fxn , uops + ( [ python_fxn ( nums ) ] if nums else [ ] ) ) )
def smax ( * lst ) : return _suop ( argfix ( * lst ) , UOp . maximum , max )
def smin ( * lst ) : return _suop ( argfix ( * lst ) , UOp . minimum , min )
def ssimplify ( uop ) : return uop . ssimplify ( ) if isinstance ( uop , UOp ) else uop
def sym_infer ( uop : Union [ UOp , int ] , var_vals : dict [ UOp , int ] ) - > int : return uop . sym_infer ( var_vals ) if isinstance ( uop , UOp ) else uop
# used for UOp and UPat
def pretty_print ( x : Any , rep : Callable , srcfn = lambda x : x . src , cache = None , d = 0 ) - > str :
def dfs ( x : Any , cache : dict ) :
for s in srcfn ( x ) or [ ] :
cache . setdefault ( s , [ len ( cache ) , 0 , False ] ) [ 1 ] + = 1
if cache [ s ] [ 1 ] == 1 : dfs ( s , cache )
if cache is None : dfs ( x , cache := { } )
if ( cx := cache . setdefault ( x , [ 0 , 0 , False ] ) ) [ 2 ] : return f " { ' ' * d } x { cx [ 0 ] } "
cx [ 2 ] , srcs = True , ( ' None ' if srcfn ( x ) is None else ' ' . join ( f ' \n { pretty_print ( s , rep , srcfn , cache , d + 2 ) } , ' for s in srcfn ( x ) ) )
return f " { ' ' * d } { f ' x { cx [ 0 ] } := ' * ( cx [ 1 ] > 1 ) } { rep ( x ) } " % srcs
class UOpMetaClass ( type ) :
ucache : dict [ tuple , weakref . ReferenceType [ UOp ] ] = { }
def __call__ ( cls , op : Ops , dtype : DType = dtypes . void , src : tuple [ UOp , . . . ] = tuple ( ) , arg : Any = None , _buffer : Buffer | None = None ) :
if ( wret := UOpMetaClass . ucache . get ( key := ( op , dtype , src , arg ) , None ) ) is not None and ( ret := wret ( ) ) is not None : return ret
UOpMetaClass . ucache [ key ] = ref = weakref . ref ( created := super ( ) . __call__ ( * key ) )
for s in src : s . children . add ( ref )
# NOTE: this will soon be set by Tensor once we remove function.py
if ( metadata := _METADATA . get ( ) ) is not None : all_metadata [ created ] = metadata
# NOTE: this value is set by pickle when pickling a realized tensor
if _buffer is not None :
assert op is Ops . BUFFER , f " trying to set Buffer { _buffer } for { op } "
buffers [ created ] = _buffer
return created
# some uops map to other stuff
buffers : weakref . WeakKeyDictionary [ UOp , Buffer ] = weakref . WeakKeyDictionary ( ) # this maps BUFFER uops to their device Buffers
all_metadata : weakref . WeakKeyDictionary [ UOp , Metadata ] = weakref . WeakKeyDictionary ( )
def _toposort ( u : UOp , cache : set [ UOp ] ) :
if u in cache : return { }
nodes : dict [ UOp , None ] = { }
# NOTE: this is a lot faster than the comprehension in parents
for parent in u . src : nodes . update ( _toposort ( parent , cache ) )
nodes [ u ] = None
cache . add ( u )
return nodes
# NOTE: this should be frozen, but frozen is slower
@dataclass ( eq = False , slots = True )
class UOp ( MathTrait , metaclass = UOpMetaClass ) :
op : Ops
dtype : DType = dtypes . void
src : tuple [ UOp , . . . ] = tuple ( )
arg : Any = None
children : set [ weakref . ref [ UOp ] ] = field ( default_factory = set )
def __del__ ( self ) :
if self . op is Ops . BUFFER and ( buffer := buffers . get ( self ) ) is not None : buffer . ref ( - 1 )
if ( ref := UOpMetaClass . ucache . get ( k := ( self . op , self . dtype , self . src , self . arg ) ) ) is not None :
for s in self . src : s . children . discard ( ref )
del UOpMetaClass . ucache [ k ]
def __reduce__ ( self ) :
args = [ self . op , self . dtype , self . src , self . arg ]
if self . op is Ops . BUFFER and self . realized is not None and PICKLE_BUFFERS : args . append ( self . realized )
return UOp , tuple ( args )
def replace ( self , * * kwargs ) - > UOp :
new_args = ( kwargs . pop ( " op " , self . op ) , kwargs . pop ( " dtype " , self . dtype ) , kwargs . pop ( " src " , self . src ) , kwargs . pop ( " arg " , self . arg ) )
assert len ( kwargs ) == 0 , f " unused kwargs in replace { list ( kwargs ) } "
if ( self . op , self . dtype , self . src , self . arg ) == new_args : return self
return UOp ( * new_args )
@functools . cached_property
def key ( self ) - > bytes :
return hashlib . sha256 ( str ( ( self . op , self . dtype , self . arg ) ) . encode ( ) + b " " . join ( [ s . key for s in self . src ] ) ) . digest ( )
def __repr__ ( self ) : return pretty_print ( self , lambda x : f " { type ( self ) . __name__ } ( { x . op } , { x . dtype } , arg= { x . argstr ( ) } , src=(%s)) " )
def argstr ( self ) : return f ' ( { " , " . join ( map ( str , self . arg ) ) } ) ' if self . op is Ops . REDUCE_AXIS else repr ( self . arg )
@property
def toposort ( self ) - > dict [ UOp , None ] :
return _toposort ( self , cache = set ( ) )
# returns map of UOps to their children in the graph rooted by self
def get_children_map ( self ) - > dict [ UOp , dict [ UOp , None ] ] :
ret : dict [ UOp , dict [ UOp , None ] ] = { }
for u in self . toposort :
for s in u . src : ret . setdefault ( s , { } ) [ u ] = None
return ret
@functools . cached_property
def tuplize ( self : UOp ) - > tuple [ int , Any , Optional [ DType ] , tuple ] : return ( self . op . value , self . arg , self . dtype , tuple ( x . tuplize for x in self . src ) )
# *** uop shape stuff ***
@functools . cached_property
def st ( self ) - > ShapeTracker | None :
from tinygrad . shape . shapetracker import ShapeTracker
if self . op is Ops . MULTI :
return ShapeTracker . from_shape (
tuple ( sum ( y . shape [ a ] for y in self . real_lbs ) if a == self . axis else s for a , s in enumerate ( self . real_lbs [ 0 ] . shape ) ) )
if self . op in { Ops . BUFFER , Ops . BUFFER_VIEW } : return ShapeTracker . from_shape ( ( self . size , ) )
if self . op is Ops . KERNEL : return ShapeTracker . from_shape ( ( self . arg . ast . size , ) )
# these ops define a ShapeTracker from the arg
if self . op is Ops . VIEW : return self . arg
if self . op in GroupOp . Movement : return unwrap ( self . src [ 0 ] . st ) . mop ( self . op , self . arg )
# buffer ops return the ShapeTracker from sources
if self . op in GroupOp . Buffer : return vsrc [ 0 ] if len ( vsrc := [ x . st for x in self . src if x . op is Ops . VIEW ] ) != 0 else None
if not ( src_sts := [ x . st for x in self . src if x . st is not None ] ) : return None
assert all_same ( [ x . shape for x in src_sts ] ) , f " UOp sources must have the same shape { self } { [ x . shape for x in src_sts ] } "
if self . op is Ops . BITCAST :
shape = src_sts [ 0 ] . shape
if self . dtype . itemsize != ( input_sz := self . src [ 0 ] . dtype . itemsize ) : shape = shape [ : - 1 ] + ( ( shape [ - 1 ] * input_sz ) / / self . dtype . itemsize , )
# only reduce ops are allowed to change shape, everything else derives shape from sources
elif self . op in { Ops . REDUCE_AXIS , Ops . WMMA } : shape = src_sts [ 0 ] . reduce ( self . axis_arg )
else : shape = src_sts [ 0 ] . shape
return ShapeTracker . from_shape ( shape )
@functools . cached_property
def full_shape ( self ) - > tuple [ sint , . . . ] :
if self . op is Ops . VIEW : return self . shape
# TODO: this should check if st is None, it cannot because local reduce has implicit movement ops
return tuple ( smax ( x ) for x in zip ( * [ x . full_shape for x in self . src if x . op not in { Ops . DEFINE_GLOBAL , Ops . DEFINE_LOCAL } \
# TODO: this exists because wmma creates consts without ShapeTracker in the AST, there's probably a way to fix this
and not ( x . op is Ops . CONST and x . st is None ) ] ) )
@property
def shape ( self ) - > tuple [ sint , . . . ] : return unwrap ( self . st ) . shape
@property
def size ( self ) - > int : return self . arg [ 0 ] if self . op is Ops . BUFFER_VIEW else self . arg if self . op is Ops . BUFFER else unwrap ( self . st ) . size
# *** uop evaluation ***
def simplify ( self ) :
# late import!
from tinygrad . codegen . symbolic import symbolic
with Context ( TRACK_MATCH_STATS = 0 ) :
return graph_rewrite ( self , symbolic )
def ssimplify ( self ) - > Union [ UOp , ConstType ] : return ret . arg if ( ret := self . simplify ( ) ) . op is Ops . CONST else ret
def _eval ( self , dtype , expected_type : Type [ T ] ) - > T :
assert self . dtype in dtype , f " eval with wrong dtype { self } "
vmin , vmax = ( simple_self := self . simplify ( ) ) . _min_max
if vmin != vmax : raise ValueError ( f " eval failed to be a single number, range is { vmin } to { vmax } in { simple_self . render ( ) } " )
assert isinstance ( vmin , expected_type ) , f " vmin is wrong dtype { type ( vmin ) } != { expected_type } "
return vmin
def __bool__ ( self ) : return self . _eval ( ( dtypes . bool , ) , bool )
def __int__ ( self ) : return self . _eval ( dtypes . ints , int )
def __float__ ( self ) : return self . _eval ( dtypes . floats , float )
def substitute ( self , dvars : dict [ UOp , UOp ] ) :
with Context ( TRACK_MATCH_STATS = 0 ) :
return graph_rewrite ( self , _substitute , dvars , bottom_up = True )
# *** uop syntactic sugar ***
@property
def st_arg ( self ) - > ShapeTracker :
assert self . op in GroupOp . Buffer , f " st_arg called on { self . op } "
return unwrap ( self . st )
@property
def axis_arg ( self ) - > tuple [ int , . . . ] :
assert self . op in { Ops . REDUCE_AXIS , Ops . WMMA } , f " axis_arg called on { self . op } "
ret = self . arg [ 1 ] if self . op is Ops . REDUCE_AXIS else self . arg [ 7 ]
assert isinstance ( ret , tuple ) and all ( isinstance ( x , int ) for x in ret ) , f " axis_arg trying to return { ret } "
return ret
def sink ( self , * srcs : UOp , * * kwargs ) : return UOp ( Ops . SINK , dtypes . void , ( self , ) + srcs , * * kwargs )
def detach ( self ) : return UOp ( Ops . DETACH , self . dtype , ( self , ) )
def index ( self , idx : UOp , valid : UOp | None = None ) : return UOp ( Ops . INDEX , self . dtype , ( self , idx , valid ) if valid is not None else ( self , idx ) )
def const_like ( self , b : ConstLike ) :
# constants can optionally have a DEVICE source
if self . _device is None : return UOp . const ( self . dtype , b )
if isinstance ( self . device , tuple ) : return UOp . multi ( * [ UOp . metaop ( Ops . CONST , self . shape , self . dtype , d , b ) for d in self . device ] , axis = None )
return UOp . metaop ( Ops . CONST , self . shape , self . dtype , self . device , b )
def broadcast ( self , count : int ) :
assert self . dtype . count == 1
if count == 1 : return self
return UOp ( Ops . VECTORIZE , self . dtype . vec ( count ) , ( self , ) * count )
def cast ( self , dtype : DType ) : return UOp ( Ops . CAST , dtype , ( self , ) )
def cast_vec ( self , dtype : DType ) : return UOp ( Ops . CAST , dtype . vec ( self . dtype . count ) , ( self , ) )
def bitcast ( self , dtype : DType ) : return UOp ( Ops . BITCAST , dtype , ( self , ) )
def gep ( self , i : Union [ tuple [ int , . . . ] , int ] ) :
if isinstance ( i , int ) :
# NOTE: these are just shortcuts to not have to create and fold later
if self . op is Ops . VECTORIZE : return self . src [ i ]
if self . op is Ops . VCONST : return UOp . const ( self . dtype . scalar ( ) , self . arg [ i ] )
if self . op is Ops . CONST : return UOp . const ( self . dtype . scalar ( ) , self . arg )
i = ( i , )
if ( self . dtype . vcount == len ( i ) and i == tuple ( range ( len ( i ) ) ) ) or self . dtype == dtypes . void : return self
return UOp ( Ops . GEP , self . dtype . scalar ( ) . vec ( len ( i ) ) if len ( i ) > 1 else self . dtype . scalar ( ) , ( self , ) , i )
def load ( self , * src : UOp , * * kwargs ) : return UOp ( Ops . LOAD , src = ( self , ) + src , * * kwargs )
def store ( self , * src : UOp , * * kwargs ) : return UOp ( Ops . STORE , dtypes . void , ( self , ) + src , * * kwargs )
def alu ( self , arg , * src : UOp ) :
out_dtype = ( self , * src ) [ - 1 ] . dtype
if arg in { Ops . CMPLT , Ops . CMPNE } : out_dtype = dtypes . bool . vec ( out_dtype . count ) if out_dtype . count > 1 else dtypes . bool
return UOp ( arg , out_dtype , ( self , ) + src )
@staticmethod
def const ( dtype : DType , b : ConstLike ) :
if isinstance ( b , UOp ) : return b . unbind ( ) [ 0 ] if b . op is Ops . BIND else b
if isinstance ( b , tuple ) and all_same ( b ) : b = b [ 0 ] # doesn't have to be a VCONST if they are all the same
return UOp ( Ops . VCONST if isinstance ( b , tuple ) else Ops . CONST , dtype , arg = dtypes . as_const ( b , dtype ) )
def valid ( self , st : ShapeTracker ) :
assert self . op in { Ops . CONST , Ops . DEFINE_VAR } , f " can only create VALID from a constant, got { self . op } "
from tinygrad . shape . shapetracker import ShapeTracker
# NOTE: only VALID has a masked ShapeTracker, the CONST operands are unmasked
unmasked_st = ShapeTracker . from_shape ( ( ) ) . reshape ( ( 1 , ) * len ( st . shape ) ) . expand ( st . shape ) . to_uop ( )
return UOp ( Ops . VALID , dtypes . bool , ( st . to_uop ( ) , ) ) . where ( self . replace ( src = ( unmasked_st , ) ) , UOp . const ( self . dtype , 0 ) . replace ( src = ( unmasked_st , ) ) )
@staticmethod
def range ( dtype : DType , start : sint , end : sint , idx : int ) : return UOp ( Ops . RANGE , dtype = dtype , src = ( sint_to_uop ( start ) , sint_to_uop ( end ) ) , arg = idx )
def r ( self , op : Ops , axis : tuple [ int , . . . ] ) :
axis = tuple ( sorted ( [ x for x in axis if resolve ( self . shape [ x ] != 1 ) ] ) )
return self if len ( axis ) == 0 else UOp ( Ops . REDUCE_AXIS , self . dtype , ( self , ) , ( op , axis ) )
def assign ( self , x : UOp ) : return UOp ( Ops . ASSIGN , self . dtype , ( self , x ) )
def contiguous ( self ) : return self . alu ( Ops . CONTIGUOUS )
def contiguous_backward ( self ) : return self . alu ( Ops . CONTIGUOUS_BACKWARD )
# *** from MultiLazyBuffer ***
def multi ( self , * more : UOp , axis : int | None , real : tuple [ bool , . . . ] | None = None ) :
parents = ( self , ) + more
assert all_same ( [ x . dtype for x in parents ] ) , " multi parents must have the same dtype "
return UOp ( Ops . MULTI , self . dtype , parents , ( axis , real if real is not None else ( True , ) * len ( parents ) ) )
@property
def bounds ( self ) :
if self . axis is None : raise RuntimeError ( " bounds is not defined when axis is None " )
return tuple ( itertools . pairwise ( itertools . accumulate ( [ lb . shape [ self . axis ] for lb in self . src ] , initial = 0 ) ) )
@functools . cached_property
def axis ( self ) - > Optional [ int ] :
if self . op is Ops . MULTI : return self . arg [ 0 ]
# NOTE: they all have to share an axis, we always choose [-1]
if self . op in GroupOp . ALU : return axes [ - 1 ] if ( axes := dedup ( [ x . axis for x in self . src if x . axis is not None ] ) ) else None
src_axis = self . src [ 0 ] . axis
if self . op is Ops . REDUCE_AXIS : return None if src_axis is not None and src_axis in self . arg [ 1 ] else src_axis
if self . op is Ops . RESHAPE :
if src_axis is None : return None
arg_acc : list [ sint ] = list ( itertools . accumulate ( self . arg , operator . mul , initial = 1 ) )
# new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
# TODO: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
return len ( arg_acc ) - arg_acc [ : : - 1 ] . index ( prod ( self . src [ 0 ] . shape [ : src_axis ] ) ) - 1
if self . op is Ops . PERMUTE : return self . arg . index ( src_axis ) if src_axis is not None else None
return src_axis
@property
def real ( self ) :
assert self . op is Ops . MULTI
return self . arg [ 1 ]
@property
def real_lbs ( self ) : return [ lb for lb , r in zip ( self . src , self . real ) if r ]
def shard ( self , devices : tuple [ str , . . . ] , axis : Optional [ int ] = None ) - > UOp :
if axis is None : lbs = [ self ] * len ( devices )
else :
if self . shape [ axis ] % len ( devices ) != 0 : raise RuntimeError ( f " multi axis uneven: { self . shape [ axis ] =} { axis =} { len ( devices ) =} " )
# NOTE: this works for both even shards and uneven shards
sz = self . shape [ axis ] / / len ( devices )
sizes = [ max ( 0 , min ( sz , self . shape [ axis ] - sz * i ) ) for i in range ( len ( devices ) ) ]
lbs = [ ]
for sz , off in zip ( sizes , itertools . accumulate ( sizes , initial = 0 ) ) :
lbs . append ( self . shrink ( tuple ( ( 0 , s ) if i != axis else ( off , off + sz ) for i , s in enumerate ( self . shape ) ) ) )
sharded_lbs = [ lb . copy_to_device ( d ) for lb , d in zip ( lbs , devices ) ]
return UOp . multi ( * [ lb . contiguous ( ) for lb in sharded_lbs ] , axis = axis )
# *** from LazyBuffer ***
@staticmethod
def metaop ( op : Ops , shape : tuple [ sint , . . . ] , dtype : DType , device : str , arg = None ) - > UOp :
from tinygrad . shape . shapetracker import ShapeTracker
# Tensor const is CONST(VIEW(DEVICE)) -> RESHAPE -> EXPAND
if op is Ops . CONST :
assert isinstance ( arg , get_args ( ConstType ) ) , f " trying to create CONST with { arg =} "
return UOp . const ( dtype , unwrap ( arg ) ) . replace ( src = ( UOp ( Ops . VIEW , dtypes . void , ( UOp ( Ops . DEVICE , arg = device ) , ) ,
ShapeTracker . from_shape ( ( ) ) ) , ) ) . reshape ( ( 1 , ) * len ( shape ) ) . expand ( shape )
# Tensor variable binding is BIND(VAR(VIEW(DEVICE)), CONST(VIEW(DEVICE)))
if op is Ops . BIND :
var , val = arg . unbind ( )
return var . replace ( src = ( UOp ( Ops . VIEW , dtypes . void , ( UOp ( Ops . DEVICE , arg = device ) , ) , ShapeTracker . from_shape ( shape ) ) , ) ) . bind ( val )
# otherwise it's just a RESHAPE(BUFFER)
if not isinstance ( size := prod ( [ x . vmax if isinstance ( x , UOp ) else x for x in shape ] ) , int ) : raise ValueError ( f " size must be int { size } " )
return UOp . new_buffer ( device , size , dtype ) . reshape ( shape )
def copy_to_device ( self , device : str | tuple [ str , . . . ] , clone : bool = False ) : return UOp ( Ops . COPY , self . dtype , ( UOp ( Ops . DEVICE , arg = device ) , self ) , clone )
def clone ( self ) - > UOp : return self . copy_to_device ( self . device , clone = True )
@property
def metadata ( self ) - > tuple [ Metadata , . . . ] | Metadata | None : return self . arg . metadata if self . op is Ops . KERNEL else all_metadata . get ( self , None )
# *** uop movement ops ***
@property
def base ( self ) - > UOp :
if ( self . op is Ops . VIEW and len ( self . src ) != 0 ) or self . op in GroupOp . Movement : return self . src [ 0 ] . base
return self
def view ( self , new_st : ShapeTracker ) - > UOp : return UOp ( Ops . VIEW , self . dtype , ( self . base , ) , new_st )
def _mop ( self , op : Ops , arg ) :
ret = UOp ( op , self . dtype , ( self , ) , arg )
if self . st == ret . st : return self # ignore NOOPs, also check ret.st
return ret
def reshape ( self , arg : tuple [ sint , . . . ] ) : return self . _mop ( Ops . RESHAPE , arg )
def pad ( self , arg : tuple [ tuple [ sint , sint ] , . . . ] ) : return self . _mop ( Ops . PAD , arg )
def expand ( self , arg : tuple [ sint , . . . ] ) : return self . _mop ( Ops . EXPAND , arg )
def permute ( self , arg : tuple [ sint , . . . ] ) : return self . _mop ( Ops . PERMUTE , arg )
def shrink ( self , arg : tuple [ tuple [ sint , sint ] , . . . ] ) : return self . _mop ( Ops . SHRINK , arg )
def flip ( self , arg : tuple [ bool , . . . ] ) : return self . _mop ( Ops . FLIP , arg )
# *** uop UNIQUE ***
# TODO: use this in Buffer
unique_num = itertools . count ( 0 )
@staticmethod
def unique ( ) : return UOp ( Ops . UNIQUE , arg = next ( UOp . unique_num ) )
# *** uop Buffer stuff ***
@staticmethod
def new_buffer ( device : str , size : int , dtype : DType ) : return UOp ( Ops . BUFFER , dtype , ( UOp ( Ops . DEVICE , arg = device ) , UOp . unique ( ) ) , size )
@property
def device ( self ) - > str | tuple [ str , . . . ] : return cast ( str | tuple [ str , . . . ] , unwrap ( self . _device ) )
@functools . cached_property
def _device ( self ) - > Optional [ str | tuple [ str , . . . ] ] :
if self . op is Ops . DEVICE : return self . arg
if self . op is Ops . MULTI : return tuple ( cast ( str , x . device ) for x in self . src )
return dsrcs [ 0 ] . _device if len ( dsrcs := [ x for x in self . src if x . _device is not None ] ) != 0 else None
@property
def buf_uop ( self ) - > UOp :
if self . op is Ops . BUFFER : return self
assert self . op is Ops . ASSIGN , f " must be ASSIGN { self . op } "
return self . src [ 0 ] . base
@property
def buffer ( self ) - > Buffer :
if self is not self . base :
assert unwrap ( self . st ) . contiguous , " VIEW only works here if it ' s contiguous "
return self . src [ 0 ] . buffer
assert self . op is Ops . BUFFER , f " must be BUFFER { self . op } "
if ( cret := buffers . get ( self ) ) is not None : return cret
from tinygrad . device import Buffer
assert isinstance ( self . device , str ) , f " buffer not supported on multi { self . device } "
buffers [ self ] = ret = Buffer ( self . device , self . size , self . dtype if isinstance ( self . dtype , ImageDType ) else self . dtype . base )
ret . ref ( 1 )
return ret
@property
def realized ( self ) - > Optional [ Buffer ] : return self . buffer if self . op is Ops . BUFFER and self . buffer . is_allocated ( ) else None
@property
def is_realized ( self ) - > bool :
return all ( x . base . realized is not None for x in self . base . real_lbs ) if self . base . op is Ops . MULTI else self . base . realized is not None
# *** uop Variable stuff ***
@staticmethod
def variable ( name : str , min_val : ConstType , max_val : ConstType , dtype : DType = dtypes . int ) :
assert not isinstance ( min_val , UOp ) and not isinstance ( max_val , UOp ) , f " can ' t create Variable { name } with { min_val } / { max_val } "
return UOp ( Ops . DEFINE_VAR , dtype , arg = ( name , min_val , max_val ) )
@property
def expr ( self ) :
assert self . op is Ops . DEFINE_VAR , f " op is { self . op } , need DEFINE_VAR "
return self . arg [ 0 ]
def bind ( self , val : int ) :
assert self . op is Ops . DEFINE_VAR , f " op is { self . op } , need DEFINE_VAR "
assert self . arg [ 1 ] < = val and val < = self . arg [ 2 ] , f " bind { val } not in range [ { self . arg [ 1 ] } , { self . arg [ 2 ] } ] "
return UOp ( Ops . BIND , self . dtype , ( self , self . const_like ( val ) ) )
def unbind ( self ) - > tuple [ Variable , int ] :
assert self . op is Ops . BIND and self . src [ 0 ] . op is Ops . DEFINE_VAR and self . src [ 1 ] . op is Ops . CONST , f " can ' t unbind { self } "
return self . src [ 0 ] , self . src [ 1 ] . arg
@property
def val ( self ) - > int : return self . unbind ( ) [ 1 ]
def vars ( self ) - > set [ UOp ] :
bound_vars = set ( [ x for x in self . toposort if x . op is Ops . BIND and x . src [ 0 ] . op is Ops . DEFINE_VAR ] )
bound_var_base = set ( x . src [ 0 ] for x in bound_vars )
all_vars = set ( [ x for x in self . toposort if x . op is Ops . DEFINE_VAR ] )
return bound_vars . union ( set ( [ x for x in all_vars if x not in bound_var_base ] ) )
def variables ( self ) - > list [ Variable ] :
st_vars : list [ set [ Variable ] ] = [ x . st_arg . vars ( ) for x in self . toposort if x . op in GroupOp . Buffer ]
return sorted ( set . union ( * st_vars , [ x . unbind ( ) [ 0 ] if x . op is not Ops . DEFINE_VAR else x for x in self . vars ( ) ] ) , key = lambda v : v . arg )
# *** uop symbolic stuff ***
def is_increasing ( self : UOp ) - > bool :
# is f a monotonically increasing function regards its input
if self . op in GroupOp . Irreducible : return True
if self . op is Ops . ADD : return self . src [ 0 ] . is_increasing ( ) and self . src [ 1 ] . is_increasing ( )
if self . op in ( Ops . MUL , Ops . IDIV ) and self . src [ 1 ] . op is Ops . CONST and self . src [ 1 ] . arg > = 0 : return self . src [ 0 ] . is_increasing ( )
return False # False if not sure
def const_factor ( self ) - > int :
""" largest known int that divides self """
if self . op is Ops . CONST : return self . arg
if self . op is Ops . VCONST : return math . gcd ( * self . arg )
if self . op is Ops . ADD : return math . gcd ( self . src [ 0 ] . const_factor ( ) , self . src [ 1 ] . const_factor ( ) )
if self . op is Ops . MUL : return self . src [ 0 ] . arg if self . src [ 0 ] . op is Ops . CONST else self . src [ 1 ] . arg if self . src [ 1 ] . op is Ops . CONST else 1
return 1
def divides ( self , v : int ) - > UOp | None :
if v == 1 : return self
if self . op is Ops . CONST : return self . const_like ( self . arg / / v ) if self . arg % v == 0 else None
if self . op is Ops . VCONST : return self . const_like ( tuple ( x / / v for x in self . arg ) ) if all ( x % v == 0 for x in self . arg ) else None
if self . op is Ops . ADD : return d0 + d1 if ( d0 := self . src [ 0 ] . divides ( v ) ) is not None and ( d1 := self . src [ 1 ] . divides ( v ) ) is not None else None
if self . op is Ops . MUL :
if ( d0 := self . src [ 0 ] . divides ( v ) ) is not None : return d0 * self . src [ 1 ]
if ( d1 := self . src [ 1 ] . divides ( v ) ) is not None : return self . src [ 0 ] * d1
return None # generic None if we aren't sure
@property
def vmin ( self ) - > ConstType : return self . _min_max [ 0 ]
@property
def vmax ( self ) - > ConstType : return self . _min_max [ 1 ]
@functools . cached_property
def _min_max ( self ) - > tuple [ ConstType , ConstType ] :
if self . op in GroupOp . Binary and not dtypes . is_float ( self . dtype ) :
( s0_vmin , s0_vmax ) , ( s1_vmin , s1_vmax ) = self . src [ 0 ] . _min_max , self . src [ 1 ] . _min_max
if self . op is Ops . ADD : return s0_vmin + s1_vmin , s0_vmax + s1_vmax
if self . op is Ops . SUB : return s0_vmin - s1_vmax , s0_vmax - s1_vmin
if self . op is Ops . AND and s1_vmin == s1_vmax and s0_vmin > = 0 and s1_vmin > = 0 : return min ( 0 , s0_vmin ) , min ( s0_vmax , s1_vmax )
if self . op is Ops . MUL : return min ( vals := ( s0_vmin * s1_vmin , s0_vmin * s1_vmax , s0_vmax * s1_vmin , s0_vmax * s1_vmax ) ) , max ( vals )
# SHL/SHR on consts only
if self . op is Ops . SHL and s1_vmin == s1_vmax and all_int ( t := ( s0_vmin , s0_vmax , s1_vmin ) ) : return t [ 0 ] << t [ 2 ] , t [ 1 ] << t [ 2 ]
if self . op is Ops . SHR and s1_vmin == s1_vmax and all_int ( t := ( s0_vmin , s0_vmax , s1_vmin ) ) : return t [ 0 ] >> t [ 2 ] , t [ 1 ] >> t [ 2 ]
if self . op is Ops . MOD and s1_vmin > 0 :
return ( 0 , s1_vmax - 1 ) if s0_vmin > = 0 else ( - ( s1_vmax - 1 ) , s1_vmax - 1 )
if self . op is Ops . IDIV :
if s1_vmin == s1_vmax : # min/max are equal in a CONST
if s1_vmin > 0 : return s0_vmin / / s1_vmin , s0_vmax / / s1_vmin
if s1_vmin < 0 and s0_vmin > = 0 : return - ( s0_vmax / / - s1_vmin ) , - ( s0_vmin / / - s1_vmin )
# don't know exact bounds, but know the sign
if ( s0_vmax < = 0 and s1_vmin < 0 ) or ( s0_vmin > = 0 and s1_vmin > 0 ) : return 0 , dtypes . max ( self . dtype )
if ( s0_vmax < = 0 and s1_vmin > 0 ) or ( s0_vmin > = 0 and s1_vmin < 0 ) : return dtypes . min ( self . dtype ) , 0
if self . op is Ops . MAX : return max ( s0_vmin , s1_vmin ) , max ( s0_vmax , s1_vmax )
if self . op is Ops . CMPLT : return ( s0_vmax < s1_vmin , s0_vmin < s1_vmax )
if self . op is Ops . CMPNE : return ( ( s0_vmax < s1_vmin ) or ( s1_vmax < s0_vmin ) , not ( s0_vmin == s0_vmax == s1_vmin == s1_vmax ) )
if self . dtype == dtypes . bool :
if self . op is Ops . OR : return s0_vmin or s1_vmin , s0_vmax or s1_vmax
if self . op is Ops . AND : return s0_vmin and s1_vmin , s0_vmax and s1_vmax
# float has NAN issue and we use explicit NAN in transcendental
if self . op is Ops . WHERE and dtypes . is_int ( self . dtype ) : return min ( self . src [ 1 ] . vmin , self . src [ 2 ] . vmin ) , max ( self . src [ 1 ] . vmax , self . src [ 2 ] . vmax )
# NOTE: returned UOp is assumed to be CONST
if self . op is Ops . DEFINE_VAR and self . arg : return self . arg [ 1 ] , self . arg [ 2 ]
if self . op is Ops . RANGE : return self . src [ 0 ] . vmin , ( self . src [ 1 ] - 1 ) . vmax
if self . op is Ops . BIND : return self . src [ 0 ] . _min_max # ignore the bound value
if self . op in { Ops . UNROLL , Ops . VECTORIZE } : return min ( x . vmin for x in self . src ) , max ( x . vmax for x in self . src )
# TODO: Ops.SPECIAL is Ops.DEFINE_VAR
if self . op is Ops . SPECIAL : return 0 , self . arg [ 1 ] - 1 if isinstance ( self . arg [ 1 ] , int ) else self . arg [ 1 ] . vmax
if self . op is Ops . CONST : return self . arg , self . arg
if self . op is Ops . VCONST : return ( min ( self . arg ) , max ( self . arg ) )
if self . op is Ops . CAST : return max ( dtypes . min ( self . dtype ) , self . src [ 0 ] . vmin ) , min ( self . src [ 0 ] . vmax , dtypes . max ( self . dtype ) )
return dtypes . min ( self . dtype ) , dtypes . max ( self . dtype )
@functools . cached_property
def _sym_fxn ( self ) :
sself = self . simplify ( )
varnames = tuple ( x . arg [ 0 ] for x in sself . toposort if x . op is Ops . DEFINE_VAR )
# TODO: sanitize varnames, or don't use naked eval while staying fast
return eval ( " lambda " + ' , ' . join ( varnames ) + " : " + sself . render ( ) ) , varnames # pylint: disable=eval-used
def sym_infer ( self , var_vals : dict [ UOp , int ] ) :
fxn , varnames = self . _sym_fxn
return fxn ( * * { k . arg [ 0 ] : v for k , v in var_vals . items ( ) if k . arg [ 0 ] in varnames } )
def render ( self , simplify = True ) - > str :
ret = graph_rewrite ( self . simplify ( ) if simplify else self , renderer )
return ret . arg if ret . op is Ops . NOOP else str ( ret )
@dataclass ( frozen = True )
class KernelInfo :
name : str = " test " # name of the kernel
local_dims : int = 0 # number of local dimensions (this is remapping RANGE to SPECIAL)
upcasted : int = 0 # count that are upcasted (this is remapping RANGE to UNROLL)
dont_use_locals : bool = False # don't use local indexing
# ******** ops in python ********
def safe_exp2 ( x ) :
try : return 2 * * x
except OverflowError : return math . inf
def safe_pow ( x , y ) :
try : return math . nan if isinstance ( p := pow ( x , y ) , complex ) else p
except ZeroDivisionError : return math . inf
except ValueError : return math . inf if x > 0 else - math . inf
python_alu : dict [ Ops , Callable ] = {
Ops . LOG2 : lambda x : math . log2 ( x ) if x > 0 else - math . inf if x == 0 else math . nan , Ops . EXP2 : safe_exp2 ,
Ops . SQRT : lambda x : math . sqrt ( x ) if x > = 0 else math . nan , Ops . RECIP : lambda x : 1 / x if x != 0 else math . copysign ( math . inf , x ) ,
Ops . SIN : lambda x : math . sin ( x ) if not math . isinf ( x ) else math . nan , Ops . POW : safe_pow ,
Ops . NEG : operator . neg , Ops . ADD : operator . add , Ops . SUB : operator . sub , Ops . MUL : operator . mul , Ops . CMPNE : operator . ne , Ops . CMPLT : operator . lt ,
Ops . XOR : operator . xor , Ops . OR : operator . or_ , Ops . AND : operator . and_ , Ops . SHR : operator . rshift , Ops . SHL : operator . lshift , Ops . MAX : max ,
Ops . MOD : cmod , Ops . IDIV : cdiv , Ops . MULACC : lambda x , y , z : ( x * y ) + z , Ops . WHERE : lambda x , y , z : y if x else z }
def exec_alu ( op : Ops , dtype : DType , operands , truncate_output = True ) :
if dtype . count > 1 :
return tuple ( [ exec_alu ( op , dtype . scalar ( ) , [ x [ i ] if isinstance ( x , tuple ) else x for x in operands ] ) for i in range ( dtype . count ) ] )
alu = python_alu [ op ] ( * operands )
return truncate . get ( dtype , lambda x : x ) ( alu ) if truncate_output else alu
# ***** uop helpers *****
def print_uops ( uops : list [ UOp ] ) :
for i , u in enumerate ( uops ) :
formatted_parents = [ ( uops . index ( x ) if x . op is not Ops . CONST else f " { x . arg } " ) if x in uops else " -- " for x in u . src ]
print ( f " { i : 4d } { str ( u . op ) : 20s } : { str ( u . dtype ) : 30s } " f " { str ( formatted_parents ) : 32s } { u . arg } " )
# ***** pattern matcher *****
def get_location ( ) - > tuple [ str , int ] :
frm = sys . _getframe ( 1 )
# find the real frame in the file that has the UPat, TODO: is there a better way to do this?
while frm . f_back is not None and pathlib . Path ( frm . f_back . f_code . co_filename ) . name in { " ops.py " , " rewriter.py " , " schedule.py " , " multi.py " ,
" symbolic.py " , " expander.py " , " lowerer.py " , " cstyle.py " ,
" linearize.py " , " devectorizer.py " } :
frm = frm . f_back
return frm . f_code . co_filename , frm . f_lineno
@functools . lru_cache ( None )
def lines ( fn ) - > list [ str ] :
with open ( fn ) as f : return f . readlines ( )
class UPat ( MathTrait ) :
__slots__ = ( " op " , " dtype " , " arg " , " name " , " src " )
def __init__ ( self , op : Optional [ Union [ Ops , tuple [ Ops , . . . ] , set [ Ops ] ] ] = None , dtype : Optional [ Union [ DType , tuple [ DType , . . . ] ] ] = None ,
src : Optional [ Union [ tuple [ UPat , . . . ] , list [ UPat ] , UPat ] ] = None , arg : Any = None ,
name : Optional [ str ] = None , allow_any_len : bool = False , location = None , custom_early_reject : Optional [ set [ Ops ] ] = None ) :
assert op is None or isinstance ( op , ( Ops , tuple , set ) ) , " op must be Ops or tuple of Ops "
self . op : Optional [ tuple [ Ops , . . . ] ] = ( op , ) if isinstance ( op , Ops ) else ( tuple ( op ) if isinstance ( op , set ) else op )
self . dtype : Optional [ tuple [ DType , . . . ] ] = ( dtype , ) if isinstance ( dtype , DType ) else dtype
self . arg , self . name , self . _in_src , self . custom_early_reject = arg , name , src , custom_early_reject
self . src : Any = None
assert self . name != " ctx " , " UPat can ' t be named ctx "
# try all permutations if it's a list
if isinstance ( src , list ) : self . src = list ( itertools . permutations ( src ) ) if not all_same ( src ) else [ src ]
# only one if it's a tuple
elif isinstance ( src , tuple ) : self . src = [ src ]
# repeat if it's a UPat
elif isinstance ( src , UPat ) : self . src = [ itertools . repeat ( src ) ]
self . allowed_len : int = - 1 if allow_any_len or isinstance ( src , UPat ) or src is None else len ( src )
self . location = location or get_location ( )
if custom_early_reject is not None : self . early_reject = custom_early_reject
else :
upat_match = [ src ] if isinstance ( src , UPat ) else ( [ ] if src is None else self . src [ 0 ] )
self . early_reject = { pp . op [ 0 ] for pp in upat_match if pp . op is not None and len ( pp . op ) == 1 }
def named ( self , name : str ) : return UPat ( self . op , self . dtype , self . _in_src , self . arg , name , self . allowed_len == - 1 , self . custom_early_reject )
@staticmethod
def any ( * src ) : return UPatAny ( src = src )
def or_casted ( self , name : str | None = None ) : return UPat . any ( self if name is None else self . named ( name ) , UPat ( Ops . CAST , name = name , src = ( self , ) ) )
@staticmethod
@functools . lru_cache ( None )
def var ( name : Optional [ str ] = None , dtype : Optional [ Union [ DType , tuple [ DType , . . . ] ] ] = None ) : return UPat ( dtype = dtype , name = name )
@staticmethod
@functools . lru_cache ( None )
def cvar ( name : Optional [ str ] = None , dtype : Optional [ DType ] = None , vec = True ) : return UPat ( ( Ops . CONST , Ops . VCONST ) if vec else Ops . CONST , dtype , name = name )
@staticmethod
def const ( dtype : Optional [ Union [ DType , tuple [ DType , . . . ] ] ] , b : ConstType ) : return UPat ( Ops . CONST , dtype = dtype , arg = b )
# copied from UOp
def index ( self , idx : UPat , valid : Optional [ UPat ] = None ) : return UPat ( Ops . INDEX , self . dtype , ( self , idx , valid ) if valid is not None else ( self , idx ) )
def view ( self , st = None , * * kwargs ) : return UPat ( Ops . VIEW , self . dtype , ( self , ) , st , * * kwargs )
def cast ( self , dtype = None ) : return UPat ( Ops . CAST , dtype , ( self , ) )
def bitcast ( self , dtype = None ) : return UPat ( Ops . BITCAST , dtype , ( self , ) )
def gep ( self , i : int ) : return UPat ( Ops . GEP , None , ( self , ) , ( i , ) )
def load ( self , * src : UPat , * * kwargs ) : return UPat ( Ops . LOAD , src = ( self , ) + src , * * kwargs )
def store ( self , * src : UPat , * * kwargs ) : return UPat ( Ops . STORE , dtypes . void , ( self , ) + src , * * kwargs )
def assign ( self , x : UPat , * * kwargs ) : return UPat ( Ops . ASSIGN , self . dtype , ( self , x ) , * * kwargs )
def const_like ( self , b : ConstLike ) : return UPat . const ( self . dtype , cast ( ConstType , b ) )
def alu ( self , op : Ops , * src : UPat ) :
asrc = ( self , ) + src
return UPat ( op , dtypes . bool if op in { Ops . CMPLT , Ops . CMPNE } else asrc [ - 1 ] . dtype , list ( asrc ) if op in GroupOp . Commutative else asrc )
def printable ( self : UPat ) - > str :
try : return lines ( self . location [ 0 ] ) [ self . location [ 1 ] - 1 ] . strip ( )
except FileNotFoundError : return " <missing> "
def __repr__ ( self ) :
def rep ( x ) :
form = " UPat( %s , %s , name= %s , dtype= %s , allow_any_len= %s , src= %s ) "
return form % ( None if x . op is None else ( ' ( %s ) ' % ' , ' . join ( map ( str , x . op ) ) ) , x . arg , repr ( x . name ) ,
set ( x . dtype ) if x . dtype else None , x . allowed_len == 0 , " [ %s ] " if x . src and len ( x . src ) > 1 else " ( %s ) " )
return pretty_print ( self , rep , srcfn = lambda x : None if x . src is None else [ next ( x . src [ 0 ] ) ] if isinstance ( x . src [ 0 ] , itertools . repeat ) else x . src [ 0 ] )
def match ( self : UPat , uop : UOp , store : dict [ str , UOp ] ) - > list [ dict [ str , UOp ] ] :
if ( self . op is not None and uop . op not in self . op ) or \
( self . name is not None and store . setdefault ( self . name , uop ) is not uop ) or \
( self . dtype is not None and uop . dtype not in self . dtype and uop . dtype . scalar ( ) not in self . dtype ) or \
( self . arg is not None and self . arg != uop . arg ) or \
( self . allowed_len != - 1 and len ( uop . src ) != self . allowed_len ) : return [ ]
if self . src is None : return [ store ]
res : list [ dict [ str , UOp ] ] = [ ]
for vp in self . src :
stores , new_stores = [ store . copy ( ) ] , [ ]
for uu , vv in zip ( uop . src , vp ) :
for s in stores : new_stores . extend ( vv . match ( uu , s ) )
stores , new_stores = new_stores , [ ]
res . extend ( stores )
return res
class UPatAny ( UPat ) :
def match ( self : UPat , uop : UOp , store : dict [ str , UOp ] ) - > list [ dict [ str , UOp ] ] :
matches = [ x . match ( uop , store . copy ( ) ) for x in self . src [ 0 ] ]
return flatten ( [ x for x in matches if x is not None ] )
def deconstruct_function ( fxn : Callable ) - > tuple :
new_globals = { k : v for k , v in fxn . __globals__ . items ( ) if k in fxn . __code__ . co_names }
for co in fxn . __code__ . co_consts :
if isinstance ( co , types . CodeType ) : new_globals . update ( { k : v for k , v in fxn . __globals__ . items ( ) if k in co . co_names } )
# NOTE: optional round trip through pickle!
assert fxn . __closure__ is None , " closures are not supported in pattern matchers "
ret = fxn . __code__ , new_globals , fxn . __name__ , fxn . __defaults__
return pickle . loads ( pickle . dumps ( ret ) ) if getenv ( " TEST_PICKLE " ) else ret
class PatternMatcher :
def __init__ ( self , patterns : list [ tuple [ UPat , Callable ] ] ) :
self . patterns = patterns
# NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher!
self . pdict : dict [ Ops , list [ tuple [ UPat , Callable , set , bool ] ] ] = { }
# uop is required, arg is optional
for p , fxn in self . patterns :
assert p . op is not None
tuple_fxn = fxn if isinstance ( fxn , tuple ) else deconstruct_function ( fxn )
real_fxn = types . FunctionType ( * tuple_fxn )
for uop in p . op : self . pdict . setdefault ( uop , [ ] ) . append ( ( p , real_fxn , p . early_reject , ' ctx ' in inspect . signature ( real_fxn ) . parameters ) )
def __reduce__ ( self ) : return PatternMatcher , ( [ ( x , deconstruct_function ( fxn ) if fxn . __name__ == " <lambda> " else fxn ) for x , fxn in self . patterns ] , )
@functools . lru_cache ( None ) # pylint: disable=method-cache-max-size-none
def __add__ ( self , more : PatternMatcher ) : return PatternMatcher ( self . patterns + more . patterns )
def rewrite ( self , uop : UOp , ctx = None ) - > UOp | None :
ler = { u . op for u in uop . src }
for p , fxn , early_reject , has_ctx in self . pdict . get ( uop . op , [ ] ) :
if not early_reject . issubset ( ler ) : continue
for match in p . match ( uop , { } ) :
if ( ret := ( fxn ( ctx = ctx , * * match ) if has_ctx else fxn ( * * match ) ) ) is not None : return ret
return None
# *** tracking pattern matcher ***
TRACK_MATCH_STATS = ContextVar ( " TRACK_MATCH_STATS " , 2 if getenv ( " VIZ " ) else 0 )
match_stats : dict [ UPat , list [ Union [ int , float ] ] ] = dict ( )
@dataclass ( frozen = True )
class TrackedGraphRewrite :
loc : tuple [ str , int ] # location that called graph_rewrite
sink : UOp # the sink input to graph_rewrite
bottom_up : bool
matches : list [ tuple [ UOp , UOp , UPat ] ] # before+after of all the matches
name : str | None
tracked_keys : list [ Any ] = [ ]
tracked_ctxs : list [ list [ TrackedGraphRewrite ] ] = [ ]
_name_cnt : dict [ str , int ] = { }
def track_rewrites ( named = False , name_fxn : Callable | None = None ) :
def _decorator ( func ) :
def __wrapper ( self , * args , * * kwargs ) :
if TRACK_MATCH_STATS > = 2 :
if ( count_names := ( named or name_fxn ) ) : _name_cnt [ func . __name__ ] = _name_cnt . get ( func . __name__ , 0 ) + 1
tracked_keys . append ( f " { func . __name__ } _ { _name_cnt [ func . __name__ ] } " if count_names else self )
tracked_ctxs . append ( [ ] )
ret = func ( self , * args , * * kwargs )
if TRACK_MATCH_STATS > = 2 and name_fxn is not None : tracked_keys [ - 1 ] = f " { name_fxn ( ret ) } n { _name_cnt [ func . __name__ ] } "
return ret
return __wrapper
return _decorator
active_rewrites : list [ TrackedGraphRewrite ] = [ ]
def track_matches ( func ) :
def _track_func ( * args , * * kwargs ) :
if tracking := ( TRACK_MATCH_STATS > = 2 and tracked_ctxs ) :
loc = ( ( frm := sys . _getframe ( 1 ) ) . f_code . co_filename , frm . f_lineno )
tracked_ctxs [ - 1 ] . append ( ctx := TrackedGraphRewrite ( loc , args [ 0 ] , kwargs . get ( " bottom_up " , False ) , [ ] , kwargs . get ( " name " , None ) ) )
active_rewrites . append ( ctx )
ret = func ( * args , * * kwargs )
if tracking : active_rewrites . pop ( )
return ret
return _track_func
class TrackedPatternMatcher ( PatternMatcher ) :
def rewrite ( self , uop : UOp , ctx = None ) - > UOp | None :
ret = None
ler = { u . op for u in uop . src }
for p , fxn , early_reject , has_ctx in self . pdict . get ( uop . op , [ ] ) :
if p not in match_stats : match_stats [ p ] = [ 0 , 0 , 0.0 , 0.0 ]
st = time . perf_counter ( )
if not early_reject . issubset ( ler ) :
match_stats [ p ] [ 2 ] + = time . perf_counter ( ) - st
continue
match_stats [ p ] [ 1 ] + = 1
for match in p . match ( uop , { } ) :
if ( ret := ( fxn ( ctx = ctx , * * match ) if has_ctx else fxn ( * * match ) ) ) is not None :
match_stats [ p ] [ 0 ] + = 1
match_stats [ p ] [ 3 ] + = ( et := time . perf_counter ( ) - st )
if TRACK_MATCH_STATS > = 3 : print ( f " { et * 1e6 : 7.2f } us -- " , p . printable ( ) )
if TRACK_MATCH_STATS > = 2 and isinstance ( ret , UOp ) and active_rewrites : active_rewrites [ - 1 ] . matches . append ( ( uop , ret , p ) )
return ret # NOTE: if it returns None, we keep trying to match
match_stats [ p ] [ 2 ] + = time . perf_counter ( ) - st
return None
if TRACK_MATCH_STATS :
PatternMatcher = TrackedPatternMatcher # type: ignore
import atexit
@atexit . register
def print_match_stats ( ) :
if TRACK_MATCH_STATS > = 2 :
with open ( fn := temp ( " rewrites.pkl " , append_user = True ) , " wb " ) as f :
print ( f " rewrote { len ( tracked_ctxs ) } graphs and matched { sum ( len ( r . matches ) for x in tracked_ctxs for r in x ) } times, saved to { fn } " )
with Context ( PICKLE_BUFFERS = 0 ) : pickle . dump ( ( tracked_keys , tracked_ctxs ) , f )
if getenv ( " VIZ " ) : launch_viz ( " VIZ " , temp ( " rewrites.pkl " , append_user = True ) )
if getenv ( " PRINT_MATCH_STATS " , 1 ) :
ret = [ 0 , 0 , 0.0 , 0.0 ]
for k , v in sorted ( list ( match_stats . items ( ) ) , key = lambda x : x [ 1 ] [ 2 ] + x [ 1 ] [ 3 ] ) :
loc_str = f " { k . location [ 0 ] . split ( ' / ' ) [ - 1 ] } : { k . location [ 1 ] } "
if v [ 1 ] != 0 : print ( f " { v [ 0 ] : 6d } / { v [ 1 ] : 7d } -- { v [ 3 ] * 1000. : 9.2f } / { ( v [ 2 ] + v [ 3 ] ) * 1000. : 9.2f } ms -- { loc_str : 15s } " , k . printable ( ) )
ret = [ x + y for x , y in zip ( ret , v ) ]
print ( f " { ret [ 0 ] : 6d } / { ret [ 1 ] : 7d } -- { ret [ 3 ] * 1000. : 9.2f } / { ( ret [ 2 ] + ret [ 3 ] ) * 1000. : 9.2f } ms -- TOTAL " )
def launch_viz ( env_str : str , data : str ) :
os . environ [ env_str ] = " 0 "
os . environ [ f " { env_str } _DATA " ] = data
if not int ( os . getenv ( " VIZ " , " 0 " ) ) and not int ( os . getenv ( " PROFILE " , " 0 " ) ) :
args = [ ' --kernels ' , getenv ( " VIZ_DATA " , " " ) ] if getenv ( " VIZ_DATA " , " " ) else [ ]
args + = [ ' --profile ' , getenv ( " PROFILE_DATA " , " " ) ] if getenv ( " PROFILE_DATA " , " " ) else [ ]
os . execv ( sys . executable , [ sys . executable ] + [ os . path . join ( os . path . dirname ( __file__ ) , " . " , " viz " , " serve.py " ) ] + args )
# *** simple graph rewrite engine ***
class RewriteContext :
def __init__ ( self , pm , ctx = None , children = None ) :
self . pm : PatternMatcher = pm
self . ctx = self if children is not None else ctx
self . replace : dict [ UOp , UOp ] = { }
self . children = children
# TODO: is this function always right?
def update_children ( self ) :
# add any new children from UOps that were replaced
for u in self . replace . values ( ) :
for s in u . src : self . children . setdefault ( s , { } ) [ u ] = None
# find any children that were replaced and replace them
for k , v in self . children . items ( ) :
new_child : dict [ UOp , None ] = { }
for tv in v :
while ( nv := self . replace . get ( tv , None ) ) is not None and nv is not tv : tv = nv
new_child [ tv ] = None
self . children [ k ] = new_child
def top_down_rewrite ( self , n : UOp ) - > UOp :
if ( rn := self . replace . get ( n ) ) is not None : return rn
new_src = tuple ( [ self . top_down_rewrite ( x ) for x in n . src ] )
new_n = self . pm . rewrite ( n , self . ctx ) if new_src == n . src else UOp ( n . op , n . dtype , new_src , n . arg )
self . replace [ n ] = ret = n if new_n is None else self . top_down_rewrite ( new_n )
return ret
def bottom_up_rewrite ( self , n : UOp ) - > UOp :
if ( rn := self . replace . get ( n ) ) is not None : return rn
new_n : UOp | None = n
while new_n is not None : last_n , new_n = new_n , self . pm . rewrite ( new_n , self . ctx )
new_src = tuple ( [ self . bottom_up_rewrite ( x ) for x in last_n . src ] )
self . replace [ n ] = ret = last_n if new_src == last_n . src else self . bottom_up_rewrite ( UOp ( last_n . op , last_n . dtype , new_src , last_n . arg ) )
return ret
@track_matches
def graph_rewrite ( sink : UOp , pm : PatternMatcher , ctx = None , bottom_up = False , name = None , track_children = False ) - > UOp :
rewrite_ctx = RewriteContext ( pm , ctx , children = sink . get_children_map ( ) if track_children else None )
return rewrite_ctx . bottom_up_rewrite ( sink ) if bottom_up else rewrite_ctx . top_down_rewrite ( sink )
@track_matches
def graph_rewrite_map ( sink : UOp , pm : PatternMatcher , ctx = None , bottom_up = False , name = None , track_children = False ) - > dict [ UOp , UOp ] :
rewrite_ctx = RewriteContext ( pm , ctx , children = sink . get_children_map ( ) if track_children else None )
return { k : ( rewrite_ctx . bottom_up_rewrite ( k ) if bottom_up else rewrite_ctx . top_down_rewrite ( k ) ) for k in list ( sink . toposort ) [ : : - 1 ] }
def sint_to_uop ( x : sint , dtype : DType = dtypes . int ) - > UOp : return UOp . const ( dtype , x ) if isinstance ( x , int ) else x
_substitute = PatternMatcher ( [ ( UPat ( tuple ( Ops ) , name = " x " ) , lambda ctx , x : ctx . get ( x , None ) ) ] )
# for debug
syms = { Ops . ADD : " + " , Ops . SUB : " - " , Ops . IDIV : " // " , Ops . MOD : " % " , Ops . SHL : " << " , Ops . SHR : " >> " ,
Ops . MUL : " * " , Ops . CMPLT : " < " , Ops . CMPNE : " != " , Ops . AND : " & " , Ops . OR : " | " , Ops . XOR : " ^ " }
renderer = PatternMatcher ( [
( UPat ( ( Ops . DEFINE_VAR , Ops . SPECIAL ) , name = " x " ) , lambda x : UOp ( Ops . NOOP , arg = x . arg [ 0 ] ) ) ,
( UPat ( Ops . RANGE , name = " x " ) , lambda x : UOp ( Ops . NOOP , arg = f " ridx { x . arg } " ) ) ,
( UPat ( ( Ops . CONST , Ops . VCONST ) , name = " x " ) , lambda x : UOp ( Ops . NOOP , arg = str ( x . arg ) ) ) ,
( UPat ( Ops . UNROLL , name = " x " ) , lambda x : UOp ( Ops . NOOP , arg = f " UNROLL( { x . src [ 0 ] . arg } , { x . arg } ) " ) ) ,
( UPat ( Ops . BIND , src = UPat ( Ops . NOOP ) , name = " x " ) , lambda x : x . src [ 0 ] ) ,
( UPat ( Ops . NEG , src = UPat ( Ops . NOOP ) , name = " x " ) , lambda x : UOp ( Ops . NOOP , arg = f " (- { x . src [ 0 ] . arg } ) " ) ) ,
( UPat ( Ops . MAX , src = UPat ( Ops . NOOP ) , name = " x " ) , lambda x : UOp ( Ops . NOOP , arg = f " max( { x . src [ 0 ] . arg } , { x . src [ 1 ] . arg } ) " ) ) ,
( UPat ( Ops . MULACC , src = UPat ( Ops . NOOP ) , name = " x " ) , lambda x : UOp ( Ops . NOOP , arg = f " ( { x . src [ 0 ] . arg } * { x . src [ 1 ] . arg } + { x . src [ 2 ] . arg } ) " ) ) ,
( UPat ( Ops . WHERE , src = UPat ( Ops . NOOP ) , name = " x " ) , lambda x : UOp ( Ops . NOOP , arg = f " ( { x . src [ 1 ] . arg } if { x . src [ 0 ] . arg } else { x . src [ 2 ] . arg } ) " ) ) ,
( UPat ( GroupOp . ALU , src = UPat ( Ops . NOOP ) , name = " x " ) , lambda x : UOp ( Ops . NOOP , arg = f " ( { x . src [ 0 ] . arg } { syms [ x . op ] } { x . src [ 1 ] . arg } ) " ) ) ,
] )
# *** what was symbolic.py ***
sint = Union [ int , UOp ]
Variable = UOp
ConstLike = Union [ ConstType , Variable , tuple [ ConstType , . . . ] ]
# *** UOp merge views and swizzling ***
merge_views = PatternMatcher ( [
# merge adjacent views
( UPat ( Ops . VIEW , src = ( UPat ( Ops . VIEW , name = " v2 " ) , ) , name = " v1 " ) , lambda v1 , v2 : v2 . replace ( arg = v2 . arg + v1 . arg ) ) ,
# merge unmasked const views
( UPat ( Ops . VIEW , name = " v " , src = ( UPat ( ( Ops . CONST , Ops . DEFINE_VAR ) , name = " const " ) , ) ) ,
lambda v , const : const . replace ( src = ( const . src [ 0 ] . replace ( arg = const . st + v . st ) , ) ) if all ( x . mask is None for x in ( const . st + v . st ) . views ) else None ) ,
# merge view on load/store/valid
( UPat ( Ops . VIEW , name = " v " , src = ( UPat ( ( Ops . LOAD , Ops . STORE , Ops . VALID ) , name = " b " ) , ) ) ,
lambda b , v : b . replace ( src = tuple ( ( s . st + v . st ) . to_uop ( ) if s . op is Ops . VIEW else s for s in b . src ) ) ) ,
# remove view if it's a contiguous and the shapes match
( UPat ( Ops . VIEW , name = " v " , src = ( UPat ( GroupOp . All - { Ops . DEVICE } , name = " x " ) , ) ) , lambda v , x : x if v . arg . contiguous and x . shape == v . shape else None ) ,
# remove mask if there's a zero in the masked dim
( UPat ( Ops . VIEW , name = " v " , src = ( UPat ( ) , ) ) ,
lambda v : v . const_like ( 0 ) if ( mask := v . st . views [ - 1 ] . mask ) is not None and any ( ( x [ 1 ] - x [ 0 ] ) == 0 for x in mask ) else None ) ,
# movement ops apply a new view on the base
( UPat ( GroupOp . Movement , src = ( UPat . var ( " x " ) , ) , name = " mop " ) , lambda mop , x : x . view ( mop . st ) ) ,
] )
view_left = merge_views + PatternMatcher ( [
# do not push masked view before unsafe pad ops
( UPat ( Ops . VIEW , src = ( UPat ( GroupOp . UnsafePad , name = " e " ) , ) , name = " view " ) ,
lambda e , view : e . contiguous ( ) . view ( view . st ) if any ( v . mask is not None for v in view . st . views ) else None ) ,
# view before elementwise ops
( UPat ( Ops . VIEW , src = ( UPat ( { * GroupOp . ALU , Ops . CAST , Ops . BITCAST } , name = " e " ) , ) , name = " view " ) ,
lambda e , view : e . replace ( src = tuple ( s . view ( s . st + view . st ) if s . op is Ops . VIEW else s . view ( view . st ) for s in e . src ) ) ) ,
] )