from typing import cast import math, struct, sys from tinygrad.renderer import Renderer from tinygrad.renderer.cstyle import ClangRenderer, AMDRenderer from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, GroupOp from tinygrad.dtype import dtypes, DType, PtrDType, truncate from tinygrad.helpers import prod, AMX def ldt(dt:DType): if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>" if isinstance(dt, PtrDType): return ldt(dt.base) + (" addrspace(3)*" if dt.local else "*") return {dtypes.void: "void", dtypes.bool: "i1", dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64", dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64", dtypes.float16: "half", dtypes.bfloat16: "bfloat", dtypes.float32: "float", dtypes.float64: "double"}[dt] def lconst(x, dtype:DType): if dtype in dtypes.floats: if math.isinf(x) or math.isnan(x): return "0x%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1]) return truncate[dtype](x) return int(x) def lcast(input_type:DType, output_type:DType): if dtypes.is_float(input_type): if dtypes.is_float(output_type): return 'fpext' if output_type.itemsize > input_type.itemsize else 'fptrunc' if dtypes.is_int(output_type): return 'fptoui' if dtypes.is_unsigned(output_type) else 'fptosi' if dtypes.is_unsigned(input_type) or dtypes.is_bool(input_type): if dtypes.is_float(output_type): return 'uitofp' if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'zext' if dtypes.is_int(input_type): if dtypes.is_float(output_type): return 'sitofp' if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'sext' raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented") # https://github.com/corsix/amx def render_wmma_amx(ctx, wmma: UOp) -> str: def AMX(op, gpr): return f'call void asm sideeffect ".word (0x201000+($0<<5)+0$1-((0$1>>4)*6))", "i,r,~{{memory}}"(i32 {op}, i64 {gpr}) #0; AMX' return "\n".join([ *[f' store {ldt(src.dtype)} {ctx[src]}, {ldt(src.dtype.ptr())} {ctx[wmma]}_amx{i}, align {src.dtype.itemsize}' for i,src in enumerate(wmma.src)], f' call void asm sideeffect "nop\\0Anop\\0Anop\\0A.word ({0x201000 + (17 << 5) + 0})", "~{{memory}}"() #0; AMX set', # set *[f' {ctx[wmma]}_ld{i} = add i64 {ctx[wmma]}_ptr_amx2, {i*4<<56 | i*64}\n {AMX(4,f"{ctx[wmma]}_ld{i}")} ldz' for i in range(16)], # ldz f' {AMX(0, f"{ctx[wmma]}_ptr_amx1")} ldx\n {AMX(1, f"{ctx[wmma]}_ptr_amx0")} ldy\n {AMX(12, 0)} fma32', # ldx ldy fma *[f' {ctx[wmma]}_st{i} = add i64 {ctx[wmma]}_ptr_amx2, {i*4<<56 | i*64}\n {AMX(5,f"{ctx[wmma]}_st{i}")} stz' for i in range(16)], # stz f' call void asm sideeffect "nop\\0Anop\\0Anop\\0A.word ({0x201000 + (17 << 5) + 1})", "~{{memory}}"() #0; AMX clr', # clr f' {ctx[wmma]} = load {ldt(wmma.dtype)}, ptr {ctx[wmma]}_amx2, align {wmma.dtype.itemsize}']) def render_wmma_amd(ctx, wmma: UOp) -> str: dt_map = {dtypes.half: "f16", dtypes.float: "f32"} # https://github.com/llvm/llvm-project/blob/main/llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wmma_32.ll # example: %wmma0 = call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16(<16 x half> %v99,<16 x half> %v100,<8 x float> %v101) return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.wmma.{dt_map[wmma.src[-1].dtype.scalar()]}.16x16x16." + \ f"{dt_map[wmma.src[0].dtype.scalar()]}(" + ", ".join([f"{ldt(w.dtype)} {ctx[w]}" for w in wmma.src]) + (", i1 false)" \ if wmma.dtype.scalar() != dtypes.float else ")") # llvm ops, lop[][] unsigned_lop = { Ops.ADD: "add", Ops.MUL: "mul", Ops.IDIV: "udiv", Ops.MOD: "urem", Ops.CMPLT: "icmp ult", Ops.CMPNE: "icmp ne", Ops.OR: "or", Ops.AND: "and", Ops.XOR: "xor", } signed_lop = {**unsigned_lop, Ops.CMPLT: "icmp slt", Ops.IDIV: "sdiv", Ops.MOD: "srem"} flags = " nsz arcp contract afn" float_lop = {Ops.ADD: "fadd"+flags, Ops.MUL: "fmul"+flags, Ops.CMPLT: f"fcmp{flags} ult", Ops.CMPNE: f"fcmp{flags} une", Ops.FDIV: "fdiv"+flags} lop = {**{x:unsigned_lop for x in (dtypes.bool,)+dtypes.uints}, **{x:signed_lop for x in dtypes.sints}, **{x:float_lop for x in dtypes.floats}} base_rewrite = PatternMatcher([ # memory load/store (UPat(Ops.INDEX, name="x"), lambda ctx,x: f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"), (UPat(Ops.LOAD, src=(UPat.or_casted(name='idx', self=UPat(src=(UPat(), UPat(), UPat.var('mask')))), UPat.var('alt')), name="x"), lambda ctx,x,idx,alt,mask: f" br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n" f" br i1 {ctx[mask]}, label {ctx[x]}_load, label {ctx[x]}_exit\n{ctx[x][1:]}_load:\n" f" {ctx[x]}_yes = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}\n" f" br label {ctx[x]}_exit\n{ctx[x][1:]}_exit:\n" f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x]}_yes, {ctx[x]}_load], [{ctx[alt]}, {ctx[x]}_entry]"), (UPat(Ops.LOAD, src=(UPat.var('idx'),), allow_any_len=True, name="x"), lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"), (UPat(Ops.STORE, name="x"), lambda ctx,x: f" store {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}"), # GEP/VECTORIZE/CAST for float4 support (UPat(Ops.GEP, name="x"), lambda ctx,x: f" {ctx[x]} = extractelement {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, i32 {x.arg[0]}"), (UPat(Ops.VECTORIZE, src=UPat.var('y'), name="x"), lambda ctx,x,y: f" {ctx[x]}_z = insertelement <1 x {ldt(y.dtype)}> poison, {ldt(y.dtype)} {ctx[y]}, i32 0\n" f" {ctx[x]} = shufflevector <1 x {ldt(y.dtype)}> {ctx[x]}_z, <1 x {ldt(y.dtype)}> poison, <{x.dtype.count} x i32> zeroinitializer"), (UPat(Ops.VECTORIZE, name="x"), lambda ctx,x: "\n".join([(f" {ctx[x]}_{i}" if i+1 != len(x.src) else f" {ctx[x]}")+ f" = insertelement {ldt(x.dtype)} "+(f"{ctx[x]}_{i-1}" if i != 0 else "poison")+ f", {ldt(u.dtype)} {ctx[u]}, i32 {i}" for i,u in enumerate(x.src)])), (UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}" if isinstance(x.dtype, PtrDType) else None), # unary/binary/ternary ops (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"), (UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"), (UPat(GroupOp.Binary, name="x"), lambda ctx,x: f" {ctx[x]} = {lop[x.src[0].dtype.scalar()][x.op]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ctx[x.src[1]]}"), (UPat(Ops.WHERE, name="x"), lambda ctx,x: f" {ctx[x]} = select {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[2].dtype)} {ctx[x.src[2]]}"), # range (UPat(Ops.RANGE, name="x"), lambda ctx,x: f" br label %loop_entry_{x.arg}\nloop_entry_{x.arg}:\n" f" br label %loop_body_{x.arg}\nloop_body_{x.arg}:\n" f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x.src[0]]}, %loop_entry_{x.arg}], [{ctx[x]}phi, %loop_latch_{x.arg}]"), (UPat(Ops.ENDRANGE, name="x"), lambda ctx,x: f" br label %loop_latch_{x.src[0].arg}\nloop_latch_{x.src[0].arg}:\n" f" {ctx[x.src[0]]}phi = add i32 {ctx[x.src[0]]}, 1\n {ctx[x]} = icmp ult i32 {ctx[x.src[0]]}phi, {ctx[x.src[0].src[1]]}\n" f" br i1 {ctx[x]}, label %loop_body_{x.src[0].arg}, label %loop_exit_{x.src[0].arg}\nloop_exit_{x.src[0].arg}:"), # if (UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"), (UPat(Ops.ENDIF, name="x"), lambda ctx,x: f" br label %ifskip_{ctx[x.src[0]][1:]}\nifskip_{ctx[x.src[0]][1:]}:"), ]) def llvm_bf16_cast(buf:UOp, idx:UOp, root:UOp): u16_buf = buf.replace(dtype=dtypes.ushort.ptr(size=cast(PtrDType,buf.dtype).size)) return UOp.load(UOp.index(u16_buf, idx), dtype=dtypes.ushort).cast(dtypes.uint).mul(1<<16).bitcast(dtypes.float32).cast(root.dtype) class LLVMRenderer(Renderer): device = "LLVM" abi = 'win64cc' if sys.platform == 'win32' else None supports_float4 = True has_local = False has_shared = False global_max: tuple[int, ...] | None = None string_rewrite = base_rewrite + PatternMatcher([(UPat(Ops.WMMA, name="wmma"), render_wmma_amx)]) if AMX: tensor_cores = ClangRenderer.amx_tc extra_matcher = PatternMatcher([ # rewrite RECIP with FDIV (UPat(Ops.RECIP, name="x"), lambda x: UOp(Ops.FDIV, x.dtype, (x.const_like(1), x.src[0]))), # rewrite cast to bool to CMPNE 0 (UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x: x.src[0] != x.src[0].const_like(0)), # rewrite MAX to CMPLT + WHERE (UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])), # rewrite bf16 CAST(LOAD) to CAST(BITCAST) (UPat(Ops.CAST, name="root", src=(UPat.load(UPat.index(UPat.var("buf"), UPat.var("idx")), dtype=dtypes.bfloat16),)), llvm_bf16_cast), # copied from cstyle.py, upcast to float32 all the ops that don't support bfloat16 (UPat((Ops.SQRT, Ops.EXP2, Ops.LOG2, Ops.SIN), dtype=dtypes.bfloat16, name="x"), lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16))), # copied from cstyle.py, add float intermediate casting (UPat(Ops.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None), (UPat(Ops.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None), ]) def render(self, uops: list[UOp]) -> str: r: dict[UOp, str] = {} args: list[str] = [] kernel: list[str] = [] end_lines: dict[str, None] = {} vc = -1 local_args: list[str] = [] acc_to_assign: dict[UOp, UOp] = {} for u in uops: if u.op is Ops.ASSIGN: # prealloc all assigns vc += 1 r[u] = r[u.src[1]] = f"%assign{vc}" assert u.src[0] not in acc_to_assign, "can't assign to DEFINE_ACC twice" acc_to_assign[u.src[0]] = u.src[1] if AMX and u.op is Ops.WMMA: # prealloc aux buffers as AMX can only load from memory vc += 1 r[u] = f"%wmma{vc}" for i, dtype in enumerate(u.arg[2].vec(sz) for sz in [prod(size for _, size in upcast) for upcast in u.arg[6]]): kernel += [f" {r[u]}_amx{i} = alloca {ldt(dtype)}, align {dtype.itemsize}", f" {r[u]}_ptr_amx{i} = ptrtoint {ldt(dtype.ptr())} {r[u]}_amx{i} to i64"] name = "test" for u in uops: if u.op is Ops.NAME: name = u.arg continue if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR): r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}" # NOTE: MallocAllocator promises 0x20 alignment args.append(f"{ldt(u.dtype)}{' noalias align 32' if isinstance(u.dtype, PtrDType) else ''} {r[u]}") elif u.op == Ops.DEFINE_LOCAL: r[u] = f"@local_{u.arg}" assert isinstance(u.dtype, PtrDType) local_args.append(f"{r[u]} = internal unnamed_addr addrspace(3) global [{u.dtype.size} x {ldt(u.dtype)}] undef, align 16") elif u.op is Ops.ASSIGN: pass # assign is already handled by the first pass elif u.op is Ops.DEFINE_ACC: r[u] = r[u.src[0]] # a define acc can be used and never be assigned to elif u.op is Ops.CONST: r[u] = lconst(u.arg, u.dtype) elif u.op is Ops.CAST and ldt(u.dtype) == ldt(u.src[0].dtype): r[u] = r[u.src[0]] # cast from signed to unsigned of the same size is a noop else: # if it's an assign target, it's already preallocated if u not in r: vc += 1 r[u] = f"%v{vc}" # do the rendering of the llvm ir code if (l:=self.string_rewrite.rewrite(u, ctx=r)) is None: raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}") kernel.append(cast(str, l)) # generate the phi nodes for the assigns if u.op is Ops.RANGE: for x in acc_to_assign: if u in x.src: # if this range is relevant for this acc vc += 1 kernel.append(f" %acc{vc} = phi {ldt(x.dtype)}" f"[{r[x]}, %loop_entry_{u.arg}], [{r[acc_to_assign[x]]}, %loop_latch_{u.arg}]") r[x] = f"%acc{vc}" # output the function. chr(10) is '\n' (python < 3.12 doesn't support backslashes in f-strings) prg = f'''\ define{(' '+self.abi) if self.abi is not None else ''} void @{name}({','.join(args)}) #0 {{ {chr(10).join(kernel)} ret void }} {chr(10).join(end_lines.keys())} attributes #0 = {{ nounwind "no-builtins" "no-trapping-math"="true" }} ''' return prg if len(local_args) == 0 else "\n".join(local_args)+f"\n{prg}" barrier = 'fence syncscope("workgroup") release\ntail call void @llvm.amdgcn.s.barrier()\nfence syncscope("workgroup") acquire\n' code_for_workitem = {"g": lambda x: f"tail call i32 @llvm.amdgcn.workgroup.id.{chr(120+int(x))}()", "l": lambda x: f"tail call i32 @llvm.amdgcn.workitem.id.{chr(120+int(x))}()"} class AMDLLVMRenderer(LLVMRenderer): device = "AMD" has_local = True has_shared = True shared_max = AMDRenderer.shared_max global_max = AMDRenderer.global_max tensor_cores = AMDRenderer.tensor_cores abi = "amdgpu_kernel" string_rewrite = PatternMatcher([ (UPat(Ops.SPECIAL, name="x"), lambda ctx, x: f" {ctx[x]} = " + f"{ code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; "), (UPat(Ops.BARRIER), lambda ctx: barrier), (UPat(Ops.CAST, name="x", dtype=dtypes.half.vec(16), src=UPat.var("y", dtypes.half.vec(8))), lambda ctx, x, y: f" {ctx[x]} = shufflevector "\ f"<8 x half> {ctx[y]}, <8 x half> zeroinitializer, <16 x i32> <{', '.join([f'i32 {i}, i32 {j}' for i, j in zip(range(0, 8), range(8, 16))])}>"), (UPat(Ops.CAST, name="x", dtype=dtypes.half.vec(8), src=UPat.var("y", dtypes.half.vec(16))), lambda ctx, x, y: f" {ctx[x]}= shufflevector <16 x half> {ctx[y]}, <16 x half> undef, <8 x i32> <{', '.join([f'i32 {x}' for x in range(0, 16, 2)])}>"), (UPat(Ops.WMMA, name="wmma"), render_wmma_amd), ]) + base_rewrite extra_matcher = PatternMatcher([ (UPat(Ops.WMMA, name="x", dtype=dtypes.half.vec(8)), lambda x: UOp(Ops.WMMA, dtypes.half.vec(16), (x.src[0], x.src[1], x.src[2].cast(dtypes.half.vec(16))), (*x.arg,)).cast(dtypes.half.vec(8))) ]) + LLVMRenderer.extra_matcher