You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
600 lines
38 KiB
600 lines
38 KiB
import functools, io, math
|
|
from typing import Union, Tuple, Optional, List, Any, cast
|
|
from tinygrad.tensor import Tensor, _broadcast_shape, ConstType
|
|
from tinygrad.dtype import ImageDType, dtypes
|
|
from tinygrad.helpers import prod, flatten
|
|
from extra.onnx import dtype_parse, to_python_const
|
|
import numpy as np
|
|
|
|
# **************** Free Ops ****************
|
|
|
|
def Identity(x:Tensor): return x
|
|
# TODO: fix buffer_parse
|
|
def Add(x:Tensor, other:Tensor, broadcast=None, axis=None): return x + other if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else (x + other).cast(x.dtype)
|
|
def Sub(x:Union[Tensor, Any], other:Tensor): return x - other # some test has input as int
|
|
def Less(x:Tensor,y:Tensor): return x < y
|
|
def LessOrEqual(x:Tensor,y:Tensor): return x <= y
|
|
def Greater(x:Tensor,y:Tensor): return x > y
|
|
def GreaterOrEqual(x:Tensor,y:Tensor): return x >= y
|
|
def Equal(x:Tensor,y:Tensor): return x == y
|
|
def BitwiseNot(x:Tensor): return ~x
|
|
def BitwiseOr(x:Tensor, y:Tensor): return x | y
|
|
def BitwiseAnd(x:Tensor, y:Tensor): return x & y
|
|
def BitwiseXor(x:Tensor, y:Tensor): return x ^ y
|
|
def Max(*data_0:Tensor): return functools.reduce(Tensor.maximum, data_0)
|
|
def Min(*data_0:Tensor): return functools.reduce(Tensor.minimum, data_0)
|
|
def Sum(*data_0:Tensor): return functools.reduce(Tensor.add, data_0)
|
|
def Mean(*data_0:Tensor): return Sum(*data_0) / len(data_0)
|
|
# NOTE: does not support saturate
|
|
def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_parse(to))
|
|
def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype)
|
|
|
|
# **************** Simple Ops ****************
|
|
|
|
# https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_div.py
|
|
def Div(x:Tensor, other:Tensor): return (x/other).cast(x.dtype)
|
|
|
|
def Constant(sparse_value:Optional[Tensor]=None, value:Optional[Tensor]=None, value_float:Optional[float]=None,
|
|
value_floats:Optional[List[float]]=None, value_int:Optional[int]=None, value_ints:Optional[List[int]]=None,
|
|
value_string:Optional[str]=None, value_strings:Optional[List[str]]=None):
|
|
if value is not None: return value
|
|
if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False)
|
|
if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False)
|
|
if value_int is not None: return Tensor(value_int, dtype=dtypes.int64, requires_grad=False)
|
|
if value_ints is not None: return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False)
|
|
if value_string is not None or value_strings is not None and sparse_value is not None:
|
|
raise NotImplementedError('Constant OP not implemented for value_string, value_strings and sparse_value')
|
|
|
|
def HardSigmoid(x:Tensor, alpha:float=0.2, beta:float=0.5): return (alpha*x + beta).clip(0, 1)
|
|
def Gelu(x:Tensor, approximate:Optional[str]=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf())
|
|
# TODO: fix this
|
|
def PRelu(X:Tensor, slope:Tensor):
|
|
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE
|
|
return (X > 0).where(X, X * slope)
|
|
def LeakyRelu(X:Tensor, alpha:float=0.01): return X.leakyrelu(alpha)
|
|
def ThresholdedRelu(X:Tensor, alpha:float=1.0): return (X > alpha).where(X, 0)
|
|
def Softmax_1(x:Tensor, axis:int=1): return x.softmax(axis)
|
|
def Softmax_13(x:Tensor, axis:int=-1): return x.softmax(axis)
|
|
Softmax = {1: Softmax_1, 13: Softmax_13} # Softmax default axis changed
|
|
def LogSoftmax(x: Tensor, axis:int=-1): return x.log_softmax(axis)
|
|
def Clip(x: Tensor, min:Optional[Tensor]=None, max:Optional[Tensor]=None):
|
|
return x.clip(float('-inf') if min is None else min, float('inf') if max is None else max).cast(x.dtype)
|
|
|
|
def _axes(axes, noop_with_empty_axes): return axes or ([] if noop_with_empty_axes else None)
|
|
def ReduceMax(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
|
return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
|
def ReduceMin(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
|
return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
|
def ReduceSum(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
|
return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
|
def ReduceMean(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
|
return data.mean(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
|
def ReduceSumSquare(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
|
return ReduceSum(data.square(), axes, keepdims, noop_with_empty_axes)
|
|
def ReduceProd(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
|
return data.prod(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
|
def ReduceL1(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
|
return ReduceSum(data.abs(), axes, keepdims, noop_with_empty_axes)
|
|
def ReduceL2(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
|
return ReduceSumSquare(data, axes, keepdims, noop_with_empty_axes).sqrt()
|
|
def ReduceLogSum(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
|
return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log()
|
|
def ReduceLogSumExp(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
|
return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log()
|
|
|
|
def GlobalAveragePool(X:Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True)
|
|
def GlobalMaxPool(X:Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True)
|
|
def OptionalHasElement(x:Optional[Tensor]=None): return Tensor(x is not None and x.numel() > 0)
|
|
def OptionalGetElement(x:Optional[Tensor]=None): return x if x is not None else Tensor([])
|
|
|
|
def Tile(x:Tensor, repeats:List[int]): return x.repeat(repeats)
|
|
def Range(start:Union[float, int], limit:Union[float, int], delta:Union[float, int]): return Tensor.arange(start=start, stop=limit, step=delta)
|
|
def Shape(data:Tensor, end:Optional[int]=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64)
|
|
def Size(data:Tensor): return prod(data if isinstance(data, list) else data.shape)
|
|
def Flatten(x:Tensor, axis:int=1): return x.reshape(prod(x.shape[0:axis]), -1)
|
|
def Reshape(data:Tensor, shape:List[int], allowzero:int=0):
|
|
return data.reshape([x if x != 0 else (0 if allowzero else data.shape[i]) for i,x in enumerate(shape)])
|
|
def Expand(x:Tensor, shape:List[int]): return x.expand(_broadcast_shape(x.shape, tuple(shape)))
|
|
def Shrink(x:Tensor, bias:float=0.0, lambd:float=0.5): return (x < -lambd)*(x+bias) + (x > lambd)*(x-bias)
|
|
def And(x:Tensor, y:Tensor): return (x==y).where(x, False)
|
|
def Or(x:Tensor, y:Tensor): return (x==y).where(x, True)
|
|
def Not(x:Tensor): return x.logical_not()
|
|
|
|
def Trilu(x:Tensor, k:int=0, upper:int=1): return x.triu(k) if upper else x.tril(k)
|
|
|
|
def Slice(data:Tensor, starts:List[int], ends:List[int], axes:Optional[List[int]]=None, steps:Optional[List[int]]=None):
|
|
axes = axes or list(range(data.ndim))
|
|
steps = steps or [1]*data.ndim
|
|
slices = [slice(0,x,1) for x in data.shape]
|
|
for i, axis in enumerate(axes): slices[axis] = slice(starts[i], ends[i], steps[i])
|
|
return data[tuple(slices)]
|
|
|
|
# TODO: add test for when axes is None
|
|
def Squeeze(data:Tensor, axes:Optional[List[int]]=None):
|
|
return data.squeeze() if axes is None else functools.reduce(lambda d, dim: d.squeeze(dim), sorted(axes, reverse=True), data)
|
|
def Unsqueeze(data:Tensor, axes:List[int]): return functools.reduce(lambda d, dim: d.unsqueeze(dim), sorted(axes), data)
|
|
|
|
def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float()
|
|
|
|
def ArgMax(x:Tensor, axis:int=0, keepdims:int=1, select_last_index:int=0):
|
|
if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64)
|
|
return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64)
|
|
def ArgMin(x, axis:int=0, keepdims:int=1, select_last_index:int=0): return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index)
|
|
|
|
def Concat(*xs:Tensor, axis:int): return Tensor.cat(*xs, dim=axis)
|
|
def Transpose(x:Tensor, perm:Optional[List[int]]=None): return x.permute(order=list(range(x.ndim)[::-1]) if perm is None else perm)
|
|
|
|
def ConstantOfShape(shape:List[int], value:Optional[Tensor]=None):
|
|
if value is None: value = Tensor(0, dtype=dtypes.float32)
|
|
return Tensor.ones(*shape, dtype=value.dtype) * (value if shape != [0] else 1)
|
|
|
|
# **************** Complex Ops ****************
|
|
|
|
def Gemm(A:Tensor, B:Tensor, C:Optional[Tensor]=None, alpha:float=1.0, beta:float=1.0, transA:int=0, transB:int=0, broadcast=0):
|
|
ret = alpha * (A.transpose(transA) @ B.transpose(transB))
|
|
if C is not None: ret = ret + beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1]))
|
|
return ret
|
|
|
|
def Einsum(*Inputs:List[Tensor], equation:str): return Tensor.einsum(equation, Inputs)
|
|
|
|
def CumSum(X:Tensor, axis:int, exclusive:int=0, reverse:int=0):
|
|
axis = X._resolve_dim(axis)
|
|
if reverse: X = X.flip(axis)
|
|
if exclusive: X = X.pad(tuple((1,0) if i == axis else None for i in range(X.ndim)))\
|
|
.shrink(tuple((0,X.shape[axis]) if i == axis else None for i in range(X.ndim)))
|
|
return X.cumsum(axis).flip(axis) if reverse else X.cumsum(axis)
|
|
|
|
# TODO: this is copied from tinygrad/nn/__init__.py
|
|
# spatial is from opset 7 and has since been removed
|
|
def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9 ,
|
|
training_mode:int=0, spatial=1, is_test=0):
|
|
if training_mode:
|
|
x_detached = X.detach()
|
|
current_mean = x_detached.mean(axis=(0,2,3))
|
|
y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1]))
|
|
current_var = (y*y).mean(axis=(0,2,3))
|
|
current_invstd = current_var.add(epsilon).rsqrt()
|
|
|
|
running_mean = input_mean * momentum + current_mean * (1 - momentum)
|
|
running_var = input_var * momentum + current_var * (1 - momentum)
|
|
|
|
return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var
|
|
invstd = (input_var + epsilon).rsqrt()
|
|
return X.batchnorm(scale, B, input_mean, invstd)
|
|
|
|
def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05):
|
|
axis = tuple(range(2, x.ndim))
|
|
mean = x.mean(axis=axis, keepdim=True)
|
|
invstd = x.sub(mean).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
|
|
return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1]))
|
|
|
|
def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1):
|
|
assert stash_type == 1, "only float32 is supported"
|
|
axis = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim))
|
|
mean = x.mean(axis=axis, keepdim=True)
|
|
return x.layernorm(axis, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
|
|
|
|
def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05):
|
|
return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape)
|
|
|
|
# onnx: [x1_begin, x2_begin, ..., x1_end, x2_end, ...]
|
|
# numpy.pad: ((x1_begin, x1_end), (x2_begin, x2_end), ...)
|
|
def _format_padding(onnx_pads, ndims=None, axes=None):
|
|
if ndims and len(onnx_pads)//2 != ndims: onnx_pads = onnx_pads * ndims # for OnnxBackendPyTorchConvertedModelTest the len(onnx_pads) == 2
|
|
if ndims is None: ndims = len(onnx_pads) // 2
|
|
if axes is None: axes = list(range(ndims))
|
|
num_axes = len(axes)
|
|
np_pads = [(0,0)] * ndims
|
|
for i in range(num_axes):
|
|
np_pads[axes[i]] = (onnx_pads[i], onnx_pads[i + num_axes])
|
|
return np_pads
|
|
|
|
def _padded(X: Tensor, pads=None, auto_pad="NOTSET", axes=None, constant_value=0., strides=None, kernel_shape=None, dilations=None, ceil_mode=0):
|
|
if auto_pad != "NOTSET": pads = _auto_pad(X, auto_pad, strides, kernel_shape, dilations)
|
|
elif ceil_mode:
|
|
if strides is not None: strides = [strides]*len(kernel_shape) if isinstance(strides, int) else strides if strides else [1]*len(kernel_shape)
|
|
if dilations is not None: dilations = [1]*len(kernel_shape) if dilations == 1 else dilations
|
|
out_spatial_shape = [math.ceil((sh - dil * (ker-1)-1)/st + 1) if ceil_mode else math.floor((sh - dil * (ker-1)-1)/st + 1) for sh, st, ker, dil in zip(X.shape[-len(kernel_shape):], strides, kernel_shape, dilations)]
|
|
pad_shape = [(osh-1)*st+((ks-1)*dil+1)-ish for osh, st, ks, dil, ish in zip(out_spatial_shape, strides, kernel_shape, dilations, X.shape[-len(kernel_shape):])]
|
|
pad_shape = [[sh//2, sh-sh//2] for sh in pad_shape]
|
|
# ceil_mode case follows NOTE in https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
|
|
# so if any kernels start in right padded region, we decrease right pads to omit that kernel. Only omitting 1 kernel now.
|
|
pad_shape = [[start,end-rpad] if (rpad := ks + st%(st-(((start+xs)%st)))) <= end else [start,end]
|
|
for (start,end), ks, st, xs in zip(pad_shape, kernel_shape, strides, X.shape[-len(kernel_shape):])]
|
|
pad_shape = flatten(pad_shape)
|
|
pads = pad_shape[::2] + pad_shape[1::2]
|
|
if pads is None: return X
|
|
pads = _format_padding(pads, ndims=len(X.shape), axes=axes)
|
|
return X.pad(tuple(pads), value=constant_value)
|
|
|
|
def _auto_pad(X: Tensor, auto_pad, strides, kernel_shape, dilations):
|
|
strides = [strides]*len(kernel_shape) if isinstance(strides, int) else strides if strides else [1]*len(kernel_shape)
|
|
dilations = [1]*len(kernel_shape) if dilations == 1 else dilations
|
|
if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
|
|
pad_shape = [(math.ceil(sh/st)-1)*st+((ks-1)*di+1)-sh for sh, st, ks, di in zip(X.shape[-len(kernel_shape):], strides, kernel_shape, dilations)]
|
|
pad_shape = flatten([[sh//2, sh-sh//2] for sh in pad_shape])
|
|
return pad_shape[::2] + pad_shape[1::2] if auto_pad == "SAME_UPPER" else pad_shape[1::2] + pad_shape[::2]
|
|
raise NotImplementedError(f"auto_pad={auto_pad} not implemented")
|
|
|
|
# (x1_begin, x2_begin, ..., x1_end, x2_end, ...) -> (..., x2_start, x2_end, x1_start, x1_end)
|
|
def _onnx_pads_to_pad2d_pads(pads): return flatten(reversed(list((pB, pE) for pB, pE in zip(pads, pads[len(pads)//2:]))))
|
|
|
|
def Pad(x:Tensor, pads:List[int], constant_value:Optional[ConstType]=None, axes:Optional[List[int]]=None, mode:str="constant", value=0):
|
|
value, axes = constant_value or value or 0, axes or list(range(x.ndim))
|
|
real_pads = [0] * (x.ndim*2)
|
|
for i,axis in enumerate(axes): real_pads[axis%x.ndim], real_pads[axis%x.ndim+x.ndim] = pads[i], pads[i+len(axes)]
|
|
return x.pad(padding=_onnx_pads_to_pad2d_pads(real_pads), mode={"edge":"replicate", "wrap":"circular"}.get(mode, mode), value=value)
|
|
|
|
def AveragePool(X: Tensor, kernel_shape, auto_pad="NOTSET", ceil_mode=0, count_include_pad=0, dilations=1, pads=None, strides=1):
|
|
pixel_axes = tuple(range(2, X.ndim))
|
|
ret = _padded(X, pads, auto_pad, axes=pixel_axes, strides=strides, kernel_shape=kernel_shape, dilations=dilations, ceil_mode=ceil_mode)
|
|
ret = ret.avg_pool2d(kernel_shape, stride=strides, dilation=dilations)
|
|
if count_include_pad: return ret
|
|
div = _padded(Tensor.ones(X.shape), pads, auto_pad, axes=pixel_axes, strides=strides, kernel_shape=kernel_shape, dilations=dilations, ceil_mode=ceil_mode).avg_pool2d(kernel_shape, stride=strides, dilation=dilations)
|
|
return ret / div
|
|
|
|
def MaxPool(X: Tensor, kernel_shape, auto_pad="NOTSET", ceil_mode=0, dilations=1, pads=None, storage_order=0, strides=1):
|
|
pixel_axes = tuple(range(2, X.ndim))
|
|
ret = _padded(X, pads, auto_pad, constant_value=-math.inf, axes=pixel_axes, strides=strides, kernel_shape=kernel_shape, dilations=dilations, ceil_mode=ceil_mode)
|
|
ret = ret.max_pool2d(kernel_shape, stride=strides, dilation=dilations).cast(X.dtype)
|
|
ret_len, X_len = ret.numel(), X.numel()
|
|
indices = ((ret.flatten().unsqueeze(1).expand(ret_len, X_len) == X.flatten().unsqueeze(0).expand(ret_len, X_len)) * \
|
|
Tensor.arange(X_len, dtype=dtypes.int64).unsqueeze(0).expand(ret_len, X_len)).sum(1).reshape(ret.shape)
|
|
if storage_order: indices = indices.transpose(-2, -1)
|
|
return ret, indices
|
|
|
|
def MaxUnpool(xT: Tensor, xI: Tensor, outshape: Optional[Tensor]=None, kernel_shape=None, pads=None, strides=None):
|
|
out_sh = [(ks//2)*2 + st * inps for inps, st, ks in zip(xI.shape, strides, kernel_shape)]
|
|
outlength = prod(out_sh)
|
|
xI = xI.flatten().unsqueeze(1).expand(None, outlength)
|
|
arange = Tensor.arange(outlength, requires_grad=False).reshape(1, outlength).expand(xI.shape)
|
|
xT = xT.flatten().unsqueeze(1).expand(None, outlength)
|
|
ret = ((xI == arange) * xT).sum(0).reshape([1, 1] + out_sh)
|
|
if outshape is not None and outshape != ret.shape:
|
|
diff = [outshape[2] - ret.shape[2], outshape[3] - ret.shape[3]]
|
|
pad_args = [diff[0]//2, diff[1]//2, diff[0]-diff[0]//2, diff[1]-diff[1]//2]
|
|
ret = ret.pad((pad_args[1], pad_args[3], pad_args[0], pad_args[2]))
|
|
return ret
|
|
|
|
def Conv(X: Tensor, W: Tensor, B:Optional[Tensor]=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, strides=1):
|
|
if auto_pad != "NOTSET":
|
|
padding = _auto_pad(X, auto_pad, strides, kernel_shape, dilations)
|
|
else:
|
|
# reorder padding
|
|
padding = [p for ps in zip(pads[:len(pads)//2][::-1], pads[len(pads)//2:][::-1]) for p in ps] if pads is not None else 0
|
|
return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, padding=padding)
|
|
|
|
def ConvTranspose(X: Tensor, W: Tensor, B:Optional[Tensor]=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, output_shape=None, output_padding=0, strides=1):
|
|
if kernel_shape is None: kernel_shape = W.shape[2:]
|
|
if isinstance(strides, int): strides = [strides]*(W.ndim-2)
|
|
if isinstance(dilations, int): dilations = [dilations]*(W.ndim-2)
|
|
if isinstance(output_padding, int): output_padding = [output_padding]*(W.ndim-2)
|
|
out_sh = [st*(xs-1) + (ks-1)*di+1 if n < 2 else st*(xs-1) + (ks-1)*di+1 - pads[n-2] - pads[n-1] for n, (st, xs, ks, di) in enumerate(zip(strides, X.shape[2:], kernel_shape, dilations))] if output_shape is not None or auto_pad != "NOTSET" else []
|
|
if pads is None:
|
|
if output_shape is None: output_shape = [xs*st for xs, st in zip(X.shape[2:], strides)]
|
|
if auto_pad == "NOTSET": pads = [0,0] * (X.ndim - 2)
|
|
else:
|
|
total_padding = [st*(ish-1) + pad + ((ks-1)*dil+1)-osh for st, ish, pad, ks, dil, osh in zip(strides, X.shape[2:], output_padding, kernel_shape, dilations, output_shape)]
|
|
pad_shape = flatten([[sh//2, sh-sh//2] for sh in total_padding])
|
|
pads = pad_shape[::2] + pad_shape[1::2] if auto_pad == "SAME_UPPER" else pad_shape[1::2] + pad_shape[::2]
|
|
else:
|
|
if output_shape is None: output_shape = [st*(xs-1) + (ks-1)*di+1 if n < 2 else st*(xs-1) + (ks-1)*di+1 - pads[n-2] - pads[n-1] for n, (st, xs, ks, di) in enumerate(zip(strides, X.shape[2:], kernel_shape, dilations))]
|
|
if out_sh: output_padding = [os - rs for os, rs in zip(output_shape, out_sh)]
|
|
return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=pads if pads is not None else 0, output_padding=output_padding)
|
|
|
|
def DepthToSpace(X:Tensor, blocksize:int, mode:str="DCR"):
|
|
return X.rearrange("b (c h1 w1) h w -> b c (h h1) (w w1)" if mode=="CRD" else "b (h1 w1 c) h w -> b c (h h1) (w w1)", h1=blocksize, w1=blocksize)
|
|
def SpaceToDepth(X:Tensor, blocksize:int):
|
|
return X.rearrange("b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=blocksize, w1=blocksize)
|
|
|
|
# Reimplemented here because you need legacy RNG for passing ONNX tests.
|
|
def Dropout(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:Optional[int]=None):
|
|
if not training_mode: return data, Tensor.ones(data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's.
|
|
mask = Tensor(np.random.RandomState(seed).random(cast(Tuple[int,...], data.shape)) >= ratio, requires_grad=False, device=data.device)
|
|
return data * mask * (1/(1.0 - ratio)), mask
|
|
|
|
def LRN(x:Tensor, size:int, alpha:float=1e-4, beta:float=0.75, bias:float=1.0):
|
|
pooled_x = (x**2).rearrange('b c h w -> b 1 c (h w)').pad((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1)
|
|
return x / (pooled_x.reshape(x.shape) * alpha + bias).pow(beta)
|
|
|
|
def MeanVarianceNormalization(x:Tensor, axis:List[int]=[0,2,3]):
|
|
return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9)
|
|
|
|
def NegativeLogLikelihoodLoss(x:Tensor, target:Tensor, weight:Optional[Tensor]=None, ignore_index:Optional[int]=None, reduction:str="mean"):
|
|
return x.nll_loss(target, weight, ignore_index, reduction)
|
|
|
|
def SoftmaxCrossEntropyLoss(scores:Tensor, labels:Tensor, weights:Optional[Tensor]=None, ignore_index:Optional[int]=None, reduction:str="mean"):
|
|
log_probs = scores.log_softmax(1)
|
|
return log_probs.nll_loss(labels, weights, ignore_index, reduction), log_probs
|
|
|
|
def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices]
|
|
|
|
def Gather(x:Tensor, indices:Tensor, axis:int=0):
|
|
if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices
|
|
x_sh = list(x.shape)
|
|
ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:]
|
|
if indices.ndim > 1: indices = indices.flatten()
|
|
indices = [to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in to_python_const(indices)]
|
|
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices]
|
|
return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape)
|
|
# NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
|
|
return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])]
|
|
def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated
|
|
|
|
def GatherND(x:Tensor, indices:Tensor, batch_dims:int=0):
|
|
if batch_dims == 0: return x[tuple(i.squeeze(-1) for i in indices.split(1, -1))]
|
|
x_shape, i_shape = x.shape, indices.shape
|
|
b = math.prod(x.shape[dim] for dim in range(batch_dims))
|
|
# NOTE: each batched dim of both input and indices are equal
|
|
x = x.reshape(b, *x.shape[batch_dims:])
|
|
indices = indices.reshape(b, *indices.shape[batch_dims:])
|
|
b_idx = Tensor.arange(b, device=x.device).reshape(b, *(1,)*(indices.ndim - 2)).expand(*indices.shape[:-1])
|
|
ret = x[(b_idx,) + tuple(i.squeeze(-1) for i in indices.split(1, -1))]
|
|
return ret.reshape(*x_shape[:batch_dims], *i_shape[batch_dims:-1], *ret.shape[indices.ndim-1:])
|
|
def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Optional[str]=None):
|
|
assert updates.shape == indices.shape[:-1] + x.shape[indices.shape[-1]:]
|
|
x = x.contiguous()
|
|
for idx, u in zip(indices.split(1, 0), updates.split(1, 0)):
|
|
idx = tuple(i.squeeze(-1) for i in idx.squeeze(0).split(1, -1))
|
|
u = u.squeeze(0)
|
|
if reduction is None: x[idx] = u
|
|
elif reduction == "add": x[idx] += u
|
|
elif reduction == "mul": x[idx] *= u
|
|
else: raise NotImplementedError("reduction doesn't support max or min")
|
|
return x
|
|
|
|
def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Optional[str]=None):
|
|
indices = (indices < 0).where(x.shape[axis], 0) + indices
|
|
return x.scatter(axis, indices, updates, reduction)
|
|
def GatherElements(x:Tensor, indices:Tensor, axis:int):
|
|
indices = (indices < 0).where(x.shape[axis], 0) + indices
|
|
return x.gather(axis, indices)
|
|
|
|
def Resize(X:Tensor, roi:Optional[List[float]]=None, scales:Optional[List[float]]=None, sizes:Optional[List[int]]=None, antialias:int=0,
|
|
axes:Optional[List[int]]=None, coordinate_transformation_mode:str='half_pixel', cubic_coeff_a:float=-0.75, exclude_outside:int=0,
|
|
extrapolation_value:float=0.0, keep_aspect_ratio_policy:str='stretch', mode:str='nearest', nearest_mode:str='round_prefer_floor'):
|
|
def _apply_nearest_mode(index: Tensor, input_dim, mode: str):
|
|
if mode == "round_prefer_floor": index = (index - 0.5).ceil()
|
|
elif mode == "round_prefer_ceil": index = (index + 0.5).floor()
|
|
elif mode in ["floor", "ceil"]: index = getattr(index, mode)()
|
|
else: raise ValueError(f"invalid {nearest_mode=}")
|
|
return index.cast(dtypes.int32).clip(0, input_dim-1)
|
|
def _apply_transformation(index: Tensor, input_dim, scale_dim, roi_dim, sizes_frac, mode):
|
|
# TODO: needs more testing, not confident in this
|
|
# NOTE: their reference implementation differ from the implementation in their reference docs
|
|
# https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_resize.py
|
|
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize
|
|
output_dim = scale_dim * input_dim
|
|
if mode == "half_pixel": index = (index + 0.5) / scale_dim - 0.5
|
|
elif mode == "align_corners": index = index * (input_dim - 1) / (output_dim - 1) if output_dim != 1 else Tensor([0])
|
|
elif mode == "asymmetric": index = index / scale_dim
|
|
elif mode == "pytorch_half_pixel": index = (index + 0.5) / scale_dim - 0.5 if output_dim != 1 else Tensor([-0.5])
|
|
elif mode == "half_pixel_symmetric": index = input_dim / 2 * (1 - int(output_dim) / sizes_frac) + (index + 0.5) / scale_dim - 0.5
|
|
elif mode == "tf_crop_and_resize": index = roi_dim[0] * (input_dim - 1) + index * ((roi_dim[1] - roi_dim[0]) * (input_dim - 1) / (output_dim - 1)) # noqa: E501
|
|
else: raise ValueError(f"invalid {coordinate_transformation_mode=}")
|
|
return index.clip(0, input_dim-1)
|
|
|
|
scales, sizes = (None if scales is None else scales[-2:]), (None if sizes is None else sizes[-2:])
|
|
# we pre permute the axes and permute back after resize
|
|
axes, input_shape, = (axes or list(range(X.ndim))), X.shape[2:],
|
|
perm = [a for a in range(len(X.shape)) if a not in axes] + list(axes)
|
|
X = X.permute(*perm)
|
|
|
|
if sizes is not None:
|
|
if keep_aspect_ratio_policy in ["not_larger", "not_smaller"]:
|
|
scale_fxn = min if keep_aspect_ratio_policy == "not_larger" else max
|
|
scales = scale_fxn([sizes[i] / input_shape[i] for i in range(X.ndim-2) if i+2 in axes])
|
|
sizes = [int((scales * input_shape[i]) + 0.5) if i+2 in axes else input_shape[i] for i in range(X.ndim-2)]
|
|
else: scales = [sizes[-2] / X.size(-2), sizes[-1] / X.size(-1)]
|
|
else: sizes = [int(sc*sh) for sc, sh in zip(scales, input_shape)]
|
|
scales = [scales] * 2 if not isinstance(scales, list) else scales
|
|
roi = [[st, ed] for st, ed in zip(roi, roi[len(roi)//2:])] if isinstance(roi, list) else [None] * (X.ndim-2)
|
|
|
|
# NOTE: this transformation makes it so that we can't just call Tensor.interpolate
|
|
# in Tensor.interpolate, we use indexes without any transformation
|
|
indexes = []
|
|
for shape, size, scale, region in zip(input_shape, sizes, scales, roi):
|
|
indexes.append(_apply_transformation(Tensor.arange(size), shape, scale, region, shape * scale, coordinate_transformation_mode))
|
|
|
|
if mode == "nearest":
|
|
indexes = [_apply_nearest_mode(index, shape, nearest_mode) for (index, shape) in zip(indexes, input_shape)]
|
|
X = X[(..., *Tensor.meshgrid(*indexes))]
|
|
if mode == "linear":
|
|
expand = list(X.shape)
|
|
for i in range(-len(sizes), 0):
|
|
reshape, index = [1] * X.ndim, indexes[i]
|
|
reshape[i] = expand[i] = sizes[i]
|
|
low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())]
|
|
X = X.gather(i, low).lerp(X.gather(i, high), perc)
|
|
if mode == "cubic": raise NotImplementedError("cubic interpolation is not implemented")
|
|
return X.permute(*[perm.index(i) for i in range(len(perm))]) if perm else X
|
|
|
|
def CenterCropPad(t:Tensor, shape:List[int], axes:Optional[List[int]]=None):
|
|
shrink_arg = [None] * t.ndim
|
|
pad_arg = [None] * t.ndim
|
|
for s, x in zip(shape, axes or range(t.ndim)):
|
|
tx = t.shape[x]
|
|
if s < tx: shrink_arg[x] = (tx//2 - (s+1)//2, tx//2 + s//2)
|
|
elif s > tx: pad_arg[x] = ((s-tx)//2, (s-tx+1)//2)
|
|
return t.shrink(tuple(shrink_arg)).pad(tuple(pad_arg))
|
|
|
|
def OneHot(indices:Tensor, depth:Union[int, float], values:Tensor, axis:int=-1):
|
|
# Scalar or Rank 1 tensor containing exactly one element
|
|
depth = int(depth)
|
|
indices = (indices < 0).where(indices+depth, indices)
|
|
return indices[:, None]._one_hot_along_dim(depth, dim=axis).where(values[1], values[0])
|
|
|
|
def Compress(inp:Tensor, condition:List[bool], axis:Optional[int]=None):
|
|
if axis is None:
|
|
inp = inp.flatten()
|
|
axis = 0
|
|
if axis < 0: axis += inp.ndim
|
|
con = Tensor(np.arange(len(condition))[condition]) # no boolean indexing in Tensor
|
|
return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))]
|
|
|
|
def EyeLike(x:Tensor, dtype:Optional[int]=None, k:int=0):
|
|
ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype) if dtype is not None else x.dtype)
|
|
return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.size(0)-k) for d in x.shape))
|
|
|
|
def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode) # deprecated
|
|
|
|
def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Union[Tensor, int] = 0, axis:int=1, block_size:int=0):
|
|
if axis < 0: axis += x.ndim
|
|
if not isinstance(x_zero_point, Tensor): x_zero_point = Tensor(x_zero_point)
|
|
if block_size: x_zer, x_sc = x_zero_point.repeat_interleave(block_size, axis), x_scale.repeat_interleave(block_size, axis)
|
|
else:
|
|
shape = (*[1]*axis, *x_scale.shape, *[1]*(x.ndim - axis - x_scale.ndim))
|
|
x_sc, x_zer = x_scale.reshape(shape), x_zero_point.reshape(shape)
|
|
return ((x.float() - x_zer) * x_sc).cast(x_scale.dtype)
|
|
|
|
# copied from https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_image_decoder.py
|
|
def ImageDecoder(encoded_stream:bytes, pixel_format="RGB"):
|
|
try: import PIL.Image
|
|
except ImportError as e: raise ImportError("Pillow must be installed to use the reference implementation of the ImageDecoder operator") from e
|
|
img = PIL.Image.open(io.BytesIO(encoded_stream))
|
|
if pixel_format == "BGR": return Tensor(np.array(img))[:, :, ::-1]
|
|
if pixel_format == "RGB": return Tensor(np.array(img))
|
|
if pixel_format == "Grayscale": return Tensor(np.array(img.convert("L"))).unsqueeze(-1) # (H, W) to (H, W, 1)
|
|
raise ValueError(f"pixel_format={pixel_format!r} is not supported.")
|
|
|
|
def AffineGrid(theta:Tensor, size:Tensor, align_corners:int=0):
|
|
N, _, *spatial_dims = size
|
|
def generate_grid(steps):
|
|
return Tensor.linspace(-1, 1, steps, device=theta.device) if align_corners else Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device)
|
|
grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims))
|
|
base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1)
|
|
base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1)
|
|
return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1)
|
|
|
|
# **************** com.microsoft Ops ****************
|
|
|
|
def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma:Tensor, beta:Optional[Tensor]=None, bias:Optional[Tensor]=None, epsilon:float=1e-12):
|
|
x = x + skip + bias
|
|
return x.layernorm(eps=epsilon) * gamma + beta, None, None, x
|
|
|
|
def FastGelu(x:Tensor, bias:Optional[Tensor]=None):
|
|
# this is tanh approximated
|
|
return (x + bias).gelu() if bias is not None else x.gelu()
|
|
|
|
def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Optional[Tensor]=None, word_embedding:Tensor=None, position_embedding:Tensor=None,
|
|
segment_embedding:Optional[Tensor]=None, gamma=None, beta=None, mask:Optional[Tensor]=None,
|
|
position_ids:Optional[Tensor]=None, epsilon=1e-12, mask_index_type=0):
|
|
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization
|
|
assert (segment_ids is None) is (segment_embedding is None)
|
|
assert mask is None and not mask_index_type, "functionality not supported yet" # TODO
|
|
input_shape = input_ids.shape
|
|
seq_length = input_shape[1]
|
|
compute_seg_emb = (segment_embedding is not None and segment_ids is not None)
|
|
vocab_size, max_position_embeddings = word_embedding.shape[0], position_embedding.shape[0]
|
|
type_vocab_size = (segment_embedding.shape[0] if compute_seg_emb else None)
|
|
|
|
def embedding(x:Tensor, vocab_size, weight:Tensor) -> Tensor:
|
|
return x.unsqueeze(-1).expand(*x.shape, vocab_size)._one_hot_along_dim(vocab_size) @ weight
|
|
|
|
# bert embedding layer
|
|
if position_ids is None: position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape)
|
|
wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding)
|
|
pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding)
|
|
seg_embedding_res = embedding(segment_ids, type_vocab_size, segment_embedding) if compute_seg_emb else None
|
|
|
|
embedding_sum = wrd_embedding_res + pos_embedding_res
|
|
if seg_embedding_res is not None: embedding_sum = embedding_sum + seg_embedding_res
|
|
out = embedding_sum.layernorm(eps=epsilon) * gamma + beta
|
|
return out, None, embedding_sum
|
|
|
|
def Attention(x:Tensor, weights, bias:Optional[Tensor]=None, mask_index:Optional[Tensor]=None, past:Optional[Tensor]=None,
|
|
relative_position_bias:Optional[Tensor]=None, past_sequence_length:Optional[Tensor]=None, do_rotary:Optional[int]=None,
|
|
mask_filter_value:Optional[float]=None, num_heads:Optional[int]=None, past_present_share_buffer:Optional[int]=None,
|
|
qkv_hidden_sizes:Optional[List[int]]=None, scale:Optional[float]=None, unidirectional:Optional[int]=None):
|
|
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention
|
|
assert num_heads is not None # required
|
|
assert (qkv_hidden_sizes is None and past is not None) or (qkv_hidden_sizes is not None)
|
|
assert relative_position_bias==do_rotary==past_sequence_length==mask_filter_value==past_present_share_buffer==scale==None, \
|
|
"functionality not supported yet" # TODO strange params
|
|
hidden_size, v_hidden_size = qkv_hidden_sizes[1:] if qkv_hidden_sizes is not None else 2*(weights.shape[1] // 3,)
|
|
|
|
if unidirectional: # gpt-style
|
|
assert hidden_size == v_hidden_size
|
|
xqkv = x.linear(weights, bias)
|
|
xq, xk, xv = [xqkv.shrink([None, None, (i*hidden_size, (i+1)*hidden_size)]) for i in range(3)]
|
|
else: # bert-style
|
|
wq, wk, wv = weights[:,:hidden_size], weights[:,hidden_size:hidden_size+v_hidden_size], weights[:,hidden_size+v_hidden_size:]
|
|
bq, bk, bv = (bias[:hidden_size], bias[hidden_size:hidden_size+v_hidden_size], bias[hidden_size+v_hidden_size]) if bias is not None else None
|
|
xq, xk, xv = [x.linear(w, b) for w, b in zip((wq, wk, wv), (bq, bk, bv))]
|
|
xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], num_heads, -1).transpose(1, 2) for x in (xq, xk, xv)]
|
|
|
|
if past is not None:
|
|
xk, xv = Tensor.cat(past[0], xk, dim=-2), Tensor.cat(past[1], xv, dim=-2)
|
|
present = Tensor.cat(xk.unsqueeze(0), xv.unsqueeze(0))
|
|
|
|
def attn(query, key, value, attn_mask):
|
|
query_length, key_length = query.shape[-2], key.shape[-2]
|
|
cdim = max(query_length, key_length) + 1
|
|
attn_weights = query @ key.transpose(-1, -2) / math.sqrt(value.shape[-1])
|
|
# This is where Tensor.scaled_dot_product_attention differs:
|
|
causal_mask = Tensor.ones((cdim, cdim), requires_grad=False, dtype=dtypes.bool).tril(0)[key_length - query_length : key_length, :key_length]
|
|
masked = Tensor.where(causal_mask, attn_weights, -math.inf)
|
|
if attn_mask is not None: masked = masked + attn_mask
|
|
return masked.softmax(-1) @ value
|
|
|
|
bsz, _, seq_len, _ = xq.shape
|
|
out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1)
|
|
return out, present if past is not None else out
|
|
|
|
# **************** ai.onnx.preview.training Ops ****************
|
|
# NOTE: onnx test coverage only covers `T==0` cases, so for all `T>0` this isn't tested
|
|
# NOTE: onnx training ops actually don't need the state for optim, all the ops work in a functional way, but we still can reuse optim.py code
|
|
|
|
from tinygrad.nn.optim import Adam as TinyAdam
|
|
from tinygrad.nn.optim import SGD
|
|
|
|
def onnx_training(input_group_size):
|
|
def _decorator(func):
|
|
def __wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs):
|
|
old_training = Tensor.training
|
|
Tensor.training = True
|
|
R = R.detach()
|
|
groups = len(inputs) // input_group_size
|
|
ret = [func(R, T, *inps, **kwargs) for inps in (inputs[i::groups] for i in range(groups))]
|
|
Tensor.training = old_training
|
|
return tuple(flatten(zip(*ret)))
|
|
return __wrapper
|
|
return _decorator
|
|
|
|
@onnx_training(3)
|
|
def Adagrad(R:Tensor, T:int, *inputs:Tensor, decay_factor:float=0.0, epsilon:float=0.0, norm_coefficient:float=0.0):
|
|
X, G, H = (i.detach() for i in inputs)
|
|
grad = norm_coefficient * X + G
|
|
H.assign(H + grad.square())
|
|
up = grad / (H.sqrt() + epsilon)
|
|
r = R / (1 + T * decay_factor)
|
|
X.assign(X.detach() - r * up)
|
|
return [X, H]
|
|
|
|
@onnx_training(4)
|
|
def Adam(R:Tensor, T:int, *inputs:Tensor, alpha:float=0.9, beta:float=0.999, epsilon:float=0.0, norm_coefficient:float=0.0,
|
|
norm_coefficient_post:float=0.0):
|
|
X, G, V, H = inputs
|
|
G, V, H = G.detach(), V.detach(), H.detach() # TODO we shouldn't need these detaches
|
|
X.grad = norm_coefficient * X.detach() + G
|
|
opt = TinyAdam([X], b1=alpha, b2=beta, eps=epsilon)
|
|
opt.m, opt.v, opt.lr = [V], [H], R
|
|
# need no-op for m_hat and v_hat if T == 0
|
|
if T == 0: opt.b1_t, opt.b2_t = opt.b1_t.zeros_like(), opt.b2_t.zeros_like()
|
|
else:
|
|
# `T-1` since it's applied again at the start of `_step`
|
|
opt.b1_t = Tensor([alpha**(T-1)], dtype=dtypes.float32, device=X.device, requires_grad=False)
|
|
opt.b2_t = Tensor([beta**(T-1)], dtype=dtypes.float32, device=X.device, requires_grad=False)
|
|
opt.step()
|
|
X = (1 - norm_coefficient_post) * X
|
|
return [X, V, H]
|
|
|
|
@onnx_training(3)
|
|
def Momentum(R:Tensor, T:int, *inputs:Tensor, alpha:float, beta:float, mode:str, norm_coefficient:float):
|
|
X, G, V = inputs
|
|
G, V = G.detach(), V.detach()
|
|
X.grad = (norm_coefficient * X.detach() + G) * (beta if T > 0 else 1)
|
|
opt = SGD([X], momentum=alpha, nesterov=(mode=="nesterov"))
|
|
opt.b, opt.lr = [V], R
|
|
opt.step()
|
|
return [X, V]
|
|
|