from typing import List import struct from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage from tinygrad.codegen.kernel import Ops, UOp from tinygrad import dtypes from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps from tinygrad.runtime.ops_cuda import arch dtype_to_nvtype = {dtypes.float32: "f32", dtypes.float16: "f16", dtypes.int64: "s64", dtypes.int32: "s32", dtypes.int8: "s8", dtypes.bool: "pred", dtypes.uint64: "u64", dtypes.uint32: "u32", dtypes.uint16: "u16", dtypes.uint8: "u8", "bits16": "b16", dtypes.float64: "f64"} def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1]) def ptx_needs_cast(dest_dtype, src_dtype): return dtypes.is_float(dest_dtype) and dtypes.is_int(src_dtype) or dtypes.is_int(dest_dtype) and dtypes.is_float(src_dtype) or (dtypes.is_float(src_dtype) and dtypes.is_float(dest_dtype) and dest_dtype.itemsize != src_dtype.itemsize) def render_cast(ins, inp, out): if inp.dtype == dtypes.bool and (dtypes.is_float(out.dtype) or dtypes.is_int(out.dtype)): ins.append(f"selp.{dtype_to_nvtype[out.dtype]} {out}, {'0f3F800000, 0f00000000' if dtypes.is_float(out.dtype) else '1, 0'}, {inp};") elif out.dtype == dtypes.bool: if inp.dtype == dtypes.bool: ins.append(f"mov.pred {out}, {inp};") else: ins.append(f"setp.ne.{dtype_to_nvtype[inp.dtype]} {out}, {'0f00000000' if dtypes.is_float(inp.dtype) else '0'}, {inp};") else: round_mod = ".rzi" if dtypes.is_int(out.dtype) and dtypes.is_float(inp.dtype) else '.rz' if dtypes.is_float(out.dtype) and (dtypes.is_int(inp.dtype) or dtypes.is_float(inp.dtype) and inp.dtype.itemsize > out.dtype.itemsize) else '' ins.append(f"cvt{round_mod}.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[inp.dtype]} {out}, {inp};") # https://docs.nvidia.com/cuda/parallel-thread-execution/# class PTXLanguage(AssemblyLanguage): supports_constant_folding: bool = True def specialize_to_ptx(lang, function_name): param_cnt = 0 ins = [] alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max", BinaryOps.MOD: "rem", BinaryOps.CMPLT: "setp.lt", UnaryOps.SQRT: "sqrt.approx", UnaryOps.NOOP: "mov", UnaryOps.NEG: "neg", UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz", TernaryOps.MULACC: "fma.rn", TernaryOps.WHERE: "selp"} for uop, out, vin, arg in lang.ins: if uop == Ops.ENDLOOP: ins.append("bar.sync 0;") elif uop == Ops.DEFINE_LOCAL: ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];") elif uop == Ops.SPECIAL: if arg.startswith('data'): param_cnt += 1 ins.append(f"ld.param.u64 {out}, [{arg}];") # TODO: we sometimes want this to be local, nvcc converts to global most of the time, not sure when we would need to? # ins.append(f"cvta.to.global.u64 {out}, {out};") elif arg.startswith('gid'): ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};") elif arg.startswith('lid'): ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};") elif uop == Ops.ALU: if arg == BinaryOps.MUL and out.dtype == dtypes.bool: ins.append(f"and.pred {out}, {', '.join(str(x) for x in vin)};") else: otype = vin[0].dtype if arg in [BinaryOps.CMPLT] else out.dtype if arg == TernaryOps.WHERE: if vin[0].dtype == dtypes.bool: reg = vin[0] else: reg = lang.newreg((vin[0], 'bool'), dtypes.bool) ins.append(f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};") vin = vin[1:] + [reg] ins.append(f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};") elif uop == Ops.LOAD: if arg.__class__ in (int, float): ins.append(f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else int(arg)};") elif arg[2] is not None and (arg[2] == dtypes.bool or arg[2] != out.dtype): dt = ('u16', dtypes.uint16) if arg[2] == dtypes.bool == out.dtype else ('u8', dtypes.uint8) if arg[2] == dtypes.bool else ('b16', dtypes.float16) if arg[2] == dtypes.half else (dtype_to_nvtype[arg[2]], arg[2]) reg = lang.newreg((out, dt[0]), dtype=dt[1]) ins.append(f"ld.{arg[1]}.{dt[0]} {reg}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];") render_cast(ins, reg, out) else: ins.append(f"ld.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];") elif uop == Ops.STORE: if ptx_needs_cast(dtypes.float if arg[2] is None else arg[2], vin[1].dtype) or arg[2] == dtypes.bool: if arg[2] == dtypes.bool != vin[1].dtype: prereg = lang.newreg((vin[1],'bool'), dtype=dtypes.bool) render_cast(ins, vin[1], prereg) else: prereg = vin[1] reg = lang.newreg((prereg, dtypes.uint16 if arg[2] == dtypes.bool else arg[2]), dtype=dtypes.uint16 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]) render_cast(ins, prereg, reg) ins.append(f"st.{arg[1]}.{dtype_to_nvtype['bits16' if arg[2] == dtypes.float16 else dtypes.uint8 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {reg};") else: ins.append(f"st.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};") elif uop == Ops.CAST: render_cast(ins, vin[0], out) elif uop == Ops.LABEL: ins.append(f"{arg}:") elif uop == Ops.COND_BRANCH: ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};") ins_prefix = [".version 7.8", ".target " + arch(), ".address_size 64", f".visible .entry {function_name}({', '.join(f'.param .u64 data{i}' for i in range(param_cnt))}) {{"] for arg in [(dtype, lang.type_to_letter(dtype), c) for dtype,c in lang.cnts.items()]: ins_prefix.append(f".reg .{dtype_to_nvtype[arg[0][0]]} %{arg[1]}<{arg[2]}>;",) ins = ins_prefix + ins ins += ["ret;", "}"] return '\n'.join(ins) def uops_to_ptx_asm(function_name:str, uops:List[UOp]): lang = PTXLanguage() global_size, local_size = uops_to_asmstyle(lang, function_name, uops) return specialize_to_ptx(lang, function_name), global_size[::-1], local_size[::-1], True