You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
118 lines
5.8 KiB
118 lines
5.8 KiB
from tinygrad.uop.ops import Ops, UOp, resolve, can_pad, GroupOp, UPat, PatternMatcher, graph_rewrite
|
|
from tinygrad.helpers import all_int, prod, unwrap, dedup, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, FUSE_CONV_BW
|
|
from tinygrad.shape.shapetracker import ShapeTracker
|
|
|
|
ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
|
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.GBARRIER}
|
|
|
|
# **** Grouper decides which of the UOps realize
|
|
|
|
def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
|
|
|
|
def realize_parents(ctx:dict[UOp, None], rb:UOp) -> None:
|
|
for s in rb.src:
|
|
if s.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
|
|
|
|
def realize_before_view(ctx:dict[UOp, None], view:UOp, tr:UOp) -> None:
|
|
st = unwrap(view.st)
|
|
# always realize unsafe pad ops before masked view
|
|
if any(v.mask is not None for v in st.views) and not can_pad(tr, ctx): return realize(ctx, tr)
|
|
# fold simple pads
|
|
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(tr.shape) and resolve(prod(tr.shape) >= prod([y-x for x,y in m])): return
|
|
# realize before expand
|
|
if resolve(prod(tr.shape) < prod(st.shape)) and not DONT_REALIZE_EXPAND: return realize(ctx, tr)
|
|
|
|
do_realize = PatternMatcher([
|
|
# always realize SINK parents
|
|
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
|
|
# always realize ASSIGN/CONTIGUOUS/GroupOp.Meta
|
|
(UPat({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}, name="tr"), realize),
|
|
# realize before expand or unsafe pad ops
|
|
(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),), name="view"), realize_before_view),
|
|
# realize parents of COPY, MSELECT, MSTACK
|
|
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_parents),
|
|
])
|
|
|
|
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:dict[UOp, dict[UOp, None]], realizes:dict[UOp, None],
|
|
reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None:
|
|
if (tr, st) in cache: return
|
|
cache.setdefault((tr, st))
|
|
rsize = unwrap(r.st).size
|
|
if tr in realizes and tr is not r:
|
|
# can only fuse contiguous
|
|
# max one reduceop per kernel
|
|
if not st.contiguous or st.size != rsize or tr in reduce_for_op: group.setdefault(r)
|
|
return group.setdefault(tr)
|
|
for tr_next in children.get(tr, {}):
|
|
# max one reduceop per kernel
|
|
if tr_next.op is Ops.REDUCE_AXIS: return group.setdefault(r)
|
|
# can only fuse contiguous
|
|
if len(st_childs:=dedup(unwrap(x.st) for x in tr_next.src if x.base == tr)) > 1: return group.setdefault(r)
|
|
recursive_group(tr_next, st+st_childs[0], r, children, realizes, reduce_for_op, group, cache)
|
|
|
|
def group_realizes(sink:UOp) -> dict[UOp, None]:
|
|
# start by adding uops that always realize
|
|
realizes: dict[UOp, None] = {}
|
|
sink = graph_rewrite(sink, do_realize, ctx=realizes, name="do_realize")
|
|
if DONT_GROUP_REDUCES: return realizes
|
|
|
|
# construct children graph (only for bases)
|
|
children: dict[UOp, dict[UOp, None]] = {}
|
|
assigns: dict[UOp, None] = {}
|
|
for u in (toposort:=sink.toposort()):
|
|
if u.op in {Ops.VIEW, Ops.SINK}: continue
|
|
if u.op is Ops.ASSIGN: assigns[u.buf_uop] = None
|
|
for s in u.src: children.setdefault(s.base, {})[u] = None
|
|
|
|
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
|
|
reduce_for_op: dict[UOp, UOp] = {}
|
|
double_reduces: list[UOp] = []
|
|
for r in toposort:
|
|
if r.op is not Ops.REDUCE_AXIS: continue
|
|
if len(r.arg) == 3 and r.arg[2] is True: continue
|
|
if FUSE_CONV_BW and r.src[0].base.op is Ops.REDUCE_AXIS and r.src[0] is not r.src[0].base: double_reduces.append(r)
|
|
if r in realizes: continue
|
|
group: dict[UOp, None] = {}
|
|
recursive_group(r, unwrap(r.st), r, children, realizes, reduce_for_op, group, cache={})
|
|
# max one reduceop per kernel
|
|
can_chase = all(tr not in reduce_for_op for tr in group)
|
|
for u in r.toposort(gate=lambda u: u not in realizes):
|
|
if u.op is Ops.REDUCE_AXIS and u.src[0].base.op is Ops.CONST:
|
|
can_chase = False
|
|
break
|
|
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
|
|
forced_realize = r in group
|
|
# can only have one output
|
|
if not forced_realize and len(group) > 1: forced_realize = True
|
|
# can only fuse assign if no other assign_target is used in the kernel
|
|
if not forced_realize and (assign_targets:={x.buf_uop for x in group if x.op is Ops.ASSIGN}):
|
|
parents = [r, *group]
|
|
while parents and not forced_realize:
|
|
p = parents.pop().base
|
|
if p.op is Ops.BUFFER and p in assigns and p not in assign_targets: forced_realize, can_chase = True, False
|
|
if p in realizes: continue
|
|
parents.extend(p.src)
|
|
if forced_realize or not group:
|
|
tr = r
|
|
if can_chase:
|
|
# can chase this down to contiguous children
|
|
st = unwrap(tr.st)
|
|
while len(lst:=children.get(tr, {})) == 1:
|
|
tr_next = next(iter(lst))
|
|
st_childs = dedup(unwrap(s.st) for s in tr_next.src if s.base is tr)
|
|
if len(st_childs) > 1: break
|
|
if st.size != st_childs[0].size: break
|
|
st = st + st_childs[0]
|
|
if not st.contiguous or tr_next.op is Ops.REDUCE_AXIS: break
|
|
tr = tr_next
|
|
# don't cast to higher size before store (tr cannot be realized if forced_realize)
|
|
if tr.op is Ops.CAST and tr.dtype.itemsize > tr.src[0].dtype.itemsize:
|
|
tr = tr.src[0].base
|
|
group = {tr: None}
|
|
realizes[tr] = None
|
|
reduce_for_op.update((tr, r) for tr in group)
|
|
# fuse double reduces with no other child
|
|
for reduceop in double_reduces:
|
|
top_reduce = reduceop.src[0].base
|
|
if len(children.get(top_reduce, {})) == 1: del realizes[top_reduce]
|
|
return realizes
|
|
|