You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
376 lines
16 KiB
376 lines
16 KiB
# mixins add syntactic sugar to Tensor and UOp
|
|
import functools
|
|
from typing import TypeAlias, TYPE_CHECKING, Self
|
|
from tinygrad.uop import Ops
|
|
from tinygrad.helpers import prod, argfix, flatten, dedup, make_tuple, ceildiv
|
|
from tinygrad.uop.ops import resolve, smax
|
|
|
|
if TYPE_CHECKING:
|
|
from tinygrad.uop.ops import UOp
|
|
sint: TypeAlias = "UOp | int"
|
|
|
|
|
|
def _align_left(*shapes: tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]:
|
|
# unsqueeze left to make every shape same length
|
|
max_dim = max(len(shape) for shape in shapes)
|
|
return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes)
|
|
|
|
|
|
class MovementMixin:
|
|
# required to implement
|
|
def _mop(self, op: Ops, arg) -> Self:
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def shape(self) -> tuple[sint, ...]:
|
|
raise NotImplementedError
|
|
|
|
# great functions you get!
|
|
@property
|
|
def ndim(self) -> int:
|
|
"""
|
|
Returns the number of dimensions in the tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([[1, 2], [3, 4]])
|
|
print(t.ndim)
|
|
```
|
|
"""
|
|
return len(self.shape)
|
|
|
|
def numel(self) -> sint:
|
|
"""
|
|
Returns the total number of elements in the tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
|
|
print(t.numel())
|
|
```
|
|
"""
|
|
return prod(self.shape)
|
|
|
|
def _resolve_dim(self, dim: int, *, extra: bool = False) -> int:
|
|
total = self.ndim + int(extra)
|
|
if not -max(1, total) <= dim <= max(1, total) - 1:
|
|
raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total) - 1]}")
|
|
return dim + total if dim < 0 else dim
|
|
|
|
def _broadcast_to(self, new_shape: tuple[sint, ...]) -> Self:
|
|
if self.shape == new_shape:
|
|
return self
|
|
if self.ndim > len(new_shape):
|
|
raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {new_shape=}")
|
|
# first unsqueeze left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html
|
|
shape, _ = _align_left(self.shape, new_shape)
|
|
# for each dimension, check either dim is 1, or it does not change
|
|
if not all(s == ns or s == 1 for s, ns in zip(shape, new_shape)):
|
|
raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}")
|
|
reshaped = self.reshape(shape)
|
|
ret = reshaped._mop(Ops.EXPAND, arg=new_shape)
|
|
return reshaped if ret.shape == reshaped.shape else ret
|
|
|
|
def expand(self, shape, *args) -> Self:
|
|
"""
|
|
Returns a tensor that is expanded to the shape that is specified.
|
|
Expand can also increase the number of dimensions that a tensor has.
|
|
|
|
Passing a `-1` or `None` to a dimension means that its size will not be changed.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([1, 2, 3])
|
|
print(t.expand(4, -1).numpy())
|
|
```
|
|
"""
|
|
new_shape = tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_align_left(self.shape, argfix(shape, *args)))))
|
|
return self._broadcast_to(new_shape)
|
|
|
|
def reshape(self, shape, *args) -> Self:
|
|
"""
|
|
Returns a tensor with the same data as the original tensor but with a different shape.
|
|
`shape` can be passed as a tuple or as separate arguments.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.arange(6)
|
|
print(t.reshape(2, 3).numpy())
|
|
```
|
|
"""
|
|
# resolve None and args
|
|
new_shape = tuple([s if s is not None else self.shape[i] for i, s in enumerate(argfix(shape, *args))])
|
|
# resolve -1
|
|
if (c := new_shape.count(-1)) > 1:
|
|
raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
|
|
if c:
|
|
new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
|
|
if prod(self.shape) != prod(new_shape):
|
|
raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})")
|
|
ret = self._mop(Ops.RESHAPE, arg=new_shape)
|
|
return self if ret.shape == self.shape else ret
|
|
|
|
def shrink(self, arg: tuple[tuple[sint, sint] | None, ...]) -> Self:
|
|
"""
|
|
Returns a tensor that shrinks the each axis based on input arg.
|
|
`arg` must have the same length as `self.ndim`.
|
|
For each axis, it can be `None`, which means no shrink, or a tuple `(start, end)` that works the same as Python slice.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.arange(9).reshape(3, 3)
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.shrink(((None, (1, 3)))).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.shrink((((0, 2), (0, 2)))).numpy())
|
|
```
|
|
"""
|
|
if self.ndim != len(arg):
|
|
raise ValueError(f"{self.ndim=} != {len(arg)=}")
|
|
ret = self._mop(Ops.SHRINK, arg=[x if x is not None else (0, s) for x, s in zip(arg, self.shape)])
|
|
return self if ret.shape == self.shape else ret
|
|
|
|
def permute(self, order, *args) -> Self:
|
|
"""
|
|
Returns a tensor that is a permutation of the original tensor.
|
|
The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified.
|
|
`order` can be passed as a tuple or as separate arguments.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.empty(2, 3, 5)
|
|
print(t.shape)
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.permute(2, 0, 1).shape)
|
|
```
|
|
"""
|
|
order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
|
|
if sorted(order_arg) != list(range(self.ndim)):
|
|
raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
|
|
return self._mop(Ops.PERMUTE, arg=order_arg) if order_arg != tuple(range(self.ndim)) else self
|
|
|
|
def flip(self, axis, *args) -> Self:
|
|
"""
|
|
Returns a tensor that reverses the order of the original tensor along given `axis`.
|
|
`axis` can be passed as a tuple or as separate arguments.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.arange(6).reshape(2, 3)
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.flip(0).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.flip((0, 1)).numpy())
|
|
```
|
|
"""
|
|
axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
|
|
assert all(not isinstance(x, bool) and x >= 0 and x < self.ndim for x in axis_arg), f"flip args must be axis ints {axis_arg}"
|
|
if len(axis_arg) != len(dedup(axis_arg)):
|
|
raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
|
|
flip_arg = tuple([i in axis_arg for i in range(len(self.shape))])
|
|
return self._mop(Ops.FLIP, arg=flip_arg) if any(flip_arg) else self
|
|
|
|
# **** high level ****
|
|
|
|
def shrink_to(self, shape, *args) -> Self:
|
|
return self.shrink(tuple([None if ns is None else (0, ns) for ns in argfix(shape, *args)]))
|
|
|
|
def view(self, shape, *args) -> Self:
|
|
"""`.view` is an alias for `.reshape`."""
|
|
return self.reshape(shape, *args)
|
|
|
|
def squeeze(self, dim: int | None = None) -> Self:
|
|
"""
|
|
Returns a tensor with specified dimensions of input of size 1 removed.
|
|
If `dim` is not specified, all dimensions with size 1 are removed.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.zeros(2, 1, 2, 1, 2)
|
|
print(t.squeeze().shape)
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.squeeze(0).shape)
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.squeeze(1).shape)
|
|
```
|
|
"""
|
|
if dim is None:
|
|
return self.reshape(tuple(dim for dim in self.shape if dim != 1))
|
|
dim = self._resolve_dim(dim)
|
|
return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim + 1 :])
|
|
|
|
def unsqueeze(self, dim: int) -> Self:
|
|
"""
|
|
Returns a tensor with a new dimension of size 1 inserted at the specified `dim`.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([1, 2, 3, 4])
|
|
print(t.unsqueeze(0).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.unsqueeze(1).numpy())
|
|
```
|
|
"""
|
|
dim = self._resolve_dim(dim, extra=True)
|
|
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
|
|
|
|
@property
|
|
def T(self) -> Self:
|
|
"""`.T` is an alias for `.transpose()`."""
|
|
return self.transpose()
|
|
|
|
def transpose(self, dim0=1, dim1=0) -> Self:
|
|
"""
|
|
Returns a tensor that is a transposed version of the original tensor.
|
|
The given dimensions `dim0` and `dim1` are swapped.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.arange(6).reshape(2, 3)
|
|
print(t.numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.transpose(0, 1).numpy())
|
|
```
|
|
"""
|
|
order = list(range(self.ndim))
|
|
order[dim0], order[dim1] = order[dim1], order[dim0]
|
|
return self.permute(order)
|
|
|
|
def flatten(self, start_dim=0, end_dim=-1) -> Self:
|
|
"""
|
|
Flattens the tensor by reshaping it into a one-dimensional tensor.
|
|
If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor.arange(8).reshape(2, 2, 2)
|
|
print(t.flatten().numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.flatten(start_dim=1).numpy())
|
|
```
|
|
"""
|
|
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
|
|
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim : end_dim + 1]),) + self.shape[end_dim + 1 :])
|
|
|
|
def unflatten(self, dim: int, sizes: tuple[int, ...]) -> Self:
|
|
"""
|
|
Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape)
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape)
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape)
|
|
```
|
|
"""
|
|
dim = self._resolve_dim(dim)
|
|
return self.reshape(self.shape[:dim] + sizes + self.shape[dim + 1 :])
|
|
|
|
def rearrange(self, formula: str, **sizes) -> Self:
|
|
"""
|
|
Rearranges input according to formula
|
|
|
|
See: https://einops.rocks/api/rearrange/
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
x = Tensor([[1, 2], [3, 4]])
|
|
print(Tensor.rearrange(x, "batch channel -> (batch channel)").numpy())
|
|
```
|
|
"""
|
|
|
|
def parse_formula(formula: str):
|
|
tokens = f" {formula} ".replace("…", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split()
|
|
lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")"))
|
|
pairs = list(zip(lparens, rparens))
|
|
assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch"
|
|
return [name for name in tokens if name not in ("(", ")")], [(s - 2 * i, e - 1 - 2 * i) for i, (s, e) in enumerate(pairs)]
|
|
|
|
assert formula.count("->") == 1, 'need exactly one "->" in formula'
|
|
|
|
(lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->"))
|
|
|
|
for name in sizes:
|
|
assert name in lhs, f"axis {name} is not used in transform"
|
|
assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}"
|
|
for name in flatten((lhs, rhs)):
|
|
assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}"
|
|
assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}"
|
|
assert lhs.count("...") <= 1, f"too many ellipses in {formula}"
|
|
|
|
# resolve ellipsis
|
|
if "..." in lhs:
|
|
ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims)
|
|
lhs, rhs = map(lambda l: l[: (i := l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1 :] if "..." in l else l, (lhs, rhs))
|
|
unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims]
|
|
flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims]
|
|
|
|
# apply movement ops in order unflatten -> permute -> flatten/unsqueeze
|
|
t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self)
|
|
for i, name in enumerate(lhs):
|
|
assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect"
|
|
t = t.permute([lhs.index(name) for name in rhs])
|
|
return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0] < dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
|
|
|
|
# *** movement ops with expand ***
|
|
|
|
def repeat_interleave(self, repeats: int, dim: int | None = None) -> Self:
|
|
"""
|
|
Repeats elements of a tensor.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([1, 2, 3])
|
|
print(t.repeat_interleave(2).numpy())
|
|
```
|
|
"""
|
|
x, dim = (self.flatten(), 0) if dim is None else (self, self._resolve_dim(dim))
|
|
shp = x.shape
|
|
x = x.reshape(*shp[: dim + 1], 1, *shp[dim + 1 :])
|
|
x = x.expand(*shp[: dim + 1], repeats, *shp[dim + 1 :])
|
|
x = x.reshape(*shp[:dim], shp[dim] * repeats, *shp[dim + 1 :])
|
|
return x
|
|
|
|
def repeat(self, repeats, *args) -> Self:
|
|
"""
|
|
Repeats tensor number of times along each dimension specified by `repeats`.
|
|
`repeats` can be passed as a tuple or as separate arguments.
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
t = Tensor([1, 2, 3])
|
|
print(t.repeat(4, 2).numpy())
|
|
```
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
print(t.repeat(4, 2, 1).shape)
|
|
```
|
|
"""
|
|
repeats = argfix(repeats, *args)
|
|
base_shape = _align_left(self.shape, repeats)[0]
|
|
unsqueezed_shape = flatten([[s] if r == 1 else [1, s] for r, s in zip(repeats, base_shape)])
|
|
expanded_shape = flatten([[s] if r == 1 else [r, s] for r, s in zip(repeats, base_shape)])
|
|
final_shape = [r * s for r, s in zip(repeats, base_shape)]
|
|
return self.reshape(unsqueezed_shape).expand(expanded_shape).reshape(final_shape)
|
|
|
|
# **** pool level ****
|
|
|
|
def _pool(self, k_: tuple[sint, ...], stride: int | tuple[int, ...] = 1, dilation: int | tuple[int, ...] = 1) -> Self:
|
|
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
|
|
s_, d_ = make_tuple(stride, len(k_)), make_tuple(dilation, len(k_))
|
|
assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
|
|
noop, i_ = [None] * (self.ndim - len(k_)), self.shape[-len(k_) :]
|
|
assert all(resolve(d * (k - 1) + 1 <= i) for k, d, i in zip(k_, d_, i_)), "kernel size cannot be greater than actual input size"
|
|
o_ = [ceildiv(i - d * (k - 1), s) for i, d, k, s in zip(i_, d_, k_, s_)]
|
|
# input size scaling factor to make sure shrink for stride is possible
|
|
f_ = [smax(1, ceildiv(o * s - d, i)) for o, s, i, d in zip(o_, s_, i_, d_)]
|
|
# repeats such that we don't need padding
|
|
x = self.repeat([1] * len(noop) + [ceildiv(k * (i * f + d), i) for k, i, d, f in zip(k_, i_, d_, f_)])
|
|
# handle dilation
|
|
x = x.shrink_to(noop + [k * (i * f + d) for k, i, d, f in zip(k_, i_, d_, f_)])
|
|
x = x.reshape(noop + flatten((k, (i * f + d)) for k, i, d, f in zip(k_, i_, d_, f_)))
|
|
# handle stride
|
|
x = x.shrink_to(noop + flatten((k, o * s) for k, o, s in zip(k_, o_, s_))).reshape(noop + flatten((k, o, s) for k, o, s in zip(k_, o_, s_)))
|
|
x = x.shrink_to(noop + flatten((k, o, 1) for k, o in zip(k_, o_))).reshape(noop + flatten((k, o) for k, o in zip(k_, o_)))
|
|
# permute to move reduce to the end
|
|
return x.permute(*range(len(noop)), *[len(noop) + i * 2 + 1 for i in range(len(i_))], *[len(noop) + i * 2 for i in range(len(i_))])
|
|
|