from typing import cast from dataclasses import dataclass, field from collections import deque, defaultdict from tinygrad.ops import UOp, Variable, Ops, UPat, PatternMatcher, graph_rewrite, buffers from tinygrad.device import Buffer, MultiBuffer from tinygrad.helpers import Metadata, unwrap, merge_dicts # **** ScheduleItem return type @dataclass(frozen=True) class ScheduleItem: ast: UOp bufs: tuple[Buffer, ...] metadata: tuple[Metadata, ...] = () fixedvars: dict[Variable, int] = field(default_factory=dict) # **** unbind Variables def unbind_view(ctx:list[dict[Variable, int]], x:UOp): st = unwrap(x.st).simplify() if any(x.op is Ops.BIND for x in st.vars()): st, var_vals = st.unbind() ctx.append(var_vals) return x.replace(arg=st) if st != x.st else None def unbind_bind(ctx:list[dict[Variable, int]], x:UOp): var, val = x.unbind() ctx.append({var.replace(src=()):val}) return var pm_unbind = PatternMatcher([ (UPat(Ops.VIEW, name="x"), unbind_view), (UPat(Ops.BIND, name="x"), unbind_bind), ]) # **** schedule linearizer def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]: # construct the KERNEL children graph based on assigns children: defaultdict[UOp, list[UOp]] = defaultdict(list) in_degree: dict[UOp, int] = {} for u in (toposort:=sched_sink.toposort()): if u.op is not Ops.ASSIGN: continue k = u.src[1] in_degree.setdefault(k, 0) for s in k.src: if s.op is not Ops.ASSIGN: continue children[s.src[1]].append(k) in_degree[k] += 1 # linearize KERNEL UOps into ScheduleItems in BFS order queue = deque(k for k,v in in_degree.items() if v == 0) schedule: list[ScheduleItem] = [] var_vals: dict[Variable, int] = {} while queue: k = queue.popleft() # unbind var_vals from the kernel local_var_vals: list[dict[Variable, int]] = [] ast = graph_rewrite(k.arg.ast, pm_unbind, ctx=local_var_vals, name="unbind vars") var_vals = merge_dicts([var_vals, *local_var_vals]) # create subbuffers if needed if ast.op is Ops.BUFFER_VIEW: base = k.src[1].buf_uop.buffer assert isinstance(base, Buffer), "base can't be MultiBuffer" buffers[k.src[0]] = base.view(k.size, ast.dtype, ast.arg[1]*base.dtype.itemsize) ubufs = tuple(s.buf_uop.buffer for s in k.src) if any(isinstance(x, MultiBuffer) for x in ubufs): if ast.op is Ops.COPY: if isinstance(ubufs[1], MultiBuffer) and ast.arg is None: # src is multiple buffers, none selected if isinstance(ubufs[0], MultiBuffer): # COPY ALL -> ALL for b1,b2 in zip(ubufs[0].bufs, ubufs[1].bufs): schedule.append(ScheduleItem(ast, (b1, b2), k.arg.metadata)) else: # COPY ANY -> ONE. Currently we just select the first schedule.append(ScheduleItem(ast, (ubufs[0], ubufs[1].bufs[0]), k.arg.metadata)) else: src_buf = ubufs[1].bufs[ast.arg] if isinstance(ubufs[1], MultiBuffer) else ubufs[1] if isinstance(ubufs[0], MultiBuffer): # COPY ONE -> ALL (BROADCAST) for b in ubufs[0].bufs: schedule.append(ScheduleItem(ast, (b, src_buf), k.arg.metadata)) else: schedule.append(ScheduleItem(ast, (ubufs[0], src_buf), k.arg.metadata)) # COPY ONE -> ONE else: assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer" dnums = [x for x in ast.variables() if x.arg[0] == '_device_num'] for i,bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])): schedule.append(ScheduleItem(ast, bufs, k.arg.metadata, {dnums[0]:i} if len(dnums) else {})) else: schedule.append(ScheduleItem(ast, cast(tuple[Buffer, ...], ubufs), k.arg.metadata)) for x in children[k]: in_degree[x] -= 1 if in_degree[x] == 0: queue.append(x) # map ASSIGN to BUFFER after ScheduleItems are constructed becomes_map = {u:u.buf_uop for u in toposort if u.op is Ops.ASSIGN} assert all(u.op in {Ops.BUFFER, Ops.BUFFER_VIEW} for u in becomes_map.values()), f"Schedule didn't end with BUFFER {becomes_map.values()}" return schedule, var_vals, becomes_map