You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
41 lines
2.3 KiB
41 lines
2.3 KiB
import os, mmap
|
|
from typing import Optional
|
|
from typing import Callable, Dict, Tuple
|
|
from tinygrad.helpers import prod, DType
|
|
from tinygrad.runtime.lib import RawBufferMapped
|
|
from tinygrad.ops import Interpreted, Op, MovementOps, UnaryOps, BufferOps
|
|
|
|
class RawDiskBuffer(RawBufferMapped):
|
|
def __init__(self, size, dtype:DType, device:Optional[str]=None, buf=None, shape=None, offset=0): # pylint: disable=super-init-not-called
|
|
self.shape = (size, ) if shape is None else shape
|
|
self.offset = offset # this is an offset in bytes
|
|
assert device is not None or buf is not None, "disk tensor needs a path or a buf"
|
|
if device is not None:
|
|
f = open(device, "a+b")
|
|
if os.path.getsize(device) < size * dtype.itemsize: os.ftruncate(f.fileno(), size * dtype.itemsize)
|
|
buf = [f, mmap.mmap(f.fileno(), size * dtype.itemsize), 1]
|
|
else:
|
|
buf[2] += 1
|
|
# NOTE: we don't call super since disk tensors don't use RAM
|
|
self.size, self.dtype, self._buf = size, dtype, buf
|
|
def __del__(self):
|
|
self._buf[2] -= 1
|
|
if self._buf[2] == 0: self._buf[0].close()
|
|
def cast(self, arg:Tuple[DType, bool]): return RawDiskBuffer(self.size, arg[0], buf=self._buf, shape=self.shape, offset=self.offset)
|
|
def reshape(self, arg): return RawDiskBuffer(self.size, self.dtype, buf=self._buf, shape=arg, offset=self.offset)
|
|
def shrink(self, arg):
|
|
assert arg[1:] == tuple([(0,x) for x in self.shape[1:]]), f"can only slice the first dim of disk tensor {arg}"
|
|
offset = arg[0][0]*prod(self.shape[1:])*self.dtype.itemsize
|
|
size = (arg[0][1]-arg[0][0]) * prod(self.shape[1:])
|
|
return RawDiskBuffer(size, self.dtype, buf=self._buf, offset=self.offset+offset, shape=(arg[0][1]-arg[0][0],)+self.shape[1:])
|
|
|
|
def as_strided(self, arg):
|
|
return RawDiskBuffer(prod(arg[0]), self.dtype, buf=self._buf, offset=self.offset+arg[2]*self.dtype.itemsize, shape=arg[0])
|
|
|
|
def _buffer(self): return memoryview(self._buf[1])[self.offset:self.offset+self.size*self.dtype.itemsize]
|
|
def readinto(self, buf):
|
|
self._buf[0].seek(self.offset)
|
|
self._buf[0].readinto(buf)
|
|
|
|
disk_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x: x, UnaryOps.CAST: RawDiskBuffer.cast, MovementOps.AS_STRIDED: RawDiskBuffer.as_strided }
|
|
DiskBuffer = Interpreted(RawDiskBuffer, disk_fxn_for_op, from_underlying=lambda x:x) |