openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

148 lines
8.4 KiB

from typing import Final, Dict, Callable, Any, List, Optional, Tuple
from llvmlite import ir # type: ignore
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.helpers import dtypes
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
LLVM_FAST_MATH_FLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
code_for_op: Final[Dict[Op, Callable]] = {
UnaryOps.NEG: lambda builder,x: builder.neg(x) if isinstance(x.type, ir.IntType) else builder.fneg(x, flags=LLVM_FAST_MATH_FLAGS),
UnaryOps.EXP2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
UnaryOps.LOG2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
UnaryOps.SIN: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
UnaryOps.SQRT: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sqrt', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
BinaryOps.ADD: lambda builder,x,y: builder.add(x,y) if isinstance(x.type, ir.IntType) else builder.fadd(x,y, flags=LLVM_FAST_MATH_FLAGS),
BinaryOps.SUB: lambda builder,x,y: builder.sub(x,y) if isinstance(x.type, ir.IntType) else builder.fsub(x,y, flags=LLVM_FAST_MATH_FLAGS),
BinaryOps.MUL: lambda builder,x,y: builder.mul(x,y) if isinstance(x.type, ir.IntType) else builder.fmul(x,y, flags=LLVM_FAST_MATH_FLAGS),
BinaryOps.DIV: lambda builder,x,y: builder.sdiv(x,y) if isinstance(x.type, ir.IntType) else builder.fdiv(x,y, flags=LLVM_FAST_MATH_FLAGS),
# TODO: this should be casted
BinaryOps.CMPLT: lambda builder,x,y: builder.zext(builder.icmp_signed("<", x, y),ir.IntType(32)) if isinstance(x.type, ir.IntType) else builder.uitofp(builder.fcmp_ordered("<", x, y, flags=LLVM_FAST_MATH_FLAGS), ir.FloatType()),
BinaryOps.MAX: lambda builder,x,y: builder.select(builder.fcmp_unordered(">", x, y, flags=LLVM_FAST_MATH_FLAGS), x, y, flags=LLVM_FAST_MATH_FLAGS),
BinaryOps.MOD: lambda builder,x,y: builder.srem(x,y),
TernaryOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(x,y, flags=LLVM_FAST_MATH_FLAGS), z, flags=LLVM_FAST_MATH_FLAGS),
TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=LLVM_FAST_MATH_FLAGS) if isinstance(x.type, ir.FloatType) else builder.trunc(x, ir.IntType(1)), y, z, flags=LLVM_FAST_MATH_FLAGS),
}
dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32), dtypes._arg_int32: ir.IntType(32), dtypes.int16:ir.IntType(16), dtypes.uint16:ir.IntType(16), dtypes.uint32:ir.IntType(32), dtypes.uint64:ir.IntType(64)}
def cast(bb, val, input_type, output_type):
if input_type == output_type: return val
if output_type == dtypes.float32:
if dtypes.is_int(input_type) or input_type == dtypes.bool:
val = bb[-1].uitofp(val, ir.FloatType()) if dtypes.is_unsigned(input_type) or input_type == dtypes.bool else bb[-1].sitofp(val, ir.FloatType())
elif input_type == dtypes.bfloat16:
val = bb[-1].sext(val, ir.IntType(32))
val = bb[-1].shl(val, ir.Constant(ir.IntType(32), 16))
val = bb[-1].bitcast(val, ir.FloatType())
elif input_type == dtypes.float64:
val = bb[-1].fptrunc(val, ir.FloatType())
else:
val = bb[-1].fpext(val, ir.FloatType())
return val
if input_type == dtypes.float32:
if dtypes.is_int(output_type) or output_type == dtypes.bool:
if dtypes.is_unsigned(output_type): val = bb[-1].fptoui(val, dtype_to_llvm_dtype[output_type])
elif output_type == dtypes.bool: val = bb[-1].fcmp_ordered("!=", val, ir.Constant(ir.FloatType(), 0), flags=LLVM_FAST_MATH_FLAGS)
else: val = bb[-1].fptosi(val, dtype_to_llvm_dtype[output_type])
elif output_type == dtypes.bfloat16:
val = bb[-1].bitcast(val, ir.IntType(32))
val = bb[-1].lshr(val, ir.Constant(ir.IntType(32), 16))
val = bb[-1].trunc(val, ir.IntType(16))
elif output_type == dtypes.float64:
val = bb[-1].fpext(val, ir.DoubleType())
else:
val = bb[-1].fptrunc(val, dtype_to_llvm_dtype[output_type])
return val
raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
# all llvm stuff goes into a module
module = ir.Module(name=__file__)
# extract global buffers
buf_to_dtype = {args[0]:args[1] for uop,_,_,args,_ in uops if uop == UOps.DEFINE_GLOBAL}
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
# create llvm function
func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values()]
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if dt!=dtypes._arg_int32 else x for x,dt in func_dtypes]), name=function_name)
for a in func.args:
if a.type.is_pointer: a.add_attribute("noalias")
# add the function attribute "no-nans-fp-math"="true", which informs llvm that it allowed to use vectorization optimizations
func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
func.attributes.add('"no-nans-fp-math"="true"')
bb = [ir.IRBuilder(func.append_basic_block("entry"))]
loop_blocks: List = []
reduce_phis: List = []
# TODO: newvar probably shouldn't be optional
lvars: Dict[Optional[UOp], Any] = {} # this Any is an llvm type
for bufname,dtype in buf_to_dtype.items():
if dtype == dtypes._arg_int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
for u in uops:
uop,dtype,vin,args,_ = u
if uop == UOps.LOOP:
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
bb[-2].branch(bb[-1]._block)
phis = []
for rp in reduce_phis:
incoming = lvars[rp]
lvars[rp] = bb[-1].phi(ir.FloatType())
lvars[rp].add_incoming(incoming, bb[-2]._block)
phis.append((rp, lvars[rp]))
lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
lvars[u].add_incoming(lvars[vin[0]], bb[-2]._block)
loop_blocks.append((bb[-1], phis))
if uop == UOps.END:
block, phis = loop_blocks.pop()
idx_p1 = bb[-1].add(lvars[vin[0]], ir.Constant(ir.IntType(32), 1))
lvars[vin[0]].add_incoming(idx_p1, bb[-1]._block)
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1]._block)
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), block._block, bb[-1]._block)
if uop == UOps.DEFINE_GLOBAL:
lvars[u] = func.args[buf_index[args[0]]]
if uop == UOps.DEFINE_ACC:
lvars[u] = ir.Constant(dtype_to_llvm_dtype[dtype], args)
reduce_phis.append(u)
if uop == UOps.SPECIAL:
lvars[u] = lvars[args.expr]
if uop == UOps.CONST:
value = int(args) if dtypes.is_int(dtype) else bool(args) if dtype == dtypes.bool else args
lvars[u] = ir.Constant(dtype_to_llvm_dtype[dtype], value)
if uop == UOps.LOAD:
assert dtype is not None
if len(vin) > 2:
gate = bb[-1].trunc(lvars[vin[2]], ir.IntType(1))
aug_idx = bb[-1].select(gate, lvars[vin[1]], ir.Constant(ir.IntType(32), 0))
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [aug_idx], inbounds=True))
val = cast(bb, val, vin[0].dtype, dtype)
val = bb[-1].select(gate, val, lvars[vin[3]])
else:
val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
val = cast(bb, val, vin[0].dtype, dtype)
lvars[u] = val
if uop == UOps.PHI:
lvars[u] = lvars[vin[1]]
# PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
backward = vin[0]
while backward.uop == UOps.PHI: backward = backward.vin[0]
lvars[backward] = lvars[u]
if uop == UOps.STORE:
element = cast(bb, lvars[vin[2]], vin[2].dtype, vin[0].dtype)
bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True))
if uop == UOps.ALU:
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin])
bb[-1].ret_void()
return str(module), {}