from typing import TypeVar, Generic, Callable, Union, cast, Optional, Any import functools, collections from tinygrad.tensor import Tensor from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, dedup, partition, unwrap from tinygrad.device import Buffer, Compiled, Device from tinygrad.dtype import DType from tinygrad.ops import UOp, Variable, sym_infer, Ops from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner, Estimates from tinygrad.engine.memory import _internal_memory_planner from tinygrad.nn.state import get_parameters from dataclasses import dataclass from weakref import WeakKeyDictionary class GraphException(Exception): pass def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], max_batch_size=0) -> list[ExecItem]: # Split JIT cache into batches for faster graph execution. # This allows the accelerator to run some batches while subsequent graphs are still being updated. graphed_jit_cache: list[ExecItem] = [] current_batch: list[ExecItem] = [] current_device: Optional[Compiled] = None def flush_batch(): nonlocal current_batch, current_device, max_batch_size try: if current_device is None: raise GraphException("no device for graph") if len(current_batch) <= 1 and not getenv("GRAPH_ONE_KERNEL"): raise GraphException("only one kernel doesn't graph") graph_runner = current_device.graph(current_batch, input_rawbuffers, var_vals) # clear jit inputs to allow their memory to be freed/reused for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None graphed_jit_cache.append(ExecItem(graph_runner, cast(list[Optional[Buffer]], input_rawbuffers))) max_batch_size *= 2 if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}") except GraphException as e: graphed_jit_cache.extend(current_batch) if DEBUG >= 2: print(f"JIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_device}: {e}") current_batch = [] current_device = None for ji in jit_cache: if isinstance(ji.prg, ViewOp): continue ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None. if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.dev elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"CUDA", "NV", "AMD"}: ji_graph_dev = Device[ji.bufs[0].device] graph_class = (ji_graph_dev.graph.func if isinstance(ji_graph_dev.graph, functools.partial) else ji_graph_dev.graph) if ji_graph_dev else None can_be_graphed = ji_graph_dev and ji_graph_dev.graph can_share_graph = (ji_graph_dev == current_device or (isinstance(graph_class, type) and issubclass(graph_class, MultiGraphRunner)) and type(ji_graph_dev) is type(current_device)) can_extend_graph_batch = can_be_graphed and (max_batch_size == 0 or len(current_batch) < max_batch_size) and can_share_graph if not can_extend_graph_batch and len(current_batch) > 0: flush_batch() if can_be_graphed: current_batch.append(ji) else: graphed_jit_cache.append(ji) current_device = ji_graph_dev if len(current_batch) > 0: flush_batch() return graphed_jit_cache def get_input_replace(jit_cache: list[ExecItem], input_rawbuffers:list[Buffer]) -> dict[tuple[int, int], int]: input_replace: dict[tuple[int, int], int] = {} for j,ji in enumerate(jit_cache): for i,a in enumerate(ji.bufs): if a in input_rawbuffers: input_replace[(j,i)] = input_rawbuffers.index(a) return input_replace class GraphRunner(Runner): def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]): self.jit_cache = jit_cache # NOTE: this is not used, but you have to keep these objects alive for the Graph self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_rawbuffers) self.var_vals_replace:dict[int, list[int]] = {} self.launch_dims_replace:dict[int, tuple[Optional[int], Optional[int]]] = {} self.launch_dims_base:dict[int, tuple[tuple[int, ...], tuple[int, ...]]] = {} def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim) self.vars = sorted(var_vals.keys(), key=lambda v: v.expr) self.symbolic_dims = dedup([tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.local_size) and is_sym_dim(d)] + [tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.global_size) and is_sym_dim(d)]) def find_symbolic_dim(dim): return self.symbolic_dims.index(tuple(dim)) if dim is not None and tuple(dim) in self.symbolic_dims else None estimates = Estimates() for j,ji in enumerate(jit_cache): estimates += ji.prg.estimates if isinstance(ji.prg, CompiledRunner): if ji.prg.p.vars: self.var_vals_replace[j] = [self.vars.index(v) for v in ji.prg.p.vars] global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size) if global_dim_idx is not None or local_dim_idx is not None: self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx) assert ji.prg.p.global_size is not None and ji.prg.p.local_size is not None self.launch_dims_base[j] = (tuple(ji.prg.p.global_size), tuple(ji.prg.p.local_size)) # used in MultiGraphRunner. the ints are id() of _bufs self.w_dependency_map: dict[int, Any] = {} self.r_dependency_map: dict[int, list[Any]] = collections.defaultdict(list) super().__init__(colored(f"", "cyan"), jit_cache[0].prg.device.split(":")[0], estimates.simplify()) def updated_vars(self, var_vals: dict[Variable, int]): vals = [var_vals[v] for v in self.vars] for j, vidxs in self.var_vals_replace.items(): for i, v in enumerate(vidxs): yield j, i, vals[v] def updated_launch_dims(self, var_vals: dict[Variable, int]): dims = [tuple(sym_infer(s, var_vals) for s in dim) for dim in self.symbolic_dims] for j, (gl, lc) in self.launch_dims_replace.items(): yield j, (dims[gl] if gl is not None else self.launch_dims_base[j][0]), (dims[lc] if lc is not None else self.launch_dims_base[j][1]) def _access_resources(self, rawbufs:list[Buffer], write:list[int], new_dependency:Any): # To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource, # whether for write or read operations. A resource can be accessed by either a single writer or multiple readers. wait_nodes = [] for i,rawbuf in enumerate(rawbufs): if id(rawbuf.base._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(rawbuf.base._buf)]) if i in write: if id(rawbuf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf.base._buf))) for i,rawbuf in enumerate(rawbufs): if i in write: self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency else: self.r_dependency_map[id(rawbuf.base._buf)].append(new_dependency) return list({id(x):x for x in wait_nodes}.values()) # a marker for your graph supporting multiple devices of the same type class MultiGraphRunner(GraphRunner): pass def update_depends(depends:set[Buffer|None], jit_cache:list[ExecItem]): for ei in jit_cache: if any(b in depends for b in ei.bufs): if isinstance(ei.prg, CompiledRunner): depends.update(cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs if out not in ei.prg.p.ins) if isinstance(ei.prg, (BufferCopy, BufferXfer)): depends.add(cast(Buffer, ei.bufs[0])) ReturnType = TypeVar('ReturnType') @dataclass class CapturedJit(Generic[ReturnType]): ret: Any # includes the Tensors or any other returned object jit_cache: list[ExecItem] input_replace: dict[tuple[int, int], int] extra_view_inputs: list[tuple[int, int, str, int, DType]] expected_names: list[Union[int, str]] expected_st_vars_dtype_device: list[tuple[ShapeTracker, tuple[Variable, ...], DType, str]] def __reduce__(self): # TODO: free_intermediates here? self.optimize_weights() return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs, self.expected_names, self.expected_st_vars_dtype_device) def __post_init__(self): self._jit_cache: list[ExecItem] = self.jit_cache self._input_replace: dict[tuple[int, int], int] = self.input_replace self._first_run = True self._clear_inputs() def _clear_inputs(self): for (j,i) in self._input_replace.keys(): self._jit_cache[j].bufs[i] = None def free_intermediates(self): depends: set[Buffer|None] = set([None]) update_depends(depends, self.jit_cache) for b in depends: if b is not None: b.deallocate() if b._base is not None and b._base.allocated_views == 0: b._base.deallocate() self.__post_init__() # reset the graph state def optimize_weights(self): blacklist = [t.lazydata.buffer for t in get_parameters(self.ret)] asgn = _internal_memory_planner([[b for item in self.jit_cache for b in item.bufs if b is not None and b not in blacklist]], ignore_checks=True) self.jit_cache = [ExecItem(item.prg, [asgn.get(b,b) if b is not None else None for b in item.bufs]) for item in self.jit_cache] for old, new in asgn.items(): if old.is_allocated(): new.ensure_allocated().copyin(old.as_buffer()) self.__post_init__() # jit exec def __call__(self, input_buffers:list[Buffer], var_vals:dict[Variable, int]) -> ReturnType: # assign inputs for idx, offset, device, size, dtype in self.extra_view_inputs: input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated()) for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx] # Condense the items into a graph executor. if self._first_run: # allocate intermediates if freed for ji in self.jit_cache: for b in ji.bufs: if b is not None: b.ensure_allocated() # create graph if needed if JIT < 2: self._jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals, max_batch_size=getenv("JIT_BATCH_SIZE", 32)) self._input_replace = get_input_replace(self._jit_cache, input_buffers) self._first_run = False if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels") for ei in self._jit_cache: ei.run(var_vals, jit=True) self._clear_inputs() return self.ret def _prepare_jit_inputs(args, kwargs): input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor] names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors] if len(unrealized_tensors := [x for x in tensors if not x.lazydata.is_realized]): Tensor.realize(*unrealized_tensors) # TODO: should we be unpacking multi here? lbs: list[UOp] = flatten([t.lazydata.src if t.lazydata.op is Ops.MULTI else [t.lazydata] for t in tensors]) input_buffers: list[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None] assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT" st_varval_dtype_device = [(*unwrap(lb.st).unbind(), lb.dtype, lb.device) for lb in lbs] var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))]) st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varval_dtype_device] return input_buffers, var_vals, names, st_vars_dtype_device class TinyJit(Generic[ReturnType]): def __init__(self, fxn:Optional[Callable[..., ReturnType]], captured:Optional[CapturedJit]=None, prune=False): assert fxn or captured, "need either a function or a CapturedJit" self.fxn = fxn self.captured: Optional[CapturedJit] = captured self.cnt: int = 2 if self.fxn is None else 0 self.prune = prune def add_buffer(self, b:Buffer) -> Buffer: if found:=self._buffer_replace.get(b, None): return found if b.is_allocated() or b.lb_refcount > 0: return b if b._base is not None: self._buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, base=self.add_buffer(b._base), offset=b.offset) else: self._buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, options=b.options) return ret def add(self, ei:ExecItem): self._jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None])) def reset(self): assert self.fxn is not None, "can't reset without function" self.cnt = 0 self.captured = None def __reduce__(self): assert self.captured is not None, "can't pickle an uncaptured JIT" return self.__class__, (None, self.captured) # keep legacy code working @property def jit_cache(self) -> list[ExecItem]: return self.captured._jit_cache if self.captured is not None else [] @property def input_replace(self) -> dict[tuple[int, int], int]: return self.captured._input_replace if self.captured is not None else {} def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods def __call__(self, *args, **kwargs) -> ReturnType: input_buffers, var_vals, names, st_vars_dtype_device = _prepare_jit_inputs(args, kwargs) if not JIT or self.cnt == 0: # jit ignore assert self.fxn is not None with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value): ret = self.fxn(*args, **kwargs) if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:]) elif self.cnt == 1: # jit capture assert self.fxn is not None if capturing: raise RuntimeError(f"having TinyJit inside another TinyJit is not supported {len(capturing)=} {capturing=}") self._jit_cache: list[ExecItem] = [] self._buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary() # TODO: should we always disable the memory planner here? it must be off for prune with Context(BEAM=getenv("JITBEAM", BEAM.value), NO_MEMORY_PLANNER=int(self.prune)): capturing.append(self) try: ret = self.fxn(*args, **kwargs) if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:]) except Exception as e: raise e finally: capturing.clear() jit_cache = self._jit_cache del self._buffer_replace, self._jit_cache assert len(jit_cache), "didn't JIT anything!" if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_buffers)} inputs") # track inputs that are views of buffers # TODO: eventually expected_buffers should live in ExecItem extra_view_inputs: list[tuple[int, int, str, int, DType]] = [] for item in jit_cache: for b in item.bufs: if b is not None and b._base is not None and b._base in input_buffers: input_buffers.append(b) extra_view_inputs.append((input_buffers.index(b.base), b.offset, b.device, b.size, b.dtype)) # prune independent kernels (optional) if self.prune: depends = set(input_buffers) update_depends(depends, jit_cache) pruned, onetime = partition(jit_cache, lambda ei: not isinstance(ei.prg, CompiledRunner) or any(ei.bufs[out] in depends for out in ei.prg.p.outs)) if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels") # run the onetime kernels here for ei in onetime: for b in ei.bufs: cast(Buffer, b).ensure_allocated() ei.run(var_vals, jit=True) jit_cache = pruned # memory planning (optional) # Exclude buffers involved in transfer ops to preserve parallelism. noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, BufferXfer) for b in ji.bufs} assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ") jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in jit_cache] input_replace = get_input_replace(jit_cache, input_buffers) if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found") # set this for next run self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, st_vars_dtype_device) elif self.cnt >= 2: # jit exec assert self.captured is not None assert self.captured.expected_names == names, f"args mismatch in JIT: {self.captured.expected_names=} != {names}" assert self.captured.expected_st_vars_dtype_device == st_vars_dtype_device, \ f"args mismatch in JIT: {self.captured.expected_st_vars_dtype_device=} != {st_vars_dtype_device=}" ret = self.captured(input_buffers, var_vals) self.cnt += 1 return ret