You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
205 lines
12 KiB
205 lines
12 KiB
import collections, time
|
|
from typing import Any, cast, Optional
|
|
from tinygrad.helpers import round_up, PROFILE
|
|
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator
|
|
from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEntry, ProfileGraphEvent
|
|
from tinygrad.dtype import dtypes
|
|
from tinygrad.ops import UOp, Variable
|
|
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
|
from tinygrad.engine.jit import MultiGraphRunner
|
|
|
|
class HCQGraph(MultiGraphRunner):
|
|
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
|
|
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
|
self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
|
|
|
|
# Replace input buffers with variables.
|
|
self.hcq_bufs = [[cast(Buffer, x)._buf for x in ji.bufs] for ji in jit_cache]
|
|
self.input_replace_to_var: dict[tuple[int, int], Variable] = {}
|
|
|
|
for (j,i), input_idx in self.input_replace.items():
|
|
x = self.input_replace_to_var.setdefault((j,i), UOp.variable(f"input_{input_idx}", 0, 0xffffffffffffffff, dtype=dtypes.uint64))
|
|
self.hcq_bufs[j][i] = HCQBuffer(x, self.hcq_bufs[j][i].size, texture_info=self.hcq_bufs[j][i].texture_info) # Create fake buffer with variable
|
|
|
|
# Allocate kernel args.
|
|
kernargs_size: dict[Compiled, int] = collections.defaultdict(int)
|
|
for ji in jit_cache:
|
|
if not isinstance(ji.prg, CompiledRunner): continue
|
|
kernargs_size[ji.prg.dev] += round_up(ji.prg._prg.kernargs_alloc_size, 16)
|
|
self.kernargs_bufs: dict[Compiled, HCQBuffer] = {dev:dev.allocator._alloc(sz, BufferSpec(cpu_access=True)) for dev,sz in kernargs_size.items()}
|
|
|
|
# Fill initial arguments.
|
|
self.ji_args: dict[int, HCQArgsState] = {}
|
|
|
|
kargs_alloc: dict[Compiled, BumpAllocator] = {dev:BumpAllocator(buf.size, start=cast(int, buf.va_addr)) for dev,buf in self.kernargs_bufs.items()}
|
|
for j,ji in enumerate(jit_cache):
|
|
if not isinstance(ji.prg, CompiledRunner): continue
|
|
|
|
self.ji_args[j] = ji.prg._prg.fill_kernargs(self.hcq_bufs[j], ji.prg.p.vars, kargs_alloc[ji.prg.dev].alloc(ji.prg._prg.kernargs_alloc_size, 16))
|
|
|
|
# Schedule Dependencies.
|
|
# There are two types of queues on each device: copy and compute. Both must synchronize with all external operations before launching any
|
|
# graph-related tasks. This synchronization uses a global timeline signal per device. Within the graph, the compute queue coordinates with
|
|
# global operations and sets a kickoff signal. Any queue accessing a buffer from another device waits for this signal from the device’s
|
|
# compute queue to ensure exclusive access. The compute queue signals the completion of the graph, synchronizing with the device's copy queue.
|
|
self.ji_schedule: dict[int, tuple[HCQCompiled, HWQueue, list, list, HCQSignal, Optional[int]]] = {}
|
|
|
|
self.comp_queues: dict[HCQCompiled, HWQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
|
|
self.copy_queues: dict[HCQCompiled, HWQueue] = {} # lazy allocation
|
|
|
|
self.signals: dict[Any, HCQSignal] = {**{dev: dev.signal_t(value=0) for dev in self.devices}, **{"CPU": self.devices[0].signal_t(value=0)}}
|
|
self.kickoff_value: int = 0
|
|
self.kickoff_var = UOp.variable("kickoff_var", 0, 0xffffffff, dtype=dtypes.uint32)
|
|
|
|
# When profiling allocate 2 signals for each jit item to measure speed. The jth jit item have signals at 2*j and 2*j+1.
|
|
# TODO: This logic might allocate a few extra signals...
|
|
self.prof_signals: list[HCQSignal] = [self.devices[0].signal_t() for i in range(len(jit_cache) * 2)] if PROFILE else []
|
|
self.prog_graph_deps: list[list[int]] = []
|
|
self.prof_graph_entries: list[ProfileGraphEntry] = []
|
|
|
|
last_j: dict[HWQueue, Optional[int]] = collections.defaultdict(lambda: None)
|
|
queue_access: dict[HWQueue, dict[HWQueue, Optional[int]]] = collections.defaultdict(lambda: collections.defaultdict(lambda: None))
|
|
dev_access: dict[HWQueue, set[HCQCompiled]] = collections.defaultdict(set)
|
|
|
|
for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev)
|
|
|
|
for j,ji in enumerate(jit_cache):
|
|
enqueue_dev: HCQCompiled = ji.prg.dev if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore
|
|
|
|
if is_exec_prg:
|
|
enqueue_queue = self.comp_queues[enqueue_dev]
|
|
else:
|
|
assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue"
|
|
enqueue_queue = self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
|
|
|
|
out_signal = self.signals.setdefault(enqueue_queue, enqueue_dev.signal_t(value=0))
|
|
|
|
# Get dependencies based on input and output buffers.
|
|
rdeps = self._access_resources(ji.bufs, ji.prg.p.outs if is_exec_prg else [0], (enqueue_queue, j + 1)) #type:ignore
|
|
|
|
# Update dependencies to include previous kernel in queue. This is required for timeline signals.
|
|
opt_deps, deps = [], rdeps + ([(enqueue_queue, prev_ji + 1)] if (prev_ji:=last_j[enqueue_queue]) is not None else [])
|
|
|
|
# Optimize dependencies by removing redundant ones. Remove waiting for the value of the queue which is known to be already
|
|
# synced with the current queue.
|
|
for dep_queue, dep_val in sorted(deps, key=lambda x: x[1], reverse=True):
|
|
if (qa:=queue_access[enqueue_queue][dep_queue]) is None or qa < dep_val:
|
|
opt_deps.append((self.signals[dep_queue], dep_val))
|
|
queue_access[enqueue_queue][dep_queue] = dep_val
|
|
|
|
# Ensure device is ready for use in current context: the graph has initialized the device and it's safe to operate on it within this graph.
|
|
for dep_queue, _ in opt_deps: dev_access[enqueue_queue].update(dev_access[dep_queue])
|
|
sync_signals = [(self.signals[d], self.kickoff_var) for b in ji.bufs if (d:=Device[cast(Buffer, b).device]) not in dev_access[enqueue_queue]]
|
|
dev_access[enqueue_queue].update(cast(HCQCompiled, Device[cast(Buffer, b).device]) for b in ji.bufs)
|
|
|
|
# Remove self-dependency for compute and copy queues.
|
|
# For compute, in case of NV, optimize when only 1 same-queue dependency exists, since NV chains 2+ executions in this case,
|
|
# eliminating dependency need.
|
|
dname = enqueue_dev.device.split(":", 1)[0]
|
|
can_opt = dname in {"AMD", "QCOM"} or (dname == "NV" and len(sync_signals) == 0 and len(opt_deps) == 1 and id(opt_deps[0][0]) == id(out_signal))
|
|
if can_opt or isinstance(ji.prg, BufferXfer): opt_deps = [x for x in opt_deps if id(x[0]) != id(out_signal)]
|
|
|
|
# Enable necessary signals in the schedule by setting the signal value.
|
|
for sig, val in opt_deps: self.ji_schedule[val - 1] = self.ji_schedule[val - 1][:5] + (val,)
|
|
self.ji_schedule[j] = (enqueue_dev, enqueue_queue, sync_signals, opt_deps[::-1], out_signal, None if is_exec_prg else (j + 1))
|
|
|
|
# Collect profile information if profiling is enabled.
|
|
if PROFILE:
|
|
# When execution are chained, we can reuse the end timestamp from the previous command as the start timestamp for the current command.
|
|
sig_st = prev_ji * 2 + 1 if len(opt_deps) == 0 and (prev_ji:=last_j[enqueue_queue]) is not None else j * 2
|
|
|
|
# Description based on the command.
|
|
prof_ji_desc = ji.prg._prg.name if is_exec_prg else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore
|
|
|
|
self.prof_graph_entries.append(ProfileGraphEntry(enqueue_dev.device, prof_ji_desc, sig_st, j * 2 + 1, is_copy=not is_exec_prg))
|
|
self.prog_graph_deps.append([d - 1 for _, d in rdeps])
|
|
|
|
last_j[enqueue_queue] = j
|
|
|
|
# Check which signals are used in the profile graph.
|
|
self.prof_signal_is_used = [any(ent.st_id == j or ent.en_id == j for ent in self.prof_graph_entries) for j in range(len(self.prof_signals))]
|
|
|
|
# Build hardware queues.
|
|
self.copy_to_devs: dict[HCQCompiled, set[HCQCompiled]] = {dev: set() for dev in self.devices}
|
|
|
|
# Create variable timeline signals for each device.
|
|
timeline_sigaddrs = {dev: UOp.variable(f"timeline_sig_{dev.device_id}", 0, 0xffffffffffffffff, dtype=dtypes.uint64) for dev in self.devices}
|
|
self.virt_timeline_vals = {dev: UOp.variable(f"timeline_var_{dev.device_id}", 0, 0xffffffff, dtype=dtypes.uint32) for dev in self.devices}
|
|
self.virt_timeline_signals = {dev: dev.signal_t(base_addr=timeline_sigaddrs[dev], timeline_for_device=dev) for dev in self.devices}
|
|
|
|
for dev in self.devices:
|
|
self.comp_queues[dev].memory_barrier().wait(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev]) \
|
|
.wait(self.signals['CPU'], self.kickoff_var).signal(self.signals[dev], self.kickoff_var)
|
|
|
|
for j,ji in enumerate(jit_cache):
|
|
enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j]
|
|
|
|
for sig, val in sync_signals + deps: enqueue_queue.wait(sig, val)
|
|
|
|
# Encode waits and start profile timestamp (if needed).
|
|
if PROFILE and self.prof_signal_is_used[j * 2]: enqueue_queue.timestamp(self.prof_signals[j * 2])
|
|
|
|
# Encode main commands based on ji type.
|
|
if isinstance(ji.prg, CompiledRunner):
|
|
enqueue_queue.exec(ji.prg._prg, self.ji_args[j], tuple(ji.prg.p.global_size or (1,1,1)), tuple(ji.prg.p.local_size or (1,1,1)))
|
|
elif isinstance(ji.prg, BufferXfer):
|
|
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
|
cast(HCQAllocator, Device[src.device].allocator).map(dest._buf)
|
|
|
|
enqueue_queue.copy(self.hcq_bufs[j][0].va_addr, self.hcq_bufs[j][1].va_addr, dest.nbytes)
|
|
self.copy_to_devs[cast(HCQCompiled, Device[dest.device])].add(cast(HCQCompiled, Device[src.device]))
|
|
|
|
# Encode finish profile timestamp (if needed).
|
|
if PROFILE and self.prof_signal_is_used[j * 2 + 1]: enqueue_queue.timestamp(self.prof_signals[j * 2 + 1])
|
|
|
|
if signal_val is not None: enqueue_queue.signal(signal, signal_val)
|
|
|
|
for dev in self.devices:
|
|
for dep_dev in list(self.copy_to_devs[dev]) + [dev]:
|
|
if dep_dev in self.copy_queues: self.comp_queues[dev].wait(self.signals[(copy_q:=self.copy_queues[dep_dev])], cast(int, last_j[copy_q]) + 1)
|
|
|
|
self.comp_queues[dev].signal(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev] + 1).bind(dev)
|
|
if dev in self.copy_queues: self.copy_queues[dev].bind(dev)
|
|
|
|
self.last_timeline: dict[HCQCompiled, tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
|
|
self.queue_signals_to_reset = [self.signals[q] for q in list(self.comp_queues.values()) + list(self.copy_queues.values()) if q in self.signals]
|
|
|
|
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> Optional[float]:
|
|
# Wait and restore signals
|
|
self.kickoff_value += 1
|
|
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
|
for sig in self.queue_signals_to_reset: sig.value = 0
|
|
self.signals['CPU'].value = self.kickoff_value
|
|
|
|
if PROFILE and self.kickoff_value > 1: self.collect_timestamps()
|
|
|
|
hcq_var_vals = {self.kickoff_var: self.kickoff_value, **var_vals,
|
|
**{var: dev.timeline_value - 1 for dev, var in self.virt_timeline_vals.items()},
|
|
**{sig.base_addr: dev.timeline_signal.base_addr for dev, sig in self.virt_timeline_signals.items()}}
|
|
|
|
# Update rawbuffers
|
|
for (j,i),input_idx in self.input_replace.items(): hcq_var_vals[self.input_replace_to_var.get((j,i))] = input_rawbuffers[input_idx]._buf.va_addr
|
|
|
|
for dev in self.devices:
|
|
self.comp_queues[dev].submit(dev, hcq_var_vals)
|
|
if (copy_queue:=self.copy_queues.get(dev, None)) is not None: copy_queue.submit(dev, hcq_var_vals)
|
|
|
|
self.last_timeline[dev] = (dev.timeline_signal, dev.timeline_value)
|
|
dev.timeline_value += 1
|
|
|
|
if wait:
|
|
st = time.perf_counter()
|
|
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
|
return time.perf_counter() - st
|
|
return None
|
|
|
|
def collect_timestamps(self):
|
|
# NOTE: Append to any device is fine...
|
|
self.devices[0].profile_events += [ProfileGraphEvent(self.prof_graph_entries, self.prog_graph_deps, [s.timestamp for s in self.prof_signals])]
|
|
|
|
def __del__(self):
|
|
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
|
|
|
if PROFILE and self.kickoff_value >= 1: self.collect_timestamps()
|
|
|
|
for fdev, buf in self.kernargs_bufs.items(): fdev.allocator._free(buf, BufferSpec(cpu_access=True))
|
|
|