You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
147 lines
8.4 KiB
147 lines
8.4 KiB
from typing import Final, Dict, Callable, Any, List, Optional, Tuple
|
|
from llvmlite import ir # type: ignore
|
|
from tinygrad.codegen.linearizer import UOps, UOp
|
|
from tinygrad.helpers import dtypes
|
|
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
|
|
|
|
LLVM_FAST_MATH_FLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
|
|
|
|
code_for_op: Final[Dict[Op, Callable]] = {
|
|
UnaryOps.NEG: lambda builder,x: builder.neg(x) if isinstance(x.type, ir.IntType) else builder.fneg(x, flags=LLVM_FAST_MATH_FLAGS),
|
|
UnaryOps.EXP2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
|
|
UnaryOps.LOG2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
|
|
UnaryOps.SIN: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
|
|
UnaryOps.SQRT: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sqrt', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
|
|
BinaryOps.ADD: lambda builder,x,y: builder.add(x,y) if isinstance(x.type, ir.IntType) else builder.fadd(x,y, flags=LLVM_FAST_MATH_FLAGS),
|
|
BinaryOps.SUB: lambda builder,x,y: builder.sub(x,y) if isinstance(x.type, ir.IntType) else builder.fsub(x,y, flags=LLVM_FAST_MATH_FLAGS),
|
|
BinaryOps.MUL: lambda builder,x,y: builder.mul(x,y) if isinstance(x.type, ir.IntType) else builder.fmul(x,y, flags=LLVM_FAST_MATH_FLAGS),
|
|
BinaryOps.DIV: lambda builder,x,y: builder.sdiv(x,y) if isinstance(x.type, ir.IntType) else builder.fdiv(x,y, flags=LLVM_FAST_MATH_FLAGS),
|
|
# TODO: this should be casted
|
|
BinaryOps.CMPLT: lambda builder,x,y: builder.zext(builder.icmp_signed("<", x, y),ir.IntType(32)) if isinstance(x.type, ir.IntType) else builder.uitofp(builder.fcmp_ordered("<", x, y, flags=LLVM_FAST_MATH_FLAGS), ir.FloatType()),
|
|
BinaryOps.MAX: lambda builder,x,y: builder.select(builder.fcmp_unordered(">", x, y, flags=LLVM_FAST_MATH_FLAGS), x, y, flags=LLVM_FAST_MATH_FLAGS),
|
|
BinaryOps.MOD: lambda builder,x,y: builder.srem(x,y),
|
|
TernaryOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(x,y, flags=LLVM_FAST_MATH_FLAGS), z, flags=LLVM_FAST_MATH_FLAGS),
|
|
TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=LLVM_FAST_MATH_FLAGS) if isinstance(x.type, ir.FloatType) else builder.trunc(x, ir.IntType(1)), y, z, flags=LLVM_FAST_MATH_FLAGS),
|
|
}
|
|
|
|
dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32), dtypes._arg_int32: ir.IntType(32), dtypes.int16:ir.IntType(16), dtypes.uint16:ir.IntType(16), dtypes.uint32:ir.IntType(32), dtypes.uint64:ir.IntType(64)}
|
|
|
|
def cast(bb, val, input_type, output_type):
|
|
if input_type == output_type: return val
|
|
|
|
if output_type == dtypes.float32:
|
|
if dtypes.is_int(input_type) or input_type == dtypes.bool:
|
|
val = bb[-1].uitofp(val, ir.FloatType()) if dtypes.is_unsigned(input_type) or input_type == dtypes.bool else bb[-1].sitofp(val, ir.FloatType())
|
|
elif input_type == dtypes.bfloat16:
|
|
val = bb[-1].sext(val, ir.IntType(32))
|
|
val = bb[-1].shl(val, ir.Constant(ir.IntType(32), 16))
|
|
val = bb[-1].bitcast(val, ir.FloatType())
|
|
elif input_type == dtypes.float64:
|
|
val = bb[-1].fptrunc(val, ir.FloatType())
|
|
else:
|
|
val = bb[-1].fpext(val, ir.FloatType())
|
|
return val
|
|
|
|
if input_type == dtypes.float32:
|
|
if dtypes.is_int(output_type) or output_type == dtypes.bool:
|
|
if dtypes.is_unsigned(output_type): val = bb[-1].fptoui(val, dtype_to_llvm_dtype[output_type])
|
|
elif output_type == dtypes.bool: val = bb[-1].fcmp_ordered("!=", val, ir.Constant(ir.FloatType(), 0), flags=LLVM_FAST_MATH_FLAGS)
|
|
else: val = bb[-1].fptosi(val, dtype_to_llvm_dtype[output_type])
|
|
elif output_type == dtypes.bfloat16:
|
|
val = bb[-1].bitcast(val, ir.IntType(32))
|
|
val = bb[-1].lshr(val, ir.Constant(ir.IntType(32), 16))
|
|
val = bb[-1].trunc(val, ir.IntType(16))
|
|
elif output_type == dtypes.float64:
|
|
val = bb[-1].fpext(val, ir.DoubleType())
|
|
else:
|
|
val = bb[-1].fptrunc(val, dtype_to_llvm_dtype[output_type])
|
|
return val
|
|
|
|
raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
|
|
|
|
def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
|
|
# all llvm stuff goes into a module
|
|
module = ir.Module(name=__file__)
|
|
|
|
# extract global buffers
|
|
buf_to_dtype = {args[0]:args[1] for uop,_,_,args,_ in uops if uop == UOps.DEFINE_GLOBAL}
|
|
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
|
|
|
|
# create llvm function
|
|
func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values()]
|
|
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if dt!=dtypes._arg_int32 else x for x,dt in func_dtypes]), name=function_name)
|
|
for a in func.args:
|
|
if a.type.is_pointer: a.add_attribute("noalias")
|
|
|
|
# add the function attribute "no-nans-fp-math"="true", which informs llvm that it allowed to use vectorization optimizations
|
|
func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
|
|
func.attributes.add('"no-nans-fp-math"="true"')
|
|
|
|
bb = [ir.IRBuilder(func.append_basic_block("entry"))]
|
|
loop_blocks: List = []
|
|
reduce_phis: List = []
|
|
# TODO: newvar probably shouldn't be optional
|
|
lvars: Dict[Optional[UOp], Any] = {} # this Any is an llvm type
|
|
|
|
for bufname,dtype in buf_to_dtype.items():
|
|
if dtype == dtypes._arg_int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
|
|
|
|
for u in uops:
|
|
uop,dtype,vin,args,_ = u
|
|
if uop == UOps.LOOP:
|
|
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
|
|
bb[-2].branch(bb[-1]._block)
|
|
|
|
phis = []
|
|
for rp in reduce_phis:
|
|
incoming = lvars[rp]
|
|
lvars[rp] = bb[-1].phi(ir.FloatType())
|
|
lvars[rp].add_incoming(incoming, bb[-2]._block)
|
|
phis.append((rp, lvars[rp]))
|
|
|
|
lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
|
|
lvars[u].add_incoming(lvars[vin[0]], bb[-2]._block)
|
|
loop_blocks.append((bb[-1], phis))
|
|
if uop == UOps.END:
|
|
block, phis = loop_blocks.pop()
|
|
idx_p1 = bb[-1].add(lvars[vin[0]], ir.Constant(ir.IntType(32), 1))
|
|
lvars[vin[0]].add_incoming(idx_p1, bb[-1]._block)
|
|
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1]._block)
|
|
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
|
|
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), block._block, bb[-1]._block)
|
|
if uop == UOps.DEFINE_GLOBAL:
|
|
lvars[u] = func.args[buf_index[args[0]]]
|
|
if uop == UOps.DEFINE_ACC:
|
|
lvars[u] = ir.Constant(dtype_to_llvm_dtype[dtype], args)
|
|
reduce_phis.append(u)
|
|
if uop == UOps.SPECIAL:
|
|
lvars[u] = lvars[args.expr]
|
|
if uop == UOps.CONST:
|
|
value = int(args) if dtypes.is_int(dtype) else bool(args) if dtype == dtypes.bool else args
|
|
lvars[u] = ir.Constant(dtype_to_llvm_dtype[dtype], value)
|
|
if uop == UOps.LOAD:
|
|
assert dtype is not None
|
|
if len(vin) > 2:
|
|
gate = bb[-1].trunc(lvars[vin[2]], ir.IntType(1))
|
|
aug_idx = bb[-1].select(gate, lvars[vin[1]], ir.Constant(ir.IntType(32), 0))
|
|
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [aug_idx], inbounds=True))
|
|
val = cast(bb, val, vin[0].dtype, dtype)
|
|
val = bb[-1].select(gate, val, lvars[vin[3]])
|
|
else:
|
|
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
|
|
val = cast(bb, val, vin[0].dtype, dtype)
|
|
lvars[u] = val
|
|
if uop == UOps.PHI:
|
|
lvars[u] = lvars[vin[1]]
|
|
# PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
|
|
backward = vin[0]
|
|
while backward.uop == UOps.PHI: backward = backward.vin[0]
|
|
lvars[backward] = lvars[u]
|
|
if uop == UOps.STORE:
|
|
element = cast(bb, lvars[vin[2]], vin[2].dtype, vin[0].dtype)
|
|
bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
|
|
if uop == UOps.ALU:
|
|
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin])
|
|
|
|
bb[-1].ret_void()
|
|
return str(module), {}
|
|
|