import functools, io, math from typing import Union, Tuple, Optional, List, Any, cast from tinygrad.tensor import Tensor, _broadcast_shape, ConstType from tinygrad.dtype import ImageDType, dtypes from tinygrad.helpers import prod, flatten from extra.onnx import dtype_parse, to_python_const import numpy as np # **************** Free Ops **************** def Identity(x:Tensor): return x # TODO: fix buffer_parse def Add(x:Tensor, other:Tensor, broadcast=None, axis=None): return x + other if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else (x + other).cast(x.dtype) def Sub(x:Union[Tensor, Any], other:Tensor): return x - other # some test has input as int def Less(x:Tensor,y:Tensor): return x < y def LessOrEqual(x:Tensor,y:Tensor): return x <= y def Greater(x:Tensor,y:Tensor): return x > y def GreaterOrEqual(x:Tensor,y:Tensor): return x >= y def Equal(x:Tensor,y:Tensor): return x == y def BitwiseNot(x:Tensor): return ~x def BitwiseOr(x:Tensor, y:Tensor): return x | y def BitwiseAnd(x:Tensor, y:Tensor): return x & y def BitwiseXor(x:Tensor, y:Tensor): return x ^ y def Max(*data_0:Tensor): return functools.reduce(Tensor.maximum, data_0) def Min(*data_0:Tensor): return functools.reduce(Tensor.minimum, data_0) def Sum(*data_0:Tensor): return functools.reduce(Tensor.add, data_0) def Mean(*data_0:Tensor): return Sum(*data_0) / len(data_0) # NOTE: does not support saturate def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_parse(to)) def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype) # **************** Simple Ops **************** # https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_div.py def Div(x:Tensor, other:Tensor): return (x/other).cast(x.dtype) def Constant(sparse_value:Optional[Tensor]=None, value:Optional[Tensor]=None, value_float:Optional[float]=None, value_floats:Optional[List[float]]=None, value_int:Optional[int]=None, value_ints:Optional[List[int]]=None, value_string:Optional[str]=None, value_strings:Optional[List[str]]=None): if value is not None: return value if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False) if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False) if value_int is not None: return Tensor(value_int, dtype=dtypes.int64, requires_grad=False) if value_ints is not None: return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False) if value_string is not None or value_strings is not None and sparse_value is not None: raise NotImplementedError('Constant OP not implemented for value_string, value_strings and sparse_value') def HardSigmoid(x:Tensor, alpha:float=0.2, beta:float=0.5): return (alpha*x + beta).clip(0, 1) def Gelu(x:Tensor, approximate:Optional[str]=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf()) # TODO: fix this def PRelu(X:Tensor, slope:Tensor): slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE return (X > 0).where(X, X * slope) def LeakyRelu(X:Tensor, alpha:float=0.01): return X.leakyrelu(alpha) def ThresholdedRelu(X:Tensor, alpha:float=1.0): return (X > alpha).where(X, 0) def Softmax_1(x:Tensor, axis:int=1): return x.softmax(axis) def Softmax_13(x:Tensor, axis:int=-1): return x.softmax(axis) Softmax = {1: Softmax_1, 13: Softmax_13} # Softmax default axis changed def LogSoftmax(x: Tensor, axis:int=-1): return x.log_softmax(axis) def Clip(x: Tensor, min:Optional[Tensor]=None, max:Optional[Tensor]=None): return x.clip(float('-inf') if min is None else min, float('inf') if max is None else max).cast(x.dtype) def _axes(axes, noop_with_empty_axes): return axes or ([] if noop_with_empty_axes else None) def ReduceMax(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0): return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims) def ReduceMin(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0): return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims) def ReduceSum(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0): return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims) def ReduceMean(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0): return data.mean(_axes(axes, noop_with_empty_axes), keepdim=keepdims) def ReduceSumSquare(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0): return ReduceSum(data.square(), axes, keepdims, noop_with_empty_axes) def ReduceProd(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0): return data.prod(_axes(axes, noop_with_empty_axes), keepdim=keepdims) def ReduceL1(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0): return ReduceSum(data.abs(), axes, keepdims, noop_with_empty_axes) def ReduceL2(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0): return ReduceSumSquare(data, axes, keepdims, noop_with_empty_axes).sqrt() def ReduceLogSum(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0): return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log() def ReduceLogSumExp(data:Tensor, axes:Optional[List[int]]=None, keepdims:int=1, noop_with_empty_axes:int=0): return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log() def GlobalAveragePool(X:Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True) def GlobalMaxPool(X:Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True) def OptionalHasElement(x:Optional[Tensor]=None): return Tensor(x is not None and x.numel() > 0) def OptionalGetElement(x:Optional[Tensor]=None): return x if x is not None else Tensor([]) def Tile(x:Tensor, repeats:List[int]): return x.repeat(repeats) def Range(start:Union[float, int], limit:Union[float, int], delta:Union[float, int]): return Tensor.arange(start=start, stop=limit, step=delta) def Shape(data:Tensor, end:Optional[int]=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64) def Size(data:Tensor): return prod(data if isinstance(data, list) else data.shape) def Flatten(x:Tensor, axis:int=1): return x.reshape(prod(x.shape[0:axis]), -1) def Reshape(data:Tensor, shape:List[int], allowzero:int=0): return data.reshape([x if x != 0 else (0 if allowzero else data.shape[i]) for i,x in enumerate(shape)]) def Expand(x:Tensor, shape:List[int]): return x.expand(_broadcast_shape(x.shape, tuple(shape))) def Shrink(x:Tensor, bias:float=0.0, lambd:float=0.5): return (x < -lambd)*(x+bias) + (x > lambd)*(x-bias) def And(x:Tensor, y:Tensor): return (x==y).where(x, False) def Or(x:Tensor, y:Tensor): return (x==y).where(x, True) def Not(x:Tensor): return x.logical_not() def Trilu(x:Tensor, k:int=0, upper:int=1): return x.triu(k) if upper else x.tril(k) def Slice(data:Tensor, starts:List[int], ends:List[int], axes:Optional[List[int]]=None, steps:Optional[List[int]]=None): axes = axes or list(range(data.ndim)) steps = steps or [1]*data.ndim slices = [slice(0,x,1) for x in data.shape] for i, axis in enumerate(axes): slices[axis] = slice(starts[i], ends[i], steps[i]) return data[tuple(slices)] # TODO: add test for when axes is None def Squeeze(data:Tensor, axes:Optional[List[int]]=None): return data.squeeze() if axes is None else functools.reduce(lambda d, dim: d.squeeze(dim), sorted(axes, reverse=True), data) def Unsqueeze(data:Tensor, axes:List[int]): return functools.reduce(lambda d, dim: d.unsqueeze(dim), sorted(axes), data) def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float() def ArgMax(x:Tensor, axis:int=0, keepdims:int=1, select_last_index:int=0): if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64) return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64) def ArgMin(x, axis:int=0, keepdims:int=1, select_last_index:int=0): return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index) def Concat(*xs:Tensor, axis:int): return Tensor.cat(*xs, dim=axis) def Transpose(x:Tensor, perm:Optional[List[int]]=None): return x.permute(order=list(range(x.ndim)[::-1]) if perm is None else perm) def ConstantOfShape(shape:List[int], value:Optional[Tensor]=None): if value is None: value = Tensor(0, dtype=dtypes.float32) return Tensor.ones(*shape, dtype=value.dtype) * (value if shape != [0] else 1) # **************** Complex Ops **************** def Gemm(A:Tensor, B:Tensor, C:Optional[Tensor]=None, alpha:float=1.0, beta:float=1.0, transA:int=0, transB:int=0, broadcast=0): ret = alpha * (A.transpose(transA) @ B.transpose(transB)) if C is not None: ret = ret + beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1])) return ret def Einsum(*Inputs:List[Tensor], equation:str): return Tensor.einsum(equation, Inputs) def CumSum(X:Tensor, axis:int, exclusive:int=0, reverse:int=0): axis = X._resolve_dim(axis) if reverse: X = X.flip(axis) if exclusive: X = X.pad(tuple((1,0) if i == axis else None for i in range(X.ndim)))\ .shrink(tuple((0,X.shape[axis]) if i == axis else None for i in range(X.ndim))) return X.cumsum(axis).flip(axis) if reverse else X.cumsum(axis) # TODO: this is copied from tinygrad/nn/__init__.py # spatial is from opset 7 and has since been removed def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9 , training_mode:int=0, spatial=1, is_test=0): if training_mode: x_detached = X.detach() current_mean = x_detached.mean(axis=(0,2,3)) y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1])) current_var = (y*y).mean(axis=(0,2,3)) current_invstd = current_var.add(epsilon).rsqrt() running_mean = input_mean * momentum + current_mean * (1 - momentum) running_var = input_var * momentum + current_var * (1 - momentum) return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var invstd = (input_var + epsilon).rsqrt() return X.batchnorm(scale, B, input_mean, invstd) def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05): axis = tuple(range(2, x.ndim)) mean = x.mean(axis=axis, keepdim=True) invstd = x.sub(mean).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt() return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1])) def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1): assert stash_type == 1, "only float32 is supported" axis = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim)) mean = x.mean(axis=axis, keepdim=True) return x.layernorm(axis, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt() def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05): return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape) # onnx: [x1_begin, x2_begin, ..., x1_end, x2_end, ...] # numpy.pad: ((x1_begin, x1_end), (x2_begin, x2_end), ...) def _format_padding(onnx_pads, ndims=None, axes=None): if ndims and len(onnx_pads)//2 != ndims: onnx_pads = onnx_pads * ndims # for OnnxBackendPyTorchConvertedModelTest the len(onnx_pads) == 2 if ndims is None: ndims = len(onnx_pads) // 2 if axes is None: axes = list(range(ndims)) num_axes = len(axes) np_pads = [(0,0)] * ndims for i in range(num_axes): np_pads[axes[i]] = (onnx_pads[i], onnx_pads[i + num_axes]) return np_pads def _padded(X: Tensor, pads=None, auto_pad="NOTSET", axes=None, constant_value=0., strides=None, kernel_shape=None, dilations=None, ceil_mode=0): if auto_pad != "NOTSET": pads = _auto_pad(X, auto_pad, strides, kernel_shape, dilations) elif ceil_mode: if strides is not None: strides = [strides]*len(kernel_shape) if isinstance(strides, int) else strides if strides else [1]*len(kernel_shape) if dilations is not None: dilations = [1]*len(kernel_shape) if dilations == 1 else dilations out_spatial_shape = [math.ceil((sh - dil * (ker-1)-1)/st + 1) if ceil_mode else math.floor((sh - dil * (ker-1)-1)/st + 1) for sh, st, ker, dil in zip(X.shape[-len(kernel_shape):], strides, kernel_shape, dilations)] pad_shape = [(osh-1)*st+((ks-1)*dil+1)-ish for osh, st, ks, dil, ish in zip(out_spatial_shape, strides, kernel_shape, dilations, X.shape[-len(kernel_shape):])] pad_shape = [[sh//2, sh-sh//2] for sh in pad_shape] # ceil_mode case follows NOTE in https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d # so if any kernels start in right padded region, we decrease right pads to omit that kernel. Only omitting 1 kernel now. pad_shape = [[start,end-rpad] if (rpad := ks + st%(st-(((start+xs)%st)))) <= end else [start,end] for (start,end), ks, st, xs in zip(pad_shape, kernel_shape, strides, X.shape[-len(kernel_shape):])] pad_shape = flatten(pad_shape) pads = pad_shape[::2] + pad_shape[1::2] if pads is None: return X pads = _format_padding(pads, ndims=len(X.shape), axes=axes) return X.pad(tuple(pads), value=constant_value) def _auto_pad(X: Tensor, auto_pad, strides, kernel_shape, dilations): strides = [strides]*len(kernel_shape) if isinstance(strides, int) else strides if strides else [1]*len(kernel_shape) dilations = [1]*len(kernel_shape) if dilations == 1 else dilations if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER": pad_shape = [(math.ceil(sh/st)-1)*st+((ks-1)*di+1)-sh for sh, st, ks, di in zip(X.shape[-len(kernel_shape):], strides, kernel_shape, dilations)] pad_shape = flatten([[sh//2, sh-sh//2] for sh in pad_shape]) return pad_shape[::2] + pad_shape[1::2] if auto_pad == "SAME_UPPER" else pad_shape[1::2] + pad_shape[::2] raise NotImplementedError(f"auto_pad={auto_pad} not implemented") # (x1_begin, x2_begin, ..., x1_end, x2_end, ...) -> (..., x2_start, x2_end, x1_start, x1_end) def _onnx_pads_to_pad2d_pads(pads): return flatten(reversed(list((pB, pE) for pB, pE in zip(pads, pads[len(pads)//2:])))) def Pad(x:Tensor, pads:List[int], constant_value:Optional[ConstType]=None, axes:Optional[List[int]]=None, mode:str="constant", value=0): value, axes = constant_value or value or 0, axes or list(range(x.ndim)) real_pads = [0] * (x.ndim*2) for i,axis in enumerate(axes): real_pads[axis%x.ndim], real_pads[axis%x.ndim+x.ndim] = pads[i], pads[i+len(axes)] return x.pad(padding=_onnx_pads_to_pad2d_pads(real_pads), mode={"edge":"replicate", "wrap":"circular"}.get(mode, mode), value=value) def AveragePool(X: Tensor, kernel_shape, auto_pad="NOTSET", ceil_mode=0, count_include_pad=0, dilations=1, pads=None, strides=1): pixel_axes = tuple(range(2, X.ndim)) ret = _padded(X, pads, auto_pad, axes=pixel_axes, strides=strides, kernel_shape=kernel_shape, dilations=dilations, ceil_mode=ceil_mode) ret = ret.avg_pool2d(kernel_shape, stride=strides, dilation=dilations) if count_include_pad: return ret div = _padded(Tensor.ones(X.shape), pads, auto_pad, axes=pixel_axes, strides=strides, kernel_shape=kernel_shape, dilations=dilations, ceil_mode=ceil_mode).avg_pool2d(kernel_shape, stride=strides, dilation=dilations) return ret / div def MaxPool(X: Tensor, kernel_shape, auto_pad="NOTSET", ceil_mode=0, dilations=1, pads=None, storage_order=0, strides=1): pixel_axes = tuple(range(2, X.ndim)) ret = _padded(X, pads, auto_pad, constant_value=-math.inf, axes=pixel_axes, strides=strides, kernel_shape=kernel_shape, dilations=dilations, ceil_mode=ceil_mode) ret = ret.max_pool2d(kernel_shape, stride=strides, dilation=dilations).cast(X.dtype) ret_len, X_len = ret.numel(), X.numel() indices = ((ret.flatten().unsqueeze(1).expand(ret_len, X_len) == X.flatten().unsqueeze(0).expand(ret_len, X_len)) * \ Tensor.arange(X_len, dtype=dtypes.int64).unsqueeze(0).expand(ret_len, X_len)).sum(1).reshape(ret.shape) if storage_order: indices = indices.transpose(-2, -1) return ret, indices def MaxUnpool(xT: Tensor, xI: Tensor, outshape: Optional[Tensor]=None, kernel_shape=None, pads=None, strides=None): out_sh = [(ks//2)*2 + st * inps for inps, st, ks in zip(xI.shape, strides, kernel_shape)] outlength = prod(out_sh) xI = xI.flatten().unsqueeze(1).expand(None, outlength) arange = Tensor.arange(outlength, requires_grad=False).reshape(1, outlength).expand(xI.shape) xT = xT.flatten().unsqueeze(1).expand(None, outlength) ret = ((xI == arange) * xT).sum(0).reshape([1, 1] + out_sh) if outshape is not None and outshape != ret.shape: diff = [outshape[2] - ret.shape[2], outshape[3] - ret.shape[3]] pad_args = [diff[0]//2, diff[1]//2, diff[0]-diff[0]//2, diff[1]-diff[1]//2] ret = ret.pad((pad_args[1], pad_args[3], pad_args[0], pad_args[2])) return ret def Conv(X: Tensor, W: Tensor, B:Optional[Tensor]=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, strides=1): if auto_pad != "NOTSET": padding = _auto_pad(X, auto_pad, strides, kernel_shape, dilations) else: # reorder padding padding = [p for ps in zip(pads[:len(pads)//2][::-1], pads[len(pads)//2:][::-1]) for p in ps] if pads is not None else 0 return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, padding=padding) def ConvTranspose(X: Tensor, W: Tensor, B:Optional[Tensor]=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, output_shape=None, output_padding=0, strides=1): if kernel_shape is None: kernel_shape = W.shape[2:] if isinstance(strides, int): strides = [strides]*(W.ndim-2) if isinstance(dilations, int): dilations = [dilations]*(W.ndim-2) if isinstance(output_padding, int): output_padding = [output_padding]*(W.ndim-2) out_sh = [st*(xs-1) + (ks-1)*di+1 if n < 2 else st*(xs-1) + (ks-1)*di+1 - pads[n-2] - pads[n-1] for n, (st, xs, ks, di) in enumerate(zip(strides, X.shape[2:], kernel_shape, dilations))] if output_shape is not None or auto_pad != "NOTSET" else [] if pads is None: if output_shape is None: output_shape = [xs*st for xs, st in zip(X.shape[2:], strides)] if auto_pad == "NOTSET": pads = [0,0] * (X.ndim - 2) else: total_padding = [st*(ish-1) + pad + ((ks-1)*dil+1)-osh for st, ish, pad, ks, dil, osh in zip(strides, X.shape[2:], output_padding, kernel_shape, dilations, output_shape)] pad_shape = flatten([[sh//2, sh-sh//2] for sh in total_padding]) pads = pad_shape[::2] + pad_shape[1::2] if auto_pad == "SAME_UPPER" else pad_shape[1::2] + pad_shape[::2] else: if output_shape is None: output_shape = [st*(xs-1) + (ks-1)*di+1 if n < 2 else st*(xs-1) + (ks-1)*di+1 - pads[n-2] - pads[n-1] for n, (st, xs, ks, di) in enumerate(zip(strides, X.shape[2:], kernel_shape, dilations))] if out_sh: output_padding = [os - rs for os, rs in zip(output_shape, out_sh)] return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=pads if pads is not None else 0, output_padding=output_padding) def DepthToSpace(X:Tensor, blocksize:int, mode:str="DCR"): return X.rearrange("b (c h1 w1) h w -> b c (h h1) (w w1)" if mode=="CRD" else "b (h1 w1 c) h w -> b c (h h1) (w w1)", h1=blocksize, w1=blocksize) def SpaceToDepth(X:Tensor, blocksize:int): return X.rearrange("b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=blocksize, w1=blocksize) # Reimplemented here because you need legacy RNG for passing ONNX tests. def Dropout(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:Optional[int]=None): if not training_mode: return data, Tensor.ones(data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's. mask = Tensor(np.random.RandomState(seed).random(cast(Tuple[int,...], data.shape)) >= ratio, requires_grad=False, device=data.device) return data * mask * (1/(1.0 - ratio)), mask def LRN(x:Tensor, size:int, alpha:float=1e-4, beta:float=0.75, bias:float=1.0): pooled_x = (x**2).rearrange('b c h w -> b 1 c (h w)').pad((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1) return x / (pooled_x.reshape(x.shape) * alpha + bias).pow(beta) def MeanVarianceNormalization(x:Tensor, axis:List[int]=[0,2,3]): return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9) def NegativeLogLikelihoodLoss(x:Tensor, target:Tensor, weight:Optional[Tensor]=None, ignore_index:Optional[int]=None, reduction:str="mean"): return x.nll_loss(target, weight, ignore_index, reduction) def SoftmaxCrossEntropyLoss(scores:Tensor, labels:Tensor, weights:Optional[Tensor]=None, ignore_index:Optional[int]=None, reduction:str="mean"): log_probs = scores.log_softmax(1) return log_probs.nll_loss(labels, weights, ignore_index, reduction), log_probs def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices] def Gather(x:Tensor, indices:Tensor, axis:int=0): if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices x_sh = list(x.shape) ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:] if indices.ndim > 1: indices = indices.flatten() indices = [to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in to_python_const(indices)] args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape) # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])] def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated def GatherND(x:Tensor, indices:Tensor, batch_dims:int=0): if batch_dims == 0: return x[tuple(i.squeeze(-1) for i in indices.split(1, -1))] x_shape, i_shape = x.shape, indices.shape b = math.prod(x.shape[dim] for dim in range(batch_dims)) # NOTE: each batched dim of both input and indices are equal x = x.reshape(b, *x.shape[batch_dims:]) indices = indices.reshape(b, *indices.shape[batch_dims:]) b_idx = Tensor.arange(b, device=x.device).reshape(b, *(1,)*(indices.ndim - 2)).expand(*indices.shape[:-1]) ret = x[(b_idx,) + tuple(i.squeeze(-1) for i in indices.split(1, -1))] return ret.reshape(*x_shape[:batch_dims], *i_shape[batch_dims:-1], *ret.shape[indices.ndim-1:]) def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Optional[str]=None): assert updates.shape == indices.shape[:-1] + x.shape[indices.shape[-1]:] x = x.contiguous() for idx, u in zip(indices.split(1, 0), updates.split(1, 0)): idx = tuple(i.squeeze(-1) for i in idx.squeeze(0).split(1, -1)) u = u.squeeze(0) if reduction is None: x[idx] = u elif reduction == "add": x[idx] += u elif reduction == "mul": x[idx] *= u else: raise NotImplementedError("reduction doesn't support max or min") return x def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Optional[str]=None): indices = (indices < 0).where(x.shape[axis], 0) + indices return x.scatter(axis, indices, updates, reduction) def GatherElements(x:Tensor, indices:Tensor, axis:int): indices = (indices < 0).where(x.shape[axis], 0) + indices return x.gather(axis, indices) def Resize(X:Tensor, roi:Optional[List[float]]=None, scales:Optional[List[float]]=None, sizes:Optional[List[int]]=None, antialias:int=0, axes:Optional[List[int]]=None, coordinate_transformation_mode:str='half_pixel', cubic_coeff_a:float=-0.75, exclude_outside:int=0, extrapolation_value:float=0.0, keep_aspect_ratio_policy:str='stretch', mode:str='nearest', nearest_mode:str='round_prefer_floor'): def _apply_nearest_mode(index: Tensor, input_dim, mode: str): if mode == "round_prefer_floor": index = (index - 0.5).ceil() elif mode == "round_prefer_ceil": index = (index + 0.5).floor() elif mode in ["floor", "ceil"]: index = getattr(index, mode)() else: raise ValueError(f"invalid {nearest_mode=}") return index.cast(dtypes.int32).clip(0, input_dim-1) def _apply_transformation(index: Tensor, input_dim, scale_dim, roi_dim, sizes_frac, mode): # TODO: needs more testing, not confident in this # NOTE: their reference implementation differ from the implementation in their reference docs # https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_resize.py # https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize output_dim = scale_dim * input_dim if mode == "half_pixel": index = (index + 0.5) / scale_dim - 0.5 elif mode == "align_corners": index = index * (input_dim - 1) / (output_dim - 1) if output_dim != 1 else Tensor([0]) elif mode == "asymmetric": index = index / scale_dim elif mode == "pytorch_half_pixel": index = (index + 0.5) / scale_dim - 0.5 if output_dim != 1 else Tensor([-0.5]) elif mode == "half_pixel_symmetric": index = input_dim / 2 * (1 - int(output_dim) / sizes_frac) + (index + 0.5) / scale_dim - 0.5 elif mode == "tf_crop_and_resize": index = roi_dim[0] * (input_dim - 1) + index * ((roi_dim[1] - roi_dim[0]) * (input_dim - 1) / (output_dim - 1)) # noqa: E501 else: raise ValueError(f"invalid {coordinate_transformation_mode=}") return index.clip(0, input_dim-1) scales, sizes = (None if scales is None else scales[-2:]), (None if sizes is None else sizes[-2:]) # we pre permute the axes and permute back after resize axes, input_shape, = (axes or list(range(X.ndim))), X.shape[2:], perm = [a for a in range(len(X.shape)) if a not in axes] + list(axes) X = X.permute(*perm) if sizes is not None: if keep_aspect_ratio_policy in ["not_larger", "not_smaller"]: scale_fxn = min if keep_aspect_ratio_policy == "not_larger" else max scales = scale_fxn([sizes[i] / input_shape[i] for i in range(X.ndim-2) if i+2 in axes]) sizes = [int((scales * input_shape[i]) + 0.5) if i+2 in axes else input_shape[i] for i in range(X.ndim-2)] else: scales = [sizes[-2] / X.size(-2), sizes[-1] / X.size(-1)] else: sizes = [int(sc*sh) for sc, sh in zip(scales, input_shape)] scales = [scales] * 2 if not isinstance(scales, list) else scales roi = [[st, ed] for st, ed in zip(roi, roi[len(roi)//2:])] if isinstance(roi, list) else [None] * (X.ndim-2) # NOTE: this transformation makes it so that we can't just call Tensor.interpolate # in Tensor.interpolate, we use indexes without any transformation indexes = [] for shape, size, scale, region in zip(input_shape, sizes, scales, roi): indexes.append(_apply_transformation(Tensor.arange(size), shape, scale, region, shape * scale, coordinate_transformation_mode)) if mode == "nearest": indexes = [_apply_nearest_mode(index, shape, nearest_mode) for (index, shape) in zip(indexes, input_shape)] X = X[(..., *Tensor.meshgrid(*indexes))] if mode == "linear": expand = list(X.shape) for i in range(-len(sizes), 0): reshape, index = [1] * X.ndim, indexes[i] reshape[i] = expand[i] = sizes[i] low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())] X = X.gather(i, low).lerp(X.gather(i, high), perc) if mode == "cubic": raise NotImplementedError("cubic interpolation is not implemented") return X.permute(*[perm.index(i) for i in range(len(perm))]) if perm else X def CenterCropPad(t:Tensor, shape:List[int], axes:Optional[List[int]]=None): shrink_arg = [None] * t.ndim pad_arg = [None] * t.ndim for s, x in zip(shape, axes or range(t.ndim)): tx = t.shape[x] if s < tx: shrink_arg[x] = (tx//2 - (s+1)//2, tx//2 + s//2) elif s > tx: pad_arg[x] = ((s-tx)//2, (s-tx+1)//2) return t.shrink(tuple(shrink_arg)).pad(tuple(pad_arg)) def OneHot(indices:Tensor, depth:Union[int, float], values:Tensor, axis:int=-1): # Scalar or Rank 1 tensor containing exactly one element depth = int(depth) indices = (indices < 0).where(indices+depth, indices) return indices[:, None]._one_hot_along_dim(depth, dim=axis).where(values[1], values[0]) def Compress(inp:Tensor, condition:List[bool], axis:Optional[int]=None): if axis is None: inp = inp.flatten() axis = 0 if axis < 0: axis += inp.ndim con = Tensor(np.arange(len(condition))[condition]) # no boolean indexing in Tensor return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))] def EyeLike(x:Tensor, dtype:Optional[int]=None, k:int=0): ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype) if dtype is not None else x.dtype) return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.size(0)-k) for d in x.shape)) def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode) # deprecated def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Union[Tensor, int] = 0, axis:int=1, block_size:int=0): if axis < 0: axis += x.ndim if not isinstance(x_zero_point, Tensor): x_zero_point = Tensor(x_zero_point) if block_size: x_zer, x_sc = x_zero_point.repeat_interleave(block_size, axis), x_scale.repeat_interleave(block_size, axis) else: shape = (*[1]*axis, *x_scale.shape, *[1]*(x.ndim - axis - x_scale.ndim)) x_sc, x_zer = x_scale.reshape(shape), x_zero_point.reshape(shape) return ((x.float() - x_zer) * x_sc).cast(x_scale.dtype) # copied from https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_image_decoder.py def ImageDecoder(encoded_stream:bytes, pixel_format="RGB"): try: import PIL.Image except ImportError as e: raise ImportError("Pillow must be installed to use the reference implementation of the ImageDecoder operator") from e img = PIL.Image.open(io.BytesIO(encoded_stream)) if pixel_format == "BGR": return Tensor(np.array(img))[:, :, ::-1] if pixel_format == "RGB": return Tensor(np.array(img)) if pixel_format == "Grayscale": return Tensor(np.array(img.convert("L"))).unsqueeze(-1) # (H, W) to (H, W, 1) raise ValueError(f"pixel_format={pixel_format!r} is not supported.") def AffineGrid(theta:Tensor, size:Tensor, align_corners:int=0): N, _, *spatial_dims = size def generate_grid(steps): return Tensor.linspace(-1, 1, steps, device=theta.device) if align_corners else Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device) grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims)) base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1) base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1) return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1) # **************** com.microsoft Ops **************** def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma:Tensor, beta:Optional[Tensor]=None, bias:Optional[Tensor]=None, epsilon:float=1e-12): x = x + skip + bias return x.layernorm(eps=epsilon) * gamma + beta, None, None, x def FastGelu(x:Tensor, bias:Optional[Tensor]=None): # this is tanh approximated return (x + bias).gelu() if bias is not None else x.gelu() def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Optional[Tensor]=None, word_embedding:Tensor=None, position_embedding:Tensor=None, segment_embedding:Optional[Tensor]=None, gamma=None, beta=None, mask:Optional[Tensor]=None, position_ids:Optional[Tensor]=None, epsilon=1e-12, mask_index_type=0): # https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization assert (segment_ids is None) is (segment_embedding is None) assert mask is None and not mask_index_type, "functionality not supported yet" # TODO input_shape = input_ids.shape seq_length = input_shape[1] compute_seg_emb = (segment_embedding is not None and segment_ids is not None) vocab_size, max_position_embeddings = word_embedding.shape[0], position_embedding.shape[0] type_vocab_size = (segment_embedding.shape[0] if compute_seg_emb else None) def embedding(x:Tensor, vocab_size, weight:Tensor) -> Tensor: return x.unsqueeze(-1).expand(*x.shape, vocab_size)._one_hot_along_dim(vocab_size) @ weight # bert embedding layer if position_ids is None: position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape) wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding) pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding) seg_embedding_res = embedding(segment_ids, type_vocab_size, segment_embedding) if compute_seg_emb else None embedding_sum = wrd_embedding_res + pos_embedding_res if seg_embedding_res is not None: embedding_sum = embedding_sum + seg_embedding_res out = embedding_sum.layernorm(eps=epsilon) * gamma + beta return out, None, embedding_sum def Attention(x:Tensor, weights, bias:Optional[Tensor]=None, mask_index:Optional[Tensor]=None, past:Optional[Tensor]=None, relative_position_bias:Optional[Tensor]=None, past_sequence_length:Optional[Tensor]=None, do_rotary:Optional[int]=None, mask_filter_value:Optional[float]=None, num_heads:Optional[int]=None, past_present_share_buffer:Optional[int]=None, qkv_hidden_sizes:Optional[List[int]]=None, scale:Optional[float]=None, unidirectional:Optional[int]=None): # https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention assert num_heads is not None # required assert (qkv_hidden_sizes is None and past is not None) or (qkv_hidden_sizes is not None) assert relative_position_bias==do_rotary==past_sequence_length==mask_filter_value==past_present_share_buffer==scale==None, \ "functionality not supported yet" # TODO strange params hidden_size, v_hidden_size = qkv_hidden_sizes[1:] if qkv_hidden_sizes is not None else 2*(weights.shape[1] // 3,) if unidirectional: # gpt-style assert hidden_size == v_hidden_size xqkv = x.linear(weights, bias) xq, xk, xv = [xqkv.shrink([None, None, (i*hidden_size, (i+1)*hidden_size)]) for i in range(3)] else: # bert-style wq, wk, wv = weights[:,:hidden_size], weights[:,hidden_size:hidden_size+v_hidden_size], weights[:,hidden_size+v_hidden_size:] bq, bk, bv = (bias[:hidden_size], bias[hidden_size:hidden_size+v_hidden_size], bias[hidden_size+v_hidden_size]) if bias is not None else None xq, xk, xv = [x.linear(w, b) for w, b in zip((wq, wk, wv), (bq, bk, bv))] xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], num_heads, -1).transpose(1, 2) for x in (xq, xk, xv)] if past is not None: xk, xv = Tensor.cat(past[0], xk, dim=-2), Tensor.cat(past[1], xv, dim=-2) present = Tensor.cat(xk.unsqueeze(0), xv.unsqueeze(0)) def attn(query, key, value, attn_mask): query_length, key_length = query.shape[-2], key.shape[-2] cdim = max(query_length, key_length) + 1 attn_weights = query @ key.transpose(-1, -2) / math.sqrt(value.shape[-1]) # This is where Tensor.scaled_dot_product_attention differs: causal_mask = Tensor.ones((cdim, cdim), requires_grad=False, dtype=dtypes.bool).tril(0)[key_length - query_length : key_length, :key_length] masked = Tensor.where(causal_mask, attn_weights, -math.inf) if attn_mask is not None: masked = masked + attn_mask return masked.softmax(-1) @ value bsz, _, seq_len, _ = xq.shape out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1) return out, present if past is not None else out # **************** ai.onnx.preview.training Ops **************** # NOTE: onnx test coverage only covers `T==0` cases, so for all `T>0` this isn't tested # NOTE: onnx training ops actually don't need the state for optim, all the ops work in a functional way, but we still can reuse optim.py code from tinygrad.nn.optim import Adam as TinyAdam from tinygrad.nn.optim import SGD def onnx_training(input_group_size): def _decorator(func): def __wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs): old_training = Tensor.training Tensor.training = True R = R.detach() groups = len(inputs) // input_group_size ret = [func(R, T, *inps, **kwargs) for inps in (inputs[i::groups] for i in range(groups))] Tensor.training = old_training return tuple(flatten(zip(*ret))) return __wrapper return _decorator @onnx_training(3) def Adagrad(R:Tensor, T:int, *inputs:Tensor, decay_factor:float=0.0, epsilon:float=0.0, norm_coefficient:float=0.0): X, G, H = (i.detach() for i in inputs) grad = norm_coefficient * X + G H.assign(H + grad.square()) up = grad / (H.sqrt() + epsilon) r = R / (1 + T * decay_factor) X.assign(X.detach() - r * up) return [X, H] @onnx_training(4) def Adam(R:Tensor, T:int, *inputs:Tensor, alpha:float=0.9, beta:float=0.999, epsilon:float=0.0, norm_coefficient:float=0.0, norm_coefficient_post:float=0.0): X, G, V, H = inputs G, V, H = G.detach(), V.detach(), H.detach() # TODO we shouldn't need these detaches X.grad = norm_coefficient * X.detach() + G opt = TinyAdam([X], b1=alpha, b2=beta, eps=epsilon) opt.m, opt.v, opt.lr = [V], [H], R # need no-op for m_hat and v_hat if T == 0 if T == 0: opt.b1_t, opt.b2_t = opt.b1_t.zeros_like(), opt.b2_t.zeros_like() else: # `T-1` since it's applied again at the start of `_step` opt.b1_t = Tensor([alpha**(T-1)], dtype=dtypes.float32, device=X.device, requires_grad=False) opt.b2_t = Tensor([beta**(T-1)], dtype=dtypes.float32, device=X.device, requires_grad=False) opt.step() X = (1 - norm_coefficient_post) * X return [X, V, H] @onnx_training(3) def Momentum(R:Tensor, T:int, *inputs:Tensor, alpha:float, beta:float, mode:str, norm_coefficient:float): X, G, V = inputs G, V = G.detach(), V.detach() X.grad = (norm_coefficient * X.detach() + G) * (beta if T > 0 else 1) opt = SGD([X], momentum=alpha, nesterov=(mode=="nesterov")) opt.b, opt.lr = [V], R opt.step() return [X, V]