openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

625 lines
31 KiB

# ruff: noqa: E501, A001, A002, A006
# A001 Variable `input` is shadowing a Python builtin
# A002 Function argument `input` is shadowing a Python builtin
# A006 Lambda argument `input` is shadowing a Python builtin
from tinygrad import Tensor, dtypes, Device
from tinygrad.helpers import getenv, prod
import torch.lib
TORCH_DEBUG = getenv("TORCH_DEBUG")
import torch, pathlib, math, operator, functools, inspect
torch.autograd.grad_mode.set_multithreading_enabled(False)
from tinygrad.dtype import _from_torch_dtype, _to_torch_dtype
# https://pytorch.org/docs/stable/torch.compiler_ir.html
def _from_torch_device(device: torch.device): return f"{Device.DEFAULT}:{device.index or 0}"
def _to_torch_device(device: str): return torch.device("tiny", int(device.partition(":")[2] or 0))
import torch.utils.cpp_extension
mod = torch.utils.cpp_extension.load(name="custom_device_extension", sources=[str(pathlib.Path(__file__).parent / "wrapped_tensor.cpp")])
def wrap(x:Tensor) -> torch.Tensor: return mod.wrap(x, _to_torch_dtype(x.dtype), _to_torch_device(x.device).index)
def unwrap(x:torch.Tensor) -> Tensor:
assert isinstance(x, torch.Tensor), f"x isn't {type(x)}"
return mod.unwrap(x)
class TinyBackend:
def is_initialized(self): return True
def is_available(self): return True
def current_device(self): return 0
def _is_in_bad_fork(self): return False
def manual_seed_all(self, seed: int): Tensor.manual_seed(seed)
def device_count(self): return getenv("GPUS", 1) # TODO: device count in tiny?
torch.utils.rename_privateuse1_backend("tiny")
torch._register_device_module("tiny", TinyBackend())
torch.utils.generate_methods_for_privateuse1_backend()
aten = torch.ops.aten
# track view relationships for in place operations
def is_view(tensor: Tensor): return hasattr(tensor, "_view_base")
def canonical_base(view: Tensor): return getattr(view, "_view_base", view)
def derived_views(base: Tensor): return [t for tref in getattr(base, "_views", set()) if (t:=tref()) is not None]
def wrap_view_op(fn):
def _wrap(*args,**kwargs):
args = [unwrap(x) if isinstance(x, torch.Tensor) else x for x in args]
kwargs = {k:unwrap(v) if isinstance(v, torch.Tensor) else v for k,v in kwargs.items()}
ret = fn(*args,**kwargs)
ret._view_base = base = canonical_base(args[0])
if not hasattr(base, "_views"): base._views = set()
base._views.add(weakref.ref(ret))
return wrap(ret)
return _wrap
view_ops = {
"aten.view": Tensor.reshape,
"aten._unsafe_view": Tensor.reshape, # when are views unsafe, and do we care?
"aten.view.dtype": lambda self,dtype: self.bitcast(_from_torch_dtype(dtype)),
"aten.expand": Tensor.expand,
"aten.t": Tensor.transpose,
"aten.transpose.int": Tensor.transpose,
"aten.squeeze.dim": Tensor.squeeze,
"aten.unsqueeze": Tensor.unsqueeze,
"aten.detach": Tensor.detach,
}
for k,v in view_ops.items(): torch.library.impl(k.replace("aten.", "aten::"), "privateuseone")(wrap_view_op(v))
# in place operations with views
def realize_with_views(self: Tensor, views: Tensor):
if not self.lazydata.st.contiguous: self.replace(self.contiguous())
self.replace(self.clone().realize())
for v in views:
ret = self
st = ShapeTracker(self.lazydata.st.views + v.lazydata.st.views) # TODO: is this right?
for mo in cached_to_movement_ops(self.shape, st): ret = apply_mop(ret, mo)
v.replace(ret)
def maybe_realize_storage(self: Tensor) -> bool:
if realize:=is_view(self): realize_with_views((base:=canonical_base(self)), derived_views(base))
return realize
def inplace_fn(outvars: str|list[str]):
if type(outvars) is str: outvars = [outvars]
def decorator(fn):
sig = inspect.signature(fn)
def wrapper(*args, **kwargs):
bound = sig.bind(*args, **kwargs)
outs = [kwargs.get(v, bound.arguments.get(v)) for v in outvars]
outs = [unwrap(o) if isinstance(o, torch.Tensor) else o for o in outs]
realize = any(maybe_realize_storage(o) for o in outs)
ret = fn(*args, **kwargs)
if realize: Tensor.realize(*(o for o in outs))
return ret
return wrapper
return decorator
# *** bad functions on CPU ***
@torch.library.impl("aten::_index_put_impl_", "privateuseone")
@inplace_fn("self")
def _index_put_impl_(self, indices, values, accumulate=False, unsafe=False):
# TODO: move to tinygrad
ret = aten._index_put_impl_(self.cpu(), [x.cpu() if isinstance(x, torch.Tensor) else None for x in indices], values.cpu(), accumulate, unsafe).to(self.device)
return wrap(unwrap(self).assign(unwrap(ret)))
@torch.library.impl("aten::index.Tensor", "privateuseone")
def index_tensor(x, y):
return aten.index(x.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in y]).to(x.device)
@torch.library.impl("aten::index_put", "privateuseone")
def index_put(self, indices, values, accumulate=False):
return aten.index_put(self.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in indices], values.cpu(), accumulate).tiny()
@torch.library.impl("aten::randperm.generator_out", "privateuseone")
def randperm_generator(n, generator=None, out=None): out.copy_(torch.randperm(n, generator=generator, device="cpu").tiny())
@torch.library.impl("aten::cummax", "privateuseone")
def cummax(self, dim):
# TODO: support cummax with indices to match torch
cummax, indices = aten.cummax(self.cpu(), dim)
return (cummax.tiny(), indices.tiny())
@torch.library.impl("aten::nonzero", "privateuseone")
# TODO: move to tinygrad
def nonzero(self): return aten.nonzero(self.cpu()).tiny()
def upsample_backward(grad_out, output_size, input_size, *args, f=None): return f(grad_out.cpu(), output_size, input_size, *args).tiny()
for i in [
"upsample_linear1d_backward", "upsample_nearest1d_backward", "_upsample_nearest_exact1d_backward",
"upsample_nearest2d_backward", "_upsample_nearest_exact2d_backward",
"upsample_nearest3d_backward", "_upsample_nearest_exact3d_backward",
"upsample_trilinear3d_backward", "upsample_bilinear2d_backward"
]:
torch.library.impl(f"aten::{i}", "privateuseone")(functools.partial(upsample_backward, f=getattr(aten, i)))
# *** end bad functions on CPU ***
@torch.library.impl("aten::zero_", "privateuseone")
@inplace_fn("x")
def zero_(x):
if TORCH_DEBUG: print(f"zero_ {x.shape}")
tt = unwrap(x)
# NOTE: unconditional contiguous covers if x is contiguous (match it) or if x is view (realize for inplace)
# TODO: consolidate
tt.assign(tt.zeros_like().contiguous())
@torch.library.impl("aten::fill_.Scalar", "privateuseone")
@inplace_fn("x")
def fill_scalar(x, y):
if TORCH_DEBUG: print(f"fill_.Scalar {x.shape} {y}")
tt = unwrap(x)
tt.assign(tt.full_like(y).contiguous())
@torch.library.impl("aten::_local_scalar_dense", "privateuseone")
def _local_scalar_dense(tensor): return unwrap(tensor).item()
@functools.lru_cache(None)
def cached_to_movement_ops(shape, st) -> list:
mops = to_movement_ops(st)
if mops[0] == (MovementOps.RESHAPE, shape): mops = mops[1:]
return mops
from tinygrad.shape.shapetracker import ShapeTracker, View
from extra.to_movement_ops import to_movement_ops, apply_mop, MovementOps
@torch.library.impl("aten::as_strided", "privateuseone")
def as_strided(tensor:torch.Tensor, size, stride, storage_offset=None):
storage_offset = storage_offset or tensor.storage_offset()
@wrap_view_op
def _as_strided(tensor:Tensor, size, stride, storage_offset=None):
# multiple as_strided do not compound
base = canonical_base(tensor)
# TODO: this is heavyweight
st = ShapeTracker(base.lazydata.st.views + (View.create(tuple(size), tuple(stride), storage_offset),))
ret = base
if TORCH_DEBUG >= 1: print("**** as_strided", tensor.shape, size, stride, st)
if prod(size) == 1: return ret.flatten()[storage_offset].reshape(size)
for mo in cached_to_movement_ops(tuple(base.shape), st): ret = apply_mop(ret, mo)
return ret
return _as_strided(tensor, size, stride, storage_offset)
@torch.library.impl("aten::empty_strided", "privateuseone")
def empty_strided(size, stride, dtype, layout=None, device=None, pin_memory=False):
if TORCH_DEBUG: print(f"empty_strided {size=} {stride=} {dtype=} {layout=} {device=} {pin_memory=}")
ret = Tensor.empty(*size, dtype=_from_torch_dtype(dtype), device=_from_torch_device(device)).contiguous()
# TODO: should return with requested strides
return wrap(ret)
@torch.library.impl("aten::empty.memory_format", "privateuseone")
def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None):
if TORCH_DEBUG: print(f"empty.memory_format {size=} {dtype=} {layout=} {device=} {pin_memory=} {memory_format=}")
ret = Tensor.empty(*size, dtype=_from_torch_dtype(dtype or torch.get_default_dtype()), device=_from_torch_device(device)).contiguous()
return wrap(ret)
@torch.library.impl("aten::max_pool2d_with_indices", "privateuseone")
def max_pool2d_with_indices(self:torch.Tensor, kernel_size:tuple[int, ...], stride=None, padding=0, dilation=1, ceil_mode=False):
# TODO: supprt stride [] in tinygrad?
if stride is not None and len(stride) == 0: stride = None
ret, idx = unwrap(self).max_pool2d(kernel_size, stride, dilation, padding, ceil_mode, return_indices=True)
return (wrap(ret), wrap(idx.cast(dtypes.int64)))
@torch.library.impl("aten::max_pool2d_with_indices_backward", "privateuseone")
def max_pool2d_with_indices_backward(grad_out:torch.Tensor, self:torch.Tensor, kernel_size:tuple[int, ...], stride=None, padding=0, dilation=1, ceil_mode=False, indices=None):
return wrap(Tensor.max_unpool2d(unwrap(grad_out), unwrap(indices), output_size=unwrap(self).shape))
@torch.library.impl("aten::max_unpool2d", "privateuseone")
def max_unpool2d(self:torch.Tensor, indices:torch.Tensor, output_size):
return wrap(unwrap(self).max_unpool2d(unwrap(indices), output_size=output_size))
@torch.library.impl("aten::arange", "privateuseone")
def arange(end, dtype=None, device=None, pin_memory=None):
return wrap(Tensor.arange(0, end, dtype=_from_torch_dtype(dtype or torch.get_default_dtype())))
@torch.library.impl("aten::arange.start", "privateuseone")
def arange_start(start, end, dtype=None, device=None, pin_memory=None):
return wrap(Tensor.arange(start, end, dtype=_from_torch_dtype(dtype or torch.get_default_dtype())))
@torch.library.impl("aten::arange.start_step", "privateuseone")
def arange_start_step(start, end, step, dtype=None, device=None, pin_memory=None):
return wrap(Tensor.arange(start, end, step, dtype=_from_torch_dtype(dtype or torch.get_default_dtype())))
@torch.library.impl("aten::convolution_overrideable", "privateuseone")
def convolution_overrideable(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups):
if TORCH_DEBUG >= 1:
print(f"convolution {input.shape=} {weight.shape=} {stride=} {padding=} {dilation=} {transposed=} {output_padding=} {groups=}")
input, weight, bias = unwrap(input), unwrap(weight), unwrap(bias) if bias is not None else None
# TODO: fix test_biased_conv2d fails without realize()
if not transposed: return wrap(input.conv2d(weight, bias, groups=groups, stride=stride, dilation=dilation, padding=padding).realize())
return wrap(input.conv_transpose2d(weight, bias, groups=groups, stride=stride, dilation=dilation, padding=padding, output_padding=output_padding).realize())
@torch.library.impl("aten::convolution_backward_overrideable", "privateuseone")
def convolution_backward_overrideable(grad_out, input, weight, stride, padding, dilation, transposed, output_padding, groups, output_mask):
if TORCH_DEBUG >= 1:
print(f"convolution_backward {input.shape=} {weight.shape=} {stride=} {padding=} {dilation=} {transposed=} {output_padding=} {groups=}")
grad_out, input, weight, bias = unwrap(grad_out), unwrap(input), unwrap(weight), Tensor.zeros(weight.shape[0], device=_from_torch_device(weight.device))
if not transposed: out = Tensor.conv2d(input, weight, bias, groups=groups, stride=stride, dilation=dilation, padding=padding)
else:
bias = Tensor.zeros(weight.shape[1] * groups)
out = Tensor.conv_transpose2d(input, weight, bias, groups=groups, stride=stride, dilation=dilation, padding=padding, output_padding=output_padding)
grads = out.gradient(*[t for t,m in zip([input, weight, bias], output_mask) if m], gradient=grad_out)
return tuple([wrap(grads.pop(0)) if m else None for m in output_mask])
@torch.library.impl("aten::slice.Tensor", "privateuseone")
@wrap_view_op
def slice_tensor(self, dim=0, start=None, end=None, step=1):
slices = [slice(None)] * self.ndim
slices[dim] = slice(start, end, step)
return self[slices]
@torch.library.impl("aten::slice_backward", "privateuseone")
def slice_backward(grad_out, input_sizes, dim, start, end, step):
grad_input = Tensor.zeros(input_sizes).contiguous()
slices = [slice(None)] * len(input_sizes)
slices[dim] = slice(start, end, step)
grad_input[slices] = unwrap(grad_out)
return wrap(grad_input)
@torch.library.impl("aten::select_backward", "privateuseone")
def select_backward(grad_out, input_sizes, dim, index):
grad_input = Tensor.zeros(input_sizes).contiguous()
slices = [slice(None)] * len(input_sizes)
slices[dim] = index
grad_input[slices] = unwrap(grad_out)
return wrap(grad_input)
def avg_pool(self, kernel_size, stride=[], padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None):
return wrap(unwrap(self).avg_pool2d(kernel_size, stride if stride != [] else None, padding=padding, ceil_mode=ceil_mode, count_include_pad=count_include_pad))
def avg_pool_backward(grad_out, self, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None):
self, grad_out = unwrap(self), unwrap(grad_out)
out = Tensor.avg_pool2d(self, kernel_size, stride if stride != [] else None, dilation=1, padding=padding, ceil_mode=ceil_mode, count_include_pad=count_include_pad)
return wrap(out.gradient(self, gradient=grad_out)[0])
for dim in [2, 3]:
torch.library.impl(f"aten::avg_pool{dim}d", "privateuseone")(avg_pool)
torch.library.impl(f"aten::avg_pool{dim}d_backward", "privateuseone")(avg_pool_backward)
def pad_forward(self, padding, mode=None): return wrap(Tensor.pad(unwrap(self), padding, mode=mode))
def pad_backward(grad_out, self, padding, mode):
self, grad_out = unwrap(self), unwrap(grad_out)
out = Tensor.pad(self, padding, mode=mode)
return wrap(out.gradient(self, gradient=grad_out)[0])
for dim in [1, 2, 3]:
for pad_type, mode in [("replication", "replicate"), ("reflection", "reflect")]:
torch.library.impl(f"aten::{pad_type}_pad{dim}d", "privateuseone")(functools.partial(pad_forward, mode=mode))
torch.library.impl(f"aten::{pad_type}_pad{dim}d_backward", "privateuseone")(functools.partial(pad_backward, mode=mode))
def upsample(self, size, align_corners=False, mode=None): return wrap(Tensor.interpolate(unwrap(self), size, mode=mode, align_corners=align_corners))
for i,pre in enumerate(["", "bi", "tri"]):
torch.library.impl(f"aten::upsample_{pre}linear{i+1}d", "privateuseone")(functools.partial(upsample, mode="linear"))
torch.library.impl(f"aten::upsample_nearest{i+1}d", "privateuseone")(functools.partial(upsample, mode="nearest"))
torch.library.impl(f"aten::_upsample_nearest_exact{i+1}d", "privateuseone")(functools.partial(upsample, mode="nearest-exact"))
@torch.library.impl("aten::scatter_add.out", "privateuseone")
@inplace_fn("out")
def scatter_add(self, dim, index, src, out):
self, index, src, out = unwrap(self), unwrap(index), unwrap(src), unwrap(out)
if self.shape == (): return wrap(out.assign(src))
return wrap(out.assign(Tensor.scatter_reduce(self, dim, index, src, reduce='sum')))
@torch.library.impl("aten::_copy_from", "privateuseone")
def _copy_from(src: torch.Tensor, dest, non_blocking=False):
realize = dest.is_tiny and maybe_realize_storage(unwrap(dest))
cast_dtype = _from_torch_dtype(dest.dtype)
if src.is_tiny and dest.is_tiny:
to_device = _from_torch_device(dest.device)
src,dest = unwrap(src),unwrap(dest)
# TODO we need to properly match dest shape and strides, not blindly assign
if dest.lazydata.st.contiguous or dest.lazydata.is_realized: src = src.contiguous() # this only solves some cases
dest.assign(src.cast(cast_dtype).to(to_device))
if realize: Tensor.realize(dest)
elif src.is_tiny and dest.is_cpu:
# TODO: is there a better way?
dest.resize_(src.numel()).resize_(src.shape)
dest.copy_(torch.from_numpy(unwrap(src).cast(cast_dtype).numpy()))
elif src.is_cpu and dest.is_tiny:
to_device = _from_torch_device(dest.device)
# TODO we need to properly match dest shape and strides, not blindly assign
unwrap(dest).assign(Tensor(src.numpy()).cast(cast_dtype).to(to_device))
if realize: Tensor.realize(unwrap(dest))
else:
raise NotImplementedError(f"can't copy from {src.device} -> {dest.device}")
@torch.library.impl("aten::cat.out", "privateuseone")
@inplace_fn("out")
def cat_out(tensors, dim=0, out=None):
unwrap(out).assign(Tensor.cat(*[unwrap(x) for x in tensors], dim=dim))
@torch.library.impl("aten::topk.values", "privateuseone")
@inplace_fn(["values", "indices"])
def topk_values(input, k, dim=None, largest=True, sorted=True, values=None, indices=None):
out_values, out_indices = unwrap(input).topk(k, dim if dim is not None else -1, largest, sorted)
unwrap(values).assign(out_values)
unwrap(indices).assign(out_indices.cast(dtypes.int64))
return wrap(out_values), wrap(out_indices)
@torch.library.impl("aten::sort.values_stable", "privateuseone")
@inplace_fn(["values", "indices"])
def sort_values(input, dim=-1, descending=False, stable=True, values=None, indices=None):
out_values, out_indices = unwrap(input).sort(dim, descending)
unwrap(values).assign(out_values)
unwrap(indices).assign(out_indices.cast(dtypes.int64))
return wrap(out_values), wrap(out_indices)
# register some decompositions
from torch._decomp import get_decompositions
decomps = [
aten.native_batch_norm, aten.native_batch_norm_backward,
aten.native_layer_norm_backward,
aten.addmm,
aten.addcmul,
aten.addcdiv,
aten._log_softmax_backward_data,
aten.threshold_backward,
aten.softplus_backward,
aten.elu, # elu has a scale + input_scale param
aten.elu_backward,
aten.softplus,
aten.threshold,
aten.nll_loss_forward,
aten.nll_loss_backward,
# AttributeError: 'int' object has no attribute '_broadcasted'
aten.sigmoid_backward,
aten.tanh_backward,
aten.sinc,
aten._prelu_kernel,
aten.softshrink,
aten.hardshrink,
aten.log_sigmoid_forward,
aten.isneginf,
aten.isposinf,
aten.nan_to_num,
aten.logit,
aten.rsub,
aten.index_select,
aten.native_dropout, aten.native_dropout_backward,
aten._softmax_backward_data, aten.embedding_dense_backward,
aten.linalg_vector_norm,
aten.binary_cross_entropy, aten.binary_cross_entropy_backward,
aten.upsample_nearest2d.out,
# activations
aten.hardswish, aten.hardswish_backward,
aten.hardtanh, aten.hardtanh_backward,
aten.gelu, aten.gelu_backward,
aten.logical_and,
aten.randint,
aten.eye,
aten.hardsigmoid_backward,
aten.leaky_relu_backward,
aten.nll_loss2d_forward,
# NOTE: many of these don't work or cause infinite loops
#aten.var_mean,
#aten.var,
#aten.rsqrt,
#aten.max_pool2d_with_indices,
# NOTE: these are prims
#aten.digamma,
#aten.erfinv,
#aten.lgamma,
# this needs copy_strided
#aten.lerp,
]
for k,v in get_decompositions(decomps).items():
key = str(k._schema).split("(")[0]
if TORCH_DEBUG >= 2: print("register decomp for", k)
torch.library.impl(key, "privateuseone")(v)
# NOTE: we should only implement the "out" form, it should be 0 overhead
# TODO: due to issue with empty / is_realized, it is slow to use assign so we use replace
# the goal is to make as much as we can this
simple_tensor_methods = [
# unary (ish)
"log", "log2", "sqrt", "rsqrt", "sign", "silu", "hardsigmoid", "exp", "exp2", "neg", "reciprocal", "bitwise_not",
"sigmoid", "clamp", "mish", "erf", "leaky_relu",
# trig
"acos", "acosh", "cos", "cosh", "asin", "asinh", "sin", "sinh", "atan", "atanh", "tan", "tanh",
# rounding
"ceil", "round", "floor", "trunc",
# binary
"mul", "div", "maximum", "minimum", "copysign",
# modify
"tril", "triu",
# reduce
"all", "any", "argmax", "argmin", "cumsum", "cumprod",
# complex
"avg_pool2d", "linspace"]
tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_methods}, **{
"aten.add.out": lambda input,other,alpha=1: input+alpha*other,
"aten.sub.out": lambda input,other,alpha=1: input-alpha*other, # NOTE: this is also needed to handle reverse
"aten.div.out_mode": Tensor.div,
"aten.mul.out": operator.mul,
"aten.bmm.out": operator.matmul,
# NOTE: because these methods have a name with "Tensor" in them, they can't go in simple tensor methods
"aten.remainder.Tensor_out": Tensor.mod,
"aten.pow.Tensor_Tensor_out": Tensor.pow,
"aten.pow.Tensor_Scalar_out": Tensor.pow,
"aten.pow.Scalar_out": lambda input,exponent: input**exponent,
"aten.bitwise_and.Tensor_out": Tensor.bitwise_and,
"aten.bitwise_or.Tensor_out": Tensor.bitwise_or,
"aten.bitwise_xor.Tensor_out": Tensor.bitwise_xor,
"aten.eq.Tensor_out": Tensor.eq, "aten.eq.Scalar_out": Tensor.eq,
"aten.ne.Tensor_out": Tensor.ne, "aten.ne.Scalar_out": Tensor.ne,
"aten.ge.Tensor_out": Tensor.__ge__, "aten.ge.Scalar_out": Tensor.__ge__,
"aten.gt.Tensor_out": Tensor.__gt__, "aten.gt.Scalar_out": Tensor.__gt__,
"aten.lt.Tensor_out": Tensor.__lt__, "aten.lt.Scalar_out": Tensor.__lt__,
"aten.le.Tensor_out": Tensor.__le__, "aten.le.Scalar_out": Tensor.__le__,
"aten.clamp_max.Tensor_out": lambda input,max_: input.clamp(max_=max_),
"aten.clamp_min.Tensor_out": lambda input,min_: input.clamp(min_=min_),
"aten.fmod.Tensor_out": lambda input,other: input-input.div(other, rounding_mode="trunc")*other,
# TODO: this might result in overflow issues
"aten.round.decimals_out": lambda self,decimals: (self*10**decimals).round()/10**decimals,
# TODO: support this in tinygrad
"aten.bitwise_left_shift.Tensor_out": lambda x,y: x*(2**y),
"aten.bitwise_right_shift.Tensor_out": lambda x,y: x//(2**y),
# not in tinygrad. are there decomps for these?
"aten.log10.out": lambda self: self.log2() * (math.log(2) / math.log(10)),
"aten.log1p.out": lambda self: (self+1).log(),
"aten.expm1.out": lambda self: self.exp() - 1,
"aten.fmax.out": lambda input,other: Tensor.where(input.isnan() & ~other.isnan(), other, Tensor.where(~input.isnan() & other.isnan(), input, Tensor.maximum(input, other))),
"aten.fmin.out": lambda input,other: Tensor.where(input.isnan() & ~other.isnan(), other, Tensor.where(~input.isnan() & other.isnan(), input, Tensor.minimum(input, other))),
# TODO: this gets the shape wrong
#"aten.arange.start_out": Tensor.arange,
"aten.lerp.Scalar_out": Tensor.lerp,
"aten.scatter.value_out": Tensor.scatter,
"aten.where.self_out": Tensor.where,
"aten.prod.int_out": Tensor.prod,
"aten.scatter.src_out": Tensor.scatter,
# NOTE: axis=[] in torch means all, change tinygrad?
"aten.sum.IntList_out": lambda self,axis,keepdim=False,dtype=None:
self.sum(axis if axis is None or len(axis) else None, keepdim,
dtype = _from_torch_dtype(dtype) if dtype is not None else None),
}}
# we add the "out" here
def wrap_out(f):
@inplace_fn("out")
def _wrap_out(*args, **kwargs):
out = kwargs.pop('out')
assigned = f(*args, **kwargs)
if getenv("ALLOW_DTYPE_MISMATCH", 1): assigned = assigned.cast(out.dtype)
assert out.shape == assigned.shape, f"shape mismatch: {assigned.shape} -> {out.shape}"
assert out.device == assigned.device, f"device mismatch: {assigned.device} -> {out.device}"
assert out.dtype == assigned.dtype, f"dtype mismatch: {assigned.dtype} -> {out.dtype}"
if out.lazydata.is_realized: assigned = assigned.contiguous() # TODO: how does this map to torch's semantics
return out.assign(assigned)
return _wrap_out
tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
"aten.remainder.Scalar_Tensor": lambda x,y: x%y,
"aten.floor_divide": lambda x,y: x//y,
"aten.floor_divide_.Tensor": inplace_fn("x")(lambda x,y: x.assign(x//y)),
# TODO: use tinygrad methods, but they require x to be unsigned
"aten.__lshift__.Scalar": lambda x,y: x*(2**y),
"aten.__ilshift__.Scalar": inplace_fn("x")(lambda x,y: x.assign(x*(2**y))),
"aten.__rshift__.Scalar": lambda x,y: x//(2**y),
"aten.__irshift__.Scalar": inplace_fn("x")(lambda x,y: x.assign(x//(2**y))),
# relu doesn't have an out form?
"aten.relu": Tensor.relu,
"aten.relu_": inplace_fn("x")(lambda x: x.assign(x.relu())),
"aten.mean": Tensor.mean,
"aten.mean.dim": Tensor.mean,
"aten.min": Tensor.min,
"aten.max": Tensor.max,
"aten.mm": Tensor.matmul,
"aten.mv": Tensor.matmul,
"aten.dot": Tensor.dot,
"aten.prod": Tensor.prod,
"aten.isnan": Tensor.isnan,
"aten.std.correction": Tensor.std,
"aten.std_mean.correction": Tensor.std_mean,
"aten.var.correction": Tensor.var,
"aten.var_mean.correction": Tensor.var_mean,
"aten.scatter.value": Tensor.scatter,
"aten.scatter.value_reduce": Tensor.scatter,
"aten.gather": lambda self, dim, index: self.gather(dim, index.cast(dtypes.int)),
"aten.where.self": Tensor.where, # NOTE: this is needed as well as the out type
"aten.repeat": lambda x,*repeats: Tensor.repeat(x,*repeats).contiguous(), # not a view
"aten._softmax": lambda self,dim,half_to_float: self.softmax(dim),
"aten._log_softmax": lambda self,dim,half_to_float: self.log_softmax(dim),
"aten.random_": inplace_fn("self")(lambda self:
self.assign(Tensor.randint(*self.shape, low=dtypes.min(self.dtype), high=dtypes.max(self.dtype), device=self.device, dtype=self.dtype))),
"aten.random_.from": inplace_fn("self")(lambda self, from_, to:
self.assign(Tensor.randint(*self.shape, low=from_, high=to, device=self.device, dtype=self.dtype))),
"aten.uniform_": inplace_fn("self")(lambda self, low=0, high=1: self.assign(Tensor.uniform(*self.shape, low=low, high=high))),
"aten.normal_": inplace_fn("self")(lambda self, mean=0, std=1: self.assign(Tensor.normal(*self.shape, mean=mean, std=std))),
# these don't work in out form, they have size 0
"aten.abs": Tensor.abs,
"aten.logical_not": Tensor.logical_not,
"aten.logical_or_": inplace_fn("x")(lambda x, y: x.assign(x | y)),
"aten.multinomial": Tensor.multinomial,
"aten.masked_fill_.Scalar": inplace_fn("self")(lambda self, mask, value: self.assign(self.masked_fill(mask, value))),
"aten.masked_fill_.Tensor": inplace_fn("self")(lambda self, mask, value: self.assign(self.masked_fill(mask, value))),
"aten.masked_fill.Scalar": Tensor.masked_fill,
"aten.masked_fill.Tensor": Tensor.masked_fill,
"aten.masked_select": Tensor.masked_select,
"aten.all": Tensor.all,
"aten.sgn": Tensor.sign,
"aten.acos": Tensor.acos,
"aten.any": Tensor.any,
"aten.bitwise_not": Tensor.bitwise_not,
"aten.argmax": Tensor.argmax,
"aten.argmin": Tensor.argmin,
"aten.asinh": Tensor.asinh,
"aten.mul": Tensor.mul,
"aten.atanh": Tensor.atanh,
"aten.fill_.Tensor": Tensor.full, # TODO: looks wrong
"aten.flip": Tensor.flip,
"aten.scatter_reduce.two": Tensor.scatter_reduce,
"aten.squeeze_.dim": lambda self, dim: self.replace(self.squeeze(dim), allow_shape_mismatch=True), # TODO: inplace view op, here?
"aten.add.Tensor": lambda input,other,alpha=1: input+alpha*other,
"aten.linspace": lambda start, stop, steps, dtype=None, **kwargs:
Tensor.linspace(start, stop, steps, **({"dtype": _from_torch_dtype(dtype)} if dtype is not None else {})),
"aten.topk": Tensor.topk,
"aten.constant_pad_nd": lambda self, padding, value=0.0: self.pad(padding, mode="constant", value=value).contiguous(),
"aten.cumsum": lambda self, dim: self.cumsum(dim).contiguous(), # TODO: fix test_simple_cumsum, fails without contiguous for shapes >512
"aten.logsumexp": lambda self, axis, keepdim=False: self.logsumexp(axis[0], keepdim=keepdim),
"aten.roll": Tensor.roll,
"aten.logcumsumexp": Tensor.logcumsumexp,
"aten.lerp.Tensor": Tensor.lerp,
"aten.ones_like": lambda self, dtype=None, device=None, **kwargs:
self.ones_like(**{k: v for k, v in {"dtype": _from_torch_dtype(dtype) if dtype else None,
"device": _from_torch_device(device) if device else None}.items() if v is not None}),
"aten.max.dim": lambda self, dim, keepdim=False: (self.max(dim, keepdim), self.argmax(dim, keepdim).cast(dtype=dtypes.int64))
}}
def wrap_fxn(k,f):
def nf(*args, **kwargs):
if TORCH_DEBUG:
print(k, len(args), [x.shape if isinstance(x, torch.Tensor) else x for x in args],
{k:v.shape if isinstance(v, torch.Tensor) else v for k,v in kwargs.items()})
args = [unwrap(x) if isinstance(x, torch.Tensor) else x for x in args]
kwargs = {k:unwrap(v) if isinstance(v, torch.Tensor) else v for k,v in kwargs.items()}
out = f(*args, **kwargs)
if isinstance(out, Tensor): return wrap(out)
elif isinstance(out, tuple): return tuple(wrap(x) for x in out)
else: raise RuntimeError(f"unknown output type {type(out)}")
return nf
for k,v in tiny_backend.items(): torch.library.impl(k.replace("aten.", "aten::"), "privateuseone")(wrap_fxn(k,v))
if TORCH_DEBUG:
from torch.utils._python_dispatch import TorchDispatchMode
class DispatchLog(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
#print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
print(f"Dispatch Log: {func}")
return func(*args, **(kwargs or {}))
(_dispatch_log:=DispatchLog()).__enter__() # NOTE: must be kept alive
# NOTE: patch torch optimizer step to avoid continously growing the computation graph
import weakref
_torch_modules_with_buffers: weakref.WeakSet[torch.nn.Module] = weakref.WeakSet()
def register_torch_buffer(mod, _name, _buffer): _torch_modules_with_buffers.add(mod)
def get_real_tinygrad_buffers():
res = set()
for mod in _torch_modules_with_buffers:
for _,b in mod.named_buffers(recurse=False):
if b is not None and b.is_tiny:
res.add(unwrap(b))
return res
torch.nn.modules.module.register_module_buffer_registration_hook(register_torch_buffer)
from torch.nn.modules import Module
def backward_hook(model:Module, _grad_input, _grad_out):
grads_to_realize = [unwrap(p.grad) for p in model.parameters() if p.grad is not None]
if len(grads_to_realize): Tensor.realize(*grads_to_realize)
def module_hook(module:Module, _name, _submodule): module.register_backward_hook(backward_hook)
torch.nn.modules.module.register_module_module_registration_hook(module_hook)
def realize_optimizer_step(optimizer: torch.optim.Optimizer, *args, **kwargs):
tinygrad_tensors = []
for param_group in optimizer.param_groups:
for param in param_group["params"]:
if param is None: continue
tinygrad_tensors.append(param.data)
for state_dict in optimizer.state.values():
for _, value in state_dict.items():
if torch.is_tensor(value): tinygrad_tensors.append(value)
real_tinygrad_tensors = [unwrap(x) for x in tinygrad_tensors if x.is_tiny]
real_tinygrad_tensors += get_real_tinygrad_buffers()
if len(real_tinygrad_tensors): Tensor.realize(*real_tinygrad_tensors)
_optimizer_init = torch.optim.Optimizer.__init__
def _optimizer_patched_init(self, *args, **kwargs):
_optimizer_init(self, *args, **kwargs)
self.register_step_post_hook(realize_optimizer_step)
torch.optim.Optimizer.__init__ = _optimizer_patched_init