import struct
from platform import system
from typing import Tuple , Dict , List , Optional
from tinygrad import dtypes
from tinygrad . uop . ops import BinaryOps , UnaryOps , TernaryOps
from tinygrad . codegen . kernel import Ops , UOp
from tinygrad . helpers import CI
from tinygrad . codegen . assembly import uops_to_asmstyle , AssemblyLanguage
def float_to_hex ( x ) : return " %02X %02X %02X %02X " % tuple ( struct . pack ( " f " , x ) [ : : - 1 ] )
def compute_offsets ( total ) :
quotient , remainder = divmod ( total , 4096 )
return [ 4096 ] * quotient + [ remainder ] if remainder else [ 4096 ] * quotient
#NOTE: Darwin needs names to start with a "_"
def get_name ( name ) : return ( ' _ ' if system ( ) == ' Darwin ' else ' ' ) + name
class ARM64Language ( AssemblyLanguage ) : pass
def specialize_to_arm64 ( fn_nm , asm ) :
var_size = 16
prev_uop : Optional [ Ops ] = None
ins = [ ]
x_regs = [ ' x ' + str ( i ) for i in reversed ( range ( 12 ) ) ]
s_regs = [ ' s ' + str ( i ) for i in reversed ( range ( 3 , 32 ) ) if i < = 7 or i > = 16 ]
type_to_reg = { dtypes . double : " d " , dtypes . half : ' h ' , dtypes . float32 : ' s ' , dtypes . bool : ' w ' , dtypes . int8 : ' w ' , dtypes . int32 : ' w ' , dtypes . int64 : ' x ' , dtypes . uint8 : ' w ' , dtypes . uint32 : ' w ' , dtypes . uint64 : ' x ' }
alu = { BinaryOps . ADD : " add " , BinaryOps . SUB : " sub " , BinaryOps . MUL : " mul " , BinaryOps . DIV : " div " , BinaryOps . MAX : " max " ,
BinaryOps . MOD : " " , BinaryOps . CMPLT : " subs " ,
UnaryOps . NOOP : " mov " , UnaryOps . NEG : " neg " ,
UnaryOps . SIN : ' bl ' + get_name ( ' sinf ' ) , UnaryOps . LOG2 : ' bl ' + get_name ( " log2f " ) , UnaryOps . EXP2 : ' bl ' + get_name ( " exp2f " ) , UnaryOps . SQRT : ' bl ' + get_name ( " sqrtf " ) ,
TernaryOps . MULACC : " madd " , TernaryOps . WHERE : " fcsel " }
def mov_imm ( value , reg ) :
# Manually move value into reg if value can't fit
if value . __class__ is not float and abs ( value ) > abs ( 65535 ) :
ins . append ( f " movz w15, # { value & 0xffff } " )
ins . append ( f " movk w15, # { ( value >> 16 ) & 0xffff } , lsl #16 " )
ins . append ( f " sxtw { reg } , w15 " )
elif reg [ 0 ] == ' s ' :
ins . append ( f " movz x15, 0x { float_to_hex ( value ) [ 4 : ] } " )
ins . append ( f " movk x15, 0x { float_to_hex ( value ) [ : 4 ] } , lsl #16 " )
ins . append ( " str x15, [sp, 16] " )
ins . append ( f " ldr { reg } , [sp, 16] " )
else :
ins . append ( f " mov { reg } , # { value } " )
# Get variables intervals
live_range : Dict [ str , List [ int ] ] = { }
for i , ( uop , out , vin , arg ) in enumerate ( asm ) :
for var in ( [ v for v in [ out ] + vin if v is not None and v . __class__ is not int ] ) :
live_range [ var . nm ] = [ i , i ] if var . nm not in live_range else [ live_range [ var . nm ] [ 0 ] , i ]
mem_vars : Dict [ str , int ] = { }
rtor : Dict [ str , str ] = { }
def allocate_regs ( mvars ) :
nonlocal var_size
for v in [ v for v in mvars if v is not None and v . __class__ is not int and v . nm not in rtor ] :
available_regs = s_regs if dtypes . is_float ( v [ 1 ] ) else x_regs
#NOTE: Very simple spill, everything that don't fit in regs goes to mem
if not available_regs :
# ARM needs the stack 16-byte aligned
var_size + = 16
available_regs . append ( ' s0 ' if dtypes . is_float ( out [ 1 ] ) else ' x12 ' )
mem_vars [ v . nm ] = var_size
rtor [ v . nm ] = available_regs . pop ( )
temp_floats = [ ' s0 ' , ' s1 ' , ' s2 ' ]
temp_ints = [ ' x12 ' , ' x13 ' , ' x16 ' ]
for i , ( uop , out , vin , arg ) in enumerate ( asm ) :
# Clear regs out of interval
for var , reg in list ( rtor . items ( ) ) :
available_regs = s_regs if reg [ 0 ] == ' s ' else x_regs
if var [ 1 ] not in ' B ' and var not in mem_vars and i > live_range [ var ] [ 1 ] :
available_regs . append ( rtor . pop ( var ) )
# Assign a registers to the variables using live ranges.
allocate_regs ( [ out ] + vin )
# Assign temp regs to vin and load them before direct use
for i , v in enumerate ( [ v for v in vin if v . __class__ is not int and v . nm in mem_vars ] ) :
rtor [ v . nm ] = temp_floats [ i ] if dtypes . is_float ( v [ 1 ] ) else temp_ints [ i ]
# ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912
ins . append ( f " mov x15, { mem_vars [ v . nm ] } " )
ins . append ( f " ldr { rtor [ v . nm ] } , [sp, x15] " )
if uop == Ops . SPECIAL :
if arg . startswith ( ' data ' ) :
# data 8 to n into the stack
if int ( arg [ 4 : ] ) > = 8 :
ins . append ( f " ldr x15, [x17, # { ( int ( arg [ 4 : ] ) - 8 ) * 8 } ] " )
ins . append ( f " mov { rtor [ out . nm ] } , x15 " )
else :
ins . append ( f " mov { rtor [ out . nm ] } , #0 " )
ins . append ( f " loop_ { arg } : " )
elif uop == Ops . CAST :
if arg == BinaryOps . CMPLT :
if rtor [ out . nm ] [ 0 ] == ' s ' :
mov_imm ( 0.0 , ' s0 ' )
mov_imm ( 1.0 , ' s1 ' )
ins . append ( f " fcsel { rtor [ out . nm ] } , s1, s0, lt " )
if rtor [ out . nm ] [ 0 ] == ' x ' :
mov_imm ( 0 , ' x14 ' )
mov_imm ( 1 , ' x15 ' )
ins . append ( f " csel { rtor [ out . nm ] } , x15, x14, lt " )
else :
ins . append ( f " sxtw { rtor [ out . nm ] } , w { rtor [ vin [ 0 ] . nm ] [ 1 : ] } " )
elif uop == Ops . ALU :
if len ( vin ) == 2 and vin [ 1 ] . __class__ is int : mov_imm ( vin [ 1 ] , ' x15 ' )
if arg == BinaryOps . MUL and out . dtype == dtypes . bool :
ins . append ( f " ands { ' , ' . join ( ' x15 ' if v . __class__ is int else rtor [ v . nm ] for v in [ out ] + vin ) } " )
elif arg == TernaryOps . WHERE :
ins . append ( f " fcmp { rtor [ vin [ 0 ] . nm ] } , #0.0 " if rtor [ vin [ 0 ] . nm ] [ 0 ] == ' s ' else f " cmp { rtor [ vin [ 0 ] . nm ] } , #0 " )
ins . append ( f " { alu [ arg ] } { rtor [ out . nm ] } , { rtor [ vin [ 1 ] . nm ] } , { rtor [ vin [ 2 ] . nm ] } , ne " )
elif arg in [ UnaryOps . LOG2 , UnaryOps . SIN , UnaryOps . EXP2 , UnaryOps . SQRT ] :
#NOTE: Not a real instruction, use to emulate a ext call in unicorn
if CI : ins . append ( f " { alu [ arg ] } { rtor [ out . nm ] } { rtor [ vin [ 0 ] . nm ] } " )
else :
save_regs = [ k for k in rtor . keys ( ) if k != out . nm and k not in mem_vars ]
ins . append ( f " sub sp, sp, # { ( len ( save_regs ) ) * 16 } " )
# Save the registers before they are cleared by func call
for i , k in enumerate ( save_regs , 1 ) :
ins . append ( f " str { rtor [ k ] } , [sp, # { 16 * i } ] " )
ins . append ( " stp x29, x30, [sp, #0]! " )
ins . append ( " mov x29, sp " )
ins . append ( f " fmov s0, { rtor [ vin [ 0 ] . nm ] } " )
ins . append ( alu [ arg ] )
ins . append ( f " fmov { rtor [ out . nm ] } , s0 " )
ins . append ( " mov sp, x29 " )
ins . append ( " ldp x29, x30, [sp], #0 " )
for i , k in enumerate ( save_regs , 1 ) :
ins . append ( f " ldr { rtor [ k ] } , [sp, # { 16 * i } ] " )
ins . append ( f " add sp, sp, # { len ( save_regs ) * 16 } " )
elif arg == BinaryOps . CMPLT :
ins . append ( f " { alu [ arg ] } { ' , ' . join ( ' x15 ' if v . __class__ is int else rtor [ v . nm ] for v in [ out ] + vin ) } " if not dtypes . is_float ( vin [ 0 ] [ 1 ] ) else f " fcmp { rtor [ vin [ 0 ] . nm ] } , { rtor [ vin [ 1 ] . nm ] } " )
elif arg == BinaryOps . MOD :
rhs = ' x15 ' if vin [ 1 ] . __class__ is int else rtor [ vin [ 1 ] . nm ]
ins . append ( f " udiv x14, { rtor [ vin [ 0 ] . nm ] } , { rhs } " )
ins . append ( f " msub { rtor [ out . nm ] } , x14, { rhs } , { rtor [ vin [ 0 ] . nm ] } " )
else :
ins . append ( f " { ' f ' if dtypes . is_float ( vin [ 0 ] [ 1 ] ) else ' s ' if arg == BinaryOps . DIV else ' ' } { alu [ arg ] } { ' , ' . join ( ' x15 ' if v . __class__ is int else rtor [ v . nm ] for v in [ out ] + vin ) } " )
elif uop == Ops . LOAD :
if arg . __class__ in ( int , float ) :
mov_imm ( arg , rtor [ out . nm ] )
else :
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
reg_in = type_to_reg [ arg [ 2 ] ] + ( ' 0 ' if dtypes . is_float ( arg [ 2 ] ) else ' 12 ' ) if arg [ 2 ] is not None else rtor [ out . nm ]
mov_imm ( arg [ 0 ] , " x15 " )
ins . append ( f " add x15, { rtor [ vin [ 0 ] . nm ] } , x15 " )
ins . append ( f " ldr { ' sb ' if arg [ 2 ] is not None and arg [ 2 ] in ( dtypes . int8 , dtypes . uint8 , dtypes . bool ) else ' ' } { reg_in } , [x15] " )
if arg [ 2 ] is not None : ins . append ( f " { ' fcvt ' if arg [ 2 ] in [ dtypes . half , dtypes . double ] else ' scvtf ' } { rtor [ out . nm ] } , { reg_in } " )
elif uop == Ops . STORE :
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
reg_out = ( type_to_reg [ arg [ 2 ] ] + ( ' 0 ' if dtypes . is_float ( arg [ 2 ] ) else ' 12 ' ) if arg [ 2 ] is not None else rtor [ vin [ 1 ] . nm ] )
if arg [ 2 ] is not None : ins . append ( f " fcvt { ' zs ' if arg [ 2 ] not in [ dtypes . half , dtypes . double ] else ' ' } { reg_out } , { rtor [ vin [ 1 ] . nm ] } " )
ins . append ( f " mov x15, # { arg [ 0 ] } " )
ins . append ( f " str { reg_out } , [ { rtor [ vin [ 0 ] . nm ] } , x15, lsl #0] " )
elif uop == Ops . COND_BRANCH :
#TODO: this is a hack it shouldn't always be a cmp before a cond branch?
if prev_uop == Ops . LOAD :
ins . append ( f " cmp { rtor [ vin [ 0 ] . nm ] } , #0 " )
ins . append ( f " b. { ' lt ' if arg [ 1 ] else ' ge ' } { arg [ 0 ] [ 1 : ] } " )
elif uop == Ops . LABEL :
ins . append ( f " { arg [ 1 : ] } : " )
elif uop == Ops . ENDLOOP :
mov_imm ( arg [ 0 ] , " x15 " )
ins . append ( f " add { rtor [ vin [ 0 ] . nm ] } , { rtor [ vin [ 0 ] . nm ] } , #1 " )
ins . append ( f " cmp { rtor [ vin [ 0 ] . nm ] } , x15 " )
ins . append ( f " b.lt loop_ { arg [ 1 ] } " )
prev_uop = uop
# store regs into memory if needed
if out is not None and out . nm in mem_vars :
ins . append ( f " mov x15, { mem_vars [ out . nm ] } " )
ins . append ( f " str { rtor [ out . nm ] } , [sp, x15] " )
return " \n " . join ( [ f " //varsize { var_size } " , " .arch armv8-a " , " .text " , f " .global { get_name ( fn_nm ) } " , " .p2align 2 " , f " { get_name ( fn_nm ) } : " , " mov x17, sp " ] + [ f " sub sp, sp, # { offset } " for offset in compute_offsets ( var_size ) ] + ins + [ f " add sp, sp, # { offset } " for offset in compute_offsets ( var_size ) ] + [ " ret " , " \n " ] )
def uops_to_arm64_asm ( fn_nm : str , uops : List [ UOp ] ) - > Tuple [ str , List [ int ] , List [ int ] , bool ] :
lang = ARM64Language ( )
global_size , local_size = uops_to_asmstyle ( lang , fn_nm , uops )
return specialize_to_arm64 ( fn_nm , lang . ins ) , global_size [ : : - 1 ] , local_size [ : : - 1 ] , True