openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

144 lines
7.5 KiB

# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
from __future__ import annotations
from dataclasses import dataclass
import functools
from typing import Optional, Callable
from tinygrad.helpers import merge_dicts, getenv
from tinygrad.shape.view import View, strides_for_shape, unravel
from tinygrad.dtype import dtypes
from tinygrad.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context
from tinygrad.codegen.symbolic import sym, split_uop, symbolic_flat, uop_given_valid, simplify_valid
def overflow(u: UOp): return u.vmax > dtypes.max(dtypes.int) or u.vmin < dtypes.min(dtypes.int)
# If a node overflow, its srcs need to be checked to see if this overflow is the result of an ALU operation,
# or that the node simply inherits the dtype from srcs. Upcast is either `Ops.CAST`+`replace` or just `replace`.
def upcast(u: UOp):
srcs = tuple(upcast(_src) for _src in u.src)
if u.dtype.scalar() is dtypes.int:
dtype = dtypes.int64.vec(u.dtype.count) if u.dtype.count > 1 else dtypes.int64
upcasted = u.replace(dtype=dtype, src=tuple([_src.cast(dtype) for _src in srcs]))
if overflow(u): return upcasted
# Check the original src, new srcs has Ops.CAST whose vmin, vmax change the real bounds
# Cast back is required because if the node is in range, siblings would never be upcasted
if any((overflow(src) for src in u.src)): return upcasted.cast(u.dtype)
return u.replace(src=tuple(srcs))
# pooling op may overflow before folding causing unnecessary upcast
def folded_upcast(u: UOp):
with Context(TRACK_MATCH_STATS=0):
return upcast(graph_rewrite(u, sym, {}))
@functools.lru_cache(None)
def views_to_indexed_uops(views: tuple[View, ...], _idxs:Optional[tuple[UOp, ...]]=None) -> tuple[UOp, UOp]:
idx, valid = views[-1].to_indexed_uops(_idxs)
for view in reversed(views[0:-1]):
view = view.minify()
idx, valid = view.to_indexed_uops([sint_to_uop(i) for i in unravel(view.shape, idx)], valid)
return idx, valid
@functools.lru_cache(None)
def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[Optional[sint], ...]:
# NOTE: if a stride is not always valid, it will be None
if len(views) == 1 and views[-1].mask is None: return views[-1].strides
ret: list[Optional[sint]] = [None] * len(views[-1].shape)
idx, valid = (graph_rewrite(u, symbolic_flat) for u in views_to_indexed_uops(views))
# TODO: always apply these in to_indexed_uops?
if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
if (newidx:=uop_given_valid(valid, idx)) is not None: idx = graph_rewrite(newidx, symbolic_flat)
for c in split_uop(idx, Ops.ADD):
if c.op is Ops.RANGE: ret[c.arg] = 1
if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg] = c.src[1].arg
if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg] = c.src[0].arg
used_ranges = [x.arg for x in idx.toposort if x.op is Ops.RANGE]
ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)]
if not ignore_valid:
for masked_axis in [x.arg for x in valid.toposort if x.op is Ops.RANGE]: ret[masked_axis] = None
return tuple(ret)
@dataclass(frozen=True, order=True)
class ShapeTracker:
views: tuple[View, ...]
def __add__(self, st:ShapeTracker) -> ShapeTracker:
ret = self
for v in st.views: ret = ShapeTracker(ret.views + (v,)).simplify() # one view at a time = better simplification
return ret
def invert(self, out_shape:tuple[sint, ...]) -> Optional[ShapeTracker]:
inverted_views:list[View] = []
for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]):
if (inverted:= v.invert(s)) is None: return None
inverted_views.append(inverted)
return ShapeTracker(tuple(inverted_views)).reshape(out_shape)
@staticmethod
def from_shape(shape:tuple[sint, ...]) -> ShapeTracker: return ShapeTracker((View.create(shape),))
@property
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
@property
def consecutive(self) -> bool: return len(self.views) == 1 and (v:=self.views[0]).mask is None and v.strides == strides_for_shape(v.shape)
@property
def shape(self) -> tuple[sint, ...]: return self.views[-1].shape
@property
def size(self) -> int: return self.views[-1].size()
def reduce(self, axis:tuple[int, ...]) -> tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape))
def to_uop(self) -> UOp: return UOp(Ops.VIEW, dtypes.void, (), self)
def to_indexed_uops(self, _idxs:Optional[list[UOp]|tuple[UOp, ...]]=None) -> tuple[UOp, UOp]:
idx, valid = views_to_indexed_uops(self.views, tuple(_idxs) if _idxs is not None else None)
return folded_upcast(idx), folded_upcast(valid)
# upper bound on buffer size required to fit this shapetracker
def real_size(self) -> int:
if 0 in self.shape: return 0
view = (v.shrink(v.mask) if (v:=self.views[0]).mask else v)
idx, _ = views_to_indexed_uops((view,))
assert idx.vmax < 1e12, f"real_size broken for {self}"
return int(idx.vmax + 1)
def vars(self) -> set[Variable]: return set().union(*[v.vars() for v in self.views])
@property
def var_vals(self) -> dict[Variable, int]: return merge_dicts([dict([v.unbind()]) for v in self.vars()])
def unbind(self) -> tuple[ShapeTracker, dict[Variable, int]]:
unbound_views, var_vals = zip(*[v.unbind() for v in self.views])
if all(len(x) == 0 for x in var_vals): return self, {}
return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals)
def real_strides(self, ignore_valid=False) -> tuple[Optional[sint], ...]:
with Context(TRACK_MATCH_STATS=0): return views_to_real_strides(self.views, ignore_valid)
def unit_stride_axes(self, ignore_valid=False) -> list[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
def axis_is_masked(self, axis:int) -> bool:
with Context(TRACK_MATCH_STATS=0):
_, valid = self.to_indexed_uops()
return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).toposort if x.op is Ops.RANGE]
def simplify(self) -> ShapeTracker:
if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:
return ShapeTracker(self.views[:-2] + (new_view,)).simplify()
return self
# *** under this line are the movement ops ***
def pad(self, arg: tuple[tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].pad(arg), ))
def shrink(self, arg: tuple[tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), ))
def expand(self, new_shape: tuple[sint, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape), ))
def permute(self, axis: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis), ))
def flip(self, mul: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].flip(mul), ))
def reshape(self, new_shape: tuple[sint, ...]) -> ShapeTracker:
if getenv("MERGE_VIEW", 1) and (new_view := self.views[-1].reshape(new_shape)) is not None: return ShapeTracker(self.views[0:-1] + (new_view,))
return ShapeTracker(self.views + (View.create(new_shape), ))
def mop(self, op, arg): return mops[op](self, arg)
mops: dict[Ops, Callable] = {Ops.RESHAPE: ShapeTracker.reshape, Ops.PERMUTE: ShapeTracker.permute, Ops.EXPAND: ShapeTracker.expand,
Ops.SHRINK: ShapeTracker.shrink, Ops.FLIP: ShapeTracker.flip, Ops.PAD: ShapeTracker.pad}