openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

1420 lines
78 KiB

from __future__ import annotations
from typing import Any, Callable, cast, TYPE_CHECKING, Type, Sequence, Iterable
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref, collections
from dataclasses import dataclass
from enum import Enum, auto
from tinygrad.uop import Ops, GroupOp
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, InvalidType, AddrSpace
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC, CI
from tinygrad.helpers import strip_parens, colored, ansilen, printable
if TYPE_CHECKING:
from tinygrad.device import Buffer, MultiBuffer
class AxisType(Enum):
def __repr__(self): return str(self)
GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
THREAD = auto()
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"}
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"}
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1}
# https://en.wikipedia.org/wiki/Identity_element
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
# With True as the default, this matches the old symbolic behavior
def resolve(x:UOp|bool, default:bool=True):
if isinstance(x, bool): return x
assert x.dtype == dtypes.bool, "UOp in resolve must be bool"
# NOTE: generating the text for the exception is expensive, so we do this
return bool(sx.vmin) if (sx:=x.simplify()).vmin == sx.vmax else default
# smax/smin are replacements for max/min that preserve symbolic
def _suop(lst, uop_fxn, python_fxn):
uops, nums = partition(lst, lambda x: isinstance(x, UOp))
return ssimplify(functools.reduce(uop_fxn, uops + ([python_fxn(nums)] if nums else [])))
def smax(*lst) -> sint: return _suop(argfix(*lst), UOp.maximum, max)
def smin(*lst) -> sint: return _suop(argfix(*lst), UOp.minimum, min)
def srender(x:sint) -> str: return x.render() if isinstance(x, UOp) else str(x)
def ssimplify(uop:sint): return uop.ssimplify() if isinstance(uop, UOp) else uop
def sym_infer(uop: UOp|int, var_vals: dict[str, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
def range_str(u:UOp, color=False) -> str:
ret = '_'.join([str(x) if x >= 0 else "m"+str(-x) for x in u.arg[0:-1]])
return colored(ret, axis_colors[u.arg[-1]]) if color else ret
def multirange_str(rngs:Iterable[UOp], color=False, pad=None) -> str:
ret = ','.join([range_str(x, color=color) for x in sorted(rngs, key=lambda x: x.arg)])
if pad is not None: ret += " " * (pad-ansilen(ret))
return ret
def consumer_map_from_toposort(lst:Iterable[UOp]):
ret: dict[UOp, dict[UOp, None]] = {}
for u in lst:
ret[u] = {}
for s in u.src: ret[s][u] = None
return ret
# used for UOp and UPat
def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->str:
def dfs(x:Any, cache:dict):
for s in srcfn(x) or []:
cache.setdefault(s, [len(cache), 0, False])[1] += 1
if cache[s][1] == 1: dfs(s, cache)
if cache is None: dfs(x, cache:={})
if (cx:=cache.setdefault(x, [0,0,False]))[2]: return f"{' '*d} x{cx[0]}"
cx[2], srcs = True, ('None' if srcfn(x) is None else ''.join(f'\n{pretty_print(s, rep, srcfn, cache, d+2)},' for s in srcfn(x)))
return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs
class UOpMetaClass(type):
ucache:dict[tuple, weakref.ReferenceType[UOp]] = {}
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, tag:Any=None,
metadata:tuple[Metadata,...]|None=None, _buffer:Buffer|None=None):
if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg, tag), None)) is not None and (ret:=wret()) is not None: return ret
UOpMetaClass.ucache[key] = weakref.ref(created:=super().__call__(*key))
if metadata is not None: all_metadata[created] = metadata
# NOTE: this value is set by pickle when pickling a realized tensor
if _buffer is not None:
assert op is Ops.BUFFER, f"trying to set Buffer {_buffer} for {op}"
buffers[created] = _buffer
if SPEC > 1:
from tinygrad.uop.spec import full_spec, test_pyrender
if SPEC > 2: test_pyrender(created)
with Context(IGNORE_OOB=1): ret = full_spec.rewrite(created)
if cast(bool|None, ret) is not True: raise RuntimeError(f"SPEC ISSUE {ret}: {created}")
return created
# some uops map to other stuff
buffers:weakref.WeakKeyDictionary[UOp, Buffer|MultiBuffer] = weakref.WeakKeyDictionary() # this maps BUFFER uops to their device Buffers
all_metadata:weakref.WeakKeyDictionary[UOp, tuple[Metadata, ...]] = weakref.WeakKeyDictionary() # TODO: should this be here?
# recursive_property replaces functools.cached_property in recursive UOp functions to prevent RecursionError
_NOT_FOUND = object()
class recursive_property(property):
def __init__(self, fxn):
self.fxn = fxn
self.nm = "_RECURSIVE_PROPERTY_"+fxn.__name__
self.__doc__ = fxn.__doc__
def __get__(self, x:UOp|None, owner=None):
if x is None: return self
if (val:=x.__dict__.get(self.nm, _NOT_FOUND)) is _NOT_FOUND:
for s in x.toposort(lambda z: not hasattr(z, self.nm)):
s.__dict__[self.nm] = val = self.fxn(s)
return val
# we import this late so we can use resolve/smax in mixins
from tinygrad.mixin import OpMixin
# NOTE: this should be frozen, but frozen is slower
@dataclass(eq=False, slots=True)
class UOp(OpMixin, metaclass=UOpMetaClass):
op:Ops
dtype:DType = dtypes.void
src:tuple[UOp, ...] = tuple()
arg:Any = None
tag:Any = None
def __del__(self):
if Ops is not None and self.op is Ops.BUFFER and (buffer:=buffers.get(self)) is not None: buffer.ref(-1)
try: del UOpMetaClass.ucache[(self.op, self.dtype, self.src, self.arg, self.tag)]
except AttributeError: pass
def __reduce__(self):
args = [self.op, self.dtype, self.src, self.arg, self.tag, self.metadata]
if self.op is Ops.BUFFER and self.realized is not None and PICKLE_BUFFERS: args.append(self.realized)
return UOp, tuple(args)
def replace(self, **kwargs) -> UOp:
new_args = (kwargs.pop("op", self.op), kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src),
kwargs.pop("arg", self.arg), kwargs.pop("tag", self.tag))
assert len(kwargs) == 0, f"unused kwargs in replace {list(kwargs)}"
if (self.op, self.dtype, self.src, self.arg, self.tag) == new_args: return self
return UOp(*new_args)
def rtag(self, tag=True): return self.replace(tag=tag)
@functools.cached_property
def key(self) -> bytes:
return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest()
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}{x.tagstr()}, src=(%s))")
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else repr(self.arg)
def tagstr(self): return f", tag={self.tag}" if self.tag is not None else ""
def f(self, op, **kwargs): return UOp(op, dtype=kwargs.pop("dtype", self.dtype), src=(self,), **kwargs)
@functools.cached_property
def backward_slice(self:UOp) -> dict[UOp, None]:
res: dict[UOp, None] = self.toposort()
res.pop(self)
return res
@property
def backward_slice_with_self(self:UOp) -> dict[UOp, None]: return {self:None, **self.backward_slice}
def op_in_backward_slice_with_self(self, *ops:Ops): return any(x.op in ops for x in self.backward_slice_with_self)
def toposort(self, gate:Callable|None=None) -> dict[UOp, None]:
ret: dict[UOp, None] = {}
stack: list[tuple[UOp, bool]] = [(self, False)] # each stack entry is (node, visited_flag)
while stack:
node, visited = stack.pop()
if node in ret: continue
if not visited:
if gate is None or gate(node):
stack.append((node, True)) # push node back on stack to process after its srcs
for s in reversed(node.src): stack.append((s, False)) # push srcs on the stack
else: ret[node] = None # second time i'm seeing this node, add it to returned toposort
return ret
# returns map of UOps to their consumers in the graph rooted by self
def get_consumer_map(self) -> dict[UOp, dict[UOp, None]]: return consumer_map_from_toposort(self.toposort())
def reverse_toposort(self, consumer_map) -> dict[UOp, None]:
ret: dict[UOp, None] = {}
stack: list[tuple[UOp, bool]] = [(x, False) for x in consumer_map if len(x.src) == 0]
while stack:
node, visited = stack.pop()
if node in ret: continue
if not visited:
stack.append((node, True)) # push node back on stack to process after its srcs
for s in consumer_map[node]: stack.append((s, False)) # push srcs on the stack
else: ret[node] = None # second time i'm seeing this node, add it to returned toposort
return ret
@functools.cached_property
def tuplize(self:UOp) -> tuple:
return (self.op.value, self.arg, self.dtype,)+tuple([x.tuplize for x in self.src])
@property
def ptrdtype(self) -> PtrDType:
if not isinstance(self.dtype, PtrDType): raise RuntimeError(f"ptrdtype called on UOp with type {self.dtype}")
return self.dtype
# *** uop shape stuff ***
@recursive_property
def _shape(self) -> tuple[sint, ...]|None:
match self.op:
# late ops don't have shape
case Ops.UNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.PRECAST | Ops.CONTRACT:
return None
case Ops.INDEX:
# non pointer index doesn't have a shape
if not isinstance(self.dtype, PtrDType): return None
# fully indexed doesn't have a shape. TODO: remove this
if self.src[0]._shape is None or len(self.src[1:]) == len(self.src[0].shape): return None
# pointer index
return self.src[0].shape[len(self.src[1:]):]
# some ops init the shape
case Ops.CONST | Ops.DEFINE_VAR | Ops.BIND: return () if self._device is not None else None
case Ops.BUFFER: return (self.arg,)
case Ops.BUFFER_VIEW: return (self.arg[0],)
case Ops.BUFFERIZE: return tuple([int(r.vmax+1) for r in self.src[1:]])
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
# passthrough ops
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END:
return self.src[0]._shape
# ops with custom handling
case Ops.KERNEL: return self.arg.ast._shape
# TODO: disallow shape changing bitcast
case Ops.BITCAST:
ps = self.src[0]._shape
if ps is None: return None
if (output_sz:=self.dtype.itemsize) != (input_sz:=self.src[0].dtype.itemsize): return ps[:-1]+(ssimplify((ps[-1]*input_sz) // output_sz),)
return ps
# TODO: disallow reshape from nothing. tested by TestOpenClip.test_multigpu_clip_score
case Ops.RESHAPE:
if self.src[0]._shape is None: return self.marg
# movement ops change the shape. this is the logic from the old ShapeTracker
# NOTE: ssimplify is required because the shape needs to be canonical for broadcasting and same shape checking
if self.op in GroupOp.Movement.union({Ops.MULTI, Ops.REDUCE_AXIS, Ops.WMMA}):
ps = self.src[0]._shape
# TODO: WMMA is used for both axis WMMA and op WMMA. fix this and remove this hack. tested by BERT on AMD LLVM
if ps is None and self.op is Ops.WMMA: return None
if ps is None: raise RuntimeError(f"movement op {self.op} requires shape")
match self.op:
case Ops.RESHAPE:
if not all(x >= 0 for x in self.marg): raise ValueError(f"shape can't contain negative numbers {self.marg}")
if prod(ps) != prod(self.marg): raise ValueError(f"bad reshape: {ps} -> {self.marg}")
return self.marg
case Ops.EXPAND:
if len(ps) != len(self.marg) or not all(s==ns or (s==1 and ns>=0) for s,ns in zip(ps, self.marg)):
raise ValueError(f"bad expand: {ps} -> {self.marg}")
return self.marg
case Ops.PERMUTE:
if sorted(self.marg) != list(range(len(ps))): raise ValueError(f"invalid permutation {self.marg} of len {len(ps)}")
return tuple(ps[i] for i in self.marg)
case Ops.PAD:
# TODO: why do i need resolve here?
if len(ps) != len(self.marg) or not all(resolve(b>=0) and resolve(e>=0) for b,e in self.marg): raise ValueError(f"invalid pad {self.marg}")
return tuple(ssimplify(s+b+e) for s,(b,e) in zip(ps, self.marg))
case Ops.SHRINK:
# TODO: why do i need resolve here?
if len(ps) != len(self.marg) or not all(resolve(0<=b) and resolve(b<=e) and resolve(e<=s) for s,(b,e) in zip(ps, self.marg)):
raise ValueError(f"invalid shrink {self.marg} for {ps}")
return tuple(ssimplify(e-s) for s,e in self.marg)
case Ops.FLIP:
if len(ps) != len(self.marg) or not all(isinstance(x, bool) for x in self.marg): raise ValueError(f"bad flip on {ps}, {self.marg}")
return ps
case Ops.MULTI: return tuple(s*len(self.device) if a == self.axis else s for a,s in enumerate(ps))
case Ops.REDUCE_AXIS | Ops.WMMA:
axis_arg = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7]
if not isinstance(axis_arg, tuple) or not all(isinstance(x, int) and x>=0 and x<len(ps) for x in axis_arg):
raise ValueError(f"invalid type for axis: {axis_arg}")
return tuple(1 if i in axis_arg else s for i,s in enumerate(ps))
# elementwise ops keep the shape the same. all inputs with shape must match
if self.op in (GroupOp.Elementwise-{Ops.BITCAST}).union({Ops.COPY, Ops.ASSIGN, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE, Ops.STORE}):
# TODO: remove this hack for 3 op assign
input_shapes = [x._shape for x in (self.src[:2] if self.op is Ops.ASSIGN else self.src) if x._shape is not None]
if len(input_shapes) == 0: return None
if not all_same(input_shapes): raise RuntimeError(f"shape mismatch at {self.op}: {input_shapes}")
return input_shapes[0]
# all Ops must be explicitly handled
raise NotImplementedError(f"no shape handling for {self.op} with {self.dtype}")
@property
def shape(self) -> tuple[sint, ...]:
if (ret:=self._shape) is None: raise RuntimeError(f"shape requested, but {self.op} doesn't have a shape")
return ret
@property
def size(self) -> int: return prod([int(x.vmax) if isinstance(x, UOp) else x for x in self.shape])
@functools.cached_property
def ended_ranges(self):
if self.op in range_start: return self.src[range_start[self.op]:]
return ()
# determine what ranges this is in
@recursive_property
def _ranges(self) -> dict[UOp, None]:
ret: dict[UOp, None] = {}
for s in self.src: ret.update(s.ranges)
for er in self.ended_ranges:
if er.op is Ops.RANGE:
# if it's a single RANGE, we don't flow through it.
if er in ret: del ret[er]
else:
# if it's not a RANGE, we include all ranges in srcs.
# technically we shouldn't flow through these ranges either, but this is pre pm_add_control_flow so it's the same.
for s in er.ranges:
if s in ret: del ret[s]
return ret
@property
def ranges(self) -> dict[UOp, None]:
if self.op is Ops.RANGE: return {self:None} | self._ranges
return self._ranges
# *** uop evaluation ***
def simplify(self, tracked=False, full_symbolic=True):
# late import!
from tinygrad.uop.symbolic import symbolic, commutative
with Context(TRACK_MATCH_STATS=0 if not tracked else TRACK_MATCH_STATS.value):
return graph_rewrite(self, symbolic if full_symbolic else commutative, name="simplify")
def ssimplify(self) -> UOp|ConstType: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret
def sintify(self) -> sint: return self.arg if self.op is Ops.CONST else self
def _eval(self, dtype, expected_type:Type[T]) -> T:
assert self.dtype in dtype, f"eval with wrong dtype {self}"
vmin, vmax = (simple_self:=self.simplify())._min_max
if vmin != vmax: raise ValueError(f"eval failed to be a single number, range is {vmin} to {vmax} in {simple_self.render()}")
assert isinstance(vmin, expected_type), f"vmin is wrong dtype {type(vmin)} != {expected_type}"
return vmin
def __bool__(self): return self._eval((dtypes.bool,), bool)
def __int__(self): return self._eval(dtypes.ints, int)
def __float__(self): return self._eval(dtypes.floats, float)
def substitute(self, dvars:dict[UOp, UOp], name:str|None=None, extra_pm:PatternMatcher|None=None):
dvars = {k:v for k,v in dvars.items() if k is not v}
if len(dvars) == 0: return self
with Context(TRACK_MATCH_STATS=(0 if name is None else TRACK_MATCH_STATS.value)):
return graph_rewrite(self, (extra_pm+_substitute) if extra_pm is not None else _substitute, dvars, bottom_up=True, name=name)
# *** uop tracing stuff ***
@recursive_property
def trace_num(self):
num = next(ucount)
# KERNEL also has a UOp in the arg
arg = type(self.arg)(self.arg.ast.trace_num, self.arg.metadata) if self.op is Ops.KERNEL else self.arg
uop_fields[num] = (self.op, self.dtype, tuple(s.trace_num for s in self.src), arg, self.tag)+((self.metadata,) if TRACEMETA>=2 else ())
return num
# *** uop syntactic sugar ***
def sink(*srcs:UOp|None, **kwargs): # pylint: disable=no-self-argument
return UOp(Ops.SINK, dtypes.void, tuple([x for x in srcs if x is not None]), **kwargs)
def group(*srcs:UOp|None): # pylint: disable=no-self-argument
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
def vectorize(self, *srcs, **kwargs):
return UOp(Ops.VECTORIZE, self.dtype.vec(len(srcs)+1), (self,)+srcs, **kwargs)
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def __getitem__(self, idx):
idx = argfix(idx)
assert len(idx) == len(self.shape), f"__getitem__ shape mismatch, indexing {self.shape} with {len(idx)} args"
if len(slice_idx:=[i for i,x in enumerate(idx) if isinstance(x, slice)]):
perm = self.permute(tuple([i for i in range(self.ndim) if i not in slice_idx] + slice_idx))
return perm.index(*[UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in idx if not isinstance(x, slice)], ptr=True)
else:
return self.index(*[UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in idx])
def const_like(self, b:ConstLike):
# constants can optionally have a DEVICE source
return UOp.const(self.dtype, b, device=self._device, shape=self._shape)
def broadcast(self, count:int):
assert self.dtype.vcount == 1
if count == 1: return self
return UOp(Ops.VECTORIZE, self.dtype.vec(count), (self,)*count)
def cast(self, dtype:DType):
# TODO: we shouldn't have to check for dtype.count == 1 here, but CAST is misused in AMD LLVM
if dtype.count == 1 and dtype.count != self.dtype.count: dtype = dtype.vec(self.dtype.count)
if self.dtype == dtype: return self
return UOp(Ops.CAST, dtype, (self,))
def bitcast(self, dtype:DType): return UOp(Ops.BITCAST, dtype, (self,))
def gep(self, i:tuple[int, ...]|int):
if isinstance(i, tuple) and len(i) == 1: return self.gep(i[0])
if isinstance(i, int):
# NOTE: these are just shortcuts to not have to create and fold later
if self.op is Ops.VECTORIZE: return self.src[i]
if self.op is Ops.VCONST: return UOp.const(self.dtype.scalar(), self.arg[i])
if self.op is Ops.CONST: return UOp.const(self.dtype.scalar(), self.arg)
i = (i,)
return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs)
def store(self, src:UOp|ConstType, **kwargs):
return UOp(Ops.STORE, kwargs.pop("dtype", dtypes.void), (self, UOp.const(self.dtype, src) if not isinstance(src, UOp) else src), **kwargs)
def end(self, *src:UOp):
if len(src) == 0: return self
return UOp(Ops.END, src=(self,)+src)
def after(self, *src:UOp, **kwargs): return UOp(Ops.AFTER, self.dtype, (self,)+src, **kwargs)
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x))
def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src)
def contract(self, *rngs:UOp):
assert all(x.arg[-1] == AxisType.UPCAST for x in rngs), "all contract ranges must be upcast"
return UOp(Ops.CONTRACT, dtype=self.dtype.vec(prod([x.vmax+1 for x in rngs])), src=(self,), arg=tuple((x.arg[0], x.vmax+1) for x in rngs))
def alu(self, op, *src:UOp, **kwargs):
out_dtype = (self, *src)[-1].dtype
if op in {Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
return UOp(op, out_dtype, (self,)+src, **kwargs)
@staticmethod
def const(dtype:DType, b:ConstLike, device:str|tuple[str, ...]|None=None, shape:tuple[sint, ...]|None=None, src=None, unique:bool|int=False):
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
# NOTE: float('nan') != float('nan'), so we canonicalize here
if isinstance(b, float) and math.isnan(b): b = math.nan
ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype), src=() if src is None else (src,))
if device is not None:
if unique or not isinstance(unique, bool): ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device), UOp.unique(None if unique is True else unique)))
else: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),))
elif unique or not isinstance(unique, bool): raise RuntimeError("unique consts only with DEVICE")
if shape is not None: ret = ret.reshape((1,)*len(shape)).expand(shape)
return ret
@staticmethod
def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.index, src=(), **kwargs):
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=(axis_id, axis_type)+arg, **kwargs)
@staticmethod
def special(end:sint, name:str, dtype=dtypes.index): return UOp(Ops.SPECIAL, dtype=dtype, src=(sint_to_uop(end, dtype),), arg=name)
def r(self, op:Ops, axis:tuple[int, ...]):
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis)) if len(axis) else self
@staticmethod
def invalid(count=1): return UOp(Ops.CONST, dtypes.index.vec(count), src=(), arg=Invalid)
def valid(self, cond): return self if cond.op is Ops.WHERE and cond.arg else cond.where(self, UOp.invalid(self.dtype.count))
def get_idx(self) -> UOp:
assert self.dtype.scalar() is dtypes.index, "Can only call get_idx on index dtype"
return self.src[1] if self.op is Ops.WHERE and self.src[2].arg is Invalid else self
def get_valid(self) -> UOp:
assert self.dtype.scalar() is dtypes.index, "Can only call get_valid on index dtype"
return self.src[0] if self.op is Ops.WHERE and self.src[2].arg is Invalid else UOp.const(dtypes.bool, self.arg is not Invalid)
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
def is_contiguous(self):
# TODO: this is is_realized
if self.op is Ops.RESHAPE: return self.src[0].is_contiguous()
return self.op is Ops.BUFFER
def contiguous(self, *args, **kwargs):
if self.op is Ops.CONTIGUOUS: return self
if self.is_contiguous(): return self
return UOp(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs)
def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD)
def bufferize(self, *args, **kwargs): return UOp(Ops.BUFFERIZE, dtype=self.dtype, src=(self,)+args, **kwargs)
def allreduce(self, op, device:str|tuple[str, ...]|UOp):
assert isinstance(self.device, tuple), f"allreduce must be on tuple {self.device} isn't"
return UOp(Ops.ALLREDUCE, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device), op)
def overflows(self, dtype:DType) -> bool: return self.vmin < dtype.min or dtype.max < self.vmax
# *** ShapeTracker helpers ***
def split_uop(self:UOp, sep:Ops):
if self.op is sep:
for s in self.src: yield from s.split_uop(sep)
else: yield self
# *** from MultiLazyBuffer ***
def multi(self, axis:int|None):
assert isinstance(self.device, tuple), f"multi device must be tuple, {self.device} isn't"
assert axis is not None, "multi None is no longer supported"
return UOp(Ops.MULTI, self.dtype, (self,), axis)
@property
def bounds(self):
if self.axis is None: raise RuntimeError("bounds is not defined when axis is None")
return tuple(itertools.pairwise(itertools.accumulate([self.src[0].shape[self.axis] for _ in self.device], initial=0)))
@functools.cached_property
def axis(self) -> int|None:
if self.op is Ops.MULTI: return self.arg
# NOTE: they all have to share an axis, we always choose [-1]
if self.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in self.src if x.axis is not None])) else None
if len(self.src) == 0: return None
src_axis = self.src[0].axis
if self.op is Ops.REDUCE_AXIS: return None if src_axis is not None and src_axis in self.arg[1] else src_axis
if self.op is Ops.RESHAPE:
if src_axis is None: return None
arg_acc:list[sint] = list(itertools.accumulate(self.marg, operator.mul, initial=1))
# new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
# TODO: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
return len(arg_acc) - arg_acc[::-1].index(prod(self.src[0].shape[:src_axis])) - 1
if self.op is Ops.PERMUTE: return self.marg.index(src_axis) if src_axis is not None else None
return src_axis
def _unshard(self, axis:int) -> UOp:
bsz, dcount = self.shape[axis], len(self.device)
dnum = UOp.variable("_device_num", 0, dcount-1)
return self.pad(tuple((0,0) if a != axis else (bsz*dnum, bsz*(dcount-1) - bsz*dnum) for a in range(len(self.shape))))
def _shard(self, axis:int) -> UOp:
dcount = len(self.device)
dnum = UOp.variable("_device_num", 0, dcount-1)
if self.shape[axis] % dcount != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {dcount=}")
sz = self.shape[axis] // dcount
return self.shrink(tuple((0,s) if i != axis else (dnum*sz,dnum*sz+sz) for i,s in enumerate(self.shape)))
def shard(self, devices:tuple[str, ...], axis:int) -> UOp: return self.copy_to_device(devices)._shard(axis).multi(axis)
# *** from LazyBuffer ***
def copy_to_device(self, device:str|tuple[str, ...]|UOp, arg=None):
assert arg is None or isinstance(self.device, tuple)
inp = self if arg is None else UOp(Ops.MSELECT, self.dtype, src=(self,), arg=arg)
return UOp(Ops.COPY, self.dtype, (inp, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device))
def mselect(self, arg:int) -> UOp: return UOp(Ops.MSELECT, self.dtype, (self,), arg)
@property
def metadata(self) -> tuple[Metadata, ...]|None: return all_metadata.get(self, None)
# *** uop movement ops ***
@property
def base(self) -> UOp:
if self.op in GroupOp.Movement: return self.src[0].base
if self.op is Ops.MULTI: return self.src[0].base # MULTI is really a VIEW
return self
# like gep, but might return an integer
def sgep(self, i:int) -> sint:
match self.op:
case Ops.CONST: return self.arg
case Ops.VCONST: return self.arg[i]
case Ops.VECTORIZE: return self.src[i].sintify()
case _: raise RuntimeError(f"no sgep on {self.op}")
@functools.cached_property
def marg(self):
match self.op:
case Ops.RESHAPE | Ops.EXPAND: return tuple(self.src[1].sgep(i) for i in range(self.src[1].dtype.count))
case Ops.PAD | Ops.SHRINK: return tuple((self.src[1].sgep(i), self.src[2].sgep(i)) for i in range(self.src[1].dtype.count))
case Ops.PERMUTE | Ops.FLIP: return self.arg
case _: raise RuntimeError(f"{self.op} is not a MovementOp")
def _mop(self, op:Ops, arg, same_shape_noop:bool=False) -> UOp:
match op:
case Ops.RESHAPE | Ops.EXPAND: src_args = [arg]
case Ops.PAD | Ops.SHRINK: src_args = list(zip(*arg))
case Ops.PERMUTE | Ops.FLIP: src_args = []
case _: raise RuntimeError(f"{op} is not a MovementOp")
usrcs = []
for arg in src_args:
if len(arg) == 0: usrcs.append(UOp(Ops.VECTORIZE, dtypes.index.vec(0)))
elif all(isinstance(x, int) for x in arg): usrcs.append(UOp.const(dtypes.index.vec(len(arg)), arg))
else: usrcs.append(UOp(Ops.VECTORIZE, dtypes.index.vec(len(arg)), tuple(UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in arg)))
if len(usrcs) == 0: ret = UOp(op, self.dtype, (self,), arg)
else: ret = UOp(op, self.dtype, (self,)+UOp.sink(*usrcs).simplify().src)
# for all movement ops, we check shape property
if ret.shape == self.shape and same_shape_noop: return self
return ret
# in these four, if the shape doesn't change we can return self
def forced_reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=False)
#def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=True)
#def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg, same_shape_noop=True)
#def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg, same_shape_noop=True)
def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg, same_shape_noop=True)
# in these two, we have custom logic to check if they are a no-op
#def permute(self, arg:tuple[int, ...]): return self._mop(Ops.PERMUTE, arg, same_shape_noop=False) if arg != tuple(range(len(self.shape))) else self
#def flip(self, arg:tuple[bool, ...]): return self._mop(Ops.FLIP, arg, same_shape_noop=False) if any(arg) and len(arg) == len(self.shape) else self
# *** uop UNIQUE ***
# TODO: use this in Buffer
unique_num = itertools.count(0)
@staticmethod
def unique(arg:int|None=None): return UOp(Ops.UNIQUE, arg=next(UOp.unique_num) if arg is None else arg)
# *** uop Buffer stuff ***
@staticmethod
def new_buffer(device:str|tuple[str, ...], size:int, dtype:DType, num=None):
return UOp(Ops.BUFFER, dtype, (UOp.unique(num), UOp(Ops.DEVICE, arg=device)), size)
@property
def device(self) -> str|tuple[str, ...]: return cast(str|tuple[str, ...], unwrap(self._device))
@recursive_property
def _device(self) -> str|tuple[str, ...]|None:
if self.op is Ops.DEVICE: return self.arg
if self.op is Ops.BUFFERIZE: return self.arg.device
if self.op is Ops.AFTER: return self.src[0]._device
if self.op is Ops.MSELECT:
assert isinstance(self.src[0].device, tuple), "mselect must be on tuple device"
return self.src[0].device[self.arg]
if self.op is Ops.MSTACK: return tuple(cast(str, x.device) for x in self.src)
if self.op in {Ops.COPY, Ops.BUFFER, Ops.ALLREDUCE}: return self.src[1].device
for x in self.src:
if x._device is not None: return x._device
return None
@property
def buf_uop(self) -> UOp:
if self.op is Ops.BUFFER: return self
if self.op is Ops.MSELECT: return self.src[0].buf_uop.mselect(self.arg)
if self.op is Ops.MSTACK: return UOp(Ops.MSTACK, self.dtype, src=tuple(x.buf_uop for x in self.src))
assert self.base.op is Ops.AFTER, f"must be AFTER {self.base.op}"
return self.base.src[0].buf_uop.base
def as_buf(self) -> UOp:
if self.op is Ops.MSELECT: return self.src[0].as_buf().mselect(self.arg)
if self.op is Ops.MSTACK: return UOp(Ops.MSTACK, self.dtype, src=tuple(x.as_buf() for x in self.src))
# TODO: this should be the only one of these. this is the one RANGEIFY uses
s = self
while len(s.src) and s.op not in {Ops.BUFFER, Ops.BUFFERIZE, Ops.MSTACK}: s = s.src[0]
return s
def buf_target(self) -> UOp:
# the buffer that's being loaded from or store to
match self.op:
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return self
case Ops.AFTER | Ops.INDEX | Ops.STORE | Ops.LOAD: return self.src[0].buf_target()
case Ops.VECTORIZE:
assert all_same(self.src)
return self.src[0].buf_target()
case _: raise RuntimeError(f"buf_target called on non load/index/store {self.op}")
@property
def buffer(self) -> Buffer|MultiBuffer:
from tinygrad.device import Buffer, MultiBuffer
if self is not self.base:
assert self.op is Ops.RESHAPE, f"can only be RESHAPE {self}"
return self.src[0].buffer
if self.op is Ops.MSELECT:
ret = self.src[0].buffer
assert isinstance(ret, MultiBuffer)
return ret.bufs[self.arg]
if self.op is Ops.MSTACK:
ret = MultiBuffer.__new__(MultiBuffer)
ret.bufs = [cast(Buffer, x.buffer) for x in self.src]
assert all_same([x.size for x in ret.bufs]) and all_same([x.dtype for x in ret.bufs]), "multibuffers mismatch buffers"
return ret
assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
if (cret:=buffers.get(self)) is not None: return cret
rdtype = self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base
if isinstance(self.device, tuple): ret = MultiBuffer(self.device, self.size, rdtype).ref(1)
else: ret = Buffer(self.device, self.size, rdtype).ref(1)
buffers[self] = ret
return ret
@property
def realized(self) -> Buffer|MultiBuffer|None:
# NOTE: this is used by the JIT to determine which inputs we capture
return self.buffer if self.op in {Ops.BUFFER, Ops.MSTACK} and self.buffer.is_allocated() else None
@property
def is_realized(self) -> bool:
return all(x.base.realized is not None for x in self.base.src) if self.base.op is Ops.MULTI else self.base.realized is not None
# *** uop Variable stuff ***
@staticmethod
def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.index) -> UOp:
assert not isinstance(min_val, UOp) and not isinstance(max_val, UOp), f"can't create Variable {name} with {min_val}/{max_val}"
return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
@property
def expr(self) -> str:
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
return self.arg[0]
def bind(self, val:int|UOp):
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
uval = self.const_like(val) if isinstance(val, int) else val
assert self.arg[1] <= uval.vmin and uval.vmax <= self.arg[2], f"bind {val} not in range [{self.arg[1]}, {self.arg[2]}]"
return UOp(Ops.BIND, self.dtype, (self, uval))
def unbind(self) -> tuple[Variable, int]:
assert self.op is Ops.BIND and self.src[0].op is Ops.DEFINE_VAR and self.src[1].op is Ops.CONST, f"can't unbind {self}"
return self.src[0], self.src[1].arg
def unbind_all(self) -> tuple[UOp, dict[Variable, int]]:
ret:dict[Variable, int] = {}
return graph_rewrite(self, pm_unbind, ctx=ret), ret
@property
def val(self) -> int: return self.unbind()[1]
def vars(self) -> set[UOp]:
bound_vars = set([x for x in self.toposort() if x.op is Ops.BIND and x.src[0].op is Ops.DEFINE_VAR])
bound_var_base = set(x.src[0] for x in bound_vars)
all_vars = set([x for x in self.toposort() if x.op is Ops.DEFINE_VAR])
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
def variables(self) -> list[Variable]:
return sorted(set([x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
# *** uop symbolic stuff ***
def is_increasing(self:UOp) -> bool:
# is f a monotonically increasing function regards its input
if self.op in GroupOp.Irreducible: return True
if self.op is Ops.ADD: return self.src[0].is_increasing() and self.src[1].is_increasing()
if self.op in (Ops.MUL, Ops.IDIV) and self.src[1].op is Ops.CONST and self.src[1].arg >= 0: return self.src[0].is_increasing()
return False # False if not sure
def const_factor(self) -> int:
"""largest known int that divides self"""
# TODO: for negatives it's not the largest
if self.op is Ops.CONST: return self.arg
if self.op is Ops.VCONST: return math.gcd(*self.arg)
if self.op is Ops.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
if self.op is Ops.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
return 1
def divides(self, v:int) -> UOp|None:
if v==1: return self
if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None
if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None
if self.op is Ops.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None
if self.op is Ops.MUL:
if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
return None # generic None if we aren't sure
def pop_const(self, op=Ops.ADD) -> tuple[UOp, ConstType]:
return (self.src[0], self.src[1].arg) if self.op is op and self.src[1].op is Ops.CONST else (self, identity_element(op, self.dtype))
@staticmethod
def gcd(*uops: UOp) -> UOp:
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in uops])
count = functools.reduce(operator.and_, [collections.Counter(term.split_uop(Ops.MUL)) for term in terms])
return math.prod([*count.elements(), terms[0].const_like(math.gcd(*factors))]) # put the const at the top
def divide_exact(self, v:UOp) -> UOp|None:
if self is v: return self.const_like(1)
if v.op is Ops.CONST: return self.divides(v.arg)
if self.op is Ops.ADD: return None if (s0:=self.src[0].divide_exact(v)) is None or (s1:=self.src[1].divide_exact(v)) is None else s0+s1
if self.op is Ops.MUL:
(fac, const), (div_fac, div_const) = self.pop_const(Ops.MUL), v.pop_const(Ops.MUL)
new_count = collections.Counter(fac.split_uop(Ops.MUL))
new_count.subtract(div_fac.split_uop(Ops.MUL))
if const%div_const==0 and all(v>=0 for v in new_count.values()): return math.prod([*new_count.elements(), self.const_like(const//div_const)])
return None # generic None if we aren't sure
def sum(self:UOp, *uops:UOp) -> UOp: return functools.reduce(operator.or_ if self.dtype is dtypes.bool else operator.add, uops, self)
def prod(self:UOp, *uops:UOp) -> UOp: return functools.reduce(operator.and_ if self.dtype is dtypes.bool else operator.mul, uops, self)
@property
def vmin(self) -> ConstType: return self._min_max[0]
@property
def vmax(self) -> ConstType: return self._min_max[1]
@functools.cached_property
def _min_max(self) -> tuple[ConstType, ConstType]:
if self.op in GroupOp.Binary and not dtypes.is_float(self.dtype):
(s0_vmin, s0_vmax), (s1_vmin, s1_vmax) = self.src[0]._min_max, self.src[1]._min_max
if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax
if self.op is Ops.SUB: return s0_vmin-s1_vmax, s0_vmax-s1_vmin
if self.op is Ops.AND and dtypes.is_int(self.dtype) and s1_vmin == s1_vmax >= 0 and s0_vmin >= 0: return min(0, s0_vmin), min(s0_vmax, s1_vmax)
if self.op is Ops.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals)
# SHL/SHR on consts only
if self.op is Ops.SHL and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] << t[2], t[1] << t[2]
if self.op is Ops.SHR and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] >> t[2], t[1] >> t[2]
if self.op is Ops.MOD:
if s1_vmin > 0: return (0, s1_vmax-1) if s0_vmin >= 0 else (-(s1_vmax-1), 0) if s0_vmax <= 0 else (-(s1_vmax-1), s1_vmax-1)
if s1_vmax < 0: return (0, -s1_vmin-1) if s0_vmin >= 0 else (-(-s1_vmin-1), 0) if s0_vmax <= 0 else (-(-s1_vmin-1), -s1_vmin-1)
if self.op is Ops.IDIV:
assert isinstance(s0_vmin, int) and isinstance(s0_vmax, int) and isinstance(s1_vmin, int) and isinstance(s1_vmax, int)
if s1_vmin*s1_vmax>0:
return min(vals:=(cdiv(s0_vmin, s1_vmin), cdiv(s0_vmin, s1_vmax), cdiv(s0_vmax, s1_vmin), cdiv(s0_vmax, s1_vmax))), max(vals)
if self.op is Ops.MAX: return max(s0_vmin, s1_vmin), max(s0_vmax, s1_vmax)
if self.op is Ops.CMPLT: return (s0_vmax<s1_vmin, s0_vmin<s1_vmax)
if self.op is Ops.CMPNE: return ((s0_vmax < s1_vmin) or (s1_vmax < s0_vmin), not (s0_vmin == s0_vmax == s1_vmin == s1_vmax))
if self.op is Ops.OR and self.dtype == dtypes.bool: return s0_vmin or s1_vmin, s0_vmax or s1_vmax
if self.op is Ops.AND and self.dtype == dtypes.bool: return s0_vmin and s1_vmin, s0_vmax and s1_vmax
# float has NAN issue and we use explicit NAN in transcendental
if self.op is Ops.WHERE and dtypes.is_int(self.dtype): return min(self.src[1].vmin, self.src[2].vmin), max(self.src[1].vmax, self.src[2].vmax)
# NOTE: returned UOp is assumed to be CONST
if self.op is Ops.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2]
if self.op in (Ops.RANGE, Ops.SPECIAL): return 0, (self.src[0]-1).vmax
if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value
if self.op in {Ops.UNROLL, Ops.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
if self.op is Ops.CONST and self.arg is not Invalid: return self.arg, self.arg
if self.op is Ops.VCONST and Invalid not in self.arg: return (min(self.arg), max(self.arg))
if self.op is Ops.GEP: return self.src[0]._min_max
# TODO: CAST to bool/unsigned is not monotone, still some case can be simplified
if self.op is Ops.CAST and self.dtype in dtypes.floats+dtypes.sints+(dtypes.index,):
return max(dtypes.min(self.dtype), self.src[0].vmin), min(self.src[0].vmax, dtypes.max(self.dtype))
return dtypes.min(self.dtype), dtypes.max(self.dtype)
@functools.cached_property
def _sym_fxn(self):
sself = self.simplify()
varnames = tuple(x.arg[0] for x in sself.toposort() if x.op is Ops.DEFINE_VAR)
# TODO: sanitize varnames, or don't use naked eval while staying fast
return eval("lambda "+','.join(varnames)+": "+sself.render(pm=renderer_infer)), varnames # pylint: disable=eval-used
def sym_infer(self, var_vals:dict[str, int]):
fxn, varnames = self._sym_fxn
return fxn(**{k:v for k,v in var_vals.items() if k in varnames})
def render(self, simplify=True, pm:PatternMatcher|None=None) -> str:
ctx: dict[UOp, str] = {}
pm = renderer if pm is None else pm
for u in (s:=self.simplify() if simplify else self).toposort():
ctx[u] = cast(str, pm.rewrite(u, ctx=ctx))
return ctx[s]
def pyrender(self): return pyrender(self)
# *** uop high level syntactic sugar ***
@staticmethod
def placeholder(shape:tuple[int, ...], dtype:DType, slot:int, addrspace=AddrSpace.GLOBAL):
lookup = {AddrSpace.GLOBAL: Ops.DEFINE_GLOBAL, AddrSpace.LOCAL: Ops.DEFINE_LOCAL, AddrSpace.REG: Ops.DEFINE_REG}
ret = UOp(lookup[addrspace], dtype.ptr(prod(shape), addrspace), arg=slot)
if len(shape) > 1: ret = ret.reshape(shape)
return ret
def placeholder_like(self, slot:int):
assert all_int(self.shape), "no placeholder-like on symbolic shape"
return UOp.placeholder(self.shape, self.dtype, slot)
# set is store+end+after
def set(self:UOp, val:UOp|ConstType, end:UOp|tuple[UOp, ...]|list[UOp]=()) -> UOp:
return self.src[0].after(self.store(val).end(*argfix(end)))
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(srcs)]
contig_srcs = tuple(x.contiguous() for x in srcs)
kernel = UOp(Ops.KERNEL, src=tuple(x.base for x in contig_srcs), arg=Kernel(fxn(*placeholders), grad_fxn=grad_fxn))
return [s.after(kernel) for s in contig_srcs]
@dataclass(frozen=True)
class KernelInfo:
name: str = "test" # name of the kernel
axis_types: tuple[AxisType, ...] = tuple()
dont_use_locals: bool = False # don't use local indexing
applied_opts: tuple = tuple()
opts_to_apply: tuple|None = None
@property
def function_name(self): return to_function_name(self.name)
@dataclass(frozen=True)
class Kernel:
ast: UOp
metadata: tuple[Metadata, ...] = ()
grad_fxn: Callable|None = None
# ******** ops in python ********
def safe_exp2(x):
try: return 2 ** x
except OverflowError: return math.inf
def safe_pow(x, y):
try: return math.nan if isinstance(p:=pow(x, y), complex) else p
except ZeroDivisionError: return math.inf
except ValueError: return math.inf if x > 0 else -math.inf
python_alu: dict[Ops, Callable] = {
Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: safe_exp2,
Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIPROCAL: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.POW: safe_pow, Ops.TRUNC: math.trunc,
Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt,
Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max,
Ops.MOD: cmod, Ops.IDIV: cdiv, Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z, Ops.CMPEQ: operator.eq}
def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
if dtype.count > 1:
return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)])
if dtype==dtypes.index and op in GroupOp.Binary and Invalid in operands: return Invalid
alu = python_alu[op](*operands)
return truncate.get(dtype, lambda x: x)(alu) if truncate_output else alu
# ***** uop helpers *****
def print_uops(uops:list[UOp]):
uops_index = {u:i for i,u in enumerate(uops)}
for i,u in enumerate(uops):
formatted_srcs = [(uops_index[x] if x.op is not Ops.CONST else f"{x.arg}") if x in uops else "--" for x in u.src]
print(f"{i:4d} {str(u.op):20s}: {multirange_str(u.ranges, color=True, pad=10)} {str(u.dtype):40s} " f"{str(formatted_srcs):32s} {u.arg}")
# ***** pattern matcher *****
def get_location() -> tuple[str, int]:
frm = sys._getframe(1)
# skip over ops.py/mathtraits.py (unless there's nothing but ops.py/mathtraits.py)
while pathlib.Path(frm.f_code.co_filename).name in ("ops.py", "mathtraits.py") and frm.f_back is not None and \
not frm.f_back.f_code.co_filename.startswith("<frozen"):
frm = frm.f_back
return frm.f_code.co_filename, frm.f_lineno
class UPat(OpMixin):
__slots__ = ("op", "dtype", "arg", "name", "src")
def __init__(self, op:Ops|tuple[Ops, ...]|set[Ops]|None=None, dtype:DType|tuple[DType, ...]|None=None,
src:tuple[UPat, ...]|list[UPat]|UPat|None=None, arg:Any=None,
name:str|None=None, allow_any_len:bool=False, custom_early_reject:set[Ops]|None=None, location=None):
assert op is None or isinstance(op, (Ops, tuple, set)), "op must be Ops or tuple of Ops"
self.op: tuple[Ops, ...]|None = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, set) else op)
self.dtype: tuple[DType, ...]|None = (dtype,) if isinstance(dtype, DType) else dtype
self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject
self.src: Any = None
assert self.name != "ctx", "UPat can't be named ctx"
assert dtype is None or isinstance(dtype, DType) or all(isinstance(x, DType) for x in dtype), f"invalid dtype {dtype}"
# try all permutations if it's a list
if isinstance(src, list): self.src = list(itertools.permutations(src)) if not all_same(src) else [tuple(src)]
# only one if it's a tuple
elif isinstance(src, tuple): self.src = [src]
# repeat if it's a UPat
elif isinstance(src, UPat): self.src = [itertools.repeat(src)]
self.strict_length = not (allow_any_len or isinstance(src, UPat) or src is None)
self.required_len: int = 0 if isinstance(src, UPat) or src is None else len(src)
self.location = location or get_location()
if custom_early_reject is not None: self.early_reject = custom_early_reject
else:
upat_match = [src] if isinstance(src, UPat) else ([] if src is None else self.src[0])
self.early_reject = {pp.op[0] for pp in upat_match if pp.op is not None and len(pp.op) == 1}
def __reduce__(self):
return UPat, (self.op, self.dtype, self._in_src, self.arg, self.name, not self.strict_length, self.custom_early_reject, self.location)
def named(self, name:str): return UPat(self.op, self.dtype, self._in_src, self.arg, name, not self.strict_length, self.custom_early_reject)
@staticmethod
def any(*src): return UPatAny(src=src)
def or_casted(self, name:str|None=None): return UPat.any(self if name is None else self.named(name), UPat(Ops.CAST, name=name, src=(self,)))
def or_after(self, name:str|None=None):
return UPat.any(self if name is None else self.named(name), UPat(Ops.AFTER, name=name, src=(self,), allow_any_len=True))
@staticmethod
@functools.cache
def var(name:str|None=None, dtype:DType|tuple[DType, ...]|None=None): return UPat(dtype=dtype, name=name)
@staticmethod
@functools.cache
def cvar(name:str|None=None, dtype:DType|tuple[DType, ...]|None=None, vec=True, arg=None):
return UPat((Ops.CONST,Ops.VCONST) if vec else Ops.CONST, dtype, name=name, arg=arg)
@staticmethod
def const(dtype:DType|tuple[DType, ...]|None, b:ConstType|InvalidType): return UPat(Ops.CONST, dtype=dtype, arg=b)
# lil helper
def f(self, op, **kwargs): return UPat(op, src=(self,), **kwargs)
# copied from UOp
def sink(self, *srcs:UPat|None, **kwargs): return UPat(Ops.SINK, dtypes.void, (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def index(self, idx:UPat, valid:UPat|None=None, **kwargs):
return UPat(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx), **kwargs)
def cast(self, dtype=None, **kwargs): return UPat(Ops.CAST, dtype, (self,), **kwargs)
def bitcast(self, dtype=None): return UPat(Ops.BITCAST, dtype, (self,))
def gep(self, i:int|None=None, **kwargs): return UPat(Ops.GEP, None, (self,), (i,) if i is not None else None, **kwargs)
def load(self, *src:UPat, **kwargs): return UPat(Ops.LOAD, src=(self,)+src, **kwargs)
def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, self.dtype, (self,)+src, **kwargs)
def assign(self, x:UPat, **kwargs): return UPat(Ops.ASSIGN, self.dtype, (self,x), **kwargs)
def reduce(self, *src:UPat, **kwargs): return UPat(Ops.REDUCE, self.dtype, src=(self,)+src, **kwargs)
def broadcast(self, **kwargs): return UPat(Ops.VECTORIZE, self.dtype, src=self, **kwargs)
def contiguous(self, *args, **kwargs): return UPat(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs)
def after(self, *src:UPat, **kwargs): return UPat(Ops.AFTER, self.dtype, (self,)+src, **kwargs)
def end(self, *src:UPat, **kwargs): return UPat(Ops.END, self.dtype, (self,)+src, **kwargs)
def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
def alu(self, op:Ops, *src:UPat):
asrc = (self,)+src
return UPat(op, dtypes.bool if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc)
def match(self:UPat, uop:UOp, store:dict[str, UOp]) -> list[dict[str, UOp]]:
if (self.op is not None and uop.op not in self.op) or \
(self.name is not None and store.setdefault(self.name, uop) is not uop) or \
(self.dtype is not None and uop.dtype not in self.dtype and uop.dtype.scalar() not in self.dtype) or \
(self.arg is not None and self.arg != uop.arg) or \
(len(uop.src) < self.required_len) or \
(self.strict_length and len(uop.src) != self.required_len): return []
if self.src is None: return [store]
res: list[dict[str, UOp]] = []
for vp in self.src:
stores, new_stores = [store.copy()], []
for uu, vv in zip(uop.src, vp):
for s in stores: new_stores.extend(vv.match(uu, s))
stores, new_stores = new_stores, []
res.extend(stores)
return res
class UPatAny(UPat):
def match(self:UPat, uop:UOp, store:dict[str, UOp]) -> list[dict[str, UOp]]:
matches = [x.match(uop, store.copy()) for x in self.src[0]]
return flatten([x for x in matches if x is not None])
def deconstruct_function(fxn:Callable) -> tuple:
new_globals = {k:v for k,v in fxn.__globals__.items() if k in fxn.__code__.co_names}
for co in fxn.__code__.co_consts:
if isinstance(co, types.CodeType): new_globals.update({k:v for k,v in fxn.__globals__.items() if k in co.co_names})
# NOTE: optional round trip through pickle!
assert fxn.__closure__ is None, "closures are not supported in pattern matchers"
ret = fxn.__code__, new_globals, fxn.__name__, fxn.__defaults__
return pickle.loads(pickle.dumps(ret)) if getenv("TEST_PICKLE") else ret
@functools.cache
def upat_interpret(p:UPat, fxn:Callable) -> Callable:
real_fxn = types.FunctionType(*deconstruct_function(fxn))
if 'ctx' in inspect.signature(real_fxn).parameters:
def universal_match(uop, ctx):
for match in p.match(uop, {}):
if (ret:=real_fxn(ctx=ctx, **match)) is not None: return ret # pylint: disable=not-callable
return None
else:
def universal_match(uop, _):
for match in p.match(uop, {}):
if (ret:=real_fxn(**match)) is not None: return ret # pylint: disable=not-callable
return None
return universal_match
class PatternMatcher:
def __init__(self, patterns:Sequence[tuple[UPat, Callable|tuple]], compiled=bool(getenv("UPAT_COMPILE", 1))):
if compiled: from tinygrad.uop.upat import upat_compile
# if this comes from a pickle, we reconstruct the lambda functions here
self.patterns:list[tuple[UPat, Callable]] = [(p,types.FunctionType(*fxn) if isinstance(fxn, tuple) else fxn) for p,fxn in patterns]
# NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher!
self.pdict: dict[Ops, list[tuple[UPat, Callable, set]]] = {}
# uop is required, arg is optional
for p,fxn in self.patterns:
assert p.op is not None
if compiled and (match:=upat_compile(p, fxn)) is not None: pass # pylint: disable=E0606
else: match = upat_interpret(p, fxn)
for uop in p.op: self.pdict.setdefault(uop, []).append((p, match, p.early_reject))
def __reduce__(self): return PatternMatcher, ([(x,deconstruct_function(fxn) if fxn.__name__ == "<lambda>" else fxn) for x,fxn in self.patterns],)
@functools.cache # pylint: disable=method-cache-max-size-none
def __add__(self, more:PatternMatcher) -> PatternMatcher: return PatternMatcher(self.patterns+more.patterns)
def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
ler = {u.op for u in uop.src}
for _,match,early_reject in self.pdict.get(uop.op, []):
if not early_reject.issubset(ler): continue
if (ret:=match(uop, ctx)) is not None and ret is not uop: return ret
return None
# *** tracking pattern matcher ***
TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if VIZ else 0)
match_stats:dict[UPat, list[int|float]] = dict()
# TRACK_MATCH_STATS>=2 or VIZ=1 saves all matches
ucount = itertools.count()
uop_fields:dict[int, tuple] = {}
@dataclass(frozen=True)
class TrackedGraphRewrite:
loc:tuple[str, int] # location that called graph_rewrite
sink:int # the sink input to graph_rewrite
matches:list[tuple[int, int, tuple, float]] # before/after UOp, UPat location and time
name:str # name of the rewrite
depth:int # depth if it's a subrewrite
bottom_up:bool
tracked_keys:list[TracingKey] = []
tracked_ctxs:list[list[TrackedGraphRewrite]] = []
_name_cnt:dict[str, itertools.count] = {}
if getenv("CAPTURE_PROCESS_REPLAY"):
replay_capture: dict[str, bytes] = {}
import atexit
@atexit.register
def save_to_diskcache():
for k,v in replay_capture.items(): diskcache_put("process_replay", k, v, prepickled=True)
def add_trace_group(kt:TracingKey) -> None:
tracked_keys.append(kt)
tracked_ctxs.append([])
def track_rewrites(name:Callable[..., str|TracingKey]|bool=True, replay:bool=False):
def _decorator(func):
def __wrapper(*args, **kwargs):
fn = key = func.__name__
if TRACK_MATCH_STATS >= 2: add_trace_group(key:=TracingKey(n:=f"{fn} n{next(_name_cnt.setdefault(fn, itertools.count(1)))}", (n,)))
with cpu_profile(key, "TINY") as e:
ret = func(*args, **kwargs)
if TRACK_MATCH_STATS >= 2 and callable(name):
name_ret = name(*args, **kwargs, ret=ret)
assert isinstance(name_ret, (TracingKey, str)), f"name function returned {type(name_ret)}"
tracked_keys[-1] = k = TracingKey(n:=tracked_keys[-1].display_name.replace(fn, name_ret), (n,)) if isinstance(name_ret, str) else name_ret
e.name = TracingKey(k.display_name if isinstance(name_ret, str) else f"{fn} for {k.display_name}", k.keys)
if getenv("CAPTURE_PROCESS_REPLAY") and replay:
# find the unittest frame we're capturing in
frm = sys._getframe(1)
while (f_back:=frm.f_back) is not None and "unittest" not in f_back.f_code.co_filename: frm = f_back
loc = f"{frm.f_code.co_filename.split('/')[-1]}:{frm.f_lineno} {frm.f_code.co_name}"
# capture global context vars and all the args passed in
with Context(PICKLE_BUFFERS=0):
inputs = (fn, args, kwargs, ContextVar._cache)
replay_capture[hashlib.sha256(pickle.dumps(inputs)).hexdigest()] = pickle.dumps(inputs+(loc, ret))
return ret
return __wrapper
return _decorator
active_rewrites:list[TrackedGraphRewrite] = []
def profile_matches(fxn:Callable):
def wrap(*args, **kwargs):
name = str(kwargs.get("name", None) or fxn.__name__)
assert args and isinstance(args[0], UOp), f"invalid match tracing inputs for {name} with {args}"
if tracking:=(TRACK_MATCH_STATS >= 2):
loc = ((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno)
depth = len(active_rewrites)
if not tracked_ctxs: add_trace_group(TracingKey(f"default {fxn.__name__}"))
tracked_ctxs[-1].append(ctx:=TrackedGraphRewrite(loc, args[0].trace_num, [], name, depth, kwargs.get("bottom_up", False)))
active_rewrites.append(ctx)
with cpu_profile(name, "TINY", display=tracking):
ret = fxn(*args, **kwargs)
if tracking: active_rewrites.pop()
return ret
return wrap
class TrackedPatternMatcher(PatternMatcher):
def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
ret = None
ler = {u.op for u in uop.src}
for p,match,early_reject in self.pdict.get(uop.op, []):
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
st = time.perf_counter()
if not early_reject.issubset(ler):
match_stats[p][2] += time.perf_counter()-st
continue
match_stats[p][1] += 1
try: ret = match(uop, ctx)
except Exception:
if TRACK_MATCH_STATS >= 2 and active_rewrites:
active_rewrites[-1].matches.append((uop.trace_num, UOp(Ops.REWRITE_ERROR,src=uop.src,arg=str(sys.exc_info()[1])).trace_num,p.location,0))
raise
if ret is not None and ret is not uop:
match_stats[p][0] += 1
match_stats[p][3] += (et:=time.perf_counter()-st)
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", printable(p.location))
if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and active_rewrites:
active_rewrites[-1].matches.append((uop.trace_num, ret.trace_num, p.location, et))
return ret
match_stats[p][2] += time.perf_counter()-st
return None
@dataclass(frozen=True)
class RewriteTrace: keys:list[TracingKey]; rewrites:list[list[TrackedGraphRewrite]]; uop_fields:dict[int, tuple] # noqa: E702
if TRACK_MATCH_STATS or PROFILE:
PatternMatcher = TrackedPatternMatcher # type: ignore
import atexit
@atexit.register
def print_match_stats():
if TRACK_MATCH_STATS >= 2:
with open(fn:=temp("rewrites.pkl", append_user=True), "wb") as f:
print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}")
pickle.dump(RewriteTrace(tracked_keys, tracked_ctxs, uop_fields), f)
if VIZ: return launch_viz("VIZ", temp("rewrites.pkl", append_user=True))
if getenv("PRINT_MATCH_STATS", TRACK_MATCH_STATS.value):
ret = [0,0,0.0,0.0]
for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]+x[1][3]):
loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}"
if v[1] != 0: print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {(v[2]+v[3])*1000.:9.2f} ms -- {loc_str:20s}", printable(k.location))
ret = [x+y for x,y in zip(ret, v)]
print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {(ret[2]+ret[3])*1000.:9.2f} ms -- TOTAL")
print(f"{len(match_stats)} rules, {sum(v[0] > 0 for v in match_stats.values())} matched once")
def launch_viz(env_str:str, data:str):
os.environ[env_str] = "0"
os.environ[f"{env_str}_DATA"] = data
if not int(os.getenv("VIZ", "0")) and not int(os.getenv("PROFILE", "0")) and not CI:
args = ['--kernels', getenv("VIZ_DATA", "")] if getenv("VIZ_DATA", "") else []
args += ['--profile', getenv("PROFILE_DATA", "")] if getenv("PROFILE_DATA", "") else []
viz_path = pathlib.Path(__file__).resolve().parent.parent / "viz" / "serve.py"
os.execv(sys.executable, [sys.executable, viz_path.as_posix()] + args)
# *** simple graph rewrite engine ***
with Context(SPEC=0): SENTINEL = UOp(Ops.SENTINEL)
class BottomUpGate(Exception): pass
class RewriteContext:
def __init__(self, pm, bpm, ctx=None):
self.pm: PatternMatcher|None = pm
self.pm_cache: dict[UOp, UOp|None] = {}
self.bpm: PatternMatcher|None = bpm
self.bpm_cache: dict[UOp, UOp|None] = {}
self.ctx = ctx
self.replace: dict[UOp, UOp] = {}
def cached_pm_rewrite(self, x:UOp):
if (ret:=self.pm_cache.get(x,SENTINEL)) is not SENTINEL: return ret
ret = self.pm_cache[x] = cast(PatternMatcher, self.pm).rewrite(x, self.ctx)
return ret
def cached_bpm_rewrite(self, x:UOp):
if (ret:=self.bpm_cache.get(x,SENTINEL)) is not SENTINEL: return ret
ret = self.bpm_cache[x] = cast(PatternMatcher, self.bpm).rewrite(x, self.ctx)
return ret
def unified_rewrite(self, root:UOp) -> UOp:
stack: collections.deque[tuple[UOp, int, UOp]] = collections.deque([(root, 0, root)])
on_stack = {root} # all UOps either on the stack or in self.replace, i.e. dont have to be placed again
REWRITE_STACK_LIMIT = getenv("REWRITE_STACK_LIMIT", 250000)
while stack:
if len(stack) > REWRITE_STACK_LIMIT: raise RuntimeError("infinite loop in graph_rewrite (stack too big)")
n, stage, new_n = stack.pop()
if n in self.replace: continue # skip any nodes we have seen
if stage == 0:
# if bottom up, we rewrite this node early. in both cases, we add its srcs to the stack
if self.bpm is not None:
# apply rewrite rules until a fixed point is reached. may return `uop` itself if PatternMatcher doesn't match
test_n: UOp|None = n
seen = set()
try:
while test_n is not None:
if test_n in seen: raise RuntimeError("infinite loop in fixed_point_rewrite")
seen.add(test_n)
new_n, test_n = test_n, self.cached_bpm_rewrite(test_n)
except BottomUpGate:
# if the bpm matching raised a gate, we are done with this node and dont continue down the srcs
self.replace[n] = unwrap(test_n)
continue
stack.append((n, 1, new_n))
for x in reversed(new_n.src):
if x in on_stack: continue
stack.append((x, 0, x))
on_stack.add(x)
elif stage == 1:
tmp = []
for x in new_n.src:
if (rx:=self.replace.get(x, SENTINEL)) is SENTINEL:
# if some new sources aren't ready, we try this again later. happens with on_stack, maybe should remove?
stack.appendleft((n, 1, new_n))
break
tmp.append(rx)
else:
# in stage 1, once all srcs are rewritten, rebuild (if changed) or run top-down rewrite
if (new_src:=tuple(tmp)) == new_n.src:
# if top down, do the rewrite. if no rewrite or bottom up, we are done rewriting this node so we add it to the dict
if self.pm is None or (new_src_n:=self.cached_pm_rewrite(new_n)) is None:
self.replace[n] = new_n
continue
else:
# if srcs changed from rewrites, construct a new UOp with the new srcs
new_src_n = UOp(new_n.op, new_n.dtype, new_src, new_n.arg, new_n.tag)
# trigger a rewrite of new_src_n, then after that rewrite is done, link it back to n
stack.append((n, 2, new_src_n))
stack.append((new_src_n, 0, new_src_n))
else:
# in stage 2, we link the result of new_n to the result of n
if (replaced_new_n:=self.replace.get(new_n, SENTINEL)) is SENTINEL:
# not ready, try the link later
stack.appendleft((n, 2, new_n))
else:
# otherwise we are done
self.replace[n] = replaced_new_n
return self.replace[root]
@profile_matches
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, bpm=None) -> UOp:
rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx)
return rewrite_ctx.unified_rewrite(sink)
@profile_matches
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, bpm=None,
input_map:dict[UOp, UOp]|None=None, ) -> dict[UOp, UOp]:
rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx)
new_map: dict[UOp, UOp] = {}
for k in (list(sink.toposort())[::-1] if bottom_up else sink.toposort()):
new_map[k] = v = rewrite_ctx.unified_rewrite(k)
if k is not v and k.metadata is not None: all_metadata[v] = tuple(dedup(all_metadata.get(v, ())))+k.metadata
if input_map is not None:
for k,v in input_map.items(): new_map[k] = new_map.get(v,v)
return new_map
def sint_to_uop(x:sint, dtype=dtypes.index) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x.cast(dtype)
def select_dtype(u): return (dtypes.long if u.overflows(dtypes.int32) else dtypes.int).vec(u.dtype.count)
pm_lower_index_dtype = PatternMatcher([
# There are no Unary ops at this point in symbolic, those are introduced later
(UPat(GroupOp.Binary, name="u", src=(UPat.var("x").cast(dtypes.index), UPat.var("y").cast(dtypes.index))), lambda u,x,y:
x.cast(dt:=least_upper_dtype(select_dtype(u), x.dtype, y.dtype)).alu(u.op, y.cast(dt)).cast(u.dtype)),
(UPat((Ops.CONST, Ops.VCONST), dtype=dtypes.index, name="u"), lambda u: u.replace(dtype=select_dtype(u)).cast(u.dtype) if u.arg!=Invalid else None),
(UPat(Ops.WHERE, dtypes.index, src=(UPat.var("cond"), UPat.var("x").cast(dtypes.index), UPat.var("y").cast(dtypes.index))), lambda cond,x,y:
cond.where(x.cast(dt:=least_upper_dtype(x.dtype, y.dtype)), y.cast(dt)).cast(dtypes.index)),
(UPat(Ops.RANGE, src=(UPat.var("end").cast(dtypes.index)), name="r"), lambda r,end: r.replace(dtype=end.dtype, src=(end,)).cast(dtypes.index)),
(UPat(Ops.VECTORIZE, src=UPat().cast(dtypes.index), name="v"),
lambda v: v.replace(dtype=(dt:=select_dtype(v)), src=tuple(s.src[0].cast(dt.scalar()) for s in v.src)).cast(dtypes.index)),
# special can only be int32
(UPat(Ops.SPECIAL, src=(UPat.var("var").cast(dtypes.index),), name="u"), lambda u,var: u.replace(dtype=dtypes.int, src=(var,)).cast(dtypes.index)),
(UPat(Ops.DEFINE_VAR, dtype=dtypes.index, name="u"), lambda u: u.replace(dtype=dtypes.int).cast(dtypes.index)),
(UPat(Ops.BIND, src=(UPat.var("var").cast(dtypes.index), UPat.cvar("val").cast(dtypes.index))), lambda var,val: var.bind(val).cast(dtypes.index)),
(UPat(Ops.CAST, src=(UPat(name="x").cast(dtypes.index),), name="c"), lambda x,c: x.cast(c.dtype)),
# lower Invalid
(UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"), UPat(Ops.CONST, arg=Invalid))), lambda buf,idx,cond: buf.index(idx, cond, ptr=True)),
# remove hanging casts
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))),
lambda buf,idx,valid: buf.index(idx, valid, ptr=True)),
(UPat((Ops.STORE, Ops.LOAD), src=(UPat(), UPat(), UPat().cast(dtypes.index)), allow_any_len=True, name="s"),
lambda s: s.replace(src=s.src[:2]+tuple(u.src[0] for u in s.src[2:]))),
(UPat((Ops.SINK, Ops.NOOP, Ops.END), name="n"),
lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.index else s for s in n.src))),
])
def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
_remove_all_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
def do_unbind(ctx:dict[Variable, int], x:UOp):
v,i = x.unbind()
ctx[v] = i
return v
pm_unbind = PatternMatcher([(UPat(Ops.BIND, name="x"), do_unbind)])
# for debug
syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<", Ops.SHR: ">>",
Ops.MUL: "*", Ops.CMPLT: "<", Ops.CMPNE: "!=", Ops.AND: "&", Ops.OR: "|", Ops.XOR: "^"}
# comparison operators are not in here because they are chained in python, not left-associative
precedence = {Ops.MUL:1, Ops.IDIV:1, Ops.MOD:1, Ops.ADD:2, Ops.SUB:2, Ops.SHL:3, Ops.SHR:3, Ops.AND:4, Ops.XOR:5, Ops.OR:6}
def strip_binary_parens(x:UOp, left:str, right:str, code_for_op) -> str:
if x.op not in precedence: return code_for_op(left, right)
return code_for_op(strip_parens(left) if precedence.get(x.src[0].op,99)<=precedence[x.op] else left, strip_parens(right) if
precedence.get(x.src[1].op,99)<precedence[x.op] else right)
renderer = PatternMatcher([
(UPat((Ops.DEFINE_VAR,), name="x"), lambda x: x.arg[0]),
(UPat((Ops.SPECIAL), name="x"), lambda x: x.arg),
(UPat(Ops.RANGE, name="x"), lambda x: f"r{range_str(x)}"),
(UPat((Ops.CONST, Ops.VCONST), name="x"), lambda x: str(x.arg)),
(UPat(Ops.UNROLL, name="x"), lambda ctx,x,u: f"UNROLL({ctx[x.src[0]]}, {u.arg})"),
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({str(x.dtype)[7:]})({ctx[x.src[0]]})"),
(UPat(Ops.BIND, name="x"), lambda ctx,x: ctx[x.src[0]]),
(UPat(Ops.NEG, name="x"), lambda ctx,x: f"(-{ctx[x.src[0]]})"),
(UPat(Ops.RECIPROCAL, name="x"), lambda ctx,x: f"(1/{ctx[x.src[0]]})"),
(UPat(Ops.MAX, name="x"), lambda ctx,x: f"max({ctx[x.src[0]]}, {ctx[x.src[1]]})"),
(UPat(Ops.MULACC, name="x"), lambda ctx,x: f"({ctx[x.src[0]]}*{ctx[x.src[1]]}+{ctx[x.src[2]]})"),
(UPat(Ops.WHERE, name="x"), lambda ctx,x: f"({ctx[x.src[1]]} if {ctx[x.src[0]]} else {ctx[x.src[2]]})"),
(UPat(set(syms.keys()), name="x"), lambda ctx,x: strip_binary_parens(x, ctx[x.src[0]], ctx[x.src[1]], lambda a,b: f"({a}{syms[x.op]}{b})")),
(UPat((Ops.INDEX, Ops.BUFFERIZE), name="x"), lambda x, ctx: ''.join([f"[{strip_parens(ctx[y])}]" for y in x.src[1:]])),
(UPat(Ops.VECTORIZE, name="x"),
lambda ctx,x: f"{{{','.join([ctx[y] for y in x.src])}}}" if not all_same(x.src) else f"{{{ctx[x.src[0]]}, ...}}"),
(UPat(GroupOp.All, name="x"), lambda x: str(x)),
])
renderer_infer = PatternMatcher([
(UPat(Ops.MOD, name="x"), lambda ctx,x: f"cmod({ctx[x.src[0]]}, {ctx[x.src[1]]})"),
(UPat(Ops.IDIV, name="x"), lambda ctx,x: f"cdiv({ctx[x.src[0]]}, {ctx[x.src[1]]})"),
]) + renderer
# *** pyrender ***
def srcs(ctx, src): return f"({ctx[src[0]]},)" if len(src) == 1 else f"({', '.join([ctx[x] for x in src])})"
def render_marg(ctx,x:UOp):
if x.op is Ops.PERMUTE: return str(x.marg)
if x.op is Ops.FLIP: return str(tuple([i for i,x in enumerate(x.marg) if x]))
pieces = []
if x.op in {Ops.RESHAPE, Ops.EXPAND}:
pieces = [f"{ctx[a] if isinstance(a, UOp) else str(a)}" for a in x.marg]
if x.op in {Ops.PAD, Ops.SHRINK}:
pieces = [f"({ctx[a[0]] if isinstance(a[0], UOp) else str(a[0])}, {ctx[a[1]] if isinstance(a[1], UOp) else str(a[1])})" for a in x.marg]
return f"({','.join(pieces)})" if len(pieces) != 1 else f"({pieces[0]},)"
sugar = {Ops.SINK, Ops.END, Ops.STORE, Ops.LOAD, Ops.UNIQUE, Ops.SQRT, Ops.INDEX, Ops.REDUCE, Ops.AFTER, Ops.THREEFRY,
Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS, Ops.BARRIER, Ops.ASSIGN, Ops.DETACH}
pm_pyrender_extra = PatternMatcher([
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE, name="d"), UPat(Ops.UNIQUE, name="u")), name="x"),
lambda x,d,u: f"UOp.const({x.dtype}, {x.arg}, device={repr(d.arg)}, unique={u.arg})"),
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE, name="d"),), name="x"), lambda x,d: f"UOp.const({x.dtype}, {x.arg}, device={repr(d.arg)})"),
(UPat(Ops.CONST, name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"),
(UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x:
f"UOp.variable(\"{x.arg[0]}\", {x.arg[1]}, {x.arg[2]}{', dtype='+str(x.dtype) if x.dtype is not dtypes.index else ''})"),
(UPat((Ops.CAST, Ops.BITCAST), name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({x.dtype})"),
(UPat(Ops.SPECIAL, src=(UPat(Ops.CONST),), name="x"), lambda x: f"UOp.special({x.src[0].arg}, {repr(x.arg)}, dtype={x.dtype})"),
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), name="x"), lambda x,u,d:
f"UOp.new_buffer({repr(d.arg)}, {x.size}, {x.dtype}, {u.arg})"),
(UPat(Ops.COPY, src=(UPat(name="x"), UPat(Ops.DEVICE, name="d"))), lambda ctx,x,d: f"{ctx[x]}.copy_to_device({repr(d.arg)})"),
(UPat(Ops.REDUCE_AXIS, name="r"), lambda ctx,r: f"{ctx[r.src[0]]}.r({r.arg[0]}, {r.arg[1]})"),
# NOTE: range has srcs sometimes after control flow
(UPat(Ops.RANGE, src=(UPat(Ops.CONST, name="c"),), allow_any_len=True, name="x"), lambda ctx,x,c:
"UOp.range("+', '.join([str(c.arg)] + [str(y) for y in x.arg])+
(f', src={srcs(ctx, x.src[1:])}' if len(x.src) > 1 else '')+(', dtype='+str(x.dtype) if x.dtype is not dtypes.index else '')+")"),
# TODO: index shouldn't mismatch dtype
(UPat(Ops.INDEX, src=(UPat(), UPat()), allow_any_len=True, name="x"), lambda ctx,x:
f"{ctx[x.src[0]]}.index({ctx[x.src[1]]}, "+(f"{ctx[x.src[2]]}, " if len(x.src) > 2 else "")+
(f"dtype={x.dtype})" if x.src[0].dtype != x.dtype else "ptr=True)") if x.src[0].dtype.base != x.dtype else None),
# TODO: fix forced_reshape
(UPat(Ops.RESHAPE, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.forced_reshape({render_marg(ctx,x)})" if x.src[0].shape == x.shape else None),
(UPat(GroupOp.Movement, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({render_marg(ctx,x)})"),
# NOTE: CMPNE doesn't work cause there's no __rne__
(UPat(set(syms.keys())-{Ops.SUB, Ops.CMPNE}, src=(UPat(Ops.CONST, name="y"), UPat(name="z")), name="x"),
lambda ctx,x,y,z: strip_binary_parens(x, str(y.arg), ctx[z], lambda a,b: f"({a}{syms[x.op]}{b})")),
# NOTE: sub doesn't work cause it's written as add/mul
(UPat(set(syms.keys())-{Ops.SUB}, src=(UPat(name="y"), UPat(Ops.CONST, name="z")), name="x"), lambda ctx,x,y,z:
strip_binary_parens(x, ctx[y], str(z.arg), lambda a,b: f"({a}{syms[x.op]}{b})")),
(UPat(set(syms.keys())-{Ops.SUB}, name="x"), lambda ctx,x:
strip_binary_parens(x, ctx[x.src[0]], ctx[x.src[1]], lambda a,b: f"({a}{syms[x.op]}{b})")),
(UPat(sugar, src=(), name="x"), lambda x: f"UOp.{x.op.name.lower()}("+', '.join(([f'arg={repr(x.arg)}'] if x.arg is not None else []))+")"),
(UPat(sugar, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}("+', '.join([ctx[y] for y in x.src[1:]] + \
([f'arg={repr(x.arg)}'] if x.arg is not None else []))+")"),
])
# NOTE: you can remove pm_pyrender_extra and it'll still be correct
pm_pyrender = pm_pyrender_extra+PatternMatcher([
(UPat(Ops.KERNEL, name="u"), lambda ctx,u: f"UOp(Ops.KERNEL, src={srcs(ctx,u.src)}, arg=Kernel({ctx[u.arg.ast]}(), {u.arg.metadata}))"),
(UPat(GroupOp.All, name="u"), lambda ctx,u: f"UOp({u.op}, {u.dtype}, {srcs(ctx,u.src)}"+(f", {repr(u.arg)})" if u.arg is not None else ")")),
])
def pyrender(ast:UOp) -> str:
lst = list(ast.toposort())
cmap = consumer_map_from_toposort(lst)
not_rendered = {Ops.CONST, Ops.VCONST, Ops.DEVICE}
always_rendered = {Ops.DEFINE_GLOBAL, Ops.LOAD, Ops.SPECIAL, Ops.RANGE, Ops.CONTIGUOUS, Ops.VECTORIZE,
Ops.BUFFER, Ops.COPY, Ops.KERNEL, Ops.WHERE, Ops.END, Ops.ASSIGN}
to_render: set[UOp] = {ast}
for u in lst:
if u.op in {Ops.SINK}:
for s in u.src: to_render.add(s)
if u.op is Ops.STORE: to_render.add(u.src[1])
if u.op in {Ops.REDUCE, Ops.REDUCE_AXIS}: to_render.add(u.src[0])
if u.op in not_rendered: continue
# checking the consumers is not enough, you have to make sure it's not used twice by the one consumer
if len(cmap[u]) == 1 and len([x for x in list(cmap[u].keys())[0].src if x is u]) == 1 and u.op not in always_rendered: continue
to_render.add(u)
kernels: dict[UOp, tuple[str, str]] = {}
r: dict[UOp, str] = {}
ret: dict[str, str] = {}
for i,u in enumerate(lst):
if u.op is Ops.KERNEL:
if u.arg.ast not in kernels:
kernels[u.arg.ast] = (f"k{len(kernels)}", f"def k{len(kernels)}():\n " + pyrender(u.arg.ast).replace('\n', '\n ') + "\n return ast\n\n")
r[u.arg.ast] = kernels[u.arg.ast][0]
ren = cast(str, pm_pyrender.rewrite(u, ctx=r))
assert isinstance(ren, str)
if u.tag is not None: ren += f".rtag({repr(u.tag)})"
if u not in to_render: r[u] = ren
else:
r[u] = f"c{i}" if u is not lst[-1] else "ast"
ret[r[u]] = ren
return ''.join([v[1] for v in kernels.values()]) + '\n'.join([f"{k} = {strip_parens(v)}" for k,v in ret.items()])
# *** what was symbolic.py ***
sint = int|UOp
Variable = UOp
ConstLike = ConstType|InvalidType|Variable|tuple[ConstType|InvalidType, ...]