from typing import List
import struct
from tinygrad . codegen . assembly import uops_to_asmstyle , AssemblyLanguage
from tinygrad . codegen . kernel import Ops , UOp
from tinygrad import dtypes
from tinygrad . uop . ops import BinaryOps , UnaryOps , TernaryOps
from tinygrad . runtime . ops_cuda import arch
dtype_to_nvtype = { dtypes . float32 : " f32 " , dtypes . float16 : " f16 " , dtypes . int64 : " s64 " , dtypes . int32 : " s32 " , dtypes . int8 : " s8 " , dtypes . bool : " pred " , dtypes . uint64 : " u64 " , dtypes . uint32 : " u32 " , dtypes . uint16 : " u16 " , dtypes . uint8 : " u8 " , " bits16 " : " b16 " , dtypes . float64 : " f64 " }
def float_to_hex ( x ) : return " %02X %02X %02X %02X " % tuple ( struct . pack ( " f " , x ) [ : : - 1 ] )
def ptx_needs_cast ( dest_dtype , src_dtype ) : return dtypes . is_float ( dest_dtype ) and dtypes . is_int ( src_dtype ) or dtypes . is_int ( dest_dtype ) and dtypes . is_float ( src_dtype ) or ( dtypes . is_float ( src_dtype ) and dtypes . is_float ( dest_dtype ) and dest_dtype . itemsize != src_dtype . itemsize )
def render_cast ( ins , inp , out ) :
if inp . dtype == dtypes . bool and ( dtypes . is_float ( out . dtype ) or dtypes . is_int ( out . dtype ) ) :
ins . append ( f " selp. { dtype_to_nvtype [ out . dtype ] } { out } , { ' 0f3F800000, 0f00000000 ' if dtypes . is_float ( out . dtype ) else ' 1, 0 ' } , { inp } ; " )
elif out . dtype == dtypes . bool :
if inp . dtype == dtypes . bool :
ins . append ( f " mov.pred { out } , { inp } ; " )
else :
ins . append ( f " setp.ne. { dtype_to_nvtype [ inp . dtype ] } { out } , { ' 0f00000000 ' if dtypes . is_float ( inp . dtype ) else ' 0 ' } , { inp } ; " )
else :
round_mod = " .rzi " if dtypes . is_int ( out . dtype ) and dtypes . is_float ( inp . dtype ) else ' .rz ' if dtypes . is_float ( out . dtype ) and ( dtypes . is_int ( inp . dtype ) or dtypes . is_float ( inp . dtype ) and inp . dtype . itemsize > out . dtype . itemsize ) else ' '
ins . append ( f " cvt { round_mod } . { dtype_to_nvtype [ out . dtype ] } . { dtype_to_nvtype [ inp . dtype ] } { out } , { inp } ; " )
# https://docs.nvidia.com/cuda/parallel-thread-execution/#
class PTXLanguage ( AssemblyLanguage ) :
supports_constant_folding : bool = True
def specialize_to_ptx ( lang , function_name ) :
param_cnt = 0
ins = [ ]
alu = { BinaryOps . ADD : " add " , BinaryOps . SUB : " sub " , BinaryOps . MUL : " mul " , BinaryOps . DIV : " div " , BinaryOps . MAX : " max " ,
BinaryOps . MOD : " rem " , BinaryOps . CMPLT : " setp.lt " , UnaryOps . SQRT : " sqrt.approx " ,
UnaryOps . NOOP : " mov " , UnaryOps . NEG : " neg " ,
UnaryOps . SIN : " sin.approx " , UnaryOps . LOG2 : " lg2.approx " , UnaryOps . EXP2 : " ex2.approx.ftz " ,
TernaryOps . MULACC : " fma.rn " , TernaryOps . WHERE : " selp " }
for uop , out , vin , arg in lang . ins :
if uop == Ops . ENDLOOP :
ins . append ( " bar.sync 0; " )
elif uop == Ops . DEFINE_LOCAL :
ins . append ( f " .shared .align 4 .b8 { arg [ 0 ] } [ { arg [ 1 ] * 4 } ]; " )
elif uop == Ops . SPECIAL :
if arg . startswith ( ' data ' ) :
param_cnt + = 1
ins . append ( f " ld.param.u64 { out } , [ { arg } ]; " )
# TODO: we sometimes want this to be local, nvcc converts to global most of the time, not sure when we would need to?
# ins.append(f"cvta.to.global.u64 {out}, {out};")
elif arg . startswith ( ' gid ' ) :
ins . append ( f " mov.u32 { out } , %ctaid. { ' xyz ' [ int ( arg [ 3 : ] ) ] } ; " )
elif arg . startswith ( ' lid ' ) :
ins . append ( f " mov.u32 { out } , %tid. { ' xyz ' [ int ( arg [ 3 : ] ) ] } ; " )
elif uop == Ops . ALU :
if arg == BinaryOps . MUL and out . dtype == dtypes . bool :
ins . append ( f " and.pred { out } , { ' , ' . join ( str ( x ) for x in vin ) } ; " )
else :
otype = vin [ 0 ] . dtype if arg in [ BinaryOps . CMPLT ] else out . dtype
if arg == TernaryOps . WHERE :
if vin [ 0 ] . dtype == dtypes . bool :
reg = vin [ 0 ]
else :
reg = lang . newreg ( ( vin [ 0 ] , ' bool ' ) , dtypes . bool )
ins . append ( f " setp.ne. { dtype_to_nvtype [ vin [ 0 ] . dtype ] } { reg } , { ' 0f00000000 ' if dtypes . is_float ( vin [ 0 ] . dtype ) else ' 0 ' } , { vin [ 0 ] } ; " )
vin = vin [ 1 : ] + [ reg ]
ins . append ( f " { alu [ arg ] } { ' .lo ' if arg == BinaryOps . MUL and out . dtype != dtypes . float32 else ' ' } { ' .rn ' if arg == BinaryOps . DIV and out . dtype == dtypes . float32 else ' ' } . { dtype_to_nvtype [ otype ] } { out } , { ' , ' . join ( str ( x ) for x in vin ) } ; " )
elif uop == Ops . LOAD :
if arg . __class__ in ( int , float ) :
ins . append ( f " mov. { dtype_to_nvtype [ out . dtype ] } { out } , { ' 0f ' + float_to_hex ( arg ) if dtypes . is_float ( out . dtype ) else int ( arg ) } ; " )
elif arg [ 2 ] is not None and ( arg [ 2 ] == dtypes . bool or arg [ 2 ] != out . dtype ) :
dt = ( ' u16 ' , dtypes . uint16 ) if arg [ 2 ] == dtypes . bool == out . dtype else ( ' u8 ' , dtypes . uint8 ) if arg [ 2 ] == dtypes . bool else ( ' b16 ' , dtypes . float16 ) if arg [ 2 ] == dtypes . half else ( dtype_to_nvtype [ arg [ 2 ] ] , arg [ 2 ] )
reg = lang . newreg ( ( out , dt [ 0 ] ) , dtype = dt [ 1 ] )
ins . append ( f " ld. { arg [ 1 ] } . { dt [ 0 ] } { reg } , [ { vin [ 0 ] } { f ' + { arg [ 0 ] } ' if arg [ 0 ] is not None else ' ' } ]; " )
render_cast ( ins , reg , out )
else :
ins . append ( f " ld. { arg [ 1 ] } . { dtype_to_nvtype [ dtypes . float if arg [ 2 ] is None else arg [ 2 ] ] } { out } , [ { vin [ 0 ] } { f ' + { arg [ 0 ] } ' if arg [ 0 ] is not None else ' ' } ]; " )
elif uop == Ops . STORE :
if ptx_needs_cast ( dtypes . float if arg [ 2 ] is None else arg [ 2 ] , vin [ 1 ] . dtype ) or arg [ 2 ] == dtypes . bool :
if arg [ 2 ] == dtypes . bool != vin [ 1 ] . dtype :
prereg = lang . newreg ( ( vin [ 1 ] , ' bool ' ) , dtype = dtypes . bool )
render_cast ( ins , vin [ 1 ] , prereg )
else : prereg = vin [ 1 ]
reg = lang . newreg ( ( prereg , dtypes . uint16 if arg [ 2 ] == dtypes . bool else arg [ 2 ] ) , dtype = dtypes . uint16 if arg [ 2 ] == dtypes . bool else dtypes . float if arg [ 2 ] is None else arg [ 2 ] )
render_cast ( ins , prereg , reg )
ins . append ( f " st. { arg [ 1 ] } . { dtype_to_nvtype [ ' bits16 ' if arg [ 2 ] == dtypes . float16 else dtypes . uint8 if arg [ 2 ] == dtypes . bool else dtypes . float if arg [ 2 ] is None else arg [ 2 ] ] } [ { vin [ 0 ] } { f ' + { arg [ 0 ] } ' if arg [ 0 ] is not None else ' ' } ], { reg } ; " )
else :
ins . append ( f " st. { arg [ 1 ] } . { dtype_to_nvtype [ dtypes . float if arg [ 2 ] is None else arg [ 2 ] ] } [ { vin [ 0 ] } { f ' + { arg [ 0 ] } ' if arg [ 0 ] is not None else ' ' } ], { vin [ 1 ] } ; " )
elif uop == Ops . CAST :
render_cast ( ins , vin [ 0 ] , out )
elif uop == Ops . LABEL :
ins . append ( f " { arg } : " )
elif uop == Ops . COND_BRANCH :
ins . append ( f " @ { ' ! ' if not arg [ 1 ] else ' ' } { vin [ 0 ] } bra { arg [ 0 ] } ; " )
ins_prefix = [ " .version 7.8 " , " .target " + arch ( ) , " .address_size 64 " ,
f " .visible .entry { function_name } ( { ' , ' . join ( f ' .param .u64 data { i } ' for i in range ( param_cnt ) ) } ) {{ " ]
for arg in [ ( dtype , lang . type_to_letter ( dtype ) , c ) for dtype , c in lang . cnts . items ( ) ] : ins_prefix . append ( f " .reg . { dtype_to_nvtype [ arg [ 0 ] [ 0 ] ] } % { arg [ 1 ] } < { arg [ 2 ] } >; " , )
ins = ins_prefix + ins
ins + = [ " ret; " , " } " ]
return ' \n ' . join ( ins )
def uops_to_ptx_asm ( function_name : str , uops : List [ UOp ] ) :
lang = PTXLanguage ( )
global_size , local_size = uops_to_asmstyle ( lang , function_name , uops )
return specialize_to_ptx ( lang , function_name ) , global_size [ : : - 1 ] , local_size [ : : - 1 ] , True