openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

550 lines
30 KiB

import sys, atexit, functools, pickle
from collections import defaultdict, deque
from dataclasses import dataclass, field
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, view_left, merge_views
from tinygrad.ops import identity_element, buffers, exec_alu
from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, ContextVar
from tinygrad.dtype import ConstType, ImageDType, dtypes
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View, strides_for_shape
from tinygrad.device import Buffer
# creation can recurse a lot
sys.setrecursionlimit(10000)
BUF_LIMIT = {"METAL":32}
# **** ScheduleItem return type
@dataclass(frozen=True)
class ScheduleItem:
ast: UOp
bufs: tuple[Buffer, ...]
metadata: tuple[Metadata, ...]
assign_preloads: tuple[UOp, ...]
@property
def outputs(self) -> tuple[Buffer, ...]:
"""Read/write or write only buffers in the schedule."""
return tuple(b for i,b in enumerate(self.bufs) if i in self.output_idxs)
@property
def inputs(self) -> tuple[Buffer, ...]:
"""Read only buffers in the schedule."""
return tuple(b for i,b in enumerate(self.bufs) if i not in self.output_idxs)
@functools.cached_property
def output_idxs(self) -> tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,)
# **** Schedule context and big graph
@dataclass(frozen=True)
class ScheduleContext:
tensor_uops: dict[UOp, list[UOp]] = field(default_factory=dict) # this maps BUFFER uops of this schedule to the tensor uop
var_vals: dict[Variable, int] = field(default_factory=dict) # this maps a BIND's DEFINE_VAR to its value
assigns: set[UOp] = field(default_factory=set) # this holds all the BUFFER uops we ASSIGN to in this schedule
realizes: dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule
allbufs: dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op
ops_metadata: dict[UOp, Metadata] = field(default_factory=dict) # this maps fused ops to Metadata
contiguous: dict[UOp, UOp] = field(default_factory=dict) # this maps roots to places they are made contiguous
children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
def to_uop(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp:
if (r:=cache.get(buf)) is not None: return r
# shapeless op is passthrough
# realized is passthrough
if buf.st is None or buf.base.is_realized: return buf
# view is passthrough
if buf is not buf.base:
cache[buf] = ret = to_uop(buf.base, ctx, cache).view(buf.st)
return ret
# make things that can't be images not images
dtype = buf.dtype
if isinstance(dtype, ImageDType) and (prod(buf.shape) != prod(dtype.shape) or not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
if DEBUG >= 2: print(f"forcing image {dtype} with shape {buf.shape} to {dtype.base}")
dtype = buf.dtype.base
# meta ops and assign already have a target buffer, otherwise we create a new one
buf_uop = buf.buf_uop if buf.op in {Ops.ASSIGN, Ops.VIEW} else UOp.new_buffer(buf.device, buf.size, dtype)
if buf.op is Ops.VIEW: op = buf.src[1].replace(src=tuple(to_uop(x, ctx, cache) for x in buf.src[1].src))
else: op = buf.replace(dtype=dtype.base, src=tuple(to_uop(x, ctx, cache) for x in buf.src))
# track the underlying tensor uop for this op
ctx.tensor_uops[buf_uop] = [buf]
# (early) bufferize
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op.alu(Ops.CONTIGUOUS) if buf.forced_realize else op), buf.st)
return ret
# **** AST graph rewrite
# ** movement ops
def apply_swizzle(u:UOp) -> UOp:
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left)
def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
input_st = ShapeTracker.from_shape(unwrap(src.st).shape)
tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])
strides = strides_for_shape(rshape)
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in st.views]
# update input_st and axis
new_input_st = tmp + ShapeTracker(tuple(nv))
new_axis = tuple(range(len(st.shape), len(st.shape) + len(r.axis_arg)))
return apply_swizzle(src.view(new_input_st)).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
def reduceop_view_right(r:UOp, v:UOp, src:UOp) -> UOp:
if not (swizzle_st:=unwrap(v.st)).contiguous or v.size != src.size: raise AssertionError(f"can't push {v} down through {src}")
output_shape = swizzle_st.reduce(r.axis_arg)
return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, output_shape)) if s != u)).view(ShapeTracker.from_shape(output_shape))
def elementwise_view_right(root:UOp) -> UOp|None:
if len(swizzles:=[x for x in root.src if x.base is not x]) == 0: return None
assert all(x.base.st is not None for x in swizzles), f"found shapeless VIEW src in {root}"
assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
# push the swizzle from src to root
output_swizzle = swizzles[0]
new_input_st = ShapeTracker.from_shape(output_swizzle.base.shape)
ret = root.replace(src=tuple(x if not x.has_st else x.src[0] if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src))
# update the ASSIGN offset to match the new shape
if ret.op is Ops.ASSIGN and ret.arg is not None: ret = ret.replace(arg=ret.arg+new_input_st,)
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(output_swizzle.shape))
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time"
return first_reduce.src[0].r(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg)
# push VIEW to stores
view_right = merge_views+PatternMatcher([
# ASSIGN with offset swizzles STORE
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.ASSIGN, name="a"))),
lambda a,b,st: None if a.arg is None else apply_swizzle(UOp.store(b, st, a.replace(arg=None)).view(a.arg))),
# REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view()
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)),
# REDUCE(src.view()) -> REDUCE(src).view()
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src").view(name="v"),), name="r"), reduceop_view_right),
# ALU(src.view()) -> ALU(src).view()
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), elementwise_view_right),
# double reduce op collapses to a single reduce op
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
])
# ** ScheduleItem context builder
@dataclass(frozen=True)
class ScheduleItemContext:
tensor_uops: dict[UOp, list[UOp]]
ops_metadata: dict[UOp, Metadata]
assigns: set[UOp]
var_vals: dict[Variable, int]
sinked: dict[UOp, UOp]
sts: set[ShapeTracker] = field(default_factory=set)
bufs: list[UOp] = field(default_factory=list)
metadata: set[Metadata] = field(default_factory=set)
assign_adj: dict[UOp, list[UOp]] = field(default_factory=dict)
def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> UOp|None:
if (st:=unwrap(x.st)) in ctx.sts: return None
st, var_vals = st.simplify().unbind()
ctx.var_vals.update(var_vals)
ctx.sts.add(st)
return st.to_uop() if st != x.st else None
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
assert x.arg[0] != -1, "fake -1 BUFFERS should not make it here"
ctx.bufs.append(x)
return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.arg[1]), (), len(ctx.bufs)-1)
append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)])
def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp:
(adj_loads:=ctx.assign_adj.setdefault(b, [])).append(x)
if not all_same([x.op for x in adj_loads]): raise RuntimeError(f"Detected cycle when fusing {adj_loads}. Can only fuse PRELOAD or LOAD of {b}")
return x.replace(op=Ops.LOAD)
check_preload = PatternMatcher([(UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload),])
to_si = PatternMatcher([
(UPat(Ops.VIEW, name="x"), _append_st_vars),
(UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda b,x: x.replace(src=(b, *x.src))),
# don't need contiguous or assign anymore
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
(UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x),
])
add_metadata = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: None if (m:=ctx.ops_metadata.get(x)) is None else ctx.metadata.add(m)),])
add_assign_adjacents = PatternMatcher([(UPat.load(UPat.var("b"), UPat(), name="x"), lambda ctx,b,x: ctx.assign_adj.setdefault(b, []).append(x)
if b in ctx.assigns else None)])
# late folding for multi output kernels
multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.sinked.get(b)),])
def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem:
# create the ast context
si_ctx = ScheduleItemContext(ctx.tensor_uops, ctx.ops_metadata, ctx.assigns, ctx.var_vals, {x.buf_uop:x.src[2] for x in pre.src})
create_ctx = add_metadata if len(si_ctx.assigns) == 0 else add_metadata+add_assign_adjacents
sink = graph_rewrite(pre, create_ctx if len(si_ctx.sinked) == 1 else multioutput+create_ctx, si_ctx)
# do movement ops
sink = graph_rewrite(graph_rewrite(sink, view_left), view_right)
# convert to AST
sink = graph_rewrite(graph_rewrite(sink, to_si+check_preload if len(si_ctx.assigns) != 0 else to_si, si_ctx), append_bufs, si_ctx)
# assert buffer count limit
if (limit:=BUF_LIMIT.get(device:=si_ctx.bufs[0].device)) is not None and len(si_ctx.bufs) >= limit:
if DEBUG >= 3: print(sink)
raise RuntimeError(f"Kernel for {si_ctx.metadata} exceeded the {limit} buffer count limit for {device} with {len(si_ctx.bufs)} buffers.")
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
for ubuf,ops in si_ctx.assign_adj.items():
if si_ctx.sinked.get(ubuf) is not None and not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in ops):
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
# can only schedule once
for buf_uop in si_ctx.sinked:
for luop in si_ctx.tensor_uops[buf_uop]: luop.become(buf_uop.view(unwrap(luop.st)))
# capture process replay
if getenv("RUN_PROCESS_REPLAY"):
PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, si_ctx.assigns, {k:v.value for k,v in ContextVar._cache.items()}, sink))
return ScheduleItem(sink, tuple(u.buffer for u in si_ctx.bufs if u.size != 0), tuple(si_ctx.metadata),
tuple(ubuf for ubuf,ops in si_ctx.assign_adj.items() if any(x.op is Ops.PRELOAD for x in ops)))
PROCESS_REPLAY_CAPTURE: dict[str, bytes] = {}
if getenv("RUN_PROCESS_REPLAY"):
@atexit.register
def save_process_replay() -> None:
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
# **** Schedule grouping
def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2
def uval(u:UOp) -> UOp:
assert is_scheduled(u), f"must be a scheduled op {u}"
return r.src[0] if (r:=u.src[1]).op is Ops.CONTIGUOUS and not (r.src[0].base.op is Ops.VIEW and len(r.src[0].base.src) == 2) else r
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, dict[UOp, None]], allbufs:dict[UOp, UOp], realizes:dict[UOp, UOp],
reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None:
"""recursively search the uop for groupable children, realize the UOp if a child can't group"""
if (tr, st) in cache: return
cache.setdefault((tr, st))
rsize = unwrap(allbufs[r].st).size
if tr in realizes and tr is not r:
# can only fuse contiguous
# max one reduceop per kernel
if not st.contiguous or st.size != rsize or tr in reduce_for_op: group.setdefault(r)
return group.setdefault(tr)
for tr_next in children[tr]:
# max one reduceop per kernel
if (tr_next_uop:=uval(allbufs[tr_next]).base).op is Ops.REDUCE_AXIS: return group.setdefault(r)
# can only fuse contiguous
if len(st_childs:=dedup(unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop == tr)) > 1: return group.setdefault(r)
recursive_group(tr_next, st+st_childs[0], r, children, allbufs, realizes, reduce_for_op, group, cache)
def get_isolated_children(r:UOp, reduce_for_op:dict[UOp, UOp], children:defaultdict[UOp, dict[UOp, None]], allbufs:dict[UOp, UOp],
realizes:dict[UOp, UOp], group:dict[UOp, None]) -> dict[UOp, None]:
rc_parents, cache = deque(group), set()
while rc_parents:
if (p:=uval(allbufs[rc_parents.pop()])) in cache: continue
cache.add(p)
# max one reduceop per kernel
if p.op is Ops.REDUCE_AXIS: return {}
rc_parents.extend(x.base.buf_uop for x in p.src if is_scheduled(x.base) and x.base.buf_uop is not r)
# search descendants of the reduceop that can cleanly group
descendants: dict[UOp, None] = {}
for tr in group: recursive_group(tr, unwrap(allbufs[tr].st), tr, children, allbufs, realizes, reduce_for_op, descendants, cache={})
return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants])
def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]:
"""search the big graph for all the reduceops that need to realize, sometimes group/fuse the reduceop"""
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
reduce_for_op: dict[UOp, UOp] = {}
reduce_of_const: list[UOp] = []
double_reduces: list[UOp] = []
for r, r_uop in ctx.allbufs.items():
if (r_uop:=uval(r_uop)).op is not Ops.REDUCE_AXIS: continue
if FUSE_CONV_BW and uval((x:=r_uop.src[0]).base).op is r_uop.op and x.base is not x: double_reduces.append(r)
if r in ctx.realizes: continue
group: dict[UOp, None] = {}
recursive_group(r, unwrap(r_uop.st), r, ctx.children, ctx.allbufs, ctx.realizes, reduce_for_op, group, cache={})
# max one reduceop per kernel
can_chase = all(tr not in reduce_for_op for tr in group)
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
forced_realize = r in group
if not forced_realize and len(group) > 1:
group = get_isolated_children(r, reduce_for_op, ctx.children, ctx.allbufs, ctx.realizes, group)
# can only fuse assign if no other assign_target is used in the kernel
if not forced_realize and any(x in ctx.assigns for x in group):
parents = deque((r, *group))
while parents and not forced_realize:
if (p_uop:=ctx.allbufs.get(p:=parents.pop())) is None: continue
if (p_uop:=uval(p_uop)).op is Ops.ASSIGN and p not in group: forced_realize, can_chase = True, False
if p in ctx.realizes: continue
parents.extend([x.base.src[0] for x in p_uop.src if x.base.op is Ops.VIEW and len(x.base.src) != 0])
if forced_realize or not group:
tr = r
if can_chase:
# can chase this down to contiguous children
st = unwrap(r_uop.st)
while len(ctx.children[tr]) == 1:
tr_next_uop = uval(ctx.allbufs[(tr_next:=next(iter(ctx.children[tr])))])
st_childs = dedup([unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop is tr])
if len(st_childs) > 1: break
if st.size != st_childs[0].size: break
st = st + st_childs[0]
if not st.contiguous or tr_next_uop.op is Ops.REDUCE_AXIS: break
tr = tr_next
# don't cast to higher size before store (tr cannot be realized if forced_realize)
if (tr_uop:=uval(ctx.allbufs[tr])).op is Ops.CAST and tr_uop.dtype.base.itemsize > tr_uop.src[0].dtype.base.itemsize:
tr = tr_uop.src[0].base.buf_uop
group = {tr: None}
ctx.realizes[tr] = tr
reduce_for_op.update((tr, r) for tr in group)
if FUSE_ARANGE and r_uop.arg[0] is Ops.ADD and uval(r_uop.src[0].base).op is Ops.CONST: reduce_of_const.append(r)
# fuse double reduces with no other child
for reduceop in double_reduces:
top_reduce = uval(ctx.allbufs[reduceop]).src[0].base.buf_uop
if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce]
# maybe fuse arange with its children
for rbuf in reduce_of_const:
group = {tr:None for tr,rop in reduce_for_op.items() if rop is rbuf}
if any(luop.forced_realize for tr in group for luop in ctx.tensor_uops[tr]): continue
kernel_children = {c for tr in group for c in ctx.children[tr] if uval(ctx.allbufs[c]).op not in {Ops.COPY, Ops.BUFFER_VIEW}}
if len(kernel_children) == 0: continue
for tr in group: del ctx.realizes[tr]
# group BUFFER uops into kernels
output_groups: defaultdict[UOp, list[UOp]] = defaultdict(list)
for ubuf in ctx.realizes: output_groups[reduce_for_op.get(ubuf, ubuf)].append(ubuf)
return list(output_groups.values())
# **** Schedule creation and BFS toposort
# ** ops in the big graph can either be pre-realized or scheduled (fused/realized)
class UPatScheduled(UPat):
def __init__(self, *args, **kwargs): super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"),
UPat(*args, **{"name":"to_store",**kwargs})))
# ** this is schedule level const folding
def _as_const(u:UOp, val:ConstType) -> UOp:
assert is_scheduled(u), f"must be scheduled to fold {u}"
st = (base:=ShapeTracker.from_shape(())).reshape((1,)*len(u.shape)).expand(u.shape)
return UOp(Ops.VIEW, u.dtype, (u.buf_uop, UOp.const(u.dtype, val)), base).view(st)
def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
# remove reduce on unmasked const
if all_int(x.shape) and x.is_unrealized_unmasked_const():
prshape = prod(unwrap(x.st).shape[i] for i in reduce.arg[1])
ret = x.const_arg
match reduce.arg[0]:
case Ops.ADD: ret *= prshape
case Ops.MUL: ret **= prshape
case Ops.MAX: pass # NOTE: Ops.MAX is passthrough
case _: return None
return UOp.const(reduce.dtype, ret)
return None
def simplify_alu(alu:UOp):
if not all(x.is_unrealized_unmasked_const() for x in alu.src): return None
# this needs to have a VIEW next (it has to, right?)
return UOp.const(alu.dtype, exec_alu(alu.op, alu.dtype, [s.const_arg for s in alu.src]))
def simplify_binop(binop:UOp, x:UOp, y:UOp):
if all_int(x.shape) and x.is_unrealized_unmasked_const(): other, const = y, x
elif all_int(y.shape) and y.is_unrealized_unmasked_const():
if binop.op is Ops.IDIV and y.const_arg == 1: return x
other, const = x, y
else: return None
if binop.op is Ops.ADD and const.const_arg == 0: return other
if binop.op is Ops.MUL and const.const_arg == 1: return other
if binop.op is Ops.MUL and const.const_arg == 0: return UOp.const(binop.dtype, 0)
def found_contiguous(ctx:ScheduleContext, contig:UOp, base:UOp, b:UOp):
if contig.src[0].op is Ops.VIEW and len(contig.src[0].src):
old_base = contig.src[0].src[0]
if old_base.op is Ops.VIEW and (sti:=unwrap(contig.src[0].st).invert(old_base.shape)) is not None: ctx.contiguous[old_base] = base.view(sti)
def replace_contiguous(ctx:ScheduleContext, alu:UOp):
new_src = list(alu.src)
for i,s in enumerate(alu.src):
if (replace_src:=ctx.contiguous.get(s, None)) is not None: new_src[i] = replace_src
if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src))
ops_folding = PatternMatcher([
# op with size 0 is zero
(UPatScheduled(), lambda b,to_store,base: _as_const(base, 0) if base.size == 0 else None),
# DETACH is a NOOP here
(UPat(Ops.DETACH, name="detach"), lambda detach: detach.src[0]),
# elementwise const folding
(UPat(GroupOp.ALU, name="alu"), simplify_alu),
(UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, name="binop", src=(UPat.var("x"), UPat.var("y"))), simplify_binop),
(UPat(Ops.CAST, src=(UPat.var("x"),), name="cast"),
lambda x,cast: UOp.const(cast.dtype, x.const_arg) if all_int(x.shape) and x.is_unrealized_unmasked_const() else None),
# reduce of size 0 is the identity element
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
lambda reduce,x:UOp.const(reduce.dtype, identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
# reduce of const is collapsed (TODO: make this a generic rule for stride0)
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), simplify_reduceop),
# CONST doesn't need COPY
(UPat(Ops.COPY, src=(UPat.var("x"),)), lambda x:x if x.is_unrealized_const() else None),
# no double COPY
(UPat(Ops.COPY, src=(UPat(Ops.VIEW, src=(UPat(), UPat(Ops.COPY, name="base")),))), lambda base: base),
# no COPY to same device, except clone (arg is True)
(UPatScheduled(Ops.COPY, src=UPat(Ops.VIEW, name="copyin"), name="copy"),
lambda base,b,copyin,copy: copyin if base.device == copy.device and copy.arg[1] is not True else None),
# support for using a contiguous permuted view instead of the parent view if one exists
(UPatScheduled(Ops.CONTIGUOUS, name="contig"), found_contiguous),
(UPat(GroupOp.ALU, name="alu"), replace_contiguous),
])
# ** buffer merging
def merge(ctx:ScheduleContext, v1:UOp, b1:UOp, v2:UOp, b2:UOp) -> UOp:
assert v1.st is not None and v2.st is not None and v1.st == v2.st, f"implicit movementop {v1.st} {v2.st}"
# if b2 is realized also realize b1
if b2 in ctx.realizes:
ctx.realizes[b1] = b1
del ctx.realizes[b2]
# ops referring to b2 now ref to b1
ctx.tensor_uops[b1] += ctx.tensor_uops[b2]
del ctx.tensor_uops[b2]
# merge
return v1
def merge_realized(ctx:ScheduleContext, v1:UOp, b1:UOp, v2:UOp, b2:UOp):
# early become
for luop in ctx.tensor_uops.get(b1, [])+ctx.tensor_uops.get(b2, []): luop.become(b1.view(unwrap(luop.st)))
return v1
merge_bufs = PatternMatcher([
# merge base
(UPat(Ops.VIEW, name="v2", src=(UPat(Ops.BUFFER, name="b2"), UPat(Ops.VIEW, name="v1", src=(UPat(Ops.BUFFER, name="b1"), UPat())))), merge),
(UPat(Ops.VIEW, name="v2", src=(UPat(Ops.BUFFER, name="b2"), UPat(Ops.VIEW, name="v1", src=(UPat(Ops.BUFFER, name="b1"),)))), merge_realized),
# collapse view
(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat())).view(name="mv"))), lambda mv:mv),
(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).view(name="mv"))), lambda mv:mv),
])
# ** this decides which ops get realized
def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None:
if to_store.op not in {Ops.CONST, Ops.BIND}: ctx.realizes.update([(b, to_store)])
def realize_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None:
if src.st is None: return None
st = unwrap(view.st)
# fold simple pads
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])):
return None if can_pad(src, ctx.realizes, set()) else realize(ctx, b, src)
# early realize before expand
if resolve(prod(src.shape) < prod(st.shape)): return realize(ctx, b, src)
# otherwise safety check pads
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, set())) else realize(ctx, b, src)
def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, to_cast:UOp, **kwargs) -> UOp|None:
if not isinstance(xb.dtype, ImageDType) or b not in ctx.realizes or xb not in ctx.realizes or uval(to_cast).op in GroupOp.Meta: return None
del ctx.realizes[b]
return to_cast.view(unwrap(view.st))
def init_big_graph(sink:UOp) -> UOp|None:
new_src = tuple(x.base for x in sink.src if is_scheduled(x.base) and x.base.src[1].op is not Ops.CONST)
return None if new_src == sink.src else UOp(Ops.NOOP) if len(new_src) == 0 else UOp.sink(*new_src)
do_realize = PatternMatcher([
# always realize sinked ops
(UPat(Ops.SINK, name="sink"), init_big_graph),
# always realize meta ops
(UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}), realize),
# realize before expand or unsafe pad ops
(UPatScheduled(name="src").view(name="view"), realize_view),
# don't realize image to image casts
(UPatScheduled(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="to_cast"),), dtype=dtypes.float).view(name="view"), fold_img_cast),
# realize before COPY or BUFFER_VIEW
(UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
# ASSIGN only needs the buffer
(UPat(Ops.ASSIGN, src=(UPat(Ops.VIEW, name="dest"), UPat.var("src")), name="x"), lambda dest,src,x: x.replace(src=(dest.base.buf_uop, src))),
])
# ** this breaks down realized ops into STOREs and rewrites the ops to LOADs
def generate_valid(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp:
if to_store.op is Ops.BIND: ctx.var_vals.update([to_store.unbind()])
return UOp.const_with_shape(base.dtype, to_store if to_store.op is Ops.BIND else to_store.arg, unwrap(base.st).shape)
def append_realize(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp:
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(base.shape).to_uop(), append_op(ctx, b, to_store))
return UOp(Ops.LOAD, base.dtype, (b, unwrap(base.st).to_uop()))
def append_op(ctx:ScheduleContext, b:UOp, to_store:UOp) -> UOp:
# TODO: metadata post merge
if (m:=ctx.tensor_uops[b][0].metadata) is not None: ctx.ops_metadata[to_store] = m
return to_store
break_sched = PatternMatcher([
# consts are always fused and generated
(UPatScheduled({Ops.CONST, Ops.BIND}), generate_valid),
# view of realized buffer just loads
(UPat(Ops.BUFFER, name="b").view(name="v"), lambda ctx,b,v: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, v.st.to_uop()))),
# all other views either fold or realize with a store
(UPatScheduled(), lambda ctx,b,to_store,base: append_realize(ctx, b, to_store, base) if b in ctx.realizes else append_op(ctx, b, to_store)),
])
# **** Schedule context builder
def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
ctx.allbufs[buf_uop] = view
if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop)
for x in op.src:
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
# BUFFER_VIEW overrides the underlying buffer
if op.op is Ops.BUFFER_VIEW:
buffers[buf_uop] = (x:=op.src[0]).base.buffer.view(view.size, view.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
buf_uop.buffer.ref(1)
create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)])
remove_movement_ops = PatternMatcher([(UPat(GroupOp.Movement, name="x"), lambda x: x.base.view(unwrap(x.st))),])
@track_rewrites(named=True)
def create_schedule_with_vars(outs:list[UOp]) -> tuple[list[ScheduleItem], dict[Variable, int]]:
if len(outs:=dedup(x.base for x in outs if x.base.realized is None and x.base.op is not Ops.CONST)) == 0: return [], {}
# create the big graph
ctx = ScheduleContext()
cache: dict[UOp, UOp] = {}
# to_uop is removing (many) of the movement ops
for u in (big_graph:=UOp.sink(*(to_uop(x, ctx, cache) for x in outs))).src: ctx.realizes[u.buf_uop] = u
big_graph = graph_rewrite(big_graph, remove_movement_ops+ops_folding+do_realize, ctx)
big_graph = graph_rewrite(big_graph, merge_bufs, ctx)
# create the scheduler context
graph_rewrite(big_graph, create_ctx, ctx)
# group realizes into kernels
store_groups = group_realizes(ctx)
graph_rewrite(big_graph, break_sched, ctx)
# preschedule realize groups
prescheduled: list[ScheduleItem] = []
for store_uops in store_groups:
if len(stores:=[ctx.realizes[u] for u in store_uops if ctx.realizes[u].op is Ops.STORE]) != 0:
prescheduled.append(schedule_uop(UOp.sink(*stores), ctx))
# do BFS
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
graph: defaultdict[ScheduleItem, list[ScheduleItem]] = defaultdict(list)
in_degree: defaultdict[ScheduleItem, int] = defaultdict(int)
for si in prescheduled:
# realize outputs before a parent is assigned to
parents_assigns = dedup(xsi for x in si.assign_preloads if (xsi:=schedule_targets.get(x.buffer)) and xsi is not si)
for assign in parents_assigns:
graph[si].append(assign)
in_degree[assign] += 1
# realize outputs after all parents are realized
scheduled_parents = dedup(xsi for x in si.inputs if (xsi:=schedule_targets.get(x)) is not None and xsi not in parents_assigns)
for x in scheduled_parents:
graph[x].append(si)
in_degree[si] += 1
queue = deque(si for si in prescheduled if in_degree[si] == 0)
schedule: list[ScheduleItem] = []
while queue:
schedule.append(si:=queue.popleft())
for x in graph[si]:
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)
# confirm everything was scheduled correctly
if len(schedule) != (groups:=len(prescheduled)): raise RuntimeError(f"cycle detected in graph, grouped {groups} but only scheduled {len(schedule)}")
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
return schedule, ctx.var_vals
def create_schedule(outs:list[UOp]) -> list[ScheduleItem]:
schedule, var_vals = create_schedule_with_vars(outs)
assert len(var_vals) == 0
return schedule