You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							173 lines
						
					
					
						
							5.9 KiB
						
					
					
				
			
		
		
	
	
							173 lines
						
					
					
						
							5.9 KiB
						
					
					
				| import os
 | |
| import subprocess
 | |
| import json
 | |
| from collections.abc import Iterator
 | |
| from collections import OrderedDict
 | |
| 
 | |
| import numpy as np
 | |
| from openpilot.tools.lib.filereader import FileReader, resolve_name
 | |
| from openpilot.tools.lib.exceptions import DataUnreadableError
 | |
| from openpilot.tools.lib.vidindex import hevc_index
 | |
| 
 | |
| 
 | |
| HEVC_SLICE_B = 0
 | |
| HEVC_SLICE_P = 1
 | |
| HEVC_SLICE_I = 2
 | |
| 
 | |
| 
 | |
| class LRUCache:
 | |
|   def __init__(self, capacity: int):
 | |
|     self._cache: OrderedDict = OrderedDict()
 | |
|     self.capacity = capacity
 | |
| 
 | |
|   def __getitem__(self, key):
 | |
|     self._cache.move_to_end(key)
 | |
|     return self._cache[key]
 | |
| 
 | |
|   def __setitem__(self, key, value):
 | |
|     self._cache[key] = value
 | |
|     if len(self._cache) > self.capacity:
 | |
|         self._cache.popitem(last=False)
 | |
| 
 | |
|   def __contains__(self, key):
 | |
|     return key in self._cache
 | |
| 
 | |
| 
 | |
| def assert_hvec(fn: str) -> None:
 | |
|   with FileReader(fn) as f:
 | |
|     header = f.read(4)
 | |
|   if len(header) == 0:
 | |
|     raise DataUnreadableError(f"{fn} is empty")
 | |
|   elif header == b"\x00\x00\x00\x01":
 | |
|     if 'hevc' not in fn:
 | |
|       raise NotImplementedError(fn)
 | |
| 
 | |
| def decompress_video_data(rawdat, w, h, pix_fmt="rgb24", vid_fmt='hevc') -> np.ndarray:
 | |
|   threads = os.getenv("FFMPEG_THREADS", "0")
 | |
|   args = ["ffmpeg", "-v", "quiet",
 | |
|           "-threads", threads,
 | |
|           "-c:v", "hevc",
 | |
|           "-vsync", "0",
 | |
|           "-f", vid_fmt,
 | |
|           "-flags2", "showall",
 | |
|           "-i", "-",
 | |
|           "-f", "rawvideo",
 | |
|           "-pix_fmt", pix_fmt,
 | |
|           "-"]
 | |
|   dat = subprocess.check_output(args, input=rawdat)
 | |
| 
 | |
|   ret: np.ndarray
 | |
|   if pix_fmt == "rgb24":
 | |
|     ret = np.frombuffer(dat, dtype=np.uint8).reshape(-1, h, w, 3)
 | |
|   elif pix_fmt in ["nv12", "yuv420p"]:
 | |
|     ret = np.frombuffer(dat, dtype=np.uint8).reshape(-1, (h*w*3//2))
 | |
|   else:
 | |
|     raise NotImplementedError(f"Unsupported pixel format: {pix_fmt}")
 | |
|   return ret
 | |
| 
 | |
| def ffprobe(fn, fmt=None):
 | |
|   fn = resolve_name(fn)
 | |
|   cmd = ["ffprobe", "-v", "quiet", "-print_format", "json", "-show_format", "-show_streams"]
 | |
|   if fmt:
 | |
|     cmd += ["-f", fmt]
 | |
|   cmd += ["-i", "-"]
 | |
| 
 | |
|   try:
 | |
|     with FileReader(fn) as f:
 | |
|       ffprobe_output = subprocess.check_output(cmd, input=f.read(4096))
 | |
|   except subprocess.CalledProcessError as e:
 | |
|     raise DataUnreadableError(fn) from e
 | |
|   return json.loads(ffprobe_output)
 | |
| 
 | |
| def get_index_data(fn: str, index_data: dict|None = None):
 | |
|   if index_data is None:
 | |
|     index_data = get_video_index(fn)
 | |
|     if index_data is None:
 | |
|       raise DataUnreadableError(f"Failed to index {fn!r}")
 | |
|   stream = index_data["probe"]["streams"][0]
 | |
|   return index_data["index"], index_data["global_prefix"], stream["width"], stream["height"]
 | |
| 
 | |
| def get_video_index(fn):
 | |
|   assert_hvec(fn)
 | |
|   frame_types, dat_len, prefix = hevc_index(fn)
 | |
|   index = np.array(frame_types + [(0xFFFFFFFF, dat_len)], dtype=np.uint32)
 | |
|   probe = ffprobe(fn, "hevc")
 | |
|   return {
 | |
|     'index': index,
 | |
|     'global_prefix': prefix,
 | |
|     'probe': probe
 | |
|   }
 | |
| 
 | |
| 
 | |
| class FfmpegDecoder:
 | |
|   def __init__(self, fn: str, index_data: dict|None = None,
 | |
|                pix_fmt: str = "rgb24"):
 | |
|     self.fn = fn
 | |
|     self.index, self.prefix, self.w, self.h = get_index_data(fn, index_data)
 | |
|     self.frame_count = len(self.index) - 1          # sentinel row at the end
 | |
|     self.iframes = np.where(self.index[:, 0] == HEVC_SLICE_I)[0]
 | |
|     self.pix_fmt = pix_fmt
 | |
| 
 | |
|   def _gop_bounds(self, frame_idx: int):
 | |
|     f_b = frame_idx
 | |
|     while f_b > 0 and self.index[f_b, 0] != HEVC_SLICE_I:
 | |
|       f_b -= 1
 | |
|     f_e = frame_idx + 1
 | |
|     while f_e < self.frame_count and self.index[f_e, 0] != HEVC_SLICE_I:
 | |
|       f_e += 1
 | |
|     return f_b, f_e, self.index[f_b, 1], self.index[f_e, 1]
 | |
| 
 | |
|   def _decode_gop(self, raw: bytes) -> Iterator[np.ndarray]:
 | |
|     yield from decompress_video_data(raw, self.w, self.h, self.pix_fmt)
 | |
| 
 | |
|   def get_gop_start(self, frame_idx: int):
 | |
|     return self.iframes[np.searchsorted(self.iframes, frame_idx, side="right") - 1]
 | |
| 
 | |
|   def get_iterator(self, start_fidx: int = 0, end_fidx: int|None = None,
 | |
|                    frame_skip: int = 1) -> Iterator[tuple[int, np.ndarray]]:
 | |
|     end_fidx = end_fidx or self.frame_count
 | |
|     fidx = start_fidx
 | |
|     while fidx < end_fidx:
 | |
|       f_b, f_e, off_b, off_e = self._gop_bounds(fidx)
 | |
|       with FileReader(self.fn) as f:
 | |
|         f.seek(off_b)
 | |
|         raw = self.prefix + f.read(off_e - off_b)
 | |
|       # number of frames to discard inside this GOP before the wanted one
 | |
|       for i, frm in enumerate(decompress_video_data(raw, self.w, self.h, self.pix_fmt)):
 | |
|         fidx = f_b + i
 | |
|         if fidx >= end_fidx:
 | |
|           return
 | |
|         elif fidx >= start_fidx and (fidx - start_fidx) % frame_skip == 0:
 | |
|           yield fidx, frm
 | |
|       fidx += 1
 | |
| 
 | |
| def FrameIterator(fn: str, index_data: dict|None=None,
 | |
|                         pix_fmt: str = "rgb24",
 | |
|                         start_fidx:int=0, end_fidx=None, frame_skip:int=1) -> Iterator[np.ndarray]:
 | |
|   dec = FfmpegDecoder(fn, pix_fmt=pix_fmt, index_data=index_data)
 | |
|   for _, frame in dec.get_iterator(start_fidx=start_fidx, end_fidx=end_fidx, frame_skip=frame_skip):
 | |
|     yield frame
 | |
| 
 | |
| class FrameReader:
 | |
|   def __init__(self, fn: str, index_data: dict|None = None,
 | |
|                cache_size: int = 30, pix_fmt: str = "rgb24"):
 | |
|     self.decoder = FfmpegDecoder(fn, index_data, pix_fmt)
 | |
|     self.iframes = self.decoder.iframes
 | |
|     self._cache: LRUCache = LRUCache(cache_size)
 | |
|     self.w, self.h, self.frame_count, = self.decoder.w, self.decoder.h, self.decoder.frame_count
 | |
|     self.pix_fmt = pix_fmt
 | |
| 
 | |
|     self.it: Iterator[tuple[int, np.ndarray]] | None = None
 | |
|     self.fidx = -1
 | |
| 
 | |
|   def get(self, fidx:int):
 | |
|     if fidx in self._cache:  # If frame is cached, return it
 | |
|       return self._cache[fidx]
 | |
|     read_start = self.decoder.get_gop_start(fidx)
 | |
|     if not self.it or fidx < self.fidx or read_start != self.decoder.get_gop_start(self.fidx):  # If the frame is in a different GOP, reset the iterator
 | |
|       self.it = self.decoder.get_iterator(read_start)
 | |
|       self.fidx = -1
 | |
|     while self.fidx < fidx:
 | |
|       self.fidx, frame = next(self.it)
 | |
|       self._cache[self.fidx] = frame
 | |
|     return self._cache[fidx]
 | |
| 
 |