#!/usr/bin/env python3
import os
import sys
import json
import bz2
import tempfile
import requests
import subprocess
import urllib.parse
from aenum import Enum
import capnp
import numpy as np

from tools.lib.exceptions import DataUnreadableError
try:
  from xx.chffr.lib.filereader import FileReader
except ImportError:
  from tools.lib.filereader import FileReader
from cereal import log as capnp_log

OP_PATH = os.path.dirname(os.path.dirname(capnp_log.__file__))

def index_log(fn):
  index_log_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "index_log")
  index_log = os.path.join(index_log_dir, "index_log")
  phonelibs_dir = os.path.join(OP_PATH, 'phonelibs')

  subprocess.check_call(["make", "PHONELIBS=" + phonelibs_dir], cwd=index_log_dir, stdout=subprocess.DEVNULL)

  try:
    dat = subprocess.check_output([index_log, fn, "-"])
  except subprocess.CalledProcessError:
    raise DataUnreadableError("%s capnp is corrupted/truncated" % fn)
  return np.frombuffer(dat, dtype=np.uint64)

def event_read_multiple_bytes(dat):
  with tempfile.NamedTemporaryFile() as dat_f:
    dat_f.write(dat)
    dat_f.flush()
    idx = index_log(dat_f.name)

  end_idx = np.uint64(len(dat))
  idx = np.append(idx, end_idx)

  return [capnp_log.Event.from_bytes(dat[idx[i]:idx[i+1]])
          for i in range(len(idx)-1)]


# this is an iterator itself, and uses private variables from LogReader
class MultiLogIterator(object):
  def __init__(self, log_paths, wraparound=True):
    self._log_paths = log_paths
    self._wraparound = wraparound

    self._first_log_idx = next(i for i in range(len(log_paths)) if log_paths[i] is not None)
    self._current_log = self._first_log_idx
    self._idx = 0
    self._log_readers = [None]*len(log_paths)
    self.start_time = self._log_reader(self._first_log_idx)._ts[0]

  def _log_reader(self, i):
    if self._log_readers[i] is None and self._log_paths[i] is not None:
      log_path = self._log_paths[i]
      print("LogReader:%s" % log_path)
      self._log_readers[i] = LogReader(log_path)

    return self._log_readers[i]

  def __iter__(self):
    return self

  def _inc(self):
    lr = self._log_reader(self._current_log)
    if self._idx < len(lr._ents)-1:
      self._idx += 1
    else:
      self._idx = 0
      self._current_log = next(i for i in range(self._current_log + 1, len(self._log_readers) + 1) if i == len(self._log_readers) or self._log_paths[i] is not None)
      # wraparound
      if self._current_log == len(self._log_readers):
        if self._wraparound:
          self._current_log = self._first_log_idx
        else:
          raise StopIteration

  def __next__(self):
    while 1:
      lr = self._log_reader(self._current_log)
      ret = lr._ents[self._idx]
      self._inc()
      return ret

  def tell(self):
    # returns seconds from start of log
    return (self._log_reader(self._current_log)._ts[self._idx] - self.start_time) * 1e-9

  def seek(self, ts):
    # seek to nearest minute
    minute = int(ts/60)
    if minute >= len(self._log_paths) or self._log_paths[minute] is None:
      return False

    self._current_log = minute

    # HACK: O(n) seek afterward
    self._idx = 0
    while self.tell() < ts:
      self._inc()
    return True


class LogReader(object):
  def __init__(self, fn, canonicalize=True, only_union_types=False):
    data_version = None
    _, ext = os.path.splitext(urllib.parse.urlparse(fn).path)
    with FileReader(fn) as f:
      dat = f.read()

    if ext == "":
      # old rlogs weren't bz2 compressed
      ents = event_read_multiple_bytes(dat)
    elif ext == ".bz2":
      dat = bz2.decompress(dat)
      ents = event_read_multiple_bytes(dat)
    else:
      raise Exception(f"unknown extension {ext}")

    self._ts = [x.logMonoTime for x in ents]
    self.data_version = data_version
    self._only_union_types = only_union_types
    self._ents = ents

  def __iter__(self):
    for ent in self._ents:
      if self._only_union_types:
        try:
          ent.which()
          yield ent
        except capnp.lib.capnp.KjException:
          pass
      else:
        yield ent

if __name__ == "__main__":
  log_path = sys.argv[1]
  lr = LogReader(log_path)
  for msg in lr:
    print(msg)