# pylint: skip-file
import json
import os
import pickle
import struct
import subprocess
import tempfile
import threading
from enum import IntEnum
from functools import wraps
import numpy as np
from lru import LRU
import _io
from openpilot.tools.lib.cache import cache_path_for_file_path
from openpilot.tools.lib.exceptions import DataUnreadableError
from openpilot.common.file_helpers import atomic_write_in_dir
from openpilot.tools.lib.filereader import FileReader
HEVC_SLICE_B = 0
HEVC_SLICE_P = 1
HEVC_SLICE_I = 2
class GOPReader:
  def get_gop(self, num):
    # returns (start_frame_num, num_frames, frames_to_skip, gop_data)
    raise NotImplementedError
class DoNothingContextManager:
  def __enter__(self):
    return self
  def __exit__(self, *x):
    pass
class FrameType(IntEnum):
  raw = 1
  h265_stream = 2
def fingerprint_video(fn):
  with FileReader(fn) as f:
    header = f.read(4)
  if len(header) == 0:
    raise DataUnreadableError(f"{fn} is empty")
  elif header == b"\x00\xc0\x12\x00":
    return FrameType.raw
  elif header == b"\x00\x00\x00\x01":
    if 'hevc' in fn:
      return FrameType.h265_stream
    else:
      raise NotImplementedError(fn)
  else:
    raise NotImplementedError(fn)
def ffprobe(fn, fmt=None):
  cmd = ["ffprobe",
         "-v", "quiet",
         "-print_format", "json",
         "-show_format", "-show_streams"]
  if fmt:
    cmd += ["-f", fmt]
  cmd += [fn]
  try:
    ffprobe_output = subprocess.check_output(cmd)
  except subprocess.CalledProcessError as e:
    raise DataUnreadableError(fn) from e
  return json.loads(ffprobe_output)
def vidindex(fn, typ):
  vidindex_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "vidindex")
  vidindex = os.path.join(vidindex_dir, "vidindex")
  subprocess.check_call(["make"], cwd=vidindex_dir, stdout=subprocess.DEVNULL)
  with tempfile.NamedTemporaryFile() as prefix_f, \
       tempfile.NamedTemporaryFile() as index_f:
    try:
      subprocess.check_call([vidindex, typ, fn, prefix_f.name, index_f.name])
    except subprocess.CalledProcessError as e:
      raise DataUnreadableError(f"vidindex failed on file {fn}") from e
    with open(index_f.name, "rb") as f:
      index = f.read()
    with open(prefix_f.name, "rb") as f:
      prefix = f.read()
  index = np.frombuffer(index, np.uint32).reshape(-1, 2)
  assert index[-1, 0] == 0xFFFFFFFF
  assert index[-1, 1] == os.path.getsize(fn)
  return index, prefix
def cache_fn(func):
  @wraps(func)
  def cache_inner(fn, *args, **kwargs):
    if kwargs.pop('no_cache', None):
      cache_path = None
    else:
      cache_prefix = kwargs.pop('cache_prefix', None)
      cache_path = cache_path_for_file_path(fn, cache_prefix)
    if cache_path and os.path.exists(cache_path):
      with open(cache_path, "rb") as cache_file:
        cache_value = pickle.load(cache_file)
    else:
      cache_value = func(fn, *args, **kwargs)
      if cache_path:
        with atomic_write_in_dir(cache_path, mode="wb", overwrite=True) as cache_file:
          pickle.dump(cache_value, cache_file, -1)
    return cache_value
  return cache_inner
@cache_fn
def index_stream(fn, typ):
  assert typ in ("hevc", )
  with FileReader(fn) as f:
    assert os.path.exists(f.name), fn
    index, prefix = vidindex(f.name, typ)
    probe = ffprobe(f.name, typ)
  return {
    'index': index,
    'global_prefix': prefix,
    'probe': probe
  }
def index_videos(camera_paths, cache_prefix=None):
  """Requires that paths in camera_paths are contiguous and of the same type."""
  if len(camera_paths) < 1:
    raise ValueError("must provide at least one video to index")
  frame_type = fingerprint_video(camera_paths[0])
  for fn in camera_paths:
    index_video(fn, frame_type, cache_prefix)
def index_video(fn, frame_type=None, cache_prefix=None):
  cache_path = cache_path_for_file_path(fn, cache_prefix)
  if os.path.exists(cache_path):
    return
  if frame_type is None:
    frame_type = fingerprint_video(fn[0])
  if frame_type == FrameType.h265_stream:
    index_stream(fn, "hevc", cache_prefix=cache_prefix)
  else:
    raise NotImplementedError("Only h265 supported")
def get_video_index(fn, frame_type, cache_prefix=None):
  cache_path = cache_path_for_file_path(fn, cache_prefix)
  if not os.path.exists(cache_path):
    index_video(fn, frame_type, cache_prefix)
  if not os.path.exists(cache_path):
    return None
  with open(cache_path, "rb") as cache_file:
    return pickle.load(cache_file)
def read_file_check_size(f, sz, cookie):
  buff = bytearray(sz)
  bytes_read = f.readinto(buff)
  assert bytes_read == sz, (bytes_read, sz)
  return buff
def rgb24toyuv(rgb):
  yuv_from_rgb = np.array([[ 0.299     ,  0.587     ,  0.114      ],
                           [-0.14714119, -0.28886916,  0.43601035 ],
                           [ 0.61497538, -0.51496512, -0.10001026 ]])
  img = np.dot(rgb.reshape(-1, 3), yuv_from_rgb.T).reshape(rgb.shape)
  ys = img[:, :, 0]
  us = (img[::2, ::2, 1] + img[1::2, ::2, 1] + img[::2, 1::2, 1] + img[1::2, 1::2, 1]) / 4 + 128
  vs = (img[::2, ::2, 2] + img[1::2, ::2, 2] + img[::2, 1::2, 2] + img[1::2, 1::2, 2]) / 4 + 128
  return ys, us, vs
def rgb24toyuv420(rgb):
  ys, us, vs = rgb24toyuv(rgb)
  y_len = rgb.shape[0] * rgb.shape[1]
  uv_len = y_len // 4
  yuv420 = np.empty(y_len + 2 * uv_len, dtype=rgb.dtype)
  yuv420[:y_len] = ys.reshape(-1)
  yuv420[y_len:y_len + uv_len] = us.reshape(-1)
  yuv420[y_len + uv_len:y_len + 2 * uv_len] = vs.reshape(-1)
  return yuv420.clip(0, 255).astype('uint8')
def rgb24tonv12(rgb):
  ys, us, vs = rgb24toyuv(rgb)
  y_len = rgb.shape[0] * rgb.shape[1]
  uv_len = y_len // 4
  nv12 = np.empty(y_len + 2 * uv_len, dtype=rgb.dtype)
  nv12[:y_len] = ys.reshape(-1)
  nv12[y_len::2] = us.reshape(-1)
  nv12[y_len+1::2] = vs.reshape(-1)
  return nv12.clip(0, 255).astype('uint8')
def decompress_video_data(rawdat, vid_fmt, w, h, pix_fmt):
  # using a tempfile is much faster than proc.communicate for some reason
  with tempfile.TemporaryFile() as tmpf:
    tmpf.write(rawdat)
    tmpf.seek(0)
    threads = os.getenv("FFMPEG_THREADS", "0")
    cuda = os.getenv("FFMPEG_CUDA", "0") == "1"
    args = ["ffmpeg",
            "-threads", threads,
            "-hwaccel", "none" if not cuda else "cuda",
            "-c:v", "hevc",
            "-vsync", "0",
            "-f", vid_fmt,
            "-flags2", "showall",
            "-i", "pipe:0",
            "-threads", threads,
            "-f", "rawvideo",
            "-pix_fmt", pix_fmt,
            "pipe:1"]
    with subprocess.Popen(args, stdin=tmpf, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) as proc:
      # dat = proc.communicate()[0]
      dat = proc.stdout.read()
      if proc.wait() != 0:
        raise DataUnreadableError("ffmpeg failed")
  if pix_fmt == "rgb24":
    ret = np.frombuffer(dat, dtype=np.uint8).reshape(-1, h, w, 3)
  elif pix_fmt == "nv12":
    ret = np.frombuffer(dat, dtype=np.uint8).reshape(-1, (h*w*3//2))
  elif pix_fmt == "yuv420p":
    ret = np.frombuffer(dat, dtype=np.uint8).reshape(-1, (h*w*3//2))
  elif pix_fmt == "yuv444p":
    ret = np.frombuffer(dat, dtype=np.uint8).reshape(-1, 3, h, w)
  else:
    raise NotImplementedError
  return ret
class BaseFrameReader:
  # properties: frame_type, frame_count, w, h
  def __enter__(self):
    return self
  def __exit__(self, *args):
    self.close()
  def close(self):
    pass
  def get(self, num, count=1, pix_fmt="yuv420p"):
    raise NotImplementedError
def FrameReader(fn, cache_prefix=None, readahead=False, readbehind=False, index_data=None):
  frame_type = fingerprint_video(fn)
  if frame_type == FrameType.raw:
    return RawFrameReader(fn)
  elif frame_type in (FrameType.h265_stream,):
    if not index_data:
      index_data = get_video_index(fn, frame_type, cache_prefix)
    return StreamFrameReader(fn, frame_type, index_data, readahead=readahead, readbehind=readbehind)
  else:
    raise NotImplementedError(frame_type)
class RawData:
  def __init__(self, f):
    self.f = _io.FileIO(f, 'rb')
    self.lenn = struct.unpack("I", self.f.read(4))[0]
    self.count = os.path.getsize(f) / (self.lenn+4)
  def read(self, i):
    self.f.seek((self.lenn+4)*i + 4)
    return self.f.read(self.lenn)
class RawFrameReader(BaseFrameReader):
  def __init__(self, fn):
    # raw camera
    self.fn = fn
    self.frame_type = FrameType.raw
    self.rawfile = RawData(self.fn)
    self.frame_count = self.rawfile.count
    self.w, self.h = 640, 480
  def load_and_debayer(self, img):
    img = np.frombuffer(img, dtype='uint8').reshape(960, 1280)
    cimg = np.dstack([img[0::2, 1::2], ((img[0::2, 0::2].astype("uint16") + img[1::2, 1::2].astype("uint16")) >> 1).astype("uint8"), img[1::2, 0::2]])
    return cimg
  def get(self, num, count=1, pix_fmt="yuv420p"):
    assert self.frame_count is not None
    assert num+count <= self.frame_count
    if pix_fmt not in ("nv12", "yuv420p", "rgb24"):
      raise ValueError(f"Unsupported pixel format {pix_fmt!r}")
    app = []
    for i in range(num, num+count):
      dat = self.rawfile.read(i)
      rgb_dat = self.load_and_debayer(dat)
      if pix_fmt == "rgb24":
        app.append(rgb_dat)
      elif pix_fmt == "nv12":
        app.append(rgb24tonv12(rgb_dat))
      elif pix_fmt == "yuv420p":
        app.append(rgb24toyuv420(rgb_dat))
      else:
        raise NotImplementedError
    return app
class VideoStreamDecompressor:
  def __init__(self, fn, vid_fmt, w, h, pix_fmt):
    self.fn = fn
    self.vid_fmt = vid_fmt
    self.w = w
    self.h = h
    self.pix_fmt = pix_fmt
    if pix_fmt in ("nv12", "yuv420p"):
      self.out_size = w*h*3//2  # yuv420p
    elif pix_fmt in ("rgb24", "yuv444p"):
      self.out_size = w*h*3
    else:
      raise NotImplementedError
    self.proc = None
    self.t = threading.Thread(target=self.write_thread)
    self.t.daemon = True
  def write_thread(self):
    try:
      with FileReader(self.fn) as f:
        while True:
          r = f.read(1024*1024)
          if len(r) == 0:
            break
          self.proc.stdin.write(r)
    except BrokenPipeError:
      pass
    finally:
      self.proc.stdin.close()
  def read(self):
    threads = os.getenv("FFMPEG_THREADS", "0")
    cuda = os.getenv("FFMPEG_CUDA", "0") == "1"
    cmd = [
      "ffmpeg",
      "-threads", threads,
      "-hwaccel", "none" if not cuda else "cuda",
      "-c:v", "hevc",
      # "-avioflags", "direct",
      "-analyzeduration", "0",
      "-probesize", "32",
      "-flush_packets", "0",
      # "-fflags", "nobuffer",
      "-vsync", "0",
      "-f", self.vid_fmt,
      "-i", "pipe:0",
      "-threads", threads,
      "-f", "rawvideo",
      "-pix_fmt", self.pix_fmt,
      "pipe:1"
    ]
    self.proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
    try:
      self.t.start()
      while True:
        dat = self.proc.stdout.read(self.out_size)
        if len(dat) == 0:
          break
        assert len(dat) == self.out_size
        if self.pix_fmt == "rgb24":
          ret = np.frombuffer(dat, dtype=np.uint8).reshape((self.h, self.w, 3))
        elif self.pix_fmt == "yuv420p":
          ret = np.frombuffer(dat, dtype=np.uint8)
        elif self.pix_fmt == "nv12":
          ret = np.frombuffer(dat, dtype=np.uint8)
        elif self.pix_fmt == "yuv444p":
          ret = np.frombuffer(dat, dtype=np.uint8).reshape((3, self.h, self.w))
        else:
          raise RuntimeError(f"unknown pix_fmt: {self.pix_fmt}")
        yield ret
      result_code = self.proc.wait()
      assert result_code == 0, result_code
    finally:
      self.proc.kill()
      self.t.join()
class StreamGOPReader(GOPReader):
  def __init__(self, fn, frame_type, index_data):
    assert frame_type == FrameType.h265_stream
    self.fn = fn
    self.frame_type = frame_type
    self.frame_count = None
    self.w, self.h = None, None
    self.prefix = None
    self.index = None
    self.index = index_data['index']
    self.prefix = index_data['global_prefix']
    probe = index_data['probe']
    self.prefix_frame_data = None
    self.num_prefix_frames = 0
    self.vid_fmt = "hevc"
    i = 0
    while i < self.index.shape[0] and self.index[i, 0] != HEVC_SLICE_I:
      i += 1
    self.first_iframe = i
    assert self.first_iframe == 0
    self.frame_count = len(self.index) - 1
    self.w = probe['streams'][0]['width']
    self.h = probe['streams'][0]['height']
  def _lookup_gop(self, num):
    frame_b = num
    while frame_b > 0 and self.index[frame_b, 0] != HEVC_SLICE_I:
      frame_b -= 1
    frame_e = num + 1
    while frame_e < (len(self.index) - 1) and self.index[frame_e, 0] != HEVC_SLICE_I:
      frame_e += 1
    offset_b = self.index[frame_b, 1]
    offset_e = self.index[frame_e, 1]
    return (frame_b, frame_e, offset_b, offset_e)
  def get_gop(self, num):
    frame_b, frame_e, offset_b, offset_e = self._lookup_gop(num)
    assert frame_b <= num < frame_e
    num_frames = frame_e - frame_b
    with FileReader(self.fn) as f:
      f.seek(offset_b)
      rawdat = f.read(offset_e - offset_b)
      if num < self.first_iframe:
        assert self.prefix_frame_data
        rawdat = self.prefix_frame_data + rawdat
      rawdat = self.prefix + rawdat
    skip_frames = 0
    if num < self.first_iframe:
      skip_frames = self.num_prefix_frames
    return frame_b, num_frames, skip_frames, rawdat
class GOPFrameReader(BaseFrameReader):
  #FrameReader with caching and readahead for formats that are group-of-picture based
  def __init__(self, readahead=False, readbehind=False):
    self.open_ = True
    self.readahead = readahead
    self.readbehind = readbehind
    self.frame_cache = LRU(64)
    if self.readahead:
      self.cache_lock = threading.RLock()
      self.readahead_last = None
      self.readahead_len = 30
      self.readahead_c = threading.Condition()
      self.readahead_thread = threading.Thread(target=self._readahead_thread)
      self.readahead_thread.daemon = True
      self.readahead_thread.start()
    else:
      self.cache_lock = DoNothingContextManager()
  def close(self):
    if not self.open_:
      return
    self.open_ = False
    if self.readahead:
      self.readahead_c.acquire()
      self.readahead_c.notify()
      self.readahead_c.release()
      self.readahead_thread.join()
  def _readahead_thread(self):
    while True:
      self.readahead_c.acquire()
      try:
        if not self.open_:
          break
        self.readahead_c.wait()
      finally:
        self.readahead_c.release()
      if not self.open_:
        break
      assert self.readahead_last
      num, pix_fmt = self.readahead_last
      if self.readbehind:
        for k in range(num - 1, max(0, num - self.readahead_len), -1):
          self._get_one(k, pix_fmt)
      else:
        for k in range(num, min(self.frame_count, num + self.readahead_len)):
          self._get_one(k, pix_fmt)
  def _get_one(self, num, pix_fmt):
    assert num < self.frame_count
    if (num, pix_fmt) in self.frame_cache:
      return self.frame_cache[(num, pix_fmt)]
    with self.cache_lock:
      if (num, pix_fmt) in self.frame_cache:
        return self.frame_cache[(num, pix_fmt)]
      frame_b, num_frames, skip_frames, rawdat = self.get_gop(num)
      ret = decompress_video_data(rawdat, self.vid_fmt, self.w, self.h, pix_fmt)
      ret = ret[skip_frames:]
      assert ret.shape[0] == num_frames
      for i in range(ret.shape[0]):
        self.frame_cache[(frame_b+i, pix_fmt)] = ret[i]
      return self.frame_cache[(num, pix_fmt)]
  def get(self, num, count=1, pix_fmt="yuv420p"):
    assert self.frame_count is not None
    if num + count > self.frame_count:
      raise ValueError(f"{num + count} > {self.frame_count}")
    if pix_fmt not in ("nv12", "yuv420p", "rgb24", "yuv444p"):
      raise ValueError(f"Unsupported pixel format {pix_fmt!r}")
    ret = [self._get_one(num + i, pix_fmt) for i in range(count)]
    if self.readahead:
      self.readahead_last = (num+count, pix_fmt)
      self.readahead_c.acquire()
      self.readahead_c.notify()
      self.readahead_c.release()
    return ret
class StreamFrameReader(StreamGOPReader, GOPFrameReader):
  def __init__(self, fn, frame_type, index_data, readahead=False, readbehind=False):
    StreamGOPReader.__init__(self, fn, frame_type, index_data)
    GOPFrameReader.__init__(self, readahead, readbehind)
def GOPFrameIterator(gop_reader, pix_fmt):
  dec = VideoStreamDecompressor(gop_reader.fn, gop_reader.vid_fmt, gop_reader.w, gop_reader.h, pix_fmt)
  yield from dec.read()
def FrameIterator(fn, pix_fmt, **kwargs):
  fr = FrameReader(fn, **kwargs)
  if isinstance(fr, GOPReader):
    yield from GOPFrameIterator(fr, pix_fmt)
  else:
    for i in range(fr.frame_count):
      yield fr.get(i, pix_fmt=pix_fmt)[0]