# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py from __future__ import annotations import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref from contextlib import ContextDecorator from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, SupportsIndex, ParamSpec, TypeVar from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate from tinygrad.dtype import _from_np_dtype, _to_np_dtype from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup from tinygrad.helpers import IMAGE, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap from tinygrad.engine.multi import get_multi_map from tinygrad.gradient import compute_gradient from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element from tinygrad.spec import tensor_uop_spec, type_verify from tinygrad.device import Device, Buffer from tinygrad.engine.realize import run_schedule from tinygrad.engine.memory import memory_planner from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars # *** all in scope Tensors are here. this gets relevant UOps *** all_tensors: set[weakref.ref[Tensor]] = set() def _apply_map_to_tensors(applied_map:dict[UOp, UOp]) -> None: # get all children of keys in applied_map all_uops: set[UOp] = set() search_uops = list(applied_map) while len(search_uops): x = search_uops.pop(0) if x in all_uops: continue all_uops.add(x) search_uops.extend([u for c in x.children if (u:=c()) is not None]) # link the found UOps back to Tensors. exit early if there's no Tensors to realize # NOTE: this uses all_tensors, but it's fast fixed_tensors: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and t.lazydata in all_uops] if len(fixed_tensors): # potentially rewrite all the discovered Tensors sink = UOp.sink(*[t.lazydata for t in fixed_tensors]) new_sink = sink.substitute(applied_map) # set the relevant lazydata to the realized UOps for t,s,ns in zip(fixed_tensors, sink.src, new_sink.src): if s is ns: continue t.lazydata = ns # **** Tensor helper functions **** def _metaop(op, shape:tuple[sint,...], dtype:DType, device:str|tuple[str, ...], arg=None) -> UOp: if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg) return UOp.multi(*[UOp.metaop(op, shape, dtype, d, arg) for d in device], axis=None) def _fromnp(x: 'np.ndarray') -> UOp: # type: ignore [name-defined] # noqa: F821 ret = UOp.metaop(Ops.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY") # fake realize ret.buffer.allocate(x) return ret def get_shape(x) -> tuple[int, ...]: # NOTE: str is special because __getitem__ on a str is still a str if not hasattr(x, "__len__") or not hasattr(x, "__getitem__") or isinstance(x, str) or (hasattr(x, "shape") and x.shape == ()): return () if not all_same(subs:=[get_shape(xi) for xi in x]): raise ValueError(f"inhomogeneous shape from {x}") return (len(subs),) + (subs[0] if subs else ()) def _frompy(x:list|tuple|bytes, dtype:DType) -> UOp: if isinstance(x, bytes): ret, data = UOp.metaop(Ops.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x else: ret = UOp.metaop(Ops.EMPTY, get_shape(x), dtype, "PYTHON") assert dtype.fmt is not None, f"{dtype=} has None fmt" truncate_function = truncate[dtype] data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)]) # fake realize ret.buffer.allocate(memoryview(data if Device.DEFAULT != "PYTHON" else bytearray(data))) return ret def _get_winograd_matcols(mat, dims:int, shp:tuple[sint, ...], device:str|tuple[str, ...], dtype:DType) -> list[list[Tensor]]: return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device, dtype=dtype) for m in mat], dim=dim) for k in range(len(mat[0]))] for dim in range(dims)] # winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308 def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor: # multiply mat_1 @ mat_2 @ t with foldable constants, where mat_i acts on vector t along dimension i; roughly kron(mat, mat) @ t # due to realize-before-expand rule in lazy.py, we must operate in this order: reshape -> expand -> arithmetic t_ = t.reshape(t.shape[:dims] + (1,) * dims + t.shape[dims:]).expand(t.shape[:dims] + (len(mat),) * dims + t.shape[dims:]) # add output dims # precalculate mat columns for each dim; prod(itertools.product(matcols)) gives the columns of kron(mat, mat, ...) matcols = _get_winograd_matcols(mat, dims, t_.shape[dims:], t_.device, t_.dtype) # multiply each element of t_ by the corresponding stacked column of kron(mat, mat), producing only one view for each element of t ret = sum(prod(col[idx] for col, idx in zip(matcols, mat_is)) * t_[mat_is] for mat_is in itertools.product(range(len(mat[0])), repeat=dims)) assert isinstance(ret, Tensor), "sum didn't return a Tensor" return ret def _align_left(*shapes:tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]: # unsqueeze left to make every shape same length max_dim = max(len(shape) for shape in shapes) return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes) def _broadcast_shape(*shapes:tuple[sint, ...]) -> tuple[sint, ...]: return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes))) def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, ...]) -> Tensor: # apply mask to values (already broadcasted) and reduce such that if mask contains repeated indices the last one remains values = values * mask for dim in axes: mask, values = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), values.split(1, dim))) # remove extra dims from reduce for dim in reversed(axes): mask, values = mask.squeeze(dim), values.squeeze(dim) # select from values for each True element in mask else select from self return mask.where(values, target) # `(padding_left, padding_right, padding_top, padding_bottom, ...)` -> `(..., (padding_top, padding_bottom), (padding_left, padding_right))` def _flat_to_grouped(padding:Sequence[sint]) -> tuple[tuple[sint, sint], ...]: return tuple(zip(padding[-2::-2], padding[::-2])) ReductionStr = Literal["mean", "sum", "none"] class Tensor(SimpleMathTrait): """ A `Tensor` is a multi-dimensional matrix containing elements of a single data type. ```python exec="true" session="tensor" from tinygrad import Tensor, dtypes, nn import numpy as np import math np.set_printoptions(precision=4) ``` """ __slots__ = "lazydata", "requires_grad", "grad" training: ClassVar[bool] = False no_grad: ClassVar[bool] = False def __init__(self, data:ConstType|bytes|list|tuple|UOp|'np.ndarray'|pathlib.Path|None, # type: ignore [name-defined] # noqa: F821 device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool|None=None): if dtype is not None: dtype = to_dtype(dtype) if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device) # tensors can have gradients if you have called .backward self.grad:Tensor|None = None # NOTE: this can be in three states. False and None: no gradient, True: gradient # None (the default) will be updated to True if it's put in an optimizer self.requires_grad:bool|None = requires_grad # create a LazyBuffer from the different types of inputs if isinstance(data, UOp): assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported" # NOTE: this is here because LazyBuffer = UOp if isinstance(data, UOp) and data.op is Ops.BIND: data = _metaop(Ops.BIND, tuple(), dtype or data.dtype, device, data) elif data is None: data = _metaop(Ops.EMPTY, (0,), dtype or dtypes.default_float, device) elif isinstance(data, get_args(ConstType)): data = _metaop(Ops.CONST, tuple(), dtype or dtypes.from_py(data), device, data) elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype) elif isinstance(data, (list, tuple)): if dtype is None: if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float # NOTE: this works because all_int([True, False]) is True if dtype == dtypes.bfloat16: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtypes.bfloat16).lazydata else: data = _frompy(data, dtype) elif str(type(data)) == "": import numpy as np assert isinstance(data, np.ndarray), f"expected np.ndarray, got {data}" if data.shape == (): data = _metaop(Ops.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item()) else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data) # type: ignore [name-defined] elif isinstance(data, pathlib.Path): dtype = dtype or dtypes.uint8 data = _metaop(Ops.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}") # by this point, it has to be a UOp if not isinstance(data, UOp): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}") # data might be on a different device if isinstance(device, str): self.lazydata:UOp = data if data.device == device else data.copy_to_device(device) # if device is a tuple, we should have/construct a MultiLazyBuffer elif isinstance(data, UOp) and isinstance(data.device, str): self.lazydata = Tensor(data).shard(device).lazydata else: assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}" self.lazydata = data # add to all_tensors after construction succeeds all_tensors.add(weakref.ref(self)) def __del__(self): all_tensors.discard(weakref.ref(self)) def _apply_uop(self, fxn:Callable, *x:Tensor, **kwargs) -> Tensor: new_uop: UOp = fxn(*[t.lazydata for t in (self,)+x], **kwargs) needs_input_grad = [t.requires_grad for t in (self,)+x] return Tensor(new_uop, device=new_uop.device, requires_grad=True if any(needs_input_grad) else None if None in needs_input_grad else False) def _apply_broadcasted_uop(self, fxn:Callable, x:Tensor|ConstType, reverse=False) -> Tensor: lhs,rhs = self._broadcasted(x, reverse) return lhs._apply_uop(fxn, rhs) def requires_grad_(self, requires_grad=True) -> Tensor: self.requires_grad = requires_grad return self class train(ContextDecorator): def __init__(self, mode:bool = True): self.mode = mode def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev class test(ContextDecorator): def __init__(self, mode:bool = True): self.mode = mode def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev def __repr__(self): ld = self.lazydata ld_repr = f"" return f"" # Python has a non moving GC, so this should be okay def __hash__(self): return id(self) def __bool__(self): raise TypeError("__bool__ on Tensor is not defined") def __len__(self): if not self.shape: raise TypeError("len() of a 0-d tensor") return self.shape[0] @property def device(self) -> str|tuple[str, ...]: return self.lazydata.device @property def shape(self) -> tuple[sint, ...]: return self.lazydata.shape @property def dtype(self) -> DType: return self.lazydata.dtype # ***** data handlers **** def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[Variable, int]]: """ Creates the schedule needed to realize these Tensor(s), with Variables. NOTE: A Tensor can only be scheduled once. """ big_sink = UOp.sink(*[x.lazydata for x in (self,)+lst]) # TODO: move this to scheduler tensor_map pass if any(x.op is Ops.MULTI for x in big_sink.toposort): # multi fixup _apply_map_to_tensors(get_multi_map(big_sink)) big_sink = UOp.sink(*flatten([x.lazydata.src if x.lazydata.op is Ops.MULTI else [x.lazydata] for x in (self,)+lst])) # verify Tensors match the spec if __debug__: type_verify(list(big_sink.toposort), tensor_uop_spec) schedule, var_vals, becomes_map = create_schedule_with_vars(big_sink) _apply_map_to_tensors(becomes_map) return memory_planner(schedule), var_vals def schedule(self, *lst:Tensor) -> list[ScheduleItem]: """Creates the schedule needed to realize these Tensor(s).""" schedule, var_vals = self.schedule_with_vars(*lst) assert len(var_vals) == 0 return schedule def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor: """Triggers the computation needed to create these Tensor(s).""" run_schedule(*self.schedule_with_vars(*lst), do_update_stats=do_update_stats) return self def replace(self, x:Tensor, allow_shape_mismatch=False) -> Tensor: """ Replaces the data of this tensor with the data of another tensor. Only the shape of the tensors must match. """ # used for replacing a Tensor with a new version of it (potentially with a different device and dtype) assert self.shape == x.shape or allow_shape_mismatch, f"replace shape mismatch {self.shape} != {x.shape}" self.lazydata = x.lazydata return self def assign(self, x) -> Tensor: # TODO: this is a hack for writing to DISK. remove with working assign if isinstance(self.device, str) and self.device.startswith("DISK"): if x.__class__ is not Tensor: x = Tensor(x, device="CPU", dtype=self.dtype) self.contiguous().realize().lazydata.base.buffer.ensure_allocated().copyin(x._data()) return self if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype) if self.lazydata is x.lazydata: return self # a self assign is a NOOP # NOTE: we allow cross device assign assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}" assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}" assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}" assert not x.requires_grad # self requires_grad is okay? if not self.lazydata.is_realized: return self.replace(x) self.lazydata = self.lazydata.assign(x.lazydata) return self def detach(self) -> Tensor: """ Returns a new tensor with the same data as this tensor, but detached from the autograd graph. """ return Tensor(self.lazydata.detach(), device=self.device, requires_grad=False) def _buffer(self) -> Buffer: return self.cast(self.dtype.base).contiguous().to("CPU").realize().lazydata.base.buffer def _data(self) -> memoryview: return self._buffer().as_buffer() def data(self) -> memoryview: """ Returns the data of this tensor as a memoryview. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([1, 2, 3, 4]) print(np.frombuffer(t.data(), dtype=np.int32)) ``` """ if 0 in self.shape: return memoryview(bytearray(0)).cast(self.dtype.base.fmt) assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}" return self._buffer().as_typed_buffer(self.shape) def item(self) -> ConstType: """ Returns the value of this tensor as a standard Python number. ```python exec="true" source="above" session="tensor" result="python" t = Tensor(42) print(t.item()) ``` """ assert self.numel() == 1, "must have one element for item" return self.data()[(0,) * len(self.shape)] # TODO: should be Tensor.tolist() -> Union[list[ConstType], ConstType]. The list is Sequence because mypy expects memoryview.tolist() -> list[int] # src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803 def tolist(self) -> Sequence[ConstType]|ConstType: """ Returns the value of this tensor as a nested list. Returns single value for const tensor. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([1, 2, 3, 4]) print(t.tolist()) ``` ```python exec="true" source="above" session="tensor" result="python" t = Tensor(5) print(t.tolist()) ``` """ return self.data().tolist() def numpy(self) -> 'np.ndarray': # type: ignore [name-defined] # noqa: F821 """ Returns the value of this tensor as a `numpy.ndarray`. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([1, 2, 3, 4]) print(repr(t.numpy())) ``` """ assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}" import numpy as np if self.dtype.base == dtypes.bfloat16: return self.float().numpy() if 0 in self.shape: return np.empty(self.shape, dtype=_to_np_dtype(self.dtype.base)) return self._buffer().numpy().reshape(self.shape) def clone(self) -> Tensor: """ Creates a clone of this tensor allocating a separate buffer for the data. """ ret = Tensor(self.lazydata.clone(), self.device, requires_grad=self.requires_grad) if self.grad is not None: ret.grad = self.grad.clone() return ret def to(self, device:str|tuple[str, ...]|None) -> Tensor: """ Moves the tensor to the given device. """ device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device) if device == self.device: return self if not isinstance(device, str): return self.shard(device) ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad) if self.grad is not None: ret.grad = self.grad.to(device) return ret def to_(self, device:str|tuple[str, ...]|None) -> Tensor: """ Moves the tensor to the given device in place. """ real = self.to(device) if self.grad is not None and real.grad is not None: self.grad.replace(real.grad) return self.replace(real) def shard(self, devices:tuple[str, ...], axis:int|None=None) -> Tensor: """ Shards the tensor across the given devices. Optionally specify which axis to shard on. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.empty(2, 4) print(t.shard((t.device, t.device), axis=1).lazydata) ``` """ assert isinstance(self.device, str), "can't shard a MultiLazyBuffer" devices = tuple(Device.canonicalize(x) for x in devices) mlb = self.lazydata.shard(devices, self._resolve_dim(axis) if axis is not None else None) return Tensor(mlb, device=devices, requires_grad=self.requires_grad) def shard_(self, devices:tuple[str, ...], axis:int|None=None) -> Tensor: """ Shards the tensor across the given devices in place. """ return self.replace(self.shard(devices, axis)) @staticmethod def from_uop(y:UOp, **kwargs) -> Tensor: if y.op is Ops.BIND: return Tensor(y, **kwargs, requires_grad=False) # this is the only UOp allowed in Tensor if y.op is Ops.CONST: return Tensor(y.arg, **kwargs, requires_grad=False) if y.op is Ops.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1]) if y.op is Ops.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1]) raise RuntimeError(f"unhandled UOp {y}") # ***** creation entrypoint ***** @staticmethod def _metaop(op, shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None, arg=None, **kwargs) -> Tensor: dtype = to_dtype(dtype) if dtype is not None else dtypes.default_float if isinstance(device, tuple): return Tensor(UOp.multi(*[UOp.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], axis=None), device, dtype, **kwargs) return Tensor(UOp.metaop(op, shape, dtype, Device.canonicalize(device), arg), device, dtype, **kwargs) @staticmethod def empty(*shape, **kwargs) -> Tensor: """ Creates an empty tensor with the given shape. You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor. Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.empty(2, 3) print(t.shape) ``` """ return Tensor._metaop(Ops.EMPTY, argfix(*shape), **kwargs) @staticmethod def from_blob(ptr:int, shape:tuple[int, ...], **kwargs) -> Tensor: """ Exposes the pointer as a Tensor without taking ownership of the original data. The pointer must remain valid for the entire lifetime of the created Tensor. You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor. Additionally, all other keyword arguments are passed to the constructor of the tensor. """ r = Tensor._metaop(Ops.EMPTY, shape, **kwargs) r.lazydata.buffer.allocate(external_ptr=ptr) return r @staticmethod def from_url(url:str, gunzip:bool=False, **kwargs) -> Tensor: """ Create a Tensor from a URL. This is the preferred way to access Internet resources. It currently returns a DISK Tensor, but in the future it may return an HTTP Tensor. This also will soon become lazy (when possible) and not print progress without DEBUG. THe `gunzip` flag will gzip extract the resource and return an extracted Tensor. """ return Tensor(fetch(url, gunzip=gunzip), **kwargs) _seed: int = int(time.time()) _device_seeds: dict[str, Tensor] = {} _device_rng_counters: dict[str, Tensor] = {} @staticmethod def manual_seed(seed=0) -> None: """ Sets the seed for random operations. ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) print(Tensor.rand(5).numpy()) print(Tensor.rand(5).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) # reset to the same seed print(Tensor.rand(5).numpy()) print(Tensor.rand(5).numpy()) ``` """ Tensor._seed, Tensor._device_seeds, Tensor._device_rng_counters = seed, {}, {} @staticmethod def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor) -> Tensor: x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64) x = x._apply_uop(UOp.threefry, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64)) counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32) return counts0.cat(counts1) @staticmethod def rand(*shape, device:str|None=None, dtype:DTypeLike|None=None, contiguous:bool=True, **kwargs) -> Tensor: """ Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`. You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor. Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) t = Tensor.rand(2, 3) print(t.numpy()) ``` """ if not dtypes.is_float(dtype := to_dtype(dtype or dtypes.default_float)): raise ValueError(f"rand only supports float dtypes, got {dtype}") if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}") if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}") device = Device.canonicalize(device) # if shape has 0, return zero tensor if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs) num = ceildiv(numel * dtype.itemsize, 4) # generate per device seeds and rng counter if we haven't seen this device yet if device not in Tensor._device_seeds: Tensor._device_seeds[device] = Tensor( [int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big"), Tensor._seed], device=device, dtype=dtypes.uint32, requires_grad=False) Tensor._device_rng_counters[device] = Tensor([0], device=device, dtype=dtypes.uint32, requires_grad=False) # increment rng counter for devices else: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num).contiguous() # threefry random bits counts0 = (Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._device_rng_counters[device]) counts1 = counts0 + ceildiv(num, 2) bits = Tensor._threefry_random_bits(Tensor._device_seeds[device], counts0, counts1)[:num] # bitcast to uint with same number of bits _, nmant = dtypes.finfo(dtype) uint_dtype = {1: dtypes.uint8, 2: dtypes.uint16, 4: dtypes.uint32, 8: dtypes.uint64}[dtype.itemsize] bits = bits.bitcast(uint_dtype) # only randomize the mantissa bits and set the exponent to 1 one = Tensor.ones_like(bits, device=bits.device, dtype=dtype).bitcast(uint_dtype) bits = bits.rshift((dtype.itemsize * 8) - nmant).bitwise_or(one) # bitcast back to the original dtype and reshape out = bits.bitcast(dtype)[:numel].sub(1).reshape(shape).requires_grad_(kwargs.get("requires_grad")) return out.contiguous() if contiguous else out # ***** creation helper functions ***** @staticmethod def full(shape:tuple[sint, ...], fill_value:ConstType, **kwargs) -> Tensor: """ Creates a tensor with the given shape, filled with the given value. You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor. Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" print(Tensor.full((2, 3), 42).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor.full((2, 3), False).numpy()) ``` """ return Tensor(fill_value, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape) @staticmethod def zeros(*shape, **kwargs) -> Tensor: """ Creates a tensor with the given shape, filled with zeros. You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor. Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" print(Tensor.zeros(2, 3).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor.zeros(2, 3, dtype=dtypes.int32).numpy()) ``` """ return Tensor.full(argfix(*shape), 0.0, **kwargs) @staticmethod def ones(*shape, **kwargs) -> Tensor: """ Creates a tensor with the given shape, filled with ones. You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor. Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" print(Tensor.ones(2, 3).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor.ones(2, 3, dtype=dtypes.int32).numpy()) ``` """ return Tensor.full(argfix(*shape), 1.0, **kwargs) @staticmethod def arange(start, stop=None, step=1, **kwargs) -> Tensor: """ Returns a 1-D tensor of size `ceil((stop - start) / step)` with values from `[start, stop)`, with spacing between values given by `step`. If `stop` is not specified, values are generated from `[0, start)` with the given `step`. If `stop` is specified, values are generated from `[start, stop)` with the given `step`. You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor. Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" print(Tensor.arange(5).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor.arange(5, 10).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor.arange(5, 10, 2).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor.arange(5.5, 10, 2).numpy()) ``` """ if stop is None: stop, start = start, 0 dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int) # NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs if (output_len:=ceildiv(stop-start, step)) <= 0: return Tensor([], dtype=dtype, **kwargs) return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumalu(0, Ops.ADD) + (start - step)).cast(dtype) @staticmethod def linspace(start:int|float, stop:int|float, steps:int, **kwargs) -> Tensor: """ Returns a 1-D tensor of `steps` evenly spaced values from `start` to `stop`, inclusive. You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor. Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" print(Tensor.linspace(0, 10, 5).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor.linspace(-1, 1, 5).numpy()) ``` """ if steps < 0: raise ValueError("number of steps must be non-negative") if (dtype := to_dtype(kwargs.pop("dtype", dtypes.default_float))) == dtypes.bool: raise ValueError("linspace with bool dtype is not supported") if steps == 1: return Tensor([start], dtype=dtype, **kwargs) return (start + Tensor.arange(steps, **kwargs) * ((stop - start) / (steps - 1))).cast(dtype) @staticmethod def eye(n:int, m:int|None=None, **kwargs) -> Tensor: """ Returns a 2-D tensor with `n` rows and `m` columns, with ones on the diagonal and zeros elsewhere. You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor. Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" print(Tensor.eye(3).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor.eye(2, 4).numpy()) ``` """ if n < 0 or (m is not None and m < 0): raise ValueError(f"cannot have negative {n=}, {m=}") x = Tensor.ones((n,1),**kwargs).pad((None,(0,n))).flatten().shrink(((0,n*n),)).reshape(n,n) return x if m is None else x.pad((None, (0, m-n))) if m > n else x.shrink((None, (0, m))) def full_like(self, fill_value:ConstType, **kwargs) -> Tensor: """ Creates a tensor with the same shape as `self`, filled with the given value. If `dtype` is not specified, the dtype of `self` is used. You can pass in the `device` keyword argument to control device of the tensor. Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.ones(2, 3) print(Tensor.full_like(t, 42).numpy()) ``` """ return Tensor.full(self.shape, fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs) def zeros_like(self, **kwargs) -> Tensor: """ Creates a tensor with the same shape as `self`, filled with zeros. You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor. Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.ones(2, 3) print(Tensor.zeros_like(t).numpy()) ``` """ return self.full_like(0, **kwargs) def ones_like(self, **kwargs) -> Tensor: """ Creates a tensor with the same shape as `self`, filled with ones. You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor. Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.zeros(2, 3) print(Tensor.ones_like(t).numpy()) ``` """ return self.full_like(1, **kwargs) def rand_like(self, **kwargs) -> Tensor: """ Creates a tensor with the same shape and sharding as `self`, filled with random values from a uniform distribution over the interval `[0, 1)`. You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor. Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.ones(2, 3) print(Tensor.rand_like(t).numpy()) ``` """ dtype = kwargs.pop("dtype", self.dtype) if isinstance(self.device, tuple): if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor") if self.lazydata.axis is None: return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device) contiguous = kwargs.pop("contiguous", True) sharded_shape = tuple(s//len(self.device) if a==self.lazydata.axis else s for a,s in enumerate(self.shape)) rands = [Tensor.rand(sharded_shape, device=d, dtype=dtype, contiguous=contiguous, **kwargs).lazydata for d in self.device] return Tensor(UOp.multi(*rands, axis=self.lazydata.axis), device=self.device, dtype=dtype, **kwargs) return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=dtype, **kwargs) # ***** rng hlops ***** @staticmethod def randn(*shape, dtype:DTypeLike|None=None, requires_grad:bool|None=None, **kwargs) -> Tensor: """ Creates a tensor with the given shape, filled with random values from a normal distribution with mean `0` and standard deviation `1`. If `dtype` is not specified, the default type is used. You can pass in the `device` keyword argument to control device of the tensor. Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) print(Tensor.randn(2, 3).numpy()) ``` """ # https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform src = Tensor.rand((2, *argfix(*shape)), **{**kwargs, "dtype": dtypes.float32}) return (src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or dtypes.default_float)).requires_grad_(requires_grad) @staticmethod def randint(*shape, low=0, high=10, dtype=dtypes.int32, **kwargs) -> Tensor: """ Creates a tensor with the given shape, filled with random integer values generated uniformly from the interval `[low, high)`. If `dtype` is not specified, the default type is used. You can pass in the `device` keyword argument to control device of the tensor. Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) print(Tensor.randint(2, 3, low=5, high=10).numpy()) ``` """ if not isinstance(low, int) or not isinstance(high, int): raise TypeError(f"{low=} and {high=} must be integers") dtype = to_dtype(dtype) if not dtypes.is_int(dtype): raise TypeError(f"{dtype=} must be int") return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs) @staticmethod def normal(*shape, mean=0.0, std=1.0, requires_grad:bool|None=None, **kwargs) -> Tensor: """ Creates a tensor with the given shape, filled with random values from a normal distribution with the given `mean` and standard deviation `std`. You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor. Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) print(Tensor.normal(2, 3, mean=10, std=2).numpy()) ``` """ return ((std * Tensor.randn(*shape, **kwargs)) + mean).requires_grad_(requires_grad) @staticmethod def uniform(*shape, low=0.0, high=1.0, dtype:DTypeLike|None=None, requires_grad:bool|None=None, **kwargs) -> Tensor: """ Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[low, high)`. You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor. Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) print(Tensor.uniform(2, 3, low=2, high=10).numpy()) ``` """ return (((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype or dtypes.default_float) + low).requires_grad_(requires_grad) @staticmethod def scaled_uniform(*shape, **kwargs) -> Tensor: """ Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[-prod(shape)**-0.5, prod(shape)**-0.5)`. You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor. Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) print(Tensor.scaled_uniform(2, 3).numpy()) ``` """ return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5) # https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform @staticmethod def glorot_uniform(*shape, **kwargs) -> Tensor: """ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor. Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) print(Tensor.glorot_uniform(2, 3).numpy()) ``` """ return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul((6/(argfix(*shape)[0]+prod(argfix(*shape)[1:])))**0.5) # https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_ @staticmethod def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor: """ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor. Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) print(Tensor.kaiming_uniform(2, 3).numpy()) ``` """ bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:])) return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs) # https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_ @staticmethod def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor: """ You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor. Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) print(Tensor.kaiming_normal(2, 3).numpy()) ``` """ std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:])) return Tensor.normal(*shape, mean=0.0, std=std, **kwargs) def multinomial(self:Tensor, num_samples:int = 1, replacement:bool = False) -> Tensor: assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive" assert replacement or num_samples == 1, "no replacement only supports num_samples = 1" weight = self.unsqueeze(0) if self.ndim == 1 else self cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1) unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1).to(self.device) indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0)) return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32) # ***** toposort and backward pass ***** def gradient(self, *targets:Tensor, gradient:Tensor|None=None, materialize_grads=False) -> list[Tensor]: """ Compute the gradient of the targets with respect to self. ```python exec="true" source="above" session="tensor" result="python" x = Tensor.eye(3) y = Tensor([[2.0,0,-2.0]]) z = y.matmul(x).sum() dx, dy = z.gradient(x, y) print(dx.tolist()) # dz/dx print(dy.tolist()) # dz/dy ``` """ assert gradient is not None or self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor" if gradient is None: gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False) rets = [] target_uops = [x.lazydata for x in targets] grads = compute_gradient(self.lazydata, gradient.lazydata, set(target_uops)) ret = [] for x in target_uops: if (y:=grads.get(x)) is None: if materialize_grads: y = x.const_like(0) else: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.lazydata}") ret.append(y) rets.append(ret) # create returned Tensors return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])] def backward(self, gradient:Tensor|None=None) -> Tensor: """ Propagates the gradient of a tensor backwards through the computation graph. If the 'gradient' argument is not provided, the tensor must be a scalar, and the gradient is implicitly set to 1.0. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True) t.sum().backward() print(t.grad.numpy()) ``` """ all_uops = self.lazydata.toposort tensors_need_grad: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and \ t.lazydata in all_uops and t.requires_grad and not Tensor.no_grad] # clear contexts for t,g in zip(tensors_need_grad, self.gradient(*tensors_need_grad, gradient=gradient, materialize_grads=True)): assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}" t.grad = g if t.grad is None else (t.grad + g) return self # ***** movement low level ops ***** def view(self, shape:tuple[sint, ...], *args) -> Tensor: """`.view` is an alias for `.reshape`.""" return self.reshape(argfix(shape, *args)) def reshape(self, shape, *args) -> Tensor: """ Returns a tensor with the same data as the original tensor but with a different shape. `shape` can be passed as a tuple or as separate arguments. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(6) print(t.reshape(2, 3).numpy()) ``` """ # resolve None and args new_shape = tuple([s if s is not None else self.shape[i] for i,s in enumerate(argfix(shape, *args))]) # resolve -1 if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}") if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]) return self._apply_uop(UOp.reshape, arg=new_shape) if new_shape != self.shape else self def expand(self, shape, *args) -> Tensor: """ Returns a tensor that is expanded to the shape that is specified. Expand can also increase the number of dimensions that a tensor has. Passing a `-1` or `None` to a dimension means that its size will not be changed. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([1, 2, 3]) print(t.expand(4, -1).numpy()) ``` """ new_shape = tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_align_left(self.shape, argfix(shape, *args))))) return self._broadcast_to(new_shape) def permute(self, order, *args) -> Tensor: """ Returns a tensor that is a permutation of the original tensor. The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified. `order` can be passed as a tuple or as separate arguments. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(6).reshape(2, 3) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.permute(1, 0).numpy()) ``` """ order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args)) if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}") return self._apply_uop(UOp.permute, arg=order_arg) def flip(self, axis, *args) -> Tensor: """ Returns a tensor that reverses the order of the original tensor along given `axis`. `axis` can be passed as a tuple or as separate arguments. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(6).reshape(2, 3) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.flip(0).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.flip((0, 1)).numpy()) ``` """ axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args)) if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}") return self._apply_uop(UOp.flip, arg=tuple([i in axis_arg for i in range(len(self.shape))])) def shrink(self, arg:tuple[tuple[sint, sint]|None, ...]) -> Tensor: """ Returns a tensor that shrinks the each axis based on input arg. `arg` must have the same length as `self.ndim`. For each axis, it can be `None`, which means no shrink, or a tuple `(start, end)` that works the same as Python slice. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(9).reshape(3, 3) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.shrink(((None, (1, 3)))).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.shrink((((0, 2), (0, 2)))).numpy()) ``` """ if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self return self._apply_uop(UOp.shrink, arg=tuple(shrink_arg)) def pad(self, padding:Sequence[sint]|Sequence[tuple[sint, sint]|None], mode:str="constant", value:float=0.0) -> Tensor: """ Returns a tensor with padding applied based on the input `padding`. `padding` supports two padding structures: 1. Flat padding: `(padding_left, padding_right, padding_top, padding_bottom, ...)` - This structure matches PyTorch's pad. - `padding` length must be even. 2. Group padding: `(..., (padding_top, padding_bottom), (padding_left, padding_right))` - This structure matches pad for JAX, NumPy, TensorFlow, and others. - For each axis, padding can be `None`, meaning no padding, or a tuple `(start, end)`. - `padding` must have the same length as `self.ndim`. Padding values can be negative, resulting in dimension shrinks that work similarly to Python negative slices. Padding modes is selected with `mode` which supports `constant`, `reflect` and `replicate`. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(9).reshape(1, 1, 3, 3) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.pad((1, 2, 0, -1)).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.pad(((None, None, (0, -1), (1, 2)))).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.pad((1, 2, 0, -1), value=-float('inf')).numpy()) ``` """ if mode not in {"constant", "reflect", "replicate", "circular"}: raise NotImplementedError(f"{mode=} is not supported") # flat padding if all(isinstance(p, (int,UOp)) for p in padding): if len(padding)%2 != 0: raise ValueError("Flat padding must have even number of pads") pX = _flat_to_grouped(tuple(cast(Sequence[sint], padding)) + (0,0)*(self.ndim - len(padding)//2)) # group padding else: pX = tuple((0,0) if p is None else p for p in cast(Sequence[tuple[sint, sint]|None], padding)) if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}") X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX) if mode == "constant": def _constant(x:Tensor,px,v) -> Tensor: return x._apply_uop(UOp.pad, arg=px) if v == 0 else (x._apply_uop(UOp.pad, arg=px)+Tensor.ones_like(x)._apply_uop(UOp.pad, arg=px).where(0,v)) return _constant(X, pX, value) if all(resolve(p >= 0) for p in flatten(pX)) else \ _constant(X.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, X.shape))), pads, value) assert all_int(self.shape), f"does not support symbolic shape {self.shape}" if mode == "circular": if any(pB>sh or pA>sh for (pB,pA),sh in zip(pX, X.shape)): raise ValueError('Padding value causes wrapping around more than once.') if any(pB<0 or pA<0 for pB,pA in pX): raise NotImplementedError("Negative pads with circular pads is not supported") orig_shape, X = X.shape, X.repeat(tuple(1 + bool(pB) + bool(pA) for pB,pA in pads)) return X.shrink(tuple((0 if pB == 0 else osh-pB, xsh if pA == 0 else xsh-osh+pA) for (pB,pA),osh,xsh in zip(pads, orig_shape, X.shape))) for d,(pB,pA) in enumerate(pads): if mode == "reflect": if pB >= (s:=X.shape[d]) or pA>=s: raise ValueError(f"Padding ({pB}, {pA}) should be less than the input size={s} for dim={d}.") slcB, slcA, = slice(pB,0,-1), slice(s-2 if s-2>=0 else None, s-2-pA if s-2-pA>=0 else None, -1) xB, xA = (X[[slc if i == d else slice(None) for i in range(X.ndim)]] if p > 0 else None for slc, p in ((slcB, pB), (slcA, pA))) if mode == "replicate": shrB, shrA, = tuple((0,1) if i==d else None for i in range(X.ndim)), tuple((X.shape[i]-1,X.shape[i]) if i==d else None for i in range(X.ndim)) xB, xA = (X.shrink(shr).expand(tuple(p if i==d else None for i in range(X.ndim))) if p > 0 else None for shr, p in ((shrB, pB), (shrA, pA))) X = Tensor.cat(*(X_ for X_ in (xB, X, xA) if X_ is not None), dim=d) return X.shrink(tuple((-min(pB,0), min(pA+s,s)) for (pB,pA),s in zip(pX, X.shape))) # ***** movement high level ops ***** def _getitem(self, indices, v: Tensor|None = None) -> Tensor: # wrap single index into a list if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)): indices = [indices] x, indices = self, list(indices) # filter ellipsis and fill with slice(None) or fill rest of indices with slice(None) if len(ellipsis_idx := [dim for dim, i in enumerate(indices) if i is Ellipsis]) > 1: raise IndexError("indices can only have a single ellipsis") fill_idx = ellipsis_idx[0] if ellipsis_idx else len(indices) num_indices = len(indices) - len(ellipsis_idx) - sum(1 for i in indices if i is None) if num_indices > self.ndim: raise IndexError(f"too many {num_indices=} for {self.ndim=}") indices[fill_idx:fill_idx+1] = [slice(None)] * (self.ndim - num_indices) indices_parsed, dim = [], 0 for index in indices: size = 1 if index is None else self.shape[dim] boundary, stride = [0, size], 1 # defaults match index: case list() | tuple() | Tensor(): if not isinstance(index, Tensor): index = Tensor(index, self.device, requires_grad=False) if not dtypes.is_int(index.dtype): raise IndexError(f"index dtype {index.dtype} is not supported") index = (index.to(self.device) < 0).where(index+size, index) # treat negative index values case int() | UOp(): # sint if index >= size or index < -size: raise IndexError(f"{index=} is out of bounds with {size=}") boundary = [index, index+1] if index >= 0 else [index+size, index+size+1] case slice(): if index.step == 0: raise ValueError(f"{index=} cannot have 0 as step") if not all(isinstance(s,int) or s is None for s in (index.start,index.stop,index.step)): raise TypeError("only int slicing is supported") # handle int slicing *boundary, stride = index.indices(cast(SupportsIndex, size)) if stride * (boundary[1] - boundary[0]) < 0: boundary = [0, 0] elif stride < 0: boundary = [boundary[1] + 1, boundary[0] + 1] # update size for slice size = ceildiv((boundary[1] - boundary[0]), abs(stride)) case None: pass # do nothing case _: raise IndexError(f"{type(index).__name__} indexing is not supported") indices_parsed.append({"index":index, "size":size, "boundary":tuple(boundary), "stride":stride}) if index is not None: dim += 1 # movement op indexing if mops := [i for i in indices_parsed if i['index'] is not None]: # flip negative strides shrinks, strides = zip(*((i['boundary'], i['stride']) for i in mops)) x = x.shrink(shrinks).flip(tuple(i for i,st in enumerate(strides) if st < 0)) # handle stride != 1 or -1 if any(abs(st) != 1 for st in strides): strides = tuple(abs(s) for s in strides) # pad shape to multiple of stride if not all_int(x.shape): raise RuntimeError("symbolic shape not supported") x = x.pad(tuple((0, round_up(s, st) - s) for s, st in zip(x.shape, strides))) x = x.reshape(tuple(flatten((s // st, st) for s, st in zip(x.shape, strides)))) x = x.shrink(tuple(flatten(((0, s), (0, 1)) for s in x.shape[::2]))).reshape(x.shape[::2]) # dim injection from None by including None dim size (which is 1) and dim collapse by skipping int dim size x = x.reshape(tuple(index['size'] for index in indices_parsed if not isinstance(index['index'], int))) # tensor indexing if tops := [(d,i) for d,i in enumerate(i_ for i_ in indices_parsed if not isinstance(i_['index'], int)) if isinstance(i['index'], Tensor)]: # unload the tensor object into actual tensors dims, tensors, masks = [d for d,_ in tops], cast(list[Tensor], [i['index'] for _,i in tops]), [] pre_reduce_shape = x.shape[:dims[0]] + (big_shape := _broadcast_shape(*(t.shape for t in tensors))) + x.shape[dims[0]:] # create index masks for dim, tensor in zip(dims, tensors): try: i = tensor.reshape(tensor.shape + (1,)*(x.ndim - dims[0])).expand(pre_reduce_shape) except ValueError as e: raise IndexError(f"cannot broadcast indices: {e}") from e masks.append(i._one_hot_along_dim(num_classes=x.shape[dim], dim=(dim - x.ndim))) # reduce masks to 1 mask mask: Tensor = functools.reduce(lambda x,y: x.mul(y), masks) # inject 1's for the extra dims added in create masks reshape_arg = x.shape[:dims[0]] + (1,) * len(big_shape) + x.shape[dims[0]:] # sum reduce the extra dims introduced in create masks x = (x.reshape(reshape_arg) * mask).sum(sum_axis:=tuple(d + len(big_shape) for d in dims), dtype=x.dtype) # special permute case if dims[0] != 0 and len(dims) != 1 and tuple(dims) != tuple(range(dims[0], dims[-1]+1)): x = x.permute(*range(dims[0], dims[0]+len(big_shape)), *range(0, dims[0]), *range(dims[0]+len(big_shape), x.ndim)) # for advanced setitem, returns whole tensor with indices replaced if v is not None: vb = v.cast(self.dtype)._broadcast_to(_broadcast_shape(x.shape, v.shape)) # add back reduced dims from sum for dim in sum_axis: vb = vb.unsqueeze(dim) # run _masked_setitem on tuple of axis that is to be reduced to match self.shape x = _masked_setitem(self, vb, mask, tuple(range(dims[0], dims[0] + len(big_shape)))) return x def __getitem__(self, indices) -> Tensor: """ Retrieve a sub-tensor using indexing. Supported Index Types: `int | slice | Tensor | None | list | tuple | Ellipsis` Examples: ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(12).reshape(3, 4) print(t.numpy()) ``` - Int Indexing: Select an element or sub-tensor using integers for each dimension. ```python exec="true" source="above" session="tensor" result="python" print(t[1, 2].numpy()) ``` - Slice Indexing: Select a range of elements using slice notation (`start:end:stride`). ```python exec="true" source="above" session="tensor" result="python" print(t[0:2, ::2].numpy()) ``` - Tensor Indexing: Use another tensor as indices for advanced indexing. Using `tuple` or `list` here also works. ```python exec="true" source="above" session="tensor" result="python" print(t[Tensor([2, 0, 1]), Tensor([1, 2, 3])].numpy()) ``` - `None` Indexing: Add a new dimension to the tensor. ```python exec="true" source="above" session="tensor" result="python" print(t[:, None].shape) ``` NOTE: Out-of-bounds indexing results in a value of `0`. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([1, 2, 3]) print(t[Tensor([4, 3, 2])].numpy()) ``` """ return self._getitem(indices) def __setitem__(self, indices, v:Tensor|ConstType) -> None: if isinstance(self.device, str) and self.device.startswith("DISK"): self._getitem(indices).assign(v) return # NOTE: check that setitem target is valid first if not unwrap(self.lazydata.st).contiguous: raise RuntimeError("setitem target needs to be contiguous") if isinstance(v, get_args(ConstType)): v = Tensor(v, device=self.device, dtype=self.dtype) if not isinstance(v, Tensor): raise TypeError(f"can't set a {type(v).__name__} to a Tensor") if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported") res = self.realize()._getitem(indices, v) # if shapes match and data is not shared it's a copy and we assign to self if res.shape == self.shape and res.lazydata is not self.lazydata: self.assign(res).realize() else: # no copy, basic setitem v = v.cast(res.dtype)._broadcast_to(_broadcast_shape(res.shape, v.shape)).contiguous() res.assign(v).realize() def gather(self:Tensor, dim:int, index:Tensor) -> Tensor: """ Gathers values along an axis specified by `dim`. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([[1, 2], [3, 4]]) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.gather(1, Tensor([[0, 0], [1, 0]])).numpy()) ``` """ assert index.ndim == self.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}" dim = self._resolve_dim(dim) assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim" index = index.to(self.device) x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim) return (x * index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim])).sum(-1, dtype=self.dtype) def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor: """ Concatenates self with other `Tensor` in `args` along an axis specified by `dim`. All tensors must have the same shape except in the concatenating dimension. ```python exec="true" source="above" session="tensor" result="python" t0, t1, t2 = Tensor([[1, 2]]), Tensor([[3, 4]]), Tensor([[5, 6]]) print(t0.cat(t1, t2, dim=0).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t0.cat(t1, t2, dim=1).numpy()) ``` """ dim = self._resolve_dim(dim) for arg in args: assert arg.ndim==self.ndim and all(ti==ai for i,(ti,ai) in enumerate(zip(self.shape, arg.shape)) if i!=dim) tensors = [self, *args] dim_cumsum = list(itertools.accumulate([t.shape[dim] for t in tensors], initial=0)) for i,t in enumerate(tensors): tensors[i] = t.pad([(dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==dim else None for j in range(t.ndim)]) return functools.reduce(Tensor.add, tensors) def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor: """ Concatenates self with other `Tensor` in `args` along a new dimension specified by `dim`. ```python exec="true" source="above" session="tensor" result="python" t0, t1, t2 = Tensor([1, 2]), Tensor([3, 4]), Tensor([5, 6]) print(t0.stack(t1, t2, dim=0).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t0.stack(t1, t2, dim=1).numpy()) ``` """ # checks for shapes and number of dimensions delegated to cat return Tensor.cat(*[t.unsqueeze(dim) for t in [self, *args]], dim=dim) def repeat_interleave(self, repeats:int, dim:int|None=None) -> Tensor: """ Repeat elements of a tensor. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([1, 2, 3]) print(t.repeat_interleave(2).numpy()) ``` """ x, dim = (self.flatten(), 0) if dim is None else (self, self._resolve_dim(dim)) shp = x.shape return x.reshape(*shp[:dim+1], 1, *shp[dim+1:]).expand(*shp[:dim+1], repeats, *shp[dim+1:]).reshape(*shp[:dim], shp[dim]*repeats, *shp[dim+1:]) def repeat(self, repeats, *args) -> Tensor: """ Repeats tensor number of times along each dimension specified by `repeats`. `repeats` can be passed as a tuple or as separate arguments. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([1, 2, 3]) print(t.repeat(4, 2).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.repeat(4, 2, 1).shape) ``` """ repeats = argfix(repeats, *args) base_shape = _align_left(self.shape, repeats)[0] unsqueezed_shape = flatten([[1, s] for s in base_shape]) expanded_shape = flatten([[r, s] for r,s in zip(repeats, base_shape)]) final_shape = [r*s for r,s in zip(repeats, base_shape)] return self.reshape(unsqueezed_shape).expand(expanded_shape).reshape(final_shape) def _resolve_dim(self, dim:int, *, extra:bool=False) -> int: total = self.ndim + int(extra) if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}") return dim + total if dim < 0 else dim def split(self, sizes:int|Sequence[int], dim:int=0) -> tuple[Tensor, ...]: """ Splits the tensor into chunks along the dimension specified by `dim`. If `sizes` is an integer, it splits into equally sized chunks if possible, otherwise the last chunk will be smaller. If `sizes` is a list, it splits into `len(sizes)` chunks with size in `dim` according to `size`. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(10).reshape(5, 2) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" split = t.split(2) print("\\n".join([repr(x.numpy()) for x in split])) ``` ```python exec="true" source="above" session="tensor" result="python" split = t.split([1, 4]) print("\\n".join([repr(x.numpy()) for x in split])) ``` """ assert all_int(self.shape), f"does not support symbolic shape {self.shape}" dim = self._resolve_dim(dim) if isinstance(sizes, int): sizes = [min(sizes, self.shape[dim]-i) for i in range(0, max(1, self.shape[dim]), max(1, sizes))] assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}" return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))]) def chunk(self, chunks:int, dim:int=0) -> list[Tensor]: """ Splits the tensor into `chunks` number of chunks along the dimension `dim`. If the tensor size along `dim` is not divisible by `chunks`, all returned chunks will be the same size except the last one. The function may return fewer than the specified number of chunks. ```python exec="true" source="above" session="tensor" result="python" chunked = Tensor.arange(11).chunk(6) print("\\n".join([repr(x.numpy()) for x in chunked])) ``` ```python exec="true" source="above" session="tensor" result="python" chunked = Tensor.arange(12).chunk(6) print("\\n".join([repr(x.numpy()) for x in chunked])) ``` ```python exec="true" source="above" session="tensor" result="python" chunked = Tensor.arange(13).chunk(6) print("\\n".join([repr(x.numpy()) for x in chunked])) ``` """ assert all_int(self.shape), f"does not support symbolic shape {self.shape}" assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}" dim = self._resolve_dim(dim) return list(self.split(ceildiv(self.shape[dim], chunks) if self.shape[dim] else [0]*chunks, dim=dim)) def meshgrid(self:Tensor, *args:Tensor, indexing:Literal["ij", "xy"]="ij") -> tuple[Tensor, ...]: """ Generates coordinate matrices from coordinate vectors. Input tensors can be scalars or 1D tensors. `indexing` determines how the output grids are aligned. `ij` indexing follows matrix-style indexing and `xy` indexing follows Cartesian-style indexing. ```python exec="true" source="above" session="tensor" result="python" x, y = Tensor([1, 2, 3]), Tensor([4, 5, 6]) grid_x, grid_y = x.meshgrid(y) print(grid_x.numpy()) print(grid_y.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" grid_x, grid_y = x.meshgrid(y, indexing="xy") print(grid_x.numpy()) print(grid_y.numpy()) ``` """ if indexing not in ("ij", "xy"): raise RuntimeError(f'indexing must be in ("ij", "xy"), got {indexing}') if len(tensors:=(self, *args)) == 1: return tensors basis = tuple(range(len(tensors))) if indexing == "ij" else (1, 0) + tuple(range(2, len(tensors))) tensors = tuple(t.reshape((-1,) + (1,)*(len(args) - i)) for i,t in zip(basis, tensors)) output_shape = _broadcast_shape(*(t.shape for t in tensors)) return tuple(t._broadcast_to(output_shape) for t in tensors) def squeeze(self, dim:int|None=None) -> Tensor: """ Returns a tensor with specified dimensions of input of size 1 removed. If `dim` is not specified, all dimensions with size 1 are removed. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.zeros(2, 1, 2, 1, 2) print(t.squeeze().shape) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.squeeze(0).shape) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.squeeze(1).shape) ``` """ if dim is None: return self.reshape(tuple(dim for dim in self.shape if dim != 1)) dim = self._resolve_dim(dim) return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:]) def unsqueeze(self, dim:int) -> Tensor: """ Returns a tensor with a new dimension of size 1 inserted at the specified `dim`. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([1, 2, 3, 4]) print(t.unsqueeze(0).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.unsqueeze(1).numpy()) ``` """ dim = self._resolve_dim(dim, extra=True) return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:]) @property def T(self) -> Tensor: """`.T` is an alias for `.transpose()`.""" return self.transpose() def transpose(self, dim0=1, dim1=0) -> Tensor: """ Returns a tensor that is a transposed version of the original tensor. The given dimensions `dim0` and `dim1` are swapped. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(6).reshape(2, 3) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.transpose(0, 1).numpy()) ``` """ order = list(range(self.ndim)) order[dim0], order[dim1] = order[dim1], order[dim0] return self.permute(order) def flatten(self, start_dim=0, end_dim=-1) -> Tensor: """ Flattens the tensor by reshaping it into a one-dimensional tensor. If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(8).reshape(2, 2, 2) print(t.flatten().numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.flatten(start_dim=1).numpy()) ``` """ start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim) return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:]) def unflatten(self, dim:int, sizes:tuple[int,...]) -> Tensor: """ Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function. ```python exec="true" source="above" session="tensor" result="python" print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape) ``` """ dim = self._resolve_dim(dim) return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:]) def roll(self, shifts:int|tuple[int, ...], dims:int|tuple[int, ...]) -> Tensor: """ Rolls the tensor along specified dimension(s). The rolling operation is circular, meaning that elements that go beyond the edge are wrapped around to the beginning of the dimension. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(4) print(t.roll(shifts=1, dims=0).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.roll(shifts=-1, dims=0).numpy()) ``` """ dims, rolled = tuple(self._resolve_dim(d) for d in make_tuple(dims, 1)), self for dim, shift in zip(dims, make_tuple(shifts, 1)): shift = shift % self.shape[dim] rolled = Tensor.cat(rolled[tuple(slice(None) if i != dim else slice(-shift, None) for i in range(rolled.ndim))], rolled[tuple(slice(None) if i != dim else slice(None, -shift) for i in range(rolled.ndim))], dim=dim) return rolled def rearrange(self, formula:str, **sizes) -> Tensor: """ Rearranges input according to formula See: https://einops.rocks/api/rearrange/ ```python exec="true" source="above" session="tensor" result="python" x = Tensor([[1, 2], [3, 4]]) print(Tensor.rearrange(x, "batch channel -> (batch channel)").numpy()) ``` """ def parse_formula(formula: str): tokens = f" {formula} ".replace("…", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split() lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")")) pairs = list(zip(lparens, rparens)) assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch" return [name for name in tokens if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)] assert formula.count("->") == 1, 'need exactly one "->" in formula' (lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->")) for name in sizes: assert name in lhs, f"axis {name} is not used in transform" assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}" for name in flatten((lhs, rhs)): assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}" assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}" assert lhs.count("...") <= 1, f"too many ellipses in {formula}" # resolve ellipsis if "..." in lhs: ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims) lhs, rhs = map(lambda l: l[:(i:=l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1:] if "..." in l else l, (lhs, rhs)) unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims] flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims] # apply movement ops in order unflatten -> permute -> flatten/unsqueeze t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self) for i, name in enumerate(lhs): assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect" t = t.permute([lhs.index(name) for name in rhs]) return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0] Tensor: axis = tuple(self._resolve_dim(x) for x in (range(self.ndim) if axis is None else make_tuple(axis, 1))) if self.ndim == 0: axis = () ret = self._apply_uop(UOp.r, op=op, axis=axis) return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis)) def sum(self, axis:int|Sequence[int]|None=None, keepdim=False, dtype:DTypeLike|None=None) -> Tensor: """ Returns the sum of the elements of the tensor along the specified axis or axes. You can pass in `axis` and `keepdim` keyword arguments to control the axis along which the maximum is computed and whether the reduced dimensions are retained. You can pass in `dtype` keyword argument to control the data type of the accumulation. If not specified, the accumulation data type is chosen based on the input tensor's data type. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(6).reshape(2, 3) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.sum().numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.sum(axis=0).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.sum(axis=1).numpy()) ``` """ ret = self.cast(sum_acc_dtype(self.dtype) if dtype is None else dtype)._reduce(Ops.ADD, axis, keepdim) return ret.cast(self.dtype) if dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret def prod(self, axis:int|Sequence[int]|None=None, keepdim=False, dtype:DTypeLike|None=None) -> Tensor: """ Returns the product of the elements of the tensor along the specified axis or axes. You can pass in `axis` and `keepdim` keyword arguments to control the axis along which the maximum is computed and whether the reduced dimensions are retained. You can pass in `dtype` keyword argument to control the data type of the accumulation. If not specified, the accumulation data type is chosen based on the input tensor's data type. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([-1, -2, -3, 1, 2, 3]).reshape(2, 3) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.prod().numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.prod(axis=0).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.prod(axis=1).numpy()) ``` """ return self.cast(dtype if dtype is not None else self.dtype)._reduce(Ops.MUL, axis, keepdim) def max(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor: """ Returns the maximum value of the tensor along the specified axis or axes. You can pass in `axis` and `keepdim` keyword arguments to control the axis along which the maximum is computed and whether the reduced dimensions are retained. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([[1, 0, 2], [5, 4, 3]]) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.max().numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.max(axis=0).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.max(axis=1, keepdim=True).numpy()) ``` """ return self._reduce(Ops.MAX, axis, keepdim) def _inverse(self) -> Tensor: return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not() def min(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor: """ Returns the minimum value of the tensor along the specified axis or axes. You can pass in `axis` and `keepdim` keyword arguments to control the axis along which the minimum is computed and whether the reduced dimensions are retained. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([[1, 0, 2], [5, 4, 3]]) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.min().numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.min(axis=0).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.min(axis=1, keepdim=True).numpy()) ``` """ return self._inverse().max(axis=axis, keepdim=keepdim)._inverse() def any(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor: """ Tests if any element evaluates to `True` along the specified axis or axes. You can pass in `axis` and `keepdim` keyword arguments to control the reduce axis and whether the reduced dimensions are retained. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([[True, True], [True, False], [False, False]]) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.any().numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.any(axis=0).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.any(axis=1, keepdim=True).numpy()) ``` """ return self.bool().max(axis, keepdim) def all(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor: """ Tests if all element evaluates to `True` along the specified axis or axes. You can pass in `axis` and `keepdim` keyword arguments to control the reduce axis and whether the reduced dimensions are retained. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([[True, True], [True, False], [False, False]]) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.all().numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.all(axis=0).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.all(axis=1, keepdim=True).numpy()) ``` """ return self.logical_not().any(axis, keepdim).logical_not() def isclose(self, other:Tensor, rtol:float=1e-05, atol:float=1e-08, equal_nan=False) -> Tensor: """ Returns a new tensor with element-wise comparison of closeness to `other` within a tolerance. The `rtol` and `atol` keyword arguments control the relative and absolute tolerance of the comparison. By default, two `NaN` values are not close to each other. If `equal_nan` is `True`, two `NaN` values are considered close. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([1e-7, 1e-8, 1e-9, float('nan')]).isclose(Tensor([0.0, 0.0, 0.0, float('nan')])).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor([float('nan')]).isclose(Tensor([float('nan')]), equal_nan=True).numpy()) ``` """ is_finite_close = self.isfinite() & other.isfinite() & ((self - other).abs() <= atol + rtol * other.abs()) is_infinite_close = (self.isinf() | other.isinf()) & (self == other) is_nan_close = (self.isnan() & other.isnan()) & equal_nan return is_finite_close | is_infinite_close | is_nan_close def mean(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor: """ Returns the mean value of the tensor along the specified axis or axes. You can pass in `axis` and `keepdim` keyword arguments to control the axis along which the mean is computed and whether the reduced dimensions are retained. ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) t = Tensor.normal(2, 3, mean=2.5, std=0.5) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.mean().numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.mean(axis=0).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.mean(axis=1).numpy()) ``` """ output_dtype = self.dtype if dtypes.is_float(self.dtype) else dtypes.float32 numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim) return numerator.div(prod([cast(int, si) for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])) \ .cast(output_dtype) def var(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> Tensor: """ Returns the variance of the tensor along the specified axis or axes. You can pass in `axis`, `keepdim`, and `correction` keyword arguments to control the axis along which the variance is computed, whether the reduced dimensions are retained, and the Bessel's correction applied. ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) t = Tensor.normal(2, 3, mean=2.5, std=0.5) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.var().numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.var(axis=0).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.var(axis=1).numpy()) ``` """ squares = (self - self.mean(axis=axis, keepdim=True)).square() n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if resolve(si != so)]) return squares.sum(axis=axis, keepdim=keepdim).div(smax([0, n-correction])) def var_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> tuple[Tensor, Tensor]: """ Calculates the variance and mean over the dimensions specified by dim. Syntactic sugar around `Tensor.var` and `Tensor.mean` to match `torch.var_mean`. ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) t = Tensor.normal(2, 3, mean=2.5, std=0.5) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" var, mean = t.var_mean() print(var.numpy(), mean.numpy()) ``` """ return self.var(axis, keepdim, correction), self.mean(axis, keepdim) def std(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> Tensor: """ Returns the standard deviation of the tensor along the specified axis or axes. You can pass in `axis`, `keepdim`, and `correction` keyword arguments to control the axis along which the standard deviation is computed, whether the reduced dimensions are retained, and the Bessel's correction applied. ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) t = Tensor.normal(2, 3, mean=2.5, std=0.5) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.std().numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.std(axis=0).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.std(axis=1).numpy()) ``` """ return self.var(axis, keepdim, correction).sqrt() def std_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> tuple[Tensor, Tensor]: """ Calculates the standard deviation and mean over the dimensions specified by dim. Syntactic sugar around `Tensor.std` and `Tensor.mean` to match `torch.std_mean`. ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) t = Tensor.normal(2, 3, mean=2.5, std=0.5) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" std, mean = t.std_mean() print(std.numpy(), mean.numpy()) ``` """ return self.std(axis, keepdim, correction), self.mean(axis, keepdim) def _softmax(self, axis, dtype:DTypeLike|None=None) -> tuple[Tensor, Tensor, Tensor]: m = self - self.max(axis=axis, keepdim=True).detach() if dtype is not None: m = m.cast(dtype) e = m.exp() return m, e, e.sum(axis=axis, keepdim=True) def softmax(self, axis=-1, dtype:DTypeLike|None=None) -> Tensor: """ Applies the softmax function to the tensor along the specified axis. Rescales the elements of the tensor such that they lie in the range [0, 1] and sum to 1. You can pass in the `axis` keyword argument to control the axis along which the softmax is computed. ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) t = Tensor.randn(2, 3) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.softmax().numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.softmax(axis=0).numpy()) ``` """ _, e, ss = self._softmax(axis, dtype) return e.div(ss) def log_softmax(self, axis=-1, dtype:DTypeLike|None=None) -> Tensor: """ Applies the log-softmax function to the tensor along the specified axis. The log-softmax function is a numerically stable alternative to the softmax function in log space. You can pass in the `axis` keyword argument to control the axis along which the log-softmax is computed. ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) t = Tensor.randn(2, 3) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.log_softmax().numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.log_softmax(axis=0).numpy()) ``` """ m, _, ss = self._softmax(axis, dtype) return m - ss.log() def logsumexp(self, axis=None, keepdim=False) -> Tensor: """ Computes the log-sum-exp of the tensor along the specified axis or axes. The log-sum-exp function is a numerically stable way to compute the logarithm of the sum of exponentials. You can pass in `axis` and `keepdim` keyword arguments to control the axis along which the log-sum-exp is computed and whether the reduced dimensions are retained. ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) t = Tensor.randn(2, 3) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.logsumexp().numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.logsumexp(axis=0).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.logsumexp(axis=1).numpy()) ``` """ m = self.max(axis=axis, keepdim=True) return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + m.squeeze(axis) def logcumsumexp(self, axis=0) -> Tensor: """ Computes the log-cumsum-exp of the tensor along the specified axis or axes. The log-cumsum-exp function is a numerically stable way to compute the logarithm of the cumulative sum of exponentials. You can pass in the `axis` keyword argument to control the axis along which the log-cum-sum-exp is computed. ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) t = Tensor.randn(2, 3) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.logcumsumexp().numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.logcumsumexp(axis=0).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.logcumsumexp(axis=1).numpy()) ``` """ if self.ndim == 0: return self axis = self._resolve_dim(axis) x = self.transpose(axis, -1) last_dim_size = x.shape[-1] x_reshaped = x.reshape(-1, last_dim_size) x_cummax = x_reshaped.cummax(-1).unsqueeze(-1) x_expand = x_reshaped.unsqueeze(1).expand(*x_reshaped.shape, last_dim_size) mask = Tensor.ones(last_dim_size, last_dim_size, requires_grad=False, device=self.device).tril().unsqueeze(0) ret = ((x_expand - x_cummax).exp() * mask).sum(-1).log() + x_cummax.squeeze(-1) return ret.reshape(*x.shape).transpose(-1, axis) def argmax(self, axis=None, keepdim=False) -> Tensor: """ Returns the indices of the maximum value of the tensor along the specified axis. You can pass in `axis` and `keepdim` keyword arguments to control the axis along which the maximum is computed and whether the reduced dimensions are retained. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([[1, 0, 2], [5, 4, 3]]) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.argmax().numpy()) # Returns the index of the maximum value in the flattened tensor. ``` ```python exec="true" source="above" session="tensor" result="python" print(t.argmax(axis=0).numpy()) # Returns the indices of the maximum values along axis 0. ``` ```python exec="true" source="above" session="tensor" result="python" print(t.argmax(axis=1).numpy()) # Returns the indices of the maximum values along axis 1. ``` """ if axis is None: return self.flatten().argmax(0) axis = self._resolve_dim(axis) m = self == self.max(axis=axis, keepdim=True) idx = m * Tensor.arange(self.shape[axis],0,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1)) return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)).cast(dtypes.int32) def argmin(self, axis=None, keepdim=False) -> Tensor: """ Returns the indices of the minimum value of the tensor along the specified axis. You can pass in `axis` and `keepdim` keyword arguments to control the axis along which the minimum is computed and whether the reduced dimensions are retained. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([[1, 0, 2], [5, 4, 3]]) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.argmin().numpy()) # Returns the index of the minimum value in the flattened tensor. ``` ```python exec="true" source="above" session="tensor" result="python" print(t.argmin(axis=0).numpy()) # Returns the indices of the minimum values along axis 0. ``` ```python exec="true" source="above" session="tensor" result="python" print(t.argmin(axis=1).numpy()) # Returns the indices of the minimum values along axis 1. ``` """ return self._inverse().argmax(axis=axis, keepdim=keepdim) @staticmethod def einsum(formula:str, *operands:Tensor|Sequence[Tensor], dtype:DTypeLike|None=None) -> Tensor: """ Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention. See: https://pytorch.org/docs/stable/generated/torch.einsum.html ```python exec="true" source="above" session="tensor" result="python" x = Tensor([[1, 2], [3, 4]]) y = Tensor([[5, 6], [7, 8]]) print(Tensor.einsum("ij,ij->", x, y).numpy()) ``` """ def parse_formula(formula:str, *operands:Tensor): if "..." in (formula := formula.replace(" ", "")): ell_chars, ell_longest = "".join(set(string.ascii_letters) - set(formula)), 0 for i, inp in enumerate(filter(lambda x: "..." in x, inputs := formula.split("->")[0].split(","))): if (ell_count := max(operands[i].ndim, 1) - (len(inp) - len("..."))) > ell_longest: ell_longest = ell_count inputs[i] = inp.replace("...", ell_chars[-ell_count:]) inputs_str, out_ellipse = ",".join(inputs), ell_chars[-ell_longest:] return (inputs_str, formula.split("->")[1].replace("...", out_ellipse)) if "->" in formula else \ (inputs_str, out_ellipse + ''.join(sorted(c for c in inputs_str if inputs_str.count(c) == 1 and c.isalpha() and c not in out_ellipse))) return formula.split("->") if "->" in formula else (formula, ''.join(c for c in sorted(formula) if formula.count(c) == 1 and c.isalpha())) xs:tuple[Tensor, ...] = argfix(*operands) inputs_str, output = parse_formula(formula, *xs) inputs = inputs_str.split(",") assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}" # map the value of each letter in the formula letter_val = sorted(merge_dicts([dict(zip(letters, tensor.shape)) for letters, tensor in zip(inputs, xs)]).items()) xs_:list[Tensor] = [] lhs = [sorted(enumerate(s), key=lambda e:e[1]) for s in inputs] for x,(order,letters) in zip(xs, [list(zip(*l)) for l in lhs]): # permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters xs_.append(x.permute(order).reshape([val if letter in letters else 1 for letter,val in letter_val]).expand([val for _,val in letter_val])) # ordinal encode the output alphabet rhs_order = argsort(argsort(list(output))) # sum over all axes that's not in the output, then permute to the output order return functools.reduce(lambda a,b:a*b, xs_) \ .sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output], dtype=dtype).permute(rhs_order) # ***** processing ops ***** def _pool(self, k_:tuple[sint, ...], stride:int|tuple[int, ...]=1, dilation:int|tuple[int, ...]=1) -> Tensor: assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}" s_, d_ = make_tuple(stride, len(k_)), make_tuple(dilation, len(k_)) assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}" noop, i_ = [None] * (self.ndim-len(k_)), self.shape[-len(k_):] assert all(resolve(d*(k-1)+1 <= i) for k,d,i in zip(k_,d_,i_)), "kernel size cannot be greater than actual input size" o_ = [ceildiv(i-d*(k-1), s) for i,d,k,s in zip(i_,d_,k_,s_)] if any(resolve(k > s) for k,s in zip(k_,s_)) or any(d != 1 for d in d_): # input size scaling factor to make sure shrink for stride is possible f_ = [1 + int(resolve(o*s > (i - d*(k-1)))) for o,s,i,d,k in zip(o_,s_,i_,d_,k_)] # # repeats such that we don't need padding x = self.repeat([1]*len(noop) + [ceildiv(k*(i*f+d),i) for k,i,d,f in zip(k_,i_,d_,f_)]) # handle dilation x = x.shrink(tuple(noop + [(0,k*(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)])).reshape(noop + flatten((k,(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_))) # handle stride x = x.shrink(tuple(noop + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_,o_,s_)))).reshape(noop + flatten((k,o,s) for k,o,s in zip(k_,o_,s_))) x = x.shrink(tuple(noop + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_,o_)))).reshape(noop + flatten((k,o) for k,o in zip(k_,o_))) # permute to move reduce to the end return x.permute(*range(len(noop)), *[len(noop)+i*2+1 for i in range(len(i_))], *[len(noop)+i*2 for i in range(len(i_))]) # TODO: once the shapetracker can optimize well, remove this alternative implementation x = self.pad(tuple(noop + [(0, max(0,o*s-i)) for i,o,s in zip(i_,o_,s_)])).shrink(tuple(noop + [(0,o*s) for o,s in zip(o_,s_)])) x = x.reshape(noop + flatten(((o,s) for o,s in zip(o_,s_)))) x = x.shrink(tuple(noop + flatten(((0,o), (0,k)) for o,k in zip(o_,k_)))) return x.permute(*range(len(noop)), *[len(noop)+i*2 for i in range(len(i_))], *[len(noop)+i*2+1 for i in range(len(i_))]) def _resolve_pool_pads(self, padding:int|Sequence[int], dims:int) -> Sequence[int]: if not isinstance(padding, int) and not (len(padding) == 2*dims or len(padding) == dims): raise ValueError(f"Padding must be an int or a sequence of length {dims} or {2*dims}, but got {padding=} for {self.shape=} with {dims=}.") return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1]) def _apply_ceil_mode(self, pads:Sequence[int], k_:tuple[sint, ...], s_:int|tuple[int, ...], d_:int|tuple[int, ...]) -> list[int]: (d_,s_), i_ = (make_tuple(x, len(k_)) for x in (d_,s_)), self.shape[-len(k_):] pads, grouped_pads = list(pads), _flat_to_grouped(pads) # https://arxiv.org/pdf/1603.07285 section 5.1, relationship 15. o_ = [ceildiv(i+pB+pA - (d*(k-1)+1), s) + 1 for i,d,k,s,(pB,pA) in zip(i_,d_,k_,s_,grouped_pads)] for dim,(o,i,s,k,d,(pB,pA)) in enumerate(zip(o_,i_,s_,k_,d_,grouped_pads)): # we have to do additional padding before `_pool` so that `o_` in `_pool` is calculated correctly # `s*(o-1) + (d*(k-1)+1) - (i+pB+pA)` -> last_sliding_window_start + full_kernel_size - padded_input_shape # we decrease padding in the case that a sliding window starts in the end padded region, thereby decreasing `o_` in `_pool` # `smax(s*(o-1) - (pB+i-1), 0)` -> last_sliding_window_start - (pad_before + input_size - zero_offset) pads[-1-dim*2] += s*(o-1) + (d*(k-1)+1) - (i+pB+pA) - smax(s*(o-1) - (pB+i-1), 0) return pads # NOTE: these work for more than 2D def avg_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0, ceil_mode=False, count_include_pad=True) -> Tensor: """ Applies average pooling over a tensor. This function supports three different types of `padding` 1. `int` (single value): Applies the same padding value uniformly to all spatial dimensions. 2. `tuple[int, ...]` (length = number of spatial dimensions): Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`. 3. `tuple[int, ...]` (length = 2 * number of spatial dimensions): Specifies explicit padding for each side of each spatial dimension in the form `(padding_left, padding_right, padding_top, padding_bottom, ...)`. When `ceil_mode` is set to `True`, output shape will be determined using ceil division. When `count_include_pad` is set to `False`, zero padding will not be included in the averaging calculation. NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions. See: https://paperswithcode.com/method/average-pooling ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(25).reshape(1, 1, 5, 5) print(t.avg_pool2d().numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.avg_pool2d(ceil_mode=True).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.avg_pool2d(padding=1).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.avg_pool2d(padding=1, count_include_pad=False).numpy()) ``` """ axis = tuple(range(-len(k_ := make_tuple(kernel_size, 2)), 0)) def pool(x:Tensor, padding_:Sequence[int]) -> Tensor: return x.pad(padding_)._pool(k_, stride if stride is not None else k_, dilation) reg_pads = self._resolve_pool_pads(padding, len(k_)) ceil_pads = self._apply_ceil_mode(reg_pads, k_, stride if stride is not None else k_, dilation) if not count_include_pad: pads = ceil_pads if ceil_mode else reg_pads return pool(self, pads).sum(axis) / pool(self.ones_like(), pads).sum(axis) if not ceil_mode: return pool(self, reg_pads).mean(axis) return pool(self, ceil_pads).sum(axis) / pool(self.pad(reg_pads).ones_like(), tuple(cp-rp for cp,rp in zip(ceil_pads, reg_pads))).sum(axis) def max_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0, ceil_mode=False, return_indices=False) -> Tensor | tuple[Tensor, Tensor]: """ Applies max pooling over a tensor. This function supports three different types of `padding` 1. `int` (single value): Applies the same padding value uniformly to all spatial dimensions. 2. `tuple[int, ...]` (length = number of spatial dimensions): Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`. 3. `tuple[int, ...]` (length = 2 * number of spatial dimensions): Specifies explicit padding for each side of each spatial dimension in the form `(padding_left, padding_right, padding_top, padding_bottom, ...)`. When `ceil_mode` is set to `True`, output shape will be determined using ceil division. When `return_indices` is set to `True`, the argmax will be returned along with the max values. NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions. See: https://paperswithcode.com/method/max-pooling ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(25).reshape(1, 1, 5, 5) print(t.max_pool2d().numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.max_pool2d(ceil_mode=True).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.max_pool2d(padding=1).numpy()) ``` """ axis = tuple(range(-len(k_ := make_tuple(kernel_size, 2)), 0)) pads = self._resolve_pool_pads(padding, len(k_)) if ceil_mode: pads = self._apply_ceil_mode(pads, k_, stride if stride is not None else k_, dilation) pooled = self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation) if not return_indices: return pooled.max(axis) spatial_sz = math.prod(spatial_shape := self.shape[-len(k_):]) idx = Tensor.arange(spatial_sz,0,-1, requires_grad=False, device=self.device).reshape(spatial_shape) m = pooled == pooled.max(axis, keepdim=True) idx = m * idx.pad(pads, value=dtypes.min(idx.dtype))._pool(k_, stride if stride is not None else k_, dilation) return pooled.max(axis), spatial_sz - idx.max(axis) def max_unpool2d(self, indices:Tensor, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0, output_size=None): """ Performs a partial inverse of `max_pool2d` using the indices from the argmax. When `output_size` is provided, the output shape disambiguates to the provided shape. NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(1, 17).reshape(1, 1, 4, 4) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" output, indices = Tensor.max_pool2d(t, return_indices=True) print(output.numpy()) print(indices.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor.max_unpool2d(output, indices).numpy()) ``` """ bs,c,*spatial_shape = self.shape if output_size is None: k_,d_,s_ = (make_tuple(x, len(spatial_shape)) for x in (kernel_size, dilation, stride if stride is not None else kernel_size)) p_ = _flat_to_grouped(self._resolve_pool_pads(padding, len(spatial_shape))) # https://arxiv.org/pdf/1603.07285 inverse of relationship 15 in section 5.1. output_size = tuple((i-1)*s - (pB+pA) + (d*(k-1)+1) for i,k,d,s,(pA,pB) in zip(spatial_shape,k_,d_,s_,p_)) else: output_size = output_size[-len(spatial_shape):] ret = (indices.reshape(bs,c,1,-1)._one_hot_along_dim(prod(output_size), 2) * self.reshape(bs,c,1,-1)).sum(3) return ret.reshape(bs,c,*output_size) def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0, dtype:DTypeLike|None=None) -> Tensor: """ Applies a convolution over a tensor with a given `weight` and optional `bias`. This function supports three different types of `padding` 1. `int` (single value): Applies the same padding value uniformly to all spatial dimensions. 2. `tuple[int, ...]` (length = number of spatial dimensions): Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`. 3. `tuple[int, ...]` (length = 2 * number of spatial dimensions): Specifies explicit padding for each side of each spatial dimension in the form `(padding_left, padding_right, padding_top, padding_bottom, ...)`. NOTE: unlike PyTorch, this implementation is not limited to only 2d convolutions and instead works for any number of dimensions. See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(9).reshape(1, 1, 3, 3) w = Tensor.ones(1, 1, 2, 2) print(t.conv2d(w).numpy()) ``` """ if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding, dtype) (bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:] padding_ = self._resolve_pool_pads(padding, len(HW)) assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501 # conv2d is a pooling op (with padding) x = self.pad(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W) rcout, oyx = cout//groups, x.shape[2:-len(HW)] if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not WINO: # normal conv x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) # noqa: E501 # conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW) ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True, dtype=dtype).reshape(bs, cout, *oyx) # noqa: E501 return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW))) HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles winograd_G = [[1/4, 0, 0], [-1/6, -1/6, -1/6], [-1/6, 1/6, -1/6], [1/24, 1/12, 1/6], [1/24, -1/12, 1/6], [0, 0, 1]] winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]] winograd_At = [[1, 1, 1, 1, 1, 0], [0, 1, -1, 2, -2, 0], [0, 1, 1, 4, 4, 0], [0, 1, -1, 8, -8, 1]] # applying At in pre-order doubles compile time # todo: stride == dilation # use padding to round up to 4x4 output tiles # (bs, cin_, tyx, HWI) d = self.pad(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # noqa: E501 # move HW to the front: # (HWI, bs, cin_, tyx) d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW))) tyx = d.shape[-len(HWI):] # dim of tiling g = weight.permute(*range(len(weight.shape)-len(HW),len(weight.shape)), *range(len(weight.shape)-len(HW))) # move HW to the front # compute 6x6 winograd tiles: GgGt, BtdB # (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1)) gfactors = _apply_winograd_matrix(winograd_G, g, len(HW)).reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx))) # (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx) dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).reshape(*HWI, bs, groups, 1, cin, *tyx) # matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx) ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), dtype=dtype), len(HW)) # interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO) ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]]) # merge groups and rcout, tyx and HWO: (bs, groups, cout, *yx), shrink to final ret = ret.reshape(bs, cout, *[c * HWO[i] for i, c in enumerate(tyx)]).shrink(tuple((0, s) for s in [bs, cout, *oyx])) return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward() def conv_transpose2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor: """ Applies a transposed convolution over a tensor with a given `weight` and optional `bias`. This function supports three different types of `padding` 1. `int` (single value): Applies the same padding value uniformly to all spatial dimensions. 2. `tuple[int, ...]` (length = number of spatial dimensions): Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`. 3. `tuple[int, ...]` (length = 2 * number of spatial dimensions): Specifies explicit padding for each side of each spatial dimension in the form `(padding_left, padding_right, padding_top, padding_bottom, ...)`. NOTE: unlike PyTorch, this implementation is not limited to only 2d transposed convolutions and instead works for any number of dimensions. See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(9).reshape(1, 1, 3, 3) w = Tensor.ones(1, 1, 2, 2) print(t.conv_transpose2d(w).numpy()) ``` """ x, w = self, weight.unflatten(0, (groups, -1)).transpose(1, 2).flip(*range(3, len(weight.shape)+1)) HW = weight.shape[2:] padding = _flat_to_grouped(self._resolve_pool_pads(padding, len(HW))) stride, dilation, output_padding = [make_tuple(x, len(HW)) for x in (stride, dilation, output_padding)] if any(s>1 for s in stride): # handle strides: (k) -> reshape -> (k,1) -> pad -> (k,s) -> reshape -> (k*s) -> shrink (k-(s-1)) x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:])) x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride))) x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)]) x = x.shrink((None, None, *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)])) padding = flatten((((k-1)*d-pB,(k-1)*d-pA+op) for k,d,(pB,pA),op in reversed(list(zip(HW, dilation, padding, output_padding))))) return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding) def dot(self, w:Tensor, dtype:DTypeLike|None=None) -> Tensor: """ Performs dot product between two tensors. If `w` is 1-D, it's a sum product over the last axis of `self` and `w`. If `w` is N-D with N>=2, it's a sum product over the last axis of `self` and the second-to-last axis of `w`. You can pass in the optional `dtype` keyword argument to control the data type of the accumulation. ```python exec="true" source="above" session="tensor" result="python" a = Tensor([1, 2, 3]) b = Tensor([1, 1, 0]) print(a.dot(b).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" a = Tensor([[1, 2], [3, 4]]) b = Tensor([[5, 6], [7, 8]]) print(a.dot(b).numpy()) ``` """ if IMAGE: return self.image_dot(w, dtype) x, dx, dw = self, self.ndim, w.ndim if not (dx > 0 and dw > 0): raise RuntimeError(f"both tensors need to be at least 1D, got {dx}D and {dw}D") if x.shape[-1] != w.shape[axis_w:=-min(w.ndim,2)]: raise RuntimeError(f"cannot dot {x.shape} and {w.shape}") x = x.reshape(*x.shape[0:-1], *[1]*min(dx-1, dw-1, 1), x.shape[-1]) w = w.reshape(*w.shape[0:-2], *[1]*min(dx-1, dw-1, 1), *w.shape[axis_w:]).transpose(-1, axis_w) return (x*w).sum(-1, dtype=dtype).cast(least_upper_dtype(x.dtype, w.dtype) if dtype is None else dtype) def matmul(self, x:Tensor, reverse=False, dtype:DTypeLike|None=None) -> Tensor: """ Performs matrix multiplication between two tensors. You can pass in the `reverse` keyword argument to control the order of the matrix multiplication. You can pass in the optional `dtype` keyword argument to control the data type of the accumulation. ```python exec="true" source="above" session="tensor" result="python" a = Tensor([[1, 2], [3, 4]]) b = Tensor([[5, 6], [7, 8]]) print(a.matmul(b).numpy()) ``` """ return x.dot(self, dtype=dtype) if reverse else self.dot(x, dtype=dtype) def _cumalu(self, axis:int, op:Ops, _include_initial=False) -> Tensor: assert self.shape[axis] != 0 and op in (Ops.ADD, Ops.MAX, Ops.MUL) pl_sz = self.shape[axis] - int(not _include_initial) pooled = self.transpose(axis,-1).pad((pl_sz, -int(_include_initial)), value=identity_element(op, self.dtype))._pool((self.shape[axis],)) return {Ops.ADD: pooled.sum(-1), Ops.MAX: pooled.max(-1), Ops.MUL: pooled.prod(-1)}[op].transpose(axis, -1) def _split_cumalu(self, axis:int, op:Ops) -> Tensor: axis = self._resolve_dim(axis) if self.ndim == 0 or 0 in self.shape: return self # TODO: someday the optimizer will find this on it's own # for now this is a two stage cumsum SPLIT = 256 if not isinstance(s:=self.shape[axis], int) or s <= SPLIT*2: return self._cumalu(axis, op) ret = self.transpose(axis,-1).pad((round_up(s, SPLIT)-s, 0), value=identity_element(op, self.dtype)).unflatten(-1, (-1, SPLIT))._cumalu(-1, op) base = ret[..., -1]._cumalu(-1, op, _include_initial=True) base = base.unsqueeze(-1).expand(*base.shape, ret.shape[-1]) def fix(x: Tensor) -> Tensor: return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1) return {Ops.ADD: Tensor.__add__, Ops.MAX: Tensor.maximum, Ops.MUL: Tensor.__mul__}[op](fix(ret), fix(base)) def cumsum(self, axis:int=0) -> Tensor: """ Computes the cumulative sum of the tensor along the specified `axis`. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.ones(2, 3) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.cumsum(1).numpy()) ``` """ return self._split_cumalu(axis, Ops.ADD) def cumprod(self, axis:int) -> Tensor: """ Computes the cumulative product of the elements of the tensor along the specified `axis`. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(1, 7).reshape(2, 3) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.cumprod(axis=0).numpy()) ``` """ return self._split_cumalu(axis, Ops.MUL) def cummax(self, axis:int=0) -> Tensor: """ Computes the cumulative max of the tensor along the specified `axis`. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([0, 1, -1, 2, -2, 3, -3]) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.cummax(0).numpy()) ``` """ return self._split_cumalu(axis, Ops.MAX) @staticmethod def _tri(r:sint, c:sint, diagonal:int=0, **kwargs) -> Tensor: assert isinstance(r, int) and isinstance(c, int), f"does not support symbolic, getting {r=}, {c=}" if r == 0 or c == 0 or diagonal >= c: return Tensor.zeros(r,c,**kwargs) if r+diagonal <= 0: return Tensor.ones(r,c,**kwargs) s = r+c-1 # build a (s, s) upper triangle t = Tensor.ones(s,s,**kwargs).pad((None,(0,s))).flatten().shrink(((0,s*(2*s-1)),)).reshape(s,-1).shrink((None,(0,s))) return t[:r,-diagonal:c-diagonal] if diagonal <= 0 else t[diagonal:r+diagonal,:c] def triu(self, diagonal:int=0) -> Tensor: """ Returns the upper triangular part of the tensor, the other elements are set to 0. The argument `diagonal` determines which diagonal is on the boundary. `diagonal = 0` means the main diagonal. Positive `diagonal` means above the main diagonal, and negative `diagonal` means below the main diagonal. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.triu(diagonal=0).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.triu(diagonal=1).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.triu(diagonal=-1).numpy()) ``` """ return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal, device=self.device, dtype=dtypes.bool).where(self, 0).cast(self.dtype) def tril(self, diagonal:int=0) -> Tensor: """ Returns the lower triangular part of the tensor, the other elements are set to 0. The argument `diagonal` determines which diagonal is on the boundary. `diagonal = 0` means the main diagonal. Positive `diagonal` means above the main diagonal, and negative `diagonal` means below the main diagonal. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.tril(diagonal=0).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.tril(diagonal=1).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.tril(diagonal=-1).numpy()) ``` """ return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device, dtype=dtypes.bool).where(0, self).cast(self.dtype) def interpolate(self, size:tuple[int, ...], mode:str="linear", align_corners:bool=False) -> Tensor: """ Downsamples or Upsamples to the input `size`, accepts 0 to N batch dimensions. The interpolation algorithm is selected with `mode` which currently only supports `linear`, `nearest` and `nearest-exact`. To run `bilinear` or `trilinear`, pass in a 2D or 3D size. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([[1, 2, 3, 4], [21, 22, 23, 24], [41, 42, 43, 44]]) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.interpolate(size=(2,3), mode="linear").numpy()) ``` """ assert isinstance(size, (tuple,list)) and all_int(size) and 0 < len(size) <= self.ndim, f"invalid {size=}" assert mode in ("linear", "nearest", "nearest-exact"), "only supports linear, nearest or nearest-exact interpolate" assert not (align_corners and mode != "linear"), "align_corners option can only be set with the interpolating mode linear" x, expand = self, list(self.shape) for i in range(-1,-len(size)-1,-1): scale = (self.shape[i] - int(align_corners)) / (size[i] - int(align_corners)) arr, reshape = Tensor.arange(size[i], dtype=dtypes.float32, device=self.device), [1] * self.ndim reshape[i] = expand[i] = size[i] if mode == "linear": index = (scale*arr if align_corners else (scale*(arr+0.5))-0.5).clip(0, self.shape[i]-1) low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor().int(), index.ceil().int(), index - index.floor())] x = x.gather(i, low).lerp(x.gather(i, high), perc) else: index = (scale*(arr+0.5) if mode=="nearest-exact" else scale*arr).cast(dtypes.int32).reshape(reshape).expand(expand) x = x.gather(i, index) return x.cast(self.dtype) def _pre_scatter(self, dim:int, index:Tensor, src:Tensor) -> tuple[Tensor, Tensor]: index, dim = index.to(self.device), self._resolve_dim(dim) assert index.ndim == self.ndim == src.ndim, f"self.ndim, index.ndim and src.dim must all equal, {self.ndim=} {index.ndim=} {src.ndim=}" assert all((d == dim or self_ >= index_) and src_ >= index_ for d,(self_,index_,src_) in enumerate(zip(self.shape, index.shape, src.shape))), \ f"All dimensions of {index.shape=} should be <= to all dimensions of {src.shape=} and all dimensions except dimension {dim} of {self.shape=}" if self.dtype != src.dtype: raise RuntimeError(f"expect {self.dtype=} to be equal to {src.dtype=}") # shrink src to index shape to shrink away the unused values src = src.shrink(tuple((0,s) for s in index.shape)) # prepare src and mask for reduce with respect to dim src = src.unsqueeze(-1).expand(*src.shape, self.shape[dim]).transpose(-1, dim) mask = index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).transpose(-1, dim) # pad src and mask to self.shape so that reduce can be done with padded values as no-ops src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask)) return src, mask def scatter(self, dim:int, index:Tensor, src:Tensor|ConstType, reduce:Literal['multiply', 'add']|None=None) -> Tensor: """ Scatters `src` values along an axis specified by `dim`. Apply `add` or `multiply` reduction operation with `reduce`. NOTE: To use the `reduce` argument with a Tensor `src`, see `Tensor.scatter_reduce`. ```python exec="true" source="above" session="tensor" result="python" src = Tensor.arange(1, 11).reshape(2, 5) print(src.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" index = Tensor([[0, 1, 2, 0]]) print(Tensor.zeros(3, 5, dtype=src.dtype).scatter(0, index, src).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" index = Tensor([[0, 1, 2], [0, 1, 4]]) print(Tensor.zeros(3, 5, dtype=src.dtype).scatter(1, index, src).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor.full((2, 4), 2.0).scatter(1, Tensor([[2], [3]]), 1.23, reduce='multiply').numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor.full((2, 4), 2.0).scatter(1, Tensor([[2], [3]]), 1.23, reduce='add').numpy()) ``` """ if reduce not in {None, "add", "multiply"}: raise TypeError(f"{reduce=} must be one of None, 'multiply', or 'add'") if reduce and isinstance(src, Tensor): raise TypeError("Tensor src is not supported with reduce arg. see scatter_reduce") if not isinstance(src, Tensor): src = index.full_like(src, device=self.device, dtype=self.dtype) if reduce == "add": return self.scatter_reduce(dim, index, src, "sum", include_self=True) if reduce == "multiply": return self.scatter_reduce(dim, index, src, "prod", include_self=True) src, mask = self._pre_scatter(dim, index, src) return _masked_setitem(self, src, mask, (-1,)) def scatter_reduce(self, dim:int, index:Tensor, src:Tensor, reduce:Literal["sum", "prod", "mean", "amax", "amin"], include_self:bool=True) -> Tensor: """ Scatters `src` values along an axis specified by `dim`. Apply `"sum"`, `"prod"`, `"mean"`, `"amax"`, or `"amin"` reduction operations with `reduce`. Set `include_self=False` to exclude values in the `self` Tensor from the reduction. ```python exec="true" source="above" session="tensor" result="python" src = Tensor.arange(1, 11).cast(dtypes.float).reshape(2, 5) print(src.numpy()) index = Tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]) print(index.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='sum').numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='prod').numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='mean', include_self=False).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor([[-10, 20, 0, 5, 10]], dtype=src.dtype).scatter_reduce(0, index, src, reduce='amax').numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor([[-10, 20, 0, 5, 10]], dtype=src.dtype).scatter_reduce(0, index, src, reduce='amin').numpy()) ``` """ src, mask = self._pre_scatter(dim, index, src) def _inv_mask(a:Tensor|ConstType, b:Tensor|ConstType) -> Tensor: return mask.any(-1).logical_not().where(a, b) # TODO: should not overwrite dtype here? if reduce == "sum": return mask.where(src, 0).sum(-1, dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)) if reduce == "prod": return mask.where(src, 1).prod(-1, dtype=self.dtype).mul(self if include_self else _inv_mask(self, 1)) if reduce == "amax": return mask.where(src, m := dtypes.min(src.dtype)).max(-1).maximum(self if include_self else _inv_mask(self, m)) if reduce == "amin": return mask.where(src, m := dtypes.max(src.dtype)).min(-1).minimum(self if include_self else _inv_mask(self, m)) if reduce == "mean": count = mask.where(1, 0).sum(-1, dtype=self.dtype).add(1 if include_self else _inv_mask(1, 0)) return mask.where(src, 0).sum(-1, dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)).div(count) raise RuntimeError(f"{reduce=} must be one of 'sum', 'prod', 'mean', 'amax', 'amin'") def sort(self, dim:int=-1, descending:bool=False) -> tuple[Tensor, Tensor]: """ Performs a bitonic sort on the tensor along the specified dimension. Order of indices for equivalent elements is always preserved. See: https://en.wikipedia.org/wiki/Bitonic_sorter ```python exec="true" source="above" session="tensor" result="python" t = Tensor([[0.1, 0.5, 1.2, 3.4, 2.1], [2.2, 1.9, 0.3, 4.5, 0.8]]) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" sorted_values, indices = t.sort(dim=1, descending=True) print(sorted_values.numpy()) print(indices.numpy()) ``` """ x, dim = self, self._resolve_dim(dim) # pad to power of 2 orig_len = x.shape[dim] n_stages = math.ceil(math.log2(orig_len)) fill_value = dtypes.min(x.dtype) if descending else dtypes.max(x.dtype) pads = tuple((0, 2**n_stages - orig_len) if i == dim else None for i in range(x.ndim)) x = x.pad(pads, value=fill_value).unflatten(dim, (2,)*n_stages) # https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort1.svg for stage in range(1, n_stages+1): if stage != n_stages: # flip so arrows of green boxes point the same way as blue boxes crossover_dim = dim + n_stages - stage - 1 blue_box, green_box = x.split(1, crossover_dim) flip_dims = tuple(-i for i in range(1, stage+1+(self.ndim-dim))) x = (blue_box.cat(green_box.flip(flip_dims), dim=crossover_dim)).contiguous() for substage in range(stage-1, -1, -1): partner_dim = dim + n_stages - substage - 1 x_top, x_bottom = x.split(1, partner_dim) x_larger, x_smaller = x_top.maximum(x_bottom), x_top.minimum(x_bottom) x = (x_larger.cat(x_smaller, dim=partner_dim) if descending else x_smaller.cat(x_larger, dim=partner_dim)).contiguous() if stage != n_stages: # flip wires back to undo the crossover blue_box, flipped_green_box = x.split(1, crossover_dim) x = blue_box.cat(flipped_green_box.flip(flip_dims), dim=crossover_dim) x = x.flatten(dim, dim+n_stages-1).shrink(tuple((0, orig_len) if i == dim else None for i in range(x.ndim))) # compute indices for sorted values idx = Tensor.arange(orig_len, requires_grad=False, device=self.device).reshape(tuple(orig_len if i == dim else 1 for i in range(x.ndim))) idx = idx.expand(x.shape) def compute_counts(t:Tensor): return ((idx.unsqueeze(dim) <= idx.unsqueeze(dim+1)) & (t.unsqueeze(dim) == t.unsqueeze(dim+1))).sum(dim+1) count_orig, count_sorted = compute_counts(self), compute_counts(x) cond = (self.unsqueeze(dim+1) == x.unsqueeze(dim)) & (count_orig.unsqueeze(dim+1) == count_sorted.unsqueeze(dim)) idx = (cond * idx.unsqueeze(dim+1)).sum(dim) return x, idx def topk(self, k:int, dim:int=-1, largest:bool=True, sorted_:bool=True) -> tuple[Tensor, Tensor]: """ Computes the top-k elements of the tensor along the specified `dim`. Order of indices for equivalent elements is always preserved. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([[0.1, 0.5, 1.2, 3.4, 2.1], [2.2, 1.9, 0.3, 4.5, 0.8]]) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" topk_values, topk_indices = t.topk(2, dim=1) print(topk_values.numpy()) print(topk_indices.numpy()) ``` """ if not sorted_: raise NotImplementedError("topk with sorted_=False is not supported") if k > self.shape[dim:=self._resolve_dim(dim)]: raise ValueError(f"selected index {k=} is out of range") x, idx = self.sort(dim, descending=largest) shrink_to_k = tuple((0, k) if i == dim else None for i in range(self.ndim)) return x.shrink(shrink_to_k), idx.shrink(shrink_to_k) # ***** unary ops ***** def logical_not(self) -> Tensor: """ Computes the logical NOT of the tensor element-wise. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([False, True]).logical_not().numpy()) ``` """ return self.cast(dtypes.bool)._apply_broadcasted_uop(UOp.ne, True) def neg(self) -> Tensor: """ Negates the tensor element-wise. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).neg().numpy()) ``` """ return self*-1 if self.dtype != dtypes.bool else self.logical_not() def contiguous(self) -> Tensor: """ Returns a contiguous tensor. """ return self._apply_uop(UOp.contiguous) def contiguous_backward(self) -> Tensor: """ Inserts a contiguous operation in the backward pass. """ return self._apply_uop(UOp.contiguous_backward) def log(self) -> Tensor: """ Computes the natural logarithm element-wise. See: https://en.wikipedia.org/wiki/Logarithm ```python exec="true" source="above" session="tensor" result="python" print(Tensor([1., 2., 4., 8.]).log().numpy()) ``` """ return self.log2()*math.log(2) def log2(self) -> Tensor: """ Computes the base-2 logarithm element-wise. See: https://en.wikipedia.org/wiki/Logarithm ```python exec="true" source="above" session="tensor" result="python" print(Tensor([1., 2., 4., 8.]).log2().numpy()) ``` """ return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.log2) def exp(self) -> Tensor: """ Computes the exponential function element-wise. See: https://en.wikipedia.org/wiki/Exponential_function ```python exec="true" source="above" session="tensor" result="python" print(Tensor([0., 1., 2., 3.]).exp().numpy()) ``` """ return self.mul(1/math.log(2)).exp2() def exp2(self) -> Tensor: """ Computes the base-2 exponential function element-wise. See: https://en.wikipedia.org/wiki/Exponential_function ```python exec="true" source="above" session="tensor" result="python" print(Tensor([0., 1., 2., 3.]).exp2().numpy()) ``` """ return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.exp2) def relu(self) -> Tensor: """ Applies the Rectified Linear Unit (ReLU) function element-wise. - Described: https://paperswithcode.com/method/relu ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).relu().numpy()) ``` """ return (self>0).where(self, 0) def sigmoid(self) -> Tensor: """ Applies the Sigmoid function element-wise. - Described: https://en.wikipedia.org/wiki/Sigmoid_function ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sigmoid().numpy()) ``` """ return (1 + (self * (-1/math.log(2))).exp2()).reciprocal() def hardsigmoid(self, alpha:float=1/6, beta:float=0.5) -> Tensor: """ Applies the Hardsigmoid function element-wise. NOTE: default `alpha` and `beta` values is taken from torch - Described: https://paperswithcode.com/method/hard-sigmoid - See: https://pytorch.org/docs/stable/generated/torch.nn.functional.hardsigmoid.html ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).hardsigmoid().numpy()) ``` """ return (alpha * self + beta).relu() - (alpha * self + beta - 1).relu() def sqrt(self) -> Tensor: """ Computes the square root of the tensor element-wise. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([1., 2., 3., 4.]).sqrt().numpy()) ``` """ return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sqrt) def rsqrt(self) -> Tensor: """ Computes the reciprocal of the square root of the tensor element-wise. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([1., 2., 3., 4.]).rsqrt().numpy()) ``` """ return self.sqrt().reciprocal() def sin(self) -> Tensor: """ Computes the sine of the tensor element-wise. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).sin().numpy()) ``` """ return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sin) def cos(self) -> Tensor: """ Computes the cosine of the tensor element-wise. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).cos().numpy()) ``` """ return ((math.pi/2)-self).sin() def tan(self) -> Tensor: """ Computes the tangent of the tensor element-wise. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([0., math.pi/4, math.pi/2, 3*math.pi/4, math.pi]).tan().numpy()) ``` """ return self.sin() / self.cos() def asin(self) -> Tensor: """ Computes the inverse sine (arcsine) of the tensor element-wise. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).asin().numpy()) ``` """ # https://personal.math.ubc.ca/~cbm/aands/page_81.htm 4.4.46 coefficients = [-0.0012624911, 0.0066700901, -0.0170881256, 0.0308918810, -0.0501743046, 0.0889789874, -0.2145988016, 1.5707963050] x = math.pi / 2 - (1.0 - self.abs()).sqrt() * polyN(self.abs(), coefficients) return self.sign() * x def acos(self) -> Tensor: """ Computes the inverse cosine (arccosine) of the tensor element-wise. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).acos().numpy()) ``` """ return math.pi / 2 - self.asin() def atan(self) -> Tensor: """ Computes the inverse tangent (arctan) of the tensor element-wise. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).atan().numpy()) ``` """ return (self / (1 + self * self).sqrt()).asin() # ***** math functions ***** def trunc(self: Tensor) -> Tensor: """ Truncates the tensor element-wise. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).trunc().numpy()) ``` """ return self.cast(dtypes.int32).cast(self.dtype) def ceil(self: Tensor) -> Tensor: """ Rounds the tensor element-wise towards positive infinity. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).ceil().numpy()) ``` """ return (self > (b := self.trunc())).where(b+1, b) def floor(self: Tensor) -> Tensor: """ Rounds the tensor element-wise towards negative infinity. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).floor().numpy()) ``` """ return (self < (b := self.trunc())).where(b-1, b) def round(self: Tensor) -> Tensor: """ Rounds the tensor element-wise with rounding half to even. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).round().numpy()) ``` """ return ((self > 0) == ((b := self.cast(dtypes.int32) / 2.0).cast(dtypes.int32) == b)).where((self - 0.5).ceil(), (self + 0.5).floor()) def isinf(self:Tensor, detect_positive:bool=True, detect_negative:bool=True) -> Tensor: """ Checks the tensor element-wise to return True where the element is infinity, otherwise returns False ```python exec="true" source="above" session="tensor" result="python" print(Tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isinf().numpy()) ``` """ return (self == float("inf")) * detect_positive + (self == float("-inf")) * detect_negative def isnan(self:Tensor) -> Tensor: """ Checks the tensor element-wise to return True where the element is NaN, otherwise returns False ```python exec="true" source="above" session="tensor" result="python" print(Tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isnan().numpy()) ``` """ return self != self def isfinite(self:Tensor) -> Tensor: """ Checks the tensor element-wise to return True where the element is finite, otherwise returns False ```python exec="true" source="above" session="tensor" result="python" print(Tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isfinite().numpy()) ``` """ return (self.isinf()|self.isnan()).logical_not() def lerp(self, end:Tensor, weight:Tensor|float) -> Tensor: """ Linearly interpolates between `self` and `end` by `weight`. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([1., 2., 3.]).lerp(Tensor([4., 5., 6.]), 0.5).numpy()) ``` """ if self.dtype == dtypes.uint8 and isinstance(weight, Tensor): w_i = (weight * (1<<(W_PREC:=7)) + 0.5).cast(dtypes.int16) return (self+(((end - self).cast(dtypes.int8) * w_i + (1<> W_PREC)).cast(dtypes.uint8) return self + (end - self) * weight def square(self) -> Tensor: """ Squares the tensor element-wise. Equivalent to `self*self`. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).square().numpy()) ``` """ return self*self def clamp(self, min_=None, max_=None) -> Tensor: """ Clips (clamps) the values in the tensor between `min_` and `max_` element-wise. If `min_` is `None`, there is no lower bound. If `max_` is None, there is no upper bound. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).clip(-1, 1).numpy()) ``` """ if min_ is None and max_ is None: raise RuntimeError("at least one of 'min_' or 'max_' must not be None") ret = self.maximum(min_) if min_ is not None else self return ret.minimum(max_) if max_ is not None else ret def clip(self, min_=None, max_=None) -> Tensor: """ Alias for `Tensor.clamp`. """ return self.clamp(min_, max_) def sign(self) -> Tensor: """ Returns the sign of the tensor element-wise. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sign().numpy()) ``` """ return self.ne(0).where((self<0).where(self.full_like(-1), self.full_like(1)), self.full_like(0)) + self*0 def abs(self) -> Tensor: """ Computes the absolute value of the tensor element-wise. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).abs().numpy()) ``` """ return self * self.sign() def reciprocal(self) -> Tensor: """ Compute `1/x` element-wise. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([1., 2., 3., 4.]).reciprocal().numpy()) ``` """ return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.reciprocal) # ***** activation functions ***** def elu(self, alpha=1.0) -> Tensor: """ Applies the Exponential Linear Unit (ELU) function element-wise. - Described: https://paperswithcode.com/method/elu - Paper: https://arxiv.org/abs/1511.07289v5 ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).elu().numpy()) ``` """ return self.relu() - alpha*(1-self.exp()).relu() def celu(self, alpha=1.0) -> Tensor: """ Applies the Continuously differentiable Exponential Linear Unit (CELU) function element-wise. - Described: https://paperswithcode.com/method/celu - Paper: https://arxiv.org/abs/1704.07483 ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).celu().numpy()) ``` """ return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0) def selu(self, alpha=1.67326, gamma=1.0507) -> Tensor: """ Applies the Scaled Exponential Linear Unit (SELU) function element-wise. - Described: https://paperswithcode.com/method/selu - Paper: https://arxiv.org/abs/1706.02515v5 ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).selu().numpy()) ``` """ return gamma * (self >= 0).detach().where(self, alpha * (self.exp() - 1)) def swish(self) -> Tensor: """ See `.silu()` - Paper: https://arxiv.org/abs/1710.05941v1 ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).swish().numpy()) ``` """ return self * self.sigmoid() def silu(self) -> Tensor: """ Applies the Sigmoid Linear Unit (SiLU) function element-wise. - Described: https://paperswithcode.com/method/silu - Paper: https://arxiv.org/abs/1606.08415 ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).silu().numpy()) ``` """ return self.swish() # The SiLU function is also known as the swish function. def relu6(self) -> Tensor: """ Applies the ReLU6 function element-wise. - Described: https://paperswithcode.com/method/relu6 - Paper: https://arxiv.org/abs/1704.04861v1 ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-9., -6., -3., 0., 3., 6., 9.]).relu6().numpy()) ``` """ return self.relu() - (self-6).relu() def hardswish(self) -> Tensor: """ Applies the Hardswish function element-wise. - Described: https://paperswithcode.com/method/hard-swish - Paper: https://arxiv.org/abs/1905.02244v5 ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).hardswish().numpy()) ``` """ return self * (self+3).relu6() * (1/6) def tanh(self) -> Tensor: """ Applies the Hyperbolic Tangent (tanh) function element-wise. - Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Tanh ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).tanh().numpy()) ``` """ return 2.0 * ((2.0 * self).sigmoid()) - 1.0 def sinh(self) -> Tensor: """ Applies the Hyperbolic Sine (sinh) function element-wise. - Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Sinh ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sinh().numpy()) ``` """ return (self.exp() - self.neg().exp()) / 2 def cosh(self) -> Tensor: """ Applies the Hyperbolic Cosine (cosh) function element-wise. - Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Cosh ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).cosh().numpy()) ``` """ return (self.exp() + self.neg().exp()) / 2 def atanh(self) -> Tensor: """ Applies the Inverse Hyperbolic Tangent (atanh) function element-wise. - Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#atanh ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).atanh().numpy()) ``` """ return ((1 + self)/(1 - self)).log() / 2 def asinh(self) -> Tensor: """ Applies the Inverse Hyperbolic Sine (asinh) function element-wise. - Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#asinh ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).asinh().numpy()) ``` """ return (self + (self.square() + 1).sqrt()).log() def acosh(self) -> Tensor: """ Applies the Inverse Hyperbolic Cosine (acosh) function element-wise. - Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#acosh ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).acosh().numpy()) ``` """ return (self + (self.square() - 1).sqrt()).log() def hardtanh(self, min_val=-1, max_val=1) -> Tensor: """ Applies the Hardtanh function element-wise. - Described: https://paperswithcode.com/method/hardtanh-activation ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-1.5, -1.0, -0.5, 0., 0.5, 1.0, 1.5]).hardtanh().numpy()) ``` """ return self.clip(min_val, max_val) def erf(self) -> Tensor: """ Applies error function element-wise. - Described: https://en.wikipedia.org/wiki/Error_function ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-1.5, -1.0, -0.5, 0., 0.5, 1.0, 1.5]).erf().numpy()) ``` """ # https://personal.math.ubc.ca/~cbm/aands/page_299.htm 7.1.26 t = 1.0 / (1.0 + 0.3275911 * self.abs()) return self.sign() * (1.0 - t * polyN(t, [1.061405429, -1.453152027, 1.421413741, -0.284496736, 0.254829592]) * (-self.square()).exp()) def gelu(self) -> Tensor: """ Applies the Gaussian Error Linear Unit (GELU) function element-wise. - Described: https://paperswithcode.com/method/gelu - Paper: https://arxiv.org/abs/1606.08415v5 ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).gelu().numpy()) ``` """ return 0.5 * self * (1 + (math.sqrt(2 / math.pi) * (self + 0.044715 * self ** 3)).tanh()) def quick_gelu(self) -> Tensor: """ Applies the Sigmoid GELU approximation element-wise. - Described: https://paperswithcode.com/method/gelu ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).quick_gelu().numpy()) ``` """ return self * (self * 1.702).sigmoid() def leaky_relu(self, neg_slope=0.01) -> Tensor: """ Applies the Leaky ReLU function element-wise. - Described: https://paperswithcode.com/method/leaky-relu ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leaky_relu().numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leaky_relu(neg_slope=0.42).numpy()) ``` """ return (self<0).where(neg_slope*self, self) def mish(self) -> Tensor: """ Applies the Mish function element-wise. - Described: https://paperswithcode.com/method/mish - Paper: https://arxiv.org/abs/1908.08681v3 ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).mish().numpy()) ``` """ return self * self.softplus().tanh() def softplus(self, beta=1) -> Tensor: """ Applies the Softplus function element-wise. - Described: https://paperswithcode.com/method/softplus ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).softplus().numpy()) ``` """ return (1/beta) * (1 + (self*beta).exp()).log() def softsign(self) -> Tensor: """ Applies the Softsign function element-wise. - Described: https://paperswithcode.com/method/softsign ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).softsign().numpy()) ``` """ return self / (1 + self.abs()) # ***** broadcasted elementwise ops ***** def _broadcast_to(self, new_shape:tuple[sint, ...]) -> Tensor: if self.shape == new_shape: return self if self.ndim > len(new_shape): raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {new_shape=}") # first unsqueeze left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html shape, _ = _align_left(self.shape, new_shape) # for each dimension, check either dim is 1, or it does not change if not all(resolve(s == ns) or resolve(s == 1) for s,ns in zip(shape, new_shape)): raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}") return self.reshape(shape)._apply_uop(UOp.expand, arg=new_shape) def _broadcasted(self, y:Tensor|ConstType|UOp, reverse:bool=False, match_dtype:bool=True) -> tuple[Tensor, Tensor]: x: Tensor = self if not isinstance(y, Tensor): # make y a Tensor assert isinstance(y, (*get_args(ConstType), UOp)), f"{type(y)=}, {y=}" if isinstance(x.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype elif not isinstance(y, UOp): y_dtype = dtypes.from_py(y) if isinstance(y, UOp): y = Tensor.from_uop(y, device=x.device) else: y = Tensor(dtypes.as_const(y, y_dtype), x.device, y_dtype, requires_grad=False) if match_dtype and x.dtype != y.dtype: output_dtype = least_upper_dtype(x.dtype, y.dtype) x, y = x.cast(output_dtype), y.cast(output_dtype) if reverse: x, y = y, x # broadcast return x._broadcast_to(out_shape:=_broadcast_shape(x.shape, y.shape)), y._broadcast_to(out_shape) def add(self, x:Tensor|ConstType, reverse=False) -> Tensor: """ Adds `self` and `x`. Equivalent to `self + x`. Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs. ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) t = Tensor.randn(4) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.add(20).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.add(Tensor([[2.0], [3.5]])).numpy()) ``` """ return self._apply_broadcasted_uop(UOp.add, x, reverse) def sub(self, x:Tensor|ConstType, reverse=False) -> Tensor: """ Subtracts `x` from `self`. Equivalent to `self - x`. Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs. ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) t = Tensor.randn(4) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.sub(20).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.sub(Tensor([[2.0], [3.5]])).numpy()) ``` """ a, b = self._broadcasted(x, reverse) return a + (-b) def mul(self, x:Tensor|ConstType, reverse=False) -> Tensor: """ Multiplies `self` and `x`. Equivalent to `self * x`. Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs. ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) t = Tensor.randn(4) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.mul(3).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.mul(Tensor([[-1.0], [2.0]])).numpy()) ``` """ return self._apply_broadcasted_uop(UOp.mul, x, reverse) def idiv(self, x:Tensor|ConstType, reverse=False) -> Tensor: """ Divides `self` by `x`. Equivalent to `self // x`. Supports broadcasting to a common shape, type promotion, and integer inputs. `idiv` performs integer division (truncate towards zero). ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-4, 7, 5, 4, -7, 8]).idiv(Tensor([2, -3, 8, -2, 3, 5])).numpy()) ``` """ return self._apply_broadcasted_uop(UOp.idiv, x, reverse) def div(self, x:Tensor|ConstType, reverse=False, rounding_mode:Literal["trunc", "floor"]|None=None) -> Tensor: """ Divides `self` by `x`. Equivalent to `self / x`. Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs. `div` performs true division. ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) t = Tensor.randn(4) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.div(3).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor([1, 4, 10]).div(Tensor([2, 3, 4])).numpy()) ``` """ numerator, denominator = self._broadcasted(x, reverse) d = numerator.cast(least_upper_float(numerator.dtype)) * denominator.cast(least_upper_float(denominator.dtype)).reciprocal() output_dtype = numerator.dtype if dtypes.is_int(numerator.dtype) else d.dtype if rounding_mode == "trunc": return d.trunc().cast(output_dtype) if rounding_mode == "floor": return d.floor().cast(output_dtype) if rounding_mode is not None: raise RuntimeError(f"{rounding_mode=} is not supported") return d def mod(self, x:Tensor|ConstType, reverse=False) -> Tensor: """ Mod `self` by `x`. Equivalent to `self % x`. Supports broadcasting to a common shape, type promotion, and integer inputs. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-4, 7, 5, 4, -7, 8]).mod(Tensor([2, -3, 8, -2, 3, 5])).numpy()) ``` """ a, b = self._broadcasted(x, reverse) return a - a.div(b, rounding_mode="floor") * b def bitwise_xor(self, x:Tensor|ConstType, reverse=False) -> Tensor: """ Computes bitwise xor of `self` and `x`. Equivalent to `self ^ x`. Supports broadcasting to a common shape, type promotion, and integer, boolean inputs. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-1, -2, 3]).bitwise_xor(Tensor([1, 0, 3])).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor([True, True, False, False]).bitwise_xor(Tensor([True, False, True, False])).numpy()) ``` """ if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported") return self._apply_broadcasted_uop(UOp.bitwise_xor, x, reverse) def bitwise_and(self, x:Tensor|ConstType, reverse=False) -> Tensor: """ Compute the bitwise AND of `self` and `x`. Equivalent to `self & x`. Supports broadcasting to a common shape, type promotion, and integer, boolean inputs. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([2, 5, 255]).bitwise_and(Tensor([3, 14, 16])).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor([True, True, False, False]).bitwise_and(Tensor([True, False, True, False])).numpy()) ``` """ if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported") return self._apply_broadcasted_uop(UOp.bitwise_and, x, reverse) def bitwise_or(self, x:Tensor|ConstType, reverse=False) -> Tensor: """ Compute the bitwise OR of `self` and `x`. Equivalent to `self | x`. Supports broadcasting to a common shape, type promotion, and integer, boolean inputs. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([2, 5, 255]).bitwise_or(Tensor([4, 4, 4])).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor([True, True, False, False]).bitwise_or(Tensor([True, False, True, False])).numpy()) ``` """ if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported") return self._apply_broadcasted_uop(UOp.bitwise_or, x, reverse) def bitwise_not(self) -> Tensor: """ Compute the bitwise NOT of `self`. Equivalent to `~self`. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([0, 2, 5, 255], dtype="int8").bitwise_not().numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor([True, False]).bitwise_not().numpy()) ``` """ if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported") return self.logical_not() if self.dtype == dtypes.bool else self ^ -1 def lshift(self, x:int) -> Tensor: """ Computes left arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype. Equivalent to `self << x`. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([1, 3, 31], dtype=dtypes.uint8).lshift(2).numpy()) ``` """ assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}" return self.mul(2 ** x) def rshift(self, x:int) -> Tensor: """ Computes right arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype. Equivalent to `self >> x`. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([4, 13, 125], dtype=dtypes.uint8).rshift(2).numpy()) ``` """ assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}" return self.idiv(2 ** x) def pow(self, x:Tensor|ConstType, reverse=False) -> Tensor: """ Computes power of `self` with `x`. Equivalent to `self ** x`. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-1, 2, 3]).pow(2.0).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-1, 2, 3]).pow(Tensor([-1.5, 0.5, 1.5])).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print((2.0 ** Tensor([-1, 2, 3])).numpy()) ``` """ base, exponent = self._broadcasted(x, reverse=reverse) # TODO: int pow if not base.is_floating_point(): raise RuntimeError("base needs to be float") # NOTE: pow(int, float) -> int ret = base._apply_uop(UOp.pow, exponent) return ret.round().cast(self.dtype) if not dtypes.is_float(self.dtype) else ret def maximum(self, x:Tensor|ConstType) -> Tensor: """ Computes element-wise maximum of `self` and `x`. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-1, 2, 3]).maximum(1).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-1, 2, 3]).maximum(Tensor([-4, -2, 9])).numpy()) ``` """ return self._apply_broadcasted_uop(UOp.maximum, x) def minimum(self, x:Tensor|ConstType) -> Tensor: """ Computes element-wise minimum of `self` and `x`. ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-1, 2, 3]).minimum(1).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print(Tensor([-1, 2, 3]).minimum(Tensor([-4, -2, 9])).numpy()) ``` """ t, x = self._broadcasted(x) return t._inverse().maximum(x._inverse())._inverse() def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint) -> Tensor: """ Return a tensor of elements selected from either `x` or `y`, depending on `self`. `output_i = x_i if self_i else y_i`. ```python exec="true" source="above" session="tensor" result="python" cond = Tensor([[True, True, False], [True, False, False]]) print(cond.where(1, 3).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) cond = Tensor.randn(2, 3) print(cond.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" print((cond > 0).where(cond, -float("inf")).numpy()) ``` """ if isinstance(x, Tensor): x, y = x._broadcasted(y) elif isinstance(y, Tensor): y, x = y._broadcasted(x) cond, x = self._broadcasted(x, match_dtype=False) cond, y = cond._broadcasted(y, match_dtype=False) return cond.cast(dtypes.bool)._apply_uop(UOp.where, *x._broadcasted(y)) def masked_fill(self:Tensor, mask:Tensor, value:Tensor|ConstType) -> Tensor: return mask.where(value, self) def copysign(self, other) -> Tensor: """ Return a tensor of with the magnitude of `self` and the sign of `other`, elementwise. """ # NOTE: torch always return in float, we return based on the broadcasting rule. other = self._broadcasted(other)[1] # TODO: remove other*0? return (other < 0).where(-self.abs(), self.abs()) + other*0 # ***** op wrappers ***** def __invert__(self) -> Tensor: return self.bitwise_not() def __lshift__(self, x:int) -> Tensor: return self.lshift(x) def __rshift__(self, x:int) -> Tensor: return self.rshift(x) def __pow__(self, x) -> Tensor: return self.pow(x) def __matmul__(self, x) -> Tensor: return self.matmul(x) def __rpow__(self, x) -> Tensor: return self.pow(x, True) def __rmatmul__(self, x) -> Tensor: return self.matmul(x, True) def __iadd__(self, x) -> Tensor: return self.assign(self.add(x)) def __isub__(self, x) -> Tensor: return self.assign(self.sub(x)) def __imul__(self, x) -> Tensor: return self.assign(self.mul(x)) def __ipow__(self, x) -> Tensor: return self.assign(self.pow(x)) def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x)) def __ifloordiv__(self, x) -> Tensor: return self.assign(self.idiv(x)) def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x)) def __iand__(self, x) -> Tensor: return self.assign(self.bitwise_and(x)) def __ior__(self, x) -> Tensor: return self.assign(self.bitwise_or(x)) def __ixor__(self, x) -> Tensor: return self.assign(self.bitwise_xor(x)) def __ilshift__(self, x:int) -> Tensor: return self.assign(self.lshift(x)) def __irshift__(self, x:int) -> Tensor: return self.assign(self.rshift(x)) def __lt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, False) def __gt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, True) def ne(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.ne, x, False) def __eq__(self, x) -> Tensor: return self.eq(x) # type: ignore[override] # ***** functional nn ops ***** def linear(self, weight:Tensor, bias:Tensor|None=None) -> Tensor: """ Applies a linear transformation to `self` using `weight` and `bias`. See: https://pytorch.org/docs/stable/generated/torch.nn.Linear.html ```python exec="true" source="above" session="tensor" result="python" t = Tensor([[1, 2], [3, 4]]) weight = Tensor([[1, 2], [3, 4]]) bias = Tensor([1, 2]) print(t.linear(weight, bias).numpy()) ``` """ x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight) return x.add(bias) if bias is not None else x def sequential(self, ll:list[Callable[[Tensor], Tensor]]) -> Tensor: """ Applies a sequence of functions to `self` chaining the output of each function to the input of the next. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([1, 2, 3]) print(t.sequential([lambda x: x * 2, lambda x: x + 1]).numpy()) ``` """ return functools.reduce(lambda x,f: f(x), ll, self) def layernorm(self, axis:int|tuple[int,...]=-1, eps:float=1e-5) -> Tensor: """ Applies Layer Normalization over a mini-batch of inputs. - Described: https://paperswithcode.com/method/layer-normalization - Paper: https://arxiv.org/abs/1607.06450v1 ```python exec="true" source="above" session="tensor" result="python" t = Tensor.randn(8, 10, 16) * 2 + 8 print(t.mean().item(), t.std().item()) ``` ```python exec="true" source="above" session="tensor" result="python" t = t.layernorm() print(t.mean().item(), t.std().item()) ``` """ y = (self - self.mean(axis, keepdim=True)) return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt()) def batchnorm(self, weight:Tensor|None, bias:Tensor|None, mean:Tensor, invstd:Tensor, axis:int|tuple[int, ...]=1) -> Tensor: """ Applies Batch Normalization over a mini-batch of inputs. - Described: https://paperswithcode.com/method/batch-normalization - Paper: https://arxiv.org/abs/1502.03167 ```python exec="true" source="above" session="tensor" result="python" t = Tensor.randn(8, 4, 16, 16) * 2 + 8 print(t.mean().item(), t.std().item()) ``` ```python exec="true" source="above" session="tensor" result="python" t = t.batchnorm(None, None, t.mean(axis=(0,2,3)), t.var(axis=(0,2,3)).add(1e-5).rsqrt()) print(t.mean().item(), t.std().item()) ``` """ axis_ = argfix(axis) shape = tuple(s if ax in axis_ else 1 for ax, s in enumerate(self.shape)) x = self - mean.reshape(shape) if weight is not None: x = x * weight.reshape(shape) ret = x.mul(invstd.reshape(shape) if len(invstd.shape) == len(axis_) else invstd) return (ret + bias.reshape(shape)) if bias is not None else ret def dropout(self, p=0.5) -> Tensor: """ Applies dropout to `self`. NOTE: dropout is only applied when `Tensor.training` is `True`. - Described: https://paperswithcode.com/method/dropout - Paper: https://jmlr.org/papers/v15/srivastava14a.html ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) t = Tensor.randn(2, 2) with Tensor.train(): print(t.dropout().numpy()) ``` """ if not Tensor.training or p == 0: return self return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p) # helper function commonly used for indexing def _one_hot_along_dim(self:Tensor, num_classes:sint, dim:int=-1) -> Tensor: if not dtypes.is_int(self.dtype): raise RuntimeError(f"_one_hot_along_dim expects int index tensor, getting {self.dtype}") offset = self.ndim - self._resolve_dim(dim) - 1 return self == Tensor.arange(num_classes, device=self.device, requires_grad=False).reshape((num_classes,) + (1,) * offset) def one_hot(self, num_classes:int=-1) -> Tensor: """ Converts `self` to a one-hot tensor. `num_classes` defaults to -1, which means num_classes will be inferred as max(self) + 1. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([0, 1, 3, 3, 4]) print(t.one_hot(5).numpy()) ``` """ if not dtypes.is_int(self.dtype): raise RuntimeError(f"expect integer dtype, getting {self.dtype=}") if num_classes == -1: num_classes = (self.max()+1).item() return self[..., None]._one_hot_along_dim(num_classes).where(1, 0) def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor: """ Computes scaled dot-product attention. `self` is the query tensor, `key` is the key tensor, and `value` is the value tensor. - Described: https://paperswithcode.com/method/scaled - Paper: https://arxiv.org/abs/1706.03762v7 ```python exec="true" source="above" session="tensor" result="python" q = Tensor.randn(2, 4, 8) k = Tensor.randn(2, 4, 8) v = Tensor.randn(2, 4, 8) print(q.scaled_dot_product_attention(k, v).numpy()) ``` """ # NOTE: it also works when `key` and `value` have symbolic shape. assert all_int(self.shape), f"does not support symbolic shape {self.shape}" qk = self.matmul(key.transpose(-2,-1), dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1]) # handle attention mask if is_causal: if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True") attn_mask = qk.ones_like(requires_grad=False, device=self.device, dtype=dtypes.bool).tril() if attn_mask is not None: if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf")) qk = qk + attn_mask return qk.cast(self.dtype).softmax(-1).dropout(dropout_p) @ value def _do_reduction(self, reduction:ReductionStr="mean") -> Tensor: if reduction not in get_args(ReductionStr): raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}") reductions: dict[str, Callable[[Tensor], Tensor]] = {"mean": Tensor.mean, "sum": Tensor.sum, "none": lambda x: x} return reductions[reduction](self) def binary_crossentropy(self, Y:Tensor, reduction:ReductionStr="mean") -> Tensor: """ Computes the binary cross-entropy loss between `self` and `Y`. See: https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html ```python exec="true" source="above" session="tensor" result="python" t = Tensor([0.1, 0.9, 0.2]) Y = Tensor([0, 1, 0]) print(t.binary_crossentropy(Y).item()) ``` """ return (-Y*self.log() - (1-Y)*(1-self).log())._do_reduction(reduction) def binary_crossentropy_logits(self, Y:Tensor, reduction:ReductionStr="mean") -> Tensor: """ Computes the binary cross-entropy loss between `self` and `Y` where `self` is logits. See: https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html ```python exec="true" source="above" session="tensor" result="python" t = Tensor([-1, 2, -3]) Y = Tensor([0, 1, 0]) print(t.binary_crossentropy_logits(Y).item()) ``` """ return (self.maximum(0) - Y * self + (1 + self.abs().neg().exp()).log())._do_reduction(reduction) def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index:int=-1, label_smoothing=0.0, reduction:ReductionStr="mean") -> Tensor: """ Computes the sparse categorical cross-entropy loss between `self` and `Y`. NOTE: `self` is logits and `Y` is the target labels. NOTE: unlike PyTorch, this function expects the class axis to be -1 See: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html ```python exec="true" source="above" session="tensor" result="python" t = Tensor([[-1, 2, -3], [1, -2, 3]]) Y = Tensor([1, 2]) print(t.sparse_categorical_crossentropy(Y).item()) ``` """ assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]" assert reduction in ("mean", "sum", "none"), "reduction must be one of ['mean', 'sum', 'none']" log_probs, loss_mask = self.log_softmax(), (Y != ignore_index) if ignore_index != -1 else Y.ones_like(dtype=dtypes.bool) y_counted = Y.to(self.device).flatten().reshape(-1, 1)._one_hot_along_dim(self.shape[-1]) y = (y_counted * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1]) smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask) unreduced = ((1 - label_smoothing) * (log_probs * y).sum(-1) + smoothing) # NOTE: because of ignore_index, we can't use Tensor.mean (so can't use `_do_reduction` here) return -(unreduced.sum() / loss_mask.sum() if reduction == "mean" else (unreduced.sum() if reduction == "sum" else unreduced)) def cross_entropy(self, Y:Tensor, reduction:ReductionStr="mean", label_smoothing:float=0.0) -> Tensor: """ Compute the cross entropy loss between input logits and target. NOTE: `self` are logits and `Y` are the target labels or class probabilities. See: https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html ```python exec="true" source="above" session="tensor" result="python" t = Tensor([[-1, 2, -3], [1, -2, 3]]) Y = Tensor([1, 2]) print(t.cross_entropy(Y).item()) ``` ```python exec="true" source="above" session="tensor" result="python" t = Tensor([[-1, 2, -3], [1, -2, 3]]) Y = Tensor([1, 2]) print(t.cross_entropy(Y, reduction='none').numpy()) ``` """ assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]" Y = Y.one_hot(num_classes=cast(int, self.shape[1])) if Y.ndim < 2 else Y Y = (1 - label_smoothing)*Y + label_smoothing / cast(int, Y.shape[1]) ret = -self.log_softmax(axis=1).mul(Y).sum(axis=1) return ret._do_reduction(reduction) def nll_loss(self, Y:Tensor, weight:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean") -> Tensor: """ Compute the negative log likelihood loss between log-probabilities and target labels. NOTE: `self` is log-probabilities and `Y` is the Y labels or class probabilities. See: https://pytorch.org/docs/stable/generated/torch.nn.functional.nll_loss.html ```python exec="true" source="above" session="tensor" result="python" t = Tensor([[-1, 2, -3], [1, -2, 3]]) Y = Tensor([1, 2]) print(t.log_softmax().nll_loss(Y).item()) ``` ```python exec="true" source="above" session="tensor" result="python" t = Tensor([[-1, 2, -3], [1, -2, 3]]) Y = Tensor([1, 2]) print(t.log_softmax().nll_loss(Y, reduction='none').numpy()) ``` """ weight = Tensor.ones_like(Y, requires_grad=False) if weight is None else weight[Y] masked_weight = weight if ignore_index is None else weight * (Y != ignore_index) nll = -self.gather(1, Y.unsqueeze(1)).squeeze(1) * masked_weight return nll.sum() / masked_weight.sum() if reduction == "mean" else nll._do_reduction(reduction) # ***** Tensor Properties ***** @property def ndim(self) -> int: """ Returns the number of dimensions in the tensor. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([[1, 2], [3, 4]]) print(t.ndim) ``` """ return len(self.shape) def numel(self) -> sint: """ Returns the total number of elements in the tensor. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) print(t.numel()) ``` """ return prod(self.shape) def element_size(self) -> int: """ Returns the size in bytes of an individual element in the tensor. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([5], dtype=dtypes.int16) print(t.element_size()) ``` """ return self.dtype.itemsize def nbytes(self) -> int: """ Returns the total number of bytes of all elements in the tensor. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([8, 9], dtype=dtypes.float) print(t.nbytes()) ``` """ return self.numel() * self.element_size() def is_floating_point(self) -> bool: """ Returns `True` if the tensor contains floating point types, i.e. is one of `dtype.float64`, `dtype.float32`, `dtype.float16`, `dtype.bfloat16`. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([8, 9], dtype=dtypes.float32) print(t.is_floating_point()) ``` """ return dtypes.is_float(self.dtype) def size(self, dim:int|None=None) -> sint|tuple[sint, ...]: """ Return the size of the tensor. If `dim` is specified, return the length along dimension `dim`. Otherwise return the shape of the tensor. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([[4, 5, 6], [7, 8, 9]]) print(t.size()) ``` ```python exec="true" source="above" session="tensor" result="python" print(t.size(dim=1)) ``` """ return self.shape if dim is None else self.shape[dim] # ***** cast ops ***** def llvm_bf16_cast(self, dtype:DTypeLike) -> Tensor: # hack for devices that don't support bfloat16 assert self.dtype == dtypes.bfloat16 return self.to("LLVM").cast(dtype) def cast(self, dtype:DTypeLike) -> Tensor: """ Casts `self` to the given `dtype`. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([-1, 2.5, 3], dtype=dtypes.float) print(t.dtype, t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" t = t.cast(dtypes.int32) print(t.dtype, t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" t = t.cast(dtypes.uint8) print(t.dtype, t.numpy()) ``` """ if (dt:=to_dtype(dtype)) in {dtypes.uint8, dtypes.uint16} and dtypes.is_float(self.dtype): # NOTE: values within the int32 range and outside the unsigned dtype range will cause values to wrap around return self._apply_uop(UOp.cast, dtype=dtypes.int32)._apply_uop(UOp.cast, dtype=dt) return self if self.dtype == dt else self._apply_uop(UOp.cast, dtype=dt) def bitcast(self, dtype:DTypeLike) -> Tensor: """ Bitcasts `self` to the given `dtype` of the same itemsize. `self` must not require a gradient. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([-1, 2, 3], dtype=dtypes.int32) print(t.dtype, t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" t = t.bitcast(dtypes.uint32) print(t.dtype, t.numpy()) ``` """ if self.requires_grad: raise RuntimeError("can't backprop through bitcast") dt = to_dtype(dtype) if (ns:=dt.itemsize) != (os:=self.dtype.itemsize) and (self.shape[-1]*os) % ns != 0: raise RuntimeError("unsupported size in bitcast") if (not isinstance(self.device, str) or not self.device.startswith("DISK")) and ns != os: new_uint, old_uint = to_dtype(f"uint{8*ns}"), to_dtype(f"uint{8*os}") tmp = self.bitcast(old_uint) if ns > os: return functools.reduce(Tensor.add, (tmp[..., i::ns//os].cast(new_uint) << 8*i*os for i in range(ns//os))).bitcast(dtype) return Tensor.stack(*(tmp>>8*i*ns for i in range(os//ns)), dim=-1).flatten(-2).cast(new_uint).bitcast(dtype) return self._apply_uop(UOp.bitcast, dtype=dt) if self.dtype != dt else self def float(self) -> Tensor: """ Convenience method to cast `self` to a `float32` Tensor. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([-1, 2, 3], dtype=dtypes.int32) print(t.dtype, t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" t = t.float() print(t.dtype, t.numpy()) ``` """ return self.cast(dtypes.float32) def half(self) -> Tensor: """ Convenience method to cast `self` to a `float16` Tensor. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([-1, 2, 3], dtype=dtypes.int32) print(t.dtype, t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" t = t.half() print(t.dtype, t.numpy()) ``` """ return self.cast(dtypes.float16) def int(self) -> Tensor: """ Convenience method to cast `self` to a `int32` Tensor. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([-1.5, -0.5, 0.0, 0.5, 1.5]) print(t.dtype, t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" t = t.int() print(t.dtype, t.numpy()) ``` """ return self.cast(dtypes.int32) def bool(self) -> Tensor: """ Convenience method to cast `self` to a `bool` Tensor. ```python exec="true" source="above" session="tensor" result="python" t = Tensor([-1, 0, 1]) print(t.dtype, t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" t = t.bool() print(t.dtype, t.numpy()) ``` """ return self.cast(dtypes.bool) # *** image Tensor function replacements *** def image_dot(self, w:Tensor, dtype:DTypeLike|None=None) -> Tensor: # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1) x, dx, dw = self, self.ndim, w.ndim if not (dx > 0 and dw > 0): raise RuntimeError(f"both tensors need to be at least 1D, got {dx}D and {dw}D") if x.shape[-1] != w.shape[-min(w.ndim, 2)]: raise RuntimeError(f"cannot image_dot {x.shape} and {w.shape}") bs, groups, cin, cout = prod(self.shape[0:-2]), prod(w.shape[0:-2]), w.shape[-2], w.shape[-1] out_shape_t = self.shape[0:-2] + (cout,-1) if len(self.shape) > 1 else (cout, ) # NOTE: with NHWC we can remove the transposes # bs x groups*cin x H x W cx = self.transpose(self.ndim-1, self.ndim-2).reshape((bs//groups, groups*cin, -1, 1)) # groups*cout x cin x H, W cw = w.transpose(w.ndim-1, w.ndim-2).reshape((groups*cout, cin, 1, 1)) return cx.image_conv2d(cw, groups=groups, dtype=dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2) def image_conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, dtype=None) -> Tensor: base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef (bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W) # hack for non multiples of 4 on cin if cin % 4 != 0 and not (cin == 1 and groups%4 == 0): x = x.reshape(bs, groups, cin, iy, ix) # do this always? added_input_channels = 4 - (cin % 4) w = w.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(w.ndim))) x = x.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(x.ndim))) cin = cin + added_input_channels x = x.reshape(bs, groups*cin, iy, ix) # hack for non multiples of 4 on rcout added_output_channels = 0 if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0): added_output_channels = 4 - (rcout % 4) rcout += added_output_channels cout = groups * rcout w = w.pad(tuple((0, added_output_channels) if i == 1 else None for i in range(w.ndim))) # packed (note: flipping bs and iy would make the auto-padding work) x = x.permute(0,2,3,1) cin_last = iy == 1 and ix == 1 if cin == 1: w = w.reshape(cout//4,4,H,W).permute(0,2,3,1) elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3) else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1) # contiguous creates the image, and early realize static weights (TODO: test for the static weight) if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4))) x, w = x.contiguous(), w.contiguous() # expand out rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1 cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1] x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo) if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo) else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4) # prepare input x = x.permute(0,3,4,5,1,2).pad(self._resolve_pool_pads(padding,2))._pool((H,W), stride, dilation)# -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W) x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, (oy := x.shape[4]), (ox := x.shape[5]), *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W) # prepare weights w = w.permute(0,4,2,5,1,3).reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W)) # the conv! ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1), dtype=dtype) # undo hack for non multiples of 4 on C.rcout if added_output_channels != 0: ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels] cout = groups * (rcout - added_output_channels) # NCHW output ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2) return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1)) P = ParamSpec("P") T = TypeVar("T") def _metadata_wrapper(fn: Callable[P, T]) -> Callable[P, T]: def _wrapper(*args: P.args, **kwargs: P.kwargs) -> T: if _METADATA.get() is not None: return fn(*args, **kwargs) if TRACEMETA >= 2: caller_frame = sys._getframe(frame := 1) caller_module = caller_frame.f_globals.get("__name__", None) caller_func = caller_frame.f_code.co_name if caller_module is None: return fn(*args, **kwargs) # if its called from nn we want to step up frames until we are out of nn while caller_module.startswith("tinygrad.nn") and "optim" not in caller_module: caller_frame = sys._getframe(frame := frame + 1) caller_module = caller_frame.f_globals.get("__name__", None) if caller_module is None: return fn(*args, **kwargs) # if its called from a lambda in tinygrad we want to look two more frames up if caller_module.startswith("tinygrad") and caller_func == "": caller_frame = sys._getframe(frame := frame + 2) caller_module = caller_frame.f_globals.get("__name__", None) if caller_module is None: return fn(*args, **kwargs) caller_func = caller_frame.f_code.co_name caller_lineno = caller_frame.f_lineno caller = f"{caller_module}:{caller_lineno}::{caller_func}" else: caller = "" token = _METADATA.set(Metadata(name=fn.__name__, caller=caller)) ret = fn(*args, **kwargs) _METADATA.reset(token) return ret return _wrapper if TRACEMETA >= 1: for name, fn in inspect.getmembers(Tensor, inspect.isfunction): if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential", "gradient"]: continue setattr(Tensor, name, functools.wraps(fn)(_metadata_wrapper(fn)))