openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

484 lines
26 KiB

import sys, atexit, pickle
from collections import defaultdict, deque
from dataclasses import dataclass
from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, buffers
from tinygrad.ops import can_pad, identity_element, resolve, view_left, merge_views
from tinygrad.codegen.symbolic import symbolic_simple
from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, unwrap, flatten, getenv, pluralize
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, SPLIT_REDUCEOP
from tinygrad.dtype import ImageDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View, strides_for_shape
from tinygrad.device import Buffer
from tinygrad.spec import type_verify, kernel_spec
# creation can recurse a lot
sys.setrecursionlimit(10000)
# **** schedule simplifier
def simplify_stride0_reduce(reduce:UOp, x:UOp):
# must be unmasked (NOTE: can be relaxed if not masked on stride 0 axis)
if any(v.mask is not None for v in unwrap(x.st).views): return None
# must have all stride 0 in the relevant axis (NOTE: can do partial)
if not all(unwrap(x.st).views[-1].strides[axis] == 0 for axis in reduce.arg[1]) or not all_int(x.shape): return None
prshape = prod(x.shape[i] for i in reduce.arg[1])
ret = x.shrink(tuple((0,s) if i not in reduce.arg[1] else (0,1) for i,s in enumerate(x.shape)))
match reduce.arg[0]:
case Ops.ADD: return ret*prshape
case Ops.MUL: return ret.pow(prshape)
case Ops.MAX: return ret # NOTE: Ops.MAX is passthrough
def split_reduceop(reduce:UOp, x:UOp):
if not SPLIT_REDUCEOP or not all_int(x.shape) or (prod(x.shape)//prod(reduce.shape))<getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return None
# if there are few globals, make some reduces into globals by splitting into two kernels
# cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm
# ~2**10 should be enough if GROUP is used
# 256 split maximum should be "negligible reduce" for low prod(reduce.shape), 8 split minimum.
# split is moved to the end to provide maximum locality for the second phase reduce.
real_strides = unwrap(x.st).real_strides(ignore_valid=True)
if not (split_candidates:=[(i,d) for i in reduce.arg[1] for d in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(reduce.shape)),8-1,-1)
if x.shape[i]%d==0 and real_strides[i]!=0]): return None
dim_to_split, divisor = split_candidates[0]
splitted_shape = x.shape[:dim_to_split]+(divisor,)+(x.shape[dim_to_split]//divisor,)+x.shape[dim_to_split+1:]
splitted = x.reshape(splitted_shape).permute(tuple([d for d in range(len(splitted_shape)) if d!=dim_to_split]+[dim_to_split]))
if DEBUG >= 3: print(f"split {divisor}: {x.shape} -> {splitted.shape} -> {reduce.shape}")
# reduce original axes, then split
return splitted.r(*reduce.arg).r(reduce.arg[0], (len(reduce.shape),)).reshape(reduce.shape)
sym = symbolic_simple+PatternMatcher([
# UOp with size 0 is zero
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
# reduce of size 0 is the identity element
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
# reduce on stride 0 is collapsed
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), simplify_stride0_reduce),
# split_reduceop
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop),
# COPY(CONST) creates a new CONST on the destination device
(UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.arg)),
# no COPY to same device, except clone (arg is True)
(UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"),
lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),
# remove cast to image when it's already a contiguous image
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"))),)),
lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
# make things that can't be images not images
(UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW, Ops.CONST, Ops.DEVICE}, name="u"), lambda u: u.replace(dtype=dt.base) if isinstance(dt:=u.dtype,ImageDType)
and (prod(u.shape) != prod(dt.shape) or not any(u.shape[x]%4 == 0 for x in u.st.unit_stride_axes())) else None),
# remove contiguous if we can just view the buffer
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
# contiguous/buffer/copy is already contiguous
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY)),)), lambda root: root.src[0]),
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), src=(UPat.var("x"),), name="t"),
lambda x,t: UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (t.size, x.st.views[0].offset)).reshape(t.shape) if x.device.startswith("DISK") else None),
# remove CONST/BIND/VIEW from SINK
(UPat(Ops.SINK, name="x"), lambda x: x.replace(src=new_src)
if (new_src:=tuple(dedup(s.base for s in x.src if s.op not in {Ops.CONST,Ops.BIND}))) != x.src else None),
])
# support for using a contiguous permuted view instead of the parent view if one exists
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
replace_contiguous = PatternMatcher([
(UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, name="src"),), name="contig"), found_contiguous),
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
])
# reorder view
reorder_view = PatternMatcher([
# put CAST to smaller dtype before EXPAND
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm"),)), lambda cast,vm: vm.base.cast(cast.dtype).view(vm.st)
if (not getenv("CAST_AFTER_EXPAND") or vm.base.op is not Ops.BUFFER) and cast.dtype.itemsize <= vm.dtype.itemsize
and resolve(prod(vm.shape) > vm.st.real_size()) else None),
# store a shrink before COPY, otherwise view after the COPY
(UPat(Ops.COPY, src=(UPat(), UPat(Ops.VIEW, name="v")), name="copy"), lambda copy,v: v.contiguous().copy_to_device(copy.device) \
if prod(v.shape) < prod(v.base.shape) else v.base.copy_to_device(copy.device, clone=copy.arg).view(v.st)),
# put UnaryOps before EXPANDs
(UPat(GroupOp.Unary, src=UPat(Ops.VIEW, src=(UPat.var("inp"),), name="v"), name="alu"),
lambda inp,v,alu: inp.alu(alu.op).view(v.st) if resolve(prod(alu.shape) > v.st.real_size()) else None),
# put CAST after expanding BUFFER
(UPat(Ops.VIEW, src=(UPat(Ops.CAST, src=(UPat.var("x"),)),), name="v"), lambda x,v: x.view(x.st+v.st).cast(v.dtype) if getenv("CAST_AFTER_EXPAND")
and x.base.op is Ops.BUFFER and resolve(prod(v.shape) > prod(x.shape)) else None),
])
# **** UOp realization
DONT_PUSH_VIEWS = {Ops.BUFFER, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.ASSIGN, Ops.SINK, Ops.CONTIGUOUS, Ops.COPY}
@dataclass(frozen=True)
class GrouperContext:
assigns: dict[UOp, UOp] # maps realized buffers to assigns
realizes: dict[UOp, None] # all the simplified tensor uops we realize
children: defaultdict[UOp, dict[UOp, None]] # children graph of tensor uops
def realize(ctx:GrouperContext, tr:UOp) -> None: ctx.realizes[tr] = None
def realize_before_view(ctx:GrouperContext, view:UOp, tr:UOp) -> None:
st = unwrap(view.st)
# awlays realize unsafe pad ops before masked view
if any(v.mask is not None for v in st.views) and not can_pad(tr, ctx.realizes, cache=dict()): return realize(ctx, tr)
# fold simple pads
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(tr.shape) and resolve(prod(tr.shape) >= prod([y-x for x,y in m])): return
# realize before expand
if resolve(prod(tr.shape) < prod(st.shape)) and not DONT_REALIZE_EXPAND: return realize(ctx, tr)
do_realize = PatternMatcher([
# always realize SINK parents
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.realizes.update((x, None) for x in s.src if x.op not in DONT_PUSH_VIEWS)),
# always realize ASSIGN/CONTIGUOUS/GroupOp.Meta
(UPat({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}, name="tr"), realize),
# realize before expand or unsafe pad ops
(UPat(Ops.VIEW, name="view", src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="tr"),)), realize_before_view),
# realize before COPY
(UPat(Ops.COPY, src=(UPat(), UPat(GroupOp.All-DONT_PUSH_VIEWS, name="tr"))), realize),
])
def append_uop(ctx:GrouperContext, u:UOp) -> None:
if u.op is Ops.ASSIGN: ctx.assigns[u.buf_uop] = u
for s in u.src: ctx.children[s.base][u] = None
create_ctx = PatternMatcher([(UPat(GroupOp.All-{Ops.SINK, Ops.VIEW}, name="u"), append_uop)])
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, dict[UOp, None]], realizes:dict[UOp, None],
reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None:
"""recursively search the uop for groupable children, realize the UOp if a child can't group"""
if (tr, st) in cache: return
cache.setdefault((tr, st))
rsize = unwrap(r.st).size
if tr in realizes and tr is not r:
# can only fuse contiguous
# max one reduceop per kernel
if not st.contiguous or st.size != rsize or tr in reduce_for_op: group.setdefault(r)
return group.setdefault(tr)
for tr_next in children[tr]:
# max one reduceop per kernel
if tr_next.op is Ops.REDUCE_AXIS: return group.setdefault(r)
# can only fuse contiguous
if len(st_childs:=dedup(unwrap(x.st) for x in tr_next.src if x.base == tr)) > 1: return group.setdefault(r)
recursive_group(tr_next, st+st_childs[0], r, children, realizes, reduce_for_op, group, cache)
def group_realizes(sink:UOp) -> dict[UOp, None]:
# start by adding uops that always realize
sink = graph_rewrite(sink, do_realize+create_ctx, ctx:=GrouperContext({}, {}, defaultdict(dict)))
if DONT_GROUP_REDUCES: return ctx.realizes
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
reduce_for_op: dict[UOp, UOp] = {}
double_reduces: list[UOp] = []
for r in sink.toposort:
if r.op is not Ops.REDUCE_AXIS: continue
if FUSE_CONV_BW and r.src[0].base.op is Ops.REDUCE_AXIS and r.src[0] is not r.src[0].base: double_reduces.append(r)
if r in ctx.realizes: continue
group: dict[UOp, None] = {}
recursive_group(r, unwrap(r.st), r, ctx.children, ctx.realizes, reduce_for_op, group, cache={})
# max one reduceop per kernel
can_chase = all(tr not in reduce_for_op for tr in group)
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
forced_realize = r in group
# can only have one output
if not forced_realize and len(group) > 1: forced_realize = True
# can only fuse assign if no other assign_target is used in the kernel
if not forced_realize and any(x.op is Ops.ASSIGN for x in group):
parents = deque((r, *group))
while parents and not forced_realize:
p = parents.pop().base
if (assign:=ctx.assigns.get(p)) is not None and assign not in group: forced_realize, can_chase = True, False
if p in ctx.realizes: continue
parents.extend(p.src)
if forced_realize or not group:
tr = r
if can_chase:
# can chase this down to contiguous children
st = unwrap(tr.st)
while len(ctx.children[tr]) == 1:
tr_next = next(iter(ctx.children[tr]))
st_childs = dedup(unwrap(s.st) for s in tr_next.src if s.base is tr)
if len(st_childs) > 1: break
if st.size != st_childs[0].size: break
st = st + st_childs[0]
if not st.contiguous or tr_next.op is Ops.REDUCE_AXIS: break
tr = tr_next
# don't cast to higher size before store (tr cannot be realized if forced_realize)
if tr.op is Ops.CAST and tr.dtype.itemsize > tr.src[0].dtype.itemsize:
tr = tr.src[0].base
group = {tr: None}
ctx.realizes[tr] = None
reduce_for_op.update((tr, r) for tr in group)
if FUSE_ARANGE and r.arg[0] is Ops.ADD and r.src[0].base.op is Ops.CONST:
# maybe fuse arange with its children
if len(flatten(ctx.children[tr] for tr in group)) != 0:
for tr in group: del ctx.realizes[tr]
# fuse double reduces with no other child
for reduceop in double_reduces:
top_reduce = reduceop.src[0].base
if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce]
return ctx.realizes
# **** create kernels
@dataclass(frozen=True)
class Kernel:
ast: UOp
metadata: tuple[Metadata, ...] = ()
def __repr__(self):
return f"<Kernel {len(list(self.ast.toposort))} {[s.op for s in self.ast.src] if self.ast.op is Ops.SINK else self.ast.op} {self.metadata}>"
@dataclass(frozen=True)
class KernelContext:
realizes: dict[UOp, None]
ops_metadata: dict[UOp, Metadata]
def create_kernel(ctx:KernelContext, x:UOp, b:UOp):
kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), (m,) if (m:=ctx.ops_metadata.get(x)) else ()))
buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset))
return UOp(Ops.ASSIGN, x.dtype, (buffer, kernel)).reshape(x.shape)
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER}
def append_to_kernel(ctx:KernelContext, x:UOp):
new_srcs: list[UOp] = []
metadata = dict.fromkeys(x.arg.metadata)
for s in x.src:
if s.op in DONT_PLACE_IN_KERNEL or s in ctx.realizes: new_srcs.append(s)
else:
new_srcs.extend(s.src)
if (m:=ctx.ops_metadata.get(s)) is not None: metadata[m] = None
if (new_src:=tuple(dedup(new_srcs))) != x.src: return x.replace(src=new_src, arg=Kernel(x.arg.ast, tuple(metadata)))
create_kernels = PatternMatcher([
# always give assign/contiguous a kernel
(UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), create_kernel),
(UPat(Ops.CONTIGUOUS, name="x"), lambda ctx,x: create_kernel(ctx, x, UOp.new_buffer(x.device, x.size, x.dtype))),
# create a buffer for COPY on the new device
(UPat(Ops.COPY, src=(UPat(Ops.DEVICE, name="d"), UPat()), name="x"), lambda ctx,d,x: create_kernel(ctx, x, UOp.new_buffer(d.arg, x.size, x.dtype))),
# otherwise check the context if we're realizing this UOp
(UPat(GroupOp.All-DONT_PLACE_IN_KERNEL, name="x"),
lambda ctx,x: create_kernel(ctx, x, UOp.new_buffer(x.device, x.size, x.dtype)) if x in ctx.realizes else None),
# walk back the local graph until we reach a buffer/assign parent
(UPat(Ops.KERNEL, name="x"), append_to_kernel),
# remove downstream reshapes from SINK
(UPat(Ops.SINK, name="x"), lambda x:x.replace(src=tuple(s.base for s in x.src)) if any(s.op is Ops.VIEW for s in x.src) else None),
])
# **** swizzler
def apply_swizzle(u:UOp) -> UOp:
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left)
def swizzle_reduceop(r:UOp, src:UOp, view:UOp):
if (st:=unwrap(view.st)).contiguous: return None
input_st = ShapeTracker.from_shape(src.shape)
tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])
strides = strides_for_shape(rshape)
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in st.views]
# create a new reduceop for the swizzled input
new_input_st = tmp + ShapeTracker(tuple(nv))
new_axis = tuple(range(len(st.shape), len(st.shape) + len(r.axis_arg)))
return UOp(Ops.REDUCE_AXIS, r.dtype, (apply_swizzle(src.view(src.arg+new_input_st if src.op is Ops.VIEW else new_input_st)),),
(r.arg[0], new_axis)).view(ShapeTracker.from_shape(st.shape))
def reduceop_view_right(src:UOp, v:UOp, r:UOp):
assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}"
return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u)).view(ShapeTracker.from_shape(r.shape))
def elementwise_view_right(root:UOp):
if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW and x.base.op not in DONT_PUSH_VIEWS]): return None
assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
# place view after applying the elementwise op
new_st = ShapeTracker.from_shape(swizzles[0].base.shape)
new_src = [x.base if x.base.shape==new_st.shape else apply_swizzle(x.view(x.arg+new_st) if x.op is Ops.VIEW else x.view(new_st)) for x in root.src]
# reshape to match downstream shapes
return root.replace(src=tuple(new_src)).reshape(root.shape)
# push VIEW to children
view_right = merge_views+PatternMatcher([
# push a non contiguous ShapeTracker through reduceop
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop),
# apply view after reduceops
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="src"),), name="v"),), name="r"), reduceop_view_right),
# apply view after elementwise ops
(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="root"), elementwise_view_right),
# merge axes for double reduce (invert of SPLIT_REDUCEOP=1)
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] == r2.arg[0] else None),
])
# **** unbind variables
def unbind_shapetracker(ctx:tuple[dict[Variable, int], tuple[UOp, ...]], x:UOp):
st = unwrap(x.st).simplify()
if any(x.op is Ops.BIND for x in st.vars()):
st, var_vals = st.unbind()
ctx[0].update(var_vals)
return x.replace(arg=st) if st != x.st else None
def unbind_variable(ctx:tuple[dict[Variable, int], tuple[UOp, ...]], var:UOp, val:UOp):
ctx[0][var.replace(src=())] = val.arg
return var
# **** fix kernel AST
add_buffer_ops = PatternMatcher([
# LOAD
(UPat(Ops.BUFFER, name="x"), lambda ctx,x:UOp.load(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx[1].index(x)), x.st.to_uop(), dtype=x.dtype)),
# STORE (except for meta ops)
(UPat(Ops.SINK, src=(UPat(GroupOp.Meta, name="x"),)), lambda x:x),
# partial assign can store to a non-contiguous ShapeTracker
(UPat(Ops.SINK, src=(UPat(Ops.ASSIGN, name="x"),)),
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.src[0].base.size), (), 0), x.src[0].st.to_uop(), x.src[1]).sink()),
# otherwise the store is contiguous
(UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()),
# VALID
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),), name="view"), lambda x,view: x.valid(view.arg)),
# if the last child is a VIEW we merge the ShapeTrackers and store the base
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="x"),)))),
lambda x,b,st: UOp.store(b, (st.arg+x.st).to_uop(), x)),
])
def check_load_st(glbl:UOp, view:UOp):
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return
# if it has a single view and it's equal when you shrink a contig, it's fine
if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return
# otherwise, it's not fine
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
fix_kernel_ops = PatternMatcher([
# remove CONTIGUOUS/DEVICE from kernel AST
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
# BIND in shapetracker becomes DEFINE_VAR
(UPat(Ops.VIEW, name="x"), unbind_shapetracker),
(UPat(Ops.BIND, src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),
# no ImageDType after load
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
(UPat(Ops.LOAD, src=(UPat.var("glbl"), UPat.var("view"))), check_load_st),
])
def fix_kernel_ast(k:UOp, var_vals:dict[Variable, int]) -> UOp:
assert k.op is Ops.KERNEL, f"kernel isn't kernel, it's {k}"
# substitute kernel sources for the target buffer + apply reshapes
parents_rep: dict[UOp, UOp] = {}
for s in k.src:
if s.op is Ops.ASSIGN:
for out in s.src[1].arg.ast.src: parents_rep[out] = s.buf_uop.view(unwrap(out.st))
ast = k.arg.ast.substitute(parents_rep)
# push views to edges
ast = graph_rewrite(graph_rewrite(ast, view_left), view_right)
# add buffer ops + fix_kernel_ops
ast = graph_rewrite(ast, merge_views+add_buffer_ops+fix_kernel_ops, ctx=(var_vals, bufs:=tuple(s.buf_uop for s in k.src)), bottom_up=True)
if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}")
# create subbuffer (TODO: this does not belong here)
if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = (base:=bufs[1].buffer).view(ast.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
return k.replace(arg=Kernel(ast, k.arg.metadata))
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
if CAPTURE_PROCESS_REPLAY:
@atexit.register
def save_process_replay():
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
# **** schedule creation and toposort
@dataclass(frozen=True)
class ScheduleItem:
ast: UOp
bufs: tuple[Buffer, ...]
metadata: tuple[Metadata, ...]
@track_rewrites(name_fxn=lambda r: f"Schedule {pluralize('Kernel', len(r[0]))}"+(f" (with_{pluralize('Var', len(r[1]))})" if len(r[1]) != 0 else ""))
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
# merge_views + sym + reorder_view + replace_contiguous
tensor_map = graph_rewrite_map(big_sink, merge_views+sym+reorder_view+replace_contiguous, ctx={})
# display the cleaned up tensor graph
if getenv("VIZ"): graph_rewrite(tensor_map[big_sink], PatternMatcher([]), name="View Tensor Graph")
# get realizes
sink = tensor_map[big_sink]
realize_map = group_realizes(sink)
# map tensor metadata to simplified ops
ops_metadata = {v:k.metadata for k,v in tensor_map.items() if k.base.op not in {Ops.CONST, Ops.DEVICE} and isinstance(k.metadata, Metadata)}
# merge_views + create_kernels
kernel_map = graph_rewrite_map(sink, merge_views+create_kernels, ctx=KernelContext(realize_map, ops_metadata), bottom_up=True)
sched_sink = kernel_map[sink]
type_verify(list(sched_sink.toposort), kernel_spec)
# map tensors to buffer/const, optionally apply a VIEW on top
becomes_map: dict[UOp, UOp] = {}
for k,v in tensor_map.items():
# ASSIGN always becomes the target buffer
if v.op is Ops.ASSIGN: becomes_map[k] = v.src[0]
# if we created a new buffer for this tensor, map it to the assigned buffer
elif (a:=kernel_map.get(v.base)) is not None and (a:=a.base).op is Ops.ASSIGN:
becomes_map[k] = a.src[0] if a.src[0].st == v.st else a.src[0].view(unwrap(v.st))
# tensors can also simplify to an existing buffer/const
else:
if k is v: continue
if v.base.op is Ops.BUFFER: becomes_map[k] = v
if v.base.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
kernel_assign: dict[UOp, UOp] = {}
assign_rep: dict[UOp, UOp] = {}
for u in sched_sink.toposort:
if u.op is not Ops.ASSIGN: continue
kernel_assign[u.buf_uop] = u
for s in u.src[1].src:
if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue
if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort):
raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER")
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
if assign_rep:
sched_sink = sched_sink.substitute(assign_rep)
type_verify(list(sched_sink.toposort), kernel_spec)
# display the final graph
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Memory Graph")
# final toposort (bfs)
children: dict[UOp, list[UOp]] = {}
in_degree: dict[UOp, int] = {}
for u in sched_sink.toposort:
if u.op is not Ops.ASSIGN: continue
in_degree[u] = 0
for s in u.src[1].src:
if s.op is not Ops.ASSIGN: continue
children.setdefault(s, []).append(u)
in_degree[u] += 1
queue = deque(k for k,v in in_degree.items() if v == 0)
schedule: list[ScheduleItem] = []
var_vals: dict[Variable, int] = {}
while queue:
u = queue.popleft()
# TODO: move this to create_kernels
k = fix_kernel_ast(u.src[1], var_vals)
schedule.append(ScheduleItem(k.arg.ast, tuple(s.buf_uop.buffer for s in k.src), k.arg.metadata))
for x in children.get(u, []):
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)
# confirm everything was scheduled correctly
if len(schedule) != (kc:=len(in_degree)): raise RuntimeError(f"cycle detected in graph, created {kc} kernels but only scheduled {len(schedule)}")
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
# capture process replay
if CAPTURE_PROCESS_REPLAY:
with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(big_sink.key)] = pickle.dumps((big_sink, ContextVar._cache, [x.ast for x in schedule]))
return schedule, var_vals, becomes_map