35 lines
1.6 KiB
35 lines
1.6 KiB
import os, mmap
|
|
try: import _posixshmem # type: ignore
|
|
except Exception: pass
|
|
from typing import Callable, Dict
|
|
from tinygrad.helpers import DType, OSX
|
|
from tinygrad.runtime.lib import RawBufferMapped
|
|
from tinygrad.ops import Interpreted, Op, UnaryOps, MovementOps, BufferOps
|
|
|
|
SHM_CACHE: Dict[str, mmap.mmap] = {}
|
|
class RawShmBuffer(RawBufferMapped):
|
|
def __init__(self, size, dtype:DType, device:str):
|
|
device, self.cache_id = device.split(",")[0], None if "," not in device else device.split(",")[1]
|
|
|
|
if self.cache_id is not None and self.cache_id in SHM_CACHE: shm = SHM_CACHE[self.cache_id]
|
|
else:
|
|
if OSX:
|
|
with open(f"/tmp/shm_{device}", "w+b") as f:
|
|
f.truncate(size * dtype.itemsize)
|
|
shm = mmap.mmap(f.fileno(), size * dtype.itemsize, flags=mmap.MAP_SHARED)
|
|
else:
|
|
fd = _posixshmem.shm_open(device, os.O_RDWR, 0o600)
|
|
# TODO: these flags are somewhat platform specific, but python doesn't expose the ones we need
|
|
shm = mmap.mmap(fd, size * dtype.itemsize, flags=mmap.MAP_SHARED | 0x2000 | 0x008000)
|
|
shm.madvise(mmap.MADV_HUGEPAGE) # type: ignore
|
|
os.close(fd)
|
|
if self.cache_id is not None: SHM_CACHE[self.cache_id] = shm
|
|
|
|
super().__init__(size, dtype, shm)
|
|
def __del__(self):
|
|
if self.cache_id is None: self._buf.close()
|
|
def _buffer(self): return memoryview(self._buf)
|
|
|
|
# TODO: is this wrong?
|
|
shm_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x:x, MovementOps.RESHAPE: lambda x,_:x, MovementOps.AS_STRIDED: lambda x,_:x }
|
|
ShmBuffer = Interpreted(RawShmBuffer, shm_fxn_for_op, from_underlying=lambda x:x)
|
|
|