import os
import numpy as np
import random

class SamplingBuffer():
  def __init__(self, fn, size, write=False):
    self._fn = fn
    self._write = write
    if self._write:
      self._f = open(self._fn, "ab")
    else:
      self._f = open(self._fn, "rb")
    self._size = size
    self._refresh()

  def _refresh(self):
    self.cnt = os.path.getsize(self._fn) / self._size

  @property
  def count(self):
    self._refresh()
    return self.cnt

  def _fetch_one(self, x):
    assert self._write == False
    self._f.seek(x*self._size)
    return self._f.read(self._size)

  def sample(self, count, indices = None):
    if indices == None:
      cnt = self.count
      assert cnt != 0
      indices = map(lambda x: random.randint(0, cnt-1), range(count))
    return map(self._fetch_one, indices)

  def write(self, dat):
    assert self._write == True
    assert (len(dat) % self._size) == 0
    self._f.write(dat)
    self._f.flush()

class NumpySamplingBuffer():
  def __init__(self, fn, size, dtype, write=False):
    self._size = size
    self._dtype = dtype
    self._buf = SamplingBuffer(fn, len(np.zeros(size, dtype=dtype).tobytes()), write)

  @property
  def count(self):
    return self._buf.count

  def write(self, dat):
    self._buf.write(dat.tobytes())

  def sample(self, count, indices = None):
    return np.fromstring(''.join(self._buf.sample(count, indices)), dtype=self._dtype).reshape([count]+list(self._size))

# TODO: n IOPS needed where n is the Multi
class MultiNumpySamplingBuffer():
  def __init__(self, fn, npa, write=False):
    self._bufs = []
    for i,n in enumerate(npa):
      self._bufs.append(NumpySamplingBuffer(fn + ("_%d" % i), n[0], n[1], write))

  def write(self, dat):
    for b,x in zip(self._bufs, dat):
      b.write(x)

  @property
  def count(self):
    return min(map(lambda x: x.count, self._bufs))

  def sample(self, count):
    cnt = self.count
    assert cnt != 0
    indices = map(lambda x: random.randint(0, cnt-1), range(count))
    return map(lambda x: x.sample(count, indices), self._bufs)