from dataclasses import dataclass from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve from tinygrad.uop.ops import track_rewrites, _substitute from tinygrad.uop.spec import type_verify, tensor_uop_spec from tinygrad.uop.symbolic import symbolic_simple from tinygrad.helpers import Metadata, all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP from tinygrad.dtype import ImageDType from tinygrad.schedule.multi import multi_pm from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS from tinygrad.codegen.opt.swizzler import merge_views, apply_swizzle, swizzle_reduceop # creation can recurse a lot import sys sys.setrecursionlimit(10000) # **** schedule simplifier def simplify_stride0_reduce(reduce:UOp, x:UOp): # must be unmasked (NOTE: can be relaxed if not masked on stride 0 axis) if any(v.mask is not None for v in unwrap(x.st).views): return None # must have all stride 0 in the relevant axis (NOTE: can do partial) if not all(unwrap(x.st).views[-1].strides[axis] == 0 for axis in reduce.arg[1]) or not all_int(x.shape): return None prshape = prod(x.shape[i] for i in reduce.arg[1]) ret = x.shrink(tuple((0,s) if i not in reduce.arg[1] else (0,1) for i,s in enumerate(x.shape))) match reduce.arg[0]: case Ops.ADD: return ret*prshape case Ops.MUL: return ret.pow(prshape) case Ops.MAX: return ret # NOTE: Ops.MAX is passthrough def split_reduceop(reduce:UOp, x:UOp): if not SPLIT_REDUCEOP or not all_int(x.shape) or (prod(x.shape)//prod(reduce.shape))= 3: print(f"split {divisor}: {x.shape} -> {splitted.shape} -> {reduce.shape}") # reduce original axes, then split return splitted.r(*reduce.arg).r(reduce.arg[0], (len(reduce.shape),)).reshape(reduce.shape) def copy_reorder_view(copy:UOp, view:UOp, base:UOp): if prod(view.shape) < prod(base.shape): return view.contiguous().copy_to_device(copy.device) return base.copy_to_device(copy.device).view(view.arg) sym = symbolic_simple+PatternMatcher([ # UOp with size 0 is zero (UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 else None), # DETACH and CONTIGUOUS_BACKWARD are NOOPs here (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]), # reduce of size 0 is the identity element (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None), # reduce on stride 0 is collapsed (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), simplify_stride0_reduce), # split_reduceop (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop), # COPY(CONST) creates a new CONST on the destination device (UPat(Ops.COPY, name="root", src=(UPat.cvar("x"), UPat(Ops.DEVICE))), lambda root,x: root.const_like(x.arg)), # non device changing COPY is a NOOP (UPat(Ops.COPY, name="c", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda c,x: x if c.device == x.device else None), # store a shrink before COPY, otherwise view after the COPY (UPat(Ops.COPY, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"), UPat(Ops.DEVICE)), name="copy"), copy_reorder_view), # remove cast to image when it's already a contiguous image (UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"),)),)), lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None), # CAST before masking constants (UPat.cvar("x").view().cast(name="c"), lambda x,c: x.cast(c.dtype).view(c.src[0].arg)), # make things that can't be images not images (UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW, Ops.CONST, Ops.DEVICE}, name="u"), lambda u: u.replace(dtype=dt.base) if isinstance(dt:=u.dtype,ImageDType) and (prod(u.shape) != prod(dt.shape) or not any(u.shape[x]%4 == 0 for x in u.st.unit_stride_axes())) else None), # remove contiguous if we can just view the buffer (UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)), lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None), # contiguous/buffer/copy/assign is already contiguous (UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY, Ops.ASSIGN)),)), lambda root: root.src[0]), # substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK (UPat((Ops.BITCAST, Ops.CONTIGUOUS), src=(UPat.var("x"),), name="t"), lambda x,t: UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (t.size, x.st.views[0].offset)).reshape(t.shape) if isinstance(x.device, str) and x.device.startswith("DISK") else None), # double ASSIGN to same target is one ASSIGN (UPat(Ops.ASSIGN, src=(UPat.var("t"), UPat(Ops.ASSIGN, src=(UPat.var("t"), UPat.var("x"))))), lambda x,t: t.assign(x.contiguous())), # ASSIGN to unrealized replaces the UOp (UPat(Ops.ASSIGN, src=(UPat.var("t"), UPat.var("x"))), lambda x,t: x.contiguous() if t.base.op not in {Ops.BUFFER, Ops.BUFFER_VIEW} and not (t.base.op is Ops.MSTACK and all(x.op is Ops.BUFFER for x in t.base.src)) else None), # put CAST to smaller dtype before EXPAND (UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm"),)), lambda cast,vm: vm.base.cast(cast.dtype).view(vm.st) if cast.dtype.itemsize <= vm.dtype.itemsize and resolve(prod(vm.shape) > vm.st.real_size()) else None), # put UnaryOps before EXPANDs, if it can fuse with the input (UPat(GroupOp.Unary, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="inp"),), name="v"),), name="alu"), lambda inp,v,alu: inp.alu(alu.op).view(v.st) if resolve(prod(alu.shape) > v.st.real_size()) else None), ]) # support for using a contiguous permuted view instead of the parent view if one exists def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp): if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti) replace_contiguous = PatternMatcher([ (UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, name="src"),), name="contig"), found_contiguous), (UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None), ]) # **** create kernels @dataclass(frozen=True) class Kernel: ast: UOp metadata: tuple[Metadata, ...] = () def __repr__(self): ast_rep = f"SINK{tuple(s.op for s in self.ast.src)}" if self.ast.op is Ops.SINK else repr(self.ast.op) return f"" def create_kernel(x:UOp, b:UOp|None=None): if b is None: b = UOp.new_buffer(x.device, x.size, x.dtype) kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), m if (m:=x.metadata) else ())) buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset)) return buffer.assign(kernel).reshape(x.shape) DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.MULTI, Ops.BIND} def append_to_kernel(x:UOp): new_srcs: list[UOp] = [] metadata = x.arg.metadata for s in x.src: if s.op in DONT_PLACE_IN_KERNEL: new_srcs.append(s) else: new_srcs.extend(s.src) # NOTE: because const and device are shared UOps they don't change metadata # NOTE: if it's a reshape after ASSIGN we're not fusing that parent kernel if s.base.op not in {Ops.CONST, Ops.DEVICE} and (not (s.op is Ops.RESHAPE and s.base.op is Ops.ASSIGN)) and (m:=s.metadata): metadata += m if (new_src:=tuple(dedup(new_srcs))) != x.src: return x.replace(src=new_src, arg=Kernel(x.arg.ast, tuple(dedup(metadata)))) create_kernels = PatternMatcher([ # always give assign/contiguous a kernel (UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), create_kernel), (UPat(Ops.CONTIGUOUS, name="x"), create_kernel), # walk back the local graph until we reach a realized source (UPat(Ops.KERNEL, name="x"), append_to_kernel), # push RESHAPE through MSELECT (UPat(Ops.MSELECT, src=(UPat(Ops.RESHAPE, name="r"),), name="ms"), lambda ms,r: r.src[0].mselect(ms.arg).reshape(r.arg)), # push RESHAPE through MSTACK (UPat(Ops.MSTACK, src=UPat(Ops.RESHAPE), name="ms"), lambda ms: UOp(Ops.MSTACK, ms.dtype, tuple(x.src[0] for x in ms.src)).reshape(ms.src[0].arg)), ]) # **** fix kernel AST def unbind_view(x:UOp): if any(x.op is Ops.BIND for x in x.arg.vars()): return x.replace(arg=x.arg.unbind()[0]) return None replace_buffers = PatternMatcher([ # replace ASSIGN with the target BUFFER (UPat(Ops.ASSIGN, src=(UPat((Ops.BUFFER, Ops.LOAD)), UPat(Ops.KERNEL)), name="assign", allow_any_len=True), lambda assign: assign.src[0]), # HACK: select the 0 branch of MSTACK (the device is wrong after this, is that okay?) (UPat(Ops.MSTACK, name="x"), lambda x: x.src[0]), # LOAD (UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)).load()), # no SINK for meta ops (UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x), # STORE (except for meta ops) (UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda ctx,sink: UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i).view(s.st), s) for i,x in enumerate(sink.src)], arg=sink.arg)), # remove CONTIGUOUS/DEVICE from kernel AST (UPat((Ops.CONTIGUOUS, Ops.MSELECT), src=(UPat.var("x"),)), lambda x: x), (UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())), # passthrough ASSIGN (but let MSTACK process first) (UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.MSTACK}), UPat()), name="x"), lambda x: x.src[1]), # remove any BINDs from VIEWS (UPat(Ops.VIEW, src=(UPat(), UPat((Ops.BIND, Ops.DEFINE_VAR))), allow_any_len=True, name="x"), lambda x: x.replace(src=x.src[0:1])), # remove any BINDs from DEFINE_VARs (UPat(Ops.BIND, name="x"), lambda x: x.src[0]), # remove BINDs from ShapeTrackers (UPat(Ops.VIEW, name="x"), unbind_view), ]) def fix_kernel_ast(k:UOp) -> UOp|None: if k.arg.ast.op in GroupOp.Meta or all(s.op is Ops.STORE for s in k.arg.ast.src): return None # replace buffer with define_global + add load/store last bufs = [] for s in k.src: if s.op is Ops.BIND: continue s = s.buf_uop # traverse back through MSELECT and MSTACK. HACK: 0 branch of MSTACK only while s.op in {Ops.MSELECT, Ops.MSTACK}: s = s.src[0] bufs.append(s) # replace global memory ops with the BUFFER they write to # NOTE: merge_views is needed to unbind the reshapes ast = graph_rewrite(k.arg.ast, merge_views+replace_buffers, bufs, bottom_up=True, name="replace buffers") if ast.op is Ops.SINK and not all_same([x.device for x in k.src if x.op is not Ops.BIND]): raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}") return k.replace(arg=Kernel(ast, k.arg.metadata)) create_ast = PatternMatcher([ (UPat(Ops.KERNEL, name="k"), fix_kernel_ast), (UPat(Ops.DEFINE_VAR, src=(UPat(),), allow_any_len=True, name="x"), lambda x: x.replace(src=())), ]) # ** add metadata of KERNEL outputs def append_metadata(root:UOp, k:UOp): if not root.metadata or (new_metadata:=tuple(dedup(k.arg.metadata+root.metadata))) == k.arg.metadata: return None return root.replace(src=(root.src[0], k.replace(arg=Kernel(k.arg.ast, new_metadata)))+root.src[2:]) replace_metadata = PatternMatcher([(UPat(Ops.ASSIGN, src=(UPat(), UPat(Ops.KERNEL, name="k")), name="root", allow_any_len=True), append_metadata),]) pm_fuse = PatternMatcher([ # FUSE on CONTIGUOUS removes FUSE (UPat(Ops.CONTIGUOUS, name="c").fuse(), lambda c: c), # FUSE triggers swizzle on reduceop (UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").or_casted(),), name="view").fuse(), lambda r,src,view: ret.cast(view.dtype) if (ret:=swizzle_reduceop(r, src, view, fuse=True)) is not None else None), # FUSE on reduce (without view) adds fuse marker to grouper (UPat(Ops.REDUCE_AXIS, name="r").fuse(), lambda r: r.replace(src=(r.src[0].fuse(),), arg=r.arg+(True,)) if len(r.arg) == 2 else None), # remove FUSE and insert CONTIGUOUS if it's an unsafe pad (UPat(Ops.VIEW, src=(UPat(GroupOp.UnsafePad, name="alu"),), name="view").fuse(), lambda alu, view: alu.contiguous().view(view.st) if any(v.mask is not None for v in view.st.views) else None), # FUSE elementwise. (UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST}, name="alu"),), name="view").fuse(), lambda alu, view: alu.replace(src=tuple(apply_swizzle(x.view(view.arg)).fuse() for x in alu.src))), # push FUSE through to srcs (UPat(Ops.FUSE, name="x"), lambda x: x.src[0].replace(src=tuple(y.fuse() for y in x.src[0].src))), ]) def do_fusion(x:UOp): found_contiguous = {} def gate_contiguous(x): if is_contiguous:=(x.op is Ops.CONTIGUOUS): found_contiguous[x] = x.replace(src=(UOp(Ops.VIEW, arg=x.st), UOp.unique())) return not is_contiguous x.toposort(gate=gate_contiguous) del gate_contiguous return graph_rewrite(x.substitute(found_contiguous), pm_fuse, name="local fusion").substitute({v:k for k,v in found_contiguous.items()}) def fuse_arange(root:UOp): # skip if root is arange if not FUSE_ARANGE or root.src[0].base.op is Ops.CONST: return None # gather all local aranges (including any fused ones) local_arange: list[UOp] = [] def gate_reduce(u): if u.op is Ops.REDUCE_AXIS and u.src[0].base.op is Ops.CONST: local_arange.append(u) return u.op not in {*ALWAYS_CONTIGUOUS, Ops.REDUCE_AXIS} or u is root toposort = root.toposort(gate=gate_reduce) if not local_arange: return None # fuse the nearest expand child of arange local_children: dict[UOp, list[UOp]] = {} for u in toposort: for s in u.src: local_children.setdefault(s, []).append(u) fuse_rep: dict[UOp, UOp] = {} for r in local_arange: # skip if already fused if len(r.arg) > 2: continue q = list(local_children[r]) while q: u = q.pop() if not (curr_children:=local_children.get(u, [])): continue for child in curr_children: other_paths = {s for s in child.toposort() if s.op in {Ops.REDUCE_AXIS, Ops.BUFFER} and s not in {root, r}} fuse_rep[child] = child.replace(src=tuple(s.fuse() if s is u else s for s in child.src)) if other_paths: break else: q.extend(curr_children) return root.substitute(fuse_rep, name="fuse_arange") if fuse_rep else None do_fuse = PatternMatcher([ (UPat(Ops.FUSE, name="x"), do_fusion), (UPat(Ops.REDUCE_AXIS, name="root"), fuse_arange), ]) add_contiguous = PatternMatcher([(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"), lambda ctx,x: x.replace(tag=1).contiguous() if x in ctx and x.tag is None else None)]) # TODO: get this from the device through GrouperOpts DEVICE_MAX_BUFS = {"METAL":32, "WEBGPU":8} def limit_bufs(root:UOp): # check if backend has a buffer limit device = root.device if isinstance(root.device, str) else root.device[0].split(":")[0] if not (MAX_BUFS:=getenv("MAX_KERNEL_BUFFERS", DEVICE_MAX_BUFS.get(device, 0))): return None # count number of unique buffers flowing into this op bufs: set[UOp] = set() def gate_input(u:UOp): if (is_load:=(u.op in {Ops.BUFFER, Ops.CONTIGUOUS, Ops.ASSIGN, Ops.MSTACK})): bufs.add(u) return not is_load root.toposort(gate=gate_input) # NOTE: this -1 is for the output buffer if len(bufs)>=MAX_BUFS-1: return root.replace(src=tuple(s if s.base in bufs else s.replace(tag=1).contiguous() for s in root.src)) def view_add_srcs(x:UOp): if len(avars:=x.arg.vars()) and len(x.src) == 1: return x.replace(src=x.src+tuple(avars)) return None finalize_contiguous = PatternMatcher([ # if an op takes more than one input, check combined LOADs don't exceed device limits (UPat(set.union(GroupOp.Binary, GroupOp.Ternary), name="root"), limit_bufs), # merge contiguous (UPat(Ops.CONTIGUOUS, src=(UPat(Ops.CONTIGUOUS),), name="x"), lambda x: x.src[0]), # simplify views (UPat(Ops.VIEW, src=(UPat.var('x')), name="v"), lambda x,v: x.view(new_st) if (new_st:=v.arg.simplify()) != v.arg else None), # vars to views srcs (UPat(Ops.VIEW, name="x"), view_add_srcs), ]) remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)]) @track_rewrites(name=lambda sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[sink].toposort() if u.op is Ops.KERNEL]))}") def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]: """ Function to transform the Tensor UOp graph into a version with Ops.KERNEL Args: sink: The Ops.SINK rooting the Tensor graph. Returns: Map transforming each UOp in the sink to the Ops.KERNEL graph. """ # multi + merge_views + simplify tensor_map = graph_rewrite_map(sink, multi_pm+do_fuse+merge_views+sym+replace_contiguous, ctx={}, name="merge_views") # display the cleaned up tensor graph if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Tensor Graph") # insert contiguous in places determined by the realize map realize_map = group_realizes(tensor_map[sink]) tensor_map = graph_rewrite_map(tensor_map[sink], add_contiguous, ctx=realize_map, bottom_up=True, input_map=tensor_map, name="add_contiguous") tensor_map = graph_rewrite_map(tensor_map[sink], finalize_contiguous+remove_tags, input_map=tensor_map, name="finalize_contiguous") # group into kernels (this is context-free) tensor_map = graph_rewrite_map(tensor_map[sink], create_kernels, input_map=tensor_map, name="create_kernels") # if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign kernel_assign: dict[UOp, UOp] = {} assign_rep: dict[UOp, UOp] = {} for u in tensor_map[sink].toposort(): if u.op is not Ops.ASSIGN: continue kernel_assign[u.buf_uop] = u for s in u.src[1].src: # TODO: this is probably broken for MSELECT/MSTACK if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort()): raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER") assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,)) if assign_rep: tensor_map = graph_rewrite_map(tensor_map[sink], _substitute, ctx=assign_rep, bottom_up=True, input_map=tensor_map, name="fix_assign") # finally, create the AST for kernels tensor_map = graph_rewrite_map(tensor_map[sink], create_ast+replace_metadata, bottom_up=True, input_map=tensor_map, name="create_ast") # display the final graph sched_sink = tensor_map[sink] if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph") # verify Kernels match the spec if __debug__: type_verify(list(sched_sink.toposort()), tensor_uop_spec) return tensor_map