from typing import Optional , Union , Literal , Callable , cast
import os , math , sys
from collections import defaultdict , Counter
from tinygrad . ops import GroupOp , Ops , UOp , PatternMatcher , UPat
from tinygrad . helpers import strip_parens , getenv , prod , dedup , AMX
from tinygrad . dtype import ImageDType , dtypes , DType , PtrDType
from tinygrad . renderer import Renderer , TensorCore
from tinygrad . codegen . devectorizer import no_vectorized_alu
base_rewrite = PatternMatcher ( [
( UPat ( Ops . DEFINE_ACC , name = " x " ) , lambda ctx , x : ctx [ x . src [ 0 ] ] ) ,
( UPat ( Ops . ASSIGN , name = " x " ) , lambda ctx , x : f " { ctx [ x . src [ 0 ] ] } = { ctx [ x . src [ 1 ] ] } ; " ) ,
( UPat ( Ops . IF , name = " x " ) , lambda ctx , x : f " if ( { ctx [ x . src [ 0 ] ] } ) {{ " ) ,
( UPat ( ( Ops . ENDIF , Ops . ENDRANGE ) ) , lambda ctx : " } " ) ,
( UPat ( Ops . WMMA , name = " x " ) , lambda ctx , x : f " __ { x . arg [ 0 ] } ( { ctx [ x . src [ 0 ] ] } , { ctx [ x . src [ 1 ] ] } , { ctx [ x . src [ 2 ] ] } ) " ) ,
# r method accesses
( UPat ( Ops . RANGE , name = " x " ) ,
lambda ctx , x : f " for ( { ctx . render_dtype ( x . dtype ) } { ctx [ x ] } = { ctx [ x . src [ 0 ] ] } ; { ctx [ x ] } < { ctx [ x . src [ 1 ] ] } ; { ctx [ x ] } ++) {{ " ) ,
( UPat ( Ops . VECTORIZE , name = " x " ) ,
lambda ctx , x : f " { ctx . float4 . replace ( ' float4 ' , ctx . render_dtype ( x . dtype ) ) } " + \
( f " {{ { ' , ' . join ( [ ctx [ y ] for y in x . src ] ) } }} " if ctx . device in { ' CPU ' , ' DSP ' } else f " ( { ' , ' . join ( [ ctx [ y ] for y in x . src ] ) } ) " ) ) ,
( UPat ( Ops . CAST , name = " x " ) , lambda ctx , x :
f " __builtin_convertvector( { ctx [ x . src [ 0 ] ] } , { ctx . render_dtype ( x . dtype ) } ) " if x . dtype . count > 1 and not isinstance ( x . dtype , PtrDType ) else None ) ,
( UPat ( Ops . CAST , name = " x " ) , lambda ctx , x : f " ( { ctx . render_cast ( x . dtype , ctx [ x . src [ 0 ] ] ) } ) " ) ,
( UPat ( Ops . BITCAST , name = " x " ) , lambda ctx , x : f " (*(( { ctx . buffer_prefix } { ctx . render_dtype ( x . dtype ) } *)& { ctx [ x . src [ 0 ] ] } )) " ) ,
( UPat ( Ops . DEFINE_LOCAL , name = " x " ) , lambda ctx , x : f " { ctx . smem_align } { ctx . smem_prefix } { ctx . render_dtype ( x . dtype . base ) } { ctx [ x ] } [ { x . dtype . size } ]; " ) ,
( UPat ( Ops . BARRIER ) , lambda ctx : ctx . barrier ) ,
( UPat ( Ops . NOOP , name = " x " ) , lambda ctx , x : ctx [ x . src [ 0 ] ] ) ,
( UPat ( Ops . SPECIAL , name = " x " ) , lambda ctx , x : f " { ctx . code_for_workitem [ x . arg [ 0 ] [ 0 ] ] ( x . arg [ 0 ] [ - 1 ] ) } ; /* { x . arg [ 1 ] } */ " ) ,
# const
( UPat ( Ops . CONST , arg = math . inf , name = " x " ) , lambda ctx , x : f " ( { ctx . render_cast ( x . dtype , ctx . infinity ) } ) " ) ,
( UPat ( Ops . CONST , arg = - math . inf , name = " x " ) , lambda ctx , x : f " ( { ctx . render_cast ( x . dtype , f ' - { ctx . infinity } ' ) } ) " ) ,
( UPat ( Ops . CONST , dtype = dtypes . floats , name = " x " ) , lambda ctx , x : f " ( { ctx . render_cast ( x . dtype , ctx . nan ) } ) " if math . isnan ( x . arg ) else None ) ,
( UPat ( Ops . CONST , dtype = dtypes . float , name = " x " ) , lambda ctx , x : f " { x . arg } f " ) ,
( UPat ( Ops . CONST , dtype = dtypes . int64 , name = " x " ) , lambda ctx , x : f " { x . arg } ll " ) ,
( UPat ( Ops . CONST , dtype = dtypes . uint64 , name = " x " ) , lambda ctx , x : f " { x . arg } ull " ) ,
( UPat ( Ops . CONST , dtype = dtypes . uint32 , name = " x " ) , lambda ctx , x : f " { x . arg } u " ) ,
( UPat ( Ops . CONST , dtype = dtypes . bool , name = " x " ) , lambda ctx , x : " 1 " if x . arg else " 0 " ) ,
# consts are rendered to larger type and casted
( UPat ( Ops . CONST , ( dtypes . bfloat16 , dtypes . half ) , name = " x " ) , lambda ctx , x : f " ( { ctx . render_cast ( x . dtype , f ' { x . arg } f ' ) } ) " ) ,
( UPat ( Ops . CONST , ( dtypes . uint8 , dtypes . uint16 ) , name = " x " ) , lambda ctx , x : f " ( { ctx . render_cast ( x . dtype , f ' { x . arg } u ' ) } ) " ) ,
( UPat ( Ops . CONST , ( dtypes . int8 , dtypes . int16 ) , name = " x " ) , lambda ctx , x : f " ( { ctx . render_cast ( x . dtype , x . arg ) } ) " ) ,
# default const render
( UPat ( Ops . CONST , name = " x " ) , lambda ctx , x : str ( x . arg ) ) ,
# new load/store
( UPat ( Ops . INDEX , src = ( UPat . var ( " buf " ) , UPat . var ( ' idx ' ) ) , allow_any_len = True ) ,
lambda ctx , buf , idx : f " ( { ctx [ buf ] } + { strip_parens ( ctx [ idx ] ) if idx . arg == Ops . ADD else ctx [ idx ] } ) " ) ,
( UPat ( Ops . LOAD , src = ( UPat ( Ops . INDEX , src = ( UPat ( ) , UPat ( ) , UPat . var ( " gate " ) ) ) . or_casted ( ' bidx ' ) , UPat . var ( " var " ) ) , allow_any_len = True ) ,
lambda ctx , bidx , var , gate : f " ( { ctx [ gate ] } ?* { ctx [ bidx ] } : { ctx [ var ] } ) " ) ,
( UPat ( Ops . LOAD , src = ( UPat . var ( ' bidx ' ) , ) , allow_any_len = True ) , lambda ctx , bidx : f " * { ctx [ bidx ] } " ) ,
( UPat ( Ops . STORE , src = ( UPat . var ( ' bidx ' ) , UPat . var ( " var " ) ) , allow_any_len = True ) , lambda ctx , bidx , var : f " * { ctx [ bidx ] } = { ctx [ var ] } ; " ) ,
# alu/gep
( UPat ( GroupOp . ALU , name = " x " ) , lambda ctx , x : ctx . code_for_op [ x . op ] (
* ( [ strip_parens ( ctx [ v ] ) if v . op == x . op and x . op in { Ops . ADD , Ops . MUL , Ops . XOR } else ctx [ v ] for v in x . src ] ) , x . dtype ) ) ,
( UPat ( Ops . GEP , name = " x " ) , lambda ctx , x : ctx [ x . src [ 0 ] ] + \
( f " [ { x . arg [ 0 ] } ] " if x . src [ 0 ] . dtype . count > ( 8 if ctx . device in { " CUDA " , " NV " } else 4 ) or ctx . device in { ' CPU ' , ' DSP ' } else \
f " . { ' xyzwabcd ' [ x . arg [ 0 ] ] } " ) ) ,
# custom passes through with format
( UPat ( ( Ops . CUSTOM , Ops . CUSTOMI ) , name = " x " ) , lambda ctx , x : x . arg . format ( * [ ctx [ y ] for y in x . src ] ) ) ,
] )
extra_pm = PatternMatcher ( [
# insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
( UPat ( Ops . BITCAST , name = " x " ) ,
lambda x : UOp ( Ops . BITCAST , x . dtype , ( UOp ( Ops . NOOP , x . src [ 0 ] . dtype , x . src ) , ) ) if x . src [ 0 ] . op is not Ops . NOOP else None ) ,
# rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
( UPat ( Ops . MAX , name = " m " ) , lambda m : ( m . src [ 0 ] < m . src [ 1 ] ) . where ( m . src [ 1 ] , m . src [ 0 ] ) ) ,
# devectorize any bools
( UPat ( ( * GroupOp . ALU , Ops . CAST , Ops . BITCAST , Ops . ASSIGN , Ops . INDEX ) , dtype = dtypes . bool , name = " alu " ) , no_vectorized_alu ) ,
# CAST (from bool) can't be vectorized
( UPat ( Ops . CAST , src = ( UPat ( dtype = dtypes . bool ) , ) , name = " alu " ) , no_vectorized_alu ) ,
# WHERE can't be vectorized
( UPat ( Ops . WHERE , name = " alu " ) , no_vectorized_alu ) ,
] )
def uops_to_dtypes ( uops : list [ UOp ] ) - > list [ DType ] : return dedup ( u . dtype for u in uops if not isinstance ( u . dtype , ( ImageDType , PtrDType ) ) )
class CStyleLanguage ( Renderer ) :
kernel_prefix : str = " "
buffer_prefix : str = " "
buffer_suffix : str = " "
smem_align : str = " "
smem_prefix : str = " "
smem_prefix_for_cast : bool = True
arg_int_prefix : str = " const int "
barrier : str = " "
code_for_workitem : dict [ Union [ Literal [ " g " ] , Literal [ " l " ] , Literal [ " i " ] ] , Callable ] = { }
extra_args : list [ str ] = [ ]
float4 : Optional [ str ] = None
type_map : dict [ DType , str ] = { }
infinity : str = " INFINITY "
nan : str = " NAN "
code_for_op : dict = {
Ops . SQRT : lambda x , dtype : f " sqrt( { x } ) " , Ops . RECIP : lambda x , dtype : f " (1/ { x } ) " , Ops . NEG : lambda x , dtype : f " - { x } " ,
Ops . EXP2 : lambda x , dtype : f " exp2( { x } ) " , Ops . LOG2 : lambda x , dtype : f " log2( { x } ) " , Ops . SIN : lambda x , dtype : f " sin( { x } ) " ,
Ops . AND : lambda a , b , dtype : f " ( { a } & { b } ) " , Ops . XOR : lambda a , b , dtype : f " ( { a } ^ { b } ) " , Ops . OR : lambda a , b , dtype : f " ( { a } | { b } ) " ,
Ops . ADD : lambda a , b , dtype : f " ( { a } + { b } ) " , Ops . SUB : lambda a , b , dtype : f " ( { a } - { b } ) " , Ops . MUL : lambda a , b , dtype : f " ( { a } * { b } ) " ,
Ops . MOD : lambda a , b , dtype : f " ( { a } % { b } ) " , Ops . IDIV : lambda a , b , dtype : f " ( { a } / { b } ) " , Ops . CMPNE : lambda a , b , dtype : f " ( { a } != { b } ) " ,
Ops . SHR : lambda a , b , dtype : f " ( { a } >> { b } ) " , Ops . SHL : lambda a , b , dtype : f " ( { a } << { b } ) " , Ops . CMPLT : lambda a , b , dtype : f " ( { a } < { b } ) " ,
Ops . WHERE : lambda a , b , c , dtype : f " ( { a } ? { b } : { c } ) " }
string_rewrite = base_rewrite
extra_matcher = extra_pm
def get_kernel_modifier ( self , uops : list [ UOp ] ) - > str : return " "
def render_kernel ( self , function_name : str , kernel : list [ str ] , bufs : list [ tuple [ str , tuple [ DType , bool ] ] ] , uops : list [ UOp ] , prefix = None ) - > str :
tmp = " const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; \n " if any ( isinstance ( dtype , ImageDType ) for _ , ( dtype , _ ) in bufs ) else " " # noqa: E501
buftypes = [ ( name , self . render_dtype ( dtype , mutable ) + self . buffer_suffix if isinstance ( dtype , ( ImageDType , PtrDType ) ) else
self . arg_int_prefix if dtype == dtypes . int else None ) for name , ( dtype , mutable ) in bufs ]
prg = ' ' . join ( [ f " { self . kernel_prefix } void { self . get_kernel_modifier ( uops ) } { function_name } ( " , ] +
[ ' , ' . join ( [ f ' { t } { name } ' for name , t in buftypes ] + self . extra_args ) ] +
[ " ) { \n " + tmp ] + [ ' \n ' . join ( kernel ) , " \n } " ] )
return prg if prefix is None else " \n " . join ( prefix ) + f " \n { prg } "
def render_cast ( self , dt : DType , val : str ) - > str : return f " ( { self . render_dtype ( dt ) } )( { val } ) "
def render_dtype ( self , dt : DType , mutable = True ) - > str :
if isinstance ( dt , ImageDType ) : return f " { ' write_only ' if mutable else ' read_only ' } image2d_t "
if isinstance ( dt , PtrDType ) :
return ( self . smem_prefix if dt . local and self . smem_prefix_for_cast else self . buffer_prefix ) + self . render_dtype ( dt . base ) + " * "
if dt . count > 1 : return self . type_map . get ( scalar := dt . scalar ( ) , scalar . name ) . replace ( " " , " _ " ) + str ( dt . count )
return self . type_map . get ( scalar := dt . scalar ( ) , scalar . name )
def __getitem__ ( self , key ) : return self . r [ key ] # hacky helper
def _render ( self , uops : list [ UOp ] ) - > tuple [ str , list [ str ] , list [ tuple [ str , tuple [ DType , bool ] ] ] ] :
r : dict [ UOp , str ] = { }
self . r = r
child_count = Counter ( v for ru in uops for v in ru . src )
bufs : dict [ UOp , tuple [ str , tuple [ DType , bool ] ] ] = { }
kernel = [ ]
depth = 1
c : defaultdict [ str , int ] = defaultdict ( int )
name = " test "
for u in uops :
if u . op is Ops . NAME :
name = u . arg
continue
if u . op in ( Ops . DEFINE_GLOBAL , Ops . DEFINE_VAR ) :
r [ u ] = f " data { u . arg } " if u . op is Ops . DEFINE_GLOBAL else u . arg [ 0 ]
bufs [ u ] = ( r [ u ] , ( u . dtype , False ) )
continue
# mark buffers that we store to writable
if u . op is Ops . STORE :
for up in u . src [ 0 ] . toposort :
if up . op is Ops . DEFINE_GLOBAL : bufs [ up ] = ( bufs [ up ] [ 0 ] , ( bufs [ up ] [ 1 ] [ 0 ] , True ) )
# naming
prefix = None
if u . op is Ops . SPECIAL :
r [ u ] = u . arg [ 0 ]
else :
prefix = { Ops . RANGE : " ridx " , Ops . WMMA : " wmma " , Ops . DEFINE_LOCAL : " temp " , Ops . CONST : " const " ,
Ops . CAST : " cast " , Ops . BITCAST : " cast " , Ops . GEP : " gep " , Ops . VECTORIZE : " cast " , Ops . NOOP : " precast " ,
Ops . INDEX : " bidx " , Ops . DEFINE_ACC : " acc " , Ops . LOAD : " val " } . get ( u . op , " alu " )
r [ u ] = f " { prefix } { c [ prefix ] } "
l = cast ( str , self . string_rewrite . rewrite ( u , ctx = self ) )
assert l is not None , f " failed to render { u . op } { u . dtype } { [ ( x . op , x . dtype ) for x in u . src ] } { u . arg } "
if u . op in { Ops . ENDIF , Ops . ENDRANGE } : depth - = 1
if ( u . op is not Ops . CAST or u . dtype . vcount == 1 ) and ( u . op in { Ops . CONST , Ops . GEP , Ops . INDEX , Ops . CUSTOMI } or \
( u . op in { Ops . VECTORIZE , * GroupOp . ALU , Ops . CAST , Ops . BITCAST } and child_count [ u ] == 1 and not getenv ( " EXPAND_SSA " ) ) ) :
r [ u ] = l
else :
if u . op in { Ops . RANGE , Ops . ASSIGN , Ops . DEFINE_LOCAL } or u . dtype == dtypes . void :
if u . op is Ops . ASSIGN : r [ u ] = r [ u . src [ 0 ] ]
else :
l = f " { self . render_dtype ( u . dtype ) } { r [ u ] } = { l } " + ( " ; " if u . op is not Ops . SPECIAL else " " )
kernel . append ( " " * depth + l )
if prefix : c [ prefix ] + = 1 # if it was used, increment
if u . op in { Ops . IF , Ops . RANGE } : depth + = 1
del self . r
# NOTE: this relies on bufs dict preserving order
return ( name , kernel , list ( bufs . values ( ) ) )
def render ( self , uops : list [ UOp ] ) - > str : return self . render_kernel ( * self . _render ( uops ) , uops )
class ClangRenderer ( CStyleLanguage ) :
device = " CPU "
float4 = " (float4) "
has_local = False
global_max = None
infinity = " __builtin_inff() "
nan = ' __builtin_nanf( " " ) '
amx_tc = [ TensorCore ( dims = ( sz , sz , 1 ) , threads = 1 , elements_per_thread = ( sz , sz , sz * sz ) , dtype_in = dt , dtype_out = dt , swizzle = ( None , ( ( ) , ( 4 , 5 , 6 , 7 , 0 , 1 , 2 , 3 ) ) ) ,
opts = ( " u0 " , " u0 " , " u0 " , " u0 " , " u1 " , " u1 " , " u1 " , " u1 " ) ) for dt , sz in [ ( dt , 64 / / dt . itemsize ) for dt in [ dtypes . float ] ] ]
if AMX : tensor_cores = amx_tc
# language options
buffer_suffix = " restrict "
type_map = { dtypes . bool : " _Bool " , dtypes . half : " __fp16 " }
code_for_op = { * * ( { k : v for k , v in CStyleLanguage . code_for_op . items ( ) if k not in [ Ops . EXP2 , Ops . SIN , Ops . LOG2 ] } ) ,
Ops . SQRT : lambda x , dtype : f " __builtin_sqrt( { x } ) " if dtype == dtypes . float64 else f " __builtin_sqrtf( { x } ) " }
# LLVM legalizes double => half cast on systems that don't support it natively (like x86 cpus without AVX512-FP16) into a compiler-rt libcall.
extra_matcher = PatternMatcher ( [ ( UPat . var ( " x " , dtypes . float64 ) . cast ( dtypes . float16 ) , lambda x : x . cast ( dtypes . float32 ) . cast ( dtypes . float16 ) ) ,
( UPat ( Ops . SQRT , name = " alu " ) , no_vectorized_alu ) , ] ) + CStyleLanguage . extra_matcher
if sys . platform == ' win32 ' :
kernel_prefix = " __attribute__((ms_abi)) "
def render_vector_prefix ( self , dt : DType ) - > str :
# round (down) to power of two (this is actually the default clang behavior)
alignment = 2 * * int ( math . log2 ( dt . itemsize ) ) if getenv ( " ALIGNED " , 1 ) else 1
return f " typedef { self . render_dtype ( dt . scalar ( ) ) } { self . render_dtype ( dt ) } __attribute__((aligned( { alignment } ),vector_size( { dt . itemsize } ))); "
def _render_defines ( self , uops ) - > list [ str ] :
prefix = [ self . render_vector_prefix ( dt ) for dt in uops_to_dtypes ( uops ) if dt . count > 1 ]
# https://github.com/corsix/amx
for name , ( N , M , _ ) , dtype_in , _ , _ , _ , _ , _ in dedup ( [ uop . arg for uop in uops if uop . op is Ops . WMMA ] ) :
prefix + = [
' #define AMX_SET(imm5) __asm( " nop \\ nnop \\ nnop \\ n.word (0x201000+( % 0<<5)+ % 1) " : : " i " (17), " i " (imm5) : " memory " ) ' ,
' #define AMX(op, gpr, btf) __asm( " .word (0x201000+( % 0 << 5)+0 % 1-((0 % 1>>4)*6)) " : : " i " (op), " r " ((unsigned long long)(gpr)+(btf)) : " memory " ) ' ,
]
# 'static' in C roughly means that function symbol isn't exported. LLVM puts those symbols at the end of object file which allows Clang JIT
# to just jump at the start of a shellcode whithout having to deal with symbols or trampolines at all. This is better than having to inline
# wmma function every time it is called or wasting complexity on a symbol parsing and a memory page on trampoline.
prefix + = [ f """ static { ( out := self . render_dtype ( dtype_in . vec ( N * N ) ) ) } __ { name } ( { self . render_dtype ( dtype_in . vec ( N ) ) } data1, { self . render_dtype ( dtype_in . vec ( M ) ) } data2, { out } data0) {{
AMX_SET ( 0 ) ; \n for ( int ridx0 = 0 ; ridx0 < 16 ; ridx0 + + ) { { AMX ( 4 , ( int * ) ( & data0 ) , 0 ull << 62 | ( ridx0 * 4 ull ) << 56 | ridx0 * 64 ull ) ; } }
AMX ( 0 , ( int * ) ( & data2 ) , 0 ull << 62 ) ; AMX ( 1 , ( int * ) ( & data1 ) , 0 ull << 62 ) ; AMX ( 12 , 0 , 0 ull ) ;
for ( int ridx0 = 0 ; ridx0 < 16 ; ridx0 + + ) { { AMX ( 5 , ( int * ) ( & data0 ) , 0 ull << 62 | ( ridx0 * 4 ull ) << 56 | ridx0 * 64 ull ) ; } } \n AMX_SET ( 1 ) ; \n return data0 ; \n } } """ ] # noqa: E501
return prefix
def _render_body ( self , function_name , kernel , bufs , uops , pref = None ) - > str : return super ( ) . render_kernel ( function_name , kernel , bufs , uops , pref )
def _render_entry ( self , function_name : str , bufs : list [ tuple [ str , tuple [ DType , bool ] ] ] ) - > str : return " "
def render_kernel ( self , function_name , kernel , bufs , uops , prefix = None ) - > str :
defines = ' \n ' . join ( self . _render_defines ( uops ) )
return defines + " \n " + self . _render_body ( function_name , kernel , bufs , uops , prefix ) + " \n " + self . _render_entry ( function_name , bufs )
class OpenCLRenderer ( CStyleLanguage ) :
device = " GPU "
# language options
kernel_prefix = " __kernel "
buffer_prefix = " __global "
smem_align = " __attribute__ ((aligned (16))) "
smem_prefix = " __local "
barrier = " barrier(CLK_LOCAL_MEM_FENCE); "
float4 = " (float4) "
code_for_workitem = { " g " : lambda x : f " get_group_id( { x } ) " , " l " : lambda x : f " get_local_id( { x } ) " , " i " : lambda x : f " get_global_id( { x } ) " }
type_map = { dtypes . int8 : " char " , dtypes . uint8 : " uchar " , dtypes . uint32 : " uint " , dtypes . uint16 : " ushort " , dtypes . uint64 : " ulong " ,
dtypes . bfloat16 : " ushort " }
string_rewrite = PatternMatcher ( [
( UPat ( Ops . BITCAST , name = " x " ) , lambda ctx , x : f " as_ { ctx . render_dtype ( x . dtype ) } ( { ctx [ x . src [ 0 ] ] } ) " ) ,
# load/store image (OpenCL)
( UPat ( Ops . LOAD , dtype = dtypes . float . vec ( 4 ) , src = ( UPat . var ( ' buf ' ) . index ( UPat . var ( ' idx ' , dtypes . int . vec ( 2 ) ) , UPat . var ( " gate " ) ) , UPat . var ( " var " ) ) ) ,
lambda ctx , buf , idx , var , gate : f " ( { ctx [ gate ] } ?read_imagef( { ctx [ buf ] } , smp, { ctx [ idx ] } ): { ctx [ var ] } ) " ) ,
( UPat ( Ops . LOAD , dtype = dtypes . float . vec ( 4 ) , src = ( UPat . var ( ' buf ' ) . index ( UPat . var ( ' idx ' , dtypes . int . vec ( 2 ) ) ) , ) ) ,
lambda ctx , buf , idx : f " read_imagef( { ctx [ buf ] } , smp, { ctx [ idx ] } ) " ) ,
( UPat ( Ops . STORE , src = ( UPat . var ( ' buf ' ) . index ( UPat . var ( ' idx ' , dtypes . int . vec ( 2 ) ) ) , UPat . var ( " var " , dtypes . float . vec ( 4 ) ) ) , allow_any_len = True ) ,
lambda ctx , buf , idx , var : f " write_imagef( { ctx [ buf ] } , { ctx [ idx ] } , { ctx [ var ] } ); " ) ,
] ) + base_rewrite
def render_kernel ( self , function_name , kernel , bufs , uops , prefix = None ) - > str :
if any ( uop . dtype . base == dtypes . half for uop in uops ) : prefix = ( [ " #pragma OPENCL EXTENSION cl_khr_fp16 : enable " ] + ( prefix or [ ] ) )
return super ( ) . render_kernel ( function_name , kernel , bufs , uops , prefix )
class IntelRenderer ( OpenCLRenderer ) :
device , suffix , kernel_prefix = " GPU " , " INTEL " , " __attribute__((intel_reqd_sub_group_size(8))) \n " + " __kernel "
tensor_cores = [ TensorCore ( dims = ( 8 , 8 , 16 ) , threads = 8 , elements_per_thread = ( 16 , 16 , 8 ) , dtype_in = dtypes . half , dtype_out = dtypes . float ,
opts = ( " l0 " , " l0 " , " l0 " , " u1 " , " u1 " , " u1 " ) , swizzle = ( ( ( 4 , 5 , 6 ) , ( 0 , 1 , 2 , 3 , 7 , 8 , 9 ) ) , ( ( 0 , 1 , 2 ) , ( 7 , 8 , 9 , 3 , 4 , 5 , 6 ) ) ) ) ]
string_rewrite = PatternMatcher ( [
( UPat ( Ops . CAST , dtype = dtypes . bfloat16 , src = ( UPat . var ( ' x ' , dtype = dtypes . float ) ) ) , lambda ctx , x : f " intel_convert_bfloat16_as_ushort( { ctx [ x ] } ) " ) ,
( UPat ( Ops . CAST , dtype = dtypes . float , src = ( UPat . var ( ' x ' , dtype = dtypes . bfloat16 ) ) ) , lambda ctx , x : f " intel_convert_as_bfloat16_float( { ctx [ x ] } ) " ) ,
] ) + OpenCLRenderer . string_rewrite
def render_kernel ( self , function_name , kernel , bufs , uops , prefix = None ) - > str :
prefix = [ ]
for arg in dedup ( [ uop . arg for uop in uops if uop . op is Ops . WMMA ] ) :
dt_in = ( " ushort " , " bf16 " ) if arg [ 2 ] == dtypes . bfloat16 else ( arg [ 2 ] . name , " f16 " )
prefix . append ( f """ { arg [ 3 ] . name } 8 __ { arg [ 0 ] } ( { dt_in [ 0 ] } 16 a, { dt_in [ 0 ] } 16 b, { arg [ 3 ] . name } 8 c) {{
return intel_sub_group_ { dt_in [ 1 ] } _ { dt_in [ 1 ] } _matrix_mad_k16 ( as_int8 ( a ) , as_int8 ( b ) , c ) ; \n } } """ )
return super ( ) . render_kernel ( function_name , kernel , bufs , uops , prefix or None )
class MetalRenderer ( CStyleLanguage ) :
device = " METAL "
shared_max = 32768
tensor_cores = [ TensorCore ( dims = ( 8 , 8 , 8 ) , threads = 32 , elements_per_thread = ( 2 , 2 , 2 ) , dtype_in = di , dtype_out = do , opts = ( " u0 " , " l0 " , " l1 " , " l1 " , " l0 " , " l1 " ) ,
swizzle = ( ( ( 6 , 1 , 2 , 7 , 4 ) , ( 8 , 0 , 3 , 5 ) ) , ( ( 0 , 5 , 6 , 3 , 7 ) , ( 1 , 2 , 4 , 8 ) ) ) ) for di , do in [ ( dtypes . float , dtypes . float ) , ( dtypes . half , dtypes . float ) ,
( dtypes . half , dtypes . half ) , ( dtypes . bfloat16 , dtypes . float ) , ( dtypes . bfloat16 , dtypes . bfloat16 ) ] ]
def __init__ ( self ) : self . tensor_cores = MetalRenderer . tensor_cores if hasattr ( os , ' uname ' ) and os . uname ( ) . machine == " arm64 " else [ ]
# language options
kernel_prefix = " kernel "
buffer_prefix = " device "
smem_prefix = " threadgroup "
arg_int_prefix = " constant int& "
barrier = " threadgroup_barrier(mem_flags::mem_threadgroup); "
float4 = " float4 "
code_for_workitem = { " g " : lambda x : f " gid. { chr ( 120 + int ( x ) ) } " , " l " : lambda x : f " lid. { chr ( 120 + int ( x ) ) } " }
# uint3 used for gid/lid - TODO: this should probably be `ushort3 lid [[thread_position_in_threadgroup]]`
extra_args = [ ' uint3 gid [[threadgroup_position_in_grid]] ' , ' uint3 lid [[thread_position_in_threadgroup]] ' ]
type_map = { dtypes . bfloat16 : " bfloat " }
# precise::sin
code_for_op = { * * CStyleLanguage . code_for_op , Ops . SIN : lambda x , dtype : f " precise::sin( { x } ) " }
# upcast to float32 all the ops that don't support bfloat16
extra_matcher = PatternMatcher ( [
# NOTE: this is copied from PTX
( UPat ( ( Ops . SQRT , Ops . EXP2 , Ops . LOG2 , Ops . SIN ) , dtype = dtypes . bfloat16 , name = " x " ) ,
lambda x : ( UOp ( x . op , dtypes . float , tuple ( vv . cast ( dtypes . float ) for vv in x . src ) , x . arg ) . cast ( dtypes . bfloat16 ) ) ) ,
] ) + extra_pm
string_rewrite = PatternMatcher ( [
( UPat ( Ops . BITCAST , name = " x " ) , lambda ctx , x : f " as_type< { ctx . render_dtype ( x . dtype ) } >( { ctx [ x . src [ 0 ] ] } ) " ) ,
] ) + base_rewrite
def render_kernel ( self , function_name , kernel , bufs , uops , prefix = None ) :
prefix , wmma_args = [ " #include <metal_stdlib> " , " using namespace metal; " ] , set ( [ uop . arg for uop in uops if uop . op is Ops . WMMA ] )
for arg in wmma_args : prefix . append (
f """ { ( dtype_out := self . render_dtype ( arg [ 3 ] . vec ( 2 ) ) ) } __ { arg [ 0 ] } ( { ( dtype_in := self . render_dtype ( arg [ 2 ] . vec ( 2 ) ) ) } a, { dtype_in } b, { dtype_out } c) {{
simdgroup_ { self . render_dtype ( arg [ 2 ] ) } 8 x8 mat_a , mat_b ; simdgroup_ { self . render_dtype ( arg [ 3 ] ) } 8 x8 mat_c ;
mat_a . thread_elements ( ) [ 0 ] = a [ 0 ] ; mat_b . thread_elements ( ) [ 0 ] = b [ 0 ] ; mat_c . thread_elements ( ) [ 0 ] = c [ 0 ] ;
mat_a . thread_elements ( ) [ 1 ] = a [ 1 ] ; mat_b . thread_elements ( ) [ 1 ] = b [ 1 ] ; mat_c . thread_elements ( ) [ 1 ] = c [ 1 ] ;
simdgroup_multiply_accumulate ( mat_c , mat_a , mat_b , mat_c ) ; \n return { dtype_out } ( mat_c . thread_elements ( ) [ 0 ] , mat_c . thread_elements ( ) [ 1 ] ) ; \n } } """ )
return super ( ) . render_kernel ( function_name , kernel , bufs , uops , prefix )
_nms = " xyzwabcdefghijkl "
cuda_tc_opts = ( " u0 " , " l0 " , " l0 " , " l1 " , " l1 " , " l1 " , " u1 " ) # shared by all shapes with M=16 N=8
class CUDARenderer ( CStyleLanguage ) :
device = " CUDA "
global_max = ( 2147483647 , 65535 , 65535 )
local_max = ( 1024 , 1024 , 64 )
shared_max = 49152
# https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-multiply-accumulate-instructions
tc_81616 = [ TensorCore ( dims = ( 8 , 16 , 16 ) , threads = 32 , elements_per_thread = ( 8 , 4 , 4 ) , dtype_in = di , dtype_out = do , opts = cuda_tc_opts ,
swizzle = ( ( ( 6 , 7 , 2 , 3 , 4 ) , ( 0 , 1 , 9 , 5 , 10 , 8 ) ) , ( ( 6 , 7 , 9 , 0 , 1 ) , ( 2 , 3 , 4 , 10 , 5 , 8 ) ) ) ) for di , do in [ ( dtypes . half , dtypes . float ) , ( dtypes . bfloat16 , dtypes . float ) ,
( dtypes . half , dtypes . half ) ] ]
tc_8168_f16 = [ TensorCore ( dims = ( 8 , 16 , 8 ) , threads = 32 , elements_per_thread = ( 4 , 2 , 4 ) , dtype_in = di , dtype_out = do , opts = cuda_tc_opts ,
swizzle = ( ( ( 6 , 7 , 2 , 3 , 4 ) , ( 0 , 1 , 8 , 5 , 9 ) ) , ( ( 6 , 7 , 8 , 0 , 1 ) , ( 2 , 3 , 4 , 9 , 5 ) ) ) ) for di , do in [ ( dtypes . half , dtypes . float ) , ( dtypes . half , dtypes . half ) ] ]
tc_8168_tf32 = [ TensorCore ( dims = ( 8 , 16 , 8 ) , threads = 32 , elements_per_thread = ( 4 , 2 , 4 ) , dtype_in = dtypes . float , dtype_out = dtypes . float , opts = cuda_tc_opts ,
swizzle = ( ( ( 5 , 6 , 2 , 3 , 4 ) , ( 0 , 1 , 8 , 9 , 7 ) ) , ( ( 5 , 6 , 8 , 0 , 1 ) , ( 2 , 3 , 4 , 9 , 7 ) ) ) ) ]
tc_sm80 = tc_81616 + tc_8168_f16
if getenv ( " ALLOW_TF32 " , 0 ) : tc_sm80 + = tc_8168_tf32
tc_sm75 = tc_8168_f16
def __init__ ( self , arch : str ) :
self . tensor_cores , self . arch = CUDARenderer . tc_sm80 if int ( arch [ 3 : ] ) > = 80 else CUDARenderer . tc_sm75 if int ( arch [ 3 : ] ) > = 75 else [ ] , arch
def __reduce__ ( self ) : return self . __class__ , ( self . arch , )
# language options
kernel_prefix = " extern \" C \" __global__ "
smem_prefix = " __shared__ "
smem_prefix_for_cast = False
barrier = " __syncthreads(); "
float4 = " make_float4 "
code_for_workitem = { " g " : lambda x : f " blockIdx. { chr ( 120 + int ( x ) ) } " , " l " : lambda x : f " threadIdx. { chr ( 120 + int ( x ) ) } " ,
" i " : lambda x : f " (blockIdx. { chr ( 120 + int ( x ) ) } *blockDim. { chr ( 120 + int ( x ) ) } +threadIdx. { chr ( 120 + int ( x ) ) } ) " }
code_for_op = { * * CStyleLanguage . code_for_op ,
Ops . SIN : lambda x , dtype : f " hsin( { x } ) " if dtype in ( dtypes . half , dtypes . bfloat16 ) else f " sin( { x } ) " ,
Ops . LOG2 : lambda x , dtype : f " hlog2( { x } ) " if dtype in ( dtypes . half , dtypes . bfloat16 ) else f " log2( { x } ) " ,
Ops . EXP2 : lambda x , dtype : f " hexp2( { x } ) " if dtype in ( dtypes . half , dtypes . bfloat16 ) else f " exp2( { x } ) " ,
Ops . SQRT : lambda x , dtype : f " hsqrt( { x } ) " if dtype in ( dtypes . half , dtypes . bfloat16 ) else f " sqrt( { x } ) " ,
Ops . RECIP : lambda x , dtype : f " hrcp( { x } ) " if dtype in ( dtypes . half , dtypes . bfloat16 ) else f " (1/ { x } ) " }
type_map = { dtypes . bfloat16 : " nv_bfloat16 " }
def render_vector_prefix ( self , dt : DType ) - > str :
vec , scal = self . render_dtype ( dt ) , self . render_dtype ( dt . scalar ( ) ) ,
elems , header = ' , ' . join ( _nms [ : dt . count ] ) , ' , ' . join ( [ f " { scal } { x } " for x in _nms [ : dt . count ] ] )
return f " struct __align__( { dt . itemsize } ) { vec } {{ { scal } { elems } ; }} ; __device__ { vec } make_ { vec } ( { header } ) {{ { vec } r= {{ { elems } }} ; return r; }} "
def render_kernel ( self , function_name , kernel , bufs , uops , prefix = None ) :
# TODO: why is dtypes.bfloat16.name == "__bf16"? would be easier not override dtypes.name
prefix = [ " #define INFINITY (__int_as_float(0x7f800000)) " , " #define NAN (__int_as_float(0x7fffffff)) " ]
used_dtypes = uops_to_dtypes ( uops )
if any ( dt . scalar ( ) == dtypes . half for dt in used_dtypes ) : prefix . append ( " #include <cuda_fp16.h> " )
if any ( dt . scalar ( ) == dtypes . bfloat16 for dt in used_dtypes ) : prefix . append ( " #include <cuda_bf16.h> " )
prefix + = [ self . render_vector_prefix ( dt ) for dt in used_dtypes if dt . count in ( 4 , 8 ) and dt . scalar ( ) in { dtypes . half , dtypes . bfloat16 } ]
dt_map_in = { dtypes . float : " tf32 " , dtypes . half : " f16 " , dtypes . bfloat16 : " bf16 " }
dt_map_out = { dtypes . float : " f32 " , dtypes . half : " f16 " }
for name , ( N , M , K ) , dtype_in , dtype_out , _ , _ , upcast_axes , _ in dedup ( [ uop . arg for uop in uops if uop . op is Ops . WMMA ] ) :
upcast_sizes = [ prod ( size for _ , size in upcast ) for upcast in upcast_axes ]
wmma_dtypes = [ self . render_dtype ( dtype . vec ( size ) ) for dtype , size in zip ( [ dtype_in , dtype_in , dtype_out ] , upcast_sizes ) ]
n_operands = [ size * dtype . itemsize / / 4 for dtype , size in zip ( [ dtype_in , dtype_in , dtype_out ] , upcast_sizes ) ] # 4 => CUDA reg size in bytes
operands = [ f " % { i } " for i in range ( sum ( n_operands ) ) ]
# mma operands => {c}, {a}, {b}, {c}
prefix . append ( f """ __device__ { wmma_dtypes [ 2 ] } __ { name } ( { wmma_dtypes [ 0 ] } a, { wmma_dtypes [ 1 ] } b, { wmma_dtypes [ 2 ] } c) {{
int * a_pk = ( int * ) ( & a ) , * b_pk = ( int * ) ( & b ) , * c_pk = ( int * ) ( & c ) ;
asm ( " mma.sync.aligned.m {M} n {N} k {K} .row.col. {dt_map_out[dtype_out]} . {dt_map_in[dtype_in]} . {dt_map_in[dtype_in]} . {dt_map_out[dtype_out]} "
" {{ { " , " .join(operands[:n_operands[2]])}}}, {{ { " , " .join(operands[n_operands[2]:n_operands[2]+n_operands[0]])}}}, "
" {{ { " , " .join(operands[-n_operands[1]:])}}}, {{ { " , " .join(operands[:n_operands[2]])}}}; "
: { " , " . join ( [ f ' " +r " (c_pk[ { i } ]) ' for i in range ( n_operands [ 2 ] ) ] ) }
: { " , " . join ( [ f ' " r " (a_pk[ { i } ]) ' for i in range ( n_operands [ 0 ] ) ] ) } , { " , " . join ( [ f ' " r " (b_pk[ { i } ]) ' for i in range ( n_operands [ 1 ] ) ] ) } ) ;
return c ; \n } } """ )
return super ( ) . render_kernel ( function_name , kernel , bufs , uops , prefix = prefix )
def get_kernel_modifier ( self , uops : list [ UOp ] ) - > str :
maxThreadsPerBlock = prod ( u . arg [ 1 ] for u in uops if u . op is Ops . SPECIAL and u . arg [ 0 ] [ 0 ] == " l " )
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
return f " __launch_bounds__( { maxThreadsPerBlock } ) "
def cast_float_to_bf16 ( x : UOp ) - > UOp :
assert x . dtype == dtypes . float , " cast float -> bf16 must start with float "
x = x . bitcast ( dtypes . uint )
x = ( - x & 0x7f800000 ) . where ( x + ( ( x >> 16 ) & 1 ) + 0x7fff , ( x & 0xffff ) . where ( ( x | 0x10000 ) , x ) )
return ( x >> 16 ) . cast ( dtypes . ushort ) . bitcast ( dtypes . bfloat16 )
class AMDRenderer ( CStyleLanguage ) :
device = " AMD "
shared_max = 65536
# NOTE: this is only really needed on gfx12, even though gfx11 reports the same limitation
global_max = ( 2147483647 , 65535 , 65535 )
# https://gpuopen.com/learn/wmma_on_rdna3/
tensor_cores = [ TensorCore ( dims = ( 16 , 16 , 16 ) , threads = 32 , elements_per_thread = ( 16 , 16 , 8 ) , dtype_in = di , dtype_out = do ,
opts = ( " l0 " , " l0 " , " l0 " , " l0 " , " l1 " , " u1 " , " u1 " , " u1 " ) , swizzle = ( ( ( 4 , 9 , 10 , 11 , 0 ) , ( 1 , 2 , 3 , 5 , 6 , 7 , 8 ) ) , ( ( 0 , 1 , 2 , 3 , 4 ) , ( 9 , 10 , 11 , 5 , 6 , 7 , 8 ) ) ) )
for di , do in [ ( dtypes . half , dtypes . float ) , ( dtypes . half , dtypes . half ) ] ]
# https://gpuopen.com/learn/amd-lab-notes/amd-lab-notes-matrix-cores-readme
tensor_cores_mfma = [ TensorCore ( dims = ( 16 , 16 , 16 ) , threads = 64 , elements_per_thread = ( 4 , 4 , 4 ) , dtype_in = di , dtype_out = do ,
opts = ( " l0 " , " l0 " , " l0 " , " l0 " , " u1 " , " u1 " , " l1 " , " l1 " ) , swizzle = ( ( ( 10 , 11 , 4 , 5 , 8 , 9 ) , ( 0 , 1 , 2 , 3 , 6 , 7 ) ) , ( ( 0 , 1 , 2 , 3 , 8 , 9 ) , ( 4 , 5 , 10 , 11 , 6 , 7 ) ) ) )
for di , do in [ ( dtypes . half , dtypes . float ) ] ]
def __init__ ( self , arch : str ) : # gfx942 => MI300, gfx1100 => RX 7900, gfx1201 => RX 9700
self . arch = arch
# TODO: fix tensor cores for gfx1201
self . tensor_cores = \
AMDRenderer . tensor_cores_mfma if arch . split ( " : " ) [ 0 ] == " gfx942 " else AMDRenderer . tensor_cores if arch . split ( " : " ) [ 0 ] != " gfx1201 " else [ ]
if self . arch . split ( " : " ) [ 0 ] == " gfx942 " :
self . string_rewrite = PatternMatcher ( [
( UPat ( Ops . WMMA , name = " x " ) , lambda ctx , x : f " __ { x . arg [ 0 ] } ( { ctx [ x . src [ 0 ] ] } , { ctx [ x . src [ 1 ] ] } , { ctx [ x . src [ 2 ] ] } , 0, 0, 0) " ) ] ) + base_rewrite
def __reduce__ ( self ) : return self . __class__ , ( self . arch , )
# language options
ockl = [ ( f " __ockl_get_ { name } " , " unsigned int " , " size_t " , " const " ) for name in [ " local_id " , " group_id " , " local_size " ] ]
ocml = [ ( f " __ocml_ { name } _f { n } " , f " { dt } , { dt } " if " fmax " == name else dt , dt , atr )
for dt , n in [ ( dtype . name , dtype . itemsize * 8 ) for dtype in [ dtypes . float , dtypes . double , dtypes . half ] ]
for name , atr in [ ( " fmax " , " const " ) , ( " exp2 " , " pure " ) , ( " log2 " , " pure " ) , ( " sqrt " , " const " ) , ( " sin " , " " ) ] ]
kernel_prefix = " \n " . join ( f ' extern " C " __attribute__((device { f " , { atr } " if atr else " " } )) { dto } { meth } ( { dti } ); ' for meth , dti , dto , atr in ockl + ocml )
kernel_prefix + = ' \n extern " C " __attribute__((global)) '
code_for_workitem = { " g " : lambda x : f " __ockl_get_group_id( { x } ) " , " l " : lambda x : f " __ockl_get_local_id( { x } ) " ,
" i " : lambda x : f " (__ockl_get_group_id( { x } )*__ockl_get_local_size( { x } )+__ockl_get_local_id( { x } )) " }
code_for_op = { * * CStyleLanguage . code_for_op ,
Ops . SIN : lambda x , dtype : f " __ocml_sin_f { { dtypes . half : 16 , dtypes . double : 64 } . get ( dtype , 32 ) } ( { x } ) " ,
Ops . LOG2 : lambda x , dtype : f " __ocml_log2_f { { dtypes . half : 16 , dtypes . double : 64 } . get ( dtype , 32 ) } ( { x } ) " ,
Ops . EXP2 : lambda x , dtype : f " __ocml_exp2_f { { dtypes . half : 16 , dtypes . double : 64 } . get ( dtype , 32 ) } ( { x } ) " ,
Ops . SQRT : lambda x , dtype : f " __ocml_sqrt_f { { dtypes . half : 16 , dtypes . double : 64 } . get ( dtype , 32 ) } ( { x } ) " }
smem_prefix = " __attribute__((shared)) "
smem_prefix_for_cast : bool = False
barrier = ' __builtin_amdgcn_fence(__ATOMIC_RELEASE, " workgroup " ); ' + ' __builtin_amdgcn_s_barrier(); ' + \
' __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, " workgroup " ); '
float4 = " make_float4 "
type_map = { dtypes . bfloat16 : " hip_bfloat16 " }
extra_matcher = PatternMatcher ( [
# cast bfloat16 alus to float
( UPat ( Ops . WHERE , src = ( UPat . var ( " b " ) , UPat . var ( " x " , dtype = dtypes . bfloat16 ) , UPat . var ( " y " , dtype = dtypes . bfloat16 ) ) ) ,
lambda b , x , y : UOp ( Ops . WHERE , dtype = dtypes . float , src = ( b , x . cast ( dtypes . float ) , y . cast ( dtypes . float ) ) ) . cast ( dtypes . bfloat16 ) ) ,
( UPat ( GroupOp . ALU , dtype = dtypes . bfloat16 , name = " x " ) ,
lambda x : UOp ( x . op , dtypes . float , tuple ( vv . cast ( dtypes . float ) for vv in x . src ) , x . arg ) . cast ( dtypes . bfloat16 ) ) ,
( UPat ( GroupOp . ALU , dtypes . bool , name = " alu " , src = ( UPat . var ( " x " , dtype = dtypes . bfloat16 ) , UPat . var ( " y " , dtype = dtypes . bfloat16 ) ) ) ,
lambda alu , x , y : UOp ( alu . op , dtypes . bool , ( x . cast ( dtypes . float ) , y . cast ( dtypes . float ) ) , alu . arg ) ) ,
# add float intermediate casting for bfloat16
( UPat ( Ops . CAST , name = " x " , src = UPat . var ( " y " , dtypes . bfloat16 ) ) , lambda x , y : y . cast ( dtypes . float ) . cast ( x . dtype ) if x . dtype != dtypes . float else None ) ,
( UPat ( Ops . CAST , dtypes . bfloat16 , UPat . var ( " x " ) ) , lambda x : x . cast ( dtypes . float ) . cast ( dtypes . bfloat16 ) if x . dtype != dtypes . float else None ) ,
# bfloat16 casting
( UPat . cvar ( ' x ' , dtypes . bfloat16 ) , lambda x : cast_float_to_bf16 ( UOp . const ( dtypes . float , x . arg ) ) ) ,
( UPat ( Ops . CAST , dtypes . float , UPat . var ( " x " , dtypes . bfloat16 ) ) , lambda x : ( x . bitcast ( dtypes . ushort ) . cast ( dtypes . uint ) << 16 ) . bitcast ( dtypes . float ) ) ,
( UPat ( Ops . CAST , dtype = dtypes . bfloat16 , src = UPat . var ( " x " , dtype = dtypes . float ) ) , cast_float_to_bf16 ) ] ) + extra_pm
def render_vector_prefix ( self , dtype : DType ) - > str :
vec , scal = self . render_dtype ( dtype ) , self . render_dtype ( dtype . scalar ( ) )
return f " typedef { scal } { vec } __attribute__((ext_vector_type( { dtype . count } ))); \n static inline __attribute__((device)) " + \
f " { vec } make_ { vec } ( { ' , ' . join ( [ f ' { scal } { x } ' for x in _nms [ : dtype . count ] ] ) } ) {{ return {{ { ' , ' . join ( _nms [ : dtype . count ] ) } }} ; }} "
def render_kernel ( self , function_name , kernel , bufs , uops , prefix = None ) - > str :
prefix = [ " #define INFINITY (__builtin_inff()) " , " #define NAN (__builtin_nanf( \" \" )) " , " typedef long unsigned int size_t; " , " #define half _Float16 " ]
used_dtypes = uops_to_dtypes ( uops )
if any ( dt . scalar ( ) == dtypes . bfloat16 for dt in used_dtypes ) : prefix . append ( " typedef unsigned short hip_bfloat16; " )
prefix + = [ self . render_vector_prefix ( dt ) for dt in used_dtypes if dt . count > 1 ]
for arg in dedup ( [ uop . arg for uop in uops if uop . op is Ops . WMMA ] ) : # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
if self . arch . split ( " : " ) [ 0 ] == " gfx942 " : prefix . append ( f " #define __ { arg [ 0 ] } __builtin_amdgcn_mfma_f32_16x16x16f16 " )
elif arg [ 3 ] == dtypes . float : prefix . append ( f " #define __ { arg [ 0 ] } __builtin_amdgcn_wmma_f32_16x16x16_f16_w32 " )
else : prefix . append ( f " static inline __attribute__((device)) half8 __ { arg [ 0 ] } " + """ (half16 a, half16 b, half8 c) {
half16 c_frag = { } ; half8 d ; for ( int n = 0 ; n < 8 ; n + + ) { c_frag [ n * 2 ] = c [ n ] ; }
c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32 ( a , b , c_frag , false ) ;
for ( int n = 0 ; n < 8 ; n + + ) { d [ n ] = c_frag [ n * 2 ] ; } return d ; \n } """ )
return super ( ) . render_kernel ( function_name , kernel , bufs , uops , prefix )
def get_kernel_modifier ( self , uops : list [ UOp ] ) - > str :
requiredMaxThreadsPerBlock = prod ( u . arg [ 1 ] for u in uops if u . op is Ops . SPECIAL and u . arg [ 0 ] [ 0 ] == " l " )
# https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
# NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
return f " __attribute__((amdgpu_flat_work_group_size(1, { requiredMaxThreadsPerBlock } ))) "
class NVRenderer ( CUDARenderer ) : device = " NV "
class HIPRenderer ( AMDRenderer ) : device = " HIP "
class QCOMRenderer ( OpenCLRenderer ) : device = " QCOM "