import collections, time from typing import Any, cast, Optional from tinygrad.helpers import round_up, PROFILE from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEntry, ProfileGraphEvent from tinygrad.dtype import dtypes from tinygrad.ops import UOp, Variable from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner from tinygrad.engine.jit import MultiGraphRunner class HCQGraph(MultiGraphRunner): def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]): super().__init__(jit_cache, input_rawbuffers, var_vals) self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs])) # Replace input buffers with variables. self.hcq_bufs = [[cast(Buffer, x)._buf for x in ji.bufs] for ji in jit_cache] self.input_replace_to_var: dict[tuple[int, int], Variable] = {} for (j,i), input_idx in self.input_replace.items(): x = self.input_replace_to_var.setdefault((j,i), UOp.variable(f"input_{input_idx}", 0, 0xffffffffffffffff, dtype=dtypes.uint64)) self.hcq_bufs[j][i] = HCQBuffer(x, self.hcq_bufs[j][i].size, texture_info=self.hcq_bufs[j][i].texture_info) # Create fake buffer with variable # Allocate kernel args. kernargs_size: dict[Compiled, int] = collections.defaultdict(int) for ji in jit_cache: if not isinstance(ji.prg, CompiledRunner): continue kernargs_size[ji.prg.dev] += round_up(ji.prg._prg.kernargs_alloc_size, 16) self.kernargs_bufs: dict[Compiled, HCQBuffer] = {dev:dev.allocator._alloc(sz, BufferSpec(cpu_access=True)) for dev,sz in kernargs_size.items()} # Fill initial arguments. self.ji_args: dict[int, HCQArgsState] = {} kargs_alloc: dict[Compiled, BumpAllocator] = {dev:BumpAllocator(buf.size, start=cast(int, buf.va_addr)) for dev,buf in self.kernargs_bufs.items()} for j,ji in enumerate(jit_cache): if not isinstance(ji.prg, CompiledRunner): continue self.ji_args[j] = ji.prg._prg.fill_kernargs(self.hcq_bufs[j], ji.prg.p.vars, kargs_alloc[ji.prg.dev].alloc(ji.prg._prg.kernargs_alloc_size, 16)) # Schedule Dependencies. # There are two types of queues on each device: copy and compute. Both must synchronize with all external operations before launching any # graph-related tasks. This synchronization uses a global timeline signal per device. Within the graph, the compute queue coordinates with # global operations and sets a kickoff signal. Any queue accessing a buffer from another device waits for this signal from the device’s # compute queue to ensure exclusive access. The compute queue signals the completion of the graph, synchronizing with the device's copy queue. self.ji_schedule: dict[int, tuple[HCQCompiled, HWQueue, list, list, HCQSignal, Optional[int]]] = {} self.comp_queues: dict[HCQCompiled, HWQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices} self.copy_queues: dict[HCQCompiled, HWQueue] = {} # lazy allocation self.signals: dict[Any, HCQSignal] = {**{dev: dev.signal_t(value=0) for dev in self.devices}, **{"CPU": self.devices[0].signal_t(value=0)}} self.kickoff_value: int = 0 self.kickoff_var = UOp.variable("kickoff_var", 0, 0xffffffff, dtype=dtypes.uint32) # When profiling allocate 2 signals for each jit item to measure speed. The jth jit item have signals at 2*j and 2*j+1. # TODO: This logic might allocate a few extra signals... self.prof_signals: list[HCQSignal] = [self.devices[0].signal_t() for i in range(len(jit_cache) * 2)] if PROFILE else [] self.prog_graph_deps: list[list[int]] = [] self.prof_graph_entries: list[ProfileGraphEntry] = [] last_j: dict[HWQueue, Optional[int]] = collections.defaultdict(lambda: None) queue_access: dict[HWQueue, dict[HWQueue, Optional[int]]] = collections.defaultdict(lambda: collections.defaultdict(lambda: None)) dev_access: dict[HWQueue, set[HCQCompiled]] = collections.defaultdict(set) for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev) for j,ji in enumerate(jit_cache): enqueue_dev: HCQCompiled = ji.prg.dev if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore if is_exec_prg: enqueue_queue = self.comp_queues[enqueue_dev] else: assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue" enqueue_queue = self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t()) out_signal = self.signals.setdefault(enqueue_queue, enqueue_dev.signal_t(value=0)) # Get dependencies based on input and output buffers. rdeps = self._access_resources(ji.bufs, ji.prg.p.outs if is_exec_prg else [0], (enqueue_queue, j + 1)) #type:ignore # Update dependencies to include previous kernel in queue. This is required for timeline signals. opt_deps, deps = [], rdeps + ([(enqueue_queue, prev_ji + 1)] if (prev_ji:=last_j[enqueue_queue]) is not None else []) # Optimize dependencies by removing redundant ones. Remove waiting for the value of the queue which is known to be already # synced with the current queue. for dep_queue, dep_val in sorted(deps, key=lambda x: x[1], reverse=True): if (qa:=queue_access[enqueue_queue][dep_queue]) is None or qa < dep_val: opt_deps.append((self.signals[dep_queue], dep_val)) queue_access[enqueue_queue][dep_queue] = dep_val # Ensure device is ready for use in current context: the graph has initialized the device and it's safe to operate on it within this graph. for dep_queue, _ in opt_deps: dev_access[enqueue_queue].update(dev_access[dep_queue]) sync_signals = [(self.signals[d], self.kickoff_var) for b in ji.bufs if (d:=Device[cast(Buffer, b).device]) not in dev_access[enqueue_queue]] dev_access[enqueue_queue].update(cast(HCQCompiled, Device[cast(Buffer, b).device]) for b in ji.bufs) # Remove self-dependency for compute and copy queues. # For compute, in case of NV, optimize when only 1 same-queue dependency exists, since NV chains 2+ executions in this case, # eliminating dependency need. dname = enqueue_dev.device.split(":", 1)[0] can_opt = dname in {"AMD", "QCOM"} or (dname == "NV" and len(sync_signals) == 0 and len(opt_deps) == 1 and id(opt_deps[0][0]) == id(out_signal)) if can_opt or isinstance(ji.prg, BufferXfer): opt_deps = [x for x in opt_deps if id(x[0]) != id(out_signal)] # Enable necessary signals in the schedule by setting the signal value. for sig, val in opt_deps: self.ji_schedule[val - 1] = self.ji_schedule[val - 1][:5] + (val,) self.ji_schedule[j] = (enqueue_dev, enqueue_queue, sync_signals, opt_deps[::-1], out_signal, None if is_exec_prg else (j + 1)) # Collect profile information if profiling is enabled. if PROFILE: # When execution are chained, we can reuse the end timestamp from the previous command as the start timestamp for the current command. sig_st = prev_ji * 2 + 1 if len(opt_deps) == 0 and (prev_ji:=last_j[enqueue_queue]) is not None else j * 2 # Description based on the command. prof_ji_desc = ji.prg._prg.name if is_exec_prg else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore self.prof_graph_entries.append(ProfileGraphEntry(enqueue_dev.device, prof_ji_desc, sig_st, j * 2 + 1, is_copy=not is_exec_prg)) self.prog_graph_deps.append([d - 1 for _, d in rdeps]) last_j[enqueue_queue] = j # Check which signals are used in the profile graph. self.prof_signal_is_used = [any(ent.st_id == j or ent.en_id == j for ent in self.prof_graph_entries) for j in range(len(self.prof_signals))] # Build hardware queues. self.copy_to_devs: dict[HCQCompiled, set[HCQCompiled]] = {dev: set() for dev in self.devices} # Create variable timeline signals for each device. timeline_sigaddrs = {dev: UOp.variable(f"timeline_sig_{dev.device_id}", 0, 0xffffffffffffffff, dtype=dtypes.uint64) for dev in self.devices} self.virt_timeline_vals = {dev: UOp.variable(f"timeline_var_{dev.device_id}", 0, 0xffffffff, dtype=dtypes.uint32) for dev in self.devices} self.virt_timeline_signals = {dev: dev.signal_t(base_addr=timeline_sigaddrs[dev], timeline_for_device=dev) for dev in self.devices} for dev in self.devices: self.comp_queues[dev].memory_barrier().wait(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev]) \ .wait(self.signals['CPU'], self.kickoff_var).signal(self.signals[dev], self.kickoff_var) for j,ji in enumerate(jit_cache): enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j] for sig, val in sync_signals + deps: enqueue_queue.wait(sig, val) # Encode waits and start profile timestamp (if needed). if PROFILE and self.prof_signal_is_used[j * 2]: enqueue_queue.timestamp(self.prof_signals[j * 2]) # Encode main commands based on ji type. if isinstance(ji.prg, CompiledRunner): enqueue_queue.exec(ji.prg._prg, self.ji_args[j], tuple(ji.prg.p.global_size or (1,1,1)), tuple(ji.prg.p.local_size or (1,1,1))) elif isinstance(ji.prg, BufferXfer): dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]] cast(HCQAllocator, Device[src.device].allocator).map(dest._buf) enqueue_queue.copy(self.hcq_bufs[j][0].va_addr, self.hcq_bufs[j][1].va_addr, dest.nbytes) self.copy_to_devs[cast(HCQCompiled, Device[dest.device])].add(cast(HCQCompiled, Device[src.device])) # Encode finish profile timestamp (if needed). if PROFILE and self.prof_signal_is_used[j * 2 + 1]: enqueue_queue.timestamp(self.prof_signals[j * 2 + 1]) if signal_val is not None: enqueue_queue.signal(signal, signal_val) for dev in self.devices: for dep_dev in list(self.copy_to_devs[dev]) + [dev]: if dep_dev in self.copy_queues: self.comp_queues[dev].wait(self.signals[(copy_q:=self.copy_queues[dep_dev])], cast(int, last_j[copy_q]) + 1) self.comp_queues[dev].signal(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev] + 1).bind(dev) if dev in self.copy_queues: self.copy_queues[dev].bind(dev) self.last_timeline: dict[HCQCompiled, tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices} self.queue_signals_to_reset = [self.signals[q] for q in list(self.comp_queues.values()) + list(self.copy_queues.values()) if q in self.signals] def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> Optional[float]: # Wait and restore signals self.kickoff_value += 1 for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1]) for sig in self.queue_signals_to_reset: sig.value = 0 self.signals['CPU'].value = self.kickoff_value if PROFILE and self.kickoff_value > 1: self.collect_timestamps() hcq_var_vals = {self.kickoff_var: self.kickoff_value, **var_vals, **{var: dev.timeline_value - 1 for dev, var in self.virt_timeline_vals.items()}, **{sig.base_addr: dev.timeline_signal.base_addr for dev, sig in self.virt_timeline_signals.items()}} # Update rawbuffers for (j,i),input_idx in self.input_replace.items(): hcq_var_vals[self.input_replace_to_var.get((j,i))] = input_rawbuffers[input_idx]._buf.va_addr for dev in self.devices: self.comp_queues[dev].submit(dev, hcq_var_vals) if (copy_queue:=self.copy_queues.get(dev, None)) is not None: copy_queue.submit(dev, hcq_var_vals) self.last_timeline[dev] = (dev.timeline_signal, dev.timeline_value) dev.timeline_value += 1 if wait: st = time.perf_counter() for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1]) return time.perf_counter() - st return None def collect_timestamps(self): # NOTE: Append to any device is fine... self.devices[0].profile_events += [ProfileGraphEvent(self.prof_graph_entries, self.prog_graph_deps, [s.timestamp for s in self.prof_signals])] def __del__(self): for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1]) if PROFILE and self.kickoff_value >= 1: self.collect_timestamps() for fdev, buf in self.kernargs_bufs.items(): fdev.allocator._free(buf, BufferSpec(cpu_access=True))