openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

221 lines
10 KiB

# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
from __future__ import annotations
import functools, operator
from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict, cast
from tinygrad.ops import MovementOps
from tinygrad.helpers import prod, DEBUG, dedup
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, sint
from tinygrad.shape.view import View
@functools.lru_cache(maxsize=None)
def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[Tuple[int, int], ...]:
assert len(shape) == len(strides)
ret = [(shape[0], strides[0])] if shape else []
for i in range(1, len(shape)):
if ret[-1][1] == shape[i]*strides[i] or ret[-1][0] == 1:
ret[-1] = (ret[-1][0] * shape[i], strides[i])
elif shape[i] == 1:
continue
else:
ret.append((shape[i], strides[i]))
return tuple(ret)
def expr_node_mask(view:View, idx, valid=None) -> Node:
expr = [valid] if valid is not None else []
if view.mask is not None:
acc = 1
for ns,(x,y) in reversed(list(zip(view.shape, view.mask))):
if x != 0 or y != ns:
base = ((idx//acc) % ns)
expr += [base >= x, base < y]
acc *= ns
return Variable.ands(expr)
# generate an expression if you have a single idx variable
def expr_node(view:View, idx=None) -> Node:
if idx is None: idx = Variable('idx', 0, prod(view.shape)-1)
ret: List[Node] = [Variable.num(view.offset) if isinstance(view.offset, int) else view.offset] if view.offset else []
acc = 1
for d,s in reversed(to_shape_strides(view.shape, view.strides)):
ret.append(((idx//acc)%d)*s)
acc *= d
return Variable.sum(ret)
# generate an expression if you have a variable or expression for each index
def expr_idxs(view:View, idxs) -> Node:
assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}"
return Variable.sum([Variable.num(view.offset) if isinstance(view.offset, int) else view.offset] + [idx*st for idx,sh,st in zip(idxs, view.shape, view.strides) if sh != 1 and st != 0])
@functools.lru_cache(maxsize=None)
def merge_views(vm2:View, vm1:View) -> Optional[View]:
if vm2.mask: return None # this isn't supported yet
mst = ShapeTracker((vm2, vm1))
strides = mst.real_strides()
if None in strides: return None
return View.create(vm1.shape, cast(Tuple[sint, ...], strides), mst.real_offset(), vm1.mask)
@functools.lru_cache(maxsize=None)
def idxs_to_idx(shape:Tuple[int, ...], idxs) -> Node:
assert len(idxs) == len(shape), "need an idx for all dimensions"
acc = 1
ret = []
for tidx,d in reversed(list(zip(idxs, shape))):
ret.append(tidx * acc)
acc *= d
return Variable.sum(ret)
@dataclass(frozen=True)
class ShapeTracker:
views: Tuple[View, ...]
def __post_init__(self): assert isinstance(self.views, tuple) and all(isinstance(v, View) for v in self.views), "ShapeTracker must be created with a tuple of Views"
@staticmethod
def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),))
@property
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
@property
def shape(self) -> Tuple[sint, ...]: return self.views[-1].shape
# this is the real size (ish)
def size(self): return self.views[-1].size()
def vars(self) -> List[Variable]: return dedup(functools.reduce(operator.add, [v.vars() for v in self.views], []))
@property
def var_vals(self) -> Dict[Variable, int]:
ret:Dict[Variable, int] = {}
for v in self.vars():
var, val = v.unbind()
assert var not in ret or ret[var] == val, f"{var} has conflicted values {val} and {ret[var]}"
ret[var] = val
return ret
def unbind(self) -> ShapeTracker: return ShapeTracker(tuple(v.unbind() for v in self.views))
def to_movement_ops(self) -> List[Tuple[MovementOps, Tuple]]:
to_apply:List[Tuple[MovementOps, Tuple]] = []
for v in self.views:
real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape
real_offset = v.offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0)
# first, we apply the offset
# then, we make it the correct shape
# then, we apply permutations
# TODO: don't use as_strided
to_apply.append((MovementOps.AS_STRIDED, (tuple([s if st != 0 else 1 for s,st in zip(real_shape, v.strides)]), v.strides, real_offset)))
# then, we apply pre expand pads
if v.mask is not None:
pre_expand_pads = tuple((x,s-y) if st != 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
post_expand_pads = tuple((x,s-y) if st == 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
if any(x != (0,0) for x in pre_expand_pads):
to_apply.append((MovementOps.PAD, pre_expand_pads))
real_shape = tuple(x+s[0]+s[1] for x,s in zip(real_shape, pre_expand_pads))
# then, we do any expands
if any(s != 1 and st == 0 for s,st in zip(real_shape, v.strides)): to_apply.append((MovementOps.EXPAND, real_shape))
# lastly, we apply post expand pads
if v.mask is not None and any(x != (0,0) for x in post_expand_pads): to_apply.append((MovementOps.PAD, post_expand_pads))
return to_apply
# these are multiview strides, value is None if it's not a simple strided dimension
# TODO: this can be shared code between simplify and merge_views
def real_offset(self) -> sint:
real_offset, _ = self.expr_node(Variable('zero', 0, 0))
return real_offset.b if isinstance(real_offset, NumNode) else real_offset
# NOTE: if a stride is not always valid, it will be None
def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]:
if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides
idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
idx, valid = self.expr_idxs(idxs)
ret: List[Optional[sint]] = [None] * len(self.views[-1].shape)
for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable) and this_dim.a in idxs:
ret[idxs.index(this_dim.a)] = this_dim.b
elif isinstance(this_dim, Variable) and this_dim in idxs:
ret[idxs.index(this_dim)] = 1
idx_vars, valid_vars = idx.vars(), valid.vars()
for i,tidx in enumerate(idxs):
if tidx in valid_vars and not ignore_valid: ret[i] = None
elif tidx not in idx_vars: ret[i] = 0
return tuple(ret)
def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
def _expr_idx(self, idx, valid) -> Tuple[Node, Node]:
for v in reversed(self.views[0:-1]):
if valid.max == 0: return Variable.num(-1), valid
valid = expr_node_mask(v, idx, valid)
idx = expr_node(v, idx)
return idx, valid
def simplify(self) -> ShapeTracker:
if len(self.views) >= 2:
new_view = merge_views(self.views[-2], self.views[-1])
if new_view:
if DEBUG >= 4: print(f"st simplify : {self.views[-2]} + {self.views[-1]} = {new_view}")
return ShapeTracker(self.views[:-2] + (new_view,)).simplify()
return self
def expr_idxs(self, idxs=None):
if idxs is None: idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
idx = expr_idxs(self.views[-1], tuple(idxs))
valid = expr_node_mask(self.views[-1], idxs_to_idx(self.views[-1].shape, tuple(idxs)))
return self._expr_idx(idx, valid)
def expr_node(self, idx='idx'):
if idx.__class__ is str: idx = Variable(idx, 0, prod(self.shape)-1)
return self._expr_idx(expr_node(self.views[-1], idx), expr_node_mask(self.views[-1], idx))
def axis_is_masked(self, axis) -> bool:
_, valid = self.expr_idxs()
return f'idx{axis}' in [v.expr for v in valid.vars()]
# *** under this line are the movement ops ***
def pad(self, arg: Tuple[Tuple[int, int], ...]) -> ShapeTracker:
return ShapeTracker(self.views[0:-1] + (self.views[-1].pad(arg), ))
def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> ShapeTracker:
return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), ))
def expand(self, new_shape: Tuple[sint, ...]) -> ShapeTracker:
return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape), ))
def permute(self, axis: Tuple[int, ...]) -> ShapeTracker:
return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis), ))
def stride(self, mul: Tuple[int, ...]) -> ShapeTracker:
return ShapeTracker(self.views[0:-1] + (self.views[-1].stride(mul), ))
def reshape(self, new_shape: Tuple[sint, ...]) -> ShapeTracker:
new_view = self.views[-1].reshape(new_shape)
if new_view is None:
extra_view = View.create(new_shape)
# last chance to merge. TODO: move into View
if (merged_view := merge_views(self.views[-1], extra_view)) is not None:
return ShapeTracker(self.views[0:-1] + (merged_view,))
return ShapeTracker(self.views + (extra_view, ))
return ShapeTracker(self.views[0:-1] + (new_view,))
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
# TODO: if we remove movementops from lazy.py we can delete this
def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
# Pre-allocate all groups.
axis_groups: List[List[int]] = [[] for _ in range(len(new_shape))]
# Index for new_shape and axis_groups.
i: int = 0
old_shape_i: int = 0
while old_shape_i < len(old_shape):
# 1s exist in new_shape only will lead to empty axes group creations.
if new_shape[i] == 1 and old_shape[old_shape_i] != 1:
if i < len(new_shape) - 1: i += 1
else:
axis_groups[i].append(old_shape_i)
axis_group_size = prod([old_shape[x] for x in axis_groups[i]])
# Move to next axes group if total size of all dimensions match.
if axis_group_size == new_shape[i]:
if i < len(new_shape) - 1: i += 1
elif axis_group_size > new_shape[i]: return None
old_shape_i += 1
return axis_groups