openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

249 lines
14 KiB

from __future__ import annotations
from typing import Optional, Tuple, Union, List, Dict, Any, ClassVar, Type
import os, sys, weakref, importlib, inspect, functools
from weakref import WeakValueDictionary
from tinygrad.helpers import prod, getenv
from tinygrad.shape import ShapeTracker, get_contraction
from tinygrad.ops import DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, get_buffers, get_lazyops, map_buffers
from tinygrad.graph import log_op
# lazy can recurse a lot
sys.setrecursionlimit(10000)
OPT = getenv("OPT", 2)
LAZY = getenv("LAZY", 1)
def get_buffer(name, base='tinygrad.runtime'):
try:
return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{base}.ops_{name}'), inspect.isclass) if (cname.lower() == name + "buffer")][0]
except Exception as e: # NOTE: this can't be put on one line due to mypy issue
print(name, "backend not available", e, file=sys.stderr)
class _Device:
def __init__(self) -> None:
self._buffers : Dict[str, Type[DeviceBuffer]] = {x.upper():get_buffer(x) for x in
[os.path.splitext(x)[0][len("ops_"):] for x in sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "runtime"))) if x.startswith("ops_")] if x is not None}
self.DEFAULT : str = "CPU"
for name in self._buffers:
if getenv(name) == 1: self.DEFAULT = name # note: DEFAULT can be a Device that can't be imported. better than silent use of a different device
if self._buffers[name] is not None: self.__setattr__(name, name)
Device = _Device()
# TODO: movement ops that only change shape are really nops. treat them as such
REMOVE_MOVEMENT_NOPS, MERGE_UNARY_OPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS = OPT>=1, OPT>=1, OPT>=1, OPT>=1
MERGE_ELEMENTWISE_OPS, MERGE_ONE_REDUCE_INTO_ELEMENTWISE = OPT>=2, OPT>=2
PUSH_PERMUTES, PUSH_CONTIGUOUS = OPT>=3, OPT>=3
# **** realize functions ****
def _ast_reduceops(self:LazyBuffer) -> LazyOp:
# TODO: this can also corealize a binary op after the reduce, not just before
src = self.op.src[0]
if MERGE_ELEMENTWISE_INTO_REDUCE and src.realized is None and src.optype == BinaryOps and len(src.children) <= 1:
src = src.op
return LazyOp(self.op.op, (src,), self.op.arg)
# this supports late merging an upstream Reduce op and even an Elementwise op above that
def _ast_binaryops(self:LazyBuffer) -> LazyOp:
real_srcs : Dict[LazyBuffer, Union[None, LazyOp, LazyBuffer]] = {x:None for x in get_buffers(self.op)}
# NOTE: contiguous does not always mean the same size with SHRINK. this is still mergeable but requires more thought how
psrcs : List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype == ReduceOps and x.realized is None and prod(k.shape) == prod(x.shape) and len(x.children) <= 1 and len(k.children) <= 1]
intermediate_shape : Tuple[int, ...] = self.shape
if len(psrcs) == 1 and MERGE_ONE_REDUCE_INTO_ELEMENTWISE:
if psrcs[0][1].optype == ReduceOps:
top = _ast_reduceops(psrcs[0][1])
real_srcs[psrcs[0][0]] = top
real_srcs.update({x:x for x in get_buffers(top)}) # the reduce op buffers are not modified
# if the ReduceOp is followed by a reshape, we push this reshape before all the ElementwiseOp inputs
if psrcs[0][0].shape != psrcs[0][1].shape:
intermediate_shape = psrcs[0][1].shape
assert psrcs[0][0].shape == self.shape, f"shape mismatch {psrcs[0][0].shape} != {self.shape}"
# reshape all the late ops into the output shape
# NOTE: these RESHAPEs will return self if they don't change the shape
for x in real_srcs.keys():
if real_srcs[x] is None: real_srcs[x] = x.movement_op(MovementOps.RESHAPE, intermediate_shape)
ast = map_buffers(real_srcs, self.op)
return LazyOp(MovementOps.RESHAPE, (ast, ), self.shape) if intermediate_shape != self.shape else ast
# **** lazy operations ****
def get_weakop(op:LazyOp) -> LazyOp: return LazyOp(op.op, tuple(get_weakop(x) if isinstance(x, LazyOp) else weakref.ref(x) for x in op.src), op.arg)
def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(root.op.src[0]) if getattr(root, 'op', None) and len(root.op.src) == 1 else root
def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(root.op.src[0], allow_contiguous) if root.realized is None and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root
def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(x.op.src[0]) if x.realized is None and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x)
def replace_with_movement_op(y:Union[LazyOp, LazyBuffer], op:MovementOps, arg:Tuple[Any, ...]) -> LazyBuffer:
if isinstance(y, LazyBuffer): return y.movement_op(op, arg)
assert y.op in BinaryOps or y.op in UnaryOps
return elementwise_op(y.op, *[replace_with_movement_op(z, op, arg) for z in y.src]) # type: ignore
def support_weakref(x): return x
@support_weakref # needed for mypyc, this prevents LazyBuffer from becoming a native class
class LazyBuffer:
__deletable__ = ('op',)
lazycache : ClassVar[WeakValueDictionary[Tuple[str, OpType, LazyOp], LazyBuffer]] = WeakValueDictionary()
def __new__(cls, device:str, shape:Union[ShapeTracker, Tuple[int, ...]], optype:OpType, op:LazyOp):
# fromcpu aren't cached
if optype == LoadOps and op.op == LoadOps.FROMCPU:
return super().__new__(cls)
wop = (device, optype, get_weakop(op)) # NOTE: shape should be deterministic. annoying to cache with the ShapeTracker
# NOTE: we need "ret" to prevent the new buffer from being immediately deleted
if wop not in LazyBuffer.lazycache: LazyBuffer.lazycache[wop] = ret = super().__new__(cls)
else: ret = LazyBuffer.lazycache[wop]
return ret
def __init__(self, device:str, shape:Union[ShapeTracker, Tuple[int, ...]], optype:OpType, op:LazyOp):
if hasattr(self, 'device'):
return # cache hit, we return and don't reinit
self.st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape))
self.shape, self.optype, self.op = self.st.shape, optype, op
self.realized : Optional[DeviceBuffer] = None
self.output_buffer : Optional[DeviceBuffer] = None
self.device, self.dbuffer = device, Device._buffers[device]
# TODO: does children have to be a ref count instead of a set? can a Buffer be a double child?
self.children : weakref.WeakSet[LazyBuffer] = weakref.WeakSet()
# NOTE: op should be read only after construction of LazyBuffer
for x in get_buffers(op): x.children.add(self)
if not LAZY: self.realize()
def __repr__(self): return f"<LB {self.shape} op:{self.op.op if self.realized is None else 'realized'}>"
# this produces a device buffer
def realize(self:LazyBuffer, required_device=None) -> DeviceBuffer:
assert required_device is None or required_device == self.device
if self.realized is None:
# get real ops first
if self.op.op == LoadOps.FROMCPU:
self.realized = Device._buffers[self.device].fromCPU(self.op.arg)
ast = LazyOp(self.op.op, tuple())
elif self.op.op == LoadOps.CONTIGUOUS:
real_src = self.op.src[0].realize(self.device)
self.realized = real_src.contiguous()
ast = LazyOp(self.op.op, (real_src, ))
elif self.optype == MovementOps:
src = self.op.src[0]
# fuse RESHAPE and ReduceOps
if src.realized is None and src.optype == ReduceOps and self.op.op == MovementOps.RESHAPE and len(src.children) <= 1:
# it's okay to add a RESHAPE to the ast here
ast = LazyOp(MovementOps.RESHAPE, (_ast_reduceops(src), ), self.op.arg)
else:
# movement ops aren't an AST, just run them
real_src = src.realize(self.device)
self.realized = real_src.movement_op(self.op.op, self.op.arg)
ast = LazyOp(self.op.op, (real_src, ))
elif self.optype == ReduceOps: ast = _ast_reduceops(self)
elif self.optype == BinaryOps: ast = _ast_binaryops(self)
# no need to keep the op after realization
del self.op
# run the ast if we still have to, and log the op
if self.realized is None:
ast = map_buffers({x:x.realize(self.device) for x in get_buffers(ast)}, ast)
self.realized = self.dbuffer.exec_ast(ast, output_buffer=self.output_buffer)
log_op(self.realized, ast)
assert self.realized.shape == self.shape, f"shape mismatch on realize {self.realized.shape} vs {self.shape}"
assert isinstance(self.realized, Device._buffers[self.device])
return self.realized
@staticmethod
def fromCPU(x, device) -> LazyBuffer: return LazyBuffer(device, x.shape, LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), x.copy()))
def toCPU(self): return self.realize().toCPU()
def unary_op(self:LazyBuffer, op:UnaryOps) -> LazyBuffer: return elementwise_op(op, self)
def binary_op(self:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y)
def contiguous(self:LazyBuffer) -> LazyBuffer: return LazyBuffer(self.device, self.shape, LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,)))
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
if self.shape == tuple(new_shape): return self
reduce = list(enumerate(zip(self.shape, new_shape)))
# move the reduce axes to the end
x = self.movement_op(MovementOps.PERMUTE, tuple([i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n]))
new_tmp_shape = tuple([n for _,(s,n) in reduce if s == n] + [n for _,(s,n) in reduce if s != n])
# NOTE: this reshape can only move around 1s
return LazyBuffer(x.device, new_tmp_shape, ReduceOps, LazyOp(op, (x,), new_tmp_shape)).movement_op(MovementOps.RESHAPE, new_shape)
def movement_op(self:LazyBuffer, op:MovementOps, arg:Tuple[Any, ...]) -> LazyBuffer:
# very instant nop
if op == MovementOps.RESHAPE and self.shape == arg: return self
# TODO: look into why that copy is needed
local_st = ShapeTracker(self.shape).movement_op(op, arg)
# instant nops
if local_st.contiguous and self.shape == local_st.shape: return self
# two ops in a row is one op. merge them if unresolved
if self.realized is None and self.op.op == op:
# TODO: why is deleting self from children needed? shouldn't GC do it?
self.op.src[0].children.discard(self)
if op in [MovementOps.RESHAPE, MovementOps.EXPAND, MovementOps.SHRINK]: return self.op.src[0].movement_op(op, arg)
if op == MovementOps.PERMUTE: return self.op.src[0].movement_op(op, tuple(self.op.arg[i] for i in arg))
if op == MovementOps.PAD: return self.op.src[0].movement_op(op, tuple((b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)))
if op == MovementOps.FLIP: return self.op.src[0].movement_op(op, tuple(i for i in arg+self.op.arg if not (i in arg and i in self.op.arg)))
# push permutes before reduce ops
if op == MovementOps.PERMUTE and PUSH_PERMUTES and self.realized is None and self.optype == ReduceOps:
# reduceops have one buffer input, permute it
narg = tuple(self.op.arg[arg[i]] for i in range(len(arg)))
src, rop = self.op.src[0], self.op.op
src.children.discard(self)
del self # TODO: why doesn't this delete remove it from the children
return src.movement_op(op, arg).reduce_op(rop, narg)
# some permutes are actually just reshapes
if op == MovementOps.PERMUTE and local_st.contiguous: return self.movement_op(MovementOps.RESHAPE, tuple(self.shape[i] for i in arg))
# move permutes before expands
if op == MovementOps.PERMUTE and PUSH_PERMUTES and self.realized is None and self.op.op == MovementOps.EXPAND:
self.op.src[0].children.discard(self)
return self.op.src[0].movement_op(MovementOps.PERMUTE, arg).movement_op(MovementOps.EXPAND, tuple(self.op.arg[a] for a in arg))
# move permutes before reshapes if we can
if op == MovementOps.PERMUTE and PUSH_PERMUTES and self.realized is None and self.op.op == MovementOps.RESHAPE and isinstance(self.op.src[0], LazyBuffer):
if shape_idx_groups := get_contraction(self.op.src[0].shape, self.shape):
new_arg : List[int] = functools.reduce(lambda r, x: r + shape_idx_groups[x], arg, [])
self.op.src[0].children.discard(self) # this changes nothing?
return self.op.src[0].movement_op(MovementOps.PERMUTE, tuple(new_arg)) \
.movement_op(MovementOps.RESHAPE, ShapeTracker(self.st).movement_op(op, arg).shape)
# if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead. NOTE: UnaryOps is never an OpType
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and self.realized is None and len(self.children) == 0 and op != MovementOps.EXPAND and (op != MovementOps.PAD or all(x.op != BinaryOps.DIV for x in get_lazyops(self.op))):
return replace_with_movement_op(self.op, op, arg)
# create the buffer
ret = LazyBuffer(self.device, ShapeTracker(self.st).movement_op(op, arg), MovementOps, LazyOp(op, (self,), arg))
# if the ShapeTracker becomes contiguous, replace the whole thing with a reshape (or nothing if shapes match)
# NOTE: if ret is in the cache, it can already be realized
if REMOVE_MOVEMENT_NOPS and ret.realized is None and self.realized is None and ret.st.contiguous:
# MovementOps aren't stacked any more, they each have one parent, find the root
root = get_movementroot(self)
if root.st.contiguous and root != self and prod(ret.st.shape) == prod(root.shape):
return root.movement_op(MovementOps.RESHAPE, ret.st.shape)
return ret
def elementwise_op(op:Union[UnaryOps, BinaryOps], *srcs:LazyBuffer) -> LazyBuffer:
out_device, out_shape = srcs[0].device, srcs[0].shape
# push all contiguous to the end of BinaryOps. kernels 198 -> 196
if PUSH_CONTIGUOUS and any(x.realized is None and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 for x in srcs):
new_srcs = []
for x in srcs:
if x.realized is None and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1:
x.op.src[0].children.discard(x)
new_srcs.append(x.op.src[0])
else:
new_srcs.append(x)
return elementwise_op(op, *new_srcs).contiguous()
if MERGE_ELEMENTWISE_OPS or (MERGE_UNARY_OPS and len(set(srcs)) == 1):
# remove the buffers from any (childless) BinaryOps that feed into this
srcs = tuple(x.op if x.optype == BinaryOps and len(x.children) == 0 and x.realized is None else x for x in srcs) # type: ignore
return LazyBuffer(out_device, out_shape, BinaryOps, LazyOp(op, srcs))