You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
190 lines
10 KiB
190 lines
10 KiB
1 month ago
|
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict, cast
|
||
|
from tinygrad.codegen.kernel import Ops, MemOp, UOp
|
||
|
from tinygrad.ops import BinaryOps, UnaryOps
|
||
|
from tinygrad.dtype import DType, dtypes
|
||
|
from tinygrad.helpers import DEBUG
|
||
|
from tinygrad.ops import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
|
||
|
import functools
|
||
|
import math
|
||
|
from collections import defaultdict
|
||
|
|
||
|
_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes.float.vec(4): 'x', dtypes.uint8: 'uc', dtypes.float16: 'h',
|
||
|
dtypes.int8: 'c', dtypes.uint16: 'us', dtypes.float64: 'd'}
|
||
|
|
||
|
class Register(NamedTuple):
|
||
|
nm:str
|
||
|
dtype:DType
|
||
|
scalar:bool
|
||
|
off:Optional[int] = None
|
||
|
def __repr__(self): return self.nm if self.off is None else f"{self.nm}:{self.off}"
|
||
|
def subregs(self):
|
||
|
if self.dtype == dtypes.float.vec(4):
|
||
|
return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)]
|
||
|
return []
|
||
|
|
||
|
class AssemblyInstruction(NamedTuple):
|
||
|
op: Ops
|
||
|
out: Optional[Register]
|
||
|
vin: List[Union[Register, int, float]]
|
||
|
arg: Any = None
|
||
|
|
||
|
# warp size of 32, s registers are shared across the warp, v are 32-wide vectors
|
||
|
class AssemblyLanguage:
|
||
|
supports_load3: bool = False
|
||
|
sin_is_sin2pi: bool = False
|
||
|
no_div: bool = False
|
||
|
#TODO: these should be global vars
|
||
|
cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int)
|
||
|
tor: Dict[Any, Register] = {}
|
||
|
ins: List[AssemblyInstruction] = []
|
||
|
|
||
|
def type_to_letter(self,x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
|
||
|
def newreg(self, tok, dtype=dtypes.float32, scalar=False) -> Register:
|
||
|
self.tor[tok] = ret = Register(f"%{self.type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}", dtype, scalar)
|
||
|
if dtype == dtypes.float.vec(4):
|
||
|
for off in range(4):
|
||
|
self.tor[tok] = Register(ret.nm, dtypes.float, ret.scalar, off)
|
||
|
self.cnts[(dtype, scalar)] += 1
|
||
|
return ret
|
||
|
|
||
|
def render_numnode(self, b) -> Register:
|
||
|
key = ("num", b)
|
||
|
if key not in self.tor: self.ins.append(AssemblyInstruction(Ops.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b))
|
||
|
return self.tor[key]
|
||
|
|
||
|
def render_alu(self, op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
|
||
|
key = (op, a, b)
|
||
|
if key not in self.tor:
|
||
|
#if not isinstance(b, Register): b = render_numnode(b)
|
||
|
self.ins.append(AssemblyInstruction(Ops.ALU, self.newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
|
||
|
return self.tor[key]
|
||
|
|
||
|
def render_cast(self, a:Register, new_dtype:DType) -> Register:
|
||
|
if a.dtype == new_dtype: return a
|
||
|
key = (a, new_dtype)
|
||
|
if key not in self.tor:
|
||
|
self.ins.append(AssemblyInstruction(Ops.CAST, self.newreg(key, dtype=new_dtype), [a]))
|
||
|
return self.tor[key]
|
||
|
|
||
|
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.tor[self], NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b),
|
||
|
MulNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b),
|
||
|
DivNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b),
|
||
|
ModNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b),
|
||
|
LtNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool),
|
||
|
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
||
|
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
||
|
|
||
|
def addr_w_offset(self, args):
|
||
|
assert isinstance(args, MemOp)
|
||
|
idx = args.idx*args.memory_dtype.itemsize
|
||
|
off = 0 # TODO: should this be None?
|
||
|
if isinstance(idx, SumNode):
|
||
|
nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
|
||
|
if nums and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU?
|
||
|
idx -= nums[0]
|
||
|
off = cast(int, nums[0])
|
||
|
reg = idx.render(self.render_ops, self)
|
||
|
if self.supports_load3:
|
||
|
if reg.scalar:
|
||
|
new_reg = self.newreg((reg.nm, 'vec'), dtype=reg.dtype)
|
||
|
self.ins.append(AssemblyInstruction(Ops.ALU, new_reg, [reg], UnaryOps.NOOP))
|
||
|
reg = new_reg
|
||
|
return self.tor[args.name], reg, off
|
||
|
reg = self.render_alu(BinaryOps.ADD, self.render_cast(reg, dtypes.uint64), self.tor[args.name], dtype=dtypes.uint64)
|
||
|
return reg, None, off
|
||
|
|
||
|
def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]):
|
||
|
#TODO: Do not use clear()
|
||
|
lang.ins.clear()
|
||
|
lang.tor.clear()
|
||
|
lang.cnts.clear()
|
||
|
buf_to_dtype = {args:dtype for uop,dtype,_,args,_ in uops if uop == Ops.DEFINE_GLOBAL}
|
||
|
global_size, local_size = [], []
|
||
|
skipload_branch = 0
|
||
|
lang.ins += [AssemblyInstruction(Ops.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype]
|
||
|
for u in uops:
|
||
|
uop,dtype,vin,args,_ = u
|
||
|
if uop == Ops.DEFINE_LOCAL:
|
||
|
lang.ins.append(AssemblyInstruction(Ops.DEFINE_LOCAL, None, [], args))
|
||
|
lang.ins.append(AssemblyInstruction(Ops.ALU, lang.newreg(args[0], dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
|
||
|
elif uop == Ops.LOOP:
|
||
|
if args[1] == "global":
|
||
|
for i,var in enumerate(args[0]):
|
||
|
global_size.append(var.max+1)
|
||
|
lang.ins.append(AssemblyInstruction(Ops.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
|
||
|
elif args[1] == "local":
|
||
|
for i,var in enumerate(args[0]):
|
||
|
local_size.append(var.max+1)
|
||
|
lang.ins.append(AssemblyInstruction(Ops.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
|
||
|
else:
|
||
|
for var in args[0]:
|
||
|
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
||
|
lang.ins.append(AssemblyInstruction(Ops.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0))
|
||
|
lang.ins.append(AssemblyInstruction(Ops.LABEL, None, [], "$loop_"+var.expr))
|
||
|
elif uop == Ops.ENDLOOP:
|
||
|
if args[1] not in ["global", "local", "global+local"]:
|
||
|
for var in reversed(args[0]):
|
||
|
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
||
|
lang.ins.append(AssemblyInstruction(Ops.ALU, lang.tor[var], [lang.tor[var], 1], BinaryOps.ADD))
|
||
|
pred = lang.render_alu(BinaryOps.CMPLT, lang.tor[var], var.max+1, dtypes.bool)
|
||
|
lang.ins.append(AssemblyInstruction(Ops.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
|
||
|
elif args[1] == "global+local":
|
||
|
for i, var in enumerate(reversed(args[0])):
|
||
|
lang.ins.append(AssemblyInstruction(Ops.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}")))
|
||
|
elif args[1] == 'local':
|
||
|
for i, var in enumerate(reversed(args[0])):
|
||
|
lang.ins.append(AssemblyInstruction(Ops.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"lid{i}")))
|
||
|
elif uop == Ops.CAST:
|
||
|
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
|
||
|
out = lang.newreg(u, dtype)
|
||
|
for i,sr in enumerate(out.subregs()):
|
||
|
lang.ins.append(AssemblyInstruction(Ops.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP))
|
||
|
elif uop == Ops.ALU:
|
||
|
out = lang.newreg(u, dtype) if u not in lang.tor else lang.tor[u]
|
||
|
# this is the only thing that can violate SSA
|
||
|
if args in [BinaryOps.CMPLT]:
|
||
|
pred_reg = lang.newreg((u, 'pred'), dtype=dtypes.bool)
|
||
|
lang.ins.append(AssemblyInstruction(Ops.ALU, pred_reg, [lang.tor[x] for x in vin], args))
|
||
|
lang.ins.append(AssemblyInstruction(Ops.CAST, out, [pred_reg], args))
|
||
|
elif args == BinaryOps.DIV and lang.no_div:
|
||
|
tmp = lang.newreg((u, "rcp"))
|
||
|
lang.ins.append(AssemblyInstruction(Ops.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP))
|
||
|
lang.ins.append(AssemblyInstruction(Ops.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL))
|
||
|
elif args == UnaryOps.SIN and lang.sin_is_sin2pi:
|
||
|
tmp = lang.newreg((u, "2pi"))
|
||
|
lang.ins.append(AssemblyInstruction(Ops.ALU, tmp, [lang.tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
|
||
|
lang.ins.append(AssemblyInstruction(Ops.ALU, out, [tmp], args))
|
||
|
else:
|
||
|
lang.ins.append(AssemblyInstruction(Ops.ALU, out, [lang.tor[x] for x in vin], args))
|
||
|
elif uop == Ops.DEFINE_ACC:
|
||
|
reg = lang.newreg(u, dtype=dtype)
|
||
|
lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [], args))
|
||
|
elif uop == Ops.SPECIAL:
|
||
|
lang.tor[u] = lang.tor[args]
|
||
|
elif uop == Ops.CONST:
|
||
|
lang.ins.append(AssemblyInstruction(Ops.LOAD, lang.newreg(u, dtype=dtype), [], args))
|
||
|
elif uop == Ops.LOAD:
|
||
|
idx, treg, off = lang.addr_w_offset(args)
|
||
|
reg = lang.newreg(u, dtype=dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar)))
|
||
|
if args.valid.min == 0:
|
||
|
lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [], 0))
|
||
|
if args.valid.max == 1:
|
||
|
pred = args.valid.render(lang.render_ops, lang)
|
||
|
lang.ins.append(AssemblyInstruction(Ops.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
|
||
|
if args.valid.max == 1:
|
||
|
# NOTE: you can't compute the index in here, because it assumes it's all available later
|
||
|
lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
|
||
|
if args.valid.min == 0 and args.valid.max == 1:
|
||
|
lang.ins.append(AssemblyInstruction(Ops.LABEL, None, [], f"$skipload_{skipload_branch}"))
|
||
|
skipload_branch += 1
|
||
|
elif uop == Ops.STORE:
|
||
|
if args is None:
|
||
|
lang.ins.append(AssemblyInstruction(Ops.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP))
|
||
|
else:
|
||
|
idx, treg, off = lang.addr_w_offset(args)
|
||
|
lang.ins.append(AssemblyInstruction(Ops.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
|
||
|
|
||
|
if DEBUG >= 4:
|
||
|
for tins in lang.ins: print(tins)
|
||
|
return global_size, local_size
|