import sys, atexit, pickle from collections import defaultdict, deque from dataclasses import dataclass from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, buffers from tinygrad.ops import can_pad, identity_element, resolve, view_left, merge_views from tinygrad.codegen.symbolic import symbolic_simple from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, unwrap, flatten, getenv, pluralize from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, SPLIT_REDUCEOP from tinygrad.dtype import ImageDType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View, strides_for_shape from tinygrad.device import Buffer from tinygrad.spec import type_verify, kernel_spec # creation can recurse a lot sys.setrecursionlimit(10000) # **** schedule simplifier def simplify_stride0_reduce(reduce:UOp, x:UOp): # must be unmasked (NOTE: can be relaxed if not masked on stride 0 axis) if any(v.mask is not None for v in unwrap(x.st).views): return None # must have all stride 0 in the relevant axis (NOTE: can do partial) if not all(unwrap(x.st).views[-1].strides[axis] == 0 for axis in reduce.arg[1]) or not all_int(x.shape): return None prshape = prod(x.shape[i] for i in reduce.arg[1]) ret = x.shrink(tuple((0,s) if i not in reduce.arg[1] else (0,1) for i,s in enumerate(x.shape))) match reduce.arg[0]: case Ops.ADD: return ret*prshape case Ops.MUL: return ret.pow(prshape) case Ops.MAX: return ret # NOTE: Ops.MAX is passthrough def split_reduceop(reduce:UOp, x:UOp): if not SPLIT_REDUCEOP or not all_int(x.shape) or (prod(x.shape)//prod(reduce.shape))= 3: print(f"split {divisor}: {x.shape} -> {splitted.shape} -> {reduce.shape}") # reduce original axes, then split return splitted.r(*reduce.arg).r(reduce.arg[0], (len(reduce.shape),)).reshape(reduce.shape) sym = symbolic_simple+PatternMatcher([ # UOp with size 0 is zero (UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \ and not (root.base.op is Ops.CONST and root.base.arg == 0) else None), # DETACH and CONTIGUOUS_BACKWARD are NOOPs here (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]), # reduce of size 0 is the identity element (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None), # reduce on stride 0 is collapsed (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), simplify_stride0_reduce), # split_reduceop (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop), # COPY(CONST) creates a new CONST on the destination device (UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.arg)), # no COPY to same device, except clone (arg is True) (UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"), lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None), # remove cast to image when it's already a contiguous image (UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"))),)), lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None), # make things that can't be images not images (UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW, Ops.CONST, Ops.DEVICE}, name="u"), lambda u: u.replace(dtype=dt.base) if isinstance(dt:=u.dtype,ImageDType) and (prod(u.shape) != prod(dt.shape) or not any(u.shape[x]%4 == 0 for x in u.st.unit_stride_axes())) else None), # remove contiguous if we can just view the buffer (UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)), lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None), # contiguous/buffer/copy is already contiguous (UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY)),)), lambda root: root.src[0]), # substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK (UPat((Ops.BITCAST, Ops.CONTIGUOUS), src=(UPat.var("x"),), name="t"), lambda x,t: UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (t.size, x.st.views[0].offset)).reshape(t.shape) if x.device.startswith("DISK") else None), # remove CONST/BIND/VIEW from SINK (UPat(Ops.SINK, name="x"), lambda x: x.replace(src=new_src) if (new_src:=tuple(dedup(s.base for s in x.src if s.op not in {Ops.CONST,Ops.BIND}))) != x.src else None), ]) # support for using a contiguous permuted view instead of the parent view if one exists def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp): if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti) replace_contiguous = PatternMatcher([ (UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, name="src"),), name="contig"), found_contiguous), (UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None), ]) # reorder view reorder_view = PatternMatcher([ # put CAST to smaller dtype before EXPAND (UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm"),)), lambda cast,vm: vm.base.cast(cast.dtype).view(vm.st) if (not getenv("CAST_AFTER_EXPAND") or vm.base.op is not Ops.BUFFER) and cast.dtype.itemsize <= vm.dtype.itemsize and resolve(prod(vm.shape) > vm.st.real_size()) else None), # store a shrink before COPY, otherwise view after the COPY (UPat(Ops.COPY, src=(UPat(), UPat(Ops.VIEW, name="v")), name="copy"), lambda copy,v: v.contiguous().copy_to_device(copy.device) \ if prod(v.shape) < prod(v.base.shape) else v.base.copy_to_device(copy.device, clone=copy.arg).view(v.st)), # put UnaryOps before EXPANDs (UPat(GroupOp.Unary, src=UPat(Ops.VIEW, src=(UPat.var("inp"),), name="v"), name="alu"), lambda inp,v,alu: inp.alu(alu.op).view(v.st) if resolve(prod(alu.shape) > v.st.real_size()) else None), # put CAST after expanding BUFFER (UPat(Ops.VIEW, src=(UPat(Ops.CAST, src=(UPat.var("x"),)),), name="v"), lambda x,v: x.view(x.st+v.st).cast(v.dtype) if getenv("CAST_AFTER_EXPAND") and x.base.op is Ops.BUFFER and resolve(prod(v.shape) > prod(x.shape)) else None), ]) # **** UOp realization DONT_PUSH_VIEWS = {Ops.BUFFER, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.ASSIGN, Ops.SINK, Ops.CONTIGUOUS, Ops.COPY} @dataclass(frozen=True) class GrouperContext: assigns: dict[UOp, UOp] # maps realized buffers to assigns realizes: dict[UOp, None] # all the simplified tensor uops we realize children: defaultdict[UOp, dict[UOp, None]] # children graph of tensor uops def realize(ctx:GrouperContext, tr:UOp) -> None: ctx.realizes[tr] = None def realize_before_view(ctx:GrouperContext, view:UOp, tr:UOp) -> None: st = unwrap(view.st) # awlays realize unsafe pad ops before masked view if any(v.mask is not None for v in st.views) and not can_pad(tr, ctx.realizes, cache=dict()): return realize(ctx, tr) # fold simple pads if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(tr.shape) and resolve(prod(tr.shape) >= prod([y-x for x,y in m])): return # realize before expand if resolve(prod(tr.shape) < prod(st.shape)) and not DONT_REALIZE_EXPAND: return realize(ctx, tr) do_realize = PatternMatcher([ # always realize SINK parents (UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.realizes.update((x, None) for x in s.src if x.op not in DONT_PUSH_VIEWS)), # always realize ASSIGN/CONTIGUOUS/GroupOp.Meta (UPat({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}, name="tr"), realize), # realize before expand or unsafe pad ops (UPat(Ops.VIEW, name="view", src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="tr"),)), realize_before_view), # realize before COPY (UPat(Ops.COPY, src=(UPat(), UPat(GroupOp.All-DONT_PUSH_VIEWS, name="tr"))), realize), ]) def append_uop(ctx:GrouperContext, u:UOp) -> None: if u.op is Ops.ASSIGN: ctx.assigns[u.buf_uop] = u for s in u.src: ctx.children[s.base][u] = None create_ctx = PatternMatcher([(UPat(GroupOp.All-{Ops.SINK, Ops.VIEW}, name="u"), append_uop)]) def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, dict[UOp, None]], realizes:dict[UOp, None], reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None: """recursively search the uop for groupable children, realize the UOp if a child can't group""" if (tr, st) in cache: return cache.setdefault((tr, st)) rsize = unwrap(r.st).size if tr in realizes and tr is not r: # can only fuse contiguous # max one reduceop per kernel if not st.contiguous or st.size != rsize or tr in reduce_for_op: group.setdefault(r) return group.setdefault(tr) for tr_next in children[tr]: # max one reduceop per kernel if tr_next.op is Ops.REDUCE_AXIS: return group.setdefault(r) # can only fuse contiguous if len(st_childs:=dedup(unwrap(x.st) for x in tr_next.src if x.base == tr)) > 1: return group.setdefault(r) recursive_group(tr_next, st+st_childs[0], r, children, realizes, reduce_for_op, group, cache) def group_realizes(sink:UOp) -> dict[UOp, None]: # start by adding uops that always realize sink = graph_rewrite(sink, do_realize+create_ctx, ctx:=GrouperContext({}, {}, defaultdict(dict))) if DONT_GROUP_REDUCES: return ctx.realizes # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child) reduce_for_op: dict[UOp, UOp] = {} double_reduces: list[UOp] = [] for r in sink.toposort: if r.op is not Ops.REDUCE_AXIS: continue if FUSE_CONV_BW and r.src[0].base.op is Ops.REDUCE_AXIS and r.src[0] is not r.src[0].base: double_reduces.append(r) if r in ctx.realizes: continue group: dict[UOp, None] = {} recursive_group(r, unwrap(r.st), r, ctx.children, ctx.realizes, reduce_for_op, group, cache={}) # max one reduceop per kernel can_chase = all(tr not in reduce_for_op for tr in group) # TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs forced_realize = r in group # can only have one output if not forced_realize and len(group) > 1: forced_realize = True # can only fuse assign if no other assign_target is used in the kernel if not forced_realize and any(x.op is Ops.ASSIGN for x in group): parents = deque((r, *group)) while parents and not forced_realize: p = parents.pop().base if (assign:=ctx.assigns.get(p)) is not None and assign not in group: forced_realize, can_chase = True, False if p in ctx.realizes: continue parents.extend(p.src) if forced_realize or not group: tr = r if can_chase: # can chase this down to contiguous children st = unwrap(tr.st) while len(ctx.children[tr]) == 1: tr_next = next(iter(ctx.children[tr])) st_childs = dedup(unwrap(s.st) for s in tr_next.src if s.base is tr) if len(st_childs) > 1: break if st.size != st_childs[0].size: break st = st + st_childs[0] if not st.contiguous or tr_next.op is Ops.REDUCE_AXIS: break tr = tr_next # don't cast to higher size before store (tr cannot be realized if forced_realize) if tr.op is Ops.CAST and tr.dtype.itemsize > tr.src[0].dtype.itemsize: tr = tr.src[0].base group = {tr: None} ctx.realizes[tr] = None reduce_for_op.update((tr, r) for tr in group) if FUSE_ARANGE and r.arg[0] is Ops.ADD and r.src[0].base.op is Ops.CONST: # maybe fuse arange with its children if len(flatten(ctx.children[tr] for tr in group)) != 0: for tr in group: del ctx.realizes[tr] # fuse double reduces with no other child for reduceop in double_reduces: top_reduce = reduceop.src[0].base if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce] return ctx.realizes # **** create kernels @dataclass(frozen=True) class Kernel: ast: UOp metadata: tuple[Metadata, ...] = () def __repr__(self): return f"" @dataclass(frozen=True) class KernelContext: realizes: dict[UOp, None] ops_metadata: dict[UOp, Metadata] def create_kernel(ctx:KernelContext, x:UOp, b:UOp): kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), (m,) if (m:=ctx.ops_metadata.get(x)) else ())) buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset)) return UOp(Ops.ASSIGN, x.dtype, (buffer, kernel)).reshape(x.shape) DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER} def append_to_kernel(ctx:KernelContext, x:UOp): new_srcs: list[UOp] = [] metadata = dict.fromkeys(x.arg.metadata) for s in x.src: if s.op in DONT_PLACE_IN_KERNEL or s in ctx.realizes: new_srcs.append(s) else: new_srcs.extend(s.src) if (m:=ctx.ops_metadata.get(s)) is not None: metadata[m] = None if (new_src:=tuple(dedup(new_srcs))) != x.src: return x.replace(src=new_src, arg=Kernel(x.arg.ast, tuple(metadata))) create_kernels = PatternMatcher([ # always give assign/contiguous a kernel (UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), create_kernel), (UPat(Ops.CONTIGUOUS, name="x"), lambda ctx,x: create_kernel(ctx, x, UOp.new_buffer(x.device, x.size, x.dtype))), # create a buffer for COPY on the new device (UPat(Ops.COPY, src=(UPat(Ops.DEVICE, name="d"), UPat()), name="x"), lambda ctx,d,x: create_kernel(ctx, x, UOp.new_buffer(d.arg, x.size, x.dtype))), # otherwise check the context if we're realizing this UOp (UPat(GroupOp.All-DONT_PLACE_IN_KERNEL, name="x"), lambda ctx,x: create_kernel(ctx, x, UOp.new_buffer(x.device, x.size, x.dtype)) if x in ctx.realizes else None), # walk back the local graph until we reach a buffer/assign parent (UPat(Ops.KERNEL, name="x"), append_to_kernel), # remove downstream reshapes from SINK (UPat(Ops.SINK, name="x"), lambda x:x.replace(src=tuple(s.base for s in x.src)) if any(s.op is Ops.VIEW for s in x.src) else None), ]) # **** swizzler def apply_swizzle(u:UOp) -> UOp: with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left) def swizzle_reduceop(r:UOp, src:UOp, view:UOp): if (st:=unwrap(view.st)).contiguous: return None input_st = ShapeTracker.from_shape(src.shape) tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg) prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):]) strides = strides_for_shape(rshape) nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides, v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in st.views] # create a new reduceop for the swizzled input new_input_st = tmp + ShapeTracker(tuple(nv)) new_axis = tuple(range(len(st.shape), len(st.shape) + len(r.axis_arg))) return UOp(Ops.REDUCE_AXIS, r.dtype, (apply_swizzle(src.view(src.arg+new_input_st if src.op is Ops.VIEW else new_input_st)),), (r.arg[0], new_axis)).view(ShapeTracker.from_shape(st.shape)) def reduceop_view_right(src:UOp, v:UOp, r:UOp): assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}" return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u)).view(ShapeTracker.from_shape(r.shape)) def elementwise_view_right(root:UOp): if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW and x.base.op not in DONT_PUSH_VIEWS]): return None assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}" # place view after applying the elementwise op new_st = ShapeTracker.from_shape(swizzles[0].base.shape) new_src = [x.base if x.base.shape==new_st.shape else apply_swizzle(x.view(x.arg+new_st) if x.op is Ops.VIEW else x.view(new_st)) for x in root.src] # reshape to match downstream shapes return root.replace(src=tuple(new_src)).reshape(root.shape) # push VIEW to children view_right = merge_views+PatternMatcher([ # push a non contiguous ShapeTracker through reduceop (UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop), # apply view after reduceops (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="src"),), name="v"),), name="r"), reduceop_view_right), # apply view after elementwise ops (UPat(GroupOp.All-DONT_PUSH_VIEWS, name="root"), elementwise_view_right), # merge axes for double reduce (invert of SPLIT_REDUCEOP=1) (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"), lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] == r2.arg[0] else None), ]) # **** unbind variables def unbind_shapetracker(ctx:tuple[dict[Variable, int], tuple[UOp, ...]], x:UOp): st = unwrap(x.st).simplify() if any(x.op is Ops.BIND for x in st.vars()): st, var_vals = st.unbind() ctx[0].update(var_vals) return x.replace(arg=st) if st != x.st else None def unbind_variable(ctx:tuple[dict[Variable, int], tuple[UOp, ...]], var:UOp, val:UOp): ctx[0][var.replace(src=())] = val.arg return var # **** fix kernel AST add_buffer_ops = PatternMatcher([ # LOAD (UPat(Ops.BUFFER, name="x"), lambda ctx,x:UOp.load(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx[1].index(x)), x.st.to_uop(), dtype=x.dtype)), # STORE (except for meta ops) (UPat(Ops.SINK, src=(UPat(GroupOp.Meta, name="x"),)), lambda x:x), # partial assign can store to a non-contiguous ShapeTracker (UPat(Ops.SINK, src=(UPat(Ops.ASSIGN, name="x"),)), lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.src[0].base.size), (), 0), x.src[0].st.to_uop(), x.src[1]).sink()), # otherwise the store is contiguous (UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)), lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()), # VALID (UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),), name="view"), lambda x,view: x.valid(view.arg)), # if the last child is a VIEW we merge the ShapeTrackers and store the base (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="x"),)))), lambda x,b,st: UOp.store(b, (st.arg+x.st).to_uop(), x)), ]) def check_load_st(glbl:UOp, view:UOp): if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return # if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return # if it has a single view and it's equal when you shrink a contig, it's fine if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return # otherwise, it's not fine raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) fix_kernel_ops = PatternMatcher([ # remove CONTIGUOUS/DEVICE from kernel AST (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x), (UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())), # BIND in shapetracker becomes DEFINE_VAR (UPat(Ops.VIEW, name="x"), unbind_shapetracker), (UPat(Ops.BIND, src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable), # no ImageDType after load (UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None), # if this kernel also assigns to the loaded buffer, ensure we can index it correctly (UPat(Ops.LOAD, src=(UPat.var("glbl"), UPat.var("view"))), check_load_st), ]) def fix_kernel_ast(k:UOp, var_vals:dict[Variable, int]) -> UOp: assert k.op is Ops.KERNEL, f"kernel isn't kernel, it's {k}" # substitute kernel sources for the target buffer + apply reshapes parents_rep: dict[UOp, UOp] = {} for s in k.src: if s.op is Ops.ASSIGN: for out in s.src[1].arg.ast.src: parents_rep[out] = s.buf_uop.view(unwrap(out.st)) ast = k.arg.ast.substitute(parents_rep) # push views to edges ast = graph_rewrite(graph_rewrite(ast, view_left), view_right) # add buffer ops + fix_kernel_ops ast = graph_rewrite(ast, merge_views+add_buffer_ops+fix_kernel_ops, ctx=(var_vals, bufs:=tuple(s.buf_uop for s in k.src)), bottom_up=True) if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}") # create subbuffer (TODO: this does not belong here) if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = (base:=bufs[1].buffer).view(ast.size, ast.dtype, ast.arg[1]*base.dtype.itemsize) return k.replace(arg=Kernel(ast, k.arg.metadata)) PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {} if CAPTURE_PROCESS_REPLAY: @atexit.register def save_process_replay(): for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True) # **** schedule creation and toposort @dataclass(frozen=True) class ScheduleItem: ast: UOp bufs: tuple[Buffer, ...] metadata: tuple[Metadata, ...] @track_rewrites(name_fxn=lambda r: f"Schedule {pluralize('Kernel', len(r[0]))}"+(f" (with_{pluralize('Var', len(r[1]))})" if len(r[1]) != 0 else "")) def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]: # merge_views + sym + reorder_view + replace_contiguous tensor_map = graph_rewrite_map(big_sink, merge_views+sym+reorder_view+replace_contiguous, ctx={}) # display the cleaned up tensor graph if getenv("VIZ"): graph_rewrite(tensor_map[big_sink], PatternMatcher([]), name="View Tensor Graph") # get realizes sink = tensor_map[big_sink] realize_map = group_realizes(sink) # map tensor metadata to simplified ops ops_metadata = {v:k.metadata for k,v in tensor_map.items() if k.base.op not in {Ops.CONST, Ops.DEVICE} and isinstance(k.metadata, Metadata)} # merge_views + create_kernels kernel_map = graph_rewrite_map(sink, merge_views+create_kernels, ctx=KernelContext(realize_map, ops_metadata), bottom_up=True) sched_sink = kernel_map[sink] type_verify(list(sched_sink.toposort), kernel_spec) # map tensors to buffer/const, optionally apply a VIEW on top becomes_map: dict[UOp, UOp] = {} for k,v in tensor_map.items(): # ASSIGN always becomes the target buffer if v.op is Ops.ASSIGN: becomes_map[k] = v.src[0] # if we created a new buffer for this tensor, map it to the assigned buffer elif (a:=kernel_map.get(v.base)) is not None and (a:=a.base).op is Ops.ASSIGN: becomes_map[k] = a.src[0] if a.src[0].st == v.st else a.src[0].view(unwrap(v.st)) # tensors can also simplify to an existing buffer/const else: if k is v: continue if v.base.op is Ops.BUFFER: becomes_map[k] = v if v.base.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v # if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign kernel_assign: dict[UOp, UOp] = {} assign_rep: dict[UOp, UOp] = {} for u in sched_sink.toposort: if u.op is not Ops.ASSIGN: continue kernel_assign[u.buf_uop] = u for s in u.src[1].src: if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort): raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER") assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,)) if assign_rep: sched_sink = sched_sink.substitute(assign_rep) type_verify(list(sched_sink.toposort), kernel_spec) # display the final graph if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph") if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Memory Graph") # final toposort (bfs) children: dict[UOp, list[UOp]] = {} in_degree: dict[UOp, int] = {} for u in sched_sink.toposort: if u.op is not Ops.ASSIGN: continue in_degree[u] = 0 for s in u.src[1].src: if s.op is not Ops.ASSIGN: continue children.setdefault(s, []).append(u) in_degree[u] += 1 queue = deque(k for k,v in in_degree.items() if v == 0) schedule: list[ScheduleItem] = [] var_vals: dict[Variable, int] = {} while queue: u = queue.popleft() # TODO: move this to create_kernels k = fix_kernel_ast(u.src[1], var_vals) schedule.append(ScheduleItem(k.arg.ast, tuple(s.buf_uop.buffer for s in k.src), k.arg.metadata)) for x in children.get(u, []): in_degree[x] -= 1 if in_degree[x] == 0: queue.append(x) # confirm everything was scheduled correctly if len(schedule) != (kc:=len(in_degree)): raise RuntimeError(f"cycle detected in graph, created {kc} kernels but only scheduled {len(schedule)}") if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels") # capture process replay if CAPTURE_PROCESS_REPLAY: with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(big_sink.key)] = pickle.dumps((big_sink, ContextVar._cache, [x.ast for x in schedule])) return schedule, var_vals, becomes_map