from __future__ import annotations from typing import Any, Optional, Set, Union, Tuple, Callable, cast, TYPE_CHECKING, Type, DefaultDict, Literal, get_args import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref from enum import auto, IntEnum, Enum from dataclasses import dataclass, field from collections import defaultdict from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, _METADATA, flatten from tinygrad.helpers import PICKLE_BUFFERS, SPLIT_REDUCEOP, DEBUG if TYPE_CHECKING: from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.device import Buffer # wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses class FastEnum(IntEnum): def __str__(self): return Enum.__str__(self) @staticmethod def _generate_next_value_(_, __, ___, last_values): return 1 + max([0, *last_values, *[max(c) for c in FastEnum.__subclasses__()]]) class SimpleMathTrait: # required to implement def alu(self:T, arg:Ops, *src) -> T: raise NotImplementedError def const_like(self:T, b:ConstLike) -> T: raise NotImplementedError # great functions you get! def ufix(self, x): return self.const_like(x) if not isinstance(x, MathTrait) else x def _binop(self, op, x, reverse): return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x)) def logical_not(self): return self.ne(True) def neg(self): if (dtype:=getattr(self, 'dtype')) is None: raise TypeError(f"MathTraits __neg__ requires a dtype, {self=}") return self.logical_not() if dtype.scalar() == dtypes.bool else self*(-1) def add(self, x, reverse=False): return self._binop(Ops.ADD, x, reverse) def mul(self, x, reverse=False): return self._binop(Ops.MUL, x, reverse) def bitwise_and(self, x, reverse=False): return self._binop(Ops.AND, x, reverse) def bitwise_or(self, x, reverse=False): return self._binop(Ops.OR, x, reverse) def xor(self, x, reverse=False): return self._binop(Ops.XOR, x, reverse) def idiv(self, x, reverse=False): return self._binop(Ops.IDIV, x, reverse) def sub(self, x, reverse=False): return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x)) def div(self, x, reverse=False): return (self.ufix(x)*self.alu(Ops.RECIP)) if reverse else (self*self.ufix(x).alu(Ops.RECIP)) def __neg__(self): return self.neg() def __add__(self, x): return self.add(x) def __sub__(self, x): return self.sub(x) def __mul__(self, x): return self.mul(x) def __truediv__(self, x): return self.div(x) def __floordiv__(self, x): return self.idiv(x) def __and__(self, x): return self.bitwise_and(x) def __or__(self, x): return self.bitwise_or(x) def __xor__(self, x): return self.xor(x) def __radd__(self, x): return self.add(x, True) def __rsub__(self, x): return self.sub(x, True) def __rmul__(self, x): return self.mul(x, True) def __rtruediv__(self, x): return self.div(x, True) def __rfloordiv__(self, x): return self.idiv(x, True) def __rand__(self, x): return self.bitwise_and(x, True) def __ror__(self, x): return self.bitwise_or(x, True) def __rxor__(self, x): return self.xor(x, True) def __lt__(self, x): return self.alu(Ops.CMPLT, self.ufix(x)) def __gt__(self, x): return self.ufix(x).alu(Ops.CMPLT, self) def __ge__(self, x): return (self < x).logical_not() def __le__(self, x): return (self > x).logical_not() def ne(self, x): return self.alu(Ops.CMPNE, self.ufix(x)) def eq(self, x): return self.ne(x).logical_not() def __ne__(self, x): return self.ne(x) # NOTE: __eq__ isn't overridden, and means the same thing as is by default class MathTrait(SimpleMathTrait): # TODO: move to Tensor when new backward is done def lshift(self, x, reverse=False): return self._binop(Ops.SHL, x, reverse) def rshift(self, x, reverse=False): return self._binop(Ops.SHR, x, reverse) def __lshift__(self, x): return self.lshift(x) def __rshift__(self, x): return self.rshift(x) def __rlshift__(self, x): return self.lshift(x, True) def __rrshift__(self, x): return self.rshift(x, True) # not in Tensor def __mod__(self, x): return self.alu(Ops.MOD, self.ufix(x)) def __rmod__(self, x): return self.ufix(x).alu(Ops.MOD, self) def maximum(self, x): return self.alu(Ops.MAX, self.ufix(x)) def minimum(self, x): return -(-self).maximum(-x) def where(self, x, y): return self.alu(Ops.WHERE, x, x.ufix(y)) def threefry(self, seed): return self.alu(Ops.THREEFRY, seed) def reciprocal(self): return self.alu(Ops.RECIP) def sqrt(self): return self.alu(Ops.SQRT) def sin(self): return self.alu(Ops.SIN) def log2(self): return self.alu(Ops.LOG2) def exp2(self): return self.alu(Ops.EXP2) # the order of these Ops controls the order of the toposort class Ops(FastEnum): # uops that aren't rendered SINK = auto(); CONTIGUOUS = auto(); DETACH = auto(); PRELOAD = auto() # noqa: E702 # MetaOps COPY = auto(); EMPTY = auto(); BUFFER_VIEW = auto() # noqa: E702 # blocks in linearizer BLOCK = auto(); BLOCKSTART = auto(); BLOCKFORK = auto(); BLOCKEND = auto() # noqa: E702 # movement ops! RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto() # noqa: E702 # misc ops UNROLL = auto(); CONTRACT = auto() # noqa: E702 VIEW = auto(); DEFINE_GLOBAL = auto(); BUFFER = auto() # noqa: E702 DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702 VALID = auto(); SPECIAL = auto(); NOOP = auto() # noqa: E702 # reduce REDUCE_AXIS = auto() # helper ops GEP = auto(); VECTORIZE = auto() # noqa: E702 # UnaryOps CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702 # load/store before math LOAD = auto(); STORE = auto() # noqa: E702 # early INDEX INDEX = auto() # math ops WMMA = auto() # BinaryOps ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702 SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto() # noqa: E702 # TernaryOps WHERE = auto(); MULACC = auto() # noqa: E702 # assignment ops ASSIGN = auto() BIND = auto() # control flow ops BARRIER = auto(); RANGE = auto(); IF = auto(); ENDRANGE = auto(); ENDIF = auto() # noqa: E702 # consts last! VCONST = auto(); CONST = auto() # noqa: E702 # device DEVICE = auto() class GroupOp: Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG} Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB, Ops.FDIV} Ternary = {Ops.WHERE, Ops.MULACC} ALU = set.union(Unary, Binary, Ternary) Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE} Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.STRIDE} # meta ops Meta = {Ops.COPY, Ops.EMPTY, Ops.BUFFER_VIEW} Buffer = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID} Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKFORK, Ops.BLOCKSTART} # BinaryOps that can be flipped Commutative = {Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPNE, Ops.XOR, Ops.AND, Ops.OR} # BinaryOps where f(f(a,b),c) = f(a,f(b,c)) Associative = {Ops.ADD, Ops.MUL, Ops.AND, Ops.OR} # BinaryOps that satisfy f(x,x)=x see https://en.wikipedia.org/wiki/Idempotence Idempotent = {Ops.OR, Ops.AND, Ops.MAX} # do not preserve f(0) = 0 UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV} # some BUFFER ops can be processed with only a view view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "METAL", "QCOM", "DSP", "DISK"} # https://en.wikipedia.org/wiki/Identity_element def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt) def can_pad(u:UOp, edges:dict[UOp, UOp], visisted:set[UOp]) -> bool: if u.op in GroupOp.UnsafePad: return False if (len(u.src) == 2 and u.src[0] in edges) or u in visisted: return True visisted.add(u) return all(can_pad(x.base, edges, visisted) for x in u.src) # With True as the default, this matches the old symbolic behavior def resolve(x, default:bool=True): if not isinstance(x, UOp): return bool(x) assert x.dtype is dtypes.bool, "UOp in resolve must be bool" # NOTE: generating the text for the exception is expensive, so we do this return bool(sx.vmin) if (sx:=x.simplify()).vmin == sx.vmax else default # smax/smin are replacements for max/min that preserve symbolic def _suop(lst, uop_fxn, python_fxn): uops, nums = partition(lst, lambda x: isinstance(x, UOp)) return ssimplify(functools.reduce(uop_fxn, uops + ([python_fxn(nums)] if nums else []))) def smax(*lst): return _suop(argfix(*lst), UOp.maximum, max) def smin(*lst): return _suop(argfix(*lst), UOp.minimum, min) def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop def sym_infer(uop: Union[UOp, int], var_vals: dict[UOp, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop # used for UOp and UPat def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->str: def dfs(x:Any, cache:dict): for s in srcfn(x) or []: cache.setdefault(s, [len(cache), 0, False])[1] += 1 if cache[s][1] == 1: dfs(s, cache) if cache is None: dfs(x, cache:={}) if (cx:=cache.setdefault(x, [0,0,False]))[2]: return f"{' '*d} x{cx[0]}" cx[2], srcs = True, ('None' if srcfn(x) is None else ''.join(f'\n{pretty_print(s, rep, srcfn, cache, d+2)},' for s in srcfn(x))) return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs class UOpMetaClass(type): ucache:dict[Tuple, weakref.ReferenceType[UOp]] = {} def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, _buffer=None): if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None and (ret:=wret()) is not None: return ret UOpMetaClass.ucache[key] = weakref.ref(created:=super().__call__(*key)) # NOTE: this will soon be set by Tensor once we remove function.py if (metadata:=_METADATA.get()) is not None: all_metadata[created] = metadata return created # some uops map to other stuff buffers:weakref.WeakKeyDictionary[UOp, Buffer] = weakref.WeakKeyDictionary() # this maps BUFFER uops to their device Buffers all_metadata:weakref.WeakKeyDictionary[UOp, Metadata] = weakref.WeakKeyDictionary() forced_realize:weakref.WeakSet[UOp] = weakref.WeakSet() # NOTE: this should be frozen, but frozen is slower @dataclass(eq=False, slots=True) class UOp(MathTrait, metaclass=UOpMetaClass): op:Ops dtype:DType = dtypes.void src:tuple[UOp, ...] = tuple() arg:Any = None def __del__(self): if self.op is Ops.BUFFER and (buffer:=buffers.get(self)) is not None: buffer.ref(-1) if (k:=(self.op, self.dtype, self.src, self.arg)) in UOpMetaClass.ucache: del UOpMetaClass.ucache[k] def __reduce__(self): args = [self.op, self.dtype, self.src, self.arg] if (_device_buffer:=self.realized) is not None and PICKLE_BUFFERS: args.extend([_device_buffer]) return UOp, tuple(args) def replace(self, **kwargs) -> UOp: new_args = (kwargs.pop("op", self.op), kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src), kwargs.pop("arg", self.arg)) assert len(kwargs) == 0, f"unused kwargs in replace {list(kwargs)}" if (self.op, self.dtype, self.src, self.arg) == new_args: return self return UOp(*new_args) @functools.cached_property def key(self) -> bytes: return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest() def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))") def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else self.arg @property def toposort(self) -> dict[UOp, None]: def _toposort(u:UOp, cache:dict[UOp, dict[UOp, None]]): if (cret:=cache.get(u)) is not None: return cret nodes: dict[UOp, None] = {} # NOTE: this is a lot faster than the comprehension in parents for parent in u.src: nodes.update(_toposort(parent, cache)) nodes[u] = None cache[u] = nodes return nodes return _toposort(self, cache={}) @functools.cached_property def tuplize(self:UOp) -> tuple[int, Any, Optional[DType], Tuple]: return (self.op.value, self.arg, self.dtype, tuple(x.tuplize for x in self.src)) # *** uop shape stuff *** @property def has_st(self) -> bool: return self.op not in {Ops.DEFINE_LOCAL, Ops.DEFINE_GLOBAL, Ops.BUFFER, Ops.CONST, Ops.DEFINE_VAR} @functools.cached_property def st(self) -> Optional[ShapeTracker]: if self.op is Ops.VIEW: return self.arg if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg) # buffer ops can have a non contiguous shapetracker if self.op in GroupOp.Buffer and len(src_sts:=[unwrap(x.st) for x in self.src if x.op is Ops.VIEW]) != 0: return src_sts[0] if len(src_sts:=[x.st for x in self.src if x.st is not None]) == 0: return None assert all_same([x.shape for x in src_sts]), f"UOp parents must have the same shape {self} {[x.shape for x in src_sts]}" # all other ops have a contiguous shapetracker from tinygrad.shape.shapetracker import ShapeTracker return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg) if self.op in (Ops.REDUCE_AXIS, Ops.WMMA) else src_sts[0].shape) @functools.cached_property def full_shape(self) -> tuple[sint, ...]: return self.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st])) @property def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape @property def size(self) -> int: return self.arg[-1] if self.op is Ops.BUFFER else unwrap(self.st).size # *** uop evaluation *** def simplify(self): with Context(TRACK_MATCH_STATS=0): return graph_rewrite(self, symbolic) def ssimplify(self) -> Union[UOp, ConstType]: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret def _eval(self, dtype, expected_type:Type[T]) -> T: assert self.dtype in dtype, f"eval with wrong dtype {self}" vmin, vmax = (simple_self:=self.simplify())._min_max if vmin != vmax: raise ValueError(f"eval failed to be a single number, range is {vmin} to {vmax} in {simple_self.render()}") assert isinstance(vmin, expected_type), f"vmin is wrong dtype {type(vmin)} != {expected_type}" return vmin def __bool__(self): return self._eval((dtypes.bool,), bool) def __int__(self): return self._eval(dtypes.ints, int) def __float__(self): return self._eval(dtypes.floats, float) def substitute(self, dvars:dict[UOp, UOp]): with Context(TRACK_MATCH_STATS=0): return graph_rewrite(self, _substitute, dvars, bottom_up=True) # *** uop syntactic sugar *** @property def st_arg(self) -> ShapeTracker: assert self.op in GroupOp.Buffer, f"st_arg called on {self.op}" ret = self.src[0 if self.op is Ops.VALID else 1] assert ret.op is Ops.VIEW, f"st_arg trying to return {ret}" return ret.arg @property def const_arg(self) -> ConstType: match self.base.op: case Ops.CONST: ret = self.base.arg case Ops.VIEW: ret = self.base.src[1].const_arg case op: raise AssertionError(f"const_arg called on {op}") assert isinstance(ret, get_args(ConstType)), f"const_arg trying to return {ret}" return ret @property def axis_arg(self) -> tuple[int, ...]: assert self.op in {Ops.REDUCE_AXIS, Ops.WMMA}, f"axis_arg called on {self.op}" ret = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7] assert isinstance(ret, tuple) and all(isinstance(x, int) for x in ret), f"axis_arg trying to return {ret}" return ret def sink(self, *srcs:UOp): return UOp(Ops.SINK, dtypes.void, (self,)+srcs) def detach(self): return UOp(Ops.DETACH, self.dtype, (self,)) def index(self, idx:UOp, valid:UOp|None=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx)) def const_like(self, b:ConstLike): if self._device is not None: return UOp.metaop(Ops.CONST, self.shape, self.dtype, self.device, b) return UOp.const(self.dtype, b) if self.st is None else UOp.const_with_shape(self.dtype, b, self.shape) def broadcast(self, count:int): assert self.dtype.count == 1 if count == 1: return self return UOp(Ops.VECTORIZE, self.dtype.vec(count), (self,)*count) def cast(self, dtype:DType, bitcast=False, allow_buffer_view=True): if self.dtype == dtype: return self # TODO: move this to the scheduler if bitcast: return self.bitcast(dtype, allow_buffer_view) if self._device is not None and self._device.startswith("DISK"): raise RuntimeError("CAST isn't supported on DISK") if getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self is not self.base: # NOTE: we have to apply the movementops here, we can't use VIEW (yet) # TODO: move this to the scheduler ret = self.base.cast(dtype, bitcast) op_arg = [] mop = self while mop is not self.base: op_arg.append((mop.op, mop.arg)) mop = mop.src[0] for op,arg in reversed(op_arg): ret = UOp(op, ret.dtype, (ret,), arg) return ret return UOp(Ops.CAST, dtype, (self,)) def bitcast(self, dtype:DType, allow_buffer_view=True): if self.can_view() and allow_buffer_view: if self.dtype.itemsize == dtype.itemsize: output_shape = self.shape else: if not self.device.startswith("DISK") or not all_int(self.shape): raise RuntimeError(f"shape changing bitcast not supported on {self}") # https://pytorch.org/docs/stable/generated/torch.Tensor.view.html if (self.shape[-1]*self.dtype.itemsize) % dtype.itemsize != 0: raise RuntimeError("unsupported size in bitcast") output_shape = self.shape[:-1]+((self.shape[-1]*self.dtype.itemsize) // dtype.itemsize,) return UOp.metaop(Ops.BUFFER_VIEW, output_shape, dtype, self.device, None, (self,)) return UOp(Ops.BITCAST, dtype, (self,)) def gep(self, i:Union[tuple[int, ...], int]): if isinstance(i, int): # NOTE: these are just shortcuts to not have to create and fold later if self.op is Ops.VECTORIZE: return self.src[i] if self.op is Ops.VCONST: return UOp.const(self.dtype.scalar(), self.arg[i]) if self.op is Ops.CONST: return UOp.const(self.dtype.scalar(), self.arg) i = (i,) if (self.dtype.vcount == len(i) and i == tuple(range(len(i)))) or self.dtype == dtypes.void: return self return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i) def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, src=(self,)+src, **kwargs) def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs) def alu(self, arg, *src:UOp): out_dtype = (self, *src)[-1].dtype if arg in {Ops.CMPLT, Ops.CMPNE}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool return UOp(arg, out_dtype, (self,)+src) @staticmethod def const(dtype:DType, b:ConstLike): if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same return UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype)) @staticmethod def range(dtype:DType, start:sint, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(start), sint_to_uop(end)), arg=idx) def _reduce_op(self, op:Ops, axis:tuple[int, ...]): axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)])) return self if len(axis) == 0 else UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis)) def r(self, op:Ops, axis:tuple[int, ...]) -> UOp: new_shape = unwrap(self.st).reduce(axis) # TODO: can we split symbolic shape if the reduce axis is not symbolic? if not SPLIT_REDUCEOP or not all_int(self.shape) or (0 in self.shape) or \ prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return self._reduce_op(op, axis) # if there are few globals, make some reduces into globals by splitting into two kernels # cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm # ~2**10 should be enough if GROUP is used # 256 split maximum should be "negligible reduce" for low prod(new_shape), 8 split minimum. # split is moved to the end to provide maximum locality for the second phase reduce. self_real_strides = unwrap(self.st).real_strides(ignore_valid=True) split_candidates = [(i, x) for i in axis for x in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(new_shape)),8-1,-1) if self.shape[i] % x == 0 and self_real_strides[i] != 0] if not split_candidates: return self._reduce_op(op, axis) dim_to_split, divisor = split_candidates[0] splitted_shape = self.shape[:dim_to_split] + (divisor,) + (self.shape[dim_to_split]//divisor,) + self.shape[dim_to_split+1:] splitted = self.reshape(splitted_shape).permute(tuple([x for x in range(len(splitted_shape)) if x != dim_to_split]+[dim_to_split])) if DEBUG >= 3: print(f"split {divisor}: {self.shape} -> {splitted.shape} -> {new_shape}") return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x), None if self.st is None or self.st.contiguous else self.st) def contiguous(self, allow_buffer_view=True): if not unwrap(self.st).contiguous or self.size != self.base.size or self.is_unrealized_const(): if allow_buffer_view and self.can_view(): return self.metaop(Ops.BUFFER_VIEW, self.shape, self.dtype, self.device, None, (self,)) return self.alu(Ops.CONTIGUOUS) forced_realize.add(self.base) return self # *** from LazyBuffer *** @staticmethod def const_with_shape(dtype:DType, val:ConstLike, shape:tuple[sint,...]) -> UOp: from tinygrad.shape.shapetracker import ShapeTracker return UOp(Ops.VALID, dtypes.bool, (ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),)).where(UOp.const(dtype, val), 0) @staticmethod def metaop(op:Ops, shape:tuple[sint, ...], dtype:DType, device:str, arg=None, src:tuple[UOp, ...]=()) -> UOp: from tinygrad.shape.shapetracker import ShapeTracker if op is Ops.CONST: # NOTE: we embed device on CONST with a fake BUFFER uop fake = UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device),), (-1, 1)) # NOTE: BIND stays BIND, UOp.const unbinds here const_uop = arg if isinstance(arg, UOp) else UOp.const(dtype, unwrap(arg)) return UOp(Ops.VIEW, dtype, (fake, const_uop), ShapeTracker.from_shape(())).reshape((1,)*len(shape)).expand(shape) # otherwise it's a contiguous st return UOp(Ops.VIEW, dtype, (UOp.new_buffer(device, (st:=ShapeTracker.from_shape(shape)).size, dtype), UOp(op, dtype, src, arg)), st) def copy_to_device(self, device:str, force=False, clone:bool=False) -> UOp: # no COPY if self.device == device and not clone: return self # TODO: hack const metaop early here, fix this in multi if self.is_unrealized_const(): return UOp.metaop(Ops.CONST, (), self.dtype, device, self.const_arg).view(unwrap(self.st)) # if it's a shrink, do the shrink before the copy with CONTIGUOUS if prod(self.shape) < prod(self.base.shape): return self.contiguous().copy_to_device(device) # copy the base and apply the shapetracker on the new device if not unwrap((src:=self.base).st).contiguous: raise RuntimeError(f"can only copy contiguous {self}") return UOp.metaop(Ops.COPY, src.shape, src.dtype, device, (device, clone), (src,)).view(unwrap(self.st)) def clone(self) -> UOp: return self.copy_to_device(self.device, clone=True) def is_unrealized_const(self): return (s:=self.base).op is Ops.VIEW and len(s.src) == 2 and s.realized is None and s.src[1].op is Ops.CONST def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in unwrap(self.st).views) def can_view(self): return (self.st is not None and self._device is not None and self.st.consecutive and not self.is_unrealized_const() and not isinstance(self.dtype, ImageDType) and self.device.split(":")[0] in view_supported_devices) @property def lbs(self): return [self] @property def metadata(self): return all_metadata.get(self, None) @property def forced_realize(self): return self in forced_realize # *** danger zone *** # CAUTION: MUTABILITY! def become(self, u:UOp): del UOpMetaClass.ucache[(self.op, self.dtype, self.src, self.arg)] self.op, self.dtype, self.src, self.arg = u.op, u.dtype, u.src, u.arg # *** uop movement ops *** @property def base(self) -> UOp: if self.op in GroupOp.Movement: return self.src[0].base return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self def view(self, new_st:ShapeTracker) -> UOp: if self.st is None: return UOp(Ops.VIEW, self.dtype.base if not isinstance(self.dtype, ImageDType) else self.dtype, (self,), new_st) ret = UOp(Ops.VIEW, self.dtype, (self.base,), new_st) # instant folding rules if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)): return ret.const_like(0) if new_st.contiguous and self.base.shape == new_st.shape: return self.base return ret def _mop(self, op:Ops, arg): ret = UOp(op, self.dtype, (self,), arg) if self.st == ret.st: return self # ignore NOOPs, also check ret.st return ret def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg) def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg) def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg) def permute(self, arg:tuple[sint, ...]): return self._mop(Ops.PERMUTE, arg) def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg) def stride(self, arg:tuple[sint, ...]): return self._mop(Ops.STRIDE, arg) # *** uop Buffer stuff *** buffer_num = itertools.count(0) @staticmethod def new_buffer(device:str, size:int, dtype:DType) -> UOp: return UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device),), (next(UOp.buffer_num), size)) @property def device(self) -> str: return unwrap(self._device) @functools.cached_property def _device(self) -> Optional[str]: if self.op is Ops.DEVICE: return self.arg # TODO: why does this fail? #if self.op is Ops.COPY: return self.arg[0] return dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None @property def buf_uop(self) -> UOp: if self.op is Ops.BUFFER: return self assert self.base.op in {*GroupOp.Buffer, Ops.ASSIGN, Ops.VIEW}, f"buf_uop called on {self.op}" return self.src[0].buf_uop def buf_uop_view(self) -> UOp: return self.buf_uop.view(unwrap(self.st)) @property def buffer(self) -> Buffer: if self.base.realized is not None: return self.base.realized if (ret:=buffers.get(self)) is not None: return ret if self.op is Ops.VIEW: assert unwrap(self.st).contiguous, "VIEW only works here if it's contiguous" return self.src[0].buffer assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}" from tinygrad.device import Buffer buffers[self] = ret = Buffer(self.device, self.size, self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base) return ret @property def realized(self) -> Optional[Buffer]: if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is Ops.BUFFER: return buffers[self.src[0]] return None @property def is_realized(self) -> bool: return self.base.realized is not None # *** uop Variable stuff *** @staticmethod def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.int): assert not isinstance(min_val, UOp) and not isinstance(max_val, UOp), f"can't create Variable {name} with {min_val}/{max_val}" return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val)) @property def expr(self): assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR" return self.arg[0] def bind(self, val:int): assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR" assert self.arg[1] <= val and val <= self.arg[2], f"bind {val} not in range [{self.arg[1]}, {self.arg[2]}]" return UOp(Ops.BIND, self.dtype, (self, self.const_like(val))) def unbind(self) -> tuple[Variable, int]: assert self.op is Ops.BIND and self.src[0].op is Ops.DEFINE_VAR and self.src[1].op is Ops.CONST, f"can't unbind {self}" return self.src[0], self.src[1].arg @property def val(self) -> int: return self.unbind()[1] def vars(self) -> set[UOp]: bound_vars = set([x for x in self.toposort if x.op is Ops.BIND and x.src[0].op is Ops.DEFINE_VAR]) bound_var_base = set(x.src[0] for x in bound_vars) all_vars = set([x for x in self.toposort if x.op is Ops.DEFINE_VAR]) return bound_vars.union(set([x for x in all_vars if x not in bound_var_base])) def variables(self) -> list[Variable]: st_vars: list[set[Variable]] = [x.st_arg.vars() for x in self.toposort if x.op in GroupOp.Buffer] return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg) # *** uop symbolic stuff *** def const_factor(self) -> int: """largest known int that divides self""" if self.op is Ops.CONST: return self.arg if self.op is Ops.VCONST: return math.gcd(*self.arg) if self.op is Ops.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor()) if self.op is Ops.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1 return 1 def divides(self, v) -> UOp|None: if v==1: return self if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None if self.op is Ops.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None if self.op is Ops.MUL: if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1] if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1 return None # generic None if we aren't sure @property def vmin(self) -> ConstType: return self._min_max[0] @property def vmax(self) -> ConstType: return self._min_max[1] @functools.cached_property def _min_max(self) -> tuple[ConstType, ConstType]: if self.op in GroupOp.Binary and not dtypes.is_float(self.dtype): (s0_vmin, s0_vmax), (s1_vmin, s1_vmax) = self.src[0]._min_max, self.src[1]._min_max if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax if self.op is Ops.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals) # SHL/SHR on consts only if self.op is Ops.SHL and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] << t[2], t[1] << t[2] if self.op is Ops.SHR and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] >> t[2], t[1] >> t[2] if self.op is Ops.MOD and s1_vmin > 0: return 0, s1_vmax-1 if self.op is Ops.IDIV: if s1_vmin == s1_vmax: # min/max are equal in a CONST if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin if s1_vmin < 0 and s0_vmin >= 0: return -(s0_vmax//-s1_vmin), -(s0_vmin//-s1_vmin) # don't know exact bounds, but know the sign if (s0_vmax <= 0 and s1_vmin < 0) or (s0_vmin >= 0 and s1_vmin > 0): return 0, dtypes.max(self.dtype) if (s0_vmax <= 0 and s1_vmin > 0) or (s0_vmin >= 0 and s1_vmin < 0): return dtypes.min(self.dtype), 0 if self.op is Ops.MAX: return max(s0_vmin, s1_vmin), max(s0_vmax, s1_vmax) if self.op is Ops.CMPLT: return (s0_vmax str: ret = graph_rewrite(self.simplify() if simplify else self, renderer) return ret.arg if ret.op is Ops.NOOP else str(ret) @dataclass(frozen=True) class KernelInfo: local_dims: int = 0 # number of local dimensions (this is remapping RANGE to SPECIAL) upcasted: int = 0 # count that are upcasted (this is remapping RANGE to UNROLL) dont_use_locals: bool = False # don't use local indexing # ***** ops in python ***** def safe_exp2(x): try: return 2 ** x except OverflowError: return math.inf python_alu: dict[Ops, Callable] = { Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: safe_exp2, Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x), Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt, Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max, Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else 0, Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z} def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True): if dtype.count > 1: return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)]) alu = python_alu[op](*operands) return truncate.get(dtype, lambda x: x)(alu) if truncate_output else alu # ***** uop helpers ***** def print_uops(uops:list[UOp]): for i,u in enumerate(uops): formatted_parents = [(uops.index(x) if x.op is not Ops.CONST else f"{x.arg}") if x in uops else "--" for x in u.src] print(f"{i:4d} {str(u.op):20s}: {str(u.dtype):30s} " f"{str(formatted_parents):32s} {u.arg}") # ***** pattern matcher ***** def get_location() -> tuple[str, int]: frm = sys._getframe(1) # find the real frame in the file that has the UPat, TODO: is there a better way to do this? while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "uopgraph.py", "schedule.py", "lowerer.py", "cstyle.py", "linearize.py"}: frm = frm.f_back return frm.f_code.co_filename, frm.f_lineno @functools.lru_cache(None) def lines(fn) -> list[str]: with open(fn) as f: return f.readlines() class UPat(MathTrait): __slots__ = ("op", "dtype", "arg", "name", "src") def __init__(self, op:Optional[Union[Ops, tuple[Ops, ...], set[Ops]]]=None, dtype:Optional[Union[DType, tuple[DType, ...]]]=None, src:Optional[Union[tuple[UPat, ...], list[UPat], UPat]]=None, arg:Any=None, name:Optional[str]=None, allow_any_len:bool=False, location=None, custom_early_reject:Optional[set[Ops]]=None): assert op is None or isinstance(op, Ops) or isinstance(op, tuple) or isinstance(op, set), "op must be Ops or tuple of Ops" self.op: Optional[tuple[Ops, ...]] = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, set) else op) self.dtype: Optional[tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject self.src: Any = None assert self.name != "ctx", "UPat can't be named ctx" # try all permutations if it's a list if isinstance(src, list): self.src = list(itertools.permutations(src)) if not all_same(src) else [src] # only one if it's a tuple elif isinstance(src, tuple): self.src = [src] # repeat if it's a UPat elif isinstance(src, UPat): self.src = [itertools.repeat(src)] self.allowed_len: int = -1 if allow_any_len or isinstance(src, UPat) or src is None else len(src) self.location = location or get_location() if custom_early_reject is not None: self.early_reject = custom_early_reject else: upat_match = [src] if isinstance(src, UPat) else ([] if src is None else self.src[0]) self.early_reject = {pp.op[0] for pp in upat_match if pp.op is not None and len(pp.op) == 1} def named(self, name:str): return UPat(self.op, self.dtype, self._in_src, self.arg, name, self.allowed_len == -1, self.custom_early_reject) @staticmethod def any(*src): return UPatAny(src=src) @staticmethod @functools.lru_cache(None) def var(name:Optional[str]=None, dtype:Optional[Union[DType, tuple[DType, ...]]]=None): return UPat(dtype=dtype, name=name) @staticmethod @functools.lru_cache(None) def cvar(name:Optional[str]=None, dtype:Optional[DType]=None, vec=True): return UPat((Ops.CONST,Ops.VCONST) if vec else Ops.CONST, dtype, name=name) @staticmethod def const(dtype:Optional[Union[DType, tuple[DType, ...]]], b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b) # copied from UOp def index(self, idx:UPat, valid:Optional[UPat]=None): return UPat(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx)) def view(self, st=None, **kwargs): return UPat(Ops.VIEW, self.dtype, (self,), st, **kwargs) def cast(self, dtype=None): return UPat(Ops.CAST, dtype, (self,)) def bitcast(self, dtype=None): return UPat(Ops.BITCAST, dtype, (self,)) def gep(self, i:int): return UPat(Ops.GEP, None, (self,), (i,)) def load(self, *src:UPat, **kwargs): return UPat(Ops.LOAD, src=(self,)+src, **kwargs) def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, dtypes.void, (self,)+src, **kwargs) def assign(self, x:UPat): return UPat(Ops.ASSIGN, self.dtype, (self,x)) def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b)) def alu(self, op:Ops, *src:UPat): asrc = (self,)+src return UPat(op, dtypes.bool if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc) def printable(self:UPat) -> str: try: return lines(self.location[0])[self.location[1]-1].strip() except FileNotFoundError: return "" def __repr__(self): def rep(x): form = "UPat(%s, %s, name=%s, dtype=%s, allow_any_len=%s, src=%s)" return form % (None if x.op is None else ('(%s)'%', '.join(map(str, x.op))), x.arg, repr(x.name), set(x.dtype) if x.dtype else None, x.allowed_len == 0, "[%s]" if x.src and len(x.src)>1 else "(%s)") return pretty_print(self, rep, srcfn=lambda x:None if x.src is None else [next(x.src[0])] if isinstance(x.src[0], itertools.repeat) else x.src[0]) def match(self:UPat, uop:UOp, store:dict[str, UOp]) -> list[dict[str, UOp]]: if (self.op is not None and uop.op not in self.op) or \ (self.name is not None and store.setdefault(self.name, uop) is not uop) or \ (self.dtype is not None and uop.dtype not in self.dtype and uop.dtype.scalar() not in self.dtype) or \ (self.arg is not None and self.arg != uop.arg) or \ (self.allowed_len != -1 and len(uop.src) != self.allowed_len): return [] if self.src is None: return [store] res: list[dict[str, UOp]] = [] for vp in self.src: stores, new_stores = [store.copy()], [] for uu, vv in zip(uop.src, vp): for s in stores: new_stores.extend(vv.match(uu, s)) stores, new_stores = new_stores, [] res.extend(stores) return res class UPatAny(UPat): def match(self:UPat, uop:UOp, store:dict[str, UOp]) -> list[dict[str, UOp]]: matches = [x.match(uop, store.copy()) for x in self.src[0]] return flatten([x for x in matches if x is not None]) def deconstruct_function(fxn:Callable) -> Tuple: new_globals = {k:v for k,v in fxn.__globals__.items() if k in fxn.__code__.co_names} for co in fxn.__code__.co_consts: if isinstance(co, types.CodeType): new_globals.update({k:v for k,v in fxn.__globals__.items() if k in co.co_names}) # NOTE: optional round trip through pickle! assert fxn.__closure__ is None, "closures are not supported in pattern matchers" ret = fxn.__code__, new_globals, fxn.__name__, fxn.__defaults__ return pickle.loads(pickle.dumps(ret)) if getenv("TEST_PICKLE") else ret class PatternMatcher: def __init__(self, patterns:list[tuple[UPat, Callable]]): self.patterns = patterns # NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher! self.pdict: dict[Ops, list[tuple[UPat, Callable, Set, bool]]] = {} # uop is required, arg is optional for p,fxn in self.patterns: assert p.op is not None tuple_fxn = fxn if isinstance(fxn, tuple) else deconstruct_function(fxn) real_fxn = types.FunctionType(*tuple_fxn) for uop in p.op: self.pdict.setdefault(uop, []).append((p, real_fxn, p.early_reject, 'ctx' in inspect.signature(real_fxn).parameters)) def __reduce__(self): return PatternMatcher, ([(x,deconstruct_function(fxn) if fxn.__name__ == "" else fxn) for x,fxn in self.patterns],) @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns) def rewrite(self, uop:UOp, ctx=None) -> UOp|None: ler = {u.op for u in uop.src} for p,fxn,early_reject,has_ctx in self.pdict.get(uop.op, []): if not early_reject.issubset(ler): continue for match in p.match(uop, {}): if (ret:=(fxn(ctx=ctx, **match) if has_ctx else fxn(**match))) is not None: return ret return None # *** tracking pattern matcher *** TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0) match_stats:dict[UPat, list[Union[int, float]]] = dict() @dataclass(frozen=True) class TrackedGraphRewrite: loc: tuple[str, int] # location that called graph_rewrite sink: bytes # sanpshot of the graph_rewrite input sink matches: list[tuple[bytes, Optional[bytes], Optional[UPat], float]] = field(default_factory=list) # before+after snapshot of all the matches tracked_keys:list[Any] = [] tracked_ctxs:list[list[TrackedGraphRewrite]] = [] _name_cnt:dict[str, int] = {} def track_rewrites(named=False): def _decorator(func): def __wrapper(self, *args, **kwargs): if TRACK_MATCH_STATS >= 2: if named: _name_cnt[func.__name__] = _name_cnt.get(func.__name__, 0)+1 tracked_keys.append(f"{func.__name__}_{_name_cnt[func.__name__]}" if named else self) tracked_ctxs.append([]) return func(self, *args, **kwargs) return __wrapper return _decorator class TrackedPatternMatcher(PatternMatcher): def rewrite(self, uop:UOp, ctx=None) -> UOp|None: ret = None ler = {u.op for u in uop.src} for p,fxn,early_reject,has_ctx in self.pdict.get(uop.op, []): if p not in match_stats: match_stats[p] = [0,0,0.0,0.0] st = time.perf_counter() if not early_reject.issubset(ler): match_stats[p][2] += time.perf_counter()-st continue match_stats[p][1] += 1 for match in p.match(uop, {}): if (ret:=(fxn(ctx=ctx, **match) if has_ctx else fxn(**match))) is not None: match_stats[p][0] += 1 match_stats[p][3] += (et:=time.perf_counter()-st) if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable()) if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and len(tracked_ctxs) != 0: with Context(PICKLE_BUFFERS=0): tracked_ctxs[-1][-1].matches.append((pickle.dumps(uop), pickle.dumps(ret), p, et)) return ret # NOTE: if it returns None, we keep trying to match match_stats[p][2] += time.perf_counter()-st if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0: with Context(PICKLE_BUFFERS=0): tracked_ctxs[-1][-1].matches.append((pickle.dumps(uop), None, None, 0)) return None if TRACK_MATCH_STATS: PatternMatcher = TrackedPatternMatcher # type: ignore import atexit @atexit.register def print_match_stats(): if TRACK_MATCH_STATS >= 2: with open(fn:=temp("rewrites.pkl"), "wb") as f: print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}") pickle.dump((tracked_keys, tracked_ctxs), f) launch_viz("VIZ", temp("rewrites.pkl")) if getenv("PRINT_MATCH_STATS", 1): ret = [0,0,0.0,0.0] for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]+x[1][3]): loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}" if v[1] != 0: print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {(v[2]+v[3])*1000.:9.2f} ms -- {loc_str:15s}", k.printable()) ret = [x+y for x,y in zip(ret, v)] print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {(ret[2]+ret[3])*1000.:9.2f} ms -- TOTAL") def launch_viz(env_str:str, data:str): os.environ[env_str] = "0" os.environ[f"{env_str}_DATA"] = data if not int(os.getenv("VIZ", "0")) and not int(os.getenv("PROFILE", "0")): args = ['--kernels', getenv("VIZ_DATA", "")] if getenv("VIZ_DATA", "") else [] args += ['--profile', getenv("PROFILE_DATA", "")] if getenv("PROFILE_DATA", "") else [] os.execv(sys.executable, [sys.executable] + [os.path.join(os.path.dirname(__file__), ".", "viz", "serve.py")] + args) # *** simple graph rewrite engine *** class RewriteContext: def __init__(self, pm, ctx): self.pm: PatternMatcher = pm self.ctx = ctx self.replace: dict[UOp, UOp] = {} def rewrite(self, n:UOp) -> UOp: if (rn := self.replace.get(n)) is not None: return rn new_src = tuple(map(self.rewrite, n.src)) new_n = self.pm.rewrite(n, self.ctx) if new_src == n.src else UOp(n.op, n.dtype, new_src, n.arg) self.replace[n] = ret = n if new_n is None else self.rewrite(new_n) return ret def bottom_up_rewrite(self, n:UOp) -> UOp: if (rn := self.replace.get(n)) is not None: return rn new_n: UOp|None = n while new_n is not None: last_n, new_n = new_n, self.pm.rewrite(new_n, self.ctx) new_src = tuple(map(self.bottom_up_rewrite, last_n.src)) self.replace[n] = ret = last_n if new_src == last_n.src else self.bottom_up_rewrite(UOp(last_n.op, last_n.dtype, new_src, last_n.arg)) return ret def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> UOp: if TRACK_MATCH_STATS >= 2 and not bottom_up and len(tracked_ctxs) != 0: # TODO: make viz work with bottom_up=True with Context(PICKLE_BUFFERS=0): tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), pickle.dumps(sink))) return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).rewrite(sink) # ***** uop type spec ***** # this is the matcher for the final rendered UOps # matcher functions returns True or False (or None to not match) spec = PatternMatcher([ (UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local), (UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local), (UPat(Ops.DEFINE_ACC, src=(UPat.var("c"),), name="x", allow_any_len=True), lambda x,c: all(y.op is Ops.RANGE for y in x.src[1:]) and c.dtype == x.dtype), (UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)), (UPat(Ops.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype and isinstance(rng.arg, int)), (UPat(Ops.SPECIAL, src=()), lambda: True), # TODO: confirm the args of both of these are shapetrackers (UPat(Ops.VIEW, dtypes.void, src=()), lambda: True), (UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"), lambda x,src: src.op is not Ops.STORE and x.dtype == src.dtype), (UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True), (UPat(Ops.CONST, name="x"), lambda x: x.dtype == x.dtype.scalar() and (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))), # early LOAD has a (UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW))), lambda: True), (UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat(Ops.STORE))), lambda: True), # early STORE has a (UPat(Ops.STORE, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat())), lambda: True), # **** new style load/store **** # INDEX is used in new style load/store (UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat())), lambda: True), # LOAD takes a (UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)),)), lambda: True), (UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat((Ops.IF, Ops.BARRIER)))), lambda: True), (UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(name="alt"), UPat(dtype=dtypes.bool)), name="ld"), lambda ld,alt: ld.dtype == alt.dtype), # STORE takes a (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat())), lambda: True), (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(dtype=dtypes.bool))), lambda: True), (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(Ops.IF))), lambda: True), # most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE (UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype), (UPat((Ops.CMPLT, Ops.CMPNE), dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y"))), lambda x,y: x.dtype == y.dtype), # and SHL/SHR, the shift distance can be an int (UPat((Ops.SHL, Ops.SHR), src=(UPat(name="x"), UPat(name="y")), name="a"), lambda a,x,y: a.dtype == x.dtype and y.dtype in (x.dtype, dtypes.uint)), (UPat(Ops.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False), (UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)), (UPat(Ops.ASSIGN, src=(UPat((Ops.DEFINE_ACC, Ops.DEFINE_GLOBAL)), UPat())), lambda: True), (UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True), # all WMMA has 3 args, (UPat(Ops.WMMA, src=(UPat(), UPat(), UPat())), lambda: True), (UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)), (UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)), # if has a (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True), (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(), UPat(Ops.BARRIER))), lambda: True), (UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True), (UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}), (UPat(Ops.GEP, src=(UPat(name="src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()), (UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)), (UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None), (UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local # NOTE: for testing, we let sinks be anything #(UPat(UOps.SINK, src=UPat(UOps.STORE)), lambda: True), (UPat(Ops.SINK, dtypes.void), lambda: True), (UPat(Ops.NOOP), lambda: True), # PTX LOAD/STORE (UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True), ]) def type_verify(uops:list[UOp]): for i,u in enumerate(uops): if not spec.rewrite(u): print_uops(uops) raise RuntimeError(f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}") # *** most of symbolic lives here now *** def split_uop(x:UOp, sep:Ops): if x.op is sep: for s in x.src: yield from split_uop(s, sep) else: yield x def div_and_mod_folding(x: UOp, c: int, which: Literal[Ops.MOD, Ops.IDIV], split_rem: bool=False) -> UOp|None: # simplify x // c or x % c, None means no change, c must be > 0 assert c > 0 if x.dtype.count > 1: return None # simple cancel div/mod case if (q:=x.vmin//c) == (x.vmax//c): if which is Ops.MOD: return x - q*c return x.const_like(q) svars, factors, quotients, remainders, gcd, div, const, offset, something_changed = [], [], [], [], c, 1, 0, 0, False for u in split_uop(x, Ops.ADD): if u.op is Ops.MOD and which is Ops.MOD and u.src[1].op is Ops.CONST and u.src[1].arg%c == 0: u = u.src[0] something_changed = True v: UOp = u.divides(f:=u.const_factor()) q, r = divmod(f, c) if r==0 or ((which is Ops.MOD or split_rem or u.op is Ops.CONST) and r!=f): something_changed = True offset += r*v.vmin if u.op is Ops.CONST: const += f else: # div is the smallest common divisor of all terms if f > 1 and c % f == 0 and (div == 1 or div > f): div = f gcd = math.gcd(r, gcd) factors.append(f); svars.append(v); quotients.append(q); remainders.append(r) # noqa: E702 lbound = ubound = offset = offset % c # we can fold if the expression has only one non-constant term and this term can only take on two values if len(svars)==1 and (v:=svars[0]).vmax-v.vmin == 1: r = (offset+remainders[0])%c - offset%c offset -= r * v.vmin if which is Ops.MOD: return r*v + offset return (factors[0]-r)//c * v + (const-offset)//c # a//c = (a-a%c)/c, if we can fold a%c, we can fold a//c # within a mod we can freely subtract multiples of c, we use this to see if a is congruent to an expression whose vmin/vmax are between 0 and c for (r, v) in zip(remainders, svars): if r > c//2: if (lbound := lbound + (r:=r-c) * (v.vmax-v.vmin)) < 0: break elif (ubound := ubound + r * (v.vmax-v.vmin)) >= c: break offset -= r * v.vmin # determine what the new offset would be else: # vmin/vmax of the remainder is between 0 and c, we can remove the mod/div remainders = [min(r, r-c, key=abs) for r in remainders] if which is Ops.MOD: return functools.reduce(operator.add, [r*v for r,v in zip(remainders,svars)], x.const_like(offset)) return functools.reduce(operator.add, [(f-r)//c * v for f,r,v in zip(factors, remainders,svars)], x.const_like((const-offset)//c)) if gcd != 1: something_changed = True if not something_changed: if which is Ops.IDIV and (1 < div < c) and (newx:=div_and_mod_folding(x, div, Ops.IDIV)) is not None: return newx//(c//div) return None quo, rem = x.const_like(const//c), x.const_like((const%c)//gcd) for q,r,f,v in zip(quotients, remainders, factors, svars): if which is Ops.IDIV and (not split_rem) and r!=0: rem += f//gcd * v else: rem += r//gcd * v quo += q * v if which is Ops.MOD: return gcd*(rem % (c//gcd)) + const%gcd return rem//(c//gcd)+quo def lt_folding(x:UOp, c:int) -> UOp|None: p, np = partition(split_uop(x, Ops.ADD), lambda u: u.const_factor() == 1) if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d: return cast(UOp, functools.reduce(operator.add, np).divides(d))<(c//d) return None def fold_unrolled_divs(divs:UOp): # div pattern in unrolled arange # example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x add_chain, denominator, seen_const, ans = list(split_uop(divs, Ops.ADD)), None, [], None for u in add_chain: if not (u.op is Ops.IDIV and u.src[1].op is Ops.CONST): return None if denominator is None: denominator = u.src[1].arg if denominator != u.src[1].arg: return None # assumed CONST is the last of an ADD if (s0:=u.src[0]).op is Ops.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST: seen_const.append(s0.src[1].arg) s0 = s0.src[0] else: seen_const.append(0) if ans is None: ans = s0 if ans is not s0: return None if denominator is None: return None # the first (denominator-len(seen_const)) terms may have been folded to 0 already for i in range(denominator-len(seen_const)): if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i) return ans if ans is not None and sorted(seen_const)==list(range(denominator)) else None def canonicalize_simplex(X:UOp) -> UOp|None: # (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints. # returns x0 + x1 + ... in such case, or None if not changed, ret = False, [] for u in split_uop(X, Ops.ADD): # assumed the const is the last src of MUL if u.op is Ops.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0: changed = True u = u.src[0] if not (u.op in GroupOp.Irreducible and u.vmin >= 0): return None ret.append(u) return functools.reduce(operator.add, ret) if changed else None def is_increasing(f:UOp) -> bool: # is f a monotonically increasing function regards its input if f.op in GroupOp.Irreducible: return True if f.op is Ops.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1]) if f.op in (Ops.MUL, Ops.IDIV) and f.src[1].op is Ops.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0]) return False # False if not sure def parse_valid(valid:UOp) -> tuple[UOp, bool, int]: # if it's X <= c, returns X, True, c # if it's X >= c, returns X, False, c # (X < c).ne(True) -> X >= c if valid.op is Ops.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \ (s0:=valid.src[0]).op is Ops.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg # X < c -> X <= c-1 if valid.op is Ops.CMPLT and valid.src[1].op is Ops.CONST: return valid.src[0], True, valid.src[1].arg-1 raise ValueError(f"not able to parse {valid=}") def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None: # return None if valid is always False, otherwise the simplified uop (might be the same as input) # first, parse valid into {expr: (lower_bound, upper_bound)} bounds:DefaultDict[UOp, list[Optional[ConstType]]] = defaultdict(lambda: [None, None]) for stmt in split_uop(valid, Ops.AND): try: expr, is_upper, c = parse_valid(stmt) except ValueError: return uop # give up if we cannot parse the valid bounds[expr][int(is_upper)] = c # simplify uop given that valid is True for expr,v in bounds.items(): # some expr has lower bound > upper bound -> valid is an empty set and we return None if v[0] is not None and v[1] is not None and v[0] > v[1]: return None # every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop candidates = [] if expr.op is Ops.ADD and v[0] == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, Ops.ADD)): # if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, Ops.ADD)]) # try checking the whole clause if expr in uop.toposort: candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))]) for candidate in candidates: # if every branch in candidate gives the same simplified uop, we can rewrite the uop newuops = [uop.substitute({X:newX}).simplify().substitute({newX:X}).simplify() for X,newX in candidate] if uop.op is Ops.VECTORIZE and len(uop.src) == 2: if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1])) if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1])) elif all_same(newuops): uop = newuops[0] return uop def _valid_priority(v: UOp, valids:list[UOp]): # we want valid that's in other valids' parents to be first, so it's more likely the other valids get simplified try: return sum(-1 if parse_valid(v)[0] in other.toposort else 0 for other in valids) except ValueError: return 0 def simplify_valid(valid:UOp) -> UOp|None: ret:list[UOp] = [] something_changed = False valids = list(split_uop(valid, Ops.AND)) for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)): ret.append(newstmt if ret and (newstmt:=uop_given_valid(functools.reduce(operator.and_, ret), stmt)) is not None else stmt) if ret[-1] is not stmt: something_changed = True return functools.reduce(operator.and_, ret) if something_changed else None def max_var_const(x:UOp, c1:UOp, c2:UOp): if x.vmin >= 0: return x*c1 if c1.arg >= c2.arg else x*c2 if x.vmax <= 0: return x*c2 if c1.arg >= c2.arg else x*c1 def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x symbolic_simple = PatternMatcher([ # ** self folding ** (UPat.var("x") + 0, lambda x: x), # x+0 -> x (UPat.var("x") * 1, lambda x: x), # x*1 -> x (UPat.var("x") // UPat.var("x"), lambda x: x.const_like(1)), # x//x -> 1 (UPat.var("x") // 1, lambda x: x), # x//1 -> x (UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x (UPat.var("x") / UPat.var("x"), lambda x: x.const_like(1)), # x/x -> 1 ((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x ((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed) (UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x ((UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"), lambda x,c1,c2,c3: x*c2 if c1.arg*c2.arg==c3.arg else None), # (x%c1)*c2+(x//c1)*c3 = x*c2 if c1*c2==c3 (UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c), (UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x), (UPat(GroupOp.Idempotent, src=(UPat.var("x"), UPat.var("x"))), lambda x: x), (UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x), (UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x), # ** zero folding ** (UPat.var("x") < UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False (UPat.var("x", dtype=dtypes.ints) != UPat.var("x", dtype=dtypes.ints), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x != x -> False (only ints) # x*0 -> 0 or 0*x -> 0 # if x is nan or inf it should render the nan value. # NOTE: this can be wrong for loaded NaN (UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)), # ** constant folding ** (UPat(GroupOp.ALU, name="a", src=UPat((Ops.VCONST, Ops.CONST))), lambda a: a.const_like(exec_alu(a.op, a.dtype, [x.arg for x in a.src], False))), # bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly (UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y), (UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y), (UPat.var('x', dtype=dtypes.bool).maximum(UPat.var('y', dtype=dtypes.bool)), lambda x,y: x|y), # *** cast *** (UPat(Ops.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)), (UPat(Ops.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None), ]) symbolic = symbolic_simple+PatternMatcher([ # ** COMMUTATIVE flipping ** (UPat(GroupOp.Commutative, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None), # group like ((UPat.var("x") + UPat.var("y")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: (x+x*c)+y), # ** boolean algebra ** (UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x # ** combine terms ** (UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1) (UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1) (UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2 ((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3) (-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c # a conditional with the same results either way is a noop, also fold const conditionals (UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val), (UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1), # alu of two where with same conds can combine, only do if true branch or false branch is const (UPat(GroupOp.Binary, name="alu", src=(UPat.var("c").where(UPat.var("t"), UPat.var("f")), UPat.var("c").where(UPat.var("tt"), UPat.var("ff")))), \ lambda alu,c,t,tt,f,ff: c.where(t.alu(alu.op, tt), f.alu(alu.op, ff)) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None), # ALU min==max -> CONST (slow!) (UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None), # max folding (UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None), # TODO: why does this rule break beautiful_mnist? #((UPat.var("x")+UPat.var("z")).maximum(UPat.var("y")+UPat.var("z")), lambda x,y,z: x.maximum(y) + z), ((UPat.var("x")*UPat.cvar("c1")).maximum(UPat.var("x")*UPat.cvar("c2")), max_var_const), # ** two stage ALU folding ** *((UPat.var("x").alu(op, UPat.cvar("c1")).alu(op, UPat.cvar("c2")).named("f"), lambda f,x,c1,c2: x.alu(f.op,c1.alu(f.op,c2))) for op in GroupOp.Associative), ((UPat.cvar("c0") + UPat.var("x")) < UPat.cvar("c1"), lambda x,c0,c1: x<(c1-c0)), # c0 + x < c1 -> x < c1 - c0 ((UPat.var("x") // UPat.cvar("c1")) // UPat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2) # ** lt ** # c0*x 0 and c1.arg > 0 else None), # c0*x 0 else None), # ** move add/mul consts to end (NOTE: this is still happening before constant folding) ** (UPat(Ops.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1), (UPat(Ops.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1), # *** rules from symbolic *** # unrolled arange div folding (UPat(Ops.ADD, name="divs", src=[UPat(), UPat(Ops.IDIV)]), fold_unrolled_divs), # generic lt folding (UPat.var("x", dtypes.sints) 0 # not x < 1 -> X > 0 ((UPat.var("x", dtypes.ints)<1).ne(True), lambda x: (newx<1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None), # ** div ** # div folding ((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)), # (x//c+a)//d -> (x+a*c)//(c*d) (UPat.var("x", dtypes.sints) // UPat.cvar("c", vec=False), lambda x,c: div_and_mod_folding(x,c.arg,Ops.IDIV) if 0 < c.arg else None), # ** mod ** # mod folding (UPat.var("x") % UPat.cvar("c", vec=False), lambda x,c: div_and_mod_folding(x,c.arg,Ops.MOD) if 0 < c.arg else None), ]) symbolic_flat = symbolic+PatternMatcher([ # ** combine terms (opinionated) ** (-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y # (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue ((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c), ]) _substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))]) # for debug syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<", Ops.SHR: ">>", Ops.MUL: "*", Ops.CMPLT: "<", Ops.CMPNE: "!=", Ops.AND: "&", Ops.OR: "|", Ops.XOR: "^"} renderer = PatternMatcher([ (UPat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])), (UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg}")), (UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))), (UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]), (UPat(Ops.NEG, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")), (UPat(Ops.MAX, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")), (UPat(Ops.MULACC, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")), (UPat(Ops.WHERE, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")), (UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}{syms[x.op]}{x.src[1].arg})")), ]) # *** what was symbolic.py *** sint = Union[int, UOp] Variable = UOp ConstLike = Union[ConstType, Variable, tuple[ConstType, ...]] # *** uop swizzling *** merge_views = PatternMatcher([(UPat(Ops.VIEW, name="s0").view(name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st))]) # push VIEW to loads view_left = merge_views+PatternMatcher([ # VIEW before elementwise ops (UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"), lambda e,v: e.replace(src=tuple(s if not s.has_st else s.view(v.st) if s is s.base else s.base.view(s.st+v.st) for s in e.src))), # early merge VIEW buffer ops (UPat(GroupOp.Buffer, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((s.st+v.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))), ])