openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

720 lines
44 KiB

from tinygrad.tensor import Tensor
from tinygrad.helpers import prod, dtypes, ImageDType
from extra.onnx import safe_numpy
from onnx.helper import tensor_dtype_to_np_dtype
from onnx.onnx_pb import TensorProto
import os
import numpy as np
import functools
from typing import Union, Tuple, Optional, List, Any
import math
# **************** Free Ops ****************
def Identity(input: Tensor): return input
def Neg(input: Tensor): return -input
def Add(input: Tensor, other: Tensor, broadcast=None): return input + other if input.dtype == dtypes.float or isinstance(input.dtype, ImageDType) else (input + other).cast(input.dtype)
def Sub(input: Union[Tensor, Any], other: Tensor): return input - other # some test has input as int
def Mul(input: Tensor, other: Tensor): return (input * other) if input.dtype == dtypes.float or isinstance(input.dtype, ImageDType) else (input * other).cast(input.dtype)
# in openpilot, due to SHUFFLE_PAD_OPS issues, we are spending an extra kernel
def Div(input: Tensor, other: Tensor): return input / other if input.dtype == dtypes.float or isinstance(input.dtype, ImageDType) else input.div(other).floor()
def Pow(input: Tensor, other: Tensor): return (input.float() ** other.float()).cast(input.dtype)
def Reciprocal(input: Tensor): return input.reciprocal()
def Sqrt(input: Tensor): return input.sqrt()
def Sign(input: Tensor): return input.sign()
def Abs(input: Tensor): return input.abs()
def Exp(input: Tensor): return input.exp()
def Log(input: Tensor): return input.log()
def Mish(input: Tensor): return input.mish()
def Sin(x: Tensor): return x.sin()
def Cos(x: Tensor): return x.cos()
def Tan(x: Tensor): return x.tan()
def Relu(input: Tensor): return input.relu()
def Sigmoid(input: Tensor): return input.sigmoid()
def Tanh(input: Tensor): return input.tanh()
def MatMul(input: Tensor, other: Tensor): return input.matmul(other)
def Floor(x:Tensor): return x.floor()
def Ceil(x:Tensor): return x.ceil()
def Less(x:Tensor,y:Tensor): return (x<y).cast(dtypes.bool)
def LessOrEqual(x:Tensor,y:Tensor): return (x<=y).cast(dtypes.bool)
def Greater(x:Tensor,y:Tensor): return (x>y).cast(dtypes.bool)
def GreaterOrEqual(x:Tensor,y:Tensor): return (x>=y).cast(dtypes.bool)
def Equal(x:Tensor,y:Tensor): return (x==y).cast(dtypes.bool)
def Max(*data_0): return functools.reduce(Tensor.maximum, data_0)
def Min(*data_0): return functools.reduce(Tensor.minimum, data_0)
def Sum(*data_0): return functools.reduce(Tensor.__add__, data_0)
def Mean(*data_0): return functools.reduce(Tensor.__add__, data_0) / len(data_0)
def Where(condition:Tensor,X:Tensor,Y:Tensor): return condition.where(X, Y).cast(X.dtype)
def Cast(input: Tensor, to): return input.cast(dtypes.from_np(tensor_dtype_to_np_dtype(to)))
# **************** Simple Ops ****************
def Constant(value: Tensor=None, value_float=None, value_floats=None, value_int=None, value_ints=None, value_string=None, value_strings=None):
if value: return value
elif value_float: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False)
elif value_floats: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False)
elif value_int: return Tensor(value_int, dtype=dtypes.int64, requires_grad=False)
elif value_ints: return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False)
elif value_string or value_strings: raise NotImplementedError(f'value_string or value_strings not implemented for Constant op')
def Softsign(input: Tensor): return input / (1+input.abs())
def Cosh(x): return (math.e ** x + math.e ** -x) / 2
def Sinh(x): return (math.e ** x - math.e ** -x) / 2
def Tanh(x): return x.tanh()
def HardSigmoid(input: Tensor, alpha=0.2, beta=0.5): return (alpha*input + beta).clip(0, 1)
def HardSwish(input: Tensor): return input * HardSigmoid(input, 1/6, 0.5)
def Celu(X: Tensor, alpha=1.0): return X.relu() - (-alpha*(X/alpha).exp()+1).relu()
def Selu(X: Tensor, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875): return gamma * (X.relu() - (-alpha*X.exp()+alpha).relu())
def Softplus(X: Tensor): return X.softplus()
def PRelu(X:Tensor, slope:Tensor):
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE
return X.clip(0, float("inf")) + X.clip(float("-inf"), 0) * slope
def LeakyRelu(X: Tensor, alpha=0.01): return X.leakyrelu(alpha)
def ThresholdedRelu(X: Tensor, alpha=1.0): return (X-alpha).relu() + (X-alpha).relu().sign() * alpha
def Softmax_1(input: Tensor, axis=1): return input.softmax(axis)
def Softmax_13(input: Tensor, axis=-1): return input.softmax(axis)
Softmax = {1: Softmax_1, 13: Softmax_13} # Softmax default axis changed
def LogSoftmax(input: Tensor, axis=-1): return input.log_softmax(axis)
def Clip(input: Tensor, min=None, max=None): return input.clip(float('-inf') if min is None else min, float('inf') if max is None else max)
# NOTE ReduceProd would require a new llop
def _axes(axes, noop_with_empty_axes): return [int(x) for x in safe_numpy(axes)] if axes is not None and not (isinstance(axes, Tensor) and axes.shape == (0,)) else ([] if noop_with_empty_axes else None)
def ReduceMax(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceMin(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceSum(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceMean(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.mean(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceSumSquare(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.square().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceL1(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.abs().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceL2(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.square().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).sqrt()
def ReduceLogSum(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).log()
def ReduceLogSumExp(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.exp().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).log()
def GlobalAveragePool(X: Tensor): return X.mean(axis=tuple(range(2, len(X.shape))), keepdim=True)
def GlobalMaxPool(X: Tensor): return X.max(axis=tuple(range(2, len(X.shape))), keepdim=True)
def OptionalHasElement(x: Tensor=None): return Tensor(x is not None and x.numel() > 0, dtype=dtypes.bool)
def OptionalGetElement(x: Tensor=None): return x if x is not None else Tensor([], dtype=dtypes.float32)
def Tile(input: Tensor, repeats): return input.repeat([int(x) for x in safe_numpy(repeats)])
def Range(start: Tensor, limit, delta): return Tensor.arange(start=int(safe_numpy(start)), stop=int(safe_numpy(limit)), step=int(safe_numpy(delta))).cast(dtype=start.dtype)
def Shape(data: Tensor, end=None, start=0): return Tensor(list(data.shape)[start:end], dtype=dtypes.int32 if os.path.isfile("/TICI") else dtypes.int64) # TODO: really?
def Size(data: Tensor): return prod(data if isinstance(data, list) else data.shape)
def Flatten(input: Tensor, axis=1): return input.reshape(prod((1,) + input.shape[0:axis]), -1)
def Reshape(data: Tensor, shape: Tensor, allowzero=None): return data.reshape([int(x) if x != 0 else data.shape[i] for i,x in enumerate(safe_numpy(shape))])
def Shrink(input: Tensor, bias=0.0, lambd=0.5): return (input < -lambd)*(input+bias) + (input > lambd)*(input-bias)
def And(x:Tensor, y:Tensor): return Where((x==y), x, Tensor.zeros(*x.shape)).cast(dtypes.bool)
def Or(x:Tensor, y:Tensor): return Where((x==y), x, Tensor.ones(*x.shape)).cast(dtypes.bool)
def Xor(x:Tensor, y:Tensor): return Where((x==y), Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool)
def Not(x:Tensor): return Where((x==1), Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool)
def Asin(x): return Atan(x / Tensor.sqrt(1 - x * x))
def Asinh(x): return Tensor.log(x + Tensor.sqrt(x * x + 1))
def Acosh(x): return Tensor.log(x + Tensor.sqrt(x * x - 1))
def Atanh(x): return 0.5 * Tensor.log((1 + x)/(1 - x))
def Acos(x: Tensor):
negate = (x < 0)
x = x.abs()
ret = ((((-0.0187293 * x) + 0.0742610)*x - 0.2121144) * x + 1.5707288) * Tensor.sqrt(1.0 - x)
ret = ret - 2 * negate * ret
return negate * 3.14159265358979 + ret
def Atan(y: Tensor):
x = Tensor.ones(y.shape)
t3 = x
t1 = y.abs()
t0 = (t3 > t1).where(t3, t1)
t1 = (t3 < t1).where(t3, t1)
t3 = t1 / t0
t4 = t3 * t3
t0 = ((((-0.013480470 * t4 + 0.057477314) * t4 - 0.121239071) * t4 + 0.195635925) * t4 - 0.332994597) * t4 + 0.999995630
t3 = t0 * t3
t3 = (y.abs() > x.abs()).where(1.570796327 - t3, t3)
return (y < 0).where(-t3, t3)
def Trilu(x: Tensor, k: Union[Tensor, int]=0, upper=1):
k = int(k.numpy().item()) if k != 0 else 0 # onnx passes k as a tensor int64 with one element, default is 0
return x.triu(k) if upper else x.tril(k)
def Squeeze(input: Tensor, axes):
if isinstance(axes, Tensor): axes = safe_numpy(axes)
axes = [int(x) if x >= 0 else int(x+input.ndim) for x in axes]
return input.reshape([s for i,s in enumerate(input.shape) if i not in axes])
def Unsqueeze(data: Tensor, axes):
axes = [len(data.shape) + int(x) if x < 0 else int(x) for x in safe_numpy(axes)]
new_shape = [1] * (len(data.shape) + len(axes))
ptr = iter(data.shape)
for i in range(len(new_shape)):
if i not in axes:
new_shape[i] = next(ptr)
return data.reshape(new_shape)
def Binarizer(input, threshold=0.0): return input > threshold
def ArgMax(x: Tensor, axis=0, keepdims=1, select_last_index=0):
axis = axis + x.ndim if axis < 0 else axis
m = x == (x.max(axis=axis, keepdim=keepdims) if keepdims else x.max(axis=axis, keepdim=keepdims).unsqueeze(axis))
c = Tensor.arange(x.shape[axis]).reshape(*[1]*(axis), x.shape[axis], *[1]*(x.ndim - axis-1)) * m
return c.max(axis=axis,keepdim=keepdims).cast(dtypes.int64)
def ArgMin(x, axis=0, keepdims=1, select_last_index=0): return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index)
def Elu(input: Tensor, alpha=1.0): return input.elu(alpha=alpha)
def Concat(*inputs: List[Tensor], axis): return inputs[0].cat(*inputs[1:], dim=axis)
def Transpose(input: Tensor, perm=None): return input.permute(order=list(range(len(input.shape))[::-1]) if perm is None else perm)
# NOTE: since we only have one type, this is valid!
def CastLike(input, target_type):
assert isinstance(target_type, Tensor), "can only CastLike Tensor"
return input
def ConstantOfShape(input, value:Tensor=None):
if value is None: value=Tensor([0.0])
shape = [int(x) for x in safe_numpy(input)]
return Tensor.ones(*shape, dtype=value.dtype) * (value if shape[0]!=0 else 1)
# TODO: abstract out the broadcast logic in tensor
def Expand(input: Tensor, shape):
x_shape, y_shape = input.shape, [int(x) for x in safe_numpy(shape)]
# copied from _broadcasted
x_shape, y_shape = [([1]*(max(len(x_shape), len(y_shape))-len(t_shape)) + list(t_shape)) for t_shape in [x_shape, y_shape]]
shape_ret = tuple(max(sx, sy) for sx,sy in zip(x_shape, y_shape))
return input.reshape(x_shape).expand(shape_ret)
# **************** Complex Ops ****************
def Gemm(A: Tensor, B: Tensor, C: Tensor=None, alpha=1.0, beta=1.0, transA=0, transB=0, broadcast=0):
ret = alpha * (A.transpose(transA) @ B.transpose(transB))
if C is not None: ret += beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(len(ret.shape))][::-1]))
return ret
# works with Tensors.ndim != 4
def _batchnorm(self:Tensor, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor):
shape = [1, -1] + [1] * (self.ndim-2)
x = (self - mean.reshape(shape=shape))
if weight: x = x * weight.reshape(shape=shape)
ret = x.mul(invstd.reshape(shape=shape) if len(invstd.shape) == 1 else invstd)
return (ret + bias.reshape(shape=shape)) if bias else ret
# TODO: this is copied from tinygrad/nn/__init__.py
# spatial is from opset 7 and has since been removed
def BatchNormalization(X: Tensor, scale, B, input_mean, input_var, epsilon=1e-05, momentum=0.9, training_mode=0, spatial=1, is_test=0):
if training_mode:
x_detached = X.detach()
current_mean = x_detached.mean(axis=(0,2,3))
y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1]))
current_var = (y*y).mean(axis=(0,2,3))
current_invstd = current_var.add(epsilon).pow(-0.5)
running_mean = input_mean * momentum + current_mean * (1 - momentum)
running_var = input_var * momentum + current_var * (1 - momentum)
return _batchnorm(X, scale, B, current_mean, current_invstd), running_mean, running_var
else:
invstd = (input_var + epsilon)**-0.5
return _batchnorm(X, scale, B, input_mean, invstd)
def InstanceNormalization(x: Tensor, scale: Tensor, bias: Tensor, epsilon=1e-05):
axis = tuple(range(2, len(x.shape)))
mean = x.mean(axis=axis, keepdim=True)
invstd = x.sub(mean).pow(2).mean(axis=axis, keepdim=True).add(epsilon).pow(-0.5)
return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1]))
def LayerNormalization(x: Tensor, scale, bias, axis=-1, epsilon=1e-05, stash_type=1):
assert stash_type == 1, "only float32 is supported"
axis = tuple(i for i in range(axis if axis >= 0 else len(x.shape) + axis, len(x.shape)))
mean = x.mean(axis=axis, keepdim=True)
return x.layernorm(axis, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).pow(2).mean(axis=axis, keepdim=True).add(epsilon).sqrt().reciprocal()
def GroupNormalization(x: Tensor, scale: Tensor, bias: Tensor, num_groups, epsilon=1e-05):
return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape)
# onnx: [x1_begin, x2_begin, ..., x1_end, x2_end, ...]
# numpy.pad: ((x1_begin, x1_end), (x2_begin, x2_end), ...)
def _format_padding(onnx_pads, ndims=None, axes=None):
if ndims and len(onnx_pads)//2 != ndims: onnx_pads = onnx_pads * ndims # for OnnxBackendPyTorchConvertedModelTest the len(onnx_pads) == 2
if ndims is None: ndims = len(onnx_pads) // 2
if axes is None: axes = list(range(ndims))
num_axes = len(axes)
np_pads = [(0,0)] * ndims
for i in range(num_axes):
np_pads[axes[i]] = (onnx_pads[i], onnx_pads[i + num_axes])
return np_pads
def _padding(X: Tensor, pads=None, auto_pad="NOTSET", axes=None, constant_value=0., strides=None, kernel_shape=None, dilations=None):
if auto_pad != "NOTSET": pads = _auto_pad(X, auto_pad, strides, kernel_shape, dilations)
if pads is None: return X
pads = _format_padding(pads, ndims=len(X.shape), axes=axes)
return X.pad(tuple(pads), value=constant_value)
def _auto_pad(X, auto_pad, strides, kernel_shape, dilations):
strides = [strides]*len(kernel_shape) if isinstance(strides, int) else strides if strides else [1]*len(kernel_shape)
dilations = [1]*len(kernel_shape) if dilations == 1 else dilations
pad_shape = [(math.ceil(sh/st)-1)*st+((ks-1)*di+1)-sh for sh, st, ks, di in zip(X.shape[-len(strides):], strides, kernel_shape, dilations)]
if auto_pad == "SAME_UPPER": return [pad_shape[0]//2, pad_shape[1]//2, pad_shape[0]-pad_shape[0]//2, pad_shape[1]-pad_shape[1]//2]
elif auto_pad == "SAME_LOWER": return [pad_shape[0]-pad_shape[0]//2, pad_shape[1]-pad_shape[1]//2, pad_shape[0]//2, pad_shape[1]//2]
else: raise NotImplementedError(f"auto_pad={auto_pad} not implemented, yet")
def Pad(x: Tensor, pads: Union[Tensor, Tuple[int, ...]], constant_value: Tensor=None, axes: Tensor=None, mode="constant", value: float=0.):
constant_value = value if constant_value is None else float(safe_numpy(constant_value)[0])
seq_pads = list(pads) if isinstance(pads, tuple) else safe_numpy(pads)
seq_pads = [math.ceil(i) for i in seq_pads]
seq_axes = safe_numpy(axes).astype(np.int32).tolist() if axes is not None else None
base_shape = x.shape
pads = _format_padding(seq_pads, ndims=len(x.shape), axes=seq_axes)
if mode == "wrap":
repeat_args = [math.ceil(dim[0]/sh) + math.ceil(dim[1]/sh) + 1 for dim, sh in zip(pads, base_shape)]
new_shape = [s*r for s,r in zip(base_shape, repeat_args)]
shrink_args = [(sh-dim[0]%sh if dim[0]%sh != 0 else 0, nsh-(sh-dim[1]%sh) if dim[1]%sh != 0 else nsh) for dim, sh, nsh in zip(pads, base_shape, new_shape)]
return x.repeat(tuple(repeat_args)).shrink(tuple(shrink_args))
elif mode == "reflect":
for i,s in enumerate(x.shape):
if pads[i] == (0,0): continue
elif pads[i][0] and not pads[i][1]:
x = x.flip(i).shrink(tuple([(0,s_) if i_ != i else (s-pads[i][0]-1, s_-1) for i_,s_ in enumerate(x.shape)])).pad(tuple([(0,0) if i_ != i else (0,s) for i_ in range(x.ndim)])) + \
x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)]))
elif not pads[i][0] and pads[i][1]:
x = x.flip(i).shrink(tuple([(0,s_) if i_ != i else (1, pads[i][1]+1) for i_,s_ in enumerate(x.shape)])).pad(tuple([(0,0) if i_ != i else (s,0) for i_ in range(x.ndim)])) + \
x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)]))
else:
x = x.flip(i).shrink(tuple([(0,s_) if i_ != i else (s-pads[i][0]-1, s_-1) for i_,s_ in enumerate(x.shape)])).pad(tuple([(0,0) if i_ != i else (0,s+pads[i][1]) for i_ in range(x.ndim)])) + \
x.flip(i).shrink(tuple([(0,s_) if i_ != i else (1, pads[i][1]+1) for i_,s_ in enumerate(x.shape)])).pad(tuple([(0,0) if i_ != i else (s+pads[i][0],0) for i_ in range(x.ndim)])) + \
x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)]))
return x
elif mode == "edge":
for i,s in enumerate(x.shape):
if pads[i] == (0,0): continue
elif pads[i][0] and not pads[i][1]:
x = x.shrink(tuple([(0,s_) if i_ != i else (0,1) for i_,s_ in enumerate(x.shape)])).expand([pads[i][0] if i_ == i else s_ for i_,s_ in enumerate(x.shape)]).pad(tuple([(0,0) if i_ != i else (0,s) for i_ in range(x.ndim)])) + \
x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)]))
elif not pads[i][0] and pads[i][1]:
x = x.shrink(tuple([(0,s_) if i_ != i else (s_-1, s_) for i_,s_ in enumerate(x.shape)])).expand([pads[i][0] if i_ == i else s_ for i_,s_ in enumerate(x.shape)]).pad(tuple([(0,0) if i_ != i else (s+pads[i][0],0) for i_ in range(x.ndim)])) + \
x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)]))
else:
x = x.shrink(tuple([(0,s_) if i_ != i else (0,1) for i_,s_ in enumerate(x.shape)])).expand([pads[i][0] if i_ == i else s_ for i_,s_ in enumerate(x.shape)]).pad(tuple([(0,0) if i_ != i else (0,s+pads[i][1]) for i_ in range(x.ndim)])) + \
x.shrink(tuple([(0,s_) if i_ != i else (s_-1, s_) for i_,s_ in enumerate(x.shape)])).expand([pads[i][1] if i_ == i else s_ for i_,s_ in enumerate(x.shape)]).pad(tuple([(0,0) if i_ != i else (s+pads[i][0],0) for i_ in range(x.ndim)])) + \
x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)]))
return x
elif mode == "constant":
return _padding(x, seq_pads, axes=seq_axes, constant_value=constant_value)
def AveragePool(X: Tensor, kernel_shape, auto_pad="NOTSET", ceil_mode=0, count_include_pad=0, dilations=1, pads=None, strides=1):
if dilations != 1: raise NotImplementedError(f"dilations != 1 not supported, dilations:{dilations}")
pixel_axes = tuple(range(len(X.shape)))[-2:]
if ceil_mode: auto_pad = "SAME_UPPER"
padding_included = _padding(X, pads, auto_pad, axes=pixel_axes, strides=strides, kernel_shape=kernel_shape, dilations=dilations).avg_pool2d(kernel_shape, stride=strides)
if count_include_pad:
return padding_included
else:
div = _padding(Tensor.ones(*X.shape), pads, auto_pad, axes=pixel_axes, strides=strides, kernel_shape=kernel_shape, dilations=dilations).avg_pool2d(kernel_shape, stride=strides)
return padding_included / div
def MaxPool(X: Tensor, kernel_shape, auto_pad="NOTSET", ceil_mode=0, dilations=1, pads=None, storage_order=0, strides=1):
if ceil_mode: auto_pad = "SAME_UPPER"
ret = _padding(X, pads, auto_pad, constant_value=-np.inf, axes=tuple(range(len(X.shape)))[-len(kernel_shape):], strides=strides, kernel_shape=kernel_shape, dilations=dilations)
ret = ret.max_pool2d(kernel_shape, stride=strides, dilation=dilations)
ret_len, X_len = ret.numel(), X.numel()
indices = ((ret.flatten().unsqueeze(1).expand(ret_len, X_len) == X.flatten().reshape(1, X_len).expand(ret_len, X_len)) * Tensor.arange(X_len).reshape(1, X_len).expand(ret_len, X_len)).sum(1).reshape(ret.shape).cast(dtypes.int64)
if storage_order: indices = indices.transpose(indices.ndim-2, indices.ndim-1)
return ret, indices
def MaxUnpool(xT: Tensor, xI: Tensor, outshape: Tensor=None, kernel_shape=None, pads=None, strides=None):
out_sh = [(ks//2)*2 + st * inps for inps, st, ks in zip(xI.shape, strides, kernel_shape)]
outlength = prod(out_sh)
xI = xI.flatten().unsqueeze(1).expand(prod(xT.shape), outlength)
arange = Tensor.arange(outlength, requires_grad=False).reshape(1, outlength).expand(xI.shape)
xT = xT.flatten().unsqueeze(1).expand(prod(xT.shape), outlength)
ret = ((xI == arange) * xT).sum(0).reshape([1, 1] + out_sh)
if outshape is not None:
outshape = safe_numpy(outshape).tolist()
if outshape != ret.shape:
diff = [outshape[2] - ret.shape[2], outshape[3] - ret.shape[3]]
pad_args = [diff[0]//2, diff[1]//2, diff[0]-diff[0]//2, diff[1]-diff[1]//2]
ret = ret.pad2d((pad_args[1], pad_args[3], pad_args[0], pad_args[2]))
return ret
def Conv(X: Tensor, W: Tensor, B=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, strides=1):
if auto_pad != "NOTSET": padding = _auto_pad(X, auto_pad, strides, kernel_shape, dilations)
else: padding = [p for ps in zip(pads[:len(pads)//2][::-1], pads[len(pads)//2:][::-1]) for p in ps] if pads is not None else 0 # reorder padding
return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, padding=padding)
def ConvTranspose(X: Tensor, W: Tensor, B=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, output_shape=None, output_padding=0, strides=1):
if not kernel_shape: kernel_shape = W.shape
if pads is None and auto_pad != "NOTSET": pads = _auto_pad(X, auto_pad, strides, kernel_shape, dilations)
elif pads is None and auto_pad == "NOTSET": pads = [0,0] * (X.ndim - 2)
strides_ = [1]*(W.ndim-1) + [strides] if isinstance(strides, int) else [1]*(W.ndim-len(strides)) + list(strides)
dilations_ = [1]*(W.ndim-1) + [dilations] if isinstance(dilations, int) else [1]*(W.ndim-len(dilations)) + list(dilations)
if output_shape and not output_padding:
out_sh = [st*(xs-1) + (ks-1)*di+1 if n < 2 else st*(xs-1) + (ks-1)*di+1 - pads[n-2] - pads[n-1] for n, (st, xs, ks, di) in enumerate(zip(strides_, X.shape, kernel_shape, dilations_))]
output_padding = [os - rs for os, rs in zip(output_shape, out_sh[-len(output_shape):])]
return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=pads if pads is not None else 0, output_padding=output_padding)
# Reimplemented here because you need legacy RNG for passing ONNX tests.
def Dropout(data: Tensor, ratio=0.5, training_mode=False, seed=None):
if isinstance(ratio, Tensor) and not ratio.shape: ratio = safe_numpy(ratio) # ratio and tensor is passed in as Tensor with shape: ()
if isinstance(training_mode, Tensor) and not training_mode.shape: training_mode = safe_numpy(training_mode)
if not training_mode: return data, Tensor.ones(*data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's.
rng = np.random.RandomState(seed)
ratio = ratio.lazydata.realize().toCPU()[0] if isinstance(ratio, Tensor) else ratio
mask = Tensor((rng.random(data.shape) >= ratio), requires_grad=False, device=data.device)
return data * mask * (1/(1.0 - ratio)), mask
def LRN(input: Tensor, size, alpha=1e-4, beta=0.75, bias=1.0):
bs, c, iy, ix = input.shape
return input / input.mul(input).reshape(bs,1,c,iy*ix).pad2d((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1).reshape(bs,c,iy,ix).mul(alpha).add(bias).pow(beta)
def MeanVarianceNormalization(input: Tensor, axis=(0, 2, 3)):
data_mean = input.mean(axis=axis, keepdim=True)
std = ((input**2).mean(axis=axis, keepdim=True) - data_mean**2).sqrt()
return (input - data_mean) / (std + 1e-9)
def NegativeLogLikelihoodLoss(input: Tensor, target: Tensor, weight=None, ignore_index=None, reduction="mean"):
target = target.cast(dtypes.float32)
N, C, i_shape = input.shape[0], input.shape[1], input.shape
t_shape = target.shape
if len(input.shape) != 3:
input = input.reshape((N, C, -1))
target = target.reshape((N, -1))
if weight is not None:
mask = target.unsqueeze(-1) == Tensor.arange(C).repeat((N, 1, 1))
weight = (mask * weight).sum(axis=-1)
if ignore_index is not None:
cond = target == ignore_index
weight = cond.where(0, weight) if weight is not None else cond.where(Tensor.zeros(*target.shape), 1)
mask = target[:, None, :] == Tensor.arange(C).reshape([1, C] + [1]*(len(input.shape) -2))
loss = (-mask * input).sum(axis=1) * (1 if weight is None else weight)
if reduction == "mean": return loss.mean() if weight is None else loss.sum() / weight.sum()
elif reduction == "sum": return loss.sum()
return loss.reshape(t_shape) if len(i_shape) != 3 else loss
def SoftmaxCrossEntropyLoss(scores: Tensor, labels: Tensor, weights=None, ignore_index=None, reduction="mean"):
N, C, *s_dimensions = scores.shape
if ignore_index is not None: labels = (labels == ignore_index).where(C+1, labels)
mask = labels.unsqueeze(1) == Tensor.arange(C).reshape(1, C, *[1]*len(s_dimensions))
y = scores.log_softmax(axis=1)
if weights is not None: weights = weights.__getitem__(tuple([labels, *[slice(None)]*(weights.ndim-1)]))
loss = (mask * -y).sum(1) if weights is None else (mask * -y).sum(1) * weights
if reduction == "mean": loss = loss.sum() / (loss == 0).where(0, 1).sum() if weights is None else loss.sum() / weights.sum()
elif reduction == "sum": loss = loss.sum()
return loss, y
def ArrayFeatureExtractor(input: Tensor, indices: Tensor): return input.__getitem__(tuple([slice(None) if i != (input.ndim-1) else indices for i in range(input.ndim)]))
def Gather(input: Tensor, indices: Tensor, axis=0):
if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices
input_sh = list(input.shape)
ret_shape = input_sh[:axis] + list(indices.shape) + input_sh[axis+1:]
if indices.ndim > 1: indices = indices.flatten()
indices = [int(safe_numpy(indices))] if indices.shape == () else [input_sh[axis]+int(x) if x<0 else int(x) for x in safe_numpy(indices)]
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(input_sh)] for i in indices]
return input.shrink(arg=tuple(args[0])).cat(*[input.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape)
else: # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
return input.__getitem__(tuple([slice(None) if i != axis else indices for i in range(input.ndim)]))
def GatherElements(input: Tensor, indices: Tensor, axis):
indices = indices.sign().contiguous().__neg__().contiguous().relu() * input.shape[axis] + indices
return input.gather(indices, axis)
def _round(x:Tensor, n:float, equidistant_case = "round_down") -> Tensor:
def _and(cond1, cond2): return ((cond1 + cond2) == 2).where(1, 0)
assert n <= 1, f"n:{n} shouldn't be larger than 1"
b = x.cast(dtypes.int32).contiguous().cast(x.dtype)
b = (b >= 0).where(b+n, b-n)
if equidistant_case == "round_down":
return (x > b).where(b+1-n, b-n)
elif equidistant_case == "round_up":
return (x >= b).where(b+1-n, b-n)
elif equidistant_case == "round_to_even":
x_ceil_fraction = x.ceil()/2
cond_ceil_even = x_ceil_fraction.ceil() == x_ceil_fraction
x = (_and(x == b, cond_ceil_even)).where(x+1-n, x)
x = (x > b).where(b+1-n, b-n)
return x
def Round(X:Tensor): return _round(X, 0.5, "round_to_even")
def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None, coordinate_transformation_mode='half_pixel', cubic_coeff_a=-0.75, exclude_outside=0, extrapolation_value=0.0, keep_aspect_ratio_policy='stretch', mode='nearest', nearest_mode='round_prefer_floor'):
def _nearest_gather(X: Tensor, x_out, y_out): return X[:,:,y_out,:][:,:,:,x_out]
def _nearest_mode(x_resized: Tensor, nearest_mode: str, x_len):
if nearest_mode == "round_prefer_floor": ret = _round(x_resized, 0.5, "round_down")
elif nearest_mode == "round_prefer_ceil": ret = _round(x_resized, 0.5, "round_up")
elif nearest_mode == "floor": ret = x_resized.floor()
elif nearest_mode == "ceil": ret = x_resized.ceil()
return ret.clip(0, x_len-1)
def _coordinate_transformation(x_out, y_out, output_shape, scales_lol, roi=None):
if coordinate_transformation_mode == "half_pixel":
x_out = (x_out + 0.5)/Tensor(scales_lol[-1]) - 0.5 # TODO Tensor() because try (((Tensor([0,1,2,3,4,5])+0.5)/3.5 - 0.5)) with LLVM or METAL, inaccuacy.
y_out = (y_out + 0.5)/Tensor(scales_lol[-2]) - 0.5
elif coordinate_transformation_mode == "align_corners":
x_out = x_out * (X.shape[-1] - 1) / (output_shape[-1] - 1)
y_out = y_out * (X.shape[-2] - 1) / (output_shape[-2] - 1)
elif coordinate_transformation_mode == "asymmetric":
x_out = x_out/scales_lol[-1]
y_out = y_out/scales_lol[-2]
elif coordinate_transformation_mode == "half_pixel_symmetric":
x_out = X.shape[-1] / 2 * (1 - int(output_shape[-1]) / output_shape[-1]) + (x_out + 0.5) / scales_lol[-1] - 0.5
y_out = X.shape[-2] / 2 * (1 - int(output_shape[-2]) / output_shape[-2]) + (y_out + 0.5) / scales_lol[-2] - 0.5
elif coordinate_transformation_mode == "pytorch_half_pixel":
x_out = (x_out + 0.5)/scales_lol[-1] - 0.5 if output_shape[-1] > 1 else Tensor([0])
y_out = (y_out + 0.5)/scales_lol[-2] - 0.5 if output_shape[-2] > 1 else Tensor([0])
elif coordinate_transformation_mode == "tf_crop_and_resize":
x_out = roi[-1][0] * (X.shape[-1] - 1) + x_out * ((roi[-1][1] - roi[-1][0]) * (X.shape[-1] - 1) / (output_shape[-1] - 1)) if output_shape[-1] > 1 else Tensor([0.5 * (roi[-1][0] + roi[-1][1]) * (X.shape[-1] - 1)])
y_out = roi[-2][0] * (X.shape[-2] - 1) + y_out * ((roi[-2][1] - roi[-2][0]) * (X.shape[-2] - 1) / (output_shape[-2] - 1)) if output_shape[-2] > 1 else Tensor([0.5 * (roi[-2][0] + roi[-2][1]) * (X.shape[-2] - 1)])
return x_out.clip(0, X.shape[-1]-1), y_out.clip(0, X.shape[-2]-1)
if roi is not None:
roi = safe_numpy(roi)
roi = [(st,ed) for st, ed in zip(roi[:len(roi)//2], roi[len(roi)//2:])]
roi_ = [(1,1)] * 4
if axes is not None:
for a,r in zip(axes, roi):
roi_[a] = r
roi = roi_
if scales is not None:
scales = safe_numpy(scales).tolist()
if axes is not None:
scales_ = [1]*X.ndim
for a,s in zip(axes, scales):
scales_[a] = s
scales = scales_
elif sizes is not None:
sizes = [int(i) for i in safe_numpy(sizes)]
scales = []
if axes is not None:
sizes_ = [1]*X.ndim
for a,s in zip(axes, sizes):
sizes_[a] = s
scales.append(s/X.shape[a])
sizes = sizes_
else: scales = [si/xs for xs, si in zip(X.shape, sizes)]
if keep_aspect_ratio_policy == "not_larger":
scale = min(scales)
sizes = _round(Tensor(list(X.shape[-2:]))*scale, 0.5, "round_up")
sizes = list(X.shape[:-2]) + [int(i) for i in safe_numpy(sizes)]
elif keep_aspect_ratio_policy == "not_smaller":
scale = max(scales)
sizes = _round(Tensor(list(X.shape[-2:]))*scale, 0.5, "round_up")
sizes = list(X.shape[:-2]) + [int(i) for i in safe_numpy(sizes)]
output_shape = sizes if sizes else [math.floor(x*s) for x,s in zip(X.shape, scales)]
output_shape_ = sizes if sizes else [x*s for x,s in zip(X.shape, scales)]
scales_lol = [os/xs for xs, os in zip(X.shape, output_shape)]
x_out = Tensor.arange(output_shape[-1])
y_out = Tensor.arange(output_shape[-2])
if mode == "nearest":
x_out, y_out = _coordinate_transformation(x_out, y_out, output_shape, scales_lol, roi)
x_out = _nearest_mode(x_out, nearest_mode, X.shape[-1])
y_out = _nearest_mode(y_out, nearest_mode, X.shape[-1])
return _nearest_gather(X, x_out, y_out)
elif mode == "linear":
x_out, y_out = _coordinate_transformation(x_out, y_out, output_shape_, scales, roi)
ret = []
for y in safe_numpy(y_out):
for x in safe_numpy(x_out):
x_floor, y_floor = int(x), int(y)
y_shrink = (0, X.shape[2]) if X.shape[2] == 1 else (y_floor, y_floor+2) if y != y_floor else (y_floor, y_floor+1)
x_shrink = (x_floor, x_floor+2) if x != x_floor else (x_floor, x_floor+1)
shrink_args = ((0, X.shape[0]), (0, X.shape[1]), y_shrink, x_shrink)
corners = safe_numpy(X.shrink(shrink_args))
x1, x2, y1, y2 = x_floor, x_floor+1, y_floor, y_floor+1
if x == x_floor and y == y_floor: # TODO https://en.wikipedia.org/wiki/Bilinear_interpolation#Weighted_mean maybe do weighted mean?
ret.append(corners[0,0,0,0])
elif x == x_floor:
ret.append((corners[0,0,0,0] * (y2 - y) + corners[0,0,1,0] * (y - y1)) / (y2 - y1))
elif y == y_floor:
ret.append((corners[0,0,0,0] * (x2 - x) + corners[0,0,0,1] * (x - x1)) / (x2 - x1))
else:
ret.append((corners[0,0,0,0] * (x2 - x) * (y2 - y) + corners[0,0,0,1] * (x - x1) * (y2 - y) + corners[0,0,1,0] * (x2 - x) * (y - y1) + corners[0,0,1,1] * (x - x1) * (y - y1)) / ((x2 - x1) * (y2 - y1)))
return Tensor(ret).reshape(output_shape)
elif mode == "cubic":
raise Exception("cubic interpolation is not implemented")
def CenterCropPad(input: Tensor, shape: Tensor, axes=None):
if not axes: axes = list(range(input.ndim))
shrink_arg = [(0,i) for i in input.shape]
pad_arg = [(0,0) for _ in range(input.ndim)]
shape = safe_numpy(shape).tolist()
for s, x in zip(shape, axes):
if s < input.shape[x]: shrink_arg[x] = (input.shape[x]//2 - s//2, input.shape[x]//2 + s//2) if s%2 == 0 else (input.shape[x]//2 - s//2 - 1, input.shape[x]//2 + s//2)
elif s > input.shape[x]: pad_arg[x] = ((s - input.shape[x])//2, (s - input.shape[x])//2) if (s - input.shape[x])% 2 == 0 else ((s - input.shape[x])//2, (s - input.shape[x])//2 + 1)
return input.shrink(tuple(shrink_arg)).pad(tuple(pad_arg))
def OneHot(indices: Tensor, depth: Tensor, values: Tensor, axis=-1):
depth = int(safe_numpy(depth).item())
indices, rank = (indices < 0).where(indices+depth, indices), len(indices.shape)
if axis < 0: axis += rank + 1
ls, rs = indices.shape[0:axis], indices.shape[axis: rank]
cond = indices[:,None] == Tensor.arange(depth).reshape((1,) * len(ls) + (depth,) + (1,) * len(rs))
return cond.where(values[1], values[0]).cast(values.dtype)
def Erf(x: Tensor):
sign = x.sign()
x = x.abs()
t = 1.0 / (1.0 + 0.3275911 * x)
term1 = 0.254829592 * t
term2 = -0.284496736 * t ** 2
term3 = 1.421413741 * t ** 3
term4 = -1.453152027 * t ** 4
term5 = 1.061405429 * t ** 5
y = (term1 + term2 + term3 + term4 + term5)
return sign * (1.0 - y * Tensor.exp(-x * x))
def Compress(inp: Tensor, condition: Tensor, axis=None):
if axis is None:
inp = inp.flatten()
axis = 0
axis = axis + inp.ndim if axis < 0 else axis
con_np = safe_numpy(condition)
con = Tensor(np.arange(condition.shape[0])[con_np]) # no boolean indexing in Tensor
return inp.__getitem__(tuple([slice(None) if i != axis else con for i in range(inp.ndim)]))
type_map = {TensorProto.DOUBLE: dtypes.double, TensorProto.FLOAT: dtypes.float32}
def EyeLike(x: Tensor, dtype=None, k=0):
if dtype is None: dtype = x.dtype
else: dtype = type_map[dtype]
shape = x.shape
dim = min(x.shape)
if shape[0] == shape[1]: return Tensor.eye(dim=dim, dtype=dtype)
else:
diff = (shape[0]-dim, shape[1]-dim)
padarg = tuple([(d, d) if d == 0 else (k, d-k) for d in diff])
return Tensor.eye(dim=dim, dtype=dtype).pad(padarg)
def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode)
# Needs work
def IsInf(x,detect_negative=1,detect_positive=1):
ret = (x == float("inf"))*detect_positive + (x == float("-inf"))*detect_negative + Tensor.zeros(*x.shape)
return ret.cast(dtypes.bool)
# Needs work
def DequantizeLinear(x: Tensor, x_scale: Tensor, x_zero_point=0, axis=1):
axis = axis + x.ndim if axis < 0 else axis
x_sc = x_scale.reshape(*[1]*axis, *x_scale.shape, *[1]*(x.ndim - axis - x_scale.ndim))
x_zer = x_zero_point.reshape(*[1]*axis, *x_scale.shape, *[1]*(x.ndim - axis - x_scale.ndim)) if isinstance(x_zero_point, Tensor) else x_zero_point
return (x - x_zer) * x_sc
# Needs work
def IsNaN(x):
return (x < float("-inf")).cast(dtypes.bool)
# **************** com.microsoft Ops ****************
def SkipLayerNormalization(input:Tensor, skip:Tensor, gamma, beta:Optional[Tensor]=None, bias:Optional[Tensor]=None, epsilon=None):
if epsilon is None: epsilon=1e-12
x = input + skip + bias
return x.layernorm(eps=epsilon) * gamma + beta, None, None, x
def FastGelu(x:Tensor, bias:Optional[Tensor]=None):
x = x + bias
return 0.5 * x * (1 + (x * 0.797885 + 0.035677 * x ** 3).tanh())
def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Optional[Tensor]=None, word_embedding:Tensor=None, position_embedding:Tensor=None, segment_embedding:Optional[Tensor]=None, gamma=None, beta=None, mask:Optional[Tensor]=None, position_ids:Optional[Tensor]=None, epsilon=None, mask_index_type=None):
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization
assert (segment_ids is None) is (segment_embedding is None)
assert (mask is None) is (mask_index_type is None)
assert mask is None, "functionality not supported yet" # TODO
input_shape = input_ids.shape
bsz, seq_length = input_shape[0], input_shape[1]
compute_seg_emb = (segment_embedding is not None and segment_ids is not None)
vocab_size, max_position_embeddings, type_vocab_size = word_embedding.shape[0], position_embedding.shape[0], (segment_embedding.shape[0] if compute_seg_emb else None)
def embedding(x:Tensor, vocab_size, weight:Tensor)->Tensor: # TODO from nn.Embedding. Could probably upstream this to Tensor
vocab_counter = Tensor.arange(vocab_size, dtype=x.dtype, requires_grad=False).reshape(1, 1, vocab_size).expand(*x.shape, vocab_size)
return (vocab_counter == x.unsqueeze(2).expand(*x.shape, vocab_size)) @ weight
# bert embedding layer
if epsilon is None: epsilon = 1e-12
if position_ids is None: position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape)
wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding)
pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding)
seg_embedding_res = embedding(segment_ids, type_vocab_size, segment_embedding) if compute_seg_emb else None
embedding_sum = wrd_embedding_res + pos_embedding_res + seg_embedding_res
out = embedding_sum.layernorm(eps=epsilon) * gamma + beta
return out, None, embedding_sum
def Attention(input:Tensor, weights, bias:Optional[Tensor]=None, mask_index:Optional[Tensor]=None, past:Optional[Tensor]=None, relative_position_bias:Optional[Tensor]=None, past_sequence_length:Optional[Tensor]=None, do_rotary=None, mask_filter_value=None, num_heads=None, past_present_share_buffer=None, qkv_hidden_sizes=None, scale=None, unidirectional=None):
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention
assert num_heads is not None # required
assert (qkv_hidden_sizes is None and past is not None) or (qkv_hidden_sizes is not None)
assert relative_position_bias==do_rotary==past_sequence_length==mask_filter_value==past_present_share_buffer==scale==None, "functionality not supported yet" # TODO strange params
hidden_size, v_hidden_size = qkv_hidden_sizes[1:] if qkv_hidden_sizes is not None else 2*(weights.shape[1] // 3,)
if unidirectional: # gpt-style
assert hidden_size == v_hidden_size
xqkv = input.linear(weights, bias)
xq, xk, xv = [xqkv.slice([None, None, (i*hidden_size, (i+1)*hidden_size)]) for i in range(3)]
else: # bert-style
wq, wk, wv = weights[:,:hidden_size], weights[:,hidden_size:hidden_size+v_hidden_size], weights[:,hidden_size+v_hidden_size:]
bq, bk, bv = (bias[:hidden_size], bias[hidden_size:hidden_size+v_hidden_size], bias[hidden_size+v_hidden_size]) if bias is not None else None
xq, xk, xv = [input.linear(w, b) for w, b in zip((wq, wk, wv), (bq, bk, bv))]
xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], num_heads, -1).transpose(1, 2) for x in (xq, xk, xv)]
if past is not None:
xk, xv = Tensor.cat(past[0], xk, dim=-2), Tensor.cat(past[1], xv, dim=-2)
present = Tensor.cat(xk.unsqueeze(0), xv.unsqueeze(0))
def attn(query, key, value, attn_mask):
query_length, key_length = query.shape[-2], key.shape[-2]
cdim = max(query_length, key_length) + 1
attn_weights = query @ key.transpose(-1, -2) / math.sqrt(value.shape[-1])
# This is where Tensor.scaled_dot_product_attention differs:
causal_mask = Tensor.ones((cdim, cdim), requires_grad=False).cast(dtypes.bool).tril(0)[key_length - query_length : key_length, :key_length].cast(dtypes.bool)
return (Tensor.where(causal_mask, attn_weights, -float("inf")) + attn_mask).softmax(-1) @ value
bsz, _, seq_len, _ = xq.shape
out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1)
return out, present
# **************** ai.onnx.preview.training Ops ****************
# TODO not entirely sure these optimizers are correct
def Adagrad(R, T, *inputs, decay_factor=0.0, epsilon=0.0, norm_coefficient=0.0):
groups = len(inputs) // 3
grouped_inputs = [inputs[i::groups] for i in range(groups)]
T, R = safe_numpy(T)[0], safe_numpy(R)[0]
r = R / (1 + T * decay_factor)
ret = []
for input in grouped_inputs:
X, G, H = input
X.grad = norm_coefficient * X + G
X.grad.requires_grad, H.requires_grad = False, False # TODO manually turning off requires_grad, see TODO under (domain == "ai.onnx.preview.training") in onnx.py
H.assign(H.detach() + X.grad * X.grad).realize()
H_adaptive = H.sqrt() + epsilon
X.assign(X.detach() - r * X.grad / H_adaptive)
ret.extend([X, H])
ret = ret[::2] + ret[1::2]
return tuple(ret)
def Momentum(R, T, *inputs, alpha, beta, mode, norm_coefficient):
groups = len(inputs) // 3
grouped_inputs = [inputs[i::groups] for i in range(groups)]
T, R = safe_numpy(T)[0], safe_numpy(R)[0]
beta_adjusted = beta if T > 0 else 1
ret = []
for input in grouped_inputs:
X, G, V = input
X.grad = (norm_coefficient * X + G).realize()
X.grad.requires_grad, V.requires_grad = False, False
V.assign(alpha * V + beta_adjusted * X.grad).realize()
if mode == "standard": X.assign(X.detach() - R * V).realize()
elif mode == "nesterov": X.assign(X.detach() - R * (X.grad + alpha + V)).realize()
ret.extend([X, V])
ret = ret[::2] + ret[1::2]
return tuple(ret)
# copied from tinygrad/nn/optim.py: LAMB with some edits
def Adam(R, T, *inputs, alpha=0.9, beta=0.999, epsilon=0.0, norm_coefficient=0.0, norm_coefficient_post=0.0):
groups = len(inputs) // 4
grouped_inputs = [inputs[i::groups] for i in range(groups)]
T, R = safe_numpy(T)[0], safe_numpy(R)[0]
ret = []
for input in grouped_inputs:
X, G, V, H = input
X.grad = (norm_coefficient * X + G).realize()
V.requires_grad, H.requires_grad, X.grad.requires_grad = False, False, False
V.assign(alpha * V + (1.0 - alpha) * X.grad).realize()
H.assign(beta * H + (1.0 - beta) * (X.grad * X.grad)).realize()
up = (V / (1.0 - alpha**T)) / ((H / (1.0 - beta**T)).sqrt() + epsilon) if T > 0 else V / (H.sqrt() + epsilon)
X.assign(X.detach() - R * up).realize()
X = (1 - norm_coefficient_post) * X
ret.extend([X, V, H])
ret = ret[::3] + ret[1::3] + ret[2::3]
return tuple(ret)