from typing import cast from collections import defaultdict from tinygrad.engine.schedule import ScheduleItem from tinygrad.device import Device, Buffer from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG, round_up from tinygrad.ops import Ops from tinygrad.dtype import dtypes, ImageDType from tinygrad.runtime.support.allocator import TLSFAllocator # **************** memory planning **************** def _internal_memory_planner(buffers:list[list[Buffer]], noopt_buffers=None, ignore_checks=False, debug_prefix="") -> dict[Buffer, Buffer]: if NO_MEMORY_PLANNER: return {} first_appearance, last_appearance, buf_to_opt = {}, {}, set() for i,u in enumerate(buffers): for buf in u: should_skip = buf.is_allocated() or buf.base.is_allocated() or buf.lb_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers) if not ignore_checks and should_skip: continue if buf.base not in first_appearance: first_appearance[buf.base] = i last_appearance[buf.base] = i buf_to_opt.add(buf) # Sort buffer operations in timeline order. Two events: buffer is allocated or buffer is freed. buffer_requests = sorted([((first_appearance[buf], True), buf) for buf in first_appearance.keys()] + \ [((last_appearance[buf] + 1, False), buf) for buf in first_appearance.keys()], key=lambda x: x[0]) # Try to suballocate from a shared buffer managed by global_planner using TLSFAllocator. # Also track buffer replacements for buffers that do not support suballocation. buffer_replace:dict[Buffer, tuple[Buffer|None, int|None]] = {} reuse_buffers:dict[tuple, list[Buffer]] = defaultdict(list) global_planner:dict[str, tuple[int, TLSFAllocator]] = defaultdict(lambda: (0, TLSFAllocator(1 << 44, block_size=0x1000, lv2_cnt=32))) for (_, is_open_ev), buf in buffer_requests: # Check if suballocation is possible for the given buffer and device. if hasattr(Device[buf.device].allocator, "_offset") and not isinstance(buf.dtype, ImageDType): if is_open_ev: buffer_replace[buf] = (None, global_planner[buf.device][1].alloc(round_up(buf.nbytes, 0x1000))) else: global_planner[buf.device][1].free(cast(int, buffer_replace[buf][1])) global_planner[buf.device] = (max(global_planner[buf.device][0], buffer_replace[buf][1] + buf.nbytes), global_planner[buf.device][1]) else: key = (buf.device, buf.dtype, buf.options, buf.nbytes) if is_open_ev: buffer_replace[buf] = (reuse_buffers[key].pop(), None) if key in reuse_buffers and len(reuse_buffers[key]) > 0 else (buf, None) else: reuse_buffers[key].append(cast(Buffer, buffer_replace[buf][0])) # Allocate global buffers based on the memory planner. global_buffers = {dev: Buffer(dev, round_up(sz, 0x1000), dtypes.int8) for dev, (sz, _) in global_planner.items()} buffer_resolve:dict[Buffer, tuple[Buffer, int|None]] = {buf: (base or global_buffers[buf.device], off) for buf,(base,off) in buffer_replace.items()} # Assign buffers. First, assign full buffers (not sub-buffers). assigned:dict[Buffer, Buffer] = {} for buf, (base, off) in buffer_resolve.items(): if buf != base: assigned[buf] = base if off is None else Buffer(buf.device, buf.size, buf.dtype, base=base, offset=off) # Now assign sub-buffers. for buf in buf_to_opt: if buf._base is not None: assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=(pbuf:=assigned.get(buf.base, buf.base)).base, offset=pbuf.offset+buf.offset) if DEBUG >= 1: ak, av = dedup(x for x in assigned.keys() if x._base is None),dedup(x for x in assigned.values() if x._base is None)+list(global_buffers.values()) omem, nmem = sum([x.nbytes for x in ak])/1e6, sum([x.nbytes for x in av])/1e6 if omem != nmem: print(f"{debug_prefix}memory reduced from {omem:.2f} MB -> {nmem:.2f} MB,", f"{len(ak)} -> {len(av)} bufs") return assigned def memory_planner(schedule:list[ScheduleItem]) -> list[ScheduleItem]: # Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs. assigned = _internal_memory_planner([list(si.bufs) for si in schedule], noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs}) return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata) for si in schedule]