from typing import cast from dataclasses import dataclass, field from collections import deque, defaultdict from tinygrad.uop.ops import UOp, Variable, Ops, buffers from tinygrad.device import Device, Buffer, MultiBuffer from tinygrad.helpers import Metadata, all_same # **** ScheduleItem return type @dataclass(frozen=True) class ScheduleItem: ast: UOp bufs: tuple[Buffer, ...] metadata: tuple[Metadata, ...] = () fixedvars: dict[Variable, int] = field(default_factory=dict) # **** schedule linearizer def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int]]: # construct the KERNEL children graph based on assigns children: defaultdict[UOp, list[UOp]] = defaultdict(list) in_degree: dict[UOp, int] = {} var_vals: dict[Variable, int] = {} for u in sched_sink.toposort(): if u.op is not Ops.ASSIGN: continue # anything that's not an ASSIGN doesn't write a kernel, so we can skip k = u.src[1] in_degree.setdefault(k, 0) for s in k.src: if s.op is Ops.ASSIGN: children[s.src[1]].append(k) in_degree[k] += 1 elif s.op in {Ops.MSELECT, Ops.MSTACK}: for ss in s.src: if ss.op is Ops.MSELECT: ss = ss.src[0] if ss.op is not Ops.BUFFER: assert ss.op is Ops.ASSIGN children[ss.src[1]].append(k) in_degree[k] += 1 elif s.op is Ops.BUFFER: pass # a BUFFER is already realized, nothing to do here elif s.op is Ops.BIND: var, val = s.unbind() assert var not in var_vals or var_vals[var] == val, f"bind mismatch on {var}, {var_vals[var]} != {val}" var_vals[var] = val else: raise RuntimeError(f"input to kernel must be ASSIGN or BUFFER, not {s.op}") # linearize KERNEL UOps into ScheduleItems in BFS order def _heuristic(k: UOp): if k.arg.ast.op is Ops.COPY and not all_same([Device[cast(Buffer, s.buf_uop.buffer).device].group_id for s in k.src]): return 1000 return 0 last_heuristic: int = 0 queues: defaultdict[int, deque[UOp]] = defaultdict(deque) last_queue: deque[UOp] = deque() for k,v in in_degree.items(): if v == 0: queues[_heuristic(k)].append(k) schedule: list[ScheduleItem] = [] while last_queue or any(queues.values()): if not last_queue: last_heuristic, last_queue = min((it for it in queues.items() if it[1]), key=lambda x: abs(x[0]-last_heuristic)) k = last_queue.popleft() ast = k.arg.ast # create subbuffers if needed if ast.op is Ops.BUFFER_VIEW: base = k.src[1].buf_uop.buffer assert isinstance(base, Buffer), "base can't be MultiBuffer" buffers[k.src[0]] = base.view(k.size, ast.dtype, ast.arg[1]*base.dtype.itemsize) ubufs = tuple(s.buf_uop.buffer for s in k.src if s.op is not Ops.BIND) if any(isinstance(x, MultiBuffer) for x in ubufs): assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer" dnums = [x for x in ast.variables() if x.arg[0] == '_device_num'] for i,bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])): schedule.append(ScheduleItem(ast, bufs, k.arg.metadata, {dnums[0]:i} if len(dnums) else {})) else: # ONE -> ONE schedule.append(ScheduleItem(ast, cast(tuple[Buffer, ...], ubufs), k.arg.metadata)) for x in children[k]: in_degree[x] -= 1 if in_degree[x] == 0: queues[_heuristic(x)].append(x) return schedule, var_vals