You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							78 lines
						
					
					
						
							2.0 KiB
						
					
					
				
			
		
		
	
	
							78 lines
						
					
					
						
							2.0 KiB
						
					
					
				| import os
 | |
| import numpy as np
 | |
| import random
 | |
| 
 | |
| class SamplingBuffer():
 | |
|   def __init__(self, fn, size, write=False):
 | |
|     self._fn = fn
 | |
|     self._write = write
 | |
|     if self._write:
 | |
|       self._f = open(self._fn, "ab")
 | |
|     else:
 | |
|       self._f = open(self._fn, "rb")
 | |
|     self._size = size
 | |
|     self._refresh()
 | |
| 
 | |
|   def _refresh(self):
 | |
|     self.cnt = os.path.getsize(self._fn) / self._size
 | |
| 
 | |
|   @property
 | |
|   def count(self):
 | |
|     self._refresh()
 | |
|     return self.cnt
 | |
| 
 | |
|   def _fetch_one(self, x):
 | |
|     assert self._write == False
 | |
|     self._f.seek(x*self._size)
 | |
|     return self._f.read(self._size)
 | |
| 
 | |
|   def sample(self, count, indices = None):
 | |
|     if indices == None:
 | |
|       cnt = self.count
 | |
|       assert cnt != 0
 | |
|       indices = map(lambda x: random.randint(0, cnt-1), range(count))
 | |
|     return map(self._fetch_one, indices)
 | |
| 
 | |
|   def write(self, dat):
 | |
|     assert self._write == True
 | |
|     assert (len(dat) % self._size) == 0
 | |
|     self._f.write(dat)
 | |
|     self._f.flush()
 | |
| 
 | |
| class NumpySamplingBuffer():
 | |
|   def __init__(self, fn, size, dtype, write=False):
 | |
|     self._size = size
 | |
|     self._dtype = dtype
 | |
|     self._buf = SamplingBuffer(fn, len(np.zeros(size, dtype=dtype).tobytes()), write)
 | |
| 
 | |
|   @property
 | |
|   def count(self):
 | |
|     return self._buf.count
 | |
| 
 | |
|   def write(self, dat):
 | |
|     self._buf.write(dat.tobytes())
 | |
| 
 | |
|   def sample(self, count, indices = None):
 | |
|     return np.fromstring(''.join(self._buf.sample(count, indices)), dtype=self._dtype).reshape([count]+list(self._size))
 | |
| 
 | |
| # TODO: n IOPS needed where n is the Multi
 | |
| class MultiNumpySamplingBuffer():
 | |
|   def __init__(self, fn, npa, write=False):
 | |
|     self._bufs = []
 | |
|     for i,n in enumerate(npa):
 | |
|       self._bufs.append(NumpySamplingBuffer(fn + ("_%d" % i), n[0], n[1], write))
 | |
| 
 | |
|   def write(self, dat):
 | |
|     for b,x in zip(self._bufs, dat):
 | |
|       b.write(x)
 | |
| 
 | |
|   @property
 | |
|   def count(self):
 | |
|     return min(map(lambda x: x.count, self._bufs))
 | |
| 
 | |
|   def sample(self, count):
 | |
|     cnt = self.count
 | |
|     assert cnt != 0
 | |
|     indices = map(lambda x: random.randint(0, cnt-1), range(count))
 | |
|     return map(lambda x: x.sample(count, indices), self._bufs)
 | |
| 
 | |
| 
 |