From 8f1806602c4e4a70e073bf1abd8f61206864d652 Mon Sep 17 00:00:00 2001 From: Greg Hogan Date: Thu, 22 Jul 2021 12:08:56 -0700 Subject: [PATCH] FrameIterator that uses 1+ GB less RAM (#21687) --- tools/lib/filereader.py | 4 +- tools/lib/framereader.py | 160 +++++++++++++++------------------------ tools/lib/url_file.py | 7 +- 3 files changed, 66 insertions(+), 105 deletions(-) diff --git a/tools/lib/filereader.py b/tools/lib/filereader.py index 5c9b375bb2..3d4c46b220 100644 --- a/tools/lib/filereader.py +++ b/tools/lib/filereader.py @@ -1,8 +1,6 @@ from tools.lib.url_file import URLFile - def FileReader(fn, debug=False): if fn.startswith("http://") or fn.startswith("https://"): return URLFile(fn, debug=debug) - else: - return open(fn, "rb") + return open(fn, "rb") diff --git a/tools/lib/framereader.py b/tools/lib/framereader.py index 5cc3c16e93..b37e7b9c1a 100644 --- a/tools/lib/framereader.py +++ b/tools/lib/framereader.py @@ -2,7 +2,6 @@ import json import os import pickle -import queue import struct import subprocess import tempfile @@ -326,7 +325,8 @@ class RawFrameReader(BaseFrameReader): class VideoStreamDecompressor: - def __init__(self, vid_fmt, w, h, pix_fmt): + def __init__(self, fn, vid_fmt, w, h, pix_fmt): + self.fn = fn self.vid_fmt = vid_fmt self.w = w self.h = h @@ -339,74 +339,66 @@ class VideoStreamDecompressor: else: raise NotImplementedError - self.out_q = queue.Queue() + self.proc = None + self.t = threading.Thread(target=self.write_thread) + self.t.daemon = True + + def write_thread(self): + try: + with FileReader(self.fn) as f: + while True: + r = f.read(1024*1024) + if len(r) == 0: + break + self.proc.stdin.write(r) + finally: + self.proc.stdin.close() + def read(self): threads = os.getenv("FFMPEG_THREADS", "0") cuda = os.getenv("FFMPEG_CUDA", "0") == "1" - self.proc = subprocess.Popen( - ["ffmpeg", - "-threads", threads, - "-hwaccel", "none" if not cuda else "cuda", - "-c:v", "hevc", - # "-avioflags", "direct", - "-analyzeduration", "0", - "-probesize", "32", - "-flush_packets", "0", - # "-fflags", "nobuffer", - "-vsync", "0", - "-f", vid_fmt, - "-i", "pipe:0", - "-threads", threads, - "-f", "rawvideo", - "-pix_fmt", pix_fmt, - "pipe:1"], - stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=open("/dev/null", "wb")) + cmd = [ + "ffmpeg", + "-threads", threads, + "-hwaccel", "none" if not cuda else "cuda", + "-c:v", "hevc", + # "-avioflags", "direct", + "-analyzeduration", "0", + "-probesize", "32", + "-flush_packets", "0", + # "-fflags", "nobuffer", + "-vsync", "0", + "-f", self.vid_fmt, + "-i", "pipe:0", + "-threads", threads, + "-f", "rawvideo", + "-pix_fmt", self.pix_fmt, + "pipe:1" + ] + self.proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) + try: + self.t.start() - def read_thread(): while True: - r = self.proc.stdout.read(self.out_size) - if len(r) == 0: + dat = self.proc.stdout.read(self.out_size) + if len(dat) == 0: break - assert len(r) == self.out_size - self.out_q.put(r) - - self.t = threading.Thread(target=read_thread) - self.t.daemon = True - self.t.start() - - def __enter__(self): - return self - - def __exit__(self, *args): - self.close() - - def write(self, rawdat): - self.proc.stdin.write(rawdat) - self.proc.stdin.flush() - - def read(self): - dat = self.out_q.get(block=True) - - if self.pix_fmt == "rgb24": - ret = np.frombuffer(dat, dtype=np.uint8).reshape((self.h, self.w, 3)) - elif self.pix_fmt == "yuv420p": - ret = np.frombuffer(dat, dtype=np.uint8) - elif self.pix_fmt == "yuv444p": - ret = np.frombuffer(dat, dtype=np.uint8).reshape((3, self.h, self.w)) - else: - assert False - - return ret - - def eos(self): - self.proc.stdin.close() - - def close(self): - self.proc.stdin.close() - self.t.join() - self.proc.wait() - assert self.proc.wait() == 0 - + assert len(dat) == self.out_size + if self.pix_fmt == "rgb24": + ret = np.frombuffer(dat, dtype=np.uint8).reshape((self.h, self.w, 3)) + elif self.pix_fmt == "yuv420p": + ret = np.frombuffer(dat, dtype=np.uint8) + elif self.pix_fmt == "yuv444p": + ret = np.frombuffer(dat, dtype=np.uint8).reshape((3, self.h, self.w)) + else: + assert False + yield ret + + result_code = self.proc.wait() + assert result_code == 0, result_code + finally: + self.proc.kill() + self.t.join() class StreamGOPReader(GOPReader): def __init__(self, fn, frame_type, index_data): @@ -579,43 +571,9 @@ class StreamFrameReader(StreamGOPReader, GOPFrameReader): def GOPFrameIterator(gop_reader, pix_fmt): - # this is really ugly. ill think about how to refactor it when i can think good - - IN_FLIGHT_GOPS = 6 # should be enough that the stream decompressor starts returning data - - with VideoStreamDecompressor(gop_reader.vid_fmt, gop_reader.w, gop_reader.h, pix_fmt) as dec: - read_work = [] - - def readthing(): - # print read_work, dec.out_q.qsize() - outf = dec.read() - read_thing = read_work[0] - if read_thing[0] > 0: - read_thing[0] -= 1 - else: - assert read_thing[1] > 0 - yield outf - read_thing[1] -= 1 - - if read_thing[1] == 0: - read_work.pop(0) - - i = 0 - while i < gop_reader.frame_count: - frame_b, num_frames, skip_frames, gop_data = gop_reader.get_gop(i) - dec.write(gop_data) - i += num_frames - read_work.append([skip_frames, num_frames]) - - while len(read_work) >= IN_FLIGHT_GOPS: - for v in readthing(): - yield v - - dec.eos() - - while read_work: - for v in readthing(): - yield v + dec = VideoStreamDecompressor(gop_reader.fn, gop_reader.vid_fmt, gop_reader.w, gop_reader.h, pix_fmt) + for frame in dec.read(): + yield frame def FrameIterator(fn, pix_fmt, **kwargs): diff --git a/tools/lib/url_file.py b/tools/lib/url_file.py index 6d2be96d2c..d0dd837891 100644 --- a/tools/lib/url_file.py +++ b/tools/lib/url_file.py @@ -117,7 +117,12 @@ class URLFile(object): download_range = False headers = ["Connection: keep-alive"] if self._pos != 0 or ll is not None: - end = (self._pos + ll if ll is not None else self.get_length()) - 1 + if ll is None: + end = self.get_length() - 1 + else: + end = min(self._pos + ll, self.get_length()) - 1 + if self._pos >= end: + return b"" headers.append(f"Range: bytes={self._pos}-{end}") download_range = True