# the job of the lowerer is to do indexing
import functools , itertools , operator , math
from dataclasses import dataclass
from typing import cast
from tinygrad . dtype import dtypes , PtrDType , least_upper_dtype
from tinygrad . ops import KernelInfo , UOp , Ops , graph_rewrite , PatternMatcher , UPat , sint , identity_element , sint_to_uop
from tinygrad . renderer import Renderer
from tinygrad . helpers import all_int , prod , partition , flatten , unwrap , QUANTIZE
from tinygrad . codegen . expander import expand_rewrite
from tinygrad . codegen . symbolic import symbolic
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
def get_contraction ( old_shape : tuple [ sint , . . . ] , new_shape : tuple [ sint , . . . ] ) - > list [ list [ int ] ] | None :
acc_old , acc_new = list ( itertools . accumulate ( old_shape , operator . mul ) ) , list ( itertools . accumulate ( new_shape , operator . mul ) )
try : split = [ acc_old . index ( acc ) + 1 if acc != 1 else 0 for acc in acc_new ]
except ValueError : return None
return [ list ( range ( st , ed ) ) for st , ed in zip ( [ 0 ] + split [ : - 1 ] , split [ : - 1 ] + [ len ( old_shape ) ] ) ]
# ***** indexing *****
def _group_dims ( dims : tuple [ sint , . . . ] , max_sizes : tuple [ int , . . . ] ) :
# TODO: symbolic shape
if not all_int ( dims ) : return dims
while len ( dims ) > len ( max_sizes ) or any ( d > m for d , m in zip ( dims , max_sizes ) ) :
for i , m in enumerate ( max_sizes ) :
if i < ( len ( dims ) - 1 ) and dims [ i ] * dims [ i + 1 ] < = m :
dims = dims [ : i ] + ( dims [ i ] * dims [ i + 1 ] , ) + dims [ i + 2 : ]
break
else : return None
return dims
def _split_dims ( dims , max_sizes ) :
if all ( d < = m for d , m in zip ( dims , max_sizes ) ) : return dims
_dims = list ( dims ) + [ 1 ] * ( 3 - len ( dims ) )
for i in range ( len ( _dims ) ) :
while _dims [ i ] > max_sizes [ i ] :
div = next ( ( d for d in range ( 2 , math . ceil ( math . sqrt ( _dims [ i ] ) ) + 1 ) if ( _dims [ i ] % d ) == 0 ) , 1 )
if div == 1 : raise RuntimeError ( f " cannot limit dim { dims =} , { max_sizes =} " )
_dims [ i ] , _dims [ ( i + 1 ) % len ( _dims ) ] = _dims [ i ] / / div , _dims [ ( i + 1 ) % len ( _dims ) ] * div
return tuple ( _dims [ : 2 ] if _dims [ 2 ] == 1 else _dims [ 0 ] if _dims [ 1 : 3 ] == [ 1 , 1 ] else _dims )
def get_grouped_dims ( prefix , dims : tuple [ sint , . . . ] , max_sizes : tuple [ int , . . . ] | None , reverse = False ) - > list [ UOp ] :
if reverse : dims = dims [ : : - 1 ]
# try to group first: (a, b, c, d) -> (ab, c, d)
limited = ( grouped if ( grouped := _group_dims ( dims , max_sizes ) ) else dims ) if max_sizes is not None else dims
# check if grouping failed
if max_sizes is not None and len ( limited ) > len ( max_sizes ) : raise RuntimeError ( f " cannot limit dim { dims =} , { max_sizes =} " )
# try to split up dims: (a,) -> (b, c)
if limited == dims : limited = _split_dims ( dims , max_sizes ) if max_sizes is not None else dims
ret = raw_idxs = [ UOp ( Ops . SPECIAL , dtypes . int , ( ) , ( f " { prefix } { i } " , s ) ) for i , s in enumerate ( limited ) ]
if len ( limited ) < len ( dims ) :
ret = [ ]
if ( contraction := get_contraction ( dims , limited ) ) is None : raise AssertionError ( f " get_contraction should not be None { dims =} { limited =} " )
for idx , contraction_group in zip ( raw_idxs , contraction ) :
for c in contraction_group [ : - 1 ] :
ret . append ( idx % dims [ c ] )
idx / / = dims [ c ]
ret . append ( idx )
elif len ( limited ) > len ( dims ) :
a , b = len ( limited ) , len ( dims )
if a == 2 and b == 1 : ret = [ raw_idxs [ 0 ] * limited [ 1 ] + raw_idxs [ 1 ] ]
if a == 3 and b == 1 : ret = [ raw_idxs [ 0 ] * ( limited [ 1 ] * limited [ 2 ] ) + raw_idxs [ 1 ] * limited [ 2 ] + raw_idxs [ 2 ] ]
if a == 3 and b == 2 : ret = [ raw_idxs [ 0 ] * limited [ 1 ] + raw_idxs [ 1 ] , raw_idxs [ 2 ] ]
return ret [ : : - 1 ] if reverse else ret
@dataclass
class IndexContext :
idxs : list [ UOp ]
ridxs : list [ UOp ]
acc_num : int = 0
def get_index ( ast : UOp , opts : Renderer ) - > IndexContext :
ki = ast . arg if isinstance ( ast . arg , KernelInfo ) else KernelInfo ( )
# NOTE: assumes the shape is <global dims> <local dims> <group_for_reduces> <reduces> <upcasts/unrolls>
full_shape = ast . full_shape
first_upcasted = len ( full_shape ) - ki . upcasted
# if there's no reduce, this is first_upcasted. assumes reduces are at the end
first_reduce = min ( [ first_upcasted ] + flatten ( x . axis_arg for x in ast . toposort if x . op is Ops . REDUCE_AXIS ) )
local_loads = [ x for x in ast . toposort if x . op is Ops . LOAD and x . src [ 0 ] . op is Ops . DEFINE_LOCAL ]
# NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces
group_for_reduces = sum ( [ any ( l . st_arg . shape [ i ] != ast . src [ 0 ] . st_arg . shape [ i ] for l in local_loads ) for i in range ( first_reduce , first_upcasted ) ] )
global_dims = first_reduce - ki . local_dims
if opts . has_local :
if ki . dont_use_locals :
assert ki . local_dims == 0 , " can ' t use locals if there ' s no local dims "
idxs = get_grouped_dims ( " idx " , full_shape [ : global_dims ] , opts . global_max , reverse = True )
else :
# define indexes for GPU-like execution
idxs = get_grouped_dims ( " gidx " , full_shape [ : global_dims ] , opts . global_max , reverse = True ) + \
get_grouped_dims ( " lidx " , full_shape [ global_dims : first_reduce + group_for_reduces ] , opts . local_max )
else :
# all loops are RANGES
idxs = [ UOp ( Ops . RANGE , dtypes . int , ( sint_to_uop ( 0 ) , sint_to_uop ( g ) ) , i ) for i , g in enumerate ( full_shape [ : first_reduce ] ) ]
# reduce loops
idxs + = [ UOp ( Ops . RANGE , dtypes . int , ( sint_to_uop ( 0 ) , sint_to_uop ( g ) ) , i )
for i , g in enumerate ( full_shape [ first_reduce + group_for_reduces : first_upcasted ] , start = first_reduce + group_for_reduces ) ]
# upcast loops
for i , g in enumerate ( full_shape [ first_upcasted : ] , start = first_upcasted ) :
assert isinstance ( g , int ) , " needs to be int to upcast/unroll "
idxs . append ( UOp ( Ops . UNROLL , dtypes . int , ( UOp . const ( dtypes . int . vec ( g ) , tuple ( range ( g ) ) ) , ) , ( ( i , g ) , ) ) )
# late indexes (group for reduce)
ridxs = idxs [ : ]
for a in range ( first_reduce , first_reduce + group_for_reduces ) :
ridxs [ a ] = UOp ( Ops . RANGE , dtypes . int , ( sint_to_uop ( 0 ) , sint_to_uop ( full_shape [ a ] ) ) , 1000 + a )
return IndexContext ( idxs , ridxs )
# ***** lowering (given index) *****
def lower_reduce_axis ( ctx : IndexContext , x : UOp ) :
# NOTE: always using ridxs is fine here
reduce_range , reduce_expand = partition ( [ ctx . ridxs [ i ] for i in x . axis_arg ] , lambda y : y . op is Ops . RANGE )
assert all ( x . op is Ops . UNROLL for x in reduce_expand ) , f " not all UNROLLS in { reduce_expand } for { x . axis_arg } "
alu_op : Ops = x . arg [ 0 ]
ret = x . src [ 0 ]
# create acc
acc = UOp ( Ops . DEFINE_ACC , x . dtype , ( x . const_like ( identity_element ( alu_op , x . dtype . scalar ( ) ) ) , ) + tuple ( reduce_range ) , ( ctx . acc_num , ) )
ctx . acc_num + = 1
if len ( contract_axis := flatten ( x . arg for x in reduce_expand ) ) :
ret = UOp ( Ops . CONTRACT , x . dtype . vec ( prod ( x [ 1 ] for x in contract_axis ) ) , ( ret , ) , tuple ( contract_axis ) )
ret = functools . reduce ( lambda x , y : x . alu ( alu_op , y ) , [ acc ] + [ ret . gep ( i ) for i in range ( ret . dtype . count ) ] )
else :
ret = acc . alu ( alu_op , ret )
if not len ( reduce_range ) : return ret
# create ACC and assign
return acc . assign ( ret )
def lower_load_store ( ctx : IndexContext , x : UOp ) :
idx , valid = x . st_arg . to_indexed_uops ( ctx . ridxs if x . op is Ops . LOAD and x . src [ 0 ] . op is Ops . DEFINE_LOCAL else ctx . idxs )
buf = x . src [ 0 ]
if x . op is Ops . LOAD :
barrier = ( UOp ( Ops . BARRIER , dtypes . void , ( x . src [ 2 ] , ) ) , ) if x . src [ 0 ] . op is Ops . DEFINE_LOCAL else ( )
return UOp ( Ops . LOAD , x . dtype , ( buf . index ( idx , valid ) , ) + barrier )
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
if cast ( PtrDType , x . src [ 0 ] . dtype ) . local and x . src [ 2 ] . op is Ops . ASSIGN :
reduce_input = x . src [ 2 ] . src [ 1 ] . src [ 1 ] if x . src [ 2 ] . src [ 1 ] . src [ 1 ] is not x . src [ 2 ] . src [ 0 ] else x . src [ 2 ] . src [ 1 ] . src [ 0 ]
store_back = reduce_input . op is Ops . LOAD and cast ( PtrDType , reduce_input . src [ 0 ] . dtype ) . local
else : store_back = False
# NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes
if store_back : idx , _ = x . st_arg . to_indexed_uops ( [ u . const_like ( 0 ) if u in x . src [ 2 ] . src else u for u in ctx . idxs ] )
if ( not cast ( PtrDType , x . src [ 0 ] . dtype ) . local ) or store_back :
for oidx , ridx in zip ( ctx . idxs , ctx . ridxs ) :
if oidx is not ridx : valid = valid * oidx . eq ( 0 )
return UOp ( Ops . STORE , dtypes . void , ( buf . index ( idx , valid ) , x . src [ 2 ] ) )
def lower_const ( x : UOp ) :
assert all ( v . mask is None for v in unwrap ( x . st ) . views ) , f " VIEW in CONST/DEFINE_VAR source must be unmasked, got { x . st } "
return x . replace ( src = ( ) )
pm_lowerer = PatternMatcher ( [
( UPat ( Ops . REDUCE_AXIS , name = " x " ) , lower_reduce_axis ) ,
( UPat ( ( Ops . CONST , Ops . DEFINE_VAR ) , src = ( UPat ( Ops . VIEW ) , ) , name = " x " ) , lower_const ) ,
( UPat ( Ops . VALID , src = ( UPat ( Ops . VIEW ) , ) , name = " x " ) , lambda ctx , x : x . st_arg . to_indexed_uops ( ctx . idxs ) [ 1 ] ) ,
# rewrite LOAD/STORE VIEW to LOAD/STORE with indexed
( UPat ( ( Ops . LOAD , Ops . STORE ) , src = ( UPat ( ) , UPat ( Ops . VIEW ) ) , allow_any_len = True , name = " x " ) , lower_load_store ) ,
( UPat ( Ops . INDEX , src = ( UPat . var ( " b " ) , UPat . var ( " idx " ) , UPat . const ( dtypes . bool , True ) ) ) , lambda b , idx : b . index ( idx ) ) ,
( UPat ( Ops . IGNORE , name = " x " ) , lambda x : x . src [ 0 ] ) ,
] )
# **** this is the "quantization preprocessor", it makes ONNX quantized models, and probably also others, actually use ints ****
FP = ( 1 << 16 )
pm_quant = symbolic + PatternMatcher ( [
# cast after add/mul
( UPat . var ( " x " ) . cast ( dtypes . float32 ) + UPat . var ( " y " ) . cast ( dtypes . float32 ) ,
lambda x , y : ( x . cast ( least_upper_dtype ( x . dtype , y . dtype ) ) + y . cast ( least_upper_dtype ( x . dtype , y . dtype ) ) ) . cast ( dtypes . float32 ) ) ,
( UPat . var ( " x " ) . cast ( dtypes . float32 ) * UPat . var ( " y " ) . cast ( dtypes . float32 ) ,
lambda x , y : ( x . cast ( least_upper_dtype ( x . dtype , y . dtype ) ) * y . cast ( least_upper_dtype ( x . dtype , y . dtype ) ) ) . cast ( dtypes . float32 ) ) ,
# masked MUL after masked ADD
( ( UPat . var ( " x " ) + UPat . var ( " v " ) . where ( UPat . var ( ' cadd ' ) , UPat ( Ops . CONST , arg = 0 ) ) ) * UPat . var ( " v " ) . where ( UPat . var ( ' cmul ' ) , UPat ( Ops . CONST , arg = 0 ) ) ,
lambda x , v , cadd , cmul : x * v . where ( cmul , 0 ) + v . where ( cadd * cmul , 0 ) ) ,
# MUL after reduce
( UPat ( Ops . REDUCE_AXIS , src = ( UPat . var ( " x " ) * UPat . cvar ( " c " ) , ) , name = " r " ) , lambda x , c , r : r . replace ( src = ( x , ) ) * c ) ,
# CAST after reduce (doesn't work if it's a size change)
( UPat ( Ops . REDUCE_AXIS , src = ( UPat ( Ops . CAST , src = ( UPat . var ( " x " ) , ) ) , ) , name = " r " ) ,
lambda x , r : r . replace ( dtype = x . dtype , src = ( x , ) ) . cast ( r . dtype ) if dtypes . is_float ( r . dtype ) else None ) ,
# x*c1 + y*c2 -> (x+y)*c1 (if c1 and c2 are close floats)
( UPat . var ( " x " ) * UPat . cvar ( " c1 " , dtype = dtypes . floats ) + UPat . var ( " y " ) * UPat . cvar ( " c2 " , dtype = dtypes . floats ) ,
lambda x , y , c1 , c2 : ( x + y ) * c1 if abs ( c1 . arg - c2 . arg ) < 1e-9 else None ) ,
# mul 0 * c1 is 0
( UPat ( Ops . VALID , src = ( UPat ( Ops . VIEW , name = " v " ) , ) ) . where ( UPat . cvar ( " c1 " ) , UPat ( Ops . CONST , arg = 0 ) ) *
UPat ( Ops . LOAD , src = ( UPat ( ) , UPat ( Ops . VIEW , name = " v " ) ) ) . cast ( dtypes . int ) . cast ( dtypes . float ) . named ( " ld " ) , lambda ld , v , c1 : ld * c1 ) ,
# mul (with plus) 0 * c1 is 0
( UPat ( Ops . VALID , src = ( UPat ( Ops . VIEW , name = " v " ) , ) ) . where ( UPat . cvar ( " c1 " ) , UPat ( Ops . CONST , arg = 0 ) ) *
( UPat ( Ops . LOAD , src = ( UPat ( ) , UPat ( Ops . VIEW , name = " v " ) ) ) . cast ( dtypes . int ) + \
UPat ( Ops . VALID , src = ( UPat ( Ops . VIEW , name = " v " ) , ) ) . where ( UPat . cvar ( ) , UPat ( Ops . CONST , arg = 0 ) ) ) . cast ( dtypes . float ) . named ( " ld " ) ,
lambda ld , v , c1 : ld * c1 ) ,
# fixed point mult, replace (x.float()*c1+c2).int() with an int expression
( ( UPat . var ( " x " ) . cast ( dtypes . float ) * UPat . var ( " c1 " ) + UPat . var ( " c2 " ) ) . cast ( dtypes . int ) ,
lambda x , c1 , c2 : ( x * ( c1 * FP ) . cast ( dtypes . int ) + ( c2 * FP ) . cast ( dtypes . int ) ) / / FP ) ,
# fixed point mult, replace (x.float()*c1 + y.float()*c2) with an int expression
( ( UPat . var ( " x " ) . cast ( dtypes . float ) * UPat . var ( " c1 " ) + UPat . var ( " y " ) . cast ( dtypes . float ) * UPat . var ( " c2 " ) ) ,
lambda x , y , c1 , c2 : ( ( x * ( c1 * FP ) . cast ( dtypes . int ) + y * ( c2 * FP ) . cast ( dtypes . int ) ) / / FP ) . cast ( dtypes . float ) ) ,
# where move
( UPat . var ( " valid " ) . where ( UPat . var ( " yes " ) , UPat ( Ops . CONST , arg = 0 ) ) * UPat . var ( " mul " ) , lambda valid , yes , mul :
( yes * mul * valid . where ( UOp . const ( mul . dtype , 1 ) , UOp . const ( mul . dtype , 0 ) ) ) if yes . op is not Ops . CONST or yes . arg != 1 else None ) ,
( ( UPat . var ( " x " ) * UPat . cvar ( " c " ) ) * ( UPat . var ( ) . where ( UPat ( Ops . CONST , arg = 1 ) , UPat ( Ops . CONST , arg = 0 ) ) . named ( " v " ) ) , lambda x , c , v : ( x * v ) * c ) ,
( UPat . var ( " x " ) . cast ( ) . named ( ' c ' ) * UPat . var ( ' valid ' ) . where ( UPat ( Ops . CONST , arg = 1 ) , UPat ( Ops . CONST , arg = 0 ) ) , lambda x , c , valid :
( x * valid . where ( UOp . const ( x . dtype , 1 ) , UOp . const ( x . dtype , 0 ) ) ) . cast ( c . dtype ) ) ,
( ( UPat . var ( ' x ' ) * UPat . var ( ' v1 ' ) . where ( UPat ( Ops . CONST , arg = 1 ) , UPat ( Ops . CONST , arg = 0 ) ) *
UPat . var ( ' v2 ' ) . where ( UPat ( Ops . CONST , arg = 1 ) , UPat ( Ops . CONST , arg = 0 ) ) ) . named ( " mul " ) , lambda x , mul , v1 , v2 :
x * ( v1 & v2 ) . where ( UOp . const ( mul . dtype , 1 ) , UOp . const ( mul . dtype , 0 ) ) ) ,
# where on two adds
( UPat . var ( " x " ) + UPat . var ( " v " ) . where ( UPat . var ( " a0 " ) , UPat . var ( " a1 " ) ) + UPat . var ( " v " ) . where ( UPat . var ( " b0 " ) , UPat . var ( " b1 " ) ) ,
lambda x , v , a0 , a1 , b0 , b1 : x + v . where ( a0 + a1 , b0 + b1 ) ) ,
# split REDUCE into multiple reduces
( UPat ( Ops . REDUCE_AXIS , src = ( UPat ( Ops . CAST , name = " v1 " ) + UPat . var ( " c1 " ) ) * UPat ( Ops . CAST , name = " v2 " , ) , name = " r " ) ,
lambda v1 , v2 , c1 , r : r . replace ( src = ( v1 * v2 , ) ) + r . replace ( src = ( c1 * v2 , ) ) ) ,
( UPat ( Ops . REDUCE_AXIS , src = ( UPat ( Ops . CAST , name = " v1 " ) + UPat . var ( " c1 " ) ) * ( UPat ( Ops . CAST , name = " v2 " , ) + UPat . var ( " c2 " ) ) , name = " r " ) ,
lambda v1 , v2 , c1 , c2 , r : r . replace ( src = ( v1 * v2 , ) ) + r . replace ( src = ( c2 * v1 , ) ) + r . replace ( src = ( c1 * v2 , ) ) ) ,
] )
def rewrite_shapetracker_with_index ( ast : UOp , opts : Renderer ) - > UOp :
if QUANTIZE and opts . device in { " CPU " , " DSP " } : ast = graph_rewrite ( ast , pm_quant , name = " quantize " )
sink = graph_rewrite ( ast , pm_lowerer , ctx = get_index ( ast , opts ) )
# expand_rewrite turns this into a vectorized program
return expand_rewrite ( sink )