You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
245 lines
10 KiB
245 lines
10 KiB
from __future__ import annotations
|
|
import heapq
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass, replace
|
|
from tinygrad.ops import UOp, Ops, PatternMatcher, UPat, GroupOp
|
|
from tinygrad.helpers import dedup, partition, all_same, flatten
|
|
from tinygrad.spec import type_verify
|
|
|
|
# NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed
|
|
def block_reorder(lst:list[UOp]) -> list[UOp]:
|
|
in_this_block = set(lst)
|
|
local_children: defaultdict[UOp, list[UOp]] = defaultdict(list)
|
|
in_degree:dict[UOp, int] = {}
|
|
priorities:dict[UOp, int] = {}
|
|
|
|
# get local children and assign priorities
|
|
# NOTE: this requires the lst be locally toposorted
|
|
for u in reversed(lst):
|
|
in_degree[u] = 0
|
|
for s in u.src:
|
|
if s in in_this_block:
|
|
local_children[s].append(u)
|
|
in_degree[u] += 1
|
|
# put loads in the beginning of the block and prevent priority inversion. hack for BARRIER grouping too
|
|
priority = [0] + [priorities[x] for x in local_children[u]]
|
|
if u.op is Ops.LOAD: priority.append(-1000)
|
|
if u.op is Ops.BARRIER: priority.append(-1500)
|
|
priorities[u] = min(priority)
|
|
|
|
# number the uops in "ideal" order
|
|
nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: (priorities[x],)+x.tuplize))}
|
|
|
|
# then force then to be toposorted in as close to the ideal order as possible
|
|
heapq.heapify(heap:=[(nkey[u],u) for u in lst if in_degree[u] == 0])
|
|
newlst = []
|
|
while heap:
|
|
newlst.append(u:=heapq.heappop(heap)[1])
|
|
for v in local_children[u]:
|
|
in_degree[v] -= 1
|
|
if in_degree[v] == 0: heapq.heappush(heap, (nkey[v],v))
|
|
|
|
assert len(newlst) == len(lst), f"len mismatch {len(newlst)} != {len(lst)}"
|
|
return newlst
|
|
|
|
# ***** basic block *****
|
|
|
|
def disp(y:UOp) -> str:
|
|
if y.op is Ops.IF: return f'IF{id(y)}'
|
|
if y.op is Ops.RANGE: return str(y.arg)
|
|
return "<NONE>"
|
|
|
|
@dataclass(frozen=True, eq=False)
|
|
class BasicBlock2:
|
|
lst: tuple[UOp, ...]
|
|
ctx: tuple[UOp, ...] = ()
|
|
end: UOp|None = None
|
|
cnt: int = 0
|
|
child_ctx: tuple[UOp, ...]|None = None
|
|
def __lt__(self, _:BasicBlock2): raise RuntimeError("no comparing basic blocks")
|
|
def __repr__(self):
|
|
return f"{(str(disp(self.end))+' ') if self.end is not None else ''}"+f'f{self.cnt} '+\
|
|
f"{[disp(y) for y in self.ctx]} {[disp(y) for y in self.child_ctx] if self.child_ctx is not None else '-'} "+\
|
|
f"{len(self.lst)}" + "\n" + '\n'.join([str(x.op) for x in self.lst])
|
|
def last_ctx(self): return self.child_ctx if self.child_ctx is not None else self.ctx
|
|
|
|
def _sort_ctx(inp): return tuple(sorted(dedup(inp), key=lambda x: x.tuplize))
|
|
|
|
# ***** block context *****
|
|
|
|
@dataclass
|
|
class BlockContext:
|
|
child_count: dict[UOp, int]
|
|
block_ctxs: dict[UOp, tuple[UOp, ...]]
|
|
child_ctxs: dict[UOp, tuple[UOp, ...]]
|
|
def last_ctx(self, u): return ret if (ret:=self.child_ctxs.get(u)) is not None else self.block_ctxs[u]
|
|
@staticmethod
|
|
def from_sink(sink:UOp) -> BlockContext:
|
|
# get children and all block contexts
|
|
ctx = BlockContext({}, {}, {})
|
|
for u in sink.toposort():
|
|
this_block_ctx: list[UOp] = []
|
|
ctx.child_count[u] = 0
|
|
|
|
# get children and accumulate the last_ctx
|
|
for s in u.src:
|
|
# NOTE: if a parent appears multiple times in the src, it counts multiple times as a child
|
|
ctx.child_count[s] += 1
|
|
this_block_ctx += ctx.last_ctx(s)
|
|
|
|
# save the block ctx
|
|
ctx.block_ctxs[u] = _sort_ctx(this_block_ctx)
|
|
|
|
# RANGE/IF add to the next ctx
|
|
# STORE/ASSIGN subtract from the next ctx
|
|
if u.op in {Ops.RANGE, Ops.IF}: ctx.child_ctxs[u] = _sort_ctx(ctx.block_ctxs[u] + (u,))
|
|
elif u.op is Ops.STORE:
|
|
# ugh, deal with non-reduce locals. probably wrong
|
|
if any(x.op is Ops.DEFINE_LOCAL for x in u.src[0].toposort()):
|
|
idx_context, store_context = ctx.last_ctx(u.src[0]), ctx.last_ctx(u.src[1])
|
|
ctx.child_ctxs[u] = tuple([y for y in store_context if y not in idx_context and y.op is Ops.RANGE])
|
|
else: ctx.child_ctxs[u] = ()
|
|
elif u.op is Ops.ASSIGN:
|
|
assert u.src[0].op is Ops.DEFINE_ACC
|
|
ctx.child_ctxs[u] = tuple([y for y in ctx.last_ctx(u.src[1]) if y not in u.src[0].src[1:]])
|
|
return ctx
|
|
|
|
# ***** make blocks *****
|
|
|
|
DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST}
|
|
|
|
def add_blockends(base_block:UOp, new_ctx:tuple[UOp, ...], current_ctx:tuple[UOp, ...], cnt:int=1) -> UOp:
|
|
ends_to_add = [z for z in new_ctx if z not in current_ctx]
|
|
while len(ends_to_add):
|
|
r:UOp = ends_to_add.pop(-1)
|
|
new_ctx = tuple([z for z in new_ctx if z is not r])
|
|
end_uop = UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,))
|
|
base_block = UOp(Ops.BLOCKEND, src=(base_block,)*cnt, arg=BasicBlock2((end_uop,), tuple(new_ctx), end=r, cnt=cnt))
|
|
return base_block
|
|
|
|
def make_block_bottom_up(ctx:BlockContext, x:UOp):
|
|
if x.op is Ops.BLOCKSTART:
|
|
current_ctx, child_ctx = x.arg
|
|
lst = list(x.src)
|
|
child_count = 1
|
|
else:
|
|
current_ctx, child_count, child_ctx = ctx.block_ctxs[x], ctx.child_count[x], ctx.child_ctxs.get(x, None)
|
|
lst = [x]
|
|
|
|
# count of times we've seen this block, or a seed for a new block if we can't merge it
|
|
unmergable: defaultdict[UOp, int] = defaultdict(int)
|
|
blockseeds = defaultdict(list)
|
|
|
|
# add the srcs of this to the frontier
|
|
# NOTE: things may be in here multiple times, that's okay
|
|
frontier_nodes = list(flatten(y.src[::-1] for y in lst))
|
|
while len(frontier_nodes):
|
|
u = frontier_nodes.pop(0)
|
|
if u.op not in DONT_PLACE_IN_BLOCK and ctx.child_count[u] == unmergable[u]+1:
|
|
# count is correct
|
|
if (newctx:=ctx.block_ctxs[u]) == current_ctx:
|
|
# block has same context, merge it, and put the srcs on the frontier
|
|
lst.append(u)
|
|
frontier_nodes.extend(u.src[::-1])
|
|
else:
|
|
# block has different context, add it to blockseeds
|
|
blockseeds[(newctx, ctx.child_ctxs.get(u, None))].append(u)
|
|
del unmergable[u]
|
|
else:
|
|
# count is incorrect (or it's DONT_PLACE_IN_BLOCK), add it to unmergable
|
|
unmergable[u] += 1
|
|
|
|
# add unmergables to sources
|
|
srcs = []
|
|
for u,cnt in unmergable.items(): srcs += [add_blockends(u, ctx.block_ctxs[u], current_ctx, cnt=cnt)]*cnt
|
|
|
|
# add blockseeds, with blockends as needed
|
|
for (new_ctx, new_child_ctx), v in blockseeds.items():
|
|
base_block = UOp(Ops.BLOCKSTART, src=tuple(v), arg=(new_ctx, new_child_ctx))
|
|
srcs.append(add_blockends(base_block, new_ctx, current_ctx))
|
|
|
|
lst = block_reorder(lst[::-1])
|
|
bb = BasicBlock2(tuple(lst), ctx=current_ctx, cnt=child_count, child_ctx=child_ctx)
|
|
return UOp(Ops.BLOCK, src=tuple(srcs), arg=bb)
|
|
|
|
block_create = PatternMatcher([
|
|
(UPat(GroupOp.All-DONT_PLACE_IN_BLOCK.union({Ops.BLOCK, Ops.BLOCKEND}), name="x"), make_block_bottom_up),
|
|
])
|
|
|
|
# ***** blockend merging ****
|
|
|
|
def merge_blockends(sink:UOp) -> UOp|None:
|
|
# only run on the final BLOCK with the SINK in it
|
|
if sink.arg.lst[-1].op is not Ops.SINK: return None
|
|
# combine matching BLOCKENDS, the keys of this dictionary are the RANGE UOps, values are the BLOCKENDs
|
|
blockends_to_arg: dict[UOp, list[UOp]] = {}
|
|
for be in sink.toposort():
|
|
if be.op is Ops.BLOCKEND: blockends_to_arg.setdefault(be.arg.end, []).append(be)
|
|
new_forks = {}
|
|
for k,v in blockends_to_arg.items():
|
|
# NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails
|
|
if len(v) > 1:
|
|
bb = BasicBlock2(v[0].arg.lst, _sort_ctx(flatten([y.arg.ctx for y in v])), k, cnt=sum(y.arg.cnt for y in v))
|
|
out = UOp(Ops.BLOCKEND, src=tuple(flatten([x.src for x in v])), arg=bb)
|
|
# NOTE: bb.ctx != u.arg.ctx can cause problems here
|
|
for u in v: new_forks[u] = out
|
|
if len(new_forks) == 0: return None
|
|
return sink.substitute(new_forks)
|
|
|
|
pm_blockend_merge = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), merge_blockends)])
|
|
|
|
# ***** block merging ****
|
|
|
|
def merge_block(x:UOp):
|
|
unmergable_blocks, mergable_blocks = [], []
|
|
mergable_dict: defaultdict[UOp, int] = defaultdict(int)
|
|
for y in x.src:
|
|
if y.op is Ops.BLOCK and x.op is Ops.BLOCK and x.arg.ctx == y.arg.ctx: mergable_dict[y] += 1
|
|
elif y.op is Ops.BLOCK and x.op is Ops.BLOCKEND and x.arg.end in y.arg.ctx: mergable_dict[y] += 1
|
|
else: unmergable_blocks.append(y)
|
|
for k,v in mergable_dict.items():
|
|
if v == k.arg.cnt: mergable_blocks.append(k)
|
|
else: unmergable_blocks.extend([k]*v)
|
|
if len(mergable_blocks) == 0: return None
|
|
del mergable_dict
|
|
|
|
# create the block
|
|
arg = replace(x.arg, lst=tuple(flatten([y.arg.lst for y in mergable_blocks]))+x.arg.lst)
|
|
return UOp(x.op, src=tuple(flatten([y.src for y in mergable_blocks])+unmergable_blocks), arg=arg)
|
|
|
|
def remove_blockend(x:UOp):
|
|
# if there's any remaining blocks that need to go in this BLOCKEND, we don't remove it
|
|
if any(x.arg.end in y.arg.ctx for y in x.src if y.op in {Ops.BLOCK, Ops.BLOCKEND}): return None
|
|
|
|
parent_blocks = [y for y in x.src if y.op is Ops.BLOCK and y.arg.child_ctx is not None and x.arg.end in y.arg.child_ctx]
|
|
assert all_same(parent_blocks), f"should never have two parent blocks (has {len(parent_blocks)})"
|
|
if len(parent_blocks) > 0:
|
|
parent_block = parent_blocks[0]
|
|
assert len(parent_blocks) == parent_block.arg.cnt
|
|
# range needs DEFINE_ACC to be before the range (never in DEFINE_ACC for if)
|
|
early_ops, late_ops = partition(x.arg.lst, lambda y: y.op is Ops.DEFINE_ACC and x.arg.end in y.src)
|
|
# NOTE: we have to add a barrier at the start if barrier is used in the range
|
|
if x.op is Ops.BLOCKEND and any(y.op is Ops.BARRIER for y in late_ops) and late_ops[-1].op is Ops.ENDRANGE:
|
|
late_ops = [UOp(Ops.BARRIER)] + late_ops
|
|
arg = BasicBlock2(tuple(early_ops)+parent_block.arg.lst+tuple(late_ops), tuple([y for y in x.arg.ctx if y is not x.arg.end]), cnt=x.arg.cnt)
|
|
return UOp(Ops.BLOCK, src=tuple(y for y in x.src if y is not parent_block)+parent_block.src, arg=arg)
|
|
|
|
block_merge = PatternMatcher([
|
|
(UPat((Ops.BLOCK, Ops.BLOCKEND), name="x"), merge_block),
|
|
(UPat(Ops.BLOCKEND, name="x"), remove_blockend),
|
|
])
|
|
|
|
# ****** finalize ******
|
|
|
|
def finalize(sink:UOp) -> UOp:
|
|
if sink.op is not Ops.BLOCK or not all(x.op in DONT_PLACE_IN_BLOCK for x in sink.src):
|
|
raise RuntimeError("linearize failure")
|
|
|
|
# place the early things
|
|
lst = sorted(dedup(sink.src), key=lambda x: x.tuplize) + list(sink.arg.lst)
|
|
|
|
if __debug__: type_verify(lst)
|
|
|
|
return UOp(Ops.BLOCKFINAL, arg=BasicBlock2(tuple(lst)))
|
|
|
|
pm_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), finalize)])
|
|
|