You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
829 lines
50 KiB
829 lines
50 KiB
from typing import Any, Sequence, cast, Literal, Callable
|
|
import dataclasses, functools, io, math, types
|
|
from tinygrad.tensor import Tensor, _broadcast_shape, ReductionStr
|
|
from tinygrad.helpers import getenv, DEBUG, all_same, prod, flatten, make_tuple, argsort
|
|
from tinygrad.dtype import DType, ConstType, dtypes, ImageDType
|
|
from tinygrad.device import is_dtype_supported
|
|
|
|
# ***** protobuf parsing ******
|
|
from onnx import AttributeProto, ModelProto, TensorProto, TypeProto, helper
|
|
import numpy as np
|
|
|
|
def dtype_parse(onnx_dtype: int) -> DType:
|
|
supported: dict[int, DType] = {
|
|
TensorProto.FLOAT:dtypes.float32, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8,
|
|
TensorProto.UINT16:dtypes.uint16, TensorProto.INT16:dtypes.int16, TensorProto.INT32:dtypes.int32, TensorProto.INT64:dtypes.int64,
|
|
TensorProto.BOOL:dtypes.bool, TensorProto.FLOAT16:dtypes.float32, TensorProto.DOUBLE:dtypes.double, TensorProto.UINT32:dtypes.uint32,
|
|
TensorProto.UINT64:dtypes.uint64, TensorProto.BFLOAT16:dtypes.bfloat16,
|
|
}
|
|
unsupported = {
|
|
TensorProto.UNDEFINED, TensorProto.STRING, TensorProto.COMPLEX64, TensorProto.COMPLEX128, TensorProto.FLOAT8E4M3FN, TensorProto.FLOAT8E4M3FNUZ,
|
|
TensorProto.FLOAT8E5M2, TensorProto.FLOAT8E5M2FNUZ, TensorProto.UINT4, TensorProto.INT4
|
|
}
|
|
if onnx_dtype in unsupported: raise NotImplementedError(f"onnx dtype {TensorProto.DataType.Name(onnx_dtype)} is not supported")
|
|
return supported[onnx_dtype] if is_dtype_supported(supported[onnx_dtype]) else dtypes.float
|
|
|
|
def attribute_parse(onnx_attribute: AttributeProto):
|
|
supported: dict[AttributeProto.AttributeType, Callable[[AttributeProto], Any]] = {
|
|
AttributeProto.FLOAT: lambda a: float(a.f), AttributeProto.INT: lambda a: int(a.i),
|
|
AttributeProto.STRING: lambda a: a.s.decode("utf-8"), AttributeProto.TENSOR: lambda a: buffer_parse(a.t),
|
|
AttributeProto.FLOATS: lambda a: tuple(float(x) for x in a.floats), AttributeProto.INTS: lambda a: tuple(int(x) for x in a.ints),
|
|
AttributeProto.STRINGS: lambda a: tuple(x.decode("utf-8") for x in a.strings)
|
|
}
|
|
unsupported = {
|
|
AttributeProto.UNDEFINED, AttributeProto.GRAPH, AttributeProto.SPARSE_TENSOR, AttributeProto.TYPE_PROTO, AttributeProto.TENSORS,
|
|
AttributeProto.GRAPHS, AttributeProto.SPARSE_TENSORS, AttributeProto.TYPE_PROTOS
|
|
}
|
|
if onnx_attribute.type in unsupported:
|
|
raise NotImplementedError(f"attribute with type {AttributeProto.AttributeType.Name(onnx_attribute.type)} is not supported")
|
|
return supported[onnx_attribute.type](onnx_attribute)
|
|
|
|
def buffer_parse(onnx_tensor: TensorProto) -> Tensor:
|
|
if onnx_tensor.string_data: raise NotImplementedError("Parsing for buffer with string data is not implemented.")
|
|
dtype, shape = dtype_parse(onnx_tensor.data_type), tuple(onnx_tensor.dims)
|
|
if data := list(onnx_tensor.float_data) or list(onnx_tensor.int32_data) or list(onnx_tensor.int64_data) or list(onnx_tensor.double_data) or \
|
|
list(onnx_tensor.uint64_data):
|
|
if len(data) == 1: return Tensor(data[0], dtype=dtype).reshape(shape)
|
|
return Tensor(data, dtype=dtype).reshape(shape).realize()
|
|
if onnx_tensor.HasField("raw_data"):
|
|
np_buffer = np.frombuffer(onnx_tensor.raw_data, dtype=helper.tensor_dtype_to_np_dtype(onnx_tensor.data_type)).copy().reshape(shape)
|
|
if np_buffer.size == 1: return Tensor(np_buffer.item(), dtype=dtype).reshape(shape)
|
|
return Tensor(np_buffer, dtype=dtype)
|
|
return Tensor(None)
|
|
|
|
def type_parse(onnx_type: TypeProto):
|
|
elem_type = onnx_type
|
|
if elem_type.HasField("map_type") or elem_type.HasField("sparse_tensor_type") or elem_type.HasField("opaque_type"):
|
|
raise NotImplementedError("parsing for map_type, sparse_tensor_type and opaque_type are not implemented")
|
|
if is_optional := elem_type.HasField("optional_type"): elem_type = elem_type.optional_type.elem_type
|
|
if is_sequence := elem_type.HasField("sequence_type"): elem_type = elem_type.sequence_type.elem_type
|
|
if elem_type.HasField("tensor_type"):
|
|
shape = tuple(d.dim_param or d.dim_value for d in elem_type.tensor_type.shape.dim)
|
|
dtype = dtype_parse(elem_type.tensor_type.elem_type)
|
|
return OnnxValue(shape, dtype, is_optional, is_sequence)
|
|
raise RuntimeError(f"TypeProto was not parsed properly: {onnx_type=}")
|
|
|
|
# ***** onnx spec *****
|
|
@dataclasses.dataclass(frozen=True)
|
|
class OnnxValue:
|
|
shape: tuple[str|int, ...]
|
|
dtype: DType
|
|
is_optional: bool
|
|
is_sequence: bool
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class OnnxNode:
|
|
num: int
|
|
op: str
|
|
inputs: tuple[str, ...]
|
|
outputs: tuple[str, ...]
|
|
opts: dict[str, Any]
|
|
|
|
# ***** python const *****
|
|
required_input_python_consts: dict[str, tuple[int, ...]] = {
|
|
"Tile": (1,), "Range": (0,1,2), "Expand": (1,), "Reshape": (1,), "Squeeze": (1,), "Unsqueeze": (1,), "Trilu": (1,), "ConstantOfShape": (0,),
|
|
"CumSum": (1,), "TopK": (1,), "Pad": (1,2,3), "MaxUnpool": (2,), "Dropout": (1,2), "CenterCropPad": (1,), "OneHot": (1,), "Compress": (1,),
|
|
"ImageDecoder": (0,), "AffineGrid": (1,), "Resize": (1,2,3), "Upsample": (1,), "Split": (1,), "Slice": (1,2,3,4),
|
|
**{"Reduce"+r: (1,) for r in ("Max", "Min", "Sum", "Mean", "SumSquare", "Prod", "L1", "L2", "LogSum", "LogSumExp")},
|
|
**{optim: (1,) for optim in ("Adam", "Adagrad", "Momentum")}
|
|
}
|
|
|
|
cache_misses = 0
|
|
@functools.lru_cache(None)
|
|
def _cached_to_python_const(t:Tensor):
|
|
if t.dtype is dtypes.uint8: return t.data().tobytes()
|
|
if 0 in t.shape: return []
|
|
return t.tolist()
|
|
|
|
# Tensor -> python value cache for parameters
|
|
def to_python_const(t:Any, op:str, idx:int) -> list[ConstType]|ConstType|bytes:
|
|
if idx not in required_input_python_consts.get(op, ()) or not isinstance(t, Tensor): return t
|
|
global cache_misses
|
|
ret = _cached_to_python_const(t)
|
|
if (info := _cached_to_python_const.cache_info()).misses > cache_misses and DEBUG >= 3:
|
|
print(f"Cache miss for {t}")
|
|
cache_misses = info.misses
|
|
return ret
|
|
|
|
# ***** runner ******
|
|
debug = int(getenv("DEBUGONNX", "0"))
|
|
limit = int(getenv("ONNXLIMIT", "-1"))
|
|
class OnnxRunner:
|
|
def __init__(self, model: ModelProto):
|
|
# parse model protobuf
|
|
self.is_training = any(n.domain in {"ai.onnx.training", "ai.onnx.preview.training"} for n in model.graph.node)
|
|
self.old_training, self.old_no_grad = Tensor.training, Tensor.no_grad
|
|
Tensor.training = True if self.is_training else False
|
|
Tensor.no_grad = False if self.is_training else True
|
|
self.graph_values = {"": None, **{x.name:buffer_parse(x) for x in model.graph.initializer}}
|
|
self.graph_inputs = {x.name:type_parse(x.type) for x in model.graph.input if x.name not in self.graph_values}
|
|
self.graph_outputs = tuple(x.name for x in model.graph.output)
|
|
self.graph_nodes = tuple(OnnxNode(num, n.op_type, tuple(n.input), tuple(n.output), {x.name:attribute_parse(x) for x in n.attribute})
|
|
for num,n in enumerate(model.graph.node))
|
|
self.opset_version = model.opset_import[0].version
|
|
self.variable_dims: dict[str, int] = {}
|
|
|
|
self.onnx_ops = onnx_ops
|
|
|
|
def _parse_input(self, name: str, value: Any, spec: OnnxValue):
|
|
if spec.is_optional and value is None: return None
|
|
# TODO: need true float16 for dtype checking
|
|
if spec.is_sequence:
|
|
if not isinstance(value, Sequence): raise RuntimeError(f"{name} received {value}, expected a sequence type")
|
|
sequence = [Tensor(v, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(v, Tensor) else v for v in value]
|
|
if not all_same(tuple(t.shape for t in sequence)): raise RuntimeError(f"Shapes for {name} sequence must be homogeneous")
|
|
return sequence
|
|
tensor = Tensor(value, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(value, Tensor) else value
|
|
for dim, (onnx_dim, user_dim_input) in enumerate(zip(spec.shape, tensor.shape, strict=True)):
|
|
if isinstance(onnx_dim, str):
|
|
onnx_dim = self.variable_dims[onnx_dim] if onnx_dim in self.variable_dims else self.variable_dims.setdefault(onnx_dim, int(user_dim_input))
|
|
if user_dim_input != onnx_dim: raise RuntimeError(f"{name} has mismatch on {dim=}. Expected {onnx_dim}, received {user_dim_input}.")
|
|
return tensor
|
|
|
|
def _dispatch_op(self, op, inps, opts):
|
|
if op in self.onnx_ops:
|
|
fxn = self.onnx_ops[op]
|
|
if isinstance(fxn, dict):
|
|
for k in sorted(fxn.keys()):
|
|
if k <= self.opset_version:
|
|
real_fxn = fxn[k]
|
|
else: real_fxn = fxn
|
|
return real_fxn(*inps, **opts)
|
|
raise NotImplementedError(f"{op=} not supported")
|
|
|
|
def __call__(self, inputs:dict[str, Any], debug=debug):
|
|
for name, input_spec in self.graph_inputs.items():
|
|
if name not in inputs: raise RuntimeError(f"Please provide input data for {name}")
|
|
self.graph_values[name] = self._parse_input(name, inputs[name], input_spec)
|
|
|
|
for node in self.graph_nodes:
|
|
inps = [to_python_const(self.graph_values[name], node.op, i) for i,name in enumerate(node.inputs)]
|
|
opts = node.opts
|
|
|
|
# provide additional opts
|
|
if node.op == "Split" and 'num_outputs' not in opts: opts['num_outputs'] = len(node.outputs)
|
|
if node.op == "Gradient": opts['intermediate_tensors'] = self.graph_values
|
|
|
|
if debug >= 1: print(f"{node.num}: op '{node.op}' opt {opts}")
|
|
if debug >= 2 and node.inputs: print("\tinputs:\n" + "\n".join(f"\t\t{x} - {i!r}" for x,i in zip(node.inputs, inps)))
|
|
ret = self._dispatch_op(node.op, inps, opts)
|
|
ret = ret if isinstance(ret, tuple) else (ret,)
|
|
if debug >= 2: print("\toutputs:\n" + "\n".join(f"\t\t{x} - {o!r}" for x,o in zip(node.outputs, ret)))
|
|
|
|
self.graph_values.update(dict(zip(node.outputs, ret[:len(node.outputs)], strict=True)))
|
|
|
|
if node.num == limit:
|
|
Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad
|
|
return {name:self.graph_values[name] for name in node.outputs}
|
|
Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad
|
|
return {name:self.graph_values[name] for name in self.graph_outputs}
|
|
|
|
####################
|
|
##### ONNX OPS #####
|
|
####################
|
|
def get_onnx_ops():
|
|
# ***** helper functions *****
|
|
def _axes(axes, noop_with_empty_axes): return axes or ([] if noop_with_empty_axes else None)
|
|
|
|
# (padding_top, padding_left, ..., padding_bottom, padding_right, ...) -> (padding_left, padding_right, padding_top, padding_bottom, ...)
|
|
def _onnx_pads_to_tiny_pads(pads): return tuple(flatten(reversed(list(zip(pads, pads[len(pads)//2:])))))
|
|
|
|
AUTO_PAD_OPTIONS = Literal["NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID"]
|
|
# (padding_height, padding_width) -> (padding_top, padding_left, padding_bottom, padding_right)
|
|
def _auto_pad(pads, auto_pad: AUTO_PAD_OPTIONS):
|
|
if auto_pad == "SAME_UPPER": return [pads[i]//2 for i in range(len(pads))] + [pads[i]-pads[i]//2 for i in range(len(pads))]
|
|
return [pads[i]-pads[i]//2 for i in range(len(pads))] + [pads[i]//2 for i in range(len(pads))]
|
|
|
|
def _resolve_pool_pads(x:Tensor, p_, k_, d_, s_, auto_pad:AUTO_PAD_OPTIONS):
|
|
if auto_pad == "VALID": return [0]*(len(k_)*2)
|
|
i_, (s_,d_,p_) = x.shape[-len(k_):], (make_tuple(x, len(k_)*2) for x in (s_, d_, p_))
|
|
if auto_pad == "NOTSET": return _onnx_pads_to_tiny_pads(p_ if len(p_)==len(k_)*2 else p_*2)
|
|
o_ = [((i - (1 if auto_pad in ("SAME_UPPER", "SAME_LOWER") else k)) // s + 1) for i,k,s in zip(i_, k_, s_)]
|
|
return _onnx_pads_to_tiny_pads(_auto_pad([(o-1)*s+k-i for o,i,k,s in zip(o_, i_, k_, s_)], auto_pad))
|
|
|
|
def _clamp_cast(x:Tensor, dtype:DType): return x.clamp(dtypes.min(dtype), dtypes.max(dtype)).cast(dtype)
|
|
|
|
def _prepare_quantize(x:Tensor, scale:Tensor, zero_point:Tensor|int, axis=1, block_size=0):
|
|
if axis < 0: axis += x.ndim
|
|
# https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_quantize_linear.py#L31
|
|
def reshape(val:Tensor):
|
|
if val.numel() == 1: return val
|
|
if block_size == 0: return val.reshape([val.shape[0] if dim == axis else 1 for dim in range(x.ndim)])
|
|
return val.repeat_interleave(block_size, axis)
|
|
return (reshape(scale), reshape(zero_point) if isinstance(zero_point, Tensor) else zero_point)
|
|
|
|
def _op_integer(op, inputs:list[Tensor], zero_points:list[Tensor], **opts):
|
|
adjusted_inputs = [inp.int() - zp for inp, zp in zip(inputs, zero_points)]
|
|
return op(*adjusted_inputs, **opts)
|
|
|
|
def _qlinearop_quantized(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts):
|
|
# op execution is done in quantized int
|
|
out = _op_integer(op, inputs, zero_points, **opts)
|
|
assert dtypes.is_int(out.dtype), "quantized op should've done math in int"
|
|
out_quantized = (out * prod(scales) / out_scale).round() + out_zero_point
|
|
return _clamp_cast(out_quantized, out_zero_point.dtype)
|
|
|
|
def _qlinearop_float(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts):
|
|
# op execution is done in float32
|
|
dequantized_inputs = [(inp.int() - zp) * scale for inp, zp, scale in zip(inputs, zero_points, scales)]
|
|
out = op(*dequantized_inputs, **opts)
|
|
assert dtypes.is_float(out.dtype), "op should've done math in float"
|
|
out_quantized = (out / out_scale).round() + out_zero_point
|
|
return _clamp_cast(out_quantized, out_zero_point.dtype)
|
|
|
|
def _onnx_training(input_group_size):
|
|
def __decorator(func):
|
|
def ___wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs):
|
|
R = R.detach()
|
|
groups = len(inputs) // input_group_size
|
|
ret = [func(R, T, *inps, **kwargs) for inps in (inputs[i::groups] for i in range(groups))]
|
|
return tuple(flatten(zip(*ret)))
|
|
return ___wrapper
|
|
return __decorator
|
|
|
|
# ***** Property/Graph Ops *****
|
|
def Identity(x:Tensor): return x
|
|
def Constant(sparse_value:Tensor|None=None, value:Tensor|None=None, value_float:float|None=None, value_floats:list[float]|None=None,
|
|
value_int:int|None=None, value_ints:list[int]|None=None, value_string:str|None=None, value_strings:list[str]|None=None):
|
|
if value is not None: return value
|
|
if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False)
|
|
if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False)
|
|
if value_int is not None: return Tensor(value_int, dtype=dtypes.int64, requires_grad=False)
|
|
if value_ints is not None: return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False)
|
|
if value_string is not None or value_strings is not None and sparse_value is not None:
|
|
raise NotImplementedError('Constant OP not implemented for value_string, value_strings and sparse_value')
|
|
|
|
def Range(start:float|int, limit:float|int, delta:float|int): return Tensor.arange(start=start, stop=limit, step=delta)
|
|
|
|
def ImageDecoder(encoded_stream:bytes, pixel_format="RGB"):
|
|
try: import PIL.Image
|
|
except ImportError as e: raise ImportError("Pillow must be installed for the ImageDecoder operator") from e
|
|
img = PIL.Image.open(io.BytesIO(encoded_stream))
|
|
if pixel_format == "BGR": return Tensor(img.tobytes(), dtype=dtypes.uint8).reshape(*img.size, 3).flip(-1)
|
|
if pixel_format == "RGB": return Tensor(img.tobytes(), dtype=dtypes.uint8).reshape(*img.size, 3)
|
|
if pixel_format == "Grayscale": return Tensor(img.convert("L").tobytes(), dtype=dtypes.uint8).reshape(*img.size, 1)
|
|
raise ValueError(f"pixel_format={pixel_format!r} is not supported.")
|
|
|
|
def EyeLike(x:Tensor, dtype:int|None=None, k:int=0):
|
|
ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype) if dtype is not None else x.dtype)
|
|
return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.shape[0]-k) for d in x.shape))
|
|
|
|
def OptionalHasElement(x:Tensor|None=None): return Tensor(x is not None and x.numel() > 0)
|
|
def OptionalGetElement(x:Tensor|None=None): return x if x is not None else Tensor([])
|
|
def ConstantOfShape(shape:list[int], value:Tensor|None=None):
|
|
if value is None: value = Tensor(0, dtype=dtypes.float32)
|
|
if shape == [0]: return Tensor([], dtype=value.dtype)
|
|
return value.expand(shape)
|
|
|
|
def Size(data:Tensor): return data.numel()
|
|
def Shape(data:Tensor, end:int|None=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64)
|
|
|
|
# ***** Unary Ops (math) *****
|
|
def Not(x:Tensor): return x.logical_not()
|
|
def Clip(x: Tensor, min:Tensor|None=None, max:Tensor|None=None): return x if min is None and max is None else x.clip(min, max)
|
|
|
|
# ***** Unary Ops (activation) *****
|
|
def Softmax_1(x:Tensor, axis:int=1): return x.softmax(axis)
|
|
def Softmax_13(x:Tensor, axis:int=-1): return x.softmax(axis)
|
|
Softmax = {1:Softmax_1, 13:Softmax_13}
|
|
def HardSigmoid(x:Tensor, alpha:float=0.2, beta:float=0.5): return (alpha*x + beta).clip(0, 1)
|
|
def Gelu(x:Tensor, approximate:str|None=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf())
|
|
def BiasGelu(x: Tensor, bias: Tensor, approximate: str | None = None) -> Tensor: return Gelu(x + bias, approximate)
|
|
def FastGelu(x:Tensor, bias:Tensor|None=None):
|
|
# this is tanh approximated
|
|
return (x + bias).gelu() if bias is not None else x.gelu()
|
|
# TODO: fix this
|
|
def PRelu(X:Tensor, slope:Tensor):
|
|
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope
|
|
return (X > 0).where(X, X * slope)
|
|
def LeakyRelu(X:Tensor, alpha:float=0.01): return X.leaky_relu(alpha)
|
|
def ThresholdedRelu(X:Tensor, alpha:float=1.0): return (X > alpha).where(X, 0)
|
|
def LogSoftmax(x: Tensor, axis:int=-1): return x.log_softmax(axis)
|
|
def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float()
|
|
|
|
# ***** Unary Ops (broadcasted) *****
|
|
def Add(x:Tensor,y:Tensor, broadcast=None, axis=None): return x + y if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else (x + y).cast(x.dtype)
|
|
def Sub(x:Tensor|int,y:Tensor): return x - y # some test has input as int
|
|
def Div(x:Tensor,y:Tensor): return x.div(y, rounding_mode='trunc' if dtypes.is_int(x.dtype) else None)
|
|
def Less(x:Tensor,y:Tensor): return x < y
|
|
def LessOrEqual(x:Tensor,y:Tensor): return x <= y
|
|
def Greater(x:Tensor,y:Tensor): return x > y
|
|
def GreaterOrEqual(x:Tensor,y:Tensor): return x >= y
|
|
def Equal(x:Tensor,y:Tensor): return x == y
|
|
def And(x:Tensor,y:Tensor): return (x==y).where(x, False)
|
|
def Or(x:Tensor,y:Tensor): return (x==y).where(x, True)
|
|
def Xor(x:Tensor,y:Tensor): return x.bool().bitwise_xor(y.bool())
|
|
def BitwiseAnd(x:Tensor,y:Tensor): return x & y
|
|
def BitwiseOr(x:Tensor,y:Tensor): return x | y
|
|
def BitwiseXor(x:Tensor,y:Tensor): return x ^ y
|
|
def BitwiseNot(x:Tensor): return ~x
|
|
def Mod(x:Tensor,y:Tensor,fmod=0):
|
|
if fmod: return x - x.div(y, rounding_mode="trunc") * y
|
|
return x % y
|
|
|
|
# ***** Casting Ops *****
|
|
# TODO: saturate
|
|
def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_parse(to))
|
|
def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype)
|
|
|
|
# ***** Reduce Ops *****
|
|
def Max(*data_0:Tensor): return functools.reduce(Tensor.maximum, data_0)
|
|
def Min(*data_0:Tensor): return functools.reduce(Tensor.minimum, data_0)
|
|
def Sum(*data_0:Tensor): return functools.reduce(Tensor.add, data_0)
|
|
def Mean(*data_0:Tensor): return Sum(*data_0) / len(data_0)
|
|
def ReduceMax(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
|
return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
|
def ReduceMin(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
|
return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
|
def ReduceSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
|
return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
|
def ReduceMean(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
|
return data.mean(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
|
def ReduceSumSquare(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
|
return ReduceSum(data.square(), axes, keepdims, noop_with_empty_axes)
|
|
def ReduceProd(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
|
return data.prod(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
|
def ReduceL1(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
|
return ReduceSum(data.abs(), axes, keepdims, noop_with_empty_axes)
|
|
def ReduceL2(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
|
return ReduceSumSquare(data, axes, keepdims, noop_with_empty_axes).sqrt()
|
|
def ReduceLogSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
|
return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log()
|
|
def ReduceLogSumExp(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
|
return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log()
|
|
def ArgMax(x:Tensor, axis:int=0, keepdims:int=1, select_last_index:int=0):
|
|
if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64)
|
|
return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64)
|
|
def ArgMin(x, axis:int=0, keepdims:int=1, select_last_index:int=0):
|
|
return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index)
|
|
|
|
# ***** Movement Ops *****
|
|
def Reshape(data:Tensor, shape:list[int], allowzero:int=0):
|
|
return data.reshape([x if x != 0 else (0 if allowzero else data.shape[i]) for i,x in enumerate(shape)])
|
|
def Flatten(x:Tensor, axis:int=1): return x.reshape(prod(x.shape[0:axis]), -1)
|
|
def Expand(x:Tensor, shape:list[int]): return x.expand(_broadcast_shape(x.shape, tuple(shape)))
|
|
def Shrink(x:Tensor, bias:float=0.0, lambd:float=0.5): return (x < -lambd)*(x+bias) + (x > lambd)*(x-bias)
|
|
def Transpose(x:Tensor, perm:list[int]|None=None): return x.permute(order=perm or list(range(x.ndim)[::-1]))
|
|
|
|
# TODO: add test for when axes is None
|
|
def Squeeze(data:Tensor, axes:list[int]|None=None):
|
|
return data.squeeze() if axes is None else functools.reduce(lambda d, dim: d.squeeze(dim), sorted(axes, reverse=True), data)
|
|
def Unsqueeze(data:Tensor, axes:list[int]): return functools.reduce(lambda d, dim: d.unsqueeze(dim), sorted(axes), data)
|
|
|
|
def Tile(x:Tensor, repeats:list[int]): return x.repeat(repeats)
|
|
def Concat(*xs:Tensor, axis:int): return Tensor.cat(*xs, dim=axis)
|
|
def Slice(data:Tensor, starts:list[int], ends:list[int], axes:list[int]|None=None, steps:list[int]|None=None):
|
|
axes = axes or list(range(data.ndim))
|
|
steps = steps or [1]*data.ndim
|
|
slices = [slice(0,x,1) for x in data.shape]
|
|
for i, axis in enumerate(axes): slices[axis] = slice(starts[i], ends[i], steps[i])
|
|
return data[tuple(slices)]
|
|
|
|
def Split(data:Tensor, split:list[int]|None=None, num_outputs:int=0, axis:int=0):
|
|
sz = data.shape[axis]
|
|
if split is None: split = [sz // num_outputs + (1 if i < sz % num_outputs else 0) for i in range(num_outputs)]
|
|
return data.split(split, axis)
|
|
|
|
def Pad(x:Tensor, pads:list[int], constant_value:ConstType|None=None, axes:list[int]|None=None,
|
|
mode:Literal["constant", "reflect", "edge", "wrap"]="constant", value=0):
|
|
value = constant_value or value
|
|
axes = axes or list(range(x.ndim))
|
|
real_pads = [0] * (x.ndim*2)
|
|
for i,axis in enumerate(axes): real_pads[axis%x.ndim], real_pads[axis%x.ndim+x.ndim] = pads[i], pads[i+len(axes)]
|
|
return x.pad(padding=_onnx_pads_to_tiny_pads(real_pads), mode={"edge":"replicate", "wrap":"circular"}.get(mode, mode), value=value)
|
|
|
|
def CenterCropPad(t:Tensor, shape:list[int], axes:list[int]|None=None):
|
|
shrink_arg:list[None|tuple[int,int]] = [None] * t.ndim
|
|
pad_arg:list[None|tuple[int,int]] = [None] * t.ndim
|
|
for s, x in zip(shape, axes or range(t.ndim)):
|
|
tx = t.shape[x]
|
|
if s < tx: shrink_arg[x] = (tx//2 - (s+1)//2, tx//2 + s//2)
|
|
elif s > tx: pad_arg[x] = ((s-tx)//2, (s-tx+1)//2)
|
|
return t.shrink(tuple(shrink_arg)).pad(tuple(pad_arg))
|
|
|
|
# ***** Processing Ops *****
|
|
def AveragePool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, count_include_pad:int=0,
|
|
dilations:list[int]|int=1, pads:list[int]|int=0, strides:list[int]|int=1):
|
|
return X.avg_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad),
|
|
ceil_mode=ceil_mode, count_include_pad=count_include_pad)
|
|
|
|
def MaxPool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, dilations:list[int]|int=1, pads:list[int]|int=0,
|
|
storage_order:int=0, strides:list[int]|int=1):
|
|
pads = _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad)
|
|
ret, idx = X.max_pool2d(kernel_shape, strides, dilations, pads, ceil_mode=ceil_mode, return_indices=True)
|
|
return ret, idx.transpose(-2, -1).cast(dtypes.int64) if storage_order else idx.cast(dtypes.int64)
|
|
|
|
def Conv(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1,
|
|
kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1):
|
|
return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations,
|
|
padding=_resolve_pool_pads(X, pads, kernel_shape or W.shape[2:], dilations, strides, auto_pad))
|
|
|
|
def ConvTranspose(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1,
|
|
kernel_shape:list[int]|None=None, pads:list[int]|None=None, output_shape:list[int]|None=None, output_padding:list[int]|int=0,
|
|
strides:list[int]|int=1):
|
|
input_shape, kernel_shape = X.shape[2:], (kernel_shape or W.shape[2:])
|
|
strides, dilations, output_padding = (make_tuple(x, len(input_shape)) for x in (strides, dilations, output_padding))
|
|
if output_shape is not None: # we pad according to output_shape
|
|
pads = _auto_pad([s*(i-1) + op + ((k-1)*d+1) - os for s,i,op,k,d,os in
|
|
zip(strides, input_shape, output_padding, kernel_shape, dilations, output_shape)], auto_pad)
|
|
if pads is None: # we generate pads
|
|
output_shape = output_shape or [X.shape[i+2] * strides[i] for i in range(len(strides))]
|
|
pads = [strides[i]*(input_shape[i]-1) + output_padding[i] + ((kernel_shape[i]-1)*dilations[i]+1)-output_shape[i] for i in range(len(input_shape))]
|
|
pads = _auto_pad(pads, auto_pad) if auto_pad != "NOTSET" else [0] * len(input_shape) * 2
|
|
pads = _onnx_pads_to_tiny_pads(pads)
|
|
return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=pads, output_padding=output_padding)
|
|
|
|
def MaxUnpool(xT: Tensor, xI: Tensor, outshape: list[int]|None=None, kernel_shape:list[int]=None, pads:list[int]|int=0, strides:list[int]|int=1):
|
|
return Tensor.max_unpool2d(xT, xI, kernel_shape, strides, 1, pads, outshape if outshape is None else tuple(outshape))
|
|
|
|
def GlobalAveragePool(X:Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True)
|
|
def GlobalMaxPool(X:Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True)
|
|
|
|
def Gemm(A:Tensor, B:Tensor, C:Tensor|None=None, alpha:float=1.0, beta:float=1.0, transA:int=0, transB:int=0, broadcast=0):
|
|
ret = alpha * (A.transpose(transA) @ B.transpose(transB))
|
|
if C is not None: ret = ret + beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1]))
|
|
return ret
|
|
|
|
def Einsum(*Inputs:list[Tensor], equation:str): return Tensor.einsum(equation, *Inputs)
|
|
|
|
def CumSum(X:Tensor, axis:int|list, exclusive:int=0, reverse:int=0):
|
|
axis = X._resolve_dim(axis[0] if isinstance(axis, list) else axis)
|
|
if reverse: X = X.flip(axis)
|
|
if exclusive: X = X.pad(tuple((1,0) if i == axis else None for i in range(X.ndim)))\
|
|
.shrink(tuple((0,X.shape[axis]) if i == axis else None for i in range(X.ndim)))
|
|
return X.cumsum(axis).flip(axis) if reverse else X.cumsum(axis)
|
|
|
|
def Trilu(x:Tensor, k:int=0, upper:int=1): return x.triu(k) if upper else x.tril(k)
|
|
|
|
def Resize(X:Tensor, roi:list[float]|None=None, scales:list[float]|None=None, sizes:list[int]|None=None, antialias:int=0,
|
|
axes:list[int]|None=None, coordinate_transformation_mode:str='half_pixel', cubic_coeff_a:float=-0.75, exclude_outside:int=0,
|
|
extrapolation_value:float=0.0, keep_aspect_ratio_policy:str='stretch', mode:str='nearest', nearest_mode:str='round_prefer_floor'):
|
|
def _apply_nearest_mode(index: Tensor, input_dim, mode: str):
|
|
if mode == "round_prefer_floor": index = (index - 0.5).ceil()
|
|
elif mode == "round_prefer_ceil": index = (index + 0.5).floor()
|
|
elif mode in ["floor", "ceil"]: index = getattr(index, mode)()
|
|
else: raise ValueError(f"invalid {nearest_mode=}")
|
|
return index.cast(dtypes.int32).clip(0, input_dim-1)
|
|
def _apply_transformation(index: Tensor, input_dim, scale_dim, mode):
|
|
# TODO: needs more testing, not confident in this
|
|
# NOTE: their reference implementation differ from the implementation in their reference docs
|
|
# https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_resize.py
|
|
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize
|
|
output_dim = scale_dim * input_dim
|
|
if mode == "half_pixel": index = (index + 0.5) / scale_dim - 0.5
|
|
elif mode == "align_corners": index = index * (input_dim - 1) / (output_dim - 1) if output_dim != 1 else Tensor([0])
|
|
elif mode == "asymmetric": index = index / scale_dim
|
|
elif mode == "pytorch_half_pixel": index = (index + 0.5) / scale_dim - 0.5 if output_dim != 1 else Tensor([-0.5])
|
|
elif mode == "half_pixel_symmetric": index = input_dim / 2 * (1 - int(output_dim) / output_dim) + (index + 0.5) / scale_dim - 0.5
|
|
else: raise NotImplementedError(f"invalid {coordinate_transformation_mode=}")
|
|
return index.clip(0, input_dim-1)
|
|
|
|
scales, sizes = (None if scales is None else scales[2-(X.ndim-len(scales)):]), (None if sizes is None else sizes[2-(X.ndim-len(sizes)):])
|
|
# we pre permute the axes and permute back after resize
|
|
axes, input_shape, = (axes or list(range(X.ndim))), cast(tuple[int, ...], X.shape[2:]),
|
|
perm = [a for a in range(len(X.shape)) if a not in axes] + list(axes)
|
|
X = X.permute(*perm)
|
|
|
|
if sizes is not None:
|
|
if keep_aspect_ratio_policy in ["not_larger", "not_smaller"]:
|
|
scale_fxn = min if keep_aspect_ratio_policy == "not_larger" else max
|
|
scales = [scale_fxn([sizes[i] / input_shape[i] for i in range(len(input_shape)) if i+2 in axes])] * 2
|
|
sizes = [int((scales[0] * input_shape[i]) + 0.5) if i+2 in axes else input_shape[i] for i in range(X.ndim-2)]
|
|
else:
|
|
scales = [size / input_shape for size, input_shape in zip(sizes, input_shape)]
|
|
else:
|
|
sizes = [int(sc*sh) for sc, sh in zip(scales, input_shape)]
|
|
|
|
# NOTE: this transformation makes it so that we can't just call Tensor.interpolate
|
|
# in Tensor.interpolate, we use indexes without any transformation
|
|
indexes = []
|
|
for shape, size, scale in zip(input_shape, sizes, scales):
|
|
indexes.append(_apply_transformation(Tensor.arange(size), shape, scale, coordinate_transformation_mode))
|
|
|
|
if mode == "nearest":
|
|
indexes = [_apply_nearest_mode(index, shape, nearest_mode) for (index, shape) in zip(indexes, input_shape)]
|
|
X = X[(..., *Tensor.meshgrid(*indexes))]
|
|
if mode == "linear":
|
|
expand = list(X.shape)
|
|
for i in range(-len(sizes), 0):
|
|
reshape, index = [1] * X.ndim, indexes[i]
|
|
reshape[i] = expand[i] = sizes[i]
|
|
low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor().int(), index.ceil().int(), index - index.floor())]
|
|
X = X.gather(i, low).lerp(X.gather(i, high), perc)
|
|
if mode == "cubic": raise NotImplementedError("cubic interpolation is not implemented")
|
|
return X.permute(*argsort(perm)) if perm else X
|
|
def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode) # deprecated
|
|
|
|
def TopK(X:Tensor, K:int|list[int], axis:int=-1, largest:int=1, sorted:int=1):
|
|
val, idx = X.topk(K if isinstance(K, int) else K[0], axis, largest, sorted)
|
|
return val, idx.cast(dtypes.int64)
|
|
|
|
# ***** Neural Network Ops *****
|
|
def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9,
|
|
training_mode:int=0, spatial=1, is_test=0):
|
|
if training_mode:
|
|
x_detached = X.detach()
|
|
current_mean = x_detached.mean(axis=(0,2,3))
|
|
y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1]))
|
|
current_var = (y*y).mean(axis=(0,2,3))
|
|
current_invstd = current_var.add(epsilon).rsqrt()
|
|
|
|
running_mean = input_mean * momentum + current_mean * (1 - momentum)
|
|
running_var = input_var * momentum + current_var * (1 - momentum)
|
|
|
|
return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var
|
|
return X.batchnorm(scale, B, input_mean, (input_var + epsilon).rsqrt())
|
|
def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05):
|
|
x = x.reshape(x.shape[0], num_groups, -1).layernorm(eps=epsilon).reshape(x.shape)
|
|
return x * scale.reshape(1, -1, *[1] * (x.ndim-2)) + bias.reshape(1, -1, *[1] * (x.ndim-2))
|
|
def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05):
|
|
return GroupNormalization(x, scale, bias, num_groups=x.shape[1], epsilon=epsilon)
|
|
def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1):
|
|
assert stash_type == 1, "only float32 is supported"
|
|
axes = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim))
|
|
mean = x.mean(axis=axes, keepdim=True)
|
|
return x.layernorm(axes, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt()
|
|
def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma:Tensor, beta:Tensor|None=None, bias:Tensor|None=None, epsilon:float=1e-12):
|
|
x = x + skip
|
|
if bias is not None: x = x + bias
|
|
ret = x.layernorm(eps=epsilon) * gamma
|
|
if beta is not None: ret = ret + beta
|
|
return ret, None, None, x
|
|
def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Tensor, word_embedding:Tensor, position_embedding:Tensor,
|
|
segment_embedding:Tensor, gamma=None, beta=None, mask:Tensor|None=None,
|
|
position_ids:Tensor|None=None, epsilon=1e-12, mask_index_type=0):
|
|
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization
|
|
assert (segment_ids is None) is (segment_embedding is None)
|
|
assert mask is None and not mask_index_type, "functionality not supported yet" # TODO
|
|
input_shape = input_ids.shape
|
|
seq_length = input_shape[1]
|
|
compute_seg_emb = (segment_embedding is not None and segment_ids is not None)
|
|
vocab_size, max_position_embeddings = word_embedding.shape[0], position_embedding.shape[0]
|
|
type_vocab_size = (segment_embedding.shape[0] if compute_seg_emb else None)
|
|
|
|
def embedding(x:Tensor, vocab_size, weight:Tensor) -> Tensor:
|
|
return x.unsqueeze(-1).expand(*x.shape, vocab_size)._one_hot_along_dim(vocab_size) @ weight
|
|
|
|
# bert embedding layer
|
|
if position_ids is None: position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape)
|
|
wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding)
|
|
pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding)
|
|
seg_embedding_res = embedding(segment_ids, type_vocab_size, segment_embedding) if compute_seg_emb else None
|
|
|
|
embedding_sum = wrd_embedding_res + pos_embedding_res
|
|
if seg_embedding_res is not None: embedding_sum = embedding_sum + seg_embedding_res
|
|
out = embedding_sum.layernorm(eps=epsilon) * gamma + beta
|
|
return out, None, embedding_sum
|
|
def MeanVarianceNormalization(x:Tensor, axis:list[int]=[0,2,3]):
|
|
return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9)
|
|
|
|
def OneHot(indices:Tensor, depth:float|int|list, values:Tensor, axis:int=-1):
|
|
# Scalar or Rank 1 tensor containing exactly one element
|
|
depth = int(depth[0] if isinstance(depth, list) else depth)
|
|
indices = indices.int()
|
|
indices = (indices < 0).where(indices+depth, indices)
|
|
return indices.unsqueeze(axis)._one_hot_along_dim(depth, dim=axis).where(values[1], values[0])
|
|
|
|
def DepthToSpace(X:Tensor, blocksize:int, mode:str="DCR"):
|
|
return X.rearrange("b (c h1 w1) h w -> b c (h h1) (w w1)" if mode=="CRD" else "b (h1 w1 c) h w -> b c (h h1) (w w1)", h1=blocksize, w1=blocksize)
|
|
def SpaceToDepth(X:Tensor, blocksize:int):
|
|
return X.rearrange("b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=blocksize, w1=blocksize)
|
|
|
|
# Reimplemented here because you need legacy RNG for passing ONNX tests.
|
|
def Dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None):
|
|
if not training_mode: return data, Tensor.ones(data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's.
|
|
mask = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)) >= ratio, requires_grad=False, device=data.device)
|
|
return data * mask * (1/(1.0 - ratio)), mask
|
|
# 6 with 'is_test' needed for https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx
|
|
def Dropout_6(data:Tensor, ratio:float=0.5, is_test=0): return Dropout_7(data, ratio, training_mode=not is_test)
|
|
Dropout = {6:Dropout_6, 7:Dropout_7}
|
|
|
|
def LRN(x:Tensor, size:int, alpha:float=1e-4, beta:float=0.75, bias:float=1.0):
|
|
pooled_x = (x**2).rearrange('b c h w -> b 1 c (h w)').pad((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1)
|
|
return x / (pooled_x.reshape(x.shape) * alpha + bias).pow(beta)
|
|
|
|
def NegativeLogLikelihoodLoss(x:Tensor, target:Tensor, weight:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"):
|
|
return x.nll_loss(target, weight, ignore_index, reduction)
|
|
def SoftmaxCrossEntropyLoss(scores:Tensor, labels:Tensor, weights:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"):
|
|
log_probs = scores.log_softmax(1)
|
|
return log_probs.nll_loss(labels, weights, ignore_index, reduction), log_probs
|
|
|
|
def AffineGrid(theta:Tensor, size:list[int], align_corners:int=0):
|
|
N, _, *spatial_dims = size
|
|
def generate_grid(steps):
|
|
return Tensor.linspace(-1, 1, steps, device=theta.device) if align_corners else Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device)
|
|
grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims))
|
|
base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1)
|
|
base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1)
|
|
return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1)
|
|
|
|
def Attention(x:Tensor, weights:Tensor, bias:Tensor|None=None, mask_index:Tensor|None=None, past:Tensor|None=None, attention_bias:Tensor|None=None,
|
|
past_sequence_length:Tensor|None=None, do_rotary:int=0, mask_filter_value:float=-10000.0, num_heads:int|None=None,
|
|
past_present_share_buffer:int|None=None, qkv_hidden_sizes:list[int]|None=None, rotary_embedding_dim:int|None=None,
|
|
scale:float|None=None, unidirectional:int=0):
|
|
assert not do_rotary and not attention_bias, "TODO"
|
|
if qkv_hidden_sizes is None: qkv_hidden_sizes = [weights.shape[1] // 3] * 3
|
|
qkv = x.linear(weights, bias)
|
|
q, k, v = qkv.split(qkv_hidden_sizes, dim=2)
|
|
|
|
batch_size, seq_len, _ = x.shape
|
|
q_head_size, k_head_size, v_head_size = (sz // num_heads for sz in qkv_hidden_sizes)
|
|
q, k, v = (x.reshape(batch_size, seq_len, num_heads, hsz).transpose(1, 2) for x, hsz in zip((q, k, v), (q_head_size, k_head_size, v_head_size)))
|
|
|
|
present = None
|
|
if past is not None:
|
|
k, v = past[0].cat(k, dim=2), past[1].cat(v, dim=2)
|
|
present = k.stack(v)
|
|
|
|
if scale is None: scale = 1.0 / math.sqrt(q_head_size)
|
|
attn_scores = q @ k.transpose(-1, -2) * scale
|
|
|
|
if mask_index is not None:
|
|
assert 4 >= mask_index.ndim >= 1, f"{mask_index.ndim=}"
|
|
if mask_index.ndim != 1: mask = mask_index.bool()
|
|
else:
|
|
if mask_index.shape[0] == batch_size:
|
|
mask = Tensor.arange(attn_scores.shape[-1], requires_grad=False, device=mask_index.device).unsqueeze(0) < mask_index.unsqueeze(1)
|
|
elif mask_index.shape[0] == 2*batch_size:
|
|
end_positions = mask_index[:batch_size]
|
|
start_positions = mask_index[batch_size:]
|
|
arange = Tensor.arange(seq_len).unsqueeze(0)
|
|
mask = (arange < end_positions.unsqueeze(1)) & (arange >= start_positions.unsqueeze(1))
|
|
else: raise NotImplementedError("mask_index with shape (3 * batch_size + 2) is not implemented")
|
|
while mask.ndim < 4: mask = mask.unsqueeze(1)
|
|
attn_scores = mask.where(attn_scores, mask_filter_value)
|
|
|
|
if unidirectional:
|
|
causal_mask = Tensor.ones((seq_len, seq_len), dtype=dtypes.bool).tril()
|
|
attn_scores = causal_mask.where(attn_scores, mask_filter_value)
|
|
|
|
output = attn_scores.softmax(-1) @ v
|
|
output = output.transpose(1, 2).reshape(batch_size, seq_len, -1)
|
|
return output, present
|
|
|
|
# ***** Indexing Ops *****
|
|
def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices]
|
|
|
|
def Gather(x:Tensor, indices:Tensor, axis:int=0):
|
|
if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices
|
|
x_sh = list(x.shape)
|
|
ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:]
|
|
if indices.ndim > 1: indices = indices.flatten()
|
|
indices = [_cached_to_python_const(indices)] if indices.shape == () else _cached_to_python_const(indices)
|
|
indices = [x_sh[axis]+x if x<0 else x for x in indices]
|
|
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] # type: ignore
|
|
return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape)
|
|
# NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
|
|
return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])]
|
|
def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated
|
|
|
|
def GatherND(x:Tensor, indices:Tensor, batch_dims:int=0):
|
|
if batch_dims == 0: return x[tuple(i.squeeze(-1) for i in indices.split(1, -1))]
|
|
x_shape, i_shape = x.shape, indices.shape
|
|
b = math.prod(x.shape[dim] for dim in range(batch_dims))
|
|
# NOTE: each batched dim of both input and indices are equal
|
|
x = x.reshape(b, *x.shape[batch_dims:])
|
|
indices = indices.reshape(b, *indices.shape[batch_dims:])
|
|
b_idx = Tensor.arange(b, device=x.device).reshape(b, *(1,)*(indices.ndim - 2)).expand(*indices.shape[:-1])
|
|
ret = x[(b_idx,) + tuple(i.squeeze(-1) for i in indices.split(1, -1))]
|
|
return ret.reshape(*x_shape[:batch_dims], *i_shape[batch_dims:-1], *ret.shape[indices.ndim-1:])
|
|
def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none", "add", "mul"]='none'):
|
|
assert updates.shape == indices.shape[:-1] + x.shape[cast(int, indices.shape[-1]):]
|
|
x = x.contiguous()
|
|
for index, u in zip(indices.split(1, 0), updates.split(1, 0)):
|
|
i = tuple(idx.squeeze(-1) for idx in index.squeeze(0).split(1, -1))
|
|
u = u.squeeze(0)
|
|
if reduction == "none": x[i] = u
|
|
elif reduction == "add": x[i] += u
|
|
elif reduction == "mul": x[i] *= u
|
|
else: raise NotImplementedError("reduction doesn't support max or min")
|
|
return x
|
|
|
|
def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul", "min", "max"]="none"):
|
|
indices = (indices < 0).where(x.shape[axis], 0) + indices
|
|
if reduction == "none": return x.scatter(axis, indices, updates)
|
|
return x.scatter_reduce(axis, indices, updates, {"add": "sum", "mul": "prod", "min": "amin", "max": "amax"}.get(reduction))
|
|
def GatherElements(x:Tensor, indices:Tensor, axis:int):
|
|
indices = (indices < 0).where(x.shape[axis], 0) + indices
|
|
return x.gather(axis, indices)
|
|
|
|
def Compress(inp:Tensor, condition:list[bool], axis:int|None=None):
|
|
if axis is None:
|
|
inp = inp.flatten()
|
|
axis = 0
|
|
if axis < 0: axis += inp.ndim
|
|
con = Tensor([i for i,cond in enumerate(condition) if cond]) # compress in python
|
|
return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))]
|
|
|
|
# ***** Quantization Ops *****
|
|
def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1):
|
|
out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8
|
|
y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size)
|
|
if out_dtype == dtypes.uchar:
|
|
# this appears to work in practice, at least for uchar out_dtype. it folds with the quantize stuff
|
|
ret = _clamp_cast((x / y_scale + 0.4999999 + y_zero_point).int(), out_dtype)
|
|
else:
|
|
ret = _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype)
|
|
# you need both NHWC=1 DONT_GROUP_REDUCES=1 for this to work
|
|
if getenv("NHWC") and len(ret.shape) == 4: return ret.permute(0,2,3,1).contiguous().permute(0,3,1,2)
|
|
return ret.contiguous()
|
|
|
|
def DynamicQuantizeLinear(x: Tensor):
|
|
# only support uint8
|
|
qmin, qmax = dtypes.min(dtypes.uint8), dtypes.max(dtypes.uint8)
|
|
scale = (x.max().maximum(0) + ((-x).max()).maximum(0)) / (qmax - qmin)
|
|
zero_point = _clamp_cast((qmin - x.min() / scale).round(), dtypes.uint8)
|
|
y = _clamp_cast((x / scale).round() + zero_point, dtypes.uint8)
|
|
return y, scale, zero_point
|
|
|
|
def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0):
|
|
WEIGHT_SHIFT = 4
|
|
if getenv("NHWC") and len(x.shape) == 4 and x.shape[2:] == (1,1) and x.shape[1]%WEIGHT_SHIFT == 0:
|
|
# DSP swizzle memory
|
|
x = x.reshape(x.shape[0], x.shape[1]//WEIGHT_SHIFT, WEIGHT_SHIFT).permute(1,0,2).contiguous().permute(1,0,2).reshape(x.shape)
|
|
x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size)
|
|
return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype)
|
|
|
|
def QLinearConv(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int, w:Tensor, w_scale:Tensor, w_zero_point:Tensor|int, y_scale:Tensor,
|
|
y_zero_point: Tensor|int, B:Tensor|None=None, **opts):
|
|
return _qlinearop_quantized(Conv, [x,w], [x_zero_point,w_zero_point], [x_scale,w_scale], y_scale, y_zero_point, **{"B":B, **opts})
|
|
|
|
def QLinearMatMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor|int, b:Tensor, b_scale:Tensor, b_zero_point:Tensor|int, y_scale:Tensor,
|
|
y_zero_point:Tensor|int) -> Tensor:
|
|
return _qlinearop_quantized(Tensor.matmul, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], y_scale, y_zero_point)
|
|
|
|
def QLinearAdd(a:Tensor, a_scale:Tensor, a_zero_point:Tensor, b:Tensor, b_scale:Tensor, b_zero_point:Tensor, c_scale:Tensor, c_zero_point:Tensor):
|
|
return _qlinearop_float(Tensor.add, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], c_scale, c_zero_point)
|
|
|
|
def QLinearMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor, b:Tensor, b_scale:Tensor, b_zero_point:Tensor, c_scale:Tensor, c_zero_point:Tensor):
|
|
return _qlinearop_quantized(Tensor.mul, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], c_scale, c_zero_point)
|
|
|
|
def QLinearGlobalAveragePool(X:Tensor, x_scale:Tensor, x_zero_point:Tensor, y_scale:Tensor, y_zero_point:Tensor, channels_last:int):
|
|
assert channels_last == 0, "TODO NHWC"
|
|
return _qlinearop_float(GlobalAveragePool, [X], [x_zero_point], [x_scale], y_scale, y_zero_point)
|
|
|
|
def ConvInteger(x: Tensor, w: Tensor, x_zero_point: Tensor | int = 0, w_zero_point: Tensor | int = 0, B: Tensor | None = None, **opts) -> Tensor:
|
|
return _op_integer(Conv, [x,w], [x_zero_point,w_zero_point], **{"B":B, **opts})
|
|
|
|
def MatMulInteger(A: Tensor, B: Tensor, a_zero_point: Tensor | int = 0, b_zero_point: Tensor | int = 0) -> Tensor:
|
|
return _op_integer(Tensor.matmul, [A,B], [a_zero_point,b_zero_point])
|
|
|
|
# ***** Training Ops *****
|
|
# NOTE: onnx test coverage only covers `T==0` cases, so for all `T>0` this isn't tested
|
|
# NOTE: onnx training ops actually don't need the state for optim, all the ops work in a functional way, but we still can reuse optim.py code
|
|
@_onnx_training(3)
|
|
def Adagrad(R:Tensor, T:int, *inputs:Tensor, decay_factor:float=0.0, epsilon:float=0.0, norm_coefficient:float=0.0):
|
|
X, G, H = (i.detach() for i in inputs)
|
|
grad = norm_coefficient * X + G
|
|
H.assign(H + grad.square())
|
|
up = grad / (H.sqrt() + epsilon)
|
|
r = R / (1 + T * decay_factor)
|
|
X.assign(X.detach() - r * up)
|
|
return [X, H]
|
|
|
|
@_onnx_training(4)
|
|
def Adam(R:Tensor, T:int, *inputs:Tensor, alpha:float=0.9, beta:float=0.999, epsilon:float=0.0, norm_coefficient:float=0.0,
|
|
norm_coefficient_post:float=0.0):
|
|
from tinygrad.nn.optim import Adam as TinyAdam
|
|
X, G, V, H = inputs
|
|
G, V, H = G.detach(), V.detach(), H.detach() # TODO we shouldn't need these detaches
|
|
X.grad = norm_coefficient * X.detach() + G
|
|
opt = TinyAdam([X], b1=alpha, b2=beta, eps=epsilon)
|
|
opt.m, opt.v, opt.lr = [V], [H], R
|
|
# need no-op for m_hat and v_hat if T == 0
|
|
if T == 0: opt.b1_t, opt.b2_t = opt.b1_t.zeros_like(), opt.b2_t.zeros_like()
|
|
else:
|
|
# `T-1` since it's applied again at the start of `_step`
|
|
opt.b1_t = Tensor([alpha**(T-1)], dtype=dtypes.float32, device=X.device, requires_grad=False)
|
|
opt.b2_t = Tensor([beta**(T-1)], dtype=dtypes.float32, device=X.device, requires_grad=False)
|
|
opt.step()
|
|
X = (1 - norm_coefficient_post) * X
|
|
return [X, V, H]
|
|
|
|
@_onnx_training(3)
|
|
def Momentum(R:Tensor, T:int, *inputs:Tensor, alpha:float, beta:float, mode:str, norm_coefficient:float):
|
|
from tinygrad.nn.optim import SGD
|
|
X, G, V = inputs
|
|
G, V = G.detach(), V.detach()
|
|
X.grad = (norm_coefficient * X.detach() + G) * (beta if T > 0 else 1)
|
|
opt = SGD([X], momentum=alpha, nesterov=(mode=="nesterov"))
|
|
opt.b, opt.lr = [V], R
|
|
opt.step()
|
|
return [X, V]
|
|
|
|
def Gradient(*inputs:Tensor, y:str, intermediate_tensors:dict[str, Tensor], **_):
|
|
intermediate_tensors[y].backward()
|
|
return tuple([t.grad for t in inputs])
|
|
|
|
return {
|
|
# Tensor ops
|
|
**{op: getattr(Tensor, op.lower()) for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan",
|
|
"Asin", "Acos", "Atan", "Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh",
|
|
"Tanh", "Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Round", "Erf")},
|
|
# Implemented ops
|
|
**{name:obj for name,obj in locals().items() if isinstance(obj, types.FunctionType) and not name.startswith("_") and name[0].isupper()},
|
|
# Version ops
|
|
**{name:obj for name,obj in locals().items() if isinstance(obj, dict)},
|
|
}
|
|
|
|
onnx_ops = get_onnx_ops()
|
|
|