from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict import math from collections import defaultdict from tinygrad.codegen.linearizer import UOps, UOp from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps from tinygrad.helpers import ImageDType, dtypes, prod, DType, strip_parens class CStyleLanguage(NamedTuple): size_prefix: str = "int" generic_var_prefix: str = "" kernel_prefix: str = "" buffer_prefix: str = "" buffer_suffix: str = "" smem_align: str = "" smem_prefix: str = "" smem_prefix_for_cast: bool = True arg_int_prefix: str = "" barrier: str = "" xid: List[str] = [] gid: List[str] = [] lid: List[str] = [] global_max: List[int] = [] local_max: List[int] = [] extra_args: List[str] = [] float4: Optional[str] = None half_prekernel: Optional[str] = None uses_vload: bool = False external_local_bufs: bool = False uses_ptr_arithmetic: bool = False launch_bounds: bool = False code_for_op: Dict = { UnaryOps.NEG: lambda x: f"(-{x})", UnaryOps.EXP2: lambda x: f"exp2({x})", UnaryOps.LOG2: lambda x: f"log2({x})", UnaryOps.SIN: lambda x: f"sin({x})", UnaryOps.SQRT: lambda x: f"sqrt({x})", BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})", BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})", BinaryOps.MAX: lambda a,b: f"max({a},{b})", BinaryOps.MOD: lambda a,b: f"({a}%{b})", BinaryOps.CMPLT: lambda a,b: f"({a}<{b})", TernaryOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})", TernaryOps.WHERE: lambda a,b,c: f"({a}!=0?{b}:{c})" } # returns a str expression of the casted xs with the given type def render_cast(self, x:List[str], var_dtype:DType) -> str: if len(x) == 1: return f"({var_dtype.name})({x[0]})" assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}" assert self.float4 is not None, "cast is not supported on this platform" if var_dtype == dtypes._float4: return f"{self.float4}({','.join(x)})" if var_dtype == dtypes._float2: return f"{self.float4.replace('float4', 'float2')}({','.join(x)})" if var_dtype == dtypes._int2: return f"{self.float4.replace('float4', 'int2')}({','.join(x)})" raise NotImplementedError(f"no cast for {var_dtype}") # returns a str expression of the const with the given type def render_const(self, x:Union[float,int], var_dtype) -> str: if math.isnan(x): val = "NAN" elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY" else: val = f"{x}f" if dtypes.is_float(var_dtype) and isinstance(x, float) else f"{int(x)}" return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 else val # returns a str expression of the loaded value with the output type def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str: if isinstance(buf_dtype, ImageDType): assert output_dtype == dtypes._float4, f"images must be float4, getting {output_dtype}" return f"read_imagef({buf_name}, smp, {idx})" if self.uses_vload and buf_dtype == dtypes.float16: return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx})" if output_dtype.sz > 1: out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))" else: out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]" return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val def render_local(self, name:str, size:int): return self.smem_align + self.smem_prefix + f"float {name}[{size}];" def render_for(self, expr: str, _min:Union[int,str], _max:Union[int,str]) -> str: return f"for (int {expr} = {_min}; {expr} < {_max}; ++{expr}) {{" def render_if(self, cond: str): return f"if ({cond}) {{" def render_conditional(self, cond: str, x:str, y:str) -> str: return f"({cond})?({x}):{y}" def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], prekernel:List[str]) -> str: tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else "" buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else self.arg_int_prefix if dtype == dtypes._arg_int32 else ("const " if i > 0 else "")+self.buffer_prefix+dtype.name+"*"+self.buffer_suffix) for i,(name,dtype) in enumerate(bufs)] prg = ''.join([f"{self.kernel_prefix}void {f'__launch_bounds__ ({prod(local_size)}, 1) ' if self.launch_bounds else ''}{function_name}(",] + [', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] + [") {\n" + tmp] + ['\n'.join(kernel), "\n}"]) if self.half_prekernel and any(dtype == dtypes.float16 for _,dtype in bufs): prg = ''.join([f"{self.half_prekernel}", "\n", prg]) return prg # returns a str statement that does the store def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx:str, local=False) -> str: if isinstance(buf_dtype, ImageDType): assert var_dtype == dtypes._float4, "images must be float4" return f"write_imagef({buf_name}, {idx}, {var_name});" if self.uses_vload and buf_dtype == dtypes.float16: return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx});" if var_dtype.sz > 1: return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};" return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};" def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tuple[str, Dict]: local_size: List[int] = [] kernel,prekernel,bufs = [],[],[] #pend_close = None depth = 1 def kk(s): kernel.append(" "*depth+s) c: DefaultDict[str, int] = defaultdict(int) r: Dict[UOp, str] = {} def ssa(u, prefix="t"): nonlocal c, r c[prefix] += 1 r[u]=f"{prefix}{c[prefix]-1}" return r[u] child_count: DefaultDict[UOp, int] = defaultdict(int) for ru in uops: for v in ru.vin: child_count[v] += 1 for u in uops: uop,dtype,vin,args,_ = u if uop == UOps.LOOP: kk(lang.render_for(ssa(u,'ridx'), r[vin[0]], r[vin[1]])) depth += 1 elif uop == UOps.IF: kk(lang.render_if(r[vin[0]])) depth += 1 elif uop == UOps.BARRIER: kk(lang.barrier) elif uop == UOps.END: depth -= 1 kk("}") elif uop == UOps.WMMA: if args[0] == "METAL": # ((lidx2*32)+(lidx3*4)+(lidx4*16)+(lidx5*8)+(lidx6*2)) kk("{ simdgroup_float8x8 a,b,c;") kk(f"a.thread_elements()[0] = {r[vin[0]]}; a.thread_elements()[1] = {r[vin[1]]};") kk(f"b.thread_elements()[0] = {r[vin[2]]}; b.thread_elements()[1] = {r[vin[3]]};") kk(f"c.thread_elements()[0] = {r[vin[4]]}; c.thread_elements()[1] = {r[vin[5]]};") kk("simdgroup_multiply_accumulate(c, a, b, c);") kk(f"{r[vin[4]]} = c.thread_elements()[0]; {r[vin[5]]} = c.thread_elements()[1]; }}") elif args[0] == "HIP": kk("{") kk(f"half16 a_frag = {{ {','.join(['(half)'+r[x] for x in vin[0:16]])} }};") kk(f"half16 b_frag = {{ {','.join(['(half)'+r[x] for x in vin[16:32]])} }};") kk(f"float8 c_frag = {{ {','.join([r[x] for x in vin[32:]])} }};") kk("c_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, c_frag);") for i in range(8): kk(f"{r[vin[32+i]]} = c_frag[{i}];") kk("}") else: raise NotImplementedError(f"WMMA not implemented for {args}") elif uop == UOps.ALU: assert dtype is not None # remove parens if ALU types are the same. TODO: can do more here if vin[0].uop == UOps.ALU and vin[0].arg == args and args in {BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL}: val = lang.code_for_op[args](strip_parens(r[vin[0]]), *[r[x] for x in vin[1:]]) else: val = lang.code_for_op[args](*[r[x] for x in vin]) assert child_count[u] != 0, f"childless ALU op found {u}" if child_count[u] <= 1 or dtypes.is_int(dtype): # fix index rendering issue r[u] = val else: kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'alu')} = {val};") elif uop == UOps.DEFINE_ACC: assert dtype is not None kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'acc')} = {lang.render_const(args, dtype)};") elif uop == UOps.SPECIAL: xid = lang.gid if args[1].startswith("g") else (lang.xid if args[1].startswith("i") else lang.lid) kk(f"{lang.size_prefix} {args[1]} = {xid[args[0]]}; /* {args[2]} */") if args[1].startswith("l"): local_size.append(args[2]) r[u] = args[1] elif uop == UOps.CONST: r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})" elif uop == UOps.LOAD: assert dtype is not None val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL) if len(vin) > 2: val = lang.render_conditional(r[vin[2]], val, r[vin[3]]) kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'val')} = {val};") elif uop == UOps.PHI: kk(f"{r[vin[0]]} = {r[vin[1]]};") r[u] = r[vin[0]] elif uop == UOps.STORE: assert vin[0].dtype is not None and vin[2].dtype is not None kk(lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL)) elif uop == UOps.CAST and dtype is not None and dtype.sz > 1: val = lang.render_cast([r[x] for x in vin], dtype) if child_count[u] <= 1: r[u] = val else: kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'cast')} = {val};") elif uop == UOps.DEFINE_LOCAL: if lang.external_local_bufs: prekernel.append(lang.render_local(args[0], args[1])) else: kk(lang.render_local(args[0], args[1])) r[u] = args[0] elif uop == UOps.DEFINE_GLOBAL: bufs.append(args) r[u] = args[0] elif uop == UOps.GEP: r[u] = f"({r[vin[0]]}).{'xyzw'[args]}" else: raise RuntimeError(f"failed to render {uop}") return lang.render_kernel(function_name, kernel, bufs, local_size, prekernel), {}