You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
51 lines
3.2 KiB
51 lines
3.2 KiB
1 month ago
|
from collections import defaultdict
|
||
|
from tinygrad.engine.schedule import ScheduleItem
|
||
|
from tinygrad.device import Device, Buffer
|
||
|
from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG
|
||
|
from tinygrad.ops import Ops
|
||
|
|
||
|
# **************** memory planning ****************
|
||
|
|
||
|
def _internal_memory_planner(buffers:list[list[Buffer]|tuple[Buffer, ...]], noopt_buffers=None, debug_prefix="") -> dict[Buffer, Buffer]:
|
||
|
if NO_MEMORY_PLANNER: return {}
|
||
|
first_appearance, last_appearance = {}, {}
|
||
|
for i,u in enumerate(buffers):
|
||
|
for buf in u:
|
||
|
if buf.is_allocated() or buf.lb_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers): continue
|
||
|
if buf.base not in first_appearance: first_appearance[buf.base] = i
|
||
|
last_appearance[buf.base] = i
|
||
|
|
||
|
# Sort buffers by size in descending order, prioritizing largest buffers for allocation first.
|
||
|
# Track free segments, each containing (start, stop, and buffer that could be reused on this segment).
|
||
|
free_segs: dict[tuple, list[tuple[int, int, Buffer]]] = defaultdict(list) # dict[buffer key, tuple[start, end, buffer to reuse on the seg]]
|
||
|
def find_replace_buffer(buf, st, en):
|
||
|
key = (buf.device, buf.dtype, buf.options) + ((buf.nbytes,) if not hasattr(Device[buf.device].allocator, "offset") else tuple())
|
||
|
|
||
|
default_buf = (0, len(buffers) - 1, buf) # will return the buffer itself if the replace one is not found.
|
||
|
seg_st, seg_en, seg_buf = next((free_segs[key].pop(i) for i,(sst,sen,_) in enumerate(free_segs[key]) if sst <= st and en <= sen), default_buf)
|
||
|
|
||
|
free_segs[key] += [(seg_st, st - 1, seg_buf)] if st - 1 >= seg_st else []
|
||
|
free_segs[key] += [(en + 1, seg_en, seg_buf)] if seg_en >= en + 1 else []
|
||
|
|
||
|
return seg_buf if seg_buf.nbytes == buf.nbytes else Buffer(buf.device, buf.size, buf.dtype, base=seg_buf)
|
||
|
|
||
|
buffer_requests = sorted([(first_appearance[buf], last_appearance[buf], buf) for buf in first_appearance.keys()], key=lambda x: -x[2].nbytes)
|
||
|
assigned = {buf:find_replace_buffer(buf, st, en) for st, en, buf in buffer_requests}
|
||
|
|
||
|
for i,u in enumerate(buffers):
|
||
|
for buf in u:
|
||
|
if buf.is_allocated() or buf.lb_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers): continue
|
||
|
if buf._base is not None: assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=assigned.get(buf.base, buf.base).base, offset=buf.offset)
|
||
|
else: assigned[buf] = assigned.get(buf, buf)
|
||
|
|
||
|
if DEBUG >= 1 and len(ak:=dedup(x for x in assigned.keys() if x._base is None)) != len(av:=dedup(x for x in assigned.values() if x._base is None)):
|
||
|
print(debug_prefix+f"memory reduced from {sum([x.nbytes for x in ak])/1e6:.2f} MB -> {sum([x.nbytes for x in av])/1e6:.2f} MB,",
|
||
|
f"{len(ak)} -> {len(av)} bufs")
|
||
|
return assigned
|
||
|
|
||
|
def memory_planner(schedule:list[ScheduleItem]) -> list[ScheduleItem]:
|
||
|
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
|
||
|
assigned = _internal_memory_planner([si.bufs for si in schedule],
|
||
|
noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs})
|
||
|
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata, si.assign_preloads) for si in schedule]
|