from typing import cast
import math , struct
from tinygrad . renderer import Renderer
from tinygrad . ops import UOp , PatternMatcher , UPat , Ops , GroupOp
from tinygrad . dtype import dtypes , DType , PtrDType , truncate
def ldt ( dt : DType ) :
if isinstance ( dt , PtrDType ) : return ldt ( dt . base ) + " * "
return { dtypes . int8 : " i8 " , dtypes . int16 : " i16 " , dtypes . int32 : " i32 " , dtypes . int64 : " i64 " ,
dtypes . uint8 : " i8 " , dtypes . uint16 : " i16 " , dtypes . uint32 : " i32 " , dtypes . uint64 : " i64 " ,
dtypes . float16 : " half " , dtypes . float32 : " float " , dtypes . float64 : " double " , dtypes . bool : " i1 " , dtypes . void : " void " } [ dt ]
def lconst ( x , dtype : DType ) :
if dtype in dtypes . floats :
if math . isinf ( x ) or math . isnan ( x ) : return " 0x %02X %02X %02X %02X %02X %02X %02X %02X " % tuple ( struct . pack ( " d " , x ) [ : : - 1 ] )
return truncate [ dtype ] ( x )
return int ( x )
def lcast ( input_type : DType , output_type : DType ) :
if dtypes . is_float ( input_type ) :
if dtypes . is_float ( output_type ) : return ' fpext ' if output_type . itemsize > input_type . itemsize else ' fptrunc '
if dtypes . is_int ( output_type ) : return ' fptoui ' if dtypes . is_unsigned ( output_type ) else ' fptosi '
if dtypes . is_unsigned ( input_type ) or input_type == dtypes . bool :
if dtypes . is_float ( output_type ) : return ' uitofp '
if dtypes . is_int ( output_type ) : return ' trunc ' if output_type . itemsize < input_type . itemsize else ' zext '
if dtypes . is_int ( input_type ) :
if dtypes . is_float ( output_type ) : return ' sitofp '
if dtypes . is_int ( output_type ) : return ' trunc ' if output_type . itemsize < input_type . itemsize else ' sext '
raise NotImplementedError ( f " cast from { input_type } -> { output_type } not implemented " )
# llvm ops, lop[<dtype>][<op>]
unsigned_lop = { Ops . ADD : " add " , Ops . MUL : " mul " , Ops . IDIV : " udiv " , Ops . MOD : " urem " ,
Ops . CMPLT : " icmp ult " , Ops . CMPNE : " icmp ne " , Ops . OR : " or " , Ops . AND : " and " , Ops . XOR : " xor " , }
signed_lop = { * * unsigned_lop , Ops . CMPLT : " icmp slt " , Ops . IDIV : " sdiv " , Ops . MOD : " srem " }
flags = " nsz arcp contract afn "
float_lop = { Ops . ADD : " fadd " + flags , Ops . MUL : " fmul " + flags , Ops . CMPLT : f " fcmp { flags } ult " , Ops . CMPNE : f " fcmp { flags } une " , Ops . FDIV : " fdiv " + flags }
lop = { * * { x : unsigned_lop for x in ( dtypes . bool , ) + dtypes . uints } , * * { x : signed_lop for x in dtypes . sints } , * * { x : float_lop for x in dtypes . floats } }
llvm_rewrite = PatternMatcher ( [
# memory load/store
( UPat ( Ops . INDEX , name = " x " ) , lambda ctx , x :
f " { ctx [ x ] } = getelementptr inbounds { ldt ( x . dtype . base ) } , { ldt ( x . src [ 0 ] . dtype ) } { ctx [ x . src [ 0 ] ] } , { ldt ( x . src [ 1 ] . dtype ) } { ctx [ x . src [ 1 ] ] } " ) ,
( UPat ( Ops . LOAD , src = ( UPat . var ( ' idx ' ) , UPat . var ( ' alt ' ) , UPat . var ( ' mask ' ) ) , name = " x " ) , lambda ctx , x , idx , alt , mask :
f " br label { ctx [ x ] } _entry \n { ctx [ x ] [ 1 : ] } _entry: \n "
f " br i1 { ctx [ mask ] } , label { ctx [ x ] } _load, label { ctx [ x ] } _exit \n { ctx [ x ] [ 1 : ] } _load: \n "
f " { ctx [ x ] } _yes = load { ldt ( x . dtype ) } , { ldt ( idx . dtype ) } { ctx [ idx ] } \n "
f " br label { ctx [ x ] } _exit \n { ctx [ x ] [ 1 : ] } _exit: \n "
f " { ctx [ x ] } = phi { ldt ( x . dtype ) } [ { ctx [ x ] } _yes, { ctx [ x ] } _load], [ { ctx [ alt ] } , { ctx [ x ] } _entry] " ) ,
( UPat ( Ops . LOAD , src = ( UPat . var ( ' idx ' ) , ) , name = " x " ) , lambda ctx , x , idx : f " { ctx [ x ] } = load { ldt ( x . dtype ) } , { ldt ( idx . dtype ) } { ctx [ idx ] } " ) ,
( UPat ( Ops . STORE , name = " x " ) , lambda ctx , x : f " store { ldt ( x . src [ 1 ] . dtype ) } { ctx [ x . src [ 1 ] ] } , { ldt ( x . src [ 0 ] . dtype ) } { ctx [ x . src [ 0 ] ] } " ) ,
# unary/binary/ternary ops
( UPat ( Ops . SQRT , name = " x " ) , lambda ctx , x :
f " { ctx [ x ] } = call { flags } { ldt ( x . dtype ) } @llvm.sqrt. { ldt ( x . src [ 0 ] . dtype ) } ( { ldt ( x . src [ 0 ] . dtype ) } { ctx [ x . src [ 0 ] ] } ) " ) ,
( UPat ( Ops . BITCAST , name = " x " ) , lambda ctx , x : f " { ctx [ x ] } = bitcast { ldt ( x . src [ 0 ] . dtype ) } { ctx [ x . src [ 0 ] ] } to { ldt ( x . dtype ) } " ) ,
( UPat ( Ops . CAST , name = " x " ) , lambda ctx , x : f " { ctx [ x ] } = { lcast ( x . src [ 0 ] . dtype , x . dtype ) } { ldt ( x . src [ 0 ] . dtype ) } { ctx [ x . src [ 0 ] ] } to { ldt ( x . dtype ) } " ) ,
( UPat ( GroupOp . Binary , name = " x " ) , lambda ctx , x : f " { ctx [ x ] } = { lop [ x . src [ 0 ] . dtype ] [ x . op ] } { ldt ( x . src [ 0 ] . dtype ) } { ctx [ x . src [ 0 ] ] } , { ctx [ x . src [ 1 ] ] } " ) ,
( UPat ( Ops . WHERE , name = " x " ) , lambda ctx , x :
f " { ctx [ x ] } = select { ldt ( x . src [ 0 ] . dtype ) } { ctx [ x . src [ 0 ] ] } , { ldt ( x . src [ 1 ] . dtype ) } { ctx [ x . src [ 1 ] ] } , { ldt ( x . src [ 2 ] . dtype ) } { ctx [ x . src [ 2 ] ] } " ) ,
# range
( UPat ( Ops . RANGE , name = " x " ) , lambda ctx , x :
f " br label %loop_entry_ { x . arg } \n loop_entry_ { x . arg } : \n "
f " br label %loop_body_ { x . arg } \n loop_body_ { x . arg } : \n "
f " { ctx [ x ] } = phi { ldt ( x . dtype ) } [ { ctx [ x . src [ 0 ] ] } , %loop_entry_ { x . arg } ], [ { ctx [ x ] } phi, %loop_latch_ { x . arg } ] " ) ,
( UPat ( Ops . ENDRANGE , name = " x " ) , lambda ctx , x :
f " br label %loop_latch_ { x . src [ 0 ] . arg } \n loop_latch_ { x . src [ 0 ] . arg } : \n "
f " { ctx [ x . src [ 0 ] ] } phi = add i32 { ctx [ x . src [ 0 ] ] } , 1 \n { ctx [ x ] } = icmp ult i32 { ctx [ x . src [ 0 ] ] } phi, { ctx [ x . src [ 0 ] . src [ 1 ] ] } \n "
f " br i1 { ctx [ x ] } , label %loop_body_ { x . src [ 0 ] . arg } , label %loop_exit_ { x . src [ 0 ] . arg } \n loop_exit_ { x . src [ 0 ] . arg } : " ) ,
# if
( UPat ( Ops . IF , name = " x " ) , lambda ctx , x : f " br i1 { ctx [ x . src [ 0 ] ] } , label %ifbody_ { ctx [ x ] [ 1 : ] } , label %ifskip_ { ctx [ x ] [ 1 : ] } \n ifbody_ { ctx [ x ] [ 1 : ] } : " ) ,
( UPat ( Ops . ENDIF , name = " x " ) , lambda ctx , x : f " br label %ifskip_ { ctx [ x . src [ 0 ] ] [ 1 : ] } \n ifskip_ { ctx [ x . src [ 0 ] ] [ 1 : ] } : " ) ,
] )
class LLVMRenderer ( Renderer ) :
device = " LLVM "
supports_float4 = False
has_local = False
has_shared = False
global_max = None
extra_matcher = PatternMatcher ( [
# rewrite RECIP with FDIV
( UPat ( Ops . RECIP , name = " x " ) , lambda x : UOp ( Ops . FDIV , x . dtype , ( x . const_like ( 1 ) , x . src [ 0 ] ) ) ) ,
# rewrite cast to bool to CMPNE 0
( UPat ( Ops . CAST , dtype = dtypes . bool , name = " x " ) , lambda x : x . src [ 0 ] != x . src [ 0 ] . const_like ( 0 ) ) ,
# rewrite MAX to CMPLT + WHERE
( UPat ( Ops . MAX , name = " m " ) , lambda m : ( m . src [ 0 ] < m . src [ 1 ] ) . where ( m . src [ 1 ] , m . src [ 0 ] ) ) ,
] )
def render ( self , name : str , uops : list [ UOp ] ) - > str :
r : dict [ UOp , str ] = { }
args : list [ str ] = [ ]
kernel : list [ str ] = [ ]
end_lines : dict [ str , None ] = { }
vc = - 1
# prealloc all assigns
acc_to_assign : dict [ UOp , UOp ] = { }
for u in uops :
if u . op is Ops . ASSIGN :
vc + = 1
r [ u ] = r [ u . src [ 1 ] ] = f " %assign { vc } "
assert u . src [ 0 ] not in acc_to_assign , " can ' t assign to DEFINE_ACC twice "
acc_to_assign [ u . src [ 0 ] ] = u . src [ 1 ]
for u in uops :
# hack for defining sqrt function (TODO: can we get a transcendental for this?)
if u . op is Ops . SQRT : end_lines [ f ' declare { ldt ( u . dtype ) } @llvm.sqrt. { ldt ( u . dtype ) } ( { ldt ( u . dtype ) } % " .1 " ) ' ] = None
if u . op in ( Ops . DEFINE_GLOBAL , Ops . DEFINE_VAR ) :
r [ u ] = f " %data { u . arg } " if u . op is Ops . DEFINE_GLOBAL else f " % { u . arg [ 0 ] } "
args . append ( f " { ldt ( u . dtype ) } { ' noalias ' if isinstance ( u . dtype , PtrDType ) else ' ' } { r [ u ] } " )
elif u . op is Ops . ASSIGN : pass # assign is already handled by the first pass
elif u . op is Ops . DEFINE_ACC : r [ u ] = r [ u . src [ 0 ] ] # a define acc can be used and never be assigned to
elif u . op is Ops . CONST : r [ u ] = lconst ( u . arg , u . dtype )
elif u . op is Ops . CAST and ldt ( u . dtype ) == ldt ( u . src [ 0 ] . dtype ) : r [ u ] = r [ u . src [ 0 ] ] # cast from signed to unsigned of the same size is a noop
else :
# if it's an assign target, it's already preallocated
if u not in r :
vc + = 1
r [ u ] = f " %v { vc } "
# do the rendering of the llvm ir code
if ( l := llvm_rewrite . rewrite ( u , ctx = r ) ) is None : raise RuntimeError ( f " failed to render { u . op } with { u . dtype } srcs { [ x . dtype for x in u . src ] } " )
kernel . append ( cast ( str , l ) )
# generate the phi nodes for the assigns
if u . op is Ops . RANGE :
for x in acc_to_assign :
if u in x . src : # if this range is relevent for this acc
vc + = 1
kernel . append ( f " %acc { vc } = phi { ldt ( x . dtype ) } " f " [ { r [ x ] } , %loop_entry_ { u . arg } ], [ { r [ acc_to_assign [ x ] ] } , %loop_latch_ { u . arg } ] " )
r [ x ] = f " %acc { vc } "
# output the function
return f " define void @ { name } ( { ' , ' . join ( args ) } ) {{ \n " + ' \n ' . join ( kernel ) + " \n ret void \n } \n " + ' \n ' . join ( end_lines . keys ( ) )