# mixins add syntactic sugar to Tensor and UOp import functools from typing import TypeAlias, TYPE_CHECKING, Self from tinygrad.uop import Ops from tinygrad.helpers import prod, argfix, flatten, dedup, make_tuple, ceildiv from tinygrad.uop.ops import resolve, smax if TYPE_CHECKING: from tinygrad.uop.ops import UOp sint: TypeAlias = "UOp | int" def _align_left(*shapes: tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]: # unsqueeze left to make every shape same length max_dim = max(len(shape) for shape in shapes) return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes) class MovementMixin: # required to implement def _mop(self, op: Ops, arg) -> Self: raise NotImplementedError @property def shape(self) -> tuple[sint, ...]: raise NotImplementedError # great functions you get! @property def ndim(self) -> int: """ Returns the number of dimensions in the tensor. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([[1, 2], [3, 4]]) print(t.ndim) ``` """ return len(self.shape) def numel(self) -> sint: """ Returns the total number of elements in the tensor. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) print(t.numel()) ``` """ return prod(self.shape) def _resolve_dim(self, dim: int, *, extra: bool = False) -> int: total = self.ndim + int(extra) if not -max(1, total) <= dim <= max(1, total) - 1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total) - 1]}") return dim + total if dim < 0 else dim def _broadcast_to(self, new_shape: tuple[sint, ...]) -> Self: if self.shape == new_shape: return self if self.ndim > len(new_shape): raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {new_shape=}") # first unsqueeze left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html shape, _ = _align_left(self.shape, new_shape) # for each dimension, check either dim is 1, or it does not change if not all(s == ns or s == 1 for s, ns in zip(shape, new_shape)): raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}") reshaped = self.reshape(shape) ret = reshaped._mop(Ops.EXPAND, arg=new_shape) return reshaped if ret.shape == reshaped.shape else ret def expand(self, shape, *args) -> Self: """ Returns a tensor that is expanded to the shape that is specified. Expand can also increase the number of dimensions that a tensor has. Passing a `-1` or `None` to a dimension means that its size will not be changed. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([1, 2, 3]) print(t.expand(4, -1).numpy()) ``` """ new_shape = tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_align_left(self.shape, argfix(shape, *args))))) return self._broadcast_to(new_shape) def reshape(self, shape, *args) -> Self: """ Returns a tensor with the same data as the original tensor but with a different shape. `shape` can be passed as a tuple or as separate arguments. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(6) print(t.reshape(2, 3).numpy()) ``` """ # resolve None and args new_shape = tuple([s if s is not None else self.shape[i] for i, s in enumerate(argfix(shape, *args))]) # resolve -1 if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}") if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]) if prod(self.shape) != prod(new_shape): raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})") ret = self._mop(Ops.RESHAPE, arg=new_shape) return self if ret.shape == self.shape else ret def shrink(self, arg: tuple[tuple[sint, sint] | None, ...]) -> Self: """ Returns a tensor that shrinks the each axis based on input arg. `arg` must have the same length as `self.ndim`. For each axis, it can be `None`, which means no shrink, or a tuple `(start, end)` that works the same as Python slice. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(9).reshape(3, 3) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.shrink(((None, (1, 3)))).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.shrink((((0, 2), (0, 2)))).numpy()) ``` """ if self.ndim != len(arg): raise ValueError(f"{self.ndim=} != {len(arg)=}") ret = self._mop(Ops.SHRINK, arg=[x if x is not None else (0, s) for x, s in zip(arg, self.shape)]) return self if ret.shape == self.shape else ret def permute(self, order, *args) -> Self: """ Returns a tensor that is a permutation of the original tensor. The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified. `order` can be passed as a tuple or as separate arguments. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.empty(2, 3, 5) print(t.shape) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.permute(2, 0, 1).shape) ``` """ order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args)) if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}") return self._mop(Ops.PERMUTE, arg=order_arg) if order_arg != tuple(range(self.ndim)) else self def flip(self, axis, *args) -> Self: """ Returns a tensor that reverses the order of the original tensor along given `axis`. `axis` can be passed as a tuple or as separate arguments. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(6).reshape(2, 3) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.flip(0).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.flip((0, 1)).numpy()) ``` """ axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args)) assert all(not isinstance(x, bool) and x >= 0 and x < self.ndim for x in axis_arg), f"flip args must be axis ints {axis_arg}" if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}") flip_arg = tuple([i in axis_arg for i in range(len(self.shape))]) return self._mop(Ops.FLIP, arg=flip_arg) if any(flip_arg) else self # **** high level **** def shrink_to(self, shape, *args) -> Self: return self.shrink(tuple([None if ns is None else (0, ns) for ns in argfix(shape, *args)])) def view(self, shape, *args) -> Self: """`.view` is an alias for `.reshape`.""" return self.reshape(shape, *args) def squeeze(self, dim: int | None = None) -> Self: """ Returns a tensor with specified dimensions of input of size 1 removed. If `dim` is not specified, all dimensions with size 1 are removed. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.zeros(2, 1, 2, 1, 2) print(t.squeeze().shape) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.squeeze(0).shape) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.squeeze(1).shape) ``` """ if dim is None: return self.reshape(tuple(dim for dim in self.shape if dim != 1)) dim = self._resolve_dim(dim) return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim + 1 :]) def unsqueeze(self, dim: int) -> Self: """ Returns a tensor with a new dimension of size 1 inserted at the specified `dim`. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([1, 2, 3, 4]) print(t.unsqueeze(0).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.unsqueeze(1).numpy()) ``` """ dim = self._resolve_dim(dim, extra=True) return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:]) @property def T(self) -> Self: """`.T` is an alias for `.transpose()`.""" return self.transpose() def transpose(self, dim0=1, dim1=0) -> Self: """ Returns a tensor that is a transposed version of the original tensor. The given dimensions `dim0` and `dim1` are swapped. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(6).reshape(2, 3) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.transpose(0, 1).numpy()) ``` """ order = list(range(self.ndim)) order[dim0], order[dim1] = order[dim1], order[dim0] return self.permute(order) def flatten(self, start_dim=0, end_dim=-1) -> Self: """ Flattens the tensor by reshaping it into a one-dimensional tensor. If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(8).reshape(2, 2, 2) print(t.flatten().numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.flatten(start_dim=1).numpy()) ``` """ start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim) return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim : end_dim + 1]),) + self.shape[end_dim + 1 :]) def unflatten(self, dim: int, sizes: tuple[int, ...]) -> Self: """ Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function. ```python exec="true" source="above" session="tensor" result="python" print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape) ``` """ dim = self._resolve_dim(dim) return self.reshape(self.shape[:dim] + sizes + self.shape[dim + 1 :]) def rearrange(self, formula: str, **sizes) -> Self: """ Rearranges input according to formula See: https://einops.rocks/api/rearrange/ ```python exec="true" source="above" session="tensor" result="python" x = Tensor([[1, 2], [3, 4]]) print(Tensor.rearrange(x, "batch channel -> (batch channel)").numpy()) ``` """ def parse_formula(formula: str): tokens = f" {formula} ".replace("…", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split() lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")")) pairs = list(zip(lparens, rparens)) assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch" return [name for name in tokens if name not in ("(", ")")], [(s - 2 * i, e - 1 - 2 * i) for i, (s, e) in enumerate(pairs)] assert formula.count("->") == 1, 'need exactly one "->" in formula' (lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->")) for name in sizes: assert name in lhs, f"axis {name} is not used in transform" assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}" for name in flatten((lhs, rhs)): assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}" assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}" assert lhs.count("...") <= 1, f"too many ellipses in {formula}" # resolve ellipsis if "..." in lhs: ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims) lhs, rhs = map(lambda l: l[: (i := l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1 :] if "..." in l else l, (lhs, rhs)) unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims] flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims] # apply movement ops in order unflatten -> permute -> flatten/unsqueeze t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self) for i, name in enumerate(lhs): assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect" t = t.permute([lhs.index(name) for name in rhs]) return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0] < dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t) # *** movement ops with expand *** def repeat_interleave(self, repeats: int, dim: int | None = None) -> Self: """ Repeats elements of a tensor. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([1, 2, 3]) print(t.repeat_interleave(2).numpy()) ``` """ x, dim = (self.flatten(), 0) if dim is None else (self, self._resolve_dim(dim)) shp = x.shape x = x.reshape(*shp[: dim + 1], 1, *shp[dim + 1 :]) x = x.expand(*shp[: dim + 1], repeats, *shp[dim + 1 :]) x = x.reshape(*shp[:dim], shp[dim] * repeats, *shp[dim + 1 :]) return x def repeat(self, repeats, *args) -> Self: """ Repeats tensor number of times along each dimension specified by `repeats`. `repeats` can be passed as a tuple or as separate arguments. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([1, 2, 3]) print(t.repeat(4, 2).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.repeat(4, 2, 1).shape) ``` """ repeats = argfix(repeats, *args) base_shape = _align_left(self.shape, repeats)[0] unsqueezed_shape = flatten([[s] if r == 1 else [1, s] for r, s in zip(repeats, base_shape)]) expanded_shape = flatten([[s] if r == 1 else [r, s] for r, s in zip(repeats, base_shape)]) final_shape = [r * s for r, s in zip(repeats, base_shape)] return self.reshape(unsqueezed_shape).expand(expanded_shape).reshape(final_shape) # **** pool level **** def _pool(self, k_: tuple[sint, ...], stride: int | tuple[int, ...] = 1, dilation: int | tuple[int, ...] = 1) -> Self: assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}" s_, d_ = make_tuple(stride, len(k_)), make_tuple(dilation, len(k_)) assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}" noop, i_ = [None] * (self.ndim - len(k_)), self.shape[-len(k_) :] assert all(resolve(d * (k - 1) + 1 <= i) for k, d, i in zip(k_, d_, i_)), "kernel size cannot be greater than actual input size" o_ = [ceildiv(i - d * (k - 1), s) for i, d, k, s in zip(i_, d_, k_, s_)] # input size scaling factor to make sure shrink for stride is possible f_ = [smax(1, ceildiv(o * s - d, i)) for o, s, i, d in zip(o_, s_, i_, d_)] # repeats such that we don't need padding x = self.repeat([1] * len(noop) + [ceildiv(k * (i * f + d), i) for k, i, d, f in zip(k_, i_, d_, f_)]) # handle dilation x = x.shrink_to(noop + [k * (i * f + d) for k, i, d, f in zip(k_, i_, d_, f_)]) x = x.reshape(noop + flatten((k, (i * f + d)) for k, i, d, f in zip(k_, i_, d_, f_))) # handle stride x = x.shrink_to(noop + flatten((k, o * s) for k, o, s in zip(k_, o_, s_))).reshape(noop + flatten((k, o, s) for k, o, s in zip(k_, o_, s_))) x = x.shrink_to(noop + flatten((k, o, 1) for k, o in zip(k_, o_))).reshape(noop + flatten((k, o) for k, o in zip(k_, o_))) # permute to move reduce to the end return x.permute(*range(len(noop)), *[len(noop) + i * 2 + 1 for i in range(len(i_))], *[len(noop) + i * 2 for i in range(len(i_))])