You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
79 lines
2.0 KiB
79 lines
2.0 KiB
5 years ago
|
import os
|
||
|
import numpy as np
|
||
|
import random
|
||
|
|
||
|
class SamplingBuffer():
|
||
|
def __init__(self, fn, size, write=False):
|
||
|
self._fn = fn
|
||
|
self._write = write
|
||
|
if self._write:
|
||
|
self._f = open(self._fn, "ab")
|
||
|
else:
|
||
|
self._f = open(self._fn, "rb")
|
||
|
self._size = size
|
||
|
self._refresh()
|
||
|
|
||
|
def _refresh(self):
|
||
|
self.cnt = os.path.getsize(self._fn) / self._size
|
||
|
|
||
|
@property
|
||
|
def count(self):
|
||
|
self._refresh()
|
||
|
return self.cnt
|
||
|
|
||
|
def _fetch_one(self, x):
|
||
|
assert self._write == False
|
||
|
self._f.seek(x*self._size)
|
||
|
return self._f.read(self._size)
|
||
|
|
||
|
def sample(self, count, indices = None):
|
||
|
if indices == None:
|
||
|
cnt = self.count
|
||
|
assert cnt != 0
|
||
|
indices = map(lambda x: random.randint(0, cnt-1), range(count))
|
||
|
return map(self._fetch_one, indices)
|
||
|
|
||
|
def write(self, dat):
|
||
|
assert self._write == True
|
||
|
assert (len(dat) % self._size) == 0
|
||
|
self._f.write(dat)
|
||
|
self._f.flush()
|
||
|
|
||
|
class NumpySamplingBuffer():
|
||
|
def __init__(self, fn, size, dtype, write=False):
|
||
|
self._size = size
|
||
|
self._dtype = dtype
|
||
|
self._buf = SamplingBuffer(fn, len(np.zeros(size, dtype=dtype).tobytes()), write)
|
||
|
|
||
|
@property
|
||
|
def count(self):
|
||
|
return self._buf.count
|
||
|
|
||
|
def write(self, dat):
|
||
|
self._buf.write(dat.tobytes())
|
||
|
|
||
|
def sample(self, count, indices = None):
|
||
|
return np.fromstring(''.join(self._buf.sample(count, indices)), dtype=self._dtype).reshape([count]+list(self._size))
|
||
|
|
||
|
# TODO: n IOPS needed where n is the Multi
|
||
|
class MultiNumpySamplingBuffer():
|
||
|
def __init__(self, fn, npa, write=False):
|
||
|
self._bufs = []
|
||
|
for i,n in enumerate(npa):
|
||
|
self._bufs.append(NumpySamplingBuffer(fn + ("_%d" % i), n[0], n[1], write))
|
||
|
|
||
|
def write(self, dat):
|
||
|
for b,x in zip(self._bufs, dat):
|
||
|
b.write(x)
|
||
|
|
||
|
@property
|
||
|
def count(self):
|
||
|
return min(map(lambda x: x.count, self._bufs))
|
||
|
|
||
|
def sample(self, count):
|
||
|
cnt = self.count
|
||
|
assert cnt != 0
|
||
|
indices = map(lambda x: random.randint(0, cnt-1), range(count))
|
||
|
return map(lambda x: x.sample(count, indices), self._bufs)
|
||
|
|