url_file: fix non-200 files being cached (#30886)

* fix + test

* fix unclosed

* easier to read

Co-authored-by: Shane Smiskol <shane@smiskol.com>

* fix that

---------

Co-authored-by: Shane Smiskol <shane@smiskol.com>
pull/30888/head
Justin Newberry 1 year ago committed by GitHub
parent 9d7f618bc5
commit fba521ecc6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 73
      tools/lib/tests/test_caching.py
  2. 2
      tools/lib/url_file.py

@ -1,15 +1,58 @@
#!/usr/bin/env python3
from functools import wraps
import http.server
import os
import threading
import time
import unittest
from pathlib import Path
from parameterized import parameterized
from unittest import mock
from openpilot.system.hardware.hw import Paths
from openpilot.tools.lib.url_file import URLFile
class CachingTestRequestHandler(http.server.BaseHTTPRequestHandler):
FILE_EXISTS = True
def do_GET(self):
if self.FILE_EXISTS:
self.send_response(200, b'1234')
else:
self.send_response(404)
self.end_headers()
def do_HEAD(self):
if self.FILE_EXISTS:
self.send_response(200)
self.send_header("Content-Length", "4")
else:
self.send_response(404)
self.end_headers()
class CachingTestServer(threading.Thread):
def run(self):
self.server = http.server.HTTPServer(("127.0.0.1", 0), CachingTestRequestHandler)
self.port = self.server.server_port
self.server.serve_forever()
def stop(self):
self.server.server_close()
self.server.shutdown()
def with_caching_server(func):
@wraps(func)
def wrapper(*args, **kwargs):
server = CachingTestServer()
server.start()
time.sleep(0.25) # wait for server to get it's port
try:
func(*args, **kwargs, port=server.port)
finally:
server.stop()
return wrapper
class TestFileDownload(unittest.TestCase):
def compare_loads(self, url, start=0, length=None):
@ -66,32 +109,20 @@ class TestFileDownload(unittest.TestCase):
self.compare_loads(large_file_url)
@parameterized.expand([(True, ), (False, )])
def test_recover_from_missing_file(self, cache_enabled):
@with_caching_server
def test_recover_from_missing_file(self, cache_enabled, port):
os.environ["FILEREADER_CACHE"] = "1" if cache_enabled else "0"
file_url = "http://localhost:5001/test.png"
file_exists = False
file_url = f"http://localhost:{port}/test.png"
def get_length_online_mock(self):
if file_exists:
return 4
return -1
patch_length = mock.patch.object(URLFile, "get_length_online", get_length_online_mock)
patch_length.start()
try:
CachingTestRequestHandler.FILE_EXISTS = False
length = URLFile(file_url).get_length()
self.assertEqual(length, -1)
file_exists = True
CachingTestRequestHandler.FILE_EXISTS = True
length = URLFile(file_url).get_length()
self.assertEqual(length, 4)
finally:
tempfile_length = Path(Paths.download_cache_root()) / "ba2119904385654cb0105a2da174875f8e7648db175f202ecae6d6428b0e838f_length"
if tempfile_length.exists():
tempfile_length.unlink()
patch_length.stop()
if __name__ == "__main__":

@ -57,6 +57,8 @@ class URLFile:
def get_length_online(self):
timeout = Timeout(connect=50.0, read=500.0)
response = self._http_client.request('HEAD', self._url, timeout=timeout, preload_content=False)
if not (200 <= response.status <= 299):
return -1
length = response.headers.get('content-length', 0)
return int(length)

Loading…
Cancel
Save