url_file: fix non-200 files being cached (#30886)

* fix + test

* fix unclosed

* easier to read

Co-authored-by: Shane Smiskol <shane@smiskol.com>

* fix that

---------

Co-authored-by: Shane Smiskol <shane@smiskol.com>
pull/30888/head
Justin Newberry 1 year ago committed by GitHub
parent 9d7f618bc5
commit fba521ecc6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 79
      tools/lib/tests/test_caching.py
  2. 2
      tools/lib/url_file.py

@ -1,15 +1,58 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from functools import wraps
import http.server
import os import os
import threading
import time
import unittest import unittest
from pathlib import Path
from parameterized import parameterized from parameterized import parameterized
from unittest import mock
from openpilot.system.hardware.hw import Paths
from openpilot.tools.lib.url_file import URLFile from openpilot.tools.lib.url_file import URLFile
class CachingTestRequestHandler(http.server.BaseHTTPRequestHandler):
FILE_EXISTS = True
def do_GET(self):
if self.FILE_EXISTS:
self.send_response(200, b'1234')
else:
self.send_response(404)
self.end_headers()
def do_HEAD(self):
if self.FILE_EXISTS:
self.send_response(200)
self.send_header("Content-Length", "4")
else:
self.send_response(404)
self.end_headers()
class CachingTestServer(threading.Thread):
def run(self):
self.server = http.server.HTTPServer(("127.0.0.1", 0), CachingTestRequestHandler)
self.port = self.server.server_port
self.server.serve_forever()
def stop(self):
self.server.server_close()
self.server.shutdown()
def with_caching_server(func):
@wraps(func)
def wrapper(*args, **kwargs):
server = CachingTestServer()
server.start()
time.sleep(0.25) # wait for server to get it's port
try:
func(*args, **kwargs, port=server.port)
finally:
server.stop()
return wrapper
class TestFileDownload(unittest.TestCase): class TestFileDownload(unittest.TestCase):
def compare_loads(self, url, start=0, length=None): def compare_loads(self, url, start=0, length=None):
@ -66,32 +109,20 @@ class TestFileDownload(unittest.TestCase):
self.compare_loads(large_file_url) self.compare_loads(large_file_url)
@parameterized.expand([(True, ), (False, )]) @parameterized.expand([(True, ), (False, )])
def test_recover_from_missing_file(self, cache_enabled): @with_caching_server
def test_recover_from_missing_file(self, cache_enabled, port):
os.environ["FILEREADER_CACHE"] = "1" if cache_enabled else "0" os.environ["FILEREADER_CACHE"] = "1" if cache_enabled else "0"
file_url = "http://localhost:5001/test.png" file_url = f"http://localhost:{port}/test.png"
file_exists = False CachingTestRequestHandler.FILE_EXISTS = False
length = URLFile(file_url).get_length()
self.assertEqual(length, -1)
def get_length_online_mock(self): CachingTestRequestHandler.FILE_EXISTS = True
if file_exists: length = URLFile(file_url).get_length()
return 4 self.assertEqual(length, 4)
return -1
patch_length = mock.patch.object(URLFile, "get_length_online", get_length_online_mock)
patch_length.start()
try:
length = URLFile(file_url).get_length()
self.assertEqual(length, -1)
file_exists = True
length = URLFile(file_url).get_length()
self.assertEqual(length, 4)
finally:
tempfile_length = Path(Paths.download_cache_root()) / "ba2119904385654cb0105a2da174875f8e7648db175f202ecae6d6428b0e838f_length"
if tempfile_length.exists():
tempfile_length.unlink()
patch_length.stop()
if __name__ == "__main__": if __name__ == "__main__":

@ -57,6 +57,8 @@ class URLFile:
def get_length_online(self): def get_length_online(self):
timeout = Timeout(connect=50.0, read=500.0) timeout = Timeout(connect=50.0, read=500.0)
response = self._http_client.request('HEAD', self._url, timeout=timeout, preload_content=False) response = self._http_client.request('HEAD', self._url, timeout=timeout, preload_content=False)
if not (200 <= response.status <= 299):
return -1
length = response.headers.get('content-length', 0) length = response.headers.get('content-length', 0)
return int(length) return int(length)

Loading…
Cancel
Save