|  |  |  | #!/usr/bin/env python3
 | 
					
						
							|  |  |  | from functools import wraps
 | 
					
						
							|  |  |  | import http.server
 | 
					
						
							|  |  |  | import os
 | 
					
						
							|  |  |  | import threading
 | 
					
						
							|  |  |  | import time
 | 
					
						
							|  |  |  | import unittest
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from parameterized import parameterized
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from openpilot.tools.lib.url_file import URLFile
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class CachingTestRequestHandler(http.server.BaseHTTPRequestHandler):
 | 
					
						
							|  |  |  |   FILE_EXISTS = True
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   def do_GET(self):
 | 
					
						
							|  |  |  |     if self.FILE_EXISTS:
 | 
					
						
							|  |  |  |       self.send_response(200, b'1234')
 | 
					
						
							|  |  |  |     else:
 | 
					
						
							|  |  |  |       self.send_response(404)
 | 
					
						
							|  |  |  |     self.end_headers()
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   def do_HEAD(self):
 | 
					
						
							|  |  |  |     if self.FILE_EXISTS:
 | 
					
						
							|  |  |  |       self.send_response(200)
 | 
					
						
							|  |  |  |       self.send_header("Content-Length", "4")
 | 
					
						
							|  |  |  |     else:
 | 
					
						
							|  |  |  |       self.send_response(404)
 | 
					
						
							|  |  |  |     self.end_headers()
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class CachingTestServer(threading.Thread):
 | 
					
						
							|  |  |  |   def run(self):
 | 
					
						
							|  |  |  |     self.server = http.server.HTTPServer(("127.0.0.1", 0), CachingTestRequestHandler)
 | 
					
						
							|  |  |  |     self.port = self.server.server_port
 | 
					
						
							|  |  |  |     self.server.serve_forever()
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   def stop(self):
 | 
					
						
							|  |  |  |     self.server.server_close()
 | 
					
						
							|  |  |  |     self.server.shutdown()
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def with_caching_server(func):
 | 
					
						
							|  |  |  |   @wraps(func)
 | 
					
						
							|  |  |  |   def wrapper(*args, **kwargs):
 | 
					
						
							|  |  |  |     server = CachingTestServer()
 | 
					
						
							|  |  |  |     server.start()
 | 
					
						
							|  |  |  |     time.sleep(0.25) # wait for server to get it's port
 | 
					
						
							|  |  |  |     try:
 | 
					
						
							|  |  |  |       func(*args, **kwargs, port=server.port)
 | 
					
						
							|  |  |  |     finally:
 | 
					
						
							|  |  |  |       server.stop()
 | 
					
						
							|  |  |  |   return wrapper
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class TestFileDownload(unittest.TestCase):
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   def compare_loads(self, url, start=0, length=None):
 | 
					
						
							|  |  |  |     """Compares range between cached and non cached version"""
 | 
					
						
							|  |  |  |     file_cached = URLFile(url, cache=True)
 | 
					
						
							|  |  |  |     file_downloaded = URLFile(url, cache=False)
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     file_cached.seek(start)
 | 
					
						
							|  |  |  |     file_downloaded.seek(start)
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     self.assertEqual(file_cached.get_length(), file_downloaded.get_length())
 | 
					
						
							|  |  |  |     self.assertLessEqual(length + start if length is not None else 0, file_downloaded.get_length())
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     response_cached = file_cached.read(ll=length)
 | 
					
						
							|  |  |  |     response_downloaded = file_downloaded.read(ll=length)
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     self.assertEqual(response_cached, response_downloaded)
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Now test with cache in place
 | 
					
						
							|  |  |  |     file_cached = URLFile(url, cache=True)
 | 
					
						
							|  |  |  |     file_cached.seek(start)
 | 
					
						
							|  |  |  |     response_cached = file_cached.read(ll=length)
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     self.assertEqual(file_cached.get_length(), file_downloaded.get_length())
 | 
					
						
							|  |  |  |     self.assertEqual(response_cached, response_downloaded)
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   def test_small_file(self):
 | 
					
						
							|  |  |  |     # Make sure we don't force cache
 | 
					
						
							|  |  |  |     os.environ["FILEREADER_CACHE"] = "0"
 | 
					
						
							|  |  |  |     small_file_url = "https://raw.githubusercontent.com/commaai/openpilot/master/docs/SAFETY.md"
 | 
					
						
							|  |  |  |     #  If you want large file to be larger than a chunk
 | 
					
						
							|  |  |  |     #  large_file_url = "https://commadataci.blob.core.windows.net/openpilotci/0375fdf7b1ce594d/2019-06-13--08-32-25/3/fcamera.hevc"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     #  Load full small file
 | 
					
						
							|  |  |  |     self.compare_loads(small_file_url)
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     file_small = URLFile(small_file_url)
 | 
					
						
							|  |  |  |     length = file_small.get_length()
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     self.compare_loads(small_file_url, length - 100, 100)
 | 
					
						
							|  |  |  |     self.compare_loads(small_file_url, 50, 100)
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     #  Load small file 100 bytes at a time
 | 
					
						
							|  |  |  |     for i in range(length // 100):
 | 
					
						
							|  |  |  |       self.compare_loads(small_file_url, 100 * i, 100)
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   def test_large_file(self):
 | 
					
						
							|  |  |  |     large_file_url = "https://commadataci.blob.core.windows.net/openpilotci/0375fdf7b1ce594d/2019-06-13--08-32-25/3/qlog.bz2"
 | 
					
						
							|  |  |  |     #  Load the end 100 bytes of both files
 | 
					
						
							|  |  |  |     file_large = URLFile(large_file_url)
 | 
					
						
							|  |  |  |     length = file_large.get_length()
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     self.compare_loads(large_file_url, length - 100, 100)
 | 
					
						
							|  |  |  |     self.compare_loads(large_file_url)
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   @parameterized.expand([(True, ), (False, )])
 | 
					
						
							|  |  |  |   @with_caching_server
 | 
					
						
							|  |  |  |   def test_recover_from_missing_file(self, cache_enabled, port):
 | 
					
						
							|  |  |  |     os.environ["FILEREADER_CACHE"] = "1" if cache_enabled else "0"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     file_url = f"http://localhost:{port}/test.png"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     CachingTestRequestHandler.FILE_EXISTS = False
 | 
					
						
							|  |  |  |     length = URLFile(file_url).get_length()
 | 
					
						
							|  |  |  |     self.assertEqual(length, -1)
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     CachingTestRequestHandler.FILE_EXISTS = True
 | 
					
						
							|  |  |  |     length = URLFile(file_url).get_length()
 | 
					
						
							|  |  |  |     self.assertEqual(length, 4)
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__":
 | 
					
						
							|  |  |  |   unittest.main()
 |