You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
144 lines
7.5 KiB
144 lines
7.5 KiB
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
|
|
from __future__ import annotations
|
|
from dataclasses import dataclass
|
|
import functools
|
|
from typing import Optional, Callable
|
|
from tinygrad.helpers import merge_dicts, getenv
|
|
from tinygrad.shape.view import View, strides_for_shape, unravel
|
|
from tinygrad.dtype import dtypes
|
|
from tinygrad.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context
|
|
from tinygrad.codegen.symbolic import sym, split_uop, symbolic_flat, uop_given_valid, simplify_valid
|
|
|
|
def overflow(u: UOp): return u.vmax > dtypes.max(dtypes.int) or u.vmin < dtypes.min(dtypes.int)
|
|
|
|
# If a node overflow, its srcs need to be checked to see if this overflow is the result of an ALU operation,
|
|
# or that the node simply inherits the dtype from srcs. Upcast is either `Ops.CAST`+`replace` or just `replace`.
|
|
def upcast(u: UOp):
|
|
srcs = tuple(upcast(_src) for _src in u.src)
|
|
if u.dtype.scalar() is dtypes.int:
|
|
dtype = dtypes.int64.vec(u.dtype.count) if u.dtype.count > 1 else dtypes.int64
|
|
upcasted = u.replace(dtype=dtype, src=tuple([_src.cast(dtype) for _src in srcs]))
|
|
if overflow(u): return upcasted
|
|
# Check the original src, new srcs has Ops.CAST whose vmin, vmax change the real bounds
|
|
# Cast back is required because if the node is in range, siblings would never be upcasted
|
|
if any((overflow(src) for src in u.src)): return upcasted.cast(u.dtype)
|
|
return u.replace(src=tuple(srcs))
|
|
|
|
# pooling op may overflow before folding causing unnecessary upcast
|
|
def folded_upcast(u: UOp):
|
|
with Context(TRACK_MATCH_STATS=0):
|
|
return upcast(graph_rewrite(u, sym, {}))
|
|
|
|
@functools.lru_cache(None)
|
|
def views_to_indexed_uops(views: tuple[View, ...], _idxs:Optional[tuple[UOp, ...]]=None) -> tuple[UOp, UOp]:
|
|
idx, valid = views[-1].to_indexed_uops(_idxs)
|
|
for view in reversed(views[0:-1]):
|
|
view = view.minify()
|
|
idx, valid = view.to_indexed_uops([sint_to_uop(i) for i in unravel(view.shape, idx)], valid)
|
|
return idx, valid
|
|
|
|
@functools.lru_cache(None)
|
|
def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[Optional[sint], ...]:
|
|
# NOTE: if a stride is not always valid, it will be None
|
|
if len(views) == 1 and views[-1].mask is None: return views[-1].strides
|
|
ret: list[Optional[sint]] = [None] * len(views[-1].shape)
|
|
idx, valid = (graph_rewrite(u, symbolic_flat) for u in views_to_indexed_uops(views))
|
|
# TODO: always apply these in to_indexed_uops?
|
|
if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
|
|
if (newidx:=uop_given_valid(valid, idx)) is not None: idx = graph_rewrite(newidx, symbolic_flat)
|
|
for c in split_uop(idx, Ops.ADD):
|
|
if c.op is Ops.RANGE: ret[c.arg] = 1
|
|
if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg] = c.src[1].arg
|
|
if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg] = c.src[0].arg
|
|
used_ranges = [x.arg for x in idx.toposort if x.op is Ops.RANGE]
|
|
ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)]
|
|
if not ignore_valid:
|
|
for masked_axis in [x.arg for x in valid.toposort if x.op is Ops.RANGE]: ret[masked_axis] = None
|
|
return tuple(ret)
|
|
|
|
@dataclass(frozen=True, order=True)
|
|
class ShapeTracker:
|
|
views: tuple[View, ...]
|
|
|
|
def __add__(self, st:ShapeTracker) -> ShapeTracker:
|
|
ret = self
|
|
for v in st.views: ret = ShapeTracker(ret.views + (v,)).simplify() # one view at a time = better simplification
|
|
return ret
|
|
|
|
def invert(self, out_shape:tuple[sint, ...]) -> Optional[ShapeTracker]:
|
|
inverted_views:list[View] = []
|
|
for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]):
|
|
if (inverted:= v.invert(s)) is None: return None
|
|
inverted_views.append(inverted)
|
|
return ShapeTracker(tuple(inverted_views)).reshape(out_shape)
|
|
|
|
@staticmethod
|
|
def from_shape(shape:tuple[sint, ...]) -> ShapeTracker: return ShapeTracker((View.create(shape),))
|
|
|
|
@property
|
|
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
|
|
|
|
@property
|
|
def consecutive(self) -> bool: return len(self.views) == 1 and (v:=self.views[0]).mask is None and v.strides == strides_for_shape(v.shape)
|
|
|
|
@property
|
|
def shape(self) -> tuple[sint, ...]: return self.views[-1].shape
|
|
|
|
@property
|
|
def size(self) -> int: return self.views[-1].size()
|
|
|
|
def reduce(self, axis:tuple[int, ...]) -> tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape))
|
|
|
|
def to_uop(self) -> UOp: return UOp(Ops.VIEW, dtypes.void, (), self)
|
|
def to_indexed_uops(self, _idxs:Optional[list[UOp]|tuple[UOp, ...]]=None) -> tuple[UOp, UOp]:
|
|
idx, valid = views_to_indexed_uops(self.views, tuple(_idxs) if _idxs is not None else None)
|
|
return folded_upcast(idx), folded_upcast(valid)
|
|
|
|
# upper bound on buffer size required to fit this shapetracker
|
|
def real_size(self) -> int:
|
|
if 0 in self.shape: return 0
|
|
view = (v.shrink(v.mask) if (v:=self.views[0]).mask else v)
|
|
idx, _ = views_to_indexed_uops((view,))
|
|
assert idx.vmax < 1e12, f"real_size broken for {self}"
|
|
return int(idx.vmax + 1)
|
|
|
|
def vars(self) -> set[Variable]: return set().union(*[v.vars() for v in self.views])
|
|
|
|
@property
|
|
def var_vals(self) -> dict[Variable, int]: return merge_dicts([dict([v.unbind()]) for v in self.vars()])
|
|
|
|
def unbind(self) -> tuple[ShapeTracker, dict[Variable, int]]:
|
|
unbound_views, var_vals = zip(*[v.unbind() for v in self.views])
|
|
if all(len(x) == 0 for x in var_vals): return self, {}
|
|
return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals)
|
|
|
|
def real_strides(self, ignore_valid=False) -> tuple[Optional[sint], ...]:
|
|
with Context(TRACK_MATCH_STATS=0): return views_to_real_strides(self.views, ignore_valid)
|
|
def unit_stride_axes(self, ignore_valid=False) -> list[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
|
|
|
|
def axis_is_masked(self, axis:int) -> bool:
|
|
with Context(TRACK_MATCH_STATS=0):
|
|
_, valid = self.to_indexed_uops()
|
|
return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).toposort if x.op is Ops.RANGE]
|
|
|
|
def simplify(self) -> ShapeTracker:
|
|
if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:
|
|
return ShapeTracker(self.views[:-2] + (new_view,)).simplify()
|
|
return self
|
|
|
|
# *** under this line are the movement ops ***
|
|
|
|
def pad(self, arg: tuple[tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].pad(arg), ))
|
|
def shrink(self, arg: tuple[tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), ))
|
|
def expand(self, new_shape: tuple[sint, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape), ))
|
|
def permute(self, axis: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis), ))
|
|
def flip(self, mul: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].flip(mul), ))
|
|
|
|
def reshape(self, new_shape: tuple[sint, ...]) -> ShapeTracker:
|
|
if getenv("MERGE_VIEW", 1) and (new_view := self.views[-1].reshape(new_shape)) is not None: return ShapeTracker(self.views[0:-1] + (new_view,))
|
|
return ShapeTracker(self.views + (View.create(new_shape), ))
|
|
|
|
def mop(self, op, arg): return mops[op](self, arg)
|
|
|
|
mops: dict[Ops, Callable] = {Ops.RESHAPE: ShapeTracker.reshape, Ops.PERMUTE: ShapeTracker.permute, Ops.EXPAND: ShapeTracker.expand,
|
|
Ops.SHRINK: ShapeTracker.shrink, Ops.FLIP: ShapeTracker.flip, Ops.PAD: ShapeTracker.pad}
|
|
|