openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

376 lines
16 KiB

# mixins add syntactic sugar to Tensor and UOp
import functools
from typing import TypeAlias, TYPE_CHECKING, Self
from tinygrad.uop import Ops
from tinygrad.helpers import prod, argfix, flatten, dedup, make_tuple, ceildiv
from tinygrad.uop.ops import resolve, smax
if TYPE_CHECKING:
from tinygrad.uop.ops import UOp
sint: TypeAlias = "UOp | int"
def _align_left(*shapes: tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]:
# unsqueeze left to make every shape same length
max_dim = max(len(shape) for shape in shapes)
return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes)
class MovementMixin:
# required to implement
def _mop(self, op: Ops, arg) -> Self:
raise NotImplementedError
@property
def shape(self) -> tuple[sint, ...]:
raise NotImplementedError
# great functions you get!
@property
def ndim(self) -> int:
"""
Returns the number of dimensions in the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 2], [3, 4]])
print(t.ndim)
```
"""
return len(self.shape)
def numel(self) -> sint:
"""
Returns the total number of elements in the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
print(t.numel())
```
"""
return prod(self.shape)
def _resolve_dim(self, dim: int, *, extra: bool = False) -> int:
total = self.ndim + int(extra)
if not -max(1, total) <= dim <= max(1, total) - 1:
raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total) - 1]}")
return dim + total if dim < 0 else dim
def _broadcast_to(self, new_shape: tuple[sint, ...]) -> Self:
if self.shape == new_shape:
return self
if self.ndim > len(new_shape):
raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {new_shape=}")
# first unsqueeze left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html
shape, _ = _align_left(self.shape, new_shape)
# for each dimension, check either dim is 1, or it does not change
if not all(s == ns or s == 1 for s, ns in zip(shape, new_shape)):
raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}")
reshaped = self.reshape(shape)
ret = reshaped._mop(Ops.EXPAND, arg=new_shape)
return reshaped if ret.shape == reshaped.shape else ret
def expand(self, shape, *args) -> Self:
"""
Returns a tensor that is expanded to the shape that is specified.
Expand can also increase the number of dimensions that a tensor has.
Passing a `-1` or `None` to a dimension means that its size will not be changed.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3])
print(t.expand(4, -1).numpy())
```
"""
new_shape = tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_align_left(self.shape, argfix(shape, *args)))))
return self._broadcast_to(new_shape)
def reshape(self, shape, *args) -> Self:
"""
Returns a tensor with the same data as the original tensor but with a different shape.
`shape` can be passed as a tuple or as separate arguments.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(6)
print(t.reshape(2, 3).numpy())
```
"""
# resolve None and args
new_shape = tuple([s if s is not None else self.shape[i] for i, s in enumerate(argfix(shape, *args))])
# resolve -1
if (c := new_shape.count(-1)) > 1:
raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
if c:
new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
if prod(self.shape) != prod(new_shape):
raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})")
ret = self._mop(Ops.RESHAPE, arg=new_shape)
return self if ret.shape == self.shape else ret
def shrink(self, arg: tuple[tuple[sint, sint] | None, ...]) -> Self:
"""
Returns a tensor that shrinks the each axis based on input arg.
`arg` must have the same length as `self.ndim`.
For each axis, it can be `None`, which means no shrink, or a tuple `(start, end)` that works the same as Python slice.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(9).reshape(3, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.shrink(((None, (1, 3)))).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.shrink((((0, 2), (0, 2)))).numpy())
```
"""
if self.ndim != len(arg):
raise ValueError(f"{self.ndim=} != {len(arg)=}")
ret = self._mop(Ops.SHRINK, arg=[x if x is not None else (0, s) for x, s in zip(arg, self.shape)])
return self if ret.shape == self.shape else ret
def permute(self, order, *args) -> Self:
"""
Returns a tensor that is a permutation of the original tensor.
The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified.
`order` can be passed as a tuple or as separate arguments.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.empty(2, 3, 5)
print(t.shape)
```
```python exec="true" source="above" session="tensor" result="python"
print(t.permute(2, 0, 1).shape)
```
"""
order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
if sorted(order_arg) != list(range(self.ndim)):
raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
return self._mop(Ops.PERMUTE, arg=order_arg) if order_arg != tuple(range(self.ndim)) else self
def flip(self, axis, *args) -> Self:
"""
Returns a tensor that reverses the order of the original tensor along given `axis`.
`axis` can be passed as a tuple or as separate arguments.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(6).reshape(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.flip(0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.flip((0, 1)).numpy())
```
"""
axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
assert all(not isinstance(x, bool) and x >= 0 and x < self.ndim for x in axis_arg), f"flip args must be axis ints {axis_arg}"
if len(axis_arg) != len(dedup(axis_arg)):
raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
flip_arg = tuple([i in axis_arg for i in range(len(self.shape))])
return self._mop(Ops.FLIP, arg=flip_arg) if any(flip_arg) else self
# **** high level ****
def shrink_to(self, shape, *args) -> Self:
return self.shrink(tuple([None if ns is None else (0, ns) for ns in argfix(shape, *args)]))
def view(self, shape, *args) -> Self:
"""`.view` is an alias for `.reshape`."""
return self.reshape(shape, *args)
def squeeze(self, dim: int | None = None) -> Self:
"""
Returns a tensor with specified dimensions of input of size 1 removed.
If `dim` is not specified, all dimensions with size 1 are removed.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.zeros(2, 1, 2, 1, 2)
print(t.squeeze().shape)
```
```python exec="true" source="above" session="tensor" result="python"
print(t.squeeze(0).shape)
```
```python exec="true" source="above" session="tensor" result="python"
print(t.squeeze(1).shape)
```
"""
if dim is None:
return self.reshape(tuple(dim for dim in self.shape if dim != 1))
dim = self._resolve_dim(dim)
return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim + 1 :])
def unsqueeze(self, dim: int) -> Self:
"""
Returns a tensor with a new dimension of size 1 inserted at the specified `dim`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3, 4])
print(t.unsqueeze(0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.unsqueeze(1).numpy())
```
"""
dim = self._resolve_dim(dim, extra=True)
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
@property
def T(self) -> Self:
"""`.T` is an alias for `.transpose()`."""
return self.transpose()
def transpose(self, dim0=1, dim1=0) -> Self:
"""
Returns a tensor that is a transposed version of the original tensor.
The given dimensions `dim0` and `dim1` are swapped.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(6).reshape(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.transpose(0, 1).numpy())
```
"""
order = list(range(self.ndim))
order[dim0], order[dim1] = order[dim1], order[dim0]
return self.permute(order)
def flatten(self, start_dim=0, end_dim=-1) -> Self:
"""
Flattens the tensor by reshaping it into a one-dimensional tensor.
If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(8).reshape(2, 2, 2)
print(t.flatten().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.flatten(start_dim=1).numpy())
```
"""
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim : end_dim + 1]),) + self.shape[end_dim + 1 :])
def unflatten(self, dim: int, sizes: tuple[int, ...]) -> Self:
"""
Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape)
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape)
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape)
```
"""
dim = self._resolve_dim(dim)
return self.reshape(self.shape[:dim] + sizes + self.shape[dim + 1 :])
def rearrange(self, formula: str, **sizes) -> Self:
"""
Rearranges input according to formula
See: https://einops.rocks/api/rearrange/
```python exec="true" source="above" session="tensor" result="python"
x = Tensor([[1, 2], [3, 4]])
print(Tensor.rearrange(x, "batch channel -> (batch channel)").numpy())
```
"""
def parse_formula(formula: str):
tokens = f" {formula} ".replace("", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split()
lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")"))
pairs = list(zip(lparens, rparens))
assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch"
return [name for name in tokens if name not in ("(", ")")], [(s - 2 * i, e - 1 - 2 * i) for i, (s, e) in enumerate(pairs)]
assert formula.count("->") == 1, 'need exactly one "->" in formula'
(lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->"))
for name in sizes:
assert name in lhs, f"axis {name} is not used in transform"
assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}"
for name in flatten((lhs, rhs)):
assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}"
assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}"
assert lhs.count("...") <= 1, f"too many ellipses in {formula}"
# resolve ellipsis
if "..." in lhs:
ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims)
lhs, rhs = map(lambda l: l[: (i := l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1 :] if "..." in l else l, (lhs, rhs))
unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims]
flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims]
# apply movement ops in order unflatten -> permute -> flatten/unsqueeze
t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self)
for i, name in enumerate(lhs):
assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect"
t = t.permute([lhs.index(name) for name in rhs])
return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0] < dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
# *** movement ops with expand ***
def repeat_interleave(self, repeats: int, dim: int | None = None) -> Self:
"""
Repeats elements of a tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3])
print(t.repeat_interleave(2).numpy())
```
"""
x, dim = (self.flatten(), 0) if dim is None else (self, self._resolve_dim(dim))
shp = x.shape
x = x.reshape(*shp[: dim + 1], 1, *shp[dim + 1 :])
x = x.expand(*shp[: dim + 1], repeats, *shp[dim + 1 :])
x = x.reshape(*shp[:dim], shp[dim] * repeats, *shp[dim + 1 :])
return x
def repeat(self, repeats, *args) -> Self:
"""
Repeats tensor number of times along each dimension specified by `repeats`.
`repeats` can be passed as a tuple or as separate arguments.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3])
print(t.repeat(4, 2).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.repeat(4, 2, 1).shape)
```
"""
repeats = argfix(repeats, *args)
base_shape = _align_left(self.shape, repeats)[0]
unsqueezed_shape = flatten([[s] if r == 1 else [1, s] for r, s in zip(repeats, base_shape)])
expanded_shape = flatten([[s] if r == 1 else [r, s] for r, s in zip(repeats, base_shape)])
final_shape = [r * s for r, s in zip(repeats, base_shape)]
return self.reshape(unsqueezed_shape).expand(expanded_shape).reshape(final_shape)
# **** pool level ****
def _pool(self, k_: tuple[sint, ...], stride: int | tuple[int, ...] = 1, dilation: int | tuple[int, ...] = 1) -> Self:
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
s_, d_ = make_tuple(stride, len(k_)), make_tuple(dilation, len(k_))
assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
noop, i_ = [None] * (self.ndim - len(k_)), self.shape[-len(k_) :]
assert all(resolve(d * (k - 1) + 1 <= i) for k, d, i in zip(k_, d_, i_)), "kernel size cannot be greater than actual input size"
o_ = [ceildiv(i - d * (k - 1), s) for i, d, k, s in zip(i_, d_, k_, s_)]
# input size scaling factor to make sure shrink for stride is possible
f_ = [smax(1, ceildiv(o * s - d, i)) for o, s, i, d in zip(o_, s_, i_, d_)]
# repeats such that we don't need padding
x = self.repeat([1] * len(noop) + [ceildiv(k * (i * f + d), i) for k, i, d, f in zip(k_, i_, d_, f_)])
# handle dilation
x = x.shrink_to(noop + [k * (i * f + d) for k, i, d, f in zip(k_, i_, d_, f_)])
x = x.reshape(noop + flatten((k, (i * f + d)) for k, i, d, f in zip(k_, i_, d_, f_)))
# handle stride
x = x.shrink_to(noop + flatten((k, o * s) for k, o, s in zip(k_, o_, s_))).reshape(noop + flatten((k, o, s) for k, o, s in zip(k_, o_, s_)))
x = x.shrink_to(noop + flatten((k, o, 1) for k, o in zip(k_, o_))).reshape(noop + flatten((k, o) for k, o in zip(k_, o_)))
# permute to move reduce to the end
return x.permute(*range(len(noop)), *[len(noop) + i * 2 + 1 for i in range(len(i_))], *[len(noop) + i * 2 for i in range(len(i_))])