import os, random, pickle, queue
from typing import List
from pathlib import Path
from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count

import numpy as np
from tinygrad import dtypes, Tensor
from tinygrad.helpers import getenv, prod, Context, round_up, tqdm

### ResNet

class MyQueue:
  def __init__(self, multiple_readers=True, multiple_writers=True):
    self._reader, self._writer = connection.Pipe(duplex=False)
    self._rlock = Lock() if multiple_readers else None
    self._wlock = Lock() if multiple_writers else None
  def get(self):
    if self._rlock: self._rlock.acquire()
    ret = pickle.loads(self._reader.recv_bytes())
    if self._rlock: self._rlock.release()
    return ret
  def put(self, obj):
    if self._wlock: self._wlock.acquire()
    self._writer.send_bytes(pickle.dumps(obj))
    if self._wlock: self._wlock.release()

def shuffled_indices(n, seed=None):
  rng = random.Random(seed)
  indices = {}
  for i in range(n-1, -1, -1):
    j = rng.randint(0, i)
    if i not in indices: indices[i] = i
    if j not in indices: indices[j] = j
    indices[i], indices[j] = indices[j], indices[i]
    yield indices[i]
    del indices[i]

def loader_process(q_in, q_out, X:Tensor, seed):
  import signal
  signal.signal(signal.SIGINT, lambda _, __: exit(0))

  from extra.datasets.imagenet import center_crop, preprocess_train
  from PIL import Image

  with Context(DEBUG=0):
    while (_recv := q_in.get()) is not None:
      idx, fn, val = _recv
      if fn is not None:
        img = Image.open(fn)
        img = img.convert('RGB') if img.mode != "RGB" else img

        if val:
          # eval: 76.08%, load in 0m7.366s (0m5.301s with simd)
          # sudo apt-get install libjpeg-dev
          # CC="cc -mavx2" pip install -U --force-reinstall pillow-simd
          img = center_crop(img)
          img = np.array(img)
        else:
          # reseed rng for determinism
          if seed is not None:
            np.random.seed(seed * 2 ** 10 + idx)
            random.seed(seed * 2 ** 10 + idx)
          img = preprocess_train(img)
      else:
        # pad data with training mean
        img = np.tile(np.array([[[123.68, 116.78, 103.94]]], dtype=np.uint8), (224, 224, 1))

      # broken out
      #img_tensor = Tensor(img.tobytes(), device='CPU')
      #storage_tensor = X[idx].contiguous().realize().lazydata.base.realized
      #storage_tensor._copyin(img_tensor.numpy())

      # faster
      X[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes()

      # ideal
      #X[idx].assign(img.tobytes())   # NOTE: this is slow!
      q_out.put(idx)
    q_out.put(None)

def batch_load_resnet(batch_size=64, val=False, shuffle=True, seed=None, pad_first_batch=False):
  from extra.datasets.imagenet import get_train_files, get_val_files
  files = get_val_files() if val else get_train_files()
  from extra.datasets.imagenet import get_imagenet_categories
  cir = get_imagenet_categories()

  if pad_first_batch:
    FIRST_BATCH_PAD = round_up(len(files), batch_size) - len(files)
  else:
    FIRST_BATCH_PAD = 0
  file_count = FIRST_BATCH_PAD + len(files)
  BATCH_COUNT = min(32, file_count // batch_size)

  def _gen():
    for _ in range(FIRST_BATCH_PAD): yield -1
    yield from shuffled_indices(len(files), seed=seed) if shuffle else iter(range(len(files)))
  gen = iter(_gen())

  def enqueue_batch(num):
    for idx in range(num*batch_size, (num+1)*batch_size):
      fidx = next(gen)
      if fidx != -1:
        fn = files[fidx]
        q_in.put((idx, fn, val))
        Y[idx] = cir[fn.split("/")[-2]]
      else:
        # padding
        q_in.put((idx, None, val))
        Y[idx] = -1

  shutdown = False
  class Cookie:
    def __init__(self, num): self.num = num
    def __del__(self):
      if not shutdown:
        try: enqueue_batch(self.num)
        except StopIteration: pass

  gotten = [0]*BATCH_COUNT
  def receive_batch():
    while 1:
      num = q_out.get()//batch_size
      gotten[num] += 1
      if gotten[num] == batch_size: break
    gotten[num] = 0
    return X[num*batch_size:(num+1)*batch_size], Y[num*batch_size:(num+1)*batch_size], Cookie(num)

  #q_in, q_out = MyQueue(multiple_writers=False), MyQueue(multiple_readers=False)
  q_in, q_out = Queue(), Queue()

  sz = (batch_size*BATCH_COUNT, 224, 224, 3)
  if os.path.exists("/dev/shm/resnet_X"): os.unlink("/dev/shm/resnet_X")
  shm = shared_memory.SharedMemory(name="resnet_X", create=True, size=prod(sz))
  procs = []

  try:
    # disk:shm is slower
    #X = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:shm:{shm.name}")
    X = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:/dev/shm/resnet_X")
    Y = [None] * (batch_size*BATCH_COUNT)

    for _ in range(cpu_count()):
      p = Process(target=loader_process, args=(q_in, q_out, X, seed))
      p.daemon = True
      p.start()
      procs.append(p)

    for bn in range(BATCH_COUNT): enqueue_batch(bn)

    # NOTE: this is batch aligned, last ones are ignored unless pad_first_batch is True
    for _ in range(0, file_count//batch_size): yield receive_batch()
  finally:
    shutdown = True
    # empty queues
    for _ in procs: q_in.put(None)
    q_in.close()
    for _ in procs:
      while q_out.get() is not None: pass
    q_out.close()
    # shutdown processes
    for p in procs: p.join()
    shm.close()
    try:
      shm.unlink()
    except FileNotFoundError:
      # happens with BENCHMARK set
      pass

### BERT

def process_batch_bert(data: List[dict]) -> dict[str, Tensor]:
  return {
    "input_ids": Tensor(np.concatenate([s["input_ids"] for s in data], axis=0), dtype=dtypes.int32, device="CPU"),
    "input_mask": Tensor(np.concatenate([s["input_mask"] for s in data], axis=0), dtype=dtypes.int32, device="CPU"),
    "segment_ids": Tensor(np.concatenate([s["segment_ids"] for s in data], axis=0), dtype=dtypes.int32, device="CPU"),
    "masked_lm_positions": Tensor(np.concatenate([s["masked_lm_positions"] for s in data], axis=0), dtype=dtypes.int32, device="CPU"),
    "masked_lm_ids": Tensor(np.concatenate([s["masked_lm_ids"] for s in data], axis=0), dtype=dtypes.int32, device="CPU"),
    "masked_lm_weights": Tensor(np.concatenate([s["masked_lm_weights"] for s in data], axis=0), dtype=dtypes.float32, device="CPU"),
    "next_sentence_labels": Tensor(np.concatenate([s["next_sentence_labels"] for s in data], axis=0), dtype=dtypes.int32, device="CPU"),
  }

def load_file(file: str):
  with open(file, "rb") as f:
    return pickle.load(f)

class InterleavedDataset:
  def __init__(self, files:List[str], cycle_length:int):
    self.dataset = files
    self.cycle_length = cycle_length
    self.queues = [queue.Queue() for _ in range(self.cycle_length)]
    for i in range(len(self.queues)): self.queues[i].queue.extend(load_file(self.dataset.pop(0)))
    self.queue_pointer = len(self.queues) - 1

  def get(self):
    # Round-robin across queues
    try:
      self.advance()
      return self.queues[self.queue_pointer].get_nowait()
    except queue.Empty:
      self.fill(self.queue_pointer)
      return self.get()

  def advance(self):
    self.queue_pointer = (self.queue_pointer + 1) % self.cycle_length

  def fill(self, queue_index: int):
    try:
      file = self.dataset.pop(0)
    except IndexError:
      return
    self.queues[queue_index].queue.extend(load_file(file))

# Reference: https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/run_pretraining.py, Line 394
def batch_load_train_bert(BS:int):
  from extra.datasets.wikipedia import get_wiki_train_files
  fs = sorted(get_wiki_train_files())
  train_files = []
  while fs: # TF shuffle
    random.shuffle(fs)
    train_files.append(fs.pop(0))

  cycle_length = min(getenv("NUM_CPU_THREADS", min(os.cpu_count(), 8)), len(train_files))
  assert cycle_length > 0, "cycle_length must be greater than 0"

  dataset = InterleavedDataset(train_files, cycle_length)
  while True:
    yield process_batch_bert([dataset.get() for _ in range(BS)])

# Reference: https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/run_pretraining.py, Line 416
def batch_load_val_bert(BS:int):
  file =  getenv("BASEDIR", Path(__file__).parent.parents[1] / "extra" / "datasets" / "wiki") / "eval.pkl"
  dataset = load_file(file)
  idx = 0
  while True:
    start_idx = (idx * BS) % len(dataset)
    end_idx = ((idx + 1) * BS) % len(dataset)
    if start_idx < end_idx:
      yield process_batch_bert(dataset[start_idx:end_idx])
    else:  # wrap around the end to the beginning of the dataset
      yield process_batch_bert(dataset[start_idx:] + dataset[:end_idx])
    idx += 1

### UNET3D

def load_unet3d_data(preprocessed_dataset_dir, seed, queue_in, queue_out, X:Tensor, Y:Tensor):
  from extra.datasets.kits19 import rand_balanced_crop, rand_flip, random_brightness_augmentation, gaussian_noise

  while (data := queue_in.get()) is not None:
    idx, fn, val = data
    case_name = os.path.basename(fn).split("_x.npy")[0]
    x, y = np.load(preprocessed_dataset_dir / f"{case_name}_x.npy"), np.load(preprocessed_dataset_dir / f"{case_name}_y.npy")

    if not val:
      if seed is not None:
        np.random.seed(seed)
        random.seed(seed)

      x, y = rand_balanced_crop(x, y)
      x, y = rand_flip(x, y)
      x, y = x.astype(np.float32), y.astype(np.uint8)
      x = random_brightness_augmentation(x)
      x = gaussian_noise(x)

    X[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = x.tobytes()
    Y[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = y.tobytes()

    queue_out.put(idx)
  queue_out.put(None)

def batch_load_unet3d(preprocessed_dataset_dir:Path, batch_size:int=6, val:bool=False, shuffle:bool=True, seed=None):
  assert preprocessed_dataset_dir is not None, "run preprocess_data on kits19"

  files = sorted(list(preprocessed_dataset_dir.glob("*_x.npy")))
  file_indices = list(range(len(files)))
  batch_count = min(32, len(files) // batch_size)

  queue_in, queue_out = Queue(), Queue()
  procs, data_out_count = [], [0] * batch_count
  shm_name_x, shm_name_y = "unet3d_x", "unet3d_y"
  sz = (batch_size * batch_count, 1, 128, 128, 128)
  if os.path.exists(f"/dev/shm/{shm_name_x}"): os.unlink(f"/dev/shm/{shm_name_x}")
  if os.path.exists(f"/dev/shm/{shm_name_y}"): os.unlink(f"/dev/shm/{shm_name_y}")
  shm_x = shared_memory.SharedMemory(name=shm_name_x, create=True, size=prod(sz))
  shm_y = shared_memory.SharedMemory(name=shm_name_y, create=True, size=prod(sz))

  shutdown = False
  class Cookie:
    def __init__(self, bc):
      self.bc = bc
    def __del__(self):
      if not shutdown:
        try: enqueue_batch(self.bc)
        except StopIteration: pass

  def enqueue_batch(bc):
    for idx in range(bc * batch_size, (bc+1) * batch_size):
      fn = files[next(ds_iter)]
      queue_in.put((idx, fn, val))

  def shuffle_indices(file_indices, seed=None):
    rng = random.Random(seed)
    rng.shuffle(file_indices)

  if shuffle: shuffle_indices(file_indices, seed=seed)
  ds_iter = iter(file_indices)

  try:
    X = Tensor.empty(*sz, dtype=dtypes.float32, device=f"disk:/dev/shm/{shm_name_x}")
    Y = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:/dev/shm/{shm_name_y}")

    for _ in range(cpu_count()):
      proc = Process(target=load_unet3d_data, args=(preprocessed_dataset_dir, seed, queue_in, queue_out, X, Y))
      proc.daemon = True
      proc.start()
      
      procs.append(proc)

    for bc in range(batch_count):
      enqueue_batch(bc)

    for _ in range(len(files) // batch_size):
      while True:
        bc = queue_out.get() // batch_size
        data_out_count[bc] += 1
        if data_out_count[bc] == batch_size: break

      data_out_count[bc] = 0
      yield X[bc * batch_size:(bc + 1) * batch_size], Y[bc * batch_size:(bc + 1) * batch_size], Cookie(bc)
  finally:
    shutdown = True

    for _ in procs: queue_in.put(None)
    queue_in.close()

    for _ in procs:
      while queue_out.get() is not None: pass
    queue_out.close()

    # shutdown processes
    for proc in procs: proc.join()

    shm_x.close()
    shm_y.close()
    try:
      shm_x.unlink()
      shm_y.unlink()
    except FileNotFoundError:
      # happens with BENCHMARK set
      pass

### RetinaNet

def load_retinanet_data(base_dir:Path, val:bool, queue_in:Queue, queue_out:Queue,
                        imgs:Tensor, boxes:Tensor, labels:Tensor, matches:Tensor|None=None,
                        anchors:Tensor|None=None, seed:int|None=None):
  from extra.datasets.openimages import image_load, random_horizontal_flip, resize
  from examples.mlperf.helpers import box_iou, find_matches, generate_anchors
  import torch

  while (data:=queue_in.get()) is not None:
    idx, img, tgt = data
    img = image_load(base_dir, img["subset"], img["file_name"])

    if val:
      img = resize(img)[0]
    else:
      if seed is not None:
        np.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(seed)

      img, tgt = random_horizontal_flip(img, tgt)
      img, tgt, _ = resize(img, tgt=tgt)
      match_quality_matrix = box_iou(tgt["boxes"], (anchor := np.concatenate(generate_anchors((800, 800)))))
      match_idxs = find_matches(match_quality_matrix, allow_low_quality_matches=True)
      clipped_match_idxs = np.clip(match_idxs, 0, None)
      clipped_boxes, clipped_labels = tgt["boxes"][clipped_match_idxs], tgt["labels"][clipped_match_idxs]

      boxes[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = clipped_boxes.tobytes()
      labels[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = clipped_labels.tobytes()
      matches[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = match_idxs.tobytes()
      anchors[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = anchor.tobytes()

    imgs[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes()

    queue_out.put(idx)
  queue_out.put(None)

def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, shuffle:bool=True, seed:int|None=None):
  def _enqueue_batch(bc):
    from extra.datasets.openimages import prepare_target
    for idx in range(bc * batch_size, (bc+1) * batch_size):
      img = dataset.loadImgs(next(dataset_iter))[0]
      ann = dataset.loadAnns(dataset.getAnnIds(img_id:=img["id"]))
      tgt = prepare_target(ann, img_id, (img["height"], img["width"]))

      if img_ids is not None:
        img_ids[idx] = img_id

      if img_sizes is not None:
        img_sizes[idx] = tgt["image_size"]

      queue_in.put((idx, img, tgt))

  def _setup_shared_mem(shm_name:str, size:tuple[int, ...], dtype:dtypes) -> tuple[shared_memory.SharedMemory, Tensor]:
    if os.path.exists(f"/dev/shm/{shm_name}"): os.unlink(f"/dev/shm/{shm_name}")
    shm = shared_memory.SharedMemory(name=shm_name, create=True, size=prod(size))
    shm_tensor = Tensor.empty(*size, dtype=dtype, device=f"disk:/dev/shm/{shm_name}")
    return shm, shm_tensor

  image_ids = sorted(dataset.imgs.keys())
  batch_count = min(32, len(image_ids) // batch_size)

  queue_in, queue_out = Queue(), Queue()
  procs, data_out_count = [], [0] * batch_count

  shm_imgs, imgs = _setup_shared_mem("retinanet_imgs", (batch_size * batch_count, 800, 800, 3), dtypes.uint8)

  if val:
    boxes, labels, matches, anchors = None, None, None, None
    img_ids, img_sizes = [None] * (batch_size * batch_count), [None] * (batch_size * batch_count)
  else:
    img_ids, img_sizes = None, None
    shm_boxes, boxes = _setup_shared_mem("retinanet_boxes", (batch_size * batch_count, 120087, 4), dtypes.float32)
    shm_labels, labels = _setup_shared_mem("retinanet_labels", (batch_size * batch_count, 120087), dtypes.int64)
    shm_matches, matches = _setup_shared_mem("retinanet_matches", (batch_size * batch_count, 120087), dtypes.int64)
    shm_anchors, anchors = _setup_shared_mem("retinanet_anchors", (batch_size * batch_count, 120087, 4), dtypes.float64)

  shutdown = False
  class Cookie:
    def __init__(self, bc):
      self.bc = bc
    def __del__(self):
      if not shutdown:
        try: _enqueue_batch(self.bc)
        except StopIteration: pass

  def shuffle_indices(indices, seed):
    rng = random.Random(seed)
    rng.shuffle(indices)

  if shuffle: shuffle_indices(image_ids, seed=seed)
  dataset_iter = iter(image_ids)

  try:
    for _ in range(cpu_count()):
      proc = Process(
        target=load_retinanet_data,
        args=(base_dir, val, queue_in, queue_out, imgs, boxes, labels),
        kwargs={"matches": matches, "anchors": anchors, "seed": seed}
      )
      proc.daemon = True
      proc.start()
      procs.append(proc)

    for bc in range(batch_count):
      _enqueue_batch(bc)

    for _ in range(len(image_ids) // batch_size):
      while True:
        bc = queue_out.get() // batch_size
        data_out_count[bc] += 1
        if data_out_count[bc] == batch_size: break

      data_out_count[bc] = 0

      if val:
        yield (imgs[bc * batch_size:(bc + 1) * batch_size],
               img_ids[bc * batch_size:(bc + 1) * batch_size],
               img_sizes[bc * batch_size:(bc + 1) * batch_size],
               Cookie(bc))
      else:
        yield (imgs[bc * batch_size:(bc + 1) * batch_size],
               boxes[bc * batch_size:(bc + 1) * batch_size],
               labels[bc * batch_size:(bc + 1) * batch_size],
               matches[bc * batch_size:(bc + 1) * batch_size],
               anchors[bc * batch_size:(bc + 1) * batch_size],
               Cookie(bc))
  finally:
    shutdown = True

    for _ in procs: queue_in.put(None)
    queue_in.close()

    for _ in procs:
      while queue_out.get() is not None: pass
    queue_out.close()

    # shutdown processes
    for proc in procs: proc.join()

    shm_imgs.close()

    if not val:
      shm_boxes.close()
      shm_labels.close()
      shm_matches.close()
      shm_anchors.close()

    try:
      shm_imgs.unlink()

      if not val:
        shm_boxes.unlink()
        shm_labels.unlink()
        shm_matches.unlink()
        shm_anchors.unlink()
    except FileNotFoundError:
      # happens with BENCHMARK set
      pass

if __name__ == "__main__":
  def load_unet3d(val):
    assert not val, "validation set is not supported due to different sizes on inputs"

    from extra.datasets.kits19 import get_train_files, get_val_files, preprocess_dataset, TRAIN_PREPROCESSED_DIR, VAL_PREPROCESSED_DIR
    preprocessed_dir = VAL_PREPROCESSED_DIR if val else TRAIN_PREPROCESSED_DIR
    files = get_val_files() if val else get_train_files()

    if not preprocessed_dir.exists(): preprocess_dataset(files, preprocessed_dir, val)
    with tqdm(total=len(files)) as pbar:
      for x, _, _ in batch_load_unet3d(preprocessed_dir, val=val):
        pbar.update(x.shape[0])

  def load_resnet(val):
    from extra.datasets.imagenet import get_train_files, get_val_files
    files = get_val_files() if val else get_train_files()
    with tqdm(total=len(files)) as pbar:
      for x,y,c in batch_load_resnet(val=val):
        pbar.update(x.shape[0])

  def load_retinanet(val):
    from extra.datasets.openimages import BASEDIR, download_dataset
    from pycocotools.coco import COCO
    dataset = COCO(download_dataset(base_dir:=getenv("BASE_DIR", BASEDIR), "validation" if val else "train"))
    with tqdm(total=len(dataset.imgs.keys())) as pbar:
      for x in batch_load_retinanet(dataset, val, base_dir):
        pbar.update(x[0].shape[0])

  load_fn_name = f"load_{getenv('MODEL', 'resnet')}"
  if load_fn_name in globals():
    globals()[load_fn_name](getenv("VAL", 1))