You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							78 lines
						
					
					
						
							2.0 KiB
						
					
					
				
			
		
		
	
	
							78 lines
						
					
					
						
							2.0 KiB
						
					
					
				import os
 | 
						|
import numpy as np
 | 
						|
import random
 | 
						|
 | 
						|
class SamplingBuffer():
 | 
						|
  def __init__(self, fn, size, write=False):
 | 
						|
    self._fn = fn
 | 
						|
    self._write = write
 | 
						|
    if self._write:
 | 
						|
      self._f = open(self._fn, "ab")
 | 
						|
    else:
 | 
						|
      self._f = open(self._fn, "rb")
 | 
						|
    self._size = size
 | 
						|
    self._refresh()
 | 
						|
 | 
						|
  def _refresh(self):
 | 
						|
    self.cnt = os.path.getsize(self._fn) / self._size
 | 
						|
 | 
						|
  @property
 | 
						|
  def count(self):
 | 
						|
    self._refresh()
 | 
						|
    return self.cnt
 | 
						|
 | 
						|
  def _fetch_one(self, x):
 | 
						|
    assert self._write == False
 | 
						|
    self._f.seek(x*self._size)
 | 
						|
    return self._f.read(self._size)
 | 
						|
 | 
						|
  def sample(self, count, indices = None):
 | 
						|
    if indices == None:
 | 
						|
      cnt = self.count
 | 
						|
      assert cnt != 0
 | 
						|
      indices = map(lambda x: random.randint(0, cnt-1), range(count))
 | 
						|
    return map(self._fetch_one, indices)
 | 
						|
 | 
						|
  def write(self, dat):
 | 
						|
    assert self._write == True
 | 
						|
    assert (len(dat) % self._size) == 0
 | 
						|
    self._f.write(dat)
 | 
						|
    self._f.flush()
 | 
						|
 | 
						|
class NumpySamplingBuffer():
 | 
						|
  def __init__(self, fn, size, dtype, write=False):
 | 
						|
    self._size = size
 | 
						|
    self._dtype = dtype
 | 
						|
    self._buf = SamplingBuffer(fn, len(np.zeros(size, dtype=dtype).tobytes()), write)
 | 
						|
 | 
						|
  @property
 | 
						|
  def count(self):
 | 
						|
    return self._buf.count
 | 
						|
 | 
						|
  def write(self, dat):
 | 
						|
    self._buf.write(dat.tobytes())
 | 
						|
 | 
						|
  def sample(self, count, indices = None):
 | 
						|
    return np.fromstring(''.join(self._buf.sample(count, indices)), dtype=self._dtype).reshape([count]+list(self._size))
 | 
						|
 | 
						|
# TODO: n IOPS needed where n is the Multi
 | 
						|
class MultiNumpySamplingBuffer():
 | 
						|
  def __init__(self, fn, npa, write=False):
 | 
						|
    self._bufs = []
 | 
						|
    for i,n in enumerate(npa):
 | 
						|
      self._bufs.append(NumpySamplingBuffer(fn + ("_%d" % i), n[0], n[1], write))
 | 
						|
 | 
						|
  def write(self, dat):
 | 
						|
    for b,x in zip(self._bufs, dat):
 | 
						|
      b.write(x)
 | 
						|
 | 
						|
  @property
 | 
						|
  def count(self):
 | 
						|
    return min(map(lambda x: x.count, self._bufs))
 | 
						|
 | 
						|
  def sample(self, count):
 | 
						|
    cnt = self.count
 | 
						|
    assert cnt != 0
 | 
						|
    indices = map(lambda x: random.randint(0, cnt-1), range(count))
 | 
						|
    return map(lambda x: x.sample(count, indices), self._bufs)
 | 
						|
 | 
						|
 |