openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

3876 lines
176 KiB

# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses, hashlib
from contextlib import ContextDecorator
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
from tinygrad.multi import MultiLazyBuffer
from tinygrad.gradient import compute_gradient
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element
from tinygrad.device import Device, Buffer, BufferSpec
from tinygrad.engine.realize import run_schedule
from tinygrad.engine.memory import memory_planner
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
# **** start with two base classes, Tensor and Function ****
class Function:
def __init__(self, device:Union[str, tuple[str, ...]], *tensors:Tensor, metadata:Optional[Metadata]=None):
self.device = device
self.needs_input_grad = [t.requires_grad for t in tensors]
self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False
if self.requires_grad: self.parents = tensors
self.metadata = metadata
def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}")
def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implemented for {type(self)}")
@classmethod
def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor:
ctx = fxn(x[0].device, *x, metadata=_METADATA.get())
ret = Tensor.__new__(Tensor)
ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None
ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine
return ret
import tinygrad.function as F
def _metaop(op, shape:tuple[sint,...], dtype:DType, device:Union[str, tuple[str, ...]], arg=None, src:tuple[UOp, ...]=()):
if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg, src)
return MultiLazyBuffer([UOp.metaop(op, shape, dtype, d, arg, src) for d in device], None)
def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
import numpy as np
return dtypes.fields()[np.dtype(npdtype).name]
def _to_np_dtype(dtype:DType) -> Optional[type]:
import numpy as np
return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
def _fromnp(x: 'np.ndarray') -> UOp: # type: ignore [name-defined] # noqa: F821
ret = UOp.metaop(Ops.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
# fake realize
ret.buffer.allocate(x)
return ret.buf_uop_view()
def get_shape(x) -> tuple[int, ...]:
# NOTE: str is special because __getitem__ on a str is still a str
if not hasattr(x, "__len__") or not hasattr(x, "__getitem__") or isinstance(x, str) or (hasattr(x, "shape") and x.shape == ()): return ()
if not all_same(subs:=[get_shape(xi) for xi in x]): raise ValueError(f"inhomogeneous shape from {x}")
return (len(subs),) + (subs[0] if subs else ())
def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> UOp:
if isinstance(x, bytes): ret, data = UOp.metaop(Ops.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x
else:
ret = UOp.metaop(Ops.EMPTY, get_shape(x), dtype, "PYTHON")
assert dtype.fmt is not None, f"{dtype=} has None fmt"
truncate_function = truncate[dtype]
data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)])
# fake realize
ret.buffer.allocate(memoryview(data if Device.DEFAULT != "PYTHON" else bytearray(data)))
return ret.buf_uop_view()
def _get_winograd_matcols(mat, dims:int, shp:tuple[sint, ...], device:Union[str, tuple[str, ...]], dtype:DType) -> list[list[Tensor]]:
return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device, dtype=dtype) for m in mat], dim=dim)
for k in range(len(mat[0]))] for dim in range(dims)]
# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
# multiply mat_1 @ mat_2 @ t with foldable constants, where mat_i acts on vector t along dimension i; roughly kron(mat, mat) @ t
# due to realize-before-expand rule in lazy.py, we must operate in this order: reshape -> expand -> arithmetic
t_ = t.reshape(t.shape[:dims] + (1,) * dims + t.shape[dims:]).expand(t.shape[:dims] + (len(mat),) * dims + t.shape[dims:]) # add output dims
# precalculate mat columns for each dim; prod(itertools.product(matcols)) gives the columns of kron(mat, mat, ...)
matcols = _get_winograd_matcols(mat, dims, t_.shape[dims:], t_.device, t_.dtype)
# multiply each element of t_ by the corresponding stacked column of kron(mat, mat), producing only one view for each element of t
ret = sum(prod(col[idx] for col, idx in zip(matcols, mat_is)) * t_[mat_is] for mat_is in itertools.product(range(len(mat[0])), repeat=dims))
assert isinstance(ret, Tensor), "sum didn't return a Tensor"
return ret
def _align_left(*shapes:tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]:
# unsqueeze left to make every shape same length
max_dim = max(len(shape) for shape in shapes)
return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes)
def _broadcast_shape(*shapes:tuple[sint, ...]) -> tuple[sint, ...]:
return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes)))
def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, ...]):
# apply mask to values (already broadcasted) and reduce such that if mask contains repeated indices the last one remains
values = values * mask
for dim in axes: mask, values = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), values.split(1, dim)))
# remove extra dims from reduce
for dim in reversed(axes): mask, values = mask.squeeze(dim), values.squeeze(dim)
# select from values for each True element in mask else select from self
return mask.where(values, target)
ReductionStr = Literal["mean", "sum", "none"]
class Tensor(SimpleMathTrait):
"""
A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
```python exec="true" session="tensor"
from tinygrad import Tensor, dtypes, nn
import numpy as np
import math
np.set_printoptions(precision=4)
```
"""
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
__deletable__ = ('_ctx',)
training: ClassVar[bool] = False
no_grad: ClassVar[bool] = False
def __init__(self, data:Union[None, ConstType, bytes, List, Tuple, UOp, MultiLazyBuffer, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
if dtype is not None: dtype = to_dtype(dtype)
if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
# tensors can have gradients if you have called .backward
self.grad: Optional[Tensor] = None
# NOTE: this can be in three states. False and None: no gradient, True: gradient
# None (the default) will be updated to True if it's put in an optimizer
self.requires_grad: Optional[bool] = requires_grad
# internal variable used for autograd graph construction
self._ctx: Optional[Function] = None
# create a LazyBuffer from the different types of inputs
if isinstance(data, (UOp, MultiLazyBuffer)):
assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported"
# NOTE: this is here because LazyBuffer = UOp
if isinstance(data, UOp) and data.op is Ops.BIND: data = _metaop(Ops.CONST, tuple(), dtype or data.dtype, device, data)
elif data is None: data = _metaop(Ops.EMPTY, (0,), dtype or dtypes.default_float, device)
elif isinstance(data, get_args(ConstType)): data = _metaop(Ops.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype)
elif isinstance(data, (list, tuple)):
if dtype is None:
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float # NOTE: this works because all_int([True, False]) is True
if dtype == dtypes.bfloat16: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtypes.bfloat16).lazydata
else: data = _frompy(data, dtype)
elif str(type(data)) == "<class 'numpy.ndarray'>":
import numpy as np
assert isinstance(data, np.ndarray), f"expected np.ndarray, got {data}"
if data.shape == (): data = _metaop(Ops.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item())
else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data) # type: ignore [name-defined]
elif isinstance(data, pathlib.Path):
dtype = dtype or dtypes.uint8
data = _metaop(Ops.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}")
# by this point, it has to be a LazyBuffer
if not isinstance(data, (UOp, MultiLazyBuffer)): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
# data might be on a different device
if isinstance(device, str): self.lazydata:Union[UOp, MultiLazyBuffer] = data if data.device == device else data.copy_to_device(device)
# if device is a tuple, we should have/construct a MultiLazyBuffer
elif isinstance(data, UOp): self.lazydata = MultiLazyBuffer.from_sharded(data, device, None, None)
else:
assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
self.lazydata = data
def requires_grad_(self, requires_grad=True) -> Tensor:
self.requires_grad = requires_grad
return self
class train(ContextDecorator):
def __init__(self, mode:bool = True): self.mode = mode
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
class test(ContextDecorator):
def __init__(self, mode:bool = True): self.mode = mode
def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode
def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
def __repr__(self):
if isinstance(ld:=self.lazydata, MultiLazyBuffer): ld_repr = f"{self.lazydata!r}"
else: ld_repr = f"<UOp {ld.device} {ld.shape} {str(ld.dtype)[7:]} {ld.st if ld.base is not ld else (ld.op, ld.realized)}>"
return f"<Tensor {ld_repr} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
# Python has a non moving GC, so this should be okay
def __hash__(self): return id(self)
def __bool__(self): raise TypeError("__bool__ on Tensor is not defined")
def __len__(self):
if not self.shape: raise TypeError("len() of a 0-d tensor")
return self.shape[0]
@property
def device(self) -> Union[str, tuple[str, ...]]: return self.lazydata.device
@property
def shape(self) -> tuple[sint, ...]: return self.lazydata.shape
@property
def dtype(self) -> DType: return self.lazydata.dtype
# ***** data handlers ****
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[Variable, int]]:
"""
Creates the schedule needed to realize these Tensor(s), with Variables.
NOTE: A Tensor can only be scheduled once.
"""
schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst]))
return memory_planner(schedule), var_vals
def schedule(self, *lst:Tensor) -> list[ScheduleItem]:
"""Creates the schedule needed to realize these Tensor(s)."""
schedule, var_vals = self.schedule_with_vars(*lst)
assert len(var_vals) == 0
return schedule
def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor:
"""Triggers the computation needed to create these Tensor(s)."""
run_schedule(*self.schedule_with_vars(*lst), do_update_stats=do_update_stats)
return self
def replace(self, x:Tensor) -> Tensor:
"""
Replaces the data of this tensor with the data of another tensor. Only the shape of the tensors must match.
"""
# used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
assert not x.requires_grad and getattr(self, '_ctx', None) is None
assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
self.lazydata = x.lazydata
return self
def assign(self, x) -> Tensor:
# TODO: this is a hack for writing to DISK. remove with working assign
if isinstance(self.device, str) and self.device.startswith("DISK"):
if x.__class__ is not Tensor: x = Tensor(x, device="CLANG", dtype=self.dtype)
self.contiguous().realize().lazydata.base.realized.copyin(x._data())
return self
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
if self.lazydata is x.lazydata: return self # a self assign is a NOOP
# NOTE: we allow cross device assign
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
assert not x.requires_grad # self requires_grad is okay?
if not self.lazydata.is_realized: return self.replace(x)
self.lazydata = self.lazydata.assign(x.lazydata)
return self
def detach(self) -> Tensor:
"""
Returns a new tensor with the same data as this tensor, but detached from the autograd graph.
"""
return Tensor(self.lazydata.detach(), device=self.device, requires_grad=False)
def _data(self) -> memoryview:
if 0 in self.shape: return memoryview(bytearray(0))
# NOTE: this realizes on the object from as_buffer being a Python object
cpu = self.cast(self.dtype.base).contiguous().to("CLANG").realize()
buf = cast(Buffer, cast(UOp, cpu.lazydata).base.realized)
if self.device != "CLANG": buf.options = BufferSpec(nolru=True)
return buf.as_buffer(allow_zero_copy=True if self.device != "CLANG" else False)
def data(self) -> memoryview:
"""
Returns the data of this tensor as a memoryview.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3, 4])
print(np.frombuffer(t.data(), dtype=np.int32))
```
"""
assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
if TYPE_CHECKING or sys.version_info < (3, 12): assert self.dtype.base.fmt != "e"
return cast(memoryview, self._data().cast(self.dtype.base.fmt) if 0 in self.shape else self._data().cast(self.dtype.base.fmt, self.shape))
def item(self) -> ConstType:
"""
Returns the value of this tensor as a standard Python number.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor(42)
print(t.item())
```
"""
assert self.numel() == 1, "must have one element for item"
return self.data()[(0,) * len(self.shape)]
# TODO: should be Tensor.tolist() -> Union[list[ConstType], ConstType]. The List is Sequence because mypy expects memoryview.tolist() -> list[int]
# src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803
def tolist(self) -> Union[Sequence[ConstType], ConstType]:
"""
Returns the value of this tensor as a nested list.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3, 4])
print(t.tolist())
```
"""
return self.data().tolist()
def numpy(self) -> 'np.ndarray': # type: ignore [name-defined] # noqa: F821
"""
Returns the value of this tensor as a `numpy.ndarray`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3, 4])
print(repr(t.numpy()))
```
"""
import numpy as np
if self.dtype.base == dtypes.bfloat16: return self.float().numpy()
assert _to_np_dtype(self.dtype.base) is not None, f"no np dtype for {self.dtype.base}"
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
return np.frombuffer(self._data(), dtype=_to_np_dtype(self.dtype.base)).reshape(self.shape)
def clone(self) -> Tensor:
"""
Creates a clone of this tensor allocating a seperate buffer for the data.
"""
ret = Tensor(self.lazydata.clone(), self.device, requires_grad=self.requires_grad)
if self.grad is not None: ret.grad = self.grad.clone()
if hasattr(self, '_ctx'): ret._ctx = self._ctx
return ret
def to(self, device:Optional[Union[str, tuple[str, ...]]]) -> Tensor:
"""
Moves the tensor to the given device.
"""
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
if device == self.device: return self
if not isinstance(device, str): return self.shard(device)
ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad)
if self.grad is not None: ret.grad = self.grad.to(device)
if hasattr(self, '_ctx'): ret._ctx = self._ctx
return ret
def to_(self, device:Optional[Union[str, tuple[str, ...]]]):
"""
Moves the tensor to the given device in place.
"""
real = self.to(device)
# TODO: is this assign?
if self.grad is not None and real.grad is not None: self.grad.lazydata = real.grad.lazydata
self.lazydata = real.lazydata
def shard(self, devices:tuple[str, ...], axis:Optional[int]=None, splits:Optional[tuple[int, ...]]=None) -> Tensor:
"""
Shards the tensor across the given devices. Optionally specify which axis to shard on, and how to split it across devices.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.empty(2, 3)
print(t.shard((t.device, t.device), axis=1, splits=(2, 1)).lazydata)
```
"""
assert isinstance(self.lazydata, UOp), "can't shard a MultiLazyBuffer"
devices, bounds = tuple(Device.canonicalize(x) for x in devices), None
if axis is not None:
axis = self._resolve_dim(axis)
if splits is None:
if not isinstance(total:=self.shape[axis], int): raise RuntimeError(f"cannot shard symbolic shape {self.shape=}, {axis=}")
sz = ceildiv(total, len(devices))
splits = tuple([max(0, min(sz, total - sz*i)) for i in range(len(devices))])
assert sum(splits) == self.shape[axis], "specified splits do not sum up to axis shape"
bounds = tuple(itertools.pairwise(itertools.accumulate(splits, initial=0)))
return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, devices, axis, bounds), device=devices, requires_grad=self.requires_grad)
def shard_(self, devices:tuple[str, ...], axis:Optional[int]=None, splits:Optional[tuple[int, ...]]=None):
"""
Shards the tensor across the given devices in place.
"""
self.lazydata = self.shard(devices, axis, splits).lazydata
return self
@staticmethod
def from_uop(y:UOp, **kwargs) -> Tensor:
if y.op is Ops.BIND: return Tensor(y, **kwargs, requires_grad=False) # this is the only UOp allowed in Tensor
if y.op is Ops.CONST: return Tensor(y.arg, **kwargs, requires_grad=False)
if y.op is Ops.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1])
if y.op is Ops.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
if y.op is Ops.MAX: return Tensor.from_uop(y.src[0]).maximum(Tensor.from_uop(y.src[1]))
raise RuntimeError(f"unhandled UOp {y}")
# ***** creation entrypoint *****
@staticmethod
def _metaop(op, shape, device:Optional[Union[tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, arg=None, **kwargs):
dtype = to_dtype(dtype) if dtype is not None else dtypes.default_float
if isinstance(device, tuple):
return Tensor(MultiLazyBuffer([UOp.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], None),
device, dtype, **kwargs)
return Tensor(UOp.metaop(op, shape, dtype, Device.canonicalize(device), arg), device, dtype, **kwargs)
@staticmethod
def empty(*shape, **kwargs):
"""
Creates an empty tensor with the given shape.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.empty(2, 3)
print(t.shape)
```
"""
return Tensor._metaop(Ops.EMPTY, argfix(*shape), **kwargs)
@staticmethod
def from_blob(ptr:int, shape:tuple[int, ...], **kwargs) -> Tensor:
"""
Exposes the pointer as a Tensor without taking ownership of the original data.
The pointer must remain valid for the entire lifetime of the created Tensor.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
"""
r = Tensor._metaop(Ops.EMPTY, shape, **kwargs)
r.lazydata.buffer.allocate(external_ptr=ptr)
r.lazydata.buf_uop_view()
return r
@staticmethod
def from_url(url:str, gunzip:bool=False, **kwargs) -> Tensor:
"""
Create a Tensor from a URL.
This is the preferred way to access Internet resources.
It currently returns a DISK Tensor, but in the future it may return an HTTP Tensor.
This also will soon become lazy (when possible) and not print progress without DEBUG.
THe `gunzip` flag will gzip extract the resource and return an extracted Tensor.
"""
return Tensor(fetch(url, gunzip=gunzip), **kwargs)
_seed: int = int(time.time())
_device_seeds: dict[str, Tensor] = {}
_device_rng_counters: dict[str, Tensor] = {}
@staticmethod
def manual_seed(seed=0):
"""
Sets the seed for random operations.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.rand(5).numpy())
print(Tensor.rand(5).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42) # reset to the same seed
print(Tensor.rand(5).numpy())
print(Tensor.rand(5).numpy())
```
"""
Tensor._seed, Tensor._device_seeds, Tensor._device_rng_counters = seed, {}, {}
@staticmethod
def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor):
x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64)
x = F.Threefry.apply(x, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32)
return counts0.cat(counts1)
@staticmethod
def rand(*shape, device:Optional[str]=None, dtype:Optional[DTypeLike]=None, contiguous:bool=True, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.rand(2, 3)
print(t.numpy())
```
"""
if not dtypes.is_float(dtype := to_dtype(dtype or dtypes.default_float)): raise ValueError(f"rand only supports float dtypes, got {dtype}")
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
_device = device = Device.canonicalize(device)
# if shape has 0, return zero tensor
if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
num = ceildiv(numel * dtype.itemsize, 4)
# when using MOCKGPU and NV generate rand on CLANG
if getenv("MOCKGPU") and device.startswith("NV"): device = "CLANG"
# generate per device seeds and rng counter if we haven't seen this device yet
if device not in Tensor._device_seeds:
Tensor._device_seeds[device] = Tensor(
[int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big"), Tensor._seed],
device=device, dtype=dtypes.uint32, requires_grad=False)
Tensor._device_rng_counters[device] = Tensor([0], device=device, dtype=dtypes.uint32, requires_grad=False)
# increment rng counter for devices
else: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num).contiguous()
# threefry random bits
counts0 = (Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._device_rng_counters[device])
counts1 = counts0 + ceildiv(num, 2)
bits = Tensor._threefry_random_bits(Tensor._device_seeds[device], counts0, counts1)[:num]
# bitcast to uint with same number of bits
_, nmant = dtypes.finfo(dtype)
uint_dtype = {1: dtypes.uint8, 2: dtypes.uint16, 4: dtypes.uint32, 8: dtypes.uint64}[dtype.itemsize]
bits = bits.bitcast(uint_dtype)
# only randomize the mantissa bits and set the exponent to 1
one = Tensor.ones_like(bits, device=bits.device, dtype=dtype).bitcast(uint_dtype)
bits = bits.rshift((dtype.itemsize * 8) - nmant).bitwise_or(one)
# bitcast back to the original dtype and reshape
out = bits.bitcast(dtype)[:numel].sub(1).reshape(shape)
# move back to the original device if we were using MOCKGPU
if getenv("MOCKGPU") and _device: out = out.to(_device)
out.requires_grad = kwargs.get("requires_grad")
return out.contiguous() if contiguous else out
# ***** creation helper functions *****
@staticmethod
def full(shape:tuple[sint, ...], fill_value:ConstType, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with the given value.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.full((2, 3), 42).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.full((2, 3), False).numpy())
```
"""
return Tensor(fill_value, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape)
@staticmethod
def zeros(*shape, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with zeros.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.zeros(2, 3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.zeros(2, 3, dtype=dtypes.int32).numpy())
```
"""
return Tensor.full(argfix(*shape), 0.0, **kwargs)
@staticmethod
def ones(*shape, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with ones.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(2, 3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(2, 3, dtype=dtypes.int32).numpy())
```
"""
return Tensor.full(argfix(*shape), 1.0, **kwargs)
@staticmethod
def arange(start, stop=None, step=1, **kwargs) -> Tensor:
"""
Returns a 1-D tensor of size `ceil((stop - start) / step)` with values from `[start, stop)`, with spacing between values given by `step`.
If `stop` is not specified, values are generated from `[0, start)` with the given `step`.
If `stop` is specified, values are generated from `[start, stop)` with the given `step`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.arange(5).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.arange(5, 10).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.arange(5, 10, 2).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.arange(5.5, 10, 2).numpy())
```
"""
if stop is None: stop, start = start, 0
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
# NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs
if (output_len:=ceildiv(stop-start, step)) <= 0: return Tensor([], dtype=dtype, **kwargs)
return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumalu(0, Ops.ADD) + (start - step)).cast(dtype)
@staticmethod
def linspace(start:Union[int, float], stop:Union[int, float], steps:int, **kwargs) -> Tensor:
"""
Returns a 1-D tensor of `steps` evenly spaced values from `start` to `stop`, inclusive.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.linspace(0, 10, 5).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.linspace(-1, 1, 5).numpy())
```
"""
if steps < 0: raise ValueError("number of steps must be non-negative")
if (dtype := to_dtype(kwargs.pop("dtype", dtypes.default_float))) == dtypes.bool: raise ValueError("linspace with bool dtype is not supported")
if steps == 1: return Tensor([start], dtype=dtype, **kwargs)
return (start + Tensor.arange(steps, **kwargs) * ((stop - start) / (steps - 1))).cast(dtype)
@staticmethod
def eye(n:int, m:Optional[int]=None, **kwargs) -> Tensor:
"""
Returns a 2-D tensor with `n` rows and `m` columns, with ones on the diagonal and zeros elsewhere.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.eye(3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.eye(2, 4).numpy())
```
"""
if n < 0 or (m is not None and m < 0): raise ValueError(f"cannot have negative {n=}, {m=}")
x = Tensor.ones((n,1),**kwargs).pad((None,(0,n))).flatten().shrink(((0,n*n),)).reshape(n,n)
return x if m is None else x.pad((None, (0, m-n))) if m > n else x.shrink((None, (0, m)))
def full_like(self, fill_value:ConstType, **kwargs) -> Tensor:
"""
Creates a tensor with the same shape as `self`, filled with the given value.
If `dtype` is not specified, the dtype of `self` is used.
You can pass in the `device` keyword argument to control device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.full_like(t, 42).numpy())
```
"""
return Tensor.full(self.shape, fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
def zeros_like(self, **kwargs) -> Tensor:
"""
Creates a tensor with the same shape as `self`, filled with zeros.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.zeros_like(t).numpy())
```
"""
return self.full_like(0, **kwargs)
def ones_like(self, **kwargs) -> Tensor:
"""
Creates a tensor with the same shape as `self`, filled with ones.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.zeros(2, 3)
print(Tensor.ones_like(t).numpy())
```
"""
return self.full_like(1, **kwargs)
def rand_like(self, **kwargs) -> Tensor:
"""
Creates a tensor with the same shape and sharding as `self`, filled with random values from a uniform distribution over the interval `[0, 1)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(Tensor.rand_like(t).numpy())
```
"""
dtype = kwargs.pop("dtype", self.dtype)
if isinstance(self.device, tuple) and isinstance(self.lazydata, MultiLazyBuffer):
if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor")
if self.lazydata.axis is None: return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device)
contiguous = kwargs.pop("contiguous", True)
rands = [Tensor.rand(*lb.shape, device=lb.device, dtype=dtype, contiguous=contiguous, **kwargs).lazydata for lb in self.lazydata.lbs]
return Tensor(MultiLazyBuffer(cast(list[UOp], rands), self.lazydata.axis), device=self.device, dtype=dtype, **kwargs)
return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=dtype, **kwargs)
# ***** rng hlops *****
@staticmethod
def randn(*shape, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`.
If `dtype` is not specified, the default type is used.
You can pass in the `device` keyword argument to control device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.randn(2, 3).numpy())
```
"""
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
src = Tensor.rand((2, *argfix(*shape)), **{**kwargs, "dtype": dtypes.float32})
return (src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or dtypes.default_float)).requires_grad_(requires_grad)
@staticmethod
def randint(*shape, low=0, high=10, dtype=dtypes.int32, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random integer values generated uniformly from the interval `[low, high)`.
If `dtype` is not specified, the default type is used.
You can pass in the `device` keyword argument to control device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.randint(2, 3, low=5, high=10).numpy())
```
"""
if not isinstance(low, int) or not isinstance(high, int): raise TypeError(f"{low=} and {high=} must be integers")
dtype = to_dtype(dtype)
if not dtypes.is_int(dtype): raise TypeError(f"{dtype=} must be int")
return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
@staticmethod
def normal(*shape, mean=0.0, std=1.0, requires_grad:Optional[bool]=None, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.normal(2, 3, mean=10, std=2).numpy())
```
"""
return ((std * Tensor.randn(*shape, **kwargs)) + mean).requires_grad_(requires_grad)
@staticmethod
def uniform(*shape, low=0.0, high=1.0, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.uniform(2, 3, low=2, high=10).numpy())
```
"""
return (((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype or dtypes.default_float) + low).requires_grad_(requires_grad)
@staticmethod
def scaled_uniform(*shape, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with random values from a uniform distribution
over the interval `[-prod(shape)**-0.5, prod(shape)**-0.5)`.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.scaled_uniform(2, 3).numpy())
```
"""
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
@staticmethod
def glorot_uniform(*shape, **kwargs) -> Tensor:
"""
<https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.glorot_uniform(2, 3).numpy())
```
"""
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul((6/(argfix(*shape)[0]+prod(argfix(*shape)[1:])))**0.5)
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
@staticmethod
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
"""
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.kaiming_uniform(2, 3).numpy())
```
"""
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
@staticmethod
def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
"""
<https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_>
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
print(Tensor.kaiming_normal(2, 3).numpy())
```
"""
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
def multinomial(self:Tensor, num_samples:int = 1, replacement:bool = False) -> Tensor:
assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
assert replacement or num_samples == 1, "no replacement only supports num_samples = 1"
weight = self.unsqueeze(0) if self.ndim == 1 else self
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1).to(self.device)
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
# ***** toposort and backward pass *****
def gradient(self, *targets:Tensor, gradient:Optional[Tensor]=None) -> list[Tensor]:
"""
Compute the gradient of the targets with respect to self.
```python exec="true" source="above" session="tensor" result="python"
x = Tensor.eye(3)
y = Tensor([[2.0,0,-2.0]])
z = y.matmul(x).sum()
dx, dy = z.gradient(x, y)
print(dx.tolist()) # dz/dx
print(dy.tolist()) # dz/dy
```
"""
assert isinstance(self.lazydata, UOp), "multi isn't supported yet"
target_uops: list[UOp] = [x.lazydata for x in targets if isinstance(x.lazydata, UOp)]
assert gradient is not None or self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
grads = compute_gradient(self.lazydata, self.lazydata.const_like(1) if gradient is None else cast(UOp, gradient.lazydata), target_uops)
ret = []
for x in target_uops:
if (y:=grads.get(x)) is None: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.lazydata}")
ret.append(Tensor(y, device=x.device))
return ret
def _deepwalk(self):
def _walk(node, visited):
visited.add(node)
# if tensor is not leaf, reset grad
if (ctx := getattr(node, "_ctx", None)) is not None and len(ctx.parents) != 0: node.grad = None
if ctx:
for i in node._ctx.parents:
if i not in visited: yield from _walk(i, visited)
yield node
return list(_walk(self, set()))
def backward(self, gradient:Optional[Tensor]=None, retain_graph:bool=False) -> Tensor:
"""
Propagates the gradient of a tensor backwards through the computation graph.
If the 'gradient' argument is not provided, the tensor must be a scalar, and the gradient is implicitly set to 1.0.
If 'retain_graph' is false, the graph used to compute the grads will be freed. Otherwise, it will be kept. Keeping it can increase memory usage.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
t.sum().backward()
print(t.grad.numpy())
```
"""
toposorted = self._deepwalk()
if gradient is None:
assert self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
# this is "implicit gradient creation"
gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
assert self.shape == gradient.shape, f"grad shape must match tensor shape, {gradient.shape!r} != {self.shape!r}"
self.grad = gradient
for t0 in reversed(toposorted):
if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := t0._ctx.metadata) is not None else None)
grads = t0._ctx.backward(t0.grad.lazydata)
_METADATA.reset(token)
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
for t, g in zip(t0._ctx.parents, grads):
if g is not None and t.requires_grad:
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
t.grad = g if t.grad is None else (t.grad + g)
if not retain_graph: del t0._ctx
return self
# ***** movement low level ops *****
def view(self, *shape) -> Tensor:
"""`.view` is an alias for `.reshape`."""
return self.reshape(shape)
def reshape(self, shape, *args) -> Tensor:
"""
Returns a tensor with the same data as the original tensor but with a different shape.
`shape` can be passed as a tuple or as separate arguments.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(6)
print(t.reshape(2, 3).numpy())
```
"""
# resolve None and args
new_shape = tuple([s if s is not None else self.shape[i] for i,s in enumerate(argfix(shape, *args))])
# resolve -1
if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self
def expand(self, shape, *args) -> Tensor:
"""
Returns a tensor that is expanded to the shape that is specified.
Expand can also increase the number of dimensions that a tensor has.
Passing a `-1` or `None` to a dimension means that its size will not be changed.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3])
print(t.expand(4, -1).numpy())
```
"""
new_shape = tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_align_left(self.shape, argfix(shape, *args)))))
return self._broadcast_to(new_shape)
def permute(self, order, *args) -> Tensor:
"""
Returns a tensor that is a permutation of the original tensor.
The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified.
`order` can be passed as a tuple or as separate arguments.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(6).reshape(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.permute(1, 0).numpy())
```
"""
order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
return F.Permute.apply(self, order=order_arg)
def flip(self, axis, *args) -> Tensor:
"""
Returns a tensor that reverses the order of the original tensor along given `axis`.
`axis` can be passed as a tuple or as separate arguments.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(6).reshape(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.flip(0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.flip((0, 1)).numpy())
```
"""
axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
return F.Flip.apply(self, axis=axis_arg)
def shrink(self, arg:tuple[Optional[tuple[sint, sint]], ...]) -> Tensor:
"""
Returns a tensor that shrinks the each axis based on input arg.
`arg` must have the same length as `self.ndim`.
For each axis, it can be `None`, which means no shrink, or a tuple `(start, end)` that works the same as Python slice.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(9).reshape(3, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.shrink(((None, (1, 3)))).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.shrink((((0, 2), (0, 2)))).numpy())
```
"""
if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self
return F.Shrink.apply(self, arg=tuple(shrink_arg))
def pad(self, padding:Union[Sequence[sint], Sequence[Optional[tuple[sint, sint]]]], mode:str="constant", value:float=0.0) -> Tensor:
"""
Returns a tensor with padding applied based on the input `padding`.
`padding` supports two padding structures:
1. Flat padding: (padding_left, padding_right, padding_top, padding_bottom, ...)
- This structure matches PyTorch's pad.
- `padding` length must be even.
2. Group padding: (..., (padding_top, padding_bottom), (padding_left, padding_right))
- This structure matches pad for jax, numpy, tensorflow and others.
- For each axis, padding can be `None`, meaning no padding, or a tuple `(start, end)`.
- `padding` must have the same length as `self.ndim`.
Padding values can be negative, resulting in dimension shrinks that work similarly to Python negative slices.
Padding modes is selected with `mode` which supports `constant`, `reflect` and `replicate`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(9).reshape(1, 1, 3, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.pad((1, 2, 0, -1)).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.pad(((None, None, (0, -1), (1, 2)))).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.pad((1, 2, 0, -1), value=-float('inf')).numpy())
```
"""
if mode not in {"constant", "reflect", "replicate", "circular"}: raise NotImplementedError(f"{mode=} is not supported")
if (flat:=all(isinstance(p, (int,UOp)) for p in padding)) and len(padding)%2 != 0: raise ValueError("Flat padding must have even number of pads")
# turn flat padding into group padding
pX = ((0,0),)*(self.ndim - len(padding)//2) + tuple(zip(padding[-2::-2], padding[::-2])) if flat else padding
if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}")
X, pX = self, cast(tuple[tuple[sint, sint]], tuple((0,0) if p is None else p for p in pX))
pads = tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX)
if mode == "constant":
def _constant(x,px,v): return F.Pad.apply(x, arg=px) if v == 0 else F.Pad.apply(x, arg=px) + F.Pad.apply(Tensor.ones_like(x), arg=px).where(0,v)
return _constant(X, pX, value) if all(resolve(p >= 0) for p in flatten(pX)) else \
_constant(X.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, X.shape))), pads, value)
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
if mode == "circular":
if any(pB>sh or pA>sh for (pB,pA),sh in zip(pX, X.shape)): raise ValueError('Padding value causes wrapping around more than once.')
if any(pB<0 or pA<0 for pB,pA in pX): raise NotImplementedError("Negative pads with circular pads is not supported")
orig_shape, X = X.shape, X.repeat(tuple(1 + bool(pB) + bool(pA) for pB,pA in pads))
return X.shrink(tuple((0 if pB == 0 else osh-pB, xsh if pA == 0 else xsh-osh+pA) for (pB,pA),osh,xsh in zip(pads, orig_shape, X.shape)))
for d,(pB,pA) in enumerate(pads):
if mode == "reflect":
if pB >= (s:=X.shape[d]) or pA>=s: raise ValueError(f"Padding ({pB}, {pA}) should be less than the input size={s} for dim={d}.")
slcB, slcA, = slice(pB,0,-1), slice(s-2 if s-2>=0 else None, s-2-pA if s-2-pA>=0 else None, -1)
xB, xA = (X[[slc if i == d else slice(None) for i in range(X.ndim)]] if p > 0 else None for slc, p in ((slcB, pB), (slcA, pA)))
if mode == "replicate":
shrB, shrA, = tuple((0,1) if i==d else None for i in range(X.ndim)), tuple((X.shape[i]-1,X.shape[i]) if i==d else None for i in range(X.ndim))
xB, xA = (X.shrink(shr).expand(tuple(p if i==d else None for i in range(X.ndim))) if p > 0 else None for shr, p in ((shrB, pB), (shrA, pA)))
X = Tensor.cat(*(X_ for X_ in (xB, X, xA) if X_ is not None), dim=d)
return X.shrink(tuple((-min(pB,0), min(pA+s,s)) for (pB,pA),s in zip(pX, X.shape)))
# ***** movement high level ops *****
# Supported Indexing Implementations:
# 1. Int indexing (no copy)
# - for all dims where there's int, shrink -> reshape
# - negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
# - X = Tensor.rand(4,5,9); X[2,-2] shrinks the Tensor to X.shrink(((2, 3), (3, 4), (0, 9))) -> X.shape=(1,1,9)
# - Then we reshape (collapse) the int dim away such that for X: (1,1,9) -> (9,)
# 2. Slice indexing (no copy)
# - for all dims where slice is start:end:stride, shrink -> Optional[flip] -> pad -> reshape -> shrink
# - first shrink the Tensor to X.shrink(((start, end),))
# - then we apply stride through Optional[flip] -> pad -> reshape -> shrink
# - flip where dim value is negative
# - pad on dims to be multiple of strides, such that reshaping [dim_size_padded] -> [dim_size_padded // stride, stride] is possible
# - shrink [dim_size_padded // stride, stride] -> [dim_size_padded // stride, 1]
# - reshape [dim_size_padded // stride, 1] -> [dim_size_padded // stride] and now you have your stride
# 3. None indexing (no copy)
# - reshape (inject) a dim at the dim where there's None
# 4. Tensor indexing (copy)
# - use Tensor.arange == tensor_index to create masks for dims with Tensors (adds a dim for each mask)
# - combine masks together with mul
# - apply mask to self by mask * self
# - sum reduce away the extra dims added from creating masks
# Tiny Things:
# 1. Supported indices: Union[int, slice, Tensor, None, List, Tuple, Ellipsis]
# - for any list, list[Union[List, Tuple, int]], must have homogeneous shape
# - for any tuple, tuple[Union[List, Tuple, int]], must have homogeneous shape
# 2. Bool indexing is not supported
# 3. Out of bounds Tensor indexing results in 0
# - e.g: Tensor([1, 2, 3])[Tensor([4, 3, 2])] -> [0, 0, 3] index 4 and 3 are out of bounds
def _getitem(self, indices, v: Optional[Tensor] = None) -> Tensor:
# wrap single index into a list
if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)): indices = [indices]
# turn scalar Tensors into const val for int indexing if possible
x, indices = self, [self._to_const_val(i) if isinstance(i, Tensor) and i.shape == () else i for i in indices]
# filter ellipsis and fill with slice(None) or fill rest of indices with slice(None)
if len(ellipsis_idx := [dim for dim, i in enumerate(indices) if i is Ellipsis]) > 1: raise IndexError("indices can only have a single ellipsis")
fill_idx = ellipsis_idx[0] if ellipsis_idx else len(indices)
num_indices = len(indices) - len(ellipsis_idx) - sum(1 for i in indices if i is None)
if num_indices > self.ndim: raise IndexError(f"too many {num_indices=} for {self.ndim=}")
indices[fill_idx:fill_idx+1] = [slice(None)] * (self.ndim - num_indices)
indices_parsed, dim = [], 0
for index in indices:
size = 1 if index is None else self.shape[dim]
boundary, stride = [0, size], 1 # defaults
match index:
case list() | tuple() | Tensor():
if not isinstance(index, Tensor): index = Tensor(index, self.device, requires_grad=False)
if not dtypes.is_int(index.dtype): raise IndexError(f"index dtype {index.dtype} is not supported")
index = (index.to(self.device) < 0).where(size, 0) + index # treat negative index values
case int() | UOp(): # sint
if index >= size or index < -size: raise IndexError(f"{index=} is out of bounds with {size=}")
boundary = [index, index+1] if index >= 0 else [index+size, index+size+1]
case slice():
if index.step == 0: raise ValueError(f"{index=} cannot have 0 as step")
if not all(isinstance(s,int) or s is None for s in (index.start,index.stop,index.step)): raise TypeError("only int slicing is supported")
# handle int slicing
*boundary, stride = index.indices(cast(SupportsIndex, size))
if stride * (boundary[1] - boundary[0]) < 0: boundary = [0, 0]
elif stride < 0: boundary = [boundary[1] + 1, boundary[0] + 1]
# update size for slice
size = ceildiv((boundary[1] - boundary[0]), abs(stride))
case None: pass # do nothing
case _: raise IndexError(f"{type(index).__name__} indexing is not supported")
indices_parsed.append({"index":index, "size":size, "boundary":tuple(boundary), "stride":stride})
if index is not None: dim += 1
# movement op indexing
if mops := [i for i in indices_parsed if i['index'] is not None]:
# flip negative strides
shrinks, strides = zip(*((i['boundary'], i['stride']) for i in mops))
x = x.shrink(shrinks).flip(tuple(i for i,st in enumerate(strides) if st < 0))
# handle stride != 1 or -1
if any(abs(st) != 1 for st in strides):
strides = tuple(abs(s) for s in strides)
# pad shape to multiple of stride
if not all_int(x.shape): raise RuntimeError("symbolic shape not supprted")
x = x.pad(tuple((0, round_up(s, st) - s) for s, st in zip(x.shape, strides)))
x = x.reshape(tuple(flatten((s // st, st) for s, st in zip(x.shape, strides))))
x = x.shrink(tuple(flatten(((0, s), (0, 1)) for s in x.shape[::2]))).reshape(x.shape[::2])
# dim injection from None by including None dim size (which is 1) and dim collapse by skipping int dim size
x = x.reshape(tuple(index['size'] for index in indices_parsed if not isinstance(index['index'], int)))
# tensor indexing
if tops := [(d,i) for d,i in enumerate(i_ for i_ in indices_parsed if not isinstance(i_['index'], int)) if isinstance(i['index'], Tensor)]:
# unload the tensor object into actual tensors
dims, tensors, masks = [d for d,_ in tops], cast(list[Tensor], [i['index'] for _,i in tops]), []
pre_reduce_shape = x.shape[:dims[0]] + (big_shape := _broadcast_shape(*(t.shape for t in tensors))) + x.shape[dims[0]:]
# create index masks
for dim, tensor in zip(dims, tensors):
try: i = tensor.reshape(tensor.shape + (1,)*(x.ndim - dims[0])).expand(pre_reduce_shape)
except ValueError as e: raise IndexError(f"cannot broadcast indices: {e}") from e
masks.append(i._one_hot_along_dim(num_classes=x.shape[dim], dim=(dim - x.ndim)))
# reduce masks to 1 mask
mask: Tensor = functools.reduce(lambda x,y: x.mul(y), masks)
# inject 1's for the extra dims added in create masks
reshape_arg = x.shape[:dims[0]] + (1,) * len(big_shape) + x.shape[dims[0]:]
# sum reduce the extra dims introduced in create masks
x = (x.reshape(reshape_arg) * mask).sum(sum_axis:=tuple(d + len(big_shape) for d in dims), acc_dtype=x.dtype)
# special permute case
if dims[0] != 0 and len(dims) != 1 and tuple(dims) != tuple(range(dims[0], dims[-1]+1)):
x = x.permute(*range(dims[0], dims[0]+len(big_shape)), *range(0, dims[0]), *range(dims[0]+len(big_shape), x.ndim))
# for advanced setitem, returns whole tensor with indices replaced
if v is not None:
vb = v.cast(self.dtype)._broadcast_to(_broadcast_shape(x.shape, v.shape))
# add back reduced dims from sum
for dim in sum_axis: vb = vb.unsqueeze(dim)
# run _masked_setitem on tuple of axis that is to be reduced to match self.shape
x = _masked_setitem(self, vb, mask, tuple(range(dims[0], dims[0] + len(big_shape))))
return x
def __getitem__(self, indices) -> Tensor:
return self._getitem(indices)
def __setitem__(self, indices, v:Union[Tensor, ConstType]) -> None:
if isinstance(self.device, str) and self.device.startswith("DISK"):
self._getitem(indices).assign(v)
return
# NOTE: check that setitem target is valid first
if not all(unwrap(lb.st).contiguous for lb in self.lazydata.lbs): raise RuntimeError("setitem target needs to be contiguous")
if not isinstance(v, (Tensor, float, int, bool)): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
res = self.realize()._getitem(indices, v)
# if shapes match and data is not shared it's a copy and we assign to self
if res.shape == self.shape and res.lazydata is not self.lazydata:
self.assign(res).realize()
else: # no copy, basic setitem
v = v.cast(res.dtype)._broadcast_to(_broadcast_shape(res.shape, v.shape)).contiguous()
res.assign(v).realize()
def gather(self:Tensor, dim:int, index:Tensor) -> Tensor:
"""
Gathers values along an axis specified by `dim`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 2], [3, 4]])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.gather(1, Tensor([[0, 0], [1, 0]])).numpy())
```
"""
assert index.ndim == self.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}"
dim = self._resolve_dim(dim)
assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
index = index.to(self.device)
x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
return (x * index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim])).sum(-1, acc_dtype=self.dtype)
def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
"""
Concatenates self with other `Tensor` in `args` along an axis specified by `dim`.
All tensors must have the same shape except in the concatenating dimension.
```python exec="true" source="above" session="tensor" result="python"
t0, t1, t2 = Tensor([[1, 2]]), Tensor([[3, 4]]), Tensor([[5, 6]])
print(t0.cat(t1, t2, dim=0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t0.cat(t1, t2, dim=1).numpy())
```
"""
dim = self._resolve_dim(dim)
for arg in args: assert arg.ndim==self.ndim and all(ti==ai for i,(ti,ai) in enumerate(zip(self.shape, arg.shape)) if i!=dim)
tensors = [self, *args]
dim_cumsum = list(itertools.accumulate([t.shape[dim] for t in tensors], initial=0))
for i,t in enumerate(tensors): tensors[i] = t.pad([(dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==dim else None for j in range(t.ndim)])
return functools.reduce(Tensor.add, tensors)
def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
"""
Concatenates self with other `Tensor` in `args` along a new dimension specified by `dim`.
```python exec="true" source="above" session="tensor" result="python"
t0, t1, t2 = Tensor([1, 2]), Tensor([3, 4]), Tensor([5, 6])
print(t0.stack(t1, t2, dim=0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t0.stack(t1, t2, dim=1).numpy())
```
"""
# checks for shapes and number of dimensions delegated to cat
return Tensor.cat(*[t.unsqueeze(dim) for t in [self, *args]], dim=dim)
def repeat_interleave(self, repeats:int, dim:Optional[int]=None) -> Tensor:
"""
Repeat elements of a tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3])
print(t.repeat_interleave(2).numpy())
```
"""
x, dim = (self.flatten(), 0) if dim is None else (self, self._resolve_dim(dim))
shp = x.shape
return x.reshape(*shp[:dim+1], 1, *shp[dim+1:]).expand(*shp[:dim+1], repeats, *shp[dim+1:]).reshape(*shp[:dim], shp[dim]*repeats, *shp[dim+1:])
def repeat(self, repeats, *args) -> Tensor:
"""
Repeats tensor number of times along each dimension specified by `repeats`.
`repeats` can be passed as a tuple or as separate arguments.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3])
print(t.repeat(4, 2).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.repeat(4, 2, 1).shape)
```
"""
repeats = argfix(repeats, *args)
base_shape = _align_left(self.shape, repeats)[0]
unsqueezed_shape = flatten([[1, s] for s in base_shape])
expanded_shape = flatten([[r, s] for r,s in zip(repeats, base_shape)])
final_shape = [r*s for r,s in zip(repeats, base_shape)]
return self.reshape(unsqueezed_shape).expand(expanded_shape).reshape(final_shape)
def _resolve_dim(self, dim:int, *, extra:bool=False) -> int:
total = self.ndim + int(extra)
if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}")
return dim + total if dim < 0 else dim
def split(self, sizes:Union[int, list[int]], dim:int=0) -> tuple[Tensor, ...]:
"""
Splits the tensor into chunks along the dimension specified by `dim`.
If `sizes` is an integer, it splits into equally sized chunks if possible, otherwise the last chunk will be smaller.
If `sizes` is a list, it splits into `len(sizes)` chunks with size in `dim` according to `size`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(10).reshape(5, 2)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
split = t.split(2)
print("\\n".join([repr(x.numpy()) for x in split]))
```
```python exec="true" source="above" session="tensor" result="python"
split = t.split([1, 4])
print("\\n".join([repr(x.numpy()) for x in split]))
```
"""
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
dim = self._resolve_dim(dim)
if isinstance(sizes, int): sizes = [min(sizes, self.shape[dim]-i) for i in range(0, max(1, self.shape[dim]), max(1, sizes))]
assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}"
return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))])
def chunk(self, chunks:int, dim:int=0) -> list[Tensor]:
"""
Splits the tensor into `chunks` number of chunks along the dimension `dim`.
If the tensor size along `dim` is not divisible by `chunks`, all returned chunks will be the same size except the last one.
The function may return fewer than the specified number of chunks.
```python exec="true" source="above" session="tensor" result="python"
chunked = Tensor.arange(11).chunk(6)
print("\\n".join([repr(x.numpy()) for x in chunked]))
```
```python exec="true" source="above" session="tensor" result="python"
chunked = Tensor.arange(12).chunk(6)
print("\\n".join([repr(x.numpy()) for x in chunked]))
```
```python exec="true" source="above" session="tensor" result="python"
chunked = Tensor.arange(13).chunk(6)
print("\\n".join([repr(x.numpy()) for x in chunked]))
```
"""
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}"
dim = self._resolve_dim(dim)
return list(self.split(ceildiv(self.shape[dim], chunks) if self.shape[dim] else [0]*chunks, dim=dim))
def meshgrid(self:Tensor, *args:Tensor, indexing:Union[Literal["ij"], Literal["xy"]]="ij") -> tuple[Tensor, ...]:
"""
Generates coordinate matrices from coordinate vectors.
Input tensors can be scalars or 1D tensors.
`indexing` determines how the output grids are aligned.
`ij` indexing follows matrix-style indexing and `xy` indexing follows Cartesian-style indexing.
```python exec="true" source="above" session="tensor" result="python"
x, y = Tensor([1, 2, 3]), Tensor([4, 5, 6])
grid_x, grid_y = x.meshgrid(y)
print(grid_x.numpy())
print(grid_y.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
grid_x, grid_y = x.meshgrid(y, indexing="xy")
print(grid_x.numpy())
print(grid_y.numpy())
```
"""
if indexing not in ("ij", "xy"): raise RuntimeError(f'indexing must be in ("ij", "xy"), got {indexing}')
if len(tensors:=(self, *args)) == 1: return tensors
basis = tuple(range(len(tensors))) if indexing == "ij" else (1, 0) + tuple(range(2, len(tensors)))
tensors = tuple(t.reshape((-1,) + (1,)*(len(args) - i)) for i,t in zip(basis, tensors))
output_shape = _broadcast_shape(*(t.shape for t in tensors))
return tuple(t._broadcast_to(output_shape) for t in tensors)
def squeeze(self, dim:Optional[int]=None) -> Tensor:
"""
Returns a tensor with specified dimensions of input of size 1 removed.
If `dim` is not specified, all dimensions with size 1 are removed.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.zeros(2, 1, 2, 1, 2)
print(t.squeeze().shape)
```
```python exec="true" source="above" session="tensor" result="python"
print(t.squeeze(0).shape)
```
```python exec="true" source="above" session="tensor" result="python"
print(t.squeeze(1).shape)
```
"""
if dim is None: return self.reshape(tuple(dim for dim in self.shape if dim != 1))
dim = self._resolve_dim(dim)
return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:])
def unsqueeze(self, dim:int) -> Tensor:
"""
Returns a tensor with a new dimension of size 1 inserted at the specified `dim`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3, 4])
print(t.unsqueeze(0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.unsqueeze(1).numpy())
```
"""
dim = self._resolve_dim(dim, extra=True)
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
@property
def T(self) -> Tensor:
"""`.T` is an alias for `.transpose()`."""
return self.transpose()
def transpose(self, dim0=1, dim1=0) -> Tensor:
"""
Returns a tensor that is a transposed version of the original tensor.
The given dimensions `dim0` and `dim1` are swapped.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(6).reshape(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.transpose(0, 1).numpy())
```
"""
order = list(range(self.ndim))
order[dim0], order[dim1] = order[dim1], order[dim0]
return self.permute(order)
def flatten(self, start_dim=0, end_dim=-1):
"""
Flattens the tensor by reshaping it into a one-dimensional tensor.
If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(8).reshape(2, 2, 2)
print(t.flatten().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.flatten(start_dim=1).numpy())
```
"""
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
def unflatten(self, dim:int, sizes:tuple[int,...]):
"""
Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape)
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape)
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape)
```
"""
dim = self._resolve_dim(dim)
return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
def roll(self, shifts:Union[int, tuple[int, ...]], dims:Union[int, tuple[int, ...]]) -> Tensor:
"""
Rolls the tensor along specified dimension(s).
The rolling operation is circular, meaning that elements that go beyond the edge are wrapped around to the beginning of the dimension.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(4)
print(t.roll(shifts=1, dims=0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.roll(shifts=-1, dims=0).numpy())
```
"""
dims, rolled = tuple(self._resolve_dim(d) for d in make_tuple(dims, 1)), self
for dim, shift in zip(dims, make_tuple(shifts, 1)):
shift = shift % self.shape[dim]
rolled = Tensor.cat(rolled[tuple(slice(None) if i != dim else slice(-shift, None) for i in range(rolled.ndim))],
rolled[tuple(slice(None) if i != dim else slice(None, -shift) for i in range(rolled.ndim))], dim=dim)
return rolled
# ***** reduce ops *****
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor:
axis = tuple(self._resolve_dim(x) for x in (range(self.ndim) if axis is None else make_tuple(axis, 1)))
if self.ndim == 0: axis = ()
ret = fxn.apply(self, axis=axis)
return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis))
def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
"""
Returns the sum of the elements of the tensor along the specified axis or axes.
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
which the maximum is computed and whether the reduced dimensions are retained.
You can pass in `acc_dtype` keyword argument to control the data type of the accumulation.
If not specified, the accumulation data type is chosen based on the input tensor's data type.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(6).reshape(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.sum().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.sum(axis=0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.sum(axis=1).numpy())
```
"""
ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(F.Sum, axis, keepdim)
return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
def prod(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
"""
Returns the product of the elements of the tensor along the specified axis or axes.
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
which the maximum is computed and whether the reduced dimensions are retained.
You can pass in `acc_dtype` keyword argument to control the data type of the accumulation.
If not specified, the accumulation data type is chosen based on the input tensor's data type.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, -2, -3, 1, 2, 3]).reshape(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.prod().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.prod(axis=0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.prod(axis=1).numpy())
```
"""
return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(F.Prod, axis, keepdim)
def max(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
"""
Returns the maximum value of the tensor along the specified axis or axes.
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
which the maximum is computed and whether the reduced dimensions are retained.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 0, 2], [5, 4, 3]])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.max().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.max(axis=0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.max(axis=1, keepdim=True).numpy())
```
"""
return self._reduce(F.Max, axis, keepdim)
def _inverse(self): return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not()
def min(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
"""
Returns the minimum value of the tensor along the specified axis or axes.
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
which the minimum is computed and whether the reduced dimensions are retained.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 0, 2], [5, 4, 3]])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.min().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.min(axis=0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.min(axis=1, keepdim=True).numpy())
```
"""
return self._inverse().max(axis=axis, keepdim=keepdim)._inverse()
def any(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
"""
Tests if any element evaluates to `True` along the specified axis or axes.
You can pass in `axis` and `keepdim` keyword arguments to control the reduce axis and whether the reduced dimensions are retained.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[True, True], [True, False], [False, False]])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.any().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.any(axis=0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.any(axis=1, keepdim=True).numpy())
```
"""
return self.bool().max(axis, keepdim)
def all(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
"""
Tests if all element evaluates to `True` along the specified axis or axes.
You can pass in `axis` and `keepdim` keyword arguments to control the reduce axis and whether the reduced dimensions are retained.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[True, True], [True, False], [False, False]])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.all().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.all(axis=0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.all(axis=1, keepdim=True).numpy())
```
"""
return self.logical_not().any(axis, keepdim).logical_not()
def mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
"""
Returns the mean value of the tensor along the specified axis or axes.
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
which the mean is computed and whether the reduced dimensions are retained.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.normal(2, 3, mean=2.5, std=0.5)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.mean().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.mean(axis=0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.mean(axis=1).numpy())
```
"""
output_dtype = self.dtype if dtypes.is_float(self.dtype) else dtypes.float32
numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim)
return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])).cast(output_dtype)
def var(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
"""
Returns the variance of the tensor along the specified axis or axes.
You can pass in `axis`, `keepdim`, and `correction` keyword arguments to control the axis along
which the variance is computed, whether the reduced dimensions are retained, and the Bessel's correction applied.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.normal(2, 3, mean=2.5, std=0.5)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.var().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.var(axis=0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.var(axis=1).numpy())
```
"""
squares = (self - self.mean(axis=axis, keepdim=True)).square()
n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])
return squares.sum(axis=axis, keepdim=keepdim).div(smax([0, n-correction]))
def std(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
"""
Returns the standard deviation of the tensor along the specified axis or axes.
You can pass in `axis`, `keepdim`, and `correction` keyword arguments to control the axis along
which the standard deviation is computed, whether the reduced dimensions are retained, and the Bessel's correction applied.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.normal(2, 3, mean=2.5, std=0.5)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.std().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.std(axis=0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.std(axis=1).numpy())
```
"""
return self.var(axis, keepdim, correction).sqrt()
def std_mean(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
"""
Calculates the standard deviation and mean over the dimensions specified by dim.
Syntactic sugar around `Tensor.std` and `Tensor.mean` to match `torch.std_mean`.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.normal(2, 3, mean=2.5, std=0.5)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
std, mean = t.std_mean()
print(std.numpy(), mean.numpy())
```
"""
return self.std(axis, keepdim, correction), self.mean(axis, keepdim)
def _softmax(self, axis, dtype:Optional[DTypeLike]=None):
x = self.cast(dtype) if dtype is not None else self
m = x - x.max(axis=axis, keepdim=True).detach()
e = m.exp()
return m, e, e.sum(axis=axis, keepdim=True)
def softmax(self, axis=-1, dtype:Optional[DTypeLike]=None):
"""
Applies the softmax function to the tensor along the specified axis.
Rescales the elements of the tensor such that they lie in the range [0, 1] and sum to 1.
You can pass in the `axis` keyword argument to control the axis along which the softmax is computed.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.softmax().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.softmax(axis=0).numpy())
```
"""
_, e, ss = self._softmax(axis, dtype)
return e.div(ss)
def log_softmax(self, axis=-1, dtype:Optional[DTypeLike]=None):
"""
Applies the log-softmax function to the tensor along the specified axis.
The log-softmax function is a numerically stable alternative to the softmax function in log space.
You can pass in the `axis` keyword argument to control the axis along which the log-softmax is computed.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.log_softmax().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.log_softmax(axis=0).numpy())
```
"""
m, _, ss = self._softmax(axis, dtype)
return m - ss.log()
def logsumexp(self, axis=None, keepdim=False):
"""
Computes the log-sum-exp of the tensor along the specified axis or axes.
The log-sum-exp function is a numerically stable way to compute the logarithm of the sum of exponentials.
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
which the log-sum-exp is computed and whether the reduced dimensions are retained.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.logsumexp().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.logsumexp(axis=0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.logsumexp(axis=1).numpy())
```
"""
m = self.max(axis=axis, keepdim=True)
return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + m.squeeze(axis)
def logcumsumexp(self, axis=0):
"""
Computes the log-cumsum-exp of the tensor along the specified axis or axes.
The log-cumsum-exp function is a numerically stable way to compute the logarithm of the cumulative sum of exponentials.
You can pass in the `axis` keyword argument to control the axis along which
the log-cum-sum-exp is computed.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.logcumsumexp().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.logcumsumexp(axis=0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.logcumsumexp(axis=1).numpy())
```
"""
m = self.max(axis=axis, keepdim=True)
return (self - m).exp().cumsum(axis=axis).log() + m
def argmax(self, axis=None, keepdim=False):
"""
Returns the indices of the maximum value of the tensor along the specified axis.
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
which the maximum is computed and whether the reduced dimensions are retained.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 0, 2], [5, 4, 3]])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.argmax().numpy()) # Returns the index of the maximum value in the flattened tensor.
```
```python exec="true" source="above" session="tensor" result="python"
print(t.argmax(axis=0).numpy()) # Returns the indices of the maximum values along axis 0.
```
```python exec="true" source="above" session="tensor" result="python"
print(t.argmax(axis=1).numpy()) # Returns the indices of the maximum values along axis 1.
```
"""
if axis is None: return self.flatten().argmax(0)
axis = self._resolve_dim(axis)
m = self == self.max(axis=axis, keepdim=True)
idx = m * Tensor.arange(self.shape[axis],0,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)).cast(dtypes.int32)
def argmin(self, axis=None, keepdim=False):
"""
Returns the indices of the minimum value of the tensor along the specified axis.
You can pass in `axis` and `keepdim` keyword arguments to control the axis along
which the minimum is computed and whether the reduced dimensions are retained.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 0, 2], [5, 4, 3]])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.argmin().numpy()) # Returns the index of the minimum value in the flattened tensor.
```
```python exec="true" source="above" session="tensor" result="python"
print(t.argmin(axis=0).numpy()) # Returns the indices of the minimum values along axis 0.
```
```python exec="true" source="above" session="tensor" result="python"
print(t.argmin(axis=1).numpy()) # Returns the indices of the minimum values along axis 1.
```
"""
return self._inverse().argmax(axis=axis, keepdim=keepdim)
def rearrange(self, formula: str, **sizes) -> Tensor:
"""
Rearranges input according to formula
See: https://einops.rocks/api/rearrange/
```python exec="true" source="above" session="tensor" result="python"
x = Tensor([[1, 2], [3, 4]])
print(Tensor.rearrange(x, "batch channel -> (batch channel)).numpy())
```
"""
def parse_formula(formula: str):
tokens = f" {formula} ".replace("", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split()
lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")"))
pairs = list(zip(lparens, rparens))
assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch"
return [name for name in tokens if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)]
assert formula.count("->") == 1, 'need exactly one "->" in formula'
(lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->"))
for name in sizes: assert name in lhs, f"axis {name} is not used in transform"
assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}"
for name in flatten((lhs, rhs)): assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}"
assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}"
assert lhs.count("...") <= 1, f"too many ellipses in {formula}"
# resolve ellipsis
if "..." in lhs: ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims)
lhs, rhs = map(lambda l: l[:(i:=l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1:] if "..." in l else l, (lhs, rhs))
unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims]
flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims]
# apply movement ops in order unflatten -> permute -> flatten/unsqueeze
t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self)
for i, name in enumerate(lhs): assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect"
t = t.permute([lhs.index(name) for name in rhs])
return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0]<dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
@staticmethod
def einsum(formula:str, *operands:Tensor|Sequence[Tensor], acc_dtype:Optional[DTypeLike]=None) -> Tensor:
"""
Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
See: https://pytorch.org/docs/stable/generated/torch.einsum.html
```python exec="true" source="above" session="tensor" result="python"
x = Tensor([[1, 2], [3, 4]])
y = Tensor([[5, 6], [7, 8]])
print(Tensor.einsum("ij,ij->", x, y).numpy())
```
"""
def parse_formula(formula:str, *operands:Tensor):
if "..." in (formula := formula.replace(" ", "")):
ell_chars, ell_longest = "".join(set(string.ascii_letters) - set(formula)), 0
for i, inp in enumerate(filter(lambda x: "..." in x, inputs := formula.split("->")[0].split(","))):
if (ell_count := max(operands[i].ndim, 1) - (len(inp) - len("..."))) > ell_longest: ell_longest = ell_count
inputs[i] = inp.replace("...", ell_chars[-ell_count:])
inputs_str, out_ellipse = ",".join(inputs), ell_chars[-ell_longest:]
return (inputs_str, formula.split("->")[1].replace("...", out_ellipse)) if "->" in formula else \
(inputs_str, out_ellipse + ''.join(sorted(c for c in inputs_str if inputs_str.count(c) == 1 and c.isalpha() and c not in out_ellipse)))
return formula.split("->") if "->" in formula else (formula, ''.join(c for c in sorted(formula) if formula.count(c) == 1 and c.isalpha()))
xs:tuple[Tensor, ...] = argfix(*operands)
inputs_str, output = parse_formula(formula, *xs)
inputs = inputs_str.split(",")
assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}"
# map the value of each letter in the formula
letter_val = sorted(merge_dicts([dict(zip(letters, tensor.shape)) for letters, tensor in zip(inputs, xs)]).items())
xs_:list[Tensor] = []
lhs = [sorted(enumerate(s), key=lambda e:e[1]) for s in inputs]
for x,(order,letters) in zip(xs, [list(zip(*l)) for l in lhs]):
# permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters
xs_.append(x.permute(order).reshape([val if letter in letters else 1 for letter,val in letter_val]).expand([val for _,val in letter_val]))
# ordinal encode the output alphabet
rhs_order = argsort(argsort(list(output)))
# sum over all axes that's not in the output, then permute to the output order
return functools.reduce(lambda a,b:a*b, xs_) \
.sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output], acc_dtype=acc_dtype).permute(rhs_order)
# ***** processing ops *****
def _pool(self, k_:tuple[sint, ...], stride:Union[tuple[int, ...], int]=1, dilation:Union[tuple[int, ...], int]=1) -> Tensor:
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
s_, d_ = make_tuple(stride, len(k_)), make_tuple(dilation, len(k_))
assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
noop, i_ = [None] * (self.ndim-len(k_)), self.shape[-len(k_):]
assert all(resolve(d*(k-1)+1 <= i) for k,d,i in zip(k_,d_,i_)), "kernel size cannot be greater than actual input size"
o_ = [ceildiv(i-d*(k-1), s) for i,d,k,s in zip(i_,d_,k_,s_)]
if any(resolve(k > s) for k,s in zip(k_,s_)) or any(d != 1 for d in d_):
# input size scaling factor to make sure shrink for stride is possible
f_ = [1 + int(resolve(o*s > i+d)) for o,s,i,d in zip(o_,s_,i_,d_)]
# # repeats such that we don't need padding
x = self.repeat([1]*len(noop) + [ceildiv(k*(i*f+d),i) for k,i,d,f in zip(k_,i_,d_,f_)])
# handle dilation
x = x.shrink(tuple(noop + [(0,k*(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)])).reshape(noop + flatten((k,(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)))
# handle stride
x = x.shrink(tuple(noop + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_,o_,s_)))).reshape(noop + flatten((k,o,s) for k,o,s in zip(k_,o_,s_)))
x = x.shrink(tuple(noop + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_,o_)))).reshape(noop + flatten((k,o) for k,o in zip(k_,o_)))
# permute to move reduce to the end
return x.permute(*range(len(noop)), *[len(noop)+i*2+1 for i in range(len(i_))], *[len(noop)+i*2 for i in range(len(i_))])
# TODO: once the shapetracker can optimize well, remove this alternative implementation
x = self.pad(tuple(noop + [(0, max(0,o*s-i)) for i,o,s in zip(i_,o_,s_)])).shrink(tuple(noop + [(0,o*s) for o,s in zip(o_,s_)]))
x = x.reshape(noop + flatten(((o,s) for o,s in zip(o_,s_))))
x = x.shrink(tuple(noop + flatten(((0,o), (0,k)) for o,k in zip(o_,k_))))
return x.permute(*range(len(noop)), *[len(noop)+i*2 for i in range(len(i_))], *[len(noop)+i*2+1 for i in range(len(i_))])
def _padding2d(self, padding:Union[int, Sequence[int]], dims:int) -> Sequence[int]:
return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1])
def _ceil_mode_padding2d(self,k_:tuple[sint, ...], s_:Union[tuple[int, ...], int], d_:Union[tuple[int, ...], int],
p_:Union[tuple[int, ...], int]) -> Sequence[int]:
(d_,s_,p_), i_ = (make_tuple(x, len(k_)) for x in (d_,s_,p_)), self.shape[-len(k_):]
# https://arxiv.org/pdf/1603.07285 section 5.1, relationship 15.
o_ = [ceildiv(i+2*p - (d*(k-1)+1), s) + 1 for i,d,k,s,p in zip(i_,d_,k_,s_,p_)]
pads = list(self._padding2d(p_, len(k_)))
# we have to do additional padding before `_pool` so that `o_` in `_pool` is calculated correctly
# `s*(o-1) + (d*(k-1)+1) - (i+2*p)` -> last_sliding_window_start + full_kernel_size - padded_input_shape
# we decrease padding in the case that a sliding window starts in the end padded region, thereby decreasing `o_` in `_pool`
# `smax(s*(o-1) - (p+i-1), 0)` -> last_sliding_window_start - (left_pad + input_size - zero_offset)
for dim,(o,i,s,p,k,d) in enumerate(zip(o_,i_,s_,p_,k_,d_)): pads[-1-dim*2] += s*(o-1) + (d*(k-1)+1) - (i+2*p) - smax(s*(o-1) - (p+i-1), 0)
return pads
# NOTE: these work for more than 2D
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, ceil_mode=False, count_include_pad=True):
"""
Applies average pooling over a tensor.
When `ceil_mode` is set to True, output shape will be determined using ceil division.
When `count_include_pad` is set to False, zero padding will not be included in the averaging calculation.
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
See: https://paperswithcode.com/method/average-pooling
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(25).reshape(1, 1, 5, 5)
print(t.avg_pool2d().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.avg_pool2d(ceil_mode=True).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.avg_pool2d(padding=1).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.avg_pool2d(padding=1, count_include_pad=False).numpy())
```
"""
axis = tuple(range(-len(k_ := make_tuple(kernel_size, 2)), 0))
reg_pads, ceil_pads = self._padding2d(padding,len(k_)), self._ceil_mode_padding2d(k_, stride if stride is not None else k_, dilation, padding)
def pool(x:Tensor, padding_:Sequence[int]) -> Tensor: return x.pad(padding_)._pool(k_, stride if stride is not None else k_, dilation)
if not count_include_pad:
pads = ceil_pads if ceil_mode else reg_pads
return pool(self, pads).sum(axis) / pool(self.ones_like(), pads).sum(axis)
if not ceil_mode: return pool(self, reg_pads).mean(axis)
return pool(self, ceil_pads).sum(axis) / pool(self.pad(reg_pads).ones_like(), tuple(cp-rp for cp,rp in zip(ceil_pads, reg_pads))).sum(axis)
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, ceil_mode=False):
"""
Applies max pooling over a tensor.
When `ceil_mode` is set to True, output shape will be determined using ceil division.
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
See: https://paperswithcode.com/method/max-pooling
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(25).reshape(1, 1, 5, 5)
print(t.max_pool2d().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.max_pool2d(ceil_mode=True).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.max_pool2d(padding=1).numpy())
```
"""
k_ = make_tuple(kernel_size, 2)
pads = self._ceil_mode_padding2d(k_, stride if stride is not None else k_, dilation, padding) if ceil_mode else self._padding2d(padding, len(k_))
return self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation).max(tuple(range(-len(k_), 0)))
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0,
acc_dtype:Optional[DTypeLike]=None) -> Tensor:
"""
Applies a convolution over a tensor with a given `weight` and optional `bias`.
NOTE: unlike PyTorch, this implementation is not limited to only 2d convolutions and instead works for any number of dimensions.
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(9).reshape(1, 1, 3, 3)
w = Tensor.ones(1, 1, 2, 2)
print(t.conv2d(w).numpy())
```
"""
if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding, acc_dtype)
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501
if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" # noqa: E501
padding_ = self._padding2d(padding, len(HW))
# conv2d is a pooling op (with padding)
x = self.pad(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
rcout, oyx = cout//groups, x.shape[2:-len(HW)]
if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not WINO:
# normal conv
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) # noqa: E501
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True, acc_dtype=acc_dtype).reshape(bs, cout, *oyx) # noqa: E501
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles
winograd_G = [[1/4, 0, 0], [-1/6, -1/6, -1/6], [-1/6, 1/6, -1/6], [1/24, 1/12, 1/6], [1/24, -1/12, 1/6], [0, 0, 1]]
winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]]
winograd_At = [[1, 1, 1, 1, 1, 0], [0, 1, -1, 2, -2, 0], [0, 1, 1, 4, 4, 0], [0, 1, -1, 8, -8, 1]] # applying At in pre-order doubles compile time
# todo: stride == dilation
# use padding to round up to 4x4 output tiles
# (bs, cin_, tyx, HWI)
d = self.pad(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # noqa: E501
# move HW to the front: # (HWI, bs, cin_, tyx)
d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW)))
tyx = d.shape[-len(HWI):] # dim of tiling
g = weight.permute(*range(len(weight.shape)-len(HW),len(weight.shape)), *range(len(weight.shape)-len(HW))) # move HW to the front
# compute 6x6 winograd tiles: GgGt, BtdB
# (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1))
gfactors = _apply_winograd_matrix(winograd_G, g, len(HW)).reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx)))
# (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx)
dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).reshape(*HWI, bs, groups, 1, cin, *tyx)
# matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx)
ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), acc_dtype=acc_dtype), len(HW))
# interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO)
ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]])
# merge groups and rcout, tyx and HWO: (bs, groups, cout, *yx), shrink to final
ret = ret.reshape(bs, cout, *[c * HWO[i] for i, c in enumerate(tyx)]).shrink(tuple((0, s) for s in [bs, cout, *oyx]))
return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward()
def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
"""
Applies a transposed convolution over a tensor with a given `weight` and optional `bias`.
NOTE: unlike PyTorch, this implementation is not limited to only 2d transposed convolutions and instead works for any number of dimensions.
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(9).reshape(1, 1, 3, 3)
w = Tensor.ones(1, 1, 2, 2)
print(t.conv_transpose2d(w).numpy())
```
"""
x, w = self, weight.unflatten(0, (groups, -1)).transpose(1, 2).flip(*range(3, len(weight.shape)+1))
HW = weight.shape[2:]
stride, dilation, padding, output_padding = [make_tuple(x, len(HW)) for x in (stride, dilation, padding, output_padding)]
if any(s>1 for s in stride):
# handle strides: (k) -> reshape -> (k,1) -> pad -> (k,s) -> reshape -> (k*s) -> shrink (k-(s-1))
x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:]))
x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride)))
x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)])
x = x.shrink((None, None, *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, dilation, padding, output_padding)))))
return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
def dot(self, w:Tensor, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
"""
Performs dot product between two tensors.
If `w` is 1-D, it's a sum product over the last axis of `self` and `w`.
If `w` is N-D with N>=2, it's a sum product over the last axis of `self` and the second-to-last axis of `w`.
You can pass in the optional `acc_dtype` keyword argument to control the data type of the accumulation.
```python exec="true" source="above" session="tensor" result="python"
a = Tensor([1, 2, 3])
b = Tensor([1, 1, 0])
print(a.dot(b).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
a = Tensor([[1, 2], [3, 4]])
b = Tensor([[5, 6], [7, 8]])
print(a.dot(b).numpy())
```
"""
if IMAGE: return self.image_dot(w, acc_dtype)
x, dx, dw = self, self.ndim, w.ndim
if not (dx > 0 and dw > 0): raise RuntimeError(f"both tensors need to be at least 1D, got {dx}D and {dw}D")
if x.shape[-1] != w.shape[axis_w:=-min(w.ndim,2)]: raise RuntimeError(f"cannot dot {x.shape} and {w.shape}")
x = x.reshape(*x.shape[0:-1], *[1]*min(dx-1, dw-1, 1), x.shape[-1])
w = w.reshape(*w.shape[0:-2], *[1]*min(dx-1, dw-1, 1), *w.shape[axis_w:]).transpose(-1, axis_w)
return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype) if acc_dtype is None else acc_dtype)
def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
"""
Performs matrix multiplication between two tensors.
You can pass in the `reverse` keyword argument to control the order of the matrix multiplication.
You can pass in the optional `acc_dtype` keyword argument to control the data type of the accumulation.
```python exec="true" source="above" session="tensor" result="python"
a = Tensor([[1, 2], [3, 4]])
b = Tensor([[5, 6], [7, 8]])
print(a.matmul(b).numpy())
```
"""
return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype)
def _cumalu(self, axis:int, op:Ops, _include_initial=False) -> Tensor:
assert self.shape[axis] != 0 and op in (Ops.ADD, Ops.MAX)
pl_sz = self.shape[axis] - int(not _include_initial)
pooled = self.transpose(axis,-1).pad((pl_sz, -int(_include_initial)), value=identity_element(op, self.dtype))._pool((self.shape[axis],))
return (pooled.sum(-1) if op is Ops.ADD else pooled.max(-1)).transpose(axis,-1)
def _split_cumalu(self, axis:int, op:Ops) -> Tensor:
axis = self._resolve_dim(axis)
if self.ndim == 0 or 0 in self.shape: return self
# TODO: someday the optimizer will find this on it's own
# for now this is a two stage cumsum
SPLIT = 256
if not isinstance(s:=self.shape[axis], int) or s <= SPLIT*2: return self._cumalu(axis, op)
ret = self.transpose(axis,-1).pad((round_up(s, SPLIT)-s, 0), value=identity_element(op, self.dtype)).unflatten(-1, (-1, SPLIT))._cumalu(-1, op)
base = ret[..., -1]._cumalu(-1, op, _include_initial=True)
base = base.unsqueeze(-1).expand(*base.shape, ret.shape[-1])
def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
return fix(ret) + fix(base) if op is Ops.ADD else fix(ret).maximum(fix(base))
def cumsum(self, axis:int=0) -> Tensor:
"""
Computes the cumulative sum of the tensor along the specified `axis`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.cumsum(1).numpy())
```
"""
return self._split_cumalu(axis, Ops.ADD)
def cummax(self, axis:int=0) -> Tensor:
"""
Computes the cumulative max of the tensor along the specified `axis`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([0, 1, -1, 2, -2, 3, -3])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.cummax(0).numpy())
```
"""
return self._split_cumalu(axis, Ops.MAX)
@staticmethod
def _tri(r:sint, c:sint, diagonal:int=0, **kwargs) -> Tensor:
assert isinstance(r, int) and isinstance(c, int), f"does not support symbolic, getting {r=}, {c=}"
if r == 0 or c == 0 or diagonal >= c: return Tensor.zeros(r,c,**kwargs)
if r+diagonal <= 0: return Tensor.ones(r,c,**kwargs)
s = r+c-1
# build a (s, s) upper triangle
t = Tensor.ones(s,s,**kwargs).pad((None,(0,s))).flatten().shrink(((0,s*(2*s-1)),)).reshape(s,-1).shrink((None,(0,s)))
return t[:r,-diagonal:c-diagonal] if diagonal <= 0 else t[diagonal:r+diagonal,:c]
def triu(self, diagonal:int=0) -> Tensor:
"""
Returns the upper triangular part of the tensor, the other elements are set to 0.
The argument `diagonal` determines which diagonal is on the boundary. `diagonal = 0` means the main diagonal.
Positive `diagonal` means above the main diagonal, and negative `diagonal` means below the main diagonal.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.triu(diagonal=0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.triu(diagonal=1).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.triu(diagonal=-1).numpy())
```
"""
return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal, device=self.device, dtype=dtypes.bool).where(self, 0).cast(self.dtype)
def tril(self, diagonal:int=0) -> Tensor:
"""
Returns the lower triangular part of the tensor, the other elements are set to 0.
The argument `diagonal` determines which diagonal is on the boundary. `diagonal = 0` means the main diagonal.
Positive `diagonal` means above the main diagonal, and negative `diagonal` means below the main diagonal.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.tril(diagonal=0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.tril(diagonal=1).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.tril(diagonal=-1).numpy())
```
"""
return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device, dtype=dtypes.bool).where(0, self).cast(self.dtype)
def interpolate(self, size:tuple[int, ...], mode:str="linear", align_corners:bool=False) -> Tensor:
"""
Downsamples or Upsamples to the input `size`, accepts 0 to N batch dimensions.
The interpolation algorithm is selected with `mode` which currently only supports `linear`, `nearest` and `nearest-exact`.
To run `bilinear` or `trilinear`, pass in a 2D or 3D size.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 2, 3, 4], [21, 22, 23, 24], [41, 42, 43, 44]])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.interpolate(size=(2,3), mode="linear").numpy())
```
"""
assert isinstance(size, (tuple,list)) and all_int(size) and 0 < len(size) <= self.ndim, f"invalid {size=}"
assert mode in ("linear", "nearest", "nearest-exact"), "only supports linear, nearest or nearest-exact interpolate"
assert not (align_corners and mode != "linear"), "align_corners option can only be set with the interpolating mode linear"
x, expand = self, list(self.shape)
for i in range(-1,-len(size)-1,-1):
scale = (self.shape[i] - int(align_corners)) / (size[i] - int(align_corners))
arr, reshape = Tensor.arange(size[i], dtype=dtypes.float32, device=self.device), [1] * self.ndim
reshape[i] = expand[i] = size[i]
if mode == "linear":
index = (scale*arr if align_corners else (scale*(arr+0.5))-0.5).clip(0, self.shape[i]-1)
low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())]
x = x.gather(i, low).lerp(x.gather(i, high), perc)
else:
index = (scale*(arr+0.5) if mode=="nearest-exact" else scale*arr).cast(dtypes.int32).reshape(reshape).expand(expand)
x = x.gather(i, index)
return x.cast(self.dtype)
def scatter(self, dim:int, index:Tensor, src:Union[Tensor, ConstType], reduce:Union[None, Literal['multiply'], Literal['add']]=None) -> Tensor:
"""
Scatters `src` values along an axis specified by `dim`.
Apply `add` or `multiply` reduction operation with `reduce`.
```python exec="true" source="above" session="tensor" result="python"
src = Tensor.arange(1, 11).reshape(2, 5)
print(src.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
index = Tensor([[0, 1, 2, 0]])
print(Tensor.zeros(3, 5, dtype=src.dtype).scatter(0, index, src).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
index = Tensor([[0, 1, 2], [0, 1, 4]])
print(Tensor.zeros(3, 5, dtype=src.dtype).scatter(1, index, src).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.full((2, 4), 2.0).scatter(1, Tensor([[2], [3]]), 1.23, reduce='multiply').numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.full((2, 4), 2.0).scatter(1, Tensor([[2], [3]]), 1.23, reduce='add').numpy())
```
"""
if reduce not in {None, "add", "multiply"}: raise TypeError(f"{reduce=} must be one of None, 'multiply', or 'add'")
index, dim = index.to(self.device), self._resolve_dim(dim)
src = src.cast(self.dtype) if isinstance(src, Tensor) else Tensor(src, device=self.device, dtype=self.dtype)._broadcast_to(index.shape)
assert index.ndim == self.ndim == src.ndim, f"self.ndim, index.ndim and src.dim must all equal, {self.ndim=} {index.ndim=} {src.ndim=}"
assert all((d == dim or self_ >= index_) and src_ >= index_ for d,(self_,index_,src_) in enumerate(zip(self.shape, index.shape, src.shape))), \
f"All dimensions of {index.shape=} should be <= to all dimensions of {src.shape=} and all dimensions except dimension {dim} of {self.shape=}"
# shrink src to index shape to shrink away the unused values
src = src.shrink(tuple((0,s) for s in index.shape))
# prepare src and mask for reduce with respect to dim
src = src.unsqueeze(-1).expand(*src.shape, self.shape[dim]).transpose(-1, dim)
mask = index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).transpose(-1, dim)
# pad src and mask to self.shape so that reduce can be done with padded values as no-ops
src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask))
if reduce == "add": return mask.where(src, 0).sum(-1, acc_dtype=self.dtype) + self
if reduce == "multiply": return mask.where(src, 1).prod(-1, acc_dtype=self.dtype) * self
return _masked_setitem(self, src, mask, (-1,))
# ***** unary ops *****
def logical_not(self):
"""
Computes the logical NOT of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([False, True]).logical_not().numpy())
```
"""
return F.Neq.apply(*self.cast(dtypes.bool)._broadcasted(True))
def neg(self):
"""
Negates the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).neg().numpy())
```
"""
return self*-1 if self.dtype != dtypes.bool else self.logical_not()
def contiguous(self):
"""
Returns a contiguous tensor.
"""
return F.Contiguous.apply(self)
def contiguous_backward(self):
"""
Inserts a contiguous operation in the backward pass.
"""
return F.ContiguousBackward.apply(self)
def log(self):
"""
Computes the natural logarithm element-wise.
See: https://en.wikipedia.org/wiki/Logarithm
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1., 2., 4., 8.]).log().numpy())
```
"""
return F.Log.apply(self.cast(least_upper_float(self.dtype)))
def log2(self):
"""
Computes the base-2 logarithm element-wise.
See: https://en.wikipedia.org/wiki/Logarithm
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1., 2., 4., 8.]).log2().numpy())
```
"""
return self.log()/math.log(2)
def exp(self):
"""
Computes the exponential function element-wise.
See: https://en.wikipedia.org/wiki/Exponential_function
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([0., 1., 2., 3.]).exp().numpy())
```
"""
return F.Exp.apply(self.cast(least_upper_float(self.dtype)))
def exp2(self):
"""
Computes the base-2 exponential function element-wise.
See: https://en.wikipedia.org/wiki/Exponential_function
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([0., 1., 2., 3.]).exp2().numpy())
```
"""
return F.Exp.apply(self*math.log(2))
def relu(self):
"""
Applies the Rectified Linear Unit (ReLU) function element-wise.
- Described: https://paperswithcode.com/method/relu
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).relu().numpy())
```
"""
return F.Relu.apply(self)
def sigmoid(self):
"""
Applies the Sigmoid function element-wise.
- Described: https://en.wikipedia.org/wiki/Sigmoid_function
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sigmoid().numpy())
```
"""
return (1 + (self * (-1/math.log(2))).exp2()).reciprocal()
def hardsigmoid(self, alpha:float=1/6, beta:float=0.5):
"""
Applies the Hardsigmoid function element-wise.
NOTE: default `alpha` and `beta` values is taken from torch
- Described: https://paperswithcode.com/method/hard-sigmoid
- See: https://pytorch.org/docs/stable/generated/torch.nn.functional.hardsigmoid.html
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).hardsigmoid().numpy())
```
"""
return (alpha * self + beta).relu() - (alpha * self + beta - 1).relu()
def sqrt(self):
"""
Computes the square root of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1., 2., 3., 4.]).sqrt().numpy())
```
"""
return F.Sqrt.apply(self.cast(least_upper_float(self.dtype)))
def rsqrt(self):
"""
Computes the reciprocal of the square root of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1., 2., 3., 4.]).rsqrt().numpy())
```
"""
return self.reciprocal().sqrt()
def sin(self):
"""
Computes the sine of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).sin().numpy())
```
"""
return F.Sin.apply(self.cast(least_upper_float(self.dtype)))
def cos(self):
"""
Computes the cosine of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).cos().numpy())
```
"""
return ((math.pi/2)-self).sin()
def tan(self):
"""
Computes the tangent of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([0., math.pi/4, math.pi/2, 3*math.pi/4, math.pi]).tan().numpy())
```
"""
return self.sin() / self.cos()
def asin(self):
"""
Computes the inverse sine (arcsine) of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).asin().numpy())
```
"""
# https://personal.math.ubc.ca/~cbm/aands/page_81.htm 4.4.46
coefficients = [-0.0012624911, 0.0066700901, -0.0170881256, 0.0308918810, -0.0501743046, 0.0889789874, -0.2145988016, 1.5707963050]
x = math.pi / 2 - (1.0 - self.abs()).sqrt() * polyN(self.abs(), coefficients)
return self.sign() * x
def acos(self):
"""
Computes the inverse cosine (arccosine) of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).acos().numpy())
```
"""
return math.pi / 2 - self.asin()
def atan(self):
"""
Computes the inverse tangent (arctan) of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).atan().numpy())
```
"""
return (self / (1 + self * self).sqrt()).asin()
# ***** math functions *****
def trunc(self: Tensor) -> Tensor:
"""
Truncates the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).trunc().numpy())
```
"""
return self.cast(dtypes.int32).cast(self.dtype)
def ceil(self: Tensor) -> Tensor:
"""
Rounds the tensor element-wise towards positive infinity.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).ceil().numpy())
```
"""
return (self > (b := self.trunc())).where(b+1, b)
def floor(self: Tensor) -> Tensor:
"""
Rounds the tensor element-wise towards negative infinity.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).floor().numpy())
```
"""
return (self < (b := self.trunc())).where(b-1, b)
def round(self: Tensor) -> Tensor:
"""
Rounds the tensor element-wise with rounding half to even.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).round().numpy())
```
"""
return ((self > 0) == ((b := self.cast(dtypes.int32) / 2.0).cast(dtypes.int32) == b)).where((self - 0.5).ceil(), (self + 0.5).floor())
def isinf(self:Tensor, detect_positive:bool=True, detect_negative:bool=True):
"""
Checks the tensor element-wise to return True where the element is infinity, otherwise returns False
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isinf().numpy())
```
"""
return (self == float("inf")) * detect_positive + (self == float("-inf")) * detect_negative
def isnan(self:Tensor):
"""
Checks the tensor element-wise to return True where the element is NaN, otherwise returns False
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isnan().numpy())
```
"""
return self != self
def lerp(self, end: Tensor, weight: Union[Tensor, float]) -> Tensor:
"""
Linearly interpolates between `self` and `end` by `weight`.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1., 2., 3.]).lerp(Tensor([4., 5., 6.]), 0.5).numpy())
```
"""
if self.dtype == dtypes.uint8 and isinstance(weight, Tensor):
w_i = (weight * (1<<(W_PREC:=7)) + 0.5).cast(dtypes.int16)
return (self+(((end - self).cast(dtypes.int8) * w_i + (1<<W_PREC-1)).cast(dtypes.uint16) >> W_PREC)).cast(dtypes.uint8)
return self + (end - self) * weight
def square(self):
"""
Squares the tensor element-wise.
Equivalent to `self*self`.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).square().numpy())
```
"""
return self*self
def clamp(self, min_=None, max_=None):
"""
Clips (clamps) the values in the tensor between `min_` and `max_` element-wise.
If `min_` is `None`, there is no lower bound. If `max_` is None, there is no upper bound.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).clip(-1, 1).numpy())
```
"""
if min_ is None and max_ is None: raise RuntimeError("at least one of 'min_' or 'max_' must not be None")
ret = self.maximum(min_) if min_ is not None else self
return ret.minimum(max_) if max_ is not None else ret
def clip(self, min_=None, max_=None):
"""
Alias for `Tensor.clamp`.
"""
return self.clamp(min_, max_)
def sign(self):
"""
Returns the sign of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sign().numpy())
```
"""
return F.Sign.apply(self)
def abs(self):
"""
Computes the absolute value of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).abs().numpy())
```
"""
return self * self.sign()
def reciprocal(self):
"""
Compute `1/x` element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1., 2., 3., 4.]).reciprocal().numpy())
```
"""
return F.Reciprocal.apply(self.cast(least_upper_float(self.dtype)))
# ***** activation functions *****
def elu(self, alpha=1.0):
"""
Applies the Exponential Linear Unit (ELU) function element-wise.
- Described: https://paperswithcode.com/method/elu
- Paper: https://arxiv.org/abs/1511.07289v5
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).elu().numpy())
```
"""
return self.relu() - alpha*(1-self.exp()).relu()
def celu(self, alpha=1.0):
"""
Applies the Continuously differentiable Exponential Linear Unit (CELU) function element-wise.
- Described: https://paperswithcode.com/method/celu
- Paper: https://arxiv.org/abs/1704.07483
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).celu().numpy())
```
"""
return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
def selu(self, alpha=1.67326, gamma=1.0507):
"""
Applies the Scaled Exponential Linear Unit (SELU) function element-wise.
- Described: https://paperswithcode.com/method/selu
- Paper: https://arxiv.org/abs/1706.02515v5
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).selu().numpy())
```
"""
return gamma * (self >= 0).detach().where(self, alpha * (self.exp() - 1))
def swish(self):
"""
See `.silu()`
- Paper: https://arxiv.org/abs/1710.05941v1
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).swish().numpy())
```
"""
return self * self.sigmoid()
def silu(self):
"""
Applies the Sigmoid Linear Unit (SiLU) function element-wise.
- Described: https://paperswithcode.com/method/silu
- Paper: https://arxiv.org/abs/1606.08415
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).silu().numpy())
```
"""
return self.swish() # The SiLU function is also known as the swish function.
def relu6(self):
"""
Applies the ReLU6 function element-wise.
- Described: https://paperswithcode.com/method/relu6
- Paper: https://arxiv.org/abs/1704.04861v1
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-9., -6., -3., 0., 3., 6., 9.]).relu6().numpy())
```
"""
return self.relu() - (self-6).relu()
def hardswish(self):
"""
Applies the Hardswish function element-wise.
- Described: https://paperswithcode.com/method/hard-swish
- Paper: https://arxiv.org/abs/1905.02244v5
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).hardswish().numpy())
```
"""
return self * (self+3).relu6() * (1/6)
def tanh(self):
"""
Applies the Hyperbolic Tangent (tanh) function element-wise.
- Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Tanh
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).tanh().numpy())
```
"""
return 2.0 * ((2.0 * self).sigmoid()) - 1.0
def sinh(self):
"""
Applies the Hyperbolic Sine (sinh) function element-wise.
- Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Sinh
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sinh().numpy())
```
"""
return (self.exp() - self.neg().exp()) / 2
def cosh(self):
"""
Applies the Hyperbolic Cosine (cosh) function element-wise.
- Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Cosh
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).cosh().numpy())
```
"""
return (self.exp() + self.neg().exp()) / 2
def atanh(self):
"""
Applies the Inverse Hyperbolic Tangent (atanh) function element-wise.
- Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#atanh
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).atanh().numpy())
```
"""
return ((1 + self)/(1 - self)).log() / 2
def asinh(self):
"""
Applies the Inverse Hyperbolic Sine (asinh) function element-wise.
- Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#asinh
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).asinh().numpy())
```
"""
return (self + (self.square() + 1).sqrt()).log()
def acosh(self):
"""
Applies the Inverse Hyperbolic Cosine (acosh) function element-wise.
- Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#acosh
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).acosh().numpy())
```
"""
return (self + (self.square() - 1).sqrt()).log()
def hardtanh(self, min_val=-1, max_val=1):
"""
Applies the Hardtanh function element-wise.
- Described: https://paperswithcode.com/method/hardtanh-activation
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-1.5, -1.0, -0.5, 0., 0.5, 1.0, 1.5]).hardtanh().numpy())
```
"""
return self.clip(min_val, max_val)
def erf(self):
"""
Applies error function element-wise.
- Described: https://en.wikipedia.org/wiki/Error_function
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-1.5, -1.0, -0.5, 0., 0.5, 1.0, 1.5]).erf().numpy())
```
"""
# https://personal.math.ubc.ca/~cbm/aands/page_299.htm 7.1.26
t = 1.0 / (1.0 + 0.3275911 * self.abs())
return self.sign() * (1.0 - t * polyN(t, [1.061405429, -1.453152027, 1.421413741, -0.284496736, 0.254829592]) * (-self.square()).exp())
def gelu(self):
"""
Applies the Gaussian Error Linear Unit (GELU) function element-wise.
- Described: https://paperswithcode.com/method/gelu
- Paper: https://arxiv.org/abs/1606.08415v5
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).gelu().numpy())
```
"""
return 0.5 * self * (1 + (math.sqrt(2 / math.pi) * (self + 0.044715 * self ** 3)).tanh())
def quick_gelu(self):
"""
Applies the Sigmoid GELU approximation element-wise.
- Described: https://paperswithcode.com/method/gelu
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).quick_gelu().numpy())
```
"""
return self * (self * 1.702).sigmoid()
def leakyrelu(self, neg_slope=0.01):
"""
Applies the Leaky ReLU function element-wise.
- Described: https://paperswithcode.com/method/leaky-relu
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leakyrelu().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leakyrelu(neg_slope=0.42).numpy())
```
"""
return self.relu() - (-neg_slope*self).relu()
def mish(self):
"""
Applies the Mish function element-wise.
- Described: https://paperswithcode.com/method/mish
- Paper: https://arxiv.org/abs/1908.08681v3
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).mish().numpy())
```
"""
return self * self.softplus().tanh()
def softplus(self, beta=1):
"""
Applies the Softplus function element-wise.
- Described: https://paperswithcode.com/method/softplus
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).softplus().numpy())
```
"""
return (1/beta) * (1 + (self*beta).exp()).log()
def softsign(self):
"""
Applies the Softsign function element-wise.
- Described: https://paperswithcode.com/method/softsign
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).softsign().numpy())
```
"""
return self / (1 + self.abs())
# ***** broadcasted elementwise ops *****
def _broadcast_to(self, new_shape:tuple[sint, ...]) -> Tensor:
if self.shape == new_shape: return self
if self.ndim > len(new_shape): raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {new_shape=}")
# first unsqueeze left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html
shape, _ = _align_left(self.shape, new_shape)
# for each dimension, check either dim is 1, or it does not change
if not all(resolve(s == ns) or resolve(s == 1) for s,ns in zip(shape, new_shape)):
raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}")
return F.Expand.apply(self.reshape(shape), shape=new_shape)
def _broadcasted(self, y:Union[Tensor, UOp, ConstType], reverse:bool=False, match_dtype:bool=True) -> tuple[Tensor, Tensor]:
x: Tensor = self
if not isinstance(y, Tensor):
# make y a Tensor
assert isinstance(y, (*get_args(ConstType), UOp)), f"{type(y)=}, {y=}"
if isinstance(x.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
elif not isinstance(y, UOp): y_dtype = dtypes.from_py(y)
if isinstance(y, UOp): y = Tensor.from_uop(y, device=x.device)
else: y = Tensor(dtypes.as_const(y, y_dtype), x.device, y_dtype, requires_grad=False)
if match_dtype and x.dtype != y.dtype:
output_dtype = least_upper_dtype(x.dtype, y.dtype)
x, y = x.cast(output_dtype), y.cast(output_dtype)
if reverse: x, y = y, x
# broadcast
return x._broadcast_to(out_shape:=_broadcast_shape(x.shape, y.shape)), y._broadcast_to(out_shape)
def _to_const_val(self, x:Union[Tensor, ConstType]) -> Union[Tensor, ConstType]:
return x.lazydata.const_arg if isinstance(x, Tensor) and isinstance(x.lazydata, UOp) and x.lazydata.is_unrealized_unmasked_const() \
and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
def add(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
"""
Adds `self` and `x`.
Equivalent to `self + x`.
Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(4)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.add(20).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.add(Tensor([[2.0], [3.5]])).numpy())
```
"""
return F.Add.apply(*self._broadcasted(x, reverse))
def sub(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
"""
Subtracts `x` from `self`.
Equivalent to `self - x`.
Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(4)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.sub(20).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.sub(Tensor([[2.0], [3.5]])).numpy())
```
"""
a, b = self._broadcasted(x, reverse)
return a + (-b)
def mul(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
"""
Multiplies `self` and `x`.
Equivalent to `self * x`.
Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(4)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.mul(3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.mul(Tensor([[-1.0], [2.0]])).numpy())
```
"""
return F.Mul.apply(*self._broadcasted(x, reverse))
def idiv(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
"""
Divides `self` by `x`.
Equivalent to `self // x`.
Supports broadcasting to a common shape, type promotion, and integer inputs.
`idiv` performs integer division.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1, 4, 10]).idiv(Tensor([2, 3, 4])).numpy())
```
"""
return F.IDiv.apply(*self._broadcasted(x, reverse))
def div(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
"""
Divides `self` by `x`.
Equivalent to `self / x`.
Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
`div` performs true division.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(4)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.div(3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1, 4, 10]).div(Tensor([2, 3, 4])).numpy())
```
"""
numerator, denominator = self._broadcasted(x, reverse)
return numerator.cast(least_upper_float(numerator.dtype)) * denominator.cast(least_upper_float(denominator.dtype)).reciprocal()
def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
"""
Computes bitwise xor of `self` and `x`.
Equivalent to `self ^ x`.
Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-1, -2, 3]).xor(Tensor([1, 0, 3])).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([True, True, False, False]).xor(Tensor([True, False, True, False])).numpy())
```
"""
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
return F.Xor.apply(*self._broadcasted(x, reverse))
def bitwise_and(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
"""
Compute the bit-wise AND of `self` and `x`.
Equivalent to `self & x`.
Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([2, 5, 255]).bitwise_and(Tensor([3, 14, 16])).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([True, True, False, False]).bitwise_and(Tensor([True, False, True, False])).numpy())
```
"""
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
return F.BitwiseAnd.apply(*self._broadcasted(x, reverse))
def bitwise_or(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
"""
Compute the bit-wise OR of `self` and `x`.
Equivalent to `self | x`.
Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([2, 5, 255]).bitwise_or(Tensor([4, 4, 4])).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([True, True, False, False]).bitwise_or(Tensor([True, False, True, False])).numpy())
```
"""
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
return F.BitwiseOr.apply(*self._broadcasted(x, reverse))
def bitwise_not(self) -> Tensor:
"""
Compute the bit-wise NOT of `self`.
Equivalent to `~self`.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([0, 2, 5, 255], dtype="int8").bitwise_not().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([True, False]).bitwise_not().numpy())
```
"""
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
return self.logical_not() if self.dtype == dtypes.bool else self ^ -1
def lshift(self, x:int):
"""
Computes left arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
Equivalent to `self << x`.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1, 3, 31], dtype=dtypes.uint8).lshift(2).numpy())
```
"""
assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
return self.mul(2 ** x)
def rshift(self, x:int):
"""
Computes right arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
Equivalent to `self >> x`.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([4, 13, 125], dtype=dtypes.uint8).rshift(2).numpy())
```
"""
assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
return self.idiv(2 ** x)
def pow(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
"""
Computes power of `self` with `x`.
Equivalent to `self ** x`.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-1, 2, 3]).pow(2).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-1, 2, 3]).pow(Tensor([-1.5, 0.5, 1.5])).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print((2 ** Tensor([-1, 2, 3])).numpy())
```
"""
x = self._to_const_val(x)
if not isinstance(x, Tensor) and not reverse:
# simple pow identities
if x < 0: return self.reciprocal().pow(-x).cast(self.dtype)
if x == 0: return 1 + self * 0
# rewrite pow 0.5 to sqrt
if int(x - 0.5) + 0.5 == x: return self.pow(int(x - 0.5)) * self.sqrt()
if int(x) == x: return self.pow(x // 2).square() * (1 if x % 2 == 0 else self)
# positive const ** self
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
base, exponent = self._broadcasted(x, reverse=reverse)
# start with b ** e = exp(e * log(b))
ret = base.abs().log().mul(exponent).exp()
# correct sign of negative base with odd exponent (cos has a period of 2pi so we use it here to get the oddness of the exponent)
negative_base = (base < 0).detach().where(1, 0)
# 1 for non-negative base or negative even exponent, -1 for negative odd exponent, don't care about non-integer exponent
correct_sign = 1 + negative_base * ((exponent * math.pi).cos() - 1)
# inject nan for negative base and non-integer exponent
inject_nan = (negative_base * (exponent != exponent.trunc())).detach().where(math.nan, 1)
# apply correct_sign inject_nan, and fix 0 ** 0 = 1
ret = ((base == 0) * (exponent == 0)).detach().where(1, ret * correct_sign * inject_nan)
return ret.round().cast(self.dtype) if not dtypes.is_float(self.dtype) else ret
def maximum(self, x:Union[Tensor, ConstType]) -> Tensor:
"""
Computes element-wise maximum of `self` and `x`.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-1, 2, 3]).maximum(1).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-1, 2, 3]).maximum(Tensor([-4, -2, 9])).numpy())
```
"""
# NOTE: the mid-point is for backward, revisit after new gradient API
if self.is_floating_point(): return (self<x).detach().where(x, (self==x).detach().where(((self * 0.5 + x * 0.5).cast(self.dtype)), self))
return (self<x).detach().where(x, self)
def minimum(self, x:Union[Tensor, ConstType]) -> Tensor:
"""
Computes element-wise minimum of `self` and `x`.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-1, 2, 3]).minimum(1).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-1, 2, 3]).minimum(Tensor([-4, -2, 9])).numpy())
```
"""
t, x = self._broadcasted(x)
return t._inverse().maximum(x._inverse())._inverse()
def where(self:Tensor, x:Union[Tensor, ConstType, sint], y:Union[Tensor, ConstType, sint]):
"""
Return a tensor of elements selected from either `x` or `y`, depending on `self`.
`output_i = x_i if self_i else y_i`.
```python exec="true" source="above" session="tensor" result="python"
cond = Tensor([[True, True, False], [True, False, False]])
print(cond.where(1, 3).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
cond = Tensor.randn(2, 3)
print(cond.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print((cond > 0).where(cond, -float("inf")).numpy())
```
"""
if isinstance(x, Tensor): x, y = x._broadcasted(y)
elif isinstance(y, Tensor): y, x = y._broadcasted(x)
cond, x = self._broadcasted(x, match_dtype=False)
cond, y = cond._broadcasted(y, match_dtype=False)
return F.Where.apply(cond.cast(dtypes.bool), *x._broadcasted(y))
def masked_fill(self:Tensor, mask:Tensor, value:Union[Tensor, ConstType]): return mask.where(value, self)
# ***** op wrappers *****
def __invert__(self) -> Tensor: return self.bitwise_not()
def __lshift__(self, x) -> Tensor: return self.lshift(x)
def __rshift__(self, x) -> Tensor: return self.rshift(x)
def __pow__(self, x) -> Tensor: return self.pow(x)
def __matmul__(self, x) -> Tensor: return self.matmul(x)
def __rpow__(self, x) -> Tensor: return self.pow(x, True)
def __rmatmul__(self, x) -> Tensor: return self.matmul(x, True)
def __iadd__(self, x) -> Tensor: return self.assign(self.add(x))
def __isub__(self, x) -> Tensor: return self.assign(self.sub(x))
def __imul__(self, x) -> Tensor: return self.assign(self.mul(x))
def __ipow__(self, x) -> Tensor: return self.assign(self.pow(x))
def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
def __ifloordiv__(self, x) -> Tensor: return self.assign(self.idiv(x))
def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
def __iand__(self, x) -> Tensor: return self.assign(self.bitwise_and(x))
def __ior__(self, x) -> Tensor: return self.assign(self.bitwise_or(x))
def __ixor__(self, x) -> Tensor: return self.assign(self.xor(x))
def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x))
def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x))
def __lt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, False))
def __gt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, True))
def ne(self, x) -> Tensor: return F.Neq.apply(*self._broadcasted(x))
def __eq__(self, x) -> Tensor: return self.eq(x) # type: ignore[override]
# ***** functional nn ops *****
def linear(self, weight:Tensor, bias:Optional[Tensor]=None):
"""
Applies a linear transformation to `self` using `weight` and `bias`.
See: https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 2], [3, 4]])
weight = Tensor([[1, 2], [3, 4]])
bias = Tensor([1, 2])
print(t.linear(weight, bias).numpy())
```
"""
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
return x.add(bias) if bias is not None else x
def sequential(self, ll:list[Callable[[Tensor], Tensor]]):
"""
Applies a sequence of functions to `self` chaining the output of each function to the input of the next.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3])
print(t.sequential([lambda x: x * 2, lambda x: x + 1]).numpy())
```
"""
return functools.reduce(lambda x,f: f(x), ll, self)
def layernorm(self, axis:Union[int,tuple[int,...]]=-1, eps:float=1e-5) -> Tensor:
"""
Applies Layer Normalization over a mini-batch of inputs.
- Described: https://paperswithcode.com/method/layer-normalization
- Paper: https://arxiv.org/abs/1607.06450v1
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.randn(8, 10, 16) * 2 + 8
print(t.mean().item(), t.std().item())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.layernorm()
print(t.mean().item(), t.std().item())
```
"""
y = (self - self.mean(axis, keepdim=True))
return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,tuple[int,...]]=1) -> Tensor:
"""
Applies Batch Normalization over a mini-batch of inputs.
- Described: https://paperswithcode.com/method/batch-normalization
- Paper: https://arxiv.org/abs/1502.03167
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.randn(8, 4, 16, 16) * 2 + 8
print(t.mean().item(), t.std().item())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.batchnorm(None, None, t.mean(axis=(0,2,3)), t.var(axis=(0,2,3)).add(1e-5).rsqrt())
print(t.mean().item(), t.std().item())
```
"""
axis_ = argfix(axis)
shape = tuple(s if ax in axis_ else 1 for ax, s in enumerate(self.shape))
x = self - mean.reshape(shape)
if weight is not None: x = x * weight.reshape(shape)
ret = x.mul(invstd.reshape(shape) if len(invstd.shape) == len(axis_) else invstd)
return (ret + bias.reshape(shape)) if bias is not None else ret
def dropout(self, p=0.5) -> Tensor:
"""
Applies dropout to `self`.
NOTE: dropout is only applied when `Tensor.training` is `True`.
- Described: https://paperswithcode.com/method/dropout
- Paper: https://jmlr.org/papers/v15/srivastava14a.html
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 2)
with Tensor.train():
print(t.dropout().numpy())
```
"""
if not Tensor.training or p == 0: return self
return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
# helper function commonly used for indexing
def _one_hot_along_dim(self:Tensor, num_classes:sint, dim:int=-1):
offset = self.ndim - self._resolve_dim(dim) - 1
return self == Tensor.arange(num_classes, device=self.device, requires_grad=False).reshape((num_classes,) + (1,) * offset)
def one_hot(self, num_classes:int=-1) -> Tensor:
"""
Converts `self` to a one-hot tensor.
`num_classes` defaults to -1, which means num_classes will be inferred as max(self) + 1.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([0, 1, 3, 3, 4])
print(t.one_hot(5).numpy())
```
"""
if num_classes == -1: num_classes = (self.max()+1).item()
return self[..., None]._one_hot_along_dim(num_classes).where(1, 0)
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None,
dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
"""
Computes scaled dot-product attention.
`self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
- Described: https://paperswithcode.com/method/scaled
- Paper: https://arxiv.org/abs/1706.03762v7
```python exec="true" source="above" session="tensor" result="python"
q = Tensor.randn(2, 4, 8)
k = Tensor.randn(2, 4, 8)
v = Tensor.randn(2, 4, 8)
print(q.scaled_dot_product_attention(k, v).numpy())
```
"""
# NOTE: it also works when `key` and `value` have symbolic shape.
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), 0)
qk = self.matmul(key.transpose(-2,-1), acc_dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1])
return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).cast(self.dtype).dropout(dropout_p) @ value
def _do_reduction(self, reduction:ReductionStr="mean") -> Tensor:
if reduction not in get_args(ReductionStr): raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}")
reductions: dict[str, Callable[[Tensor], Tensor]] = {"mean": Tensor.mean, "sum": Tensor.sum, "none": lambda x: x}
return reductions[reduction](self)
def binary_crossentropy(self, Y:Tensor, reduction:ReductionStr="mean") -> Tensor:
"""
Computes the binary cross-entropy loss between `self` and `Y`.
See: https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([0.1, 0.9, 0.2])
Y = Tensor([0, 1, 0])
print(t.binary_crossentropy(Y).item())
```
"""
return (-Y*self.log() - (1-Y)*(1-self).log())._do_reduction(reduction)
def binary_crossentropy_logits(self, Y:Tensor, reduction:ReductionStr="mean") -> Tensor:
"""
Computes the binary cross-entropy loss between `self` and `Y` where `self` is logits.
See: https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, 2, -3])
Y = Tensor([0, 1, 0])
print(t.binary_crossentropy_logits(Y).item())
```
"""
return (self.maximum(0) - Y * self + (1 + self.abs().neg().exp()).log())._do_reduction(reduction)
def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index:int=-1, label_smoothing=0.0, reduction:ReductionStr="mean") -> Tensor:
"""
Computes the sparse categorical cross-entropy loss between `self` and `Y`.
NOTE: `self` is logits and `Y` is the target labels.
NOTE: unlike PyTorch, this function expects the class axis to be -1
See: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[-1, 2, -3], [1, -2, 3]])
Y = Tensor([1, 2])
print(t.sparse_categorical_crossentropy(Y).item())
```
"""
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
assert reduction in ("mean", "sum", "none"), "reduction must be one of ['mean', 'sum', 'none']"
log_probs, loss_mask = self.log_softmax(), (Y != ignore_index) if ignore_index != -1 else Y.ones_like(dtype=dtypes.bool)
y_counted = Y.to(self.device).flatten().reshape(-1, 1)._one_hot_along_dim(self.shape[-1])
y = (y_counted * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask)
unreduced = ((1 - label_smoothing) * (log_probs * y).sum(-1) + smoothing)
# NOTE: because of ignore_index, we can't use Tensor.mean (so can't use `_do_reduction` here)
return -(unreduced.sum() / loss_mask.sum() if reduction == "mean" else (unreduced.sum() if reduction == "sum" else unreduced))
def cross_entropy(self, Y:Tensor, reduction:ReductionStr="mean", label_smoothing:float=0.0) -> Tensor:
"""
Compute the cross entropy loss between input logits and target.
NOTE: `self` are logits and `Y` are the target labels or class probabilities.
See: https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[-1, 2, -3], [1, -2, 3]])
Y = Tensor([1, 2])
print(t.cross_entropy(Y).item())
```
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[-1, 2, -3], [1, -2, 3]])
Y = Tensor([1, 2])
print(t.cross_entropy(Y, reduction='none').numpy())
```
"""
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
Y = Y.one_hot(num_classes=cast(int, self.shape[1])) if Y.ndim < 2 else Y
Y = (1 - label_smoothing)*Y + label_smoothing / cast(int, Y.shape[1])
ret = -self.log_softmax(axis=1).mul(Y).sum(axis=1)
return ret._do_reduction(reduction)
def nll_loss(self, Y:Tensor, weight:Optional[Tensor]=None, ignore_index:Optional[int]=None, reduction:ReductionStr="mean") -> Tensor:
"""
Compute the negative log likelihood loss between log-probabilities and target labels.
NOTE: `self` is log-probabilities and `Y` is the Y labels or class probabilities.
See: https://pytorch.org/docs/stable/generated/torch.nn.functional.nll_loss.html
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[-1, 2, -3], [1, -2, 3]])
Y = Tensor([1, 2])
print(t.log_softmax().nll_loss(Y).item())
```
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[-1, 2, -3], [1, -2, 3]])
Y = Tensor([1, 2])
print(t.log_softmax().nll_loss(Y, reduction='none').numpy())
```
"""
weight = Tensor.ones_like(Y, requires_grad=False) if weight is None else weight[Y]
masked_weight = weight if ignore_index is None else weight * (Y != ignore_index)
nll = -self.gather(1, Y.unsqueeze(1)).squeeze(1) * masked_weight
return nll.sum() / masked_weight.sum() if reduction == "mean" else nll._do_reduction(reduction)
# ***** Tensor Properties *****
@property
def ndim(self) -> int:
"""
Returns the number of dimensions in the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 2], [3, 4]])
print(t.ndim)
```
"""
return len(self.shape)
def numel(self) -> sint:
"""
Returns the total number of elements in the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
print(t.numel())
```
"""
return prod(self.shape)
def element_size(self) -> int:
"""
Returns the size in bytes of an individual element in the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([5], dtype=dtypes.int16)
print(t.element_size())
```
"""
return self.dtype.itemsize
def nbytes(self) -> int:
"""
Returns the total number of bytes of all elements in the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([8, 9], dtype=dtypes.float)
print(t.nbytes())
```
"""
return self.numel() * self.element_size()
def is_floating_point(self) -> bool:
"""
Returns `True` if the tensor contains floating point types, i.e. is one of `dtype.float64`, `dtype.float32`,
`dtype.float16`, `dtype.bfloat16`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([8, 9], dtype=dtypes.float32)
print(t.is_floating_point())
```
"""
return dtypes.is_float(self.dtype)
def size(self, dim:Optional[int]=None) -> Union[sint, tuple[sint, ...]]:
"""
Return the size of the tensor. If `dim` is specified, return the length along dimension `dim`. Otherwise return the shape of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[4, 5, 6], [7, 8, 9]])
print(t.size())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.size(dim=1))
```
"""
return self.shape if dim is None else self.shape[dim]
# ***** cast ops *****
def llvm_bf16_cast(self, dtype:DTypeLike):
# hack for devices that don't support bfloat16
assert self.dtype == dtypes.bfloat16
return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
def cast(self, dtype:DTypeLike) -> Tensor:
"""
Casts `self` to the given `dtype`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, 2.5, 3], dtype=dtypes.float)
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.cast(dtypes.int32)
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.cast(dtypes.uint8)
print(t.dtype, t.numpy())
```
"""
if (dt:=to_dtype(dtype)) in {dtypes.uint8, dtypes.uint16} and dtypes.is_float(self.dtype):
# NOTE: values within the int32 range and outside the unsigned dtype range will cause values to wrap around
return F.Cast.apply(F.Cast.apply(self, dtype=dtypes.int32), dtype=dt)
return self if self.dtype == dt else F.Cast.apply(self, dtype=dt)
def bitcast(self, dtype:DTypeLike) -> Tensor:
"""
Bitcasts `self` to the given `dtype` of the same itemsize.
`self` must not require a gradient.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, 2, 3], dtype=dtypes.int32)
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.bitcast(dtypes.uint32)
print(t.dtype, t.numpy())
```
"""
if self.requires_grad: raise RuntimeError("can't backprop through bitcast")
dt = to_dtype(dtype)
if (not isinstance(self.device, str) or not self.device.startswith("DISK")) and (ns:=dt.itemsize) != (os:=self.dtype.itemsize):
if (self.shape[-1]*os) % ns != 0: raise RuntimeError("unsupported size in bitcast")
new_uint, old_uint = to_dtype(f"uint{8*ns}"), to_dtype(f"uint{8*os}")
tmp = self.bitcast(old_uint)
if ns > os: return functools.reduce(Tensor.add, (tmp[..., i::ns//os].cast(new_uint) << 8*i*os for i in range(ns//os))).bitcast(dtype)
return Tensor.stack(*(tmp>>8*i*ns for i in range(os//ns)), dim=-1).flatten(-2).cast(new_uint).bitcast(dtype)
return F.Cast.apply(self, dtype=dt, bitcast=True) if self.dtype != dt else self
def float(self) -> Tensor:
"""
Convenience method to cast `self` to a `float32` Tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, 2, 3], dtype=dtypes.int32)
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.float()
print(t.dtype, t.numpy())
```
"""
return self.cast(dtypes.float32)
def half(self) -> Tensor:
"""
Convenience method to cast `self` to a `float16` Tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, 2, 3], dtype=dtypes.int32)
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.half()
print(t.dtype, t.numpy())
```
"""
return self.cast(dtypes.float16)
def int(self) -> Tensor:
"""
Convenience method to cast `self` to a `int32` Tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1.5, -0.5, 0.0, 0.5, 1.5])
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.int()
print(t.dtype, t.numpy())
```
"""
return self.cast(dtypes.int32)
def bool(self) -> Tensor:
"""
Convenience method to cast `self` to a `bool` Tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, 0, 1])
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.bool()
print(t.dtype, t.numpy())
```
"""
return self.cast(dtypes.bool)
# *** image Tensor function replacements ***
def image_dot(self, w:Tensor, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
x, dx, dw = self, self.ndim, w.ndim
if not (dx > 0 and dw > 0): raise RuntimeError(f"both tensors need to be at least 1D, got {dx}D and {dw}D")
if x.shape[-1] != w.shape[-min(w.ndim, 2)]: raise RuntimeError(f"cannot image_dot {x.shape} and {w.shape}")
bs, groups, cin, cout = prod(self.shape[0:-2]), prod(w.shape[0:-2]), w.shape[-2], w.shape[-1]
out_shape_t = self.shape[0:-2] + (cout,-1) if len(self.shape) > 1 else (cout, )
# NOTE: with NHWC we can remove the transposes
# bs x groups*cin x H x W
cx = self.transpose(self.ndim-1, self.ndim-2).reshape((bs//groups, groups*cin, -1, 1))
# groups*cout x cin x H, W
cw = w.transpose(w.ndim-1, w.ndim-2).reshape((groups*cout, cin, 1, 1))
return cx.image_conv2d(cw, groups=groups, acc_dtype=acc_dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
def image_conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype=None) -> Tensor:
base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W)
# hack for non multiples of 4 on cin
if cin % 4 != 0 and not (cin == 1 and groups%4 == 0):
x = x.reshape(bs, groups, cin, iy, ix) # do this always?
added_input_channels = 4 - (cin % 4)
w = w.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(w.ndim)))
x = x.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(x.ndim)))
cin = cin + added_input_channels
x = x.reshape(bs, groups*cin, iy, ix)
# hack for non multiples of 4 on rcout
added_output_channels = 0
if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0):
added_output_channels = 4 - (rcout % 4)
rcout += added_output_channels
cout = groups * rcout
w = w.pad(tuple((0, added_output_channels) if i == 1 else None for i in range(w.ndim)))
# packed (note: flipping bs and iy would make the auto-padding work)
x = x.permute(0,2,3,1)
cin_last = iy == 1 and ix == 1
if cin == 1: w = w.reshape(cout//4,4,H,W).permute(0,2,3,1)
elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3)
else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1)
# contiguous creates the image, and early realize static weights (TODO: test for the static weight)
if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4)))
x, w = x.contiguous(), w.contiguous()
# expand out
rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1
cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1]
x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
# prepare input
x = x.permute(0,3,4,5,1,2).pad(self._padding2d(padding, 2))._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, (oy := x.shape[4]), (ox := x.shape[5]), *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W)
# prepare weights
w = w.permute(0,4,2,5,1,3).reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W))
# the conv!
ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1), acc_dtype=acc_dtype)
# undo hack for non multiples of 4 on C.rcout
if added_output_channels != 0:
ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels]
cout = groups * (rcout - added_output_channels)
# NCHW output
ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
def _metadata_wrapper(fn):
def _wrapper(*args, **kwargs):
if _METADATA.get() is not None: return fn(*args, **kwargs)
if TRACEMETA >= 2:
caller_frame = sys._getframe(frame := 1)
caller_module = caller_frame.f_globals.get("__name__", None)
caller_func = caller_frame.f_code.co_name
if caller_module is None: return fn(*args, **kwargs)
# if its called from nn we want to step up frames until we are out of nn
while caller_module.startswith("tinygrad.nn") and "optim" not in caller_module:
caller_frame = sys._getframe(frame := frame + 1)
caller_module = caller_frame.f_globals.get("__name__", None)
if caller_module is None: return fn(*args, **kwargs)
# if its called from a lambda in tinygrad we want to look two more frames up
if caller_module.startswith("tinygrad") and caller_func == "<lambda>": caller_frame = sys._getframe(frame := frame + 2)
caller_module = caller_frame.f_globals.get("__name__", None)
if caller_module is None: return fn(*args, **kwargs)
caller_func = caller_frame.f_code.co_name
caller_lineno = caller_frame.f_lineno
caller = f"{caller_module}:{caller_lineno}::{caller_func}"
else: caller = ""
token = _METADATA.set(Metadata(name=fn.__name__, caller=caller))
ret = fn(*args, **kwargs)
_METADATA.reset(token)
return ret
return _wrapper
if TRACEMETA >= 1:
for name, fn in inspect.getmembers(Tensor, inspect.isfunction):
if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential"]: continue
setattr(Tensor, name, functools.wraps(fn)(_metadata_wrapper(fn)))