openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

139 lines
7.4 KiB

# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
from __future__ import annotations
from dataclasses import dataclass
import functools
from typing import Optional, Callable
from tinygrad.helpers import merge_dicts, getenv
from tinygrad.shape.view import View, strides_for_shape, unravel
from tinygrad.dtype import dtypes
from tinygrad.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context, PatternMatcher, UPat, GroupOp
from tinygrad.codegen.symbolic import split_uop, symbolic_flat, uop_given_valid, simplify_valid
# If a node overflow, its srcs need to be checked to see if this overflow is the result of an ALU operation,
# or that the node simply inherits the dtype from srcs. Upcast is either `Ops.CAST`+`replace` or just `replace`.
def handle_upcast(u: UOp) -> UOp|None:
dtype = dtypes.int64.vec(u.dtype.count) if u.dtype.count > 1 else dtypes.int64
# check for overflow, upcast this to int64
if u.vmax > dtypes.max(dtypes.int) or u.vmin < dtypes.min(dtypes.int):
return u.replace(dtype=dtype, src=tuple([x.cast(dtype) for x in u.src]))
# if any inputs are int64 and this *doesn't* overflow, cast back to int
if any(x.dtype == dtypes.int64 for x in u.src):
return u.replace(dtype=dtype, src=tuple([x.cast(dtype) for x in u.src])).cast(u.dtype)
return None
pm_upcast = PatternMatcher([(UPat(GroupOp.ALU, dtype=dtypes.int, name="u"), handle_upcast),])
@functools.cache
def views_to_indexed_uops(views: tuple[View, ...], _idxs:Optional[tuple[UOp, ...]]=None) -> tuple[UOp, UOp]:
idx, valid = views[-1].to_indexed_uops(_idxs)
for view in reversed(views[0:-1]):
view = view.minify()
idx, valid = view.to_indexed_uops([sint_to_uop(i) for i in unravel(view.shape, idx)], valid)
# symbolic
idx, valid = graph_rewrite(UOp.sink(idx, valid), symbolic_flat, name="indexing sym @ 1").src
# simplify
if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
if (newidx:=uop_given_valid(valid, idx)) is not None: idx = newidx
# symbolic again, upcast if needed
return graph_rewrite(UOp.sink(idx, valid), symbolic_flat+pm_upcast, name="indexing sym @ 2").src
@functools.cache
def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[Optional[sint], ...]:
# NOTE: if a stride is not always valid, it will be None
if len(views) == 1 and views[-1].mask is None: return views[-1].strides
ret: list[Optional[sint]] = [None] * len(views[-1].shape)
idx, valid = views_to_indexed_uops(views)
for c in split_uop(idx, Ops.ADD):
if c.op is Ops.RANGE: ret[c.arg] = 1
if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg] = c.src[1].arg
if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg] = c.src[0].arg
used_ranges = [x.arg for x in idx.toposort() if x.op is Ops.RANGE]
ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)]
if not ignore_valid:
for masked_axis in [x.arg for x in valid.toposort() if x.op is Ops.RANGE]: ret[masked_axis] = None
return tuple(ret)
@dataclass(frozen=True, order=True)
class ShapeTracker:
views: tuple[View, ...]
def __add__(self, st:ShapeTracker) -> ShapeTracker:
ret = self
for v in st.views: ret = ShapeTracker(ret.views + (v,)).simplify() # one view at a time = better simplification
return ret
def invert(self, out_shape:tuple[sint, ...]) -> Optional[ShapeTracker]:
inverted_views:list[View] = []
for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]):
if (inverted:= v.invert(s)) is None: return None
inverted_views.append(inverted)
return ShapeTracker(tuple(inverted_views)).reshape(out_shape)
@staticmethod
def from_shape(shape:tuple[sint, ...], strides:tuple[sint, ...]|None=None) -> ShapeTracker: return ShapeTracker((View.create(shape, strides),))
@property
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
@property
def consecutive(self) -> bool: return len(self.views) == 1 and (v:=self.views[0]).mask is None and v.strides == strides_for_shape(v.shape)
@property
def shape(self) -> tuple[sint, ...]: return self.views[-1].shape
@property
def size(self) -> int: return self.views[-1].size()
def reduce(self, axis:tuple[int, ...]) -> tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape))
def to_uop(self) -> UOp: return UOp(Ops.VIEW, dtypes.void, (), self)
def to_indexed_uops(self, _idxs:Optional[list[UOp]|tuple[UOp, ...]]=None) -> tuple[UOp, UOp]:
return views_to_indexed_uops(self.views, tuple(_idxs) if _idxs is not None else None)
# upper bound on buffer size required to fit this shapetracker
def real_size(self) -> int:
if 0 in self.shape: return 0
view = (v.shrink(v.mask) if (v:=self.views[0]).mask else v)
idx, _ = views_to_indexed_uops((view,))
assert idx.vmax < 1e12, f"real_size broken for {self}"
return int(idx.vmax + 1)
def vars(self) -> set[Variable]: return set().union(*[v.vars() for v in self.views])
@property
def var_vals(self) -> dict[Variable, int]: return merge_dicts([dict([v.unbind()]) for v in self.vars()])
def unbind(self) -> tuple[ShapeTracker, dict[Variable, int]]:
unbound_views, var_vals = zip(*[v.unbind() for v in self.views])
if all(len(x) == 0 for x in var_vals): return self, {}
return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals)
def real_strides(self, ignore_valid=False) -> tuple[Optional[sint], ...]:
with Context(TRACK_MATCH_STATS=0): return views_to_real_strides(self.views, ignore_valid)
def unit_stride_axes(self, ignore_valid=False) -> list[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
def axis_is_masked(self, axis:int) -> bool:
with Context(TRACK_MATCH_STATS=0):
_, valid = self.to_indexed_uops()
return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).toposort() if x.op is Ops.RANGE]
def simplify(self) -> ShapeTracker:
if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:
return ShapeTracker(self.views[:-2] + (new_view,)).simplify()
return self
# *** under this line are the movement ops ***
def pad(self, arg: tuple[tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].pad(arg), ))
def shrink(self, arg: tuple[tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), ))
def expand(self, new_shape: tuple[sint, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape), ))
def permute(self, axis: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis), ))
def flip(self, mul: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].flip(mul), ))
def reshape(self, new_shape: tuple[sint, ...]) -> ShapeTracker:
if getenv("MERGE_VIEW", 1) and (new_view := self.views[-1].reshape(new_shape)) is not None: return ShapeTracker(self.views[0:-1] + (new_view,))
return ShapeTracker(self.views + (View.create(new_shape), ))
def mop(self, op, arg): return mops[op](self, arg)
mops: dict[Ops, Callable] = {Ops.RESHAPE: ShapeTracker.reshape, Ops.PERMUTE: ShapeTracker.permute, Ops.EXPAND: ShapeTracker.expand,
Ops.SHRINK: ShapeTracker.shrink, Ops.FLIP: ShapeTracker.flip, Ops.PAD: ShapeTracker.pad}