# ShapeTracker allows movement operations to a buffer that don't require a copy to be made. from __future__ import annotations from dataclasses import dataclass import functools from typing import Callable from tinygrad.helpers import merge_dicts, getenv from tinygrad.shape.view import View, unravel from tinygrad.uop.symbolic import sym from tinygrad.uop.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context @functools.cache def views_to_valid_uop(views: tuple[View, ...], _idxs:tuple[UOp, ...]|None=None) -> UOp: idx = views[-1].to_valid_uop(_idxs) for view in reversed(views[0:-1]): view = view.minify() idx = view.to_valid_uop([sint_to_uop(i) for i in unravel(view.shape, idx)]) with Context(TRACK_MATCH_STATS=0): return graph_rewrite(idx, sym, name="indexing sym @ 1") @functools.cache def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[sint|None, ...]: # NOTE: if a stride is not always valid, it will be None if len(views) == 1 and views[-1].mask is None: return views[-1].strides ret: list[sint|None] = [None] * len(views[-1].shape) idx, valid = (vidx:=views_to_valid_uop(views)).get_idx(), vidx.get_valid() for c in idx.split_uop(Ops.ADD): if c.op is Ops.RANGE: ret[c.arg[0]] = 1 if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg[0]] = c.src[1].arg if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg[0]] = c.src[0].arg used_ranges = [x.arg[0] for x in idx.toposort() if x.op is Ops.RANGE] ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)] if not ignore_valid: for masked_axis in [x.arg[0] for x in valid.toposort() if x.op is Ops.RANGE]: ret[masked_axis] = None return tuple(ret) @dataclass(frozen=True, order=True) class ShapeTracker: views: tuple[View, ...] def __add__(self, st:ShapeTracker) -> ShapeTracker: ret = self for v in st.views: ret = ShapeTracker(ret.views + (v,)).simplify() # one view at a time = better simplification return ret def invert(self, out_shape:tuple[sint, ...]) -> ShapeTracker|None: inverted_views:list[View] = [] for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]): if (inverted:= v.invert(s)) is None: return None inverted_views.append(inverted) return ShapeTracker(tuple(inverted_views)).reshape(out_shape) @staticmethod def from_shape(shape:tuple[sint, ...], strides:tuple[sint, ...]|None=None) -> ShapeTracker: return ShapeTracker((View.create(shape, strides),)) @property def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous @property def shape(self) -> tuple[sint, ...]: return self.views[-1].shape @property def size(self) -> int: return self.views[-1].size() def reduce(self, axis:tuple[int, ...]) -> tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape)) def to_valid_uop(self, _idxs:list[UOp]|tuple[UOp, ...]|None=None) -> UOp: return views_to_valid_uop(self.views, tuple(_idxs) if _idxs is not None else None) # upper bound on buffer size required to fit this shapetracker def real_size(self) -> int: if 0 in self.shape: return 0 view = (v.shrink(v.mask) if (v:=self.views[0]).mask else v) idx = views_to_valid_uop((view,)).get_idx() assert idx.vmax < 1e12, f"real_size broken for {self}" return int(idx.vmax + 1) def vars(self) -> set[Variable]: return set().union(*[v.vars() for v in self.views]) @property def var_vals(self) -> dict[str, int]: return merge_dicts([{(vu:=v.unbind())[0].expr:vu[1]} for v in self.vars()]) def unbind(self) -> tuple[ShapeTracker, dict[Variable, int]]: unbound_views, var_vals = zip(*[v.unbind() for v in self.views]) if all(len(x) == 0 for x in var_vals): return self, {} return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals) def substitute(self, dvars:dict[UOp, UOp]): return ShapeTracker(tuple(x.substitute(dvars) for x in self.views)) def real_strides(self, ignore_valid=False) -> tuple[sint|None, ...]: with Context(TRACK_MATCH_STATS=0): return views_to_real_strides(self.views, ignore_valid) def unit_stride_axes(self, ignore_valid=False) -> list[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1] def simplify(self) -> ShapeTracker: if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None: return ShapeTracker(self.views[:-2] + (new_view,)).simplify() return self # *** under this line are the movement ops *** def pad(self, arg: tuple[tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].pad(arg), )) def shrink(self, arg: tuple[tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), )) def expand(self, new_shape: tuple[sint, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape), )) def permute(self, axis: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis), )) def flip(self, mul: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].flip(mul), )) def reshape(self, new_shape: tuple[sint, ...]) -> ShapeTracker: if getenv("MERGE_VIEW", 1) and (new_view := self.views[-1].reshape(new_shape)) is not None: return ShapeTracker(self.views[0:-1] + (new_view,)) return ShapeTracker(self.views + (View.create(new_shape), )) def mop(self, op, arg): return mops[op](self, arg) mops: dict[Ops, Callable] = {Ops.RESHAPE: ShapeTracker.reshape, Ops.PERMUTE: ShapeTracker.permute, Ops.EXPAND: ShapeTracker.expand, Ops.SHRINK: ShapeTracker.shrink, Ops.FLIP: ShapeTracker.flip, Ops.PAD: ShapeTracker.pad}