# ShapeTracker allows movement operations to a buffer that don't require a copy to be made. from __future__ import annotations from dataclasses import dataclass import functools from typing import Optional, Callable from tinygrad.helpers import merge_dicts, getenv from tinygrad.shape.view import View, strides_for_shape from tinygrad.dtype import dtypes from tinygrad.ops import UOp, Ops, graph_rewrite, split_uop, symbolic_flat, Variable, sint, uop_given_valid, simplify_valid @functools.lru_cache(None) def views_to_indexed_uops(views: tuple[View, ...], _idxs:Optional[tuple[UOp, ...]]=None) -> tuple[UOp, UOp]: idx, valid = views[-1].to_indexed_uops(_idxs) for view in reversed(views[0:-1]): view = view.minify() acc, idxs = 1, [] for d in reversed(view.shape): idxs.append((idx//acc)%d) acc *= d idx, valid = view.to_indexed_uops(idxs[::-1], valid) return idx, valid @functools.lru_cache(None) def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[Optional[sint], ...]: # NOTE: if a stride is not always valid, it will be None if len(views) == 1 and views[-1].mask is None: return views[-1].strides ret: list[Optional[sint]] = [None] * len(views[-1].shape) idx, valid = (graph_rewrite(u, symbolic_flat) for u in views_to_indexed_uops(views)) # TODO: always apply these in to_indexed_uops? if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid if (newidx:=uop_given_valid(valid, idx)) is not None: idx = graph_rewrite(newidx, symbolic_flat) for c in split_uop(idx, Ops.ADD): if c.op is Ops.RANGE: ret[c.arg] = 1 if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg] = c.src[1].arg if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg] = c.src[0].arg used_ranges = [x.arg for x in idx.toposort if x.op is Ops.RANGE] ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)] if not ignore_valid: for masked_axis in [x.arg for x in valid.toposort if x.op is Ops.RANGE]: ret[masked_axis] = None return tuple(ret) @dataclass(frozen=True, order=True) class ShapeTracker: views: tuple[View, ...] def __add__(self, st:ShapeTracker) -> ShapeTracker: ret = self for v in st.views: ret = ShapeTracker(ret.views + (v,)).simplify() # one view at a time = better simplification return ret def invert(self, out_shape:tuple[sint, ...]) -> Optional[ShapeTracker]: inverted_views:list[View] = [] for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]): if (inverted:= v.invert(s)) is None: return None inverted_views.append(inverted) return ShapeTracker(tuple(inverted_views)).reshape(out_shape) @staticmethod def from_shape(shape:tuple[sint, ...]) -> ShapeTracker: return ShapeTracker((View.create(shape),)) @property def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous @property def consecutive(self) -> bool: return len(self.views) == 1 and (v:=self.views[0]).mask is None and v.strides == strides_for_shape(v.shape) @property def shape(self) -> tuple[sint, ...]: return self.views[-1].shape @property def size(self) -> int: return self.views[-1].size() def reduce(self, axis:tuple[int, ...]) -> tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape)) def to_uop(self) -> UOp: return UOp(Ops.VIEW, dtypes.void, (), self) def to_indexed_uops(self, _idxs:Optional[list[UOp]|tuple[UOp, ...]]=None) -> tuple[UOp, UOp]: return views_to_indexed_uops(self.views, tuple(_idxs) if _idxs is not None else None) def real_size(self) -> int: if 0 in self.shape: return 0 idx, valid = self.to_indexed_uops() if not valid.vmax: return 0 assert idx.vmax < 1e12, f"real_size broken for {self}" return int(idx.vmax+1) def vars(self) -> set[Variable]: return set().union(*[v.vars() for v in self.views]) @property def var_vals(self) -> dict[Variable, int]: return merge_dicts([dict([v.unbind()]) for v in self.vars()]) def unbind(self) -> tuple[ShapeTracker, dict[Variable, int]]: unbound_views, var_vals = zip(*[v.unbind() for v in self.views]) return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals) def real_strides(self, ignore_valid=False) -> tuple[Optional[sint], ...]: return views_to_real_strides(self.views, ignore_valid) def unit_stride_axes(self, ignore_valid=False) -> list[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1] def axis_is_masked(self, axis:int) -> bool: _, valid = self.to_indexed_uops() return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).toposort if x.op is Ops.RANGE] def simplify(self) -> ShapeTracker: if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None: return ShapeTracker(self.views[:-2] + (new_view,)).simplify() return self # *** under this line are the movement ops *** def pad(self, arg: tuple[tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].pad(arg), )) def shrink(self, arg: tuple[tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), )) def expand(self, new_shape: tuple[sint, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape), )) def permute(self, axis: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis), )) def stride(self, mul: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].stride(mul), )) def reshape(self, new_shape: tuple[sint, ...]) -> ShapeTracker: if getenv("MERGE_VIEW", 1) and (new_view := self.views[-1].reshape(new_shape)) is not None: return ShapeTracker(self.views[0:-1] + (new_view,)) return ShapeTracker(self.views + (View.create(new_shape), )) def mop(self, op, arg): return mops[op](self, arg) mops: dict[Ops, Callable] = {Ops.RESHAPE: ShapeTracker.reshape, Ops.PERMUTE: ShapeTracker.permute, Ops.EXPAND: ShapeTracker.expand, Ops.SHRINK: ShapeTracker.shrink, Ops.STRIDE: ShapeTracker.stride, Ops.PAD: ShapeTracker.pad}