You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
101 lines
4.1 KiB
101 lines
4.1 KiB
import heapq
|
|
from typing import Any
|
|
from collections import defaultdict
|
|
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, UPat, multirange_str
|
|
from tinygrad.helpers import prod, getenv, TUPLE_ORDER
|
|
|
|
def linearize(sink:UOp) -> list[UOp]:
|
|
# this is a toposort with priority
|
|
lst = list(sink.toposort())
|
|
consumers: defaultdict[UOp, list[UOp]] = defaultdict(list)
|
|
in_degree:dict[UOp, int] = {}
|
|
out_degree:dict[UOp, int] = {}
|
|
priorities:dict[UOp, tuple[int, int, Any]] = {}
|
|
|
|
# get consumers and assign priorities
|
|
# NOTE: this requires the lst be locally toposorted
|
|
for u in reversed(lst):
|
|
for s in u.src: consumers[s].append(u)
|
|
in_degree[u] = len(u.src)
|
|
out_degree[u] = len(consumers[u])
|
|
|
|
# we place UOps with higher run_counts later
|
|
run_count = prod([int(r.vmax)+1 for r in u.ranges])
|
|
|
|
# simple priority override. this is all bottom up now, smaller numbers will be closer to the top
|
|
extra = None
|
|
match u.op:
|
|
# the order and placement of these defines is important
|
|
case Ops.DEFINE_GLOBAL: priority, extra = -20, u.arg
|
|
case Ops.DEFINE_VAR: priority, extra = -19, u.arg
|
|
case Ops.DEFINE_LOCAL: priority = -18
|
|
case Ops.DEFINE_REG: priority = -17
|
|
case Ops.CONST: priority = -10 # early consts
|
|
case Ops.LOAD: priority = -1 # place loads early
|
|
case Ops.STORE: priority = 1 # place stores late
|
|
case Ops.RANGE: priority = 5 # placing RANGE is good
|
|
case Ops.END: priority = -5 # placing END is bad
|
|
case _: priority = 0 # everything else has priority 0
|
|
priorities[u] = (run_count, priority, extra)
|
|
|
|
# number the uops in "ideal" order
|
|
nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: priorities[x]+(x.tuplize if TUPLE_ORDER else ())))}
|
|
|
|
# then force them to be toposorted in as close to the ideal order as possible
|
|
heap = [(-nkey[sink], sink)]
|
|
newlst = []
|
|
while heap:
|
|
newlst.append(u:=heapq.heappop(heap)[1])
|
|
for v in u.src:
|
|
out_degree[v] -= 1
|
|
if out_degree[v] == 0: heapq.heappush(heap, (-nkey[v],v))
|
|
newlst = newlst[::-1]
|
|
|
|
if getenv("DEBUG_LINEARIZE"):
|
|
for i,u in enumerate(newlst):
|
|
print(f"{i:4d} {str(u.op):20s} {multirange_str(u.ranges, color=True, pad=10)} {priorities[u]}")
|
|
return newlst
|
|
|
|
class CFGContext:
|
|
def __init__(self, sink:UOp):
|
|
# there are 3 relationships between ranges:
|
|
# nested, meaning endrange y is a dependency of endrange x and range x is a dependency of endrange y
|
|
# dependent, meaning endrange y is a dependency of endrange x and range x is not a dependency of endrange y
|
|
# independent, endrange y is not a dependency of endrange x
|
|
# everything is nested inside the sink
|
|
deps: dict[UOp, dict[UOp, None]] = {}
|
|
nesting: dict[UOp, UOp] = {}
|
|
for u in sink.toposort():
|
|
# get the deps from the src
|
|
deps[u] = {}
|
|
for s in u.src: deps[u] |= deps[s]
|
|
|
|
if u.op in (Ops.END, Ops.SINK):
|
|
nesting |= {x:u for x in deps[u] if x.op is Ops.END and (u.op is Ops.SINK or u.src[1] in deps[x]) and x not in nesting}
|
|
if u.op in (Ops.RANGE, Ops.END): deps[u][u] = None
|
|
|
|
self.edges: dict[UOp, UOp] = {}
|
|
siblings: dict[UOp, list[UOp]] = {}
|
|
for k,vv in nesting.items(): siblings.setdefault(vv, []).append(k)
|
|
for k,v in siblings.items():
|
|
# ranges that have dependencies on other siblings need to be scheduled after them
|
|
order = sorted(v, key=lambda x: len([u for u in v if u in deps[x]]))
|
|
zipped = zip(order, order[1:]) if k.op is Ops.SINK else zip([k.src[1]] + order, order)
|
|
for x,y in zipped:
|
|
# TODO: this can happen! it causes infinite loop in shufflenet
|
|
assert y.src[1] not in x.backward_slice_with_self
|
|
self.edges[y.src[1]] = x
|
|
|
|
pm_add_control_flow = PatternMatcher([
|
|
(UPat(Ops.RANGE, name="x"), lambda ctx,x: x.replace(src=x.src+(y,)) if (y:=ctx.edges.get(x)) is not None else None),
|
|
])
|
|
|
|
def do_split_ends(e:UOp):
|
|
ret = e.src[0]
|
|
for r in sorted(UOp.sink(*e.src[1:]).ranges, key=lambda x: x.arg, reverse=True): ret = ret.end(r)
|
|
return ret
|
|
|
|
pm_split_ends = PatternMatcher([
|
|
# split the ends
|
|
(UPat(Ops.END, name="e"), do_split_ends),
|
|
]) |