url_file: fix non-200 files being cached (#30886)

* fix + test

* fix unclosed

* easier to read

Co-authored-by: Shane Smiskol <shane@smiskol.com>

* fix that

---------

Co-authored-by: Shane Smiskol <shane@smiskol.com>
old-commit-hash: fba521ecc6
chrysler-long2
Justin Newberry 1 year ago committed by GitHub
parent f82d7f453f
commit 7aecd2f91d
  1. 79
      tools/lib/tests/test_caching.py
  2. 2
      tools/lib/url_file.py

@ -1,15 +1,58 @@
#!/usr/bin/env python3
from functools import wraps
import http.server
import os
import threading
import time
import unittest
from pathlib import Path
from parameterized import parameterized
from unittest import mock
from openpilot.system.hardware.hw import Paths
from openpilot.tools.lib.url_file import URLFile
class CachingTestRequestHandler(http.server.BaseHTTPRequestHandler):
FILE_EXISTS = True
def do_GET(self):
if self.FILE_EXISTS:
self.send_response(200, b'1234')
else:
self.send_response(404)
self.end_headers()
def do_HEAD(self):
if self.FILE_EXISTS:
self.send_response(200)
self.send_header("Content-Length", "4")
else:
self.send_response(404)
self.end_headers()
class CachingTestServer(threading.Thread):
def run(self):
self.server = http.server.HTTPServer(("127.0.0.1", 0), CachingTestRequestHandler)
self.port = self.server.server_port
self.server.serve_forever()
def stop(self):
self.server.server_close()
self.server.shutdown()
def with_caching_server(func):
@wraps(func)
def wrapper(*args, **kwargs):
server = CachingTestServer()
server.start()
time.sleep(0.25) # wait for server to get it's port
try:
func(*args, **kwargs, port=server.port)
finally:
server.stop()
return wrapper
class TestFileDownload(unittest.TestCase):
def compare_loads(self, url, start=0, length=None):
@ -66,32 +109,20 @@ class TestFileDownload(unittest.TestCase):
self.compare_loads(large_file_url)
@parameterized.expand([(True, ), (False, )])
def test_recover_from_missing_file(self, cache_enabled):
@with_caching_server
def test_recover_from_missing_file(self, cache_enabled, port):
os.environ["FILEREADER_CACHE"] = "1" if cache_enabled else "0"
file_url = "http://localhost:5001/test.png"
file_url = f"http://localhost:{port}/test.png"
file_exists = False
CachingTestRequestHandler.FILE_EXISTS = False
length = URLFile(file_url).get_length()
self.assertEqual(length, -1)
def get_length_online_mock(self):
if file_exists:
return 4
return -1
CachingTestRequestHandler.FILE_EXISTS = True
length = URLFile(file_url).get_length()
self.assertEqual(length, 4)
patch_length = mock.patch.object(URLFile, "get_length_online", get_length_online_mock)
patch_length.start()
try:
length = URLFile(file_url).get_length()
self.assertEqual(length, -1)
file_exists = True
length = URLFile(file_url).get_length()
self.assertEqual(length, 4)
finally:
tempfile_length = Path(Paths.download_cache_root()) / "ba2119904385654cb0105a2da174875f8e7648db175f202ecae6d6428b0e838f_length"
if tempfile_length.exists():
tempfile_length.unlink()
patch_length.stop()
if __name__ == "__main__":

@ -57,6 +57,8 @@ class URLFile:
def get_length_online(self):
timeout = Timeout(connect=50.0, read=500.0)
response = self._http_client.request('HEAD', self._url, timeout=timeout, preload_content=False)
if not (200 <= response.status <= 299):
return -1
length = response.headers.get('content-length', 0)
return int(length)

Loading…
Cancel
Save