You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
326 lines
17 KiB
326 lines
17 KiB
from typing import TypeVar, Generic, Callable, Union, cast, Optional, Any
|
|
import functools, collections
|
|
from tinygrad.tensor import Tensor
|
|
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, dedup, partition, unwrap
|
|
from tinygrad.device import Buffer, Compiled, Device
|
|
from tinygrad.dtype import DType
|
|
from tinygrad.ops import UOp, Variable, sym_infer, Ops
|
|
from tinygrad.shape.shapetracker import ShapeTracker
|
|
from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner, Estimates
|
|
from tinygrad.engine.memory import _internal_memory_planner
|
|
from tinygrad.nn.state import get_parameters
|
|
from dataclasses import dataclass
|
|
from weakref import WeakKeyDictionary
|
|
|
|
class GraphException(Exception): pass
|
|
|
|
def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], max_batch_size=0) -> list[ExecItem]:
|
|
# Split JIT cache into batches for faster graph execution.
|
|
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
|
|
graphed_jit_cache: list[ExecItem] = []
|
|
current_batch: list[ExecItem] = []
|
|
current_device: Optional[Compiled] = None
|
|
|
|
def flush_batch():
|
|
nonlocal current_batch, current_device, max_batch_size
|
|
try:
|
|
if current_device is None: raise GraphException("no device for graph")
|
|
if len(current_batch) <= 1 and not getenv("GRAPH_ONE_KERNEL"): raise GraphException("only one kernel doesn't graph")
|
|
graph_runner = current_device.graph(current_batch, input_rawbuffers, var_vals)
|
|
# clear jit inputs to allow their memory to be freed/reused
|
|
for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None
|
|
graphed_jit_cache.append(ExecItem(graph_runner, cast(list[Optional[Buffer]], input_rawbuffers)))
|
|
max_batch_size *= 2
|
|
if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}")
|
|
except GraphException as e:
|
|
graphed_jit_cache.extend(current_batch)
|
|
if DEBUG >= 2: print(f"JIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_device}: {e}")
|
|
current_batch = []
|
|
current_device = None
|
|
|
|
for ji in jit_cache:
|
|
if isinstance(ji.prg, ViewOp): continue
|
|
ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None.
|
|
if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.dev
|
|
elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"CUDA", "NV", "AMD"}:
|
|
ji_graph_dev = Device[ji.bufs[0].device]
|
|
|
|
graph_class = (ji_graph_dev.graph.func if isinstance(ji_graph_dev.graph, functools.partial) else ji_graph_dev.graph) if ji_graph_dev else None
|
|
can_be_graphed = ji_graph_dev and ji_graph_dev.graph
|
|
can_share_graph = (ji_graph_dev == current_device or (isinstance(graph_class, type) and issubclass(graph_class, MultiGraphRunner)) and
|
|
type(ji_graph_dev) is type(current_device))
|
|
can_extend_graph_batch = can_be_graphed and (max_batch_size == 0 or len(current_batch) < max_batch_size) and can_share_graph
|
|
if not can_extend_graph_batch and len(current_batch) > 0: flush_batch()
|
|
|
|
if can_be_graphed: current_batch.append(ji)
|
|
else: graphed_jit_cache.append(ji)
|
|
|
|
current_device = ji_graph_dev
|
|
|
|
if len(current_batch) > 0: flush_batch()
|
|
return graphed_jit_cache
|
|
|
|
def get_input_replace(jit_cache: list[ExecItem], input_rawbuffers:list[Buffer]) -> dict[tuple[int, int], int]:
|
|
input_replace: dict[tuple[int, int], int] = {}
|
|
for j,ji in enumerate(jit_cache):
|
|
for i,a in enumerate(ji.bufs):
|
|
if a in input_rawbuffers:
|
|
input_replace[(j,i)] = input_rawbuffers.index(a)
|
|
return input_replace
|
|
|
|
class GraphRunner(Runner):
|
|
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
|
|
self.jit_cache = jit_cache # NOTE: this is not used, but you have to keep these objects alive for the Graph
|
|
self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_rawbuffers)
|
|
self.var_vals_replace:dict[int, list[int]] = {}
|
|
self.launch_dims_replace:dict[int, tuple[Optional[int], Optional[int]]] = {}
|
|
self.launch_dims_base:dict[int, tuple[tuple[int, ...], tuple[int, ...]]] = {}
|
|
|
|
def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim)
|
|
|
|
self.vars = sorted(var_vals.keys(), key=lambda v: v.expr)
|
|
self.symbolic_dims = dedup([tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.local_size) and is_sym_dim(d)] +
|
|
[tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.global_size) and is_sym_dim(d)])
|
|
def find_symbolic_dim(dim): return self.symbolic_dims.index(tuple(dim)) if dim is not None and tuple(dim) in self.symbolic_dims else None
|
|
|
|
estimates = Estimates()
|
|
for j,ji in enumerate(jit_cache):
|
|
estimates += ji.prg.estimates
|
|
if isinstance(ji.prg, CompiledRunner):
|
|
if ji.prg.p.vars: self.var_vals_replace[j] = [self.vars.index(v) for v in ji.prg.p.vars]
|
|
|
|
global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size)
|
|
if global_dim_idx is not None or local_dim_idx is not None:
|
|
self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx)
|
|
assert ji.prg.p.global_size is not None and ji.prg.p.local_size is not None
|
|
self.launch_dims_base[j] = (tuple(ji.prg.p.global_size), tuple(ji.prg.p.local_size))
|
|
|
|
# used in MultiGraphRunner. the ints are id() of _bufs
|
|
self.w_dependency_map: dict[int, Any] = {}
|
|
self.r_dependency_map: dict[int, list[Any]] = collections.defaultdict(list)
|
|
|
|
super().__init__(colored(f"<batched {len(jit_cache)}>", "cyan"), jit_cache[0].prg.device.split(":")[0], estimates.simplify())
|
|
|
|
def updated_vars(self, var_vals: dict[Variable, int]):
|
|
vals = [var_vals[v] for v in self.vars]
|
|
for j, vidxs in self.var_vals_replace.items():
|
|
for i, v in enumerate(vidxs): yield j, i, vals[v]
|
|
|
|
def updated_launch_dims(self, var_vals: dict[Variable, int]):
|
|
dims = [tuple(sym_infer(s, var_vals) for s in dim) for dim in self.symbolic_dims]
|
|
for j, (gl, lc) in self.launch_dims_replace.items():
|
|
yield j, (dims[gl] if gl is not None else self.launch_dims_base[j][0]), (dims[lc] if lc is not None else self.launch_dims_base[j][1])
|
|
|
|
def _access_resources(self, rawbufs:list[Buffer], write:list[int], new_dependency:Any):
|
|
# To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource,
|
|
# whether for write or read operations. A resource can be accessed by either a single writer or multiple readers.
|
|
wait_nodes = []
|
|
|
|
for i,rawbuf in enumerate(rawbufs):
|
|
if id(rawbuf.base._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(rawbuf.base._buf)])
|
|
if i in write:
|
|
if id(rawbuf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf.base._buf)))
|
|
|
|
for i,rawbuf in enumerate(rawbufs):
|
|
if i in write: self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency
|
|
else: self.r_dependency_map[id(rawbuf.base._buf)].append(new_dependency)
|
|
|
|
return list({id(x):x for x in wait_nodes}.values())
|
|
|
|
# a marker for your graph supporting multiple devices of the same type
|
|
class MultiGraphRunner(GraphRunner): pass
|
|
|
|
def update_depends(depends:set[Buffer|None], jit_cache:list[ExecItem]):
|
|
for ei in jit_cache:
|
|
if any(b in depends for b in ei.bufs):
|
|
if isinstance(ei.prg, CompiledRunner):
|
|
depends.update(cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs if out not in ei.prg.p.ins)
|
|
if isinstance(ei.prg, (BufferCopy, BufferXfer)):
|
|
depends.add(cast(Buffer, ei.bufs[0]))
|
|
|
|
ReturnType = TypeVar('ReturnType')
|
|
@dataclass
|
|
class CapturedJit(Generic[ReturnType]):
|
|
ret: Any # includes the Tensors or any other returned object
|
|
jit_cache: list[ExecItem]
|
|
input_replace: dict[tuple[int, int], int]
|
|
extra_view_inputs: list[tuple[int, int, str, int, DType]]
|
|
expected_names: list[Union[int, str]]
|
|
expected_st_vars_dtype_device: list[tuple[ShapeTracker, tuple[Variable, ...], DType, str]]
|
|
|
|
def __reduce__(self):
|
|
# TODO: free_intermediates here?
|
|
self.optimize_weights()
|
|
return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs,
|
|
self.expected_names, self.expected_st_vars_dtype_device)
|
|
|
|
def __post_init__(self):
|
|
self._jit_cache: list[ExecItem] = self.jit_cache
|
|
self._input_replace: dict[tuple[int, int], int] = self.input_replace
|
|
self._first_run = True
|
|
self._clear_inputs()
|
|
|
|
def _clear_inputs(self):
|
|
for (j,i) in self._input_replace.keys(): self._jit_cache[j].bufs[i] = None
|
|
|
|
def free_intermediates(self):
|
|
depends: set[Buffer|None] = set([None])
|
|
update_depends(depends, self.jit_cache)
|
|
for b in depends:
|
|
if b is not None:
|
|
b.deallocate()
|
|
if b._base is not None and b._base.allocated_views == 0: b._base.deallocate()
|
|
self.__post_init__() # reset the graph state
|
|
|
|
def optimize_weights(self):
|
|
blacklist = [t.lazydata.buffer for t in get_parameters(self.ret)]
|
|
asgn = _internal_memory_planner([[b for item in self.jit_cache for b in item.bufs if b is not None and b not in blacklist]], ignore_checks=True)
|
|
self.jit_cache = [ExecItem(item.prg, [asgn.get(b,b) if b is not None else None for b in item.bufs]) for item in self.jit_cache]
|
|
for old, new in asgn.items():
|
|
if old.is_allocated(): new.ensure_allocated().copyin(old.as_buffer())
|
|
self.__post_init__()
|
|
|
|
# jit exec
|
|
def __call__(self, input_buffers:list[Buffer], var_vals:dict[Variable, int]) -> ReturnType:
|
|
# assign inputs
|
|
for idx, offset, device, size, dtype in self.extra_view_inputs:
|
|
input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
|
|
for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx]
|
|
|
|
# Condense the items into a graph executor.
|
|
if self._first_run:
|
|
# allocate intermediates if freed
|
|
for ji in self.jit_cache:
|
|
for b in ji.bufs:
|
|
if b is not None: b.ensure_allocated()
|
|
# create graph if needed
|
|
if JIT < 2:
|
|
self._jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals, max_batch_size=getenv("JIT_BATCH_SIZE", 32))
|
|
self._input_replace = get_input_replace(self._jit_cache, input_buffers)
|
|
self._first_run = False
|
|
|
|
if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels")
|
|
for ei in self._jit_cache: ei.run(var_vals, jit=True)
|
|
self._clear_inputs()
|
|
return self.ret
|
|
|
|
def _prepare_jit_inputs(args, kwargs):
|
|
input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
|
|
names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors]
|
|
if len(unrealized_tensors := [x for x in tensors if not x.lazydata.is_realized]): Tensor.realize(*unrealized_tensors)
|
|
# TODO: should we be unpacking multi here?
|
|
lbs: list[UOp] = flatten([t.lazydata.src if t.lazydata.op is Ops.MULTI else [t.lazydata] for t in tensors])
|
|
input_buffers: list[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None]
|
|
assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
|
|
st_varval_dtype_device = [(*unwrap(lb.st).unbind(), lb.dtype, lb.device) for lb in lbs]
|
|
var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])
|
|
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varval_dtype_device]
|
|
return input_buffers, var_vals, names, st_vars_dtype_device
|
|
|
|
class TinyJit(Generic[ReturnType]):
|
|
def __init__(self, fxn:Optional[Callable[..., ReturnType]], captured:Optional[CapturedJit]=None, prune=False):
|
|
assert fxn or captured, "need either a function or a CapturedJit"
|
|
self.fxn = fxn
|
|
self.captured: Optional[CapturedJit] = captured
|
|
self.cnt: int = 2 if self.fxn is None else 0
|
|
self.prune = prune
|
|
|
|
def add_buffer(self, b:Buffer) -> Buffer:
|
|
if found:=self._buffer_replace.get(b, None): return found
|
|
if b.is_allocated() or b.lb_refcount > 0: return b
|
|
if b._base is not None:
|
|
self._buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, base=self.add_buffer(b._base), offset=b.offset)
|
|
else:
|
|
self._buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, options=b.options)
|
|
return ret
|
|
|
|
def add(self, ei:ExecItem):
|
|
self._jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None]))
|
|
|
|
def reset(self):
|
|
assert self.fxn is not None, "can't reset without function"
|
|
self.cnt = 0
|
|
self.captured = None
|
|
|
|
def __reduce__(self):
|
|
assert self.captured is not None, "can't pickle an uncaptured JIT"
|
|
return self.__class__, (None, self.captured)
|
|
|
|
# keep legacy code working
|
|
@property
|
|
def jit_cache(self) -> list[ExecItem]: return self.captured._jit_cache if self.captured is not None else []
|
|
@property
|
|
def input_replace(self) -> dict[tuple[int, int], int]: return self.captured._input_replace if self.captured is not None else {}
|
|
|
|
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods
|
|
|
|
def __call__(self, *args, **kwargs) -> ReturnType:
|
|
input_buffers, var_vals, names, st_vars_dtype_device = _prepare_jit_inputs(args, kwargs)
|
|
if not JIT or self.cnt == 0:
|
|
# jit ignore
|
|
assert self.fxn is not None
|
|
with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value):
|
|
ret = self.fxn(*args, **kwargs)
|
|
if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:])
|
|
elif self.cnt == 1:
|
|
# jit capture
|
|
assert self.fxn is not None
|
|
if capturing: raise RuntimeError(f"having TinyJit inside another TinyJit is not supported {len(capturing)=} {capturing=}")
|
|
self._jit_cache: list[ExecItem] = []
|
|
self._buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary()
|
|
# TODO: should we always disable the memory planner here? it must be off for prune
|
|
with Context(BEAM=getenv("JITBEAM", BEAM.value), NO_MEMORY_PLANNER=int(self.prune)):
|
|
capturing.append(self)
|
|
try:
|
|
ret = self.fxn(*args, **kwargs)
|
|
if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:])
|
|
except Exception as e: raise e
|
|
finally: capturing.clear()
|
|
jit_cache = self._jit_cache
|
|
del self._buffer_replace, self._jit_cache
|
|
assert len(jit_cache), "didn't JIT anything!"
|
|
if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_buffers)} inputs")
|
|
|
|
# track inputs that are views of buffers
|
|
# TODO: eventually expected_buffers should live in ExecItem
|
|
extra_view_inputs: list[tuple[int, int, str, int, DType]] = []
|
|
for item in jit_cache:
|
|
for b in item.bufs:
|
|
if b is not None and b._base is not None and b._base in input_buffers:
|
|
input_buffers.append(b)
|
|
extra_view_inputs.append((input_buffers.index(b.base), b.offset, b.device, b.size, b.dtype))
|
|
|
|
# prune independent kernels (optional)
|
|
if self.prune:
|
|
depends = set(input_buffers)
|
|
update_depends(depends, jit_cache)
|
|
pruned, onetime = partition(jit_cache,
|
|
lambda ei: not isinstance(ei.prg, CompiledRunner) or any(ei.bufs[out] in depends for out in ei.prg.p.outs))
|
|
if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels")
|
|
# run the onetime kernels here
|
|
for ei in onetime:
|
|
for b in ei.bufs: cast(Buffer, b).ensure_allocated()
|
|
ei.run(var_vals, jit=True)
|
|
jit_cache = pruned
|
|
|
|
# memory planning (optional)
|
|
# Exclude buffers involved in transfer ops to preserve parallelism.
|
|
noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, BufferXfer) for b in ji.bufs}
|
|
assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ")
|
|
jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in jit_cache]
|
|
|
|
input_replace = get_input_replace(jit_cache, input_buffers)
|
|
if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")
|
|
|
|
# set this for next run
|
|
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, st_vars_dtype_device)
|
|
elif self.cnt >= 2:
|
|
# jit exec
|
|
assert self.captured is not None
|
|
assert self.captured.expected_names == names, f"args mismatch in JIT: {self.captured.expected_names=} != {names}"
|
|
assert self.captured.expected_st_vars_dtype_device == st_vars_dtype_device, \
|
|
f"args mismatch in JIT: {self.captured.expected_st_vars_dtype_device=} != {st_vars_dtype_device=}"
|
|
ret = self.captured(input_buffers, var_vals)
|
|
|
|
self.cnt += 1
|
|
return ret
|
|
|