diff --git a/tools/lib/file_helpers.py b/tools/lib/file_helpers.py index 9e681634a9..a1595d43d7 100644 --- a/tools/lib/file_helpers.py +++ b/tools/lib/file_helpers.py @@ -1,6 +1,7 @@ import os from atomicwrites import AtomicWriter + def atomic_write_in_dir(path, **kwargs): """Creates an atomic writer using a temporary file in the same directory as the destination file. @@ -8,6 +9,7 @@ def atomic_write_in_dir(path, **kwargs): writer = AtomicWriter(path, **kwargs) return writer._open(_get_fileobject_func(writer, os.path.dirname(path))) + def _get_fileobject_func(writer, temp_dir): def _get_fileobject(): file_obj = writer.get_fileobject(dir=temp_dir) @@ -15,6 +17,7 @@ def _get_fileobject_func(writer, temp_dir): return file_obj return _get_fileobject + def mkdirs_exists_ok(path): try: os.makedirs(path) diff --git a/tools/lib/tests/__init__.py b/tools/lib/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tools/lib/tests/test_caching.py b/tools/lib/tests/test_caching.py new file mode 100644 index 0000000000..fef6e536cb --- /dev/null +++ b/tools/lib/tests/test_caching.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 + +import os +import unittest +import shutil +from tools.lib.url_file import URLFile, CACHE_DIR + + +class TestFileDownload(unittest.TestCase): + + def compare_loads(self, url, start=0, length=None): + """Compares range between cached and non cached version""" + shutil.rmtree(CACHE_DIR) + + file_cached = URLFile(url, cache=True) + file_downloaded = URLFile(url, cache=False) + + file_cached.seek(start) + file_downloaded.seek(start) + + self.assertEqual(file_cached.get_length(), file_downloaded.get_length()) + self.assertLessEqual(length + start if length is not None else 0, file_downloaded.get_length()) + + response_cached = file_cached.read(ll=length) + response_downloaded = file_downloaded.read(ll=length) + + self.assertEqual(response_cached, response_downloaded) + + # Now test with cache in place + file_cached = URLFile(url, cache=True) + file_cached.seek(start) + response_cached = file_cached.read(ll=length) + + self.assertEqual(file_cached.get_length(), file_downloaded.get_length()) + self.assertEqual(response_cached, response_downloaded) + + def test_small_file(self): + # Make sure we don't force cache + os.environ["FILEREADER_CACHE"] = "0" + small_file_url = "https://raw.githubusercontent.com/commaai/openpilot/master/SAFETY.md" + # If you want large file to be larger than a chunk + # large_file_url = "https://commadataci.blob.core.windows.net/openpilotci/0375fdf7b1ce594d/2019-06-13--08-32-25/3/fcamera.hevc" + + # Load full small file + self.compare_loads(small_file_url) + + file_small = URLFile(small_file_url) + length = file_small.get_length() + + self.compare_loads(small_file_url, length - 100, 100) + self.compare_loads(small_file_url, 50, 100) + + # Load small file 100 bytes at a time + for i in range(length // 100): + self.compare_loads(small_file_url, 100 * i, 100) + + def test_large_file(self): + large_file_url = "https://commadataci.blob.core.windows.net/openpilotci/0375fdf7b1ce594d/2019-06-13--08-32-25/3/qlog.bz2" + # Load the end 100 bytes of both files + file_large = URLFile(large_file_url) + length = file_large.get_length() + + self.compare_loads(large_file_url, length - 100, 100) + self.compare_loads(large_file_url) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/lib/tests/test_readers.py b/tools/lib/tests/test_readers.py index cb6ba3e11a..1d8918ba54 100755 --- a/tools/lib/tests/test_readers.py +++ b/tools/lib/tests/test_readers.py @@ -8,7 +8,9 @@ import numpy as np from tools.lib.framereader import FrameReader from tools.lib.logreader import LogReader + class TestReaders(unittest.TestCase): + @unittest.skip("skip for bandwith reasons") def test_logreader(self): def _check_data(lr): hist = defaultdict(int) @@ -29,6 +31,7 @@ class TestReaders(unittest.TestCase): lr_url = LogReader("https://github.com/commaai/comma2k19/blob/master/Example_1/b0c9d2329ad1606b%7C2018-08-02--08-34-47/40/raw_log.bz2?raw=true") _check_data(lr_url) + @unittest.skip("skip for bandwith reasons") def test_framereader(self): def _check_data(f): self.assertEqual(f.frame_count, 1200) diff --git a/tools/lib/url_file.py b/tools/lib/url_file.py index 38637d90cc..e5f49d79a8 100644 --- a/tools/lib/url_file.py +++ b/tools/lib/url_file.py @@ -6,23 +6,41 @@ import tempfile import threading import urllib.parse import pycurl +from hashlib import sha256 from io import BytesIO from tenacity import retry, wait_random_exponential, stop_after_attempt +from tools.lib.file_helpers import mkdirs_exists_ok, atomic_write_in_dir +# Cache chunk size +K = 1000 +CHUNK_SIZE = 1000 * K + +CACHE_DIR = "/tmp/comma_download_cache/" + + +def hash_256(link): + hsh = str(sha256((link.split("?")[0]).encode('utf-8')).hexdigest()) + return hsh class URLFile(object): _tlocal = threading.local() - def __init__(self, url, debug=False): + def __init__(self, url, debug=False, cache=None): self._url = url self._pos = 0 + self._length = None self._local_file = None self._debug = debug + # True by default, false if FILEREADER_CACHE is defined, but can be overwritten by the cache input + self._force_download = not int(os.environ.get("FILEREADER_CACHE", "0")) + if cache is not None: + self._force_download = not cache try: self._curl = self._tlocal.curl except AttributeError: self._curl = self._tlocal.curl = pycurl.Curl() + mkdirs_exists_ok(CACHE_DIR) def __enter__(self): return self @@ -34,9 +52,70 @@ class URLFile(object): self._local_file = None @retry(wait=wait_random_exponential(multiplier=1, max=5), stop=stop_after_attempt(3), reraise=True) + def get_length_online(self): + c = self._curl + c.reset() + c.setopt(pycurl.NOSIGNAL, 1) + c.setopt(pycurl.TIMEOUT_MS, 500000) + c.setopt(pycurl.FOLLOWLOCATION, True) + c.setopt(pycurl.URL, self._url) + c.setopt(c.NOBODY, 1) + c.perform() + length = int(c.getinfo(c.CONTENT_LENGTH_DOWNLOAD)) + c.reset() + return length + + def get_length(self): + if self._length is not None: + return self._length + file_length_path = os.path.join(CACHE_DIR, hash_256(self._url) + "_length") + if os.path.exists(file_length_path) and not self._force_download: + with open(file_length_path, "r") as file_length: + content = file_length.read() + self._length = int(content) + return self._length + + self._length = self.get_length_online() + if not self._force_download: + with atomic_write_in_dir(file_length_path, mode="w") as file_length: + file_length.write(str(self._length)) + return self._length + def read(self, ll=None): + if self._force_download: + return self.read_aux(ll=ll) + + file_begin = self._pos + file_end = self._pos + ll if ll is not None else self.get_length() + # We have to allign with chunks we store. Position is the begginiing of the latest chunk that starts before or at our file + position = (file_begin // CHUNK_SIZE) * CHUNK_SIZE + response = b"" + while True: + self._pos = position + chunk_number = self._pos / CHUNK_SIZE + file_name = hash_256(self._url) + "_" + str(chunk_number) + full_path = os.path.join(CACHE_DIR, str(file_name)) + data = None + # If we don't have a file, download it + if not os.path.exists(full_path): + data = self.read_aux(ll=CHUNK_SIZE) + with atomic_write_in_dir(full_path, mode="wb") as new_cached_file: + new_cached_file.write(data) + else: + with open(full_path, "rb") as cached_file: + data = cached_file.read() + + response += data[max(0, file_begin - position): min(CHUNK_SIZE, file_end - position)] + + position += CHUNK_SIZE + if position >= file_end: + self._pos = file_end + return response + + @retry(wait=wait_random_exponential(multiplier=1, max=5), stop=stop_after_attempt(3), reraise=True) + def read_aux(self, ll=None): if ll is None: - trange = 'bytes=%d-' % self._pos + trange = 'bytes=%d-%d' % (self._pos, self.get_length()-1) else: trange = 'bytes=%d-%d' % (self._pos, self._pos + ll - 1) @@ -74,7 +153,7 @@ class URLFile(object): response_code = c.getinfo(pycurl.RESPONSE_CODE) if response_code == 416: # Requested Range Not Satisfiable - return "" + raise Exception("Error, range out of bounds {} ({}): {}".format(response_code, self._url, repr(dats.getvalue())[:500])) if response_code != 206 and response_code != 200: raise Exception("Error {} ({}): {}".format(response_code, self._url, repr(dats.getvalue())[:500]))