from typing import Any , Callable , cast
import functools , operator , itertools
from collections import defaultdict
from dataclasses import dataclass
from tinygrad . device import is_dtype_supported
from tinygrad . dtype import dtypes , ImageDType , PtrDType , promo_lattice , DType
from tinygrad . uop . ops import UOp , Ops , UPat , PatternMatcher , graph_rewrite , GroupOp , identity_element
from tinygrad . uop . symbolic import split_uop , uop_given_valid , parse_valid , simplify_valid , sym , symbolic_flat
from tinygrad . helpers import getenv , flatten , AMX , prod , partition
from tinygrad . uop . transcendental import xexp2 , xlog2 , xsin , xpow , TRANSCENDENTAL_SUPPORTED_DTYPES
from tinygrad . renderer import Renderer
# ***** image load valid simplification *****
def simplify_valid_load ( buf : UOp , start_idx : UOp , valid : UOp ) - > UOp | None :
if ( idx := uop_given_valid ( valid , start_idx ) ) is None : return buf . const_like ( 0 )
if not isinstance ( buf . dtype , ImageDType ) : return None if idx is start_idx else buf . index ( idx , valid )
# wait for it to be image indexed before running simplification
if start_idx . dtype . count != 2 : return None
# can drop valid if idx is out of bound when valid is False
drop_stmt = [ ]
for stmt in split_uop ( valid , Ops . AND ) :
try : X , is_upper_bound , c = parse_valid ( stmt )
except ValueError : return None
# for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i
if not is_upper_bound and c == 1 and all ( u . op in GroupOp . Irreducible and u . vmin == 0 for u in split_uop ( X , Ops . ADD ) ) :
testidx = functools . reduce ( lambda nowidx , u : nowidx . substitute ( { u : u . const_like ( 0 ) } ) , split_uop ( X , Ops . ADD ) , idx )
testidx = testidx . simplify ( )
if testidx . gep ( 0 ) . vmax < 0 or testidx . gep ( 1 ) . vmax < 0 :
drop_stmt . append ( stmt )
continue
# if X <= c, check if it's out of bound when X = c+1
# if X >= c, check if it's out of bound when X = c-1
test_value = c + 1 if is_upper_bound else c - 1
for i , b in zip ( idx . src , ( buf . dtype . shape [ 1 ] , buf . dtype . shape [ 0 ] ) ) :
if i . is_increasing ( ) :
rw = i . substitute ( { X : X . const_like ( test_value ) } ) . simplify ( )
if rw . vmin > = b or rw . vmax < 0 :
drop_stmt . append ( stmt )
break
if not drop_stmt and idx is start_idx : return None
new_valid = functools . reduce ( operator . and_ , ss ) if ( ss := [ s for s in split_uop ( valid , Ops . AND ) if s not in drop_stmt ] ) else None
return buf . index ( idx , new_valid )
def delete_redundant_gates ( buf : UOp , idx : UOp , val : UOp , store_gate : UOp , cast : UOp | None = None ) - > UOp | None :
if store_gate not in [ gate . src [ 0 ] for gate in val . toposort ( ) if gate . op is Ops . IF ] : return None
# remove the gate from the index
return UOp . store ( buf . index ( idx ) . cast ( cast . dtype ) if cast is not None else buf . index ( idx ) , val )
load_store_indexing = PatternMatcher ( [
# simplify valid
( UPat ( Ops . AND , name = " valid " ) , simplify_valid ) ,
# image load valid idx simplification
( UPat ( Ops . INDEX , src = ( UPat . var ( " buf " ) , UPat . var ( " start_idx " ) , UPat . var ( " valid " ) ) ) , simplify_valid_load ) ,
# index True is just Index
( UPat ( Ops . INDEX , src = ( UPat . var ( " buf " ) , UPat . var ( " start_idx " ) , UPat ( Ops . CONST , arg = True ) ) ) , lambda buf , start_idx : buf . index ( start_idx ) ) ,
# delete_redundant_gates (after expand)
( UPat ( Ops . STORE , src = ( UPat . any ( stidx := UPat . var ( " buf " ) . index ( UPat . var ( " idx " ) , UPat . var ( " store_gate " ) ) , stidx . cast ( ) . named ( " cast " ) ) ,
UPat . var ( " val " ) ) ) , delete_redundant_gates ) ,
] )
# ***** load/store grouping *****
def expand_index ( buf : UOp , vec : UOp , mask : UOp | None = None ) :
if getenv ( " UNSAFE_DISABLE_MASK " , 0 ) : mask = None
# generate the individual indexes
midx = graph_rewrite ( UOp . sink ( * [ buf . index ( vec . gep ( i ) , mask . gep ( i ) if mask is not None else None ) for i in range ( vec . dtype . count ) ] ) ,
symbolic_flat + load_store_indexing , name = f " index_buf_ { buf . arg } " )
# extract all the relevant offsets
offsets_rootsrc : defaultdict [ Any , dict [ int , list [ int ] ] ] = defaultdict ( dict )
for i in range ( vec . dtype . count ) :
idx : Any = midx . src [ i ] . src [ 1 ]
if idx . op is Ops . ADD and idx . src [ 1 ] . op is Ops . CONST : root_src , arg = idx . src [ 0 ] , idx . src [ 1 ] . arg
elif idx . op is Ops . ADD and idx . src [ 0 ] . op is Ops . CONST : root_src , arg = idx . src [ 1 ] , idx . src [ 0 ] . arg
elif idx . op is Ops . CONST : root_src , arg = " CONST " , idx . arg
else : root_src , arg = idx , 0
if len ( midx . src [ i ] . src ) == 3 : root_src = ( midx . src [ i ] . src [ 2 ] , root_src )
offsets_rootsrc [ root_src ] . setdefault ( arg , [ ] ) . append ( i )
# the buf.dtype is always a pointer
ptrdtype = cast ( PtrDType , buf . dtype )
# then rewrite everything we can into groups
ret = [ ]
idxs : list [ int | None ] = [ None ] * vec . dtype . count
global_offset = 0
for offsets in offsets_rootsrc . values ( ) :
grouped_offsets = [ [ x for _ , x in group ] for _ , group in itertools . groupby ( enumerate ( sorted ( offsets . keys ( ) ) ) , lambda x : x [ 1 ] - x [ 0 ] ) ]
for grp in grouped_offsets :
# get the index offset for this element. using [0] is okay, because they are the same
lidx = midx . src [ offsets [ grp [ 0 ] ] [ 0 ] ]
if len ( grp ) > 1 : lidx = lidx . cast ( ptrdtype . base . vec ( len ( grp ) ) . ptr ( size = ptrdtype . size , local = ptrdtype . local ) )
# set the idxs of the output
for i , g in enumerate ( grp ) :
for oo in offsets [ g ] : idxs [ oo ] = global_offset + i
# add this lidx to the CAT
ret . append ( lidx )
global_offset + = len ( grp )
assert None not in idxs , f " some idxs are missing { idxs } "
# this base thing is for image, we want the CAT to be a normal pointer
post_cat = UOp ( Ops . PTRCAT , ptrdtype . base . ptr ( size = ptrdtype . size , local = ptrdtype . local ) . vec ( vec . dtype . count ) , tuple ( ret ) )
return post_cat . gep ( tuple ( cast ( list [ int ] , idxs ) ) )
def cat_after_store ( cat : UOp , data : UOp ) :
# TODO: this is written in many places
offset = 0
ret = [ ]
for s in cat . src :
ret . append ( s . store ( data . gep ( tuple ( range ( offset , offset + s . dtype . count ) ) ) ) )
offset + = s . dtype . count
return UOp . sink ( ret [ 0 ] , * ret [ 1 : ] )
def gep_on_store ( gep : UOp , st : UOp ) :
# NOTE: we need to invert the gep here, but it may be an expanding gep
# fake argsort. TODO: handle duplicates
a = { }
for i , x in enumerate ( gep . arg ) : a [ x ] = i
new_arg = tuple ( x [ 1 ] for x in sorted ( a . items ( ) ) )
return UOp ( Ops . STORE , src = ( gep . src [ 0 ] , st . gep ( new_arg ) ) )
load_store_folding = PatternMatcher ( [
( UPat ( Ops . INDEX , src = ( UPat ( Ops . VECTORIZE , src = UPat ( ( Ops . DEFINE_GLOBAL , Ops . DEFINE_LOCAL ) , name = " buf " ) ) , UPat . var ( " vec " ) ) ) , expand_index ) ,
( UPat ( Ops . INDEX , src = ( UPat ( Ops . VECTORIZE , src = UPat ( ( Ops . DEFINE_GLOBAL , Ops . DEFINE_LOCAL ) , name = " buf " ) ) , UPat . var ( " vec " ) ,
UPat . var ( " mask " ) ) ) , expand_index ) ,
# GEP after LOAD
( UPat ( Ops . LOAD , src = ( UPat ( Ops . GEP , name = " gep " ) , ) , name = " ld " , allow_any_len = True ) ,
lambda gep , ld : ld . replace ( dtype = ld . dtype . scalar ( ) . vec ( gep . dtype . count ) , src = ( gep . src [ 0 ] , ) + ld . src [ 1 : ] ) . gep ( gep . arg ) ) ,
# GEP on data of STORE
( UPat ( Ops . STORE , src = ( UPat ( Ops . GEP , name = " gep " ) , UPat . var ( " st " ) ) ) , gep_on_store ) ,
# put PTRCAT after LOAD
( UPat ( Ops . LOAD , src = ( UPat ( Ops . PTRCAT , name = " cat " ) , ) , name = " ld " , allow_any_len = True ) ,
lambda cat , ld : UOp ( Ops . CAT , ld . dtype , tuple ( ld . replace ( dtype = x . dtype . base , src = ( x , ) + ld . src [ 1 : ] ) for x in cat . src ) ) ) ,
# put PTRCAT after STORE
( UPat ( Ops . STORE , src = ( UPat ( Ops . PTRCAT , name = " cat " ) , UPat ( name = " data " ) ) ) , cat_after_store ) ,
] )
# ***** optional patterns *****
@functools . lru_cache ( None )
def magicgu ( vmax : int , d : int ) - > tuple [ int , int ] :
# calculate m,s such that x//d == (x*m) >> s for all 0 <= x <= vmax, d>0; adapted from Hacker's Delight, Chapter 10
nc = ( vmax + 1 ) / / ( d ) * d - 1
nbits = vmax . bit_length ( )
for s in range ( 0 , 2 * nbits + 1 ) :
if 2 * * s > nc * ( d - 1 - ( 2 * * s - 1 ) % d ) :
m = ( 2 * * s + d - 1 - ( 2 * * s - 1 ) % d ) / / d
return m , s
assert False
def fast_idiv ( ctx : Renderer | None , x : UOp , d : int ) - > UOp | None :
# idiv is truncated division, but arithmatic shift is floored division, so can only do non-negative numbers!
if x . vmin < 0 : return None
sign = 1 if d > 0 else - 1
m , s = magicgu ( vmax := min ( x . vmax , dtypes . max ( x . dtype ) ) , abs ( d ) )
if m * vmax < = dtypes . max ( x . dtype ) : return sign * ( ( x * m ) >> s )
# promo_lattice needs to return an unsigned type
if ctx is not None and dtypes . is_int ( next_dtype := promo_lattice [ x . dtype ] [ - 1 ] ) and is_dtype_supported ( next_dtype , ctx . device ) :
if m * vmax < = dtypes . max ( next_dtype ) : return sign * ( ( x . cast ( next_dtype ) * m ) >> s ) . cast ( x . dtype )
return None
powers_of_two = { 2 * * i : i for i in range ( 64 ) }
@functools . cache
def get_late_rewrite_patterns ( ops , force_transcendental = False ) :
pat : list [ tuple [ UPat , Callable ] ] = [ ( UPat ( op , dtype = TRANSCENDENTAL_SUPPORTED_DTYPES , src = ( UPat . var ( " d " ) , ) ) , f ) for op , f in \
( ( Ops . EXP2 , xexp2 ) , ( Ops . LOG2 , xlog2 ) , ( Ops . SIN , xsin ) ) if op not in ops or force_transcendental ]
# rewrite SQRT to xpow 0.5
if Ops . SQRT not in ops : pat . append ( ( UPat ( Ops . SQRT , src = UPat . var ( " d " ) ) , lambda d : xpow ( d , d . const_like ( 0.5 ) ) ) )
# rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
if Ops . AND in ops : pat + = [ ( UPat . var ( " x " , dtypes . ints ) % UPat . cvar ( " c " ) , lambda x , c : x & ( c . arg - 1 ) if c . arg in powers_of_two else None ) ]
# rewrite MUL/IDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)
if Ops . SHL in ops : pat + = [ ( UPat . var ( " x " , dtypes . ints ) * UPat . cvar ( " c " ) , lambda c , x : x << v if ( v := powers_of_two . get ( c . arg , 0 ) ) else None ) ]
if Ops . SHR in ops :
# no reason to check x<0 for uints
pat + = [ ( UPat . var ( " x " , dtypes . uints ) / / UPat . cvar ( " c " ) , lambda x , c : x >> v if ( v := powers_of_two . get ( c . arg , 0 ) ) else None ) ]
pat + = [ ( UPat . var ( " x " , dtypes . ints ) / / UPat . cvar ( " c " ) , lambda x , c : ( x + ( l . const_like ( l . vmin ) if ( l := ( x < 0 ) ) . vmin == l . vmax else l ) . where (
c - 1 , 0 ) ) >> v if ( v := powers_of_two . get ( c . arg , 0 ) ) else None ) ] # (x+(x<0).where(c-1, 0)) >> v
if not getenv ( " DISABLE_FAST_IDIV " ) :
pat + = [ ( UPat . var ( " x " , dtypes . ints ) / / UPat . cvar ( " d " ) , lambda ctx , x , d : fast_idiv ( ctx , x , d . arg ) ) ]
pat + = [ ( UPat . var ( " x " , dtypes . ints ) % UPat . cvar ( " d " ) , lambda ctx , x , d : x - d * f if ( f := fast_idiv ( ctx , x , d . arg ) ) is not None else None ) ]
if Ops . NEG in ops :
pat + = [ ( UPat . var ( ' x ' ) * - 1 , lambda x : x . alu ( Ops . NEG ) ) ]
if Ops . SUB in ops : pat + = [ ( UPat . var ( ' x ' ) + UPat . var ( ' y ' ) . alu ( Ops . NEG ) , lambda x , y : x . alu ( Ops . SUB , y ) ) ]
if Ops . MULACC in ops : pat + = [ ( UPat . var ( ' a ' ) * UPat . var ( ' b ' ) + UPat . var ( ' c ' ) , lambda a , b , c : a . alu ( Ops . MULACC , b , c ) ) ]
return PatternMatcher ( pat )
# *** correct load/store ***
def split_load_store ( ctx : Renderer | None , ls : UOp , idx : UOp ) :
# this splits loads and stores into multiple chunks
# if there's only one element to load/store, no splitting needed
if ( sz := ls . src [ 0 ] . dtype . count ) == 1 : return None
buf = idx . src [ 0 ]
# determine fold lengths
lengths = [ ]
must_divide = True
if ctx is not None and ctx . device == " DSP " :
lengths = [ 128 , 64 , 32 , 16 , 8 , 4 ]
must_divide = False
elif buf . dtype . base != dtypes . float and buf . dtype . base != dtypes . half and not isinstance ( buf . dtype , ImageDType ) :
pass
elif isinstance ( buf . dtype , ImageDType ) :
lengths = [ 4 ]
elif ctx is not None and ctx . supports_float4 :
# TODO: a better way to get this than ctx
lengths = [ 8 , 4 , 2 ] if buf . dtype . base == dtypes . half and getenv ( " ALLOW_HALF8 " ) else ( [ 16 , 8 , 4 , 2 ] if AMX else [ 4 , 2 ] )
lengths . append ( 1 ) # worst case, it's not folded
# filter fold lengths that don't divide
if must_divide : lengths = [ x for x in lengths if idx . src [ 1 ] . divides ( x ) is not None ]
# split based on the fold lengths
global_offset = 0
ret = [ ]
ptrdtype = cast ( PtrDType , buf . dtype )
while global_offset < sz :
# with 1 at the end of the lengths list, this will always hit
for fold_length in lengths :
if global_offset + fold_length > sz : continue
lidx = buf . index ( idx . src [ 1 ] + global_offset , idx . src [ 2 ] if len ( idx . src ) > 2 else None )
if fold_length > 1 : lidx = lidx . cast ( ptrdtype . base . vec ( fold_length ) . ptr ( size = ptrdtype . size , local = ptrdtype . local ) )
if ls . op is Ops . STORE : ret . append ( ls . replace ( src = ( lidx , ls . src [ 1 ] . gep ( tuple ( range ( global_offset , global_offset + fold_length ) ) ) ) + ls . src [ 2 : ] ) )
else : ret . append ( ls . replace ( src = ( lidx , ) + ls . src [ 1 : ] , dtype = ls . dtype . scalar ( ) . vec ( fold_length ) ) )
global_offset + = fold_length
break
# if it wasn't split, we return None. otherwise we CAT them
return UOp ( Ops . CAT , ls . dtype , tuple ( ret ) ) if len ( ret ) > 1 else None
def image_fixup ( ls : UOp ) :
# normal image load or store, with the CAST from expand_index
if ls . src [ 0 ] . op is Ops . CAST and isinstance ( image_dtype := ls . src [ 0 ] . src [ 0 ] . dtype , ImageDType ) :
assert ls . src [ 0 ] . dtype . count == 4 , " image must be casted to 4 "
idx = ls . src [ 0 ] . src [ 0 ]
oidx = UOp ( Ops . VECTORIZE , dtypes . int . vec ( 2 ) , ( ( idx . src [ 1 ] / / 4 ) % image_dtype . shape [ 1 ] , ( idx . src [ 1 ] / / ( 4 * image_dtype . shape [ 1 ] ) ) ) )
idx = idx . replace ( src = ( idx . src [ 0 ] , oidx ) + idx . src [ 2 : ] )
return ls . replace ( src = ( idx , ) + ls . src [ 1 : ] )
# this is an unprocessed image without a cast, aka unfoldable image load. this doesn't work for stores
if isinstance ( image_dtype := ls . src [ 0 ] . dtype , ImageDType ) and ls . src [ 0 ] . src [ 1 ] . dtype != dtypes . int . vec ( 2 ) :
assert ls . op is Ops . LOAD , " if an image store isn ' t upcasted to 4, we can ' t store it "
idx = ls . src [ 0 ]
id4 = idx . src [ 1 ] % 4
oidx = UOp ( Ops . VECTORIZE , dtypes . int . vec ( 2 ) , ( ( idx . src [ 1 ] / / 4 ) % image_dtype . shape [ 1 ] , ( idx . src [ 1 ] / / ( 4 * image_dtype . shape [ 1 ] ) ) ) )
idx = idx . replace ( src = ( idx . src [ 0 ] , oidx ) + idx . src [ 2 : ] )
vec_load = ls . replace ( dtype = ls . dtype . vec ( 4 ) , src = ( idx , ) + ls . src [ 1 : ] )
return functools . reduce ( lambda ret , i : id4 . ne ( i ) . where ( ret , vec_load . gep ( i ) ) , range ( 4 ) , ls . const_like ( float ( ' nan ' ) ) )
return None
correct_load_store = PatternMatcher ( [
# split LOAD/STORE
( UPat ( ( Ops . LOAD , Ops . STORE ) , src = ( UPat ( Ops . INDEX , name = " idx " ) . cast ( ) , ) , name = " ls " , allow_any_len = True ) , split_load_store ) ,
# image indexing, including unfoldable images
( UPat ( ( Ops . LOAD , Ops . STORE ) , name = " ls " ) , image_fixup ) ,
] )
# *** uop expander ***
# TODO: there's a lot shared with gep_through_wmma here
def no_vectorized_wmma ( wmma : UOp ) :
out_sz = prod ( x [ 1 ] for x in wmma . arg [ 6 ] [ - 1 ] )
if wmma . dtype . count == out_sz : return None
tsrcs = [ ]
for s , sz in zip ( wmma . src , wmma . arg [ 6 ] ) :
ssz = prod ( x [ 1 ] for x in sz )
tsrcs . append ( [ s . gep ( tuple ( range ( grp , grp + ssz ) ) ) for grp in range ( 0 , s . dtype . count , ssz ) ] )
wmmas = [ UOp ( Ops . WMMA , wmma . dtype . scalar ( ) . vec ( out_sz ) , tsrc , wmma . arg ) for tsrc in zip ( * tsrcs ) ]
wmma_ex = flatten ( [ [ e . gep ( i ) for i in range ( out_sz ) ] for e in wmmas ] )
return UOp ( Ops . VECTORIZE , wmma . dtype , tuple ( wmma_ex ) )
def no_vectorized_alu ( alu : UOp ) :
if alu . dtype . vcount == 1 : return None
alus = tuple ( UOp ( alu . op , alu . dtype . scalar ( ) , tuple ( s . gep ( i ) for s in alu . src ) , alu . arg ) for i in range ( alu . dtype . vcount ) )
return UOp ( Ops . VECTORIZE , alu . dtype , alus )
def no_vectorized_acc ( acc : UOp ) :
if acc . dtype . count == 1 : return None
alus = tuple ( UOp ( acc . op , acc . dtype . scalar ( ) ,
tuple ( s . gep ( i ) if j == 0 else s for j , s in enumerate ( acc . src ) ) , acc . arg + ( i , ) ) for i in range ( acc . dtype . count ) )
return UOp ( Ops . VECTORIZE , acc . dtype , alus )
devectorize = PatternMatcher ( [
# no ALU on vectorized dtypes
( UPat ( ( * GroupOp . ALU , Ops . CAST , Ops . BITCAST , Ops . ASSIGN ) , name = " alu " ) , no_vectorized_alu ) ,
( UPat ( Ops . WMMA , name = " wmma " ) , no_vectorized_wmma ) ,
( UPat ( Ops . DEFINE_ACC , name = " acc " ) , no_vectorized_acc ) ,
] )
pm_render = PatternMatcher ( [
# for rendering, we use explicit VECTORIZE
( UPat ( Ops . CONST , name = ' c ' ) ,
lambda c : UOp ( Ops . VECTORIZE , c . dtype , ( UOp . const ( c . dtype . scalar ( ) , c . arg ) , ) * c . dtype . vcount ) if c . dtype . vcount > 1 else None ) ,
( UPat ( Ops . VCONST , name = ' c ' ) , lambda c : UOp ( Ops . VECTORIZE , c . dtype , tuple ( UOp . const ( c . dtype . scalar ( ) , x ) for x in c . arg ) ) ) ,
( UPat ( Ops . GEP , name = ' gep ' ) , lambda gep : UOp ( Ops . VECTORIZE , gep . dtype , tuple ( gep . src [ 0 ] . gep ( x ) for x in gep . arg ) ) if len ( gep . arg ) > 1 else None ) ,
( UPat ( Ops . GEP , name = ' gep ' ) , lambda gep : gep . src [ 0 ] if gep . src [ 0 ] . dtype . vcount == 1 and gep . arg == ( 0 , ) else None ) ,
( UPat ( Ops . VECTORIZE , src = ( UPat ( name = ' x ' ) , ) ) , lambda x : x ) ,
# give any loads that are masked an alt value
( UPat ( Ops . LOAD , src = ( UPat ( Ops . INDEX , src = ( UPat ( ) , UPat ( ) , UPat ( ) ) ) . or_casted ( ) , ) , allow_any_len = True , name = " x " ) ,
lambda x : x . replace ( src = ( x . src [ 0 ] , x . const_like ( 0 ) ) + x . src [ 1 : ] ) if len ( x . src ) == 1 or x . src [ 1 ] . op is Ops . CUSTOM else None ) ,
# gate any stores that aren't gated with ifs
( UPat ( Ops . STORE , dtype = dtypes . void , src = ( UPat ( src = ( UPat ( ) , UPat ( ) , UPat ( dtype = dtypes . bool ) ) , name = " idx " ) . or_casted ( ) , UPat ( ) ) , name = " store " ) ,
lambda store , idx : UOp ( Ops . STORE , src = store . src + ( UOp ( Ops . IF , src = ( idx . src [ 2 ] , ) ) , ) ) ) ,
] )
# *** Ops.REDUCE -> Ops.DEFINE_ACC+Ops.ASSIGN ***
@dataclass
class ReduceContext :
acc_num : int = 0
def horizontal_reduce ( inp : UOp , out_dtype : DType ) - > list [ UOp ] :
# if this has a horizontal reduction component, do that first
if inp . dtype != out_dtype :
# NOTE: [0 1 2 3 4 5 6 7] -> [0+4, 1+5, 2+6, 3+7]
horizontal_amount = inp . dtype . count / / out_dtype . count
return [ inp . gep ( tuple ( range ( i , inp . dtype . count , horizontal_amount ) ) ) for i in range ( 0 , horizontal_amount ) ]
return [ inp ]
def reduce_to_acc ( ctx : ReduceContext , red : UOp ) :
inp , reduce_range = red . src [ 0 ] , red . src [ 1 : ]
lst = horizontal_reduce ( inp , red . dtype )
assert all ( x . dtype == red . dtype for x in lst ) , f " horizontal reduction mismatch { lst [ 0 ] . dtype } != { red . dtype } "
# if we have a range
if len ( reduce_range ) != 0 :
acc = UOp ( Ops . DEFINE_ACC , red . dtype , ( red . const_like ( identity_element ( red . arg , red . dtype . scalar ( ) ) ) , ) + tuple ( reduce_range ) , ( ctx . acc_num , ) )
lst = [ acc ] + lst # put acc as the first element
ctx . acc_num + = 1
ret = functools . reduce ( lambda x , y : x . alu ( red . arg , y ) , lst )
return acc . assign ( ret ) if len ( reduce_range ) != 0 else ret
def no_vectorized_reduce ( inp : UOp , red : UOp ) :
if inp . dtype != red . dtype :
red = red . replace ( src = ( functools . reduce ( lambda x , y : x . alu ( red . arg , y ) , horizontal_reduce ( inp , red . dtype ) ) , ) + red . src [ 1 : ] )
if red . dtype . vcount == 1 : return red
# no_vectorize_alu ignoring ranges
if red . dtype . vcount == 1 : return None
alus = tuple ( UOp ( red . op , red . dtype . scalar ( ) , ( red . src [ 0 ] . gep ( i ) , ) + red . src [ 1 : ] , red . arg ) for i in range ( red . dtype . vcount ) )
return UOp ( Ops . VECTORIZE , red . dtype , alus )
def reduce_rangeless ( red : UOp ) :
# TODO: share code with reduce_unparented
if red . arg not in { Ops . ADD , Ops . MAX } : return None
if red . src [ 0 ] . dtype != red . dtype : return None
if any ( x . op in { Ops . RANGE } for x in red . src [ 0 ] . toposort ( ) ) : return None
ret = red . src [ 0 ]
if red . arg is Ops . ADD :
for r in red . src [ 1 : ] :
ret = ret * r . src [ 0 ] . cast ( ret . dtype . scalar ( ) ) . broadcast ( ret . dtype . count )
return ret
def no_range ( u : UOp ) - > bool : return not any ( x . op is Ops . RANGE for x in u . sparents )
pm_reduce_collapse = PatternMatcher ( [
# lift x+y out of reduce on lt
( ( UPat . var ( " x " ) + UPat . var ( " y " ) ) < UPat . var ( " c " ) , lambda x , y , c : ( x < ( c - y ) ) if no_range ( y ) and no_range ( c ) else None ) ,
# lift x*y out of reduce
( ( UPat . var ( " x " ) * UPat . var ( " y " ) ) < UPat . var ( " c " ) ,
lambda x , y , c : ( x < ( ( c + y - 1 ) / / y ) ) if no_range ( y ) and no_range ( c ) and y . vmin > 0 else None ) ,
# lift x+y out of reduce on ne
( ( UPat . var ( " x " ) + UPat . var ( " y " ) ) != UPat . var ( " c " ) , lambda x , y , c : ( x != ( c - y ) ) if no_range ( y ) and no_range ( c ) else None ) ,
# fold the range
( ( UPat ( Ops . RANGE , name = " r " ) < UPat . var ( " cut " ) ) . where ( 0 , UPat . cvar ( " val " ) ) . reduce ( arg = Ops . ADD , allow_any_len = True ) ,
lambda r , cut , val : ( r . src [ 0 ] - cut ) . maximum ( 0 ) . minimum ( r . src [ 0 ] ) . cast ( val . dtype ) * val ) ,
( ( UPat ( Ops . RANGE , name = " r " ) < UPat . var ( " cut " ) ) . where ( UPat . cvar ( " val " ) , 0 ) . reduce ( arg = Ops . ADD , allow_any_len = True ) ,
lambda r , cut , val : cut . maximum ( 0 ) . minimum ( r . src [ 0 ] ) . cast ( val . dtype ) * val ) ,
# REDUCE on ADD
( ( UPat . var ( " x " ) + UPat . var ( " y " ) ) . reduce ( arg = Ops . ADD , allow_any_len = True , name = " r " ) ,
lambda x , y , r : x . reduce ( * r . src [ 1 : ] , arg = Ops . ADD ) + y . reduce ( * r . src [ 1 : ] , arg = Ops . ADD ) ) ,
# MUL casted bool
( ( UPat . var ( " x " ) * UPat . var ( " gate " , dtype = dtypes . bool ) . cast ( ) . or_broadcasted ( name = " b " ) ) ,
lambda x , gate , b = None : gate . broadcast ( x . dtype . count ) . where ( x , 0 ) if b is not None else gate . where ( x , 0 ) ) ,
# WHERE on LOAD (works on max too)
( UPat . var ( " gate " ) . where ( UPat ( Ops . INDEX , src = ( UPat . var ( " buf " ) , UPat . var ( " idx " ) ) ) . load ( ) , 0 ) . reduce ( arg = Ops . ADD , allow_any_len = True ) ,
lambda buf , idx , gate : buf . index ( idx , gate ) . load ( ) ) ,
( UPat . var ( " gate " ) . where ( 0 , UPat ( Ops . INDEX , src = ( UPat . var ( " buf " ) , UPat . var ( " idx " ) ) ) . load ( ) ) . reduce ( arg = Ops . ADD , allow_any_len = True ) ,
lambda buf , idx , gate : buf . index ( idx , gate . logical_not ( ) ) . load ( ) ) ,
# INDEX on RANGE / gated RANGE
( UPat . var ( " buf " ) . index ( UPat . var ( " expr " ) , UPat . var ( " idx " ) . eq ( UPat ( Ops . RANGE , name = " r " ) . or_casted ( ) ) ) ,
lambda buf , r , idx , expr : buf . index ( expr . substitute ( { r : idx . cast ( r . dtype ) } ) , ( idx . cast ( r . dtype ) > = 0 ) & ( idx . cast ( r . dtype ) < r . src [ 0 ] ) ) ) ,
# AND on WHERE
( ( UPat . any ( UPat ( Ops . DEFINE_VAR , name = " x " ) , UPat ( Ops . DEFINE_VAR ) . gep ( name = " x " ) ) & UPat . var ( " y " ) ) \
. where ( UPat . cvar ( " c " ) , 0 ) . reduce ( arg = Ops . ADD , allow_any_len = True , name = " r " ) ,
lambda x , y , c , r : y . where ( c , 0 ) . reduce ( * r . src [ 1 : ] , arg = Ops . ADD ) * x . cast ( c . dtype ) ) ,
# remove REDUCEs that no longer have a RANGE in the src
( UPat ( Ops . REDUCE , name = " red " ) , reduce_rangeless ) ,
# devectorize REDUCE
( UPat ( Ops . VECTORIZE , name = " inp " ) . reduce ( name = " red " , allow_any_len = True ) , no_vectorized_reduce ) ,
# index/load/where. TODO: this is more aggressive than needed
( UPat ( ( Ops . INDEX , Ops . LOAD , Ops . WHERE ) , name = " alu " ) , no_vectorized_alu ) ,
] ) + sym
def reduce_collapse ( red : UOp ) :
included , not_included = partition ( red . parents , lambda x : any ( y in x . sparents for y in red . src [ 1 : ] ) )
if any ( x . op in { Ops . STORE , Ops . REDUCE } for x in included ) : return None
replaces : dict [ UOp , UOp ] = { }
for u in included :
for s in u . src :
if s in not_included and s not in replaces and s . op not in { Ops . CONST , Ops . VCONST , Ops . DEFINE_GLOBAL , Ops . DEFINE_LOCAL , Ops . DEFINE_VAR } :
replaces [ s ] = UOp ( Ops . DEFINE_VAR , dtype = s . dtype , arg = ( f ' in { len ( replaces ) } ' , s . vmin , s . vmax ) )
collapse_fxn = red . substitute ( replaces )
sink = graph_rewrite ( collapse_fxn , pm_reduce_collapse , name = " reduce_collapse " )
# TODO: why is REDUCE needed here and just RANGE isn't enough?
if any ( x . op in { Ops . REDUCE , Ops . RANGE } for x in sink . toposort ( ) ) : return None
return sink . substitute ( { v : k for k , v in replaces . items ( ) } )
def reduce_unparented ( red : UOp ) :
if red . arg not in { Ops . ADD , Ops . MAX } : return None
reduce_parented , reduce_unparented = partition ( red . src [ 1 : ] , lambda x : x in red . src [ 0 ] . sparents )
if len ( reduce_unparented ) == 0 : return None
ret = red . replace ( src = ( red . src [ 0 ] , ) + tuple ( reduce_parented ) ) if len ( reduce_parented ) or red . dtype != red . src [ 0 ] . dtype else red . src [ 0 ]
if red . arg is Ops . ADD :
for r in reduce_unparented : ret = ret * r . src [ 0 ] . cast ( ret . dtype . scalar ( ) ) . broadcast ( ret . dtype . count )
return ret
pm_reduce = PatternMatcher ( [
# remove any ranges from a REDUCE that aren't referenced in the reduce source
( UPat ( Ops . REDUCE , name = " red " ) , reduce_unparented ) ,
# remove REDUCE without loads (generic arange opt / indexing). TODO: support multi range
( UPat ( Ops . REDUCE , src = ( UPat ( ) , UPat ( ) ) , name = " red " ) , reduce_collapse ) ,
# REDUCE -> DEFINE_ACC+ASSIGN
( UPat ( Ops . REDUCE , name = " red " ) , reduce_to_acc ) ,
# tensor core built in accumulate
( UPat ( Ops . WMMA , name = " wmma " ) + UPat . var ( " add " ) ,
lambda add , wmma : UOp ( wmma . op , wmma . dtype , ( wmma . src [ 0 ] , wmma . src [ 1 ] , wmma . src [ 2 ] + add ) , wmma . arg ) ) ,
] ) + sym