from __future__ import annotations import collections, heapq from dataclasses import dataclass from tinygrad.ops import UOp, Ops, PatternMatcher, UPat, graph_rewrite, GroupOp from tinygrad.spec import type_verify from tinygrad.dtype import dtypes, PtrDType from tinygrad.helpers import dedup, flatten, partition DONT_PLACE_IN_BLOCK = {Ops.NAME, Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST, *GroupOp.Block} def disp(y:UOp) -> str: if y.op is Ops.BLOCKSTART: return "w"+disp(y.src[0]) if y.op is Ops.IF: return f'IF{id(y)}' if y.op is Ops.RANGE: return str(y.arg) return "" @dataclass(frozen=True) class BasicBlock: ctx: tuple[UOp, ...] lst: tuple[UOp, ...] end: UOp|None = None def __lt__(self, o:BasicBlock): return tuple(x.tuplize for x in self.ctx+self.lst) < tuple(x.tuplize for x in o.ctx+o.lst) def __repr__(self): return f"{(str(disp(self.end))+' ') if self.end is not None else ''}"+\ f"{[disp(y) for y in self.ctx]} {len(self.lst)}" + "\n" + '\n'.join([str(x.op) for x in self.lst]) def append_to_block(ctx:tuple[dict[UOp, tuple[UOp, ...]], dict[UOp, list[UOp]]], x:UOp): block_ctxs, children = ctx in_this_block = set(x.arg.lst) # collections to build new_srcs: list[UOp] = [] to_append: list[UOp] = [] old_blocks: dict[tuple[UOp, ...], UOp] = {} new_blocks: dict[tuple[UOp, ...], list[UOp]] = {} seen_u = set() for u in x.src: if u.op is Ops.BLOCK: if u not in seen_u: # merge sibling blocks. NOTE: blocks must only have one output source assert u.arg.ctx not in old_blocks, "sibling should never have been created" old_blocks[u.arg.ctx] = u elif u.op not in DONT_PLACE_IN_BLOCK and set(children[u]).issubset(in_this_block): if u not in seen_u: # if it can go in blocks and all its children are in the block, we add it to the block if (block_ctx:=block_ctxs[u]) == x.arg.ctx: # if it's the same context, we place the UOp in this block and append the parents to its srcs new_srcs.extend(u.src) to_append.append(u) else: # if it's a different context, we create a new block with this UOp new_blocks.setdefault(block_ctx, []).append(u) else: # otherwise, we keep it in the srcs new_srcs.append(u) seen_u.add(u) if len(to_append) == 0 and len(new_blocks) == 0: return None for rng,lst in new_blocks.items(): srcs = flatten(y.src for y in lst) if (old_block:=old_blocks.pop(rng, None)) is not None: # NOTE: order shouldn't matter here srcs.extend(old_block.src) lst.extend(old_block.arg.lst) new_block = UOp(Ops.BLOCK, dtypes.void, tuple(dedup(srcs)), BasicBlock(rng, tuple(lst))) lrng = list(rng) for r in rng[::-1]: if r not in x.arg.ctx and r.op is not Ops.BLOCKSTART: lrng.remove(r) new_block = UOp(Ops.BLOCKEND, src=(new_block,), arg=BasicBlock(tuple(lrng), (UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,)),), r)) new_srcs.append(new_block) return UOp(Ops.BLOCK, dtypes.void, tuple(dedup(list(old_blocks.values())+new_srcs)), BasicBlock(x.arg.ctx, tuple(to_append)+x.arg.lst)) make_basic_blocks = PatternMatcher([ (UPat(Ops.SINK, name="x"), lambda x: UOp(Ops.BLOCK, src=x.src+((UOp(Ops.NAME, arg=x.arg.name),) if x.arg is not None else ()), arg=BasicBlock((), (x,)))), (UPat(Ops.BLOCK, name="x"), append_to_block), ]) def block_merge(ctx, x:UOp): # ctx is children here if x.op is Ops.BLOCKEND: # if it's a BLOCKEND, see if we are done with placement. if all the children of the range are in here in_this_block = set(x.arg.lst) if len([y for y in ctx[x.arg.end] if y not in in_this_block]) == 0: # find the parent block that has the BLOCKSTART in the ctx parent_blocks = [y for y in x.src if y.op is Ops.BLOCK and UOp(Ops.BLOCKSTART, src=(x.arg.end,)) in y.arg.ctx] assert len(parent_blocks) <= 1, "should never have two parent blocks" if len(parent_blocks) == 1: parent_block = parent_blocks[0] # range needs DEFINE_ACC to be before the range (never in DEFINE_ACC for if) early_ops, late_ops = partition(x.arg.lst, lambda y: y.op is Ops.DEFINE_ACC and x.arg.end in y.src) # NOTE: we have to add a barrier at the start if barrier is used in the range if x.op is Ops.BLOCKEND and any(y.op is Ops.BARRIER for y in late_ops) and late_ops[-1].op is Ops.ENDRANGE: late_ops = [UOp(Ops.BARRIER)] + late_ops return UOp(Ops.BLOCK, dtypes.void, tuple(y for y in x.src if y is not parent_block)+parent_block.src, BasicBlock(tuple(y for y in x.arg.ctx if y is not x.arg.end), tuple(early_ops)+parent_block.arg.lst+tuple(late_ops))) new_srcs: list[UOp] = [] to_append: list[UOp] = [] new_ctx = x.arg.ctx placed = set() for u in x.src: if u.op is Ops.BLOCK and (tuple(u.arg.ctx) == tuple(x.arg.ctx) or (x.arg.end is not None and x.arg.end in u.arg.ctx)): # NOTE: this can't appear in srcs twice or it would be a BLOCKFORK new_ctx += tuple(y for y in u.arg.ctx if y not in x.arg.ctx) new_srcs.extend(u.src) to_append.extend(u.arg.lst) elif u.op is Ops.BLOCKFORK and x.src.count(u) == u.arg: # block fork appears # of times in srcs if u not in placed: new_srcs.extend(u.src) placed.add(u) else: # keep it in srcs new_srcs.append(u) if len(to_append) == 0 and len(placed) == 0: return None return UOp(x.op, dtypes.void, tuple(new_srcs), BasicBlock(tuple(dedup(sorted(new_ctx, key=lambda x: x.tuplize))), tuple(to_append)+x.arg.lst, x.arg.end)) pm_block_merge = PatternMatcher([(UPat((Ops.BLOCKEND, Ops.BLOCK), name="x"), block_merge),]) def block_finalize(block:UOp): if len(block.src) == 0: return None _uops = sorted(dedup(block.src), key=lambda x: x.tuplize) assert all(len(x.src) == 0 and x.op not in {Ops.BLOCK, Ops.BLOCKSTART, Ops.BLOCKEND, Ops.BLOCKFORK} for x in _uops) _uops += block.arg.lst # strip the SINK assert _uops[-1].op is Ops.SINK, "doesn't end with SINK" return UOp(Ops.BLOCK, arg=BasicBlock((), tuple(_uops[:-1]))) pm_block_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="block"), block_finalize)]) # NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed def block_reorder(in_block:UOp): in_this_block = set(in_block.arg.lst) local_children: collections.defaultdict[UOp, list[UOp]] = collections.defaultdict(list) in_degree: collections.defaultdict[UOp, int] = collections.defaultdict(int) priorities:dict[UOp, int] = {} # get local children and assign priorities for u in reversed(in_block.arg.lst): for s in u.src: if s in in_this_block: local_children[s].append(u) in_degree[u] += 1 # put loads in the beginning of the block and prevent priority inversion. hack for BARRIER grouping too priority = [0] + [priorities[x] for x in local_children[u]] if u.op is Ops.LOAD: priority.append(-1000) if u.op is Ops.BARRIER: priority.append(-1500) priorities[u] = min(priority) # placement queue queue:list[tuple[int, tuple, UOp]] = [] def push(u:UOp): heapq.heappush(queue, (priorities[u], u.tuplize, u)) # place the first ones that don't have deps for u in in_block.arg.lst: if u not in in_degree: push(u) newlst = [] while queue: _,_,x = heapq.heappop(queue) newlst.append(x) for u in local_children[x]: in_degree[u] -= 1 if in_degree[u] == 0: push(u) assert len(newlst) == len(in_block.arg.lst), f"len mismatch {len(newlst)} != {len(in_block.arg.lst)}" return in_block.replace(arg=BasicBlock(in_block.arg.ctx, tuple(newlst))) def upsettingly_promote_blockend(be:UOp): new_srcs = tuple(b.replace(arg=BasicBlock(be.arg.ctx, b.arg.lst)) if b.op is Ops.BLOCK else b for b in be.src) return be.replace(src=new_srcs) if be.src != new_srcs else None pm_force_upcast_block = PatternMatcher([(UPat(Ops.BLOCKEND, name="be"), upsettingly_promote_blockend)]) def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]: assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}" # get children and all block contexts temp_block_ctxs: dict[UOp, list[UOp]] = {} children: dict[UOp, list[UOp]] = {} for u in sink.toposort: this_block_ctx: list[UOp] = [] for s in u.src: # save children children.setdefault(s, []).append(u) # compute block ctx if s.op in {Ops.RANGE, Ops.IF}: this_block_ctx.append(s) # don't flow (fully) through assign and store elif s.op is Ops.STORE: # ugh, deal with non-reduce locals. probably wrong if isinstance(s.src[0].dtype, PtrDType) and s.src[0].dtype.local: idx_context, store_context = temp_block_ctxs[s.src[0]], temp_block_ctxs[s] this_block_ctx += [x for x in store_context if x not in idx_context and x.op is Ops.RANGE] elif s.op is Ops.ASSIGN: # flow though assign, but remove the ranges used in the assign assert s.src[0].op is Ops.DEFINE_ACC this_block_ctx += [x for x in temp_block_ctxs[s.src[1]] if x not in s.src[0].src[1:]] else: # flow though everything else this_block_ctx += temp_block_ctxs[s] temp_block_ctxs[u] = sorted(dedup(this_block_ctx), key=lambda x: x.tuplize) # make final block_ctxs, add BLOCKSTART to block_ctxs for IF and RANGE block_ctxs: dict[UOp, tuple[UOp, ...]] = {} for u in sink.toposort: block_ctxs[u] = ((UOp(Ops.BLOCKSTART, src=(u,)),) + tuple(temp_block_ctxs[u])) if u.op in {Ops.IF, Ops.RANGE} else tuple(temp_block_ctxs[u]) # TODO: there's probably a clever way to remove this while loop while 1: sink = graph_rewrite(sink, make_basic_blocks, ctx=(block_ctxs, children)) # add BLOCKFORK (slow!) block_parent_count = collections.Counter(flatten([x.src for x in sink.toposort if x.op is Ops.BLOCK])) non_block_parents = set(flatten([x.src for x in sink.toposort if x.op is not Ops.BLOCK])) forks = {u:UOp(Ops.BLOCKFORK, src=(UOp(Ops.BLOCK, src=u.src, arg=BasicBlock(block_ctxs[u], (u,))),), arg=child_count) for u,child_count in block_parent_count.items() if u.op not in DONT_PLACE_IN_BLOCK and child_count > 1 and u not in non_block_parents} if not len(forks): break sink = sink.substitute(forks) # combine matching BLOCKENDS, the keys of this dictionary are the RANGE UOps, values are the BLOCKENDs blockends_to_arg: dict[UOp, list[UOp]] = {} for be in sink.toposort: if be.op is Ops.BLOCKEND: blockends_to_arg.setdefault(be.arg.end, []).append(be) new_forks = {} for k,v in blockends_to_arg.items(): # NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails if len(v) > 1: out = UOp(Ops.BLOCKFORK, src=(UOp(Ops.BLOCKEND, src=tuple(flatten(x.src for x in v)), arg=BasicBlock(tuple(dedup(flatten([y.arg.ctx for y in v]))), v[0].arg.lst, k)),), arg=len(v)) for u in v: new_forks[u] = out sink = sink.substitute(new_forks) # reorder ops in block for speed sink = sink.substitute({u:newu for u in sink.toposort if u.op is Ops.BLOCK and (newu:=block_reorder(u)) is not u}) # final rewrite to merge all blocks into one sink = graph_rewrite(sink, pm_block_merge, ctx=children) # if there's BLOCKENDs left in the graph, we might have to merge. TODO: is there a better way to handle this? while (newsink := graph_rewrite(sink, pm_force_upcast_block)) is not sink: sink = graph_rewrite(newsink, pm_block_merge, ctx=children, name="bad_merge") # there should just be one block left, with a few parents with 0 srcs (now done in a rewriter) sink = graph_rewrite(sink, pm_block_finalize) # sanity checks (NOTE: these can cause things to be skipped in BEAM) if not skip_check: type_verify(sink.arg.lst) # return the list. TODO: refactor to return the UOp return list(sink.arg.lst)