diff --git a/tools/lib/url_file.py b/tools/lib/url_file.py index 5c6f187eee..be9c815c93 100644 --- a/tools/lib/url_file.py +++ b/tools/lib/url_file.py @@ -6,6 +6,7 @@ from hashlib import sha256 from urllib3 import PoolManager from urllib3.util import Timeout from tenacity import retry, wait_random_exponential, stop_after_attempt +from typing import Optional from openpilot.common.file_helpers import atomic_write_in_dir from openpilot.system.hardware.hw import Paths @@ -25,9 +26,12 @@ class URLFileException(Exception): class URLFile: - _tlocal = threading.local() + _pid: Optional[int] = None + _pool_manager: Optional[PoolManager] = None + _pool_manager_lock = threading.Lock() def __init__(self, url, debug=False, cache=None): + self._pool_manager = None self._url = url self._pos = 0 self._length = None @@ -41,11 +45,6 @@ class URLFile: if not self._force_download: os.makedirs(Paths.download_cache_root(), exist_ok=True) - try: - self._http_client = URLFile._tlocal.http_client - except AttributeError: - self._http_client = URLFile._tlocal.http_client = PoolManager() - def __enter__(self): return self @@ -55,10 +54,20 @@ class URLFile: self._local_file.close() self._local_file = None + def _http_client(self) -> PoolManager: + if self._pool_manager is None: + pid = os.getpid() + with URLFile._pool_manager_lock: + if URLFile._pid != pid or URLFile._pool_manager is None: # unsafe to share after fork + URLFile._pid = pid + URLFile._pool_manager = PoolManager(num_pools=10, maxsize=10) + self._pool_manager = URLFile._pool_manager + return self._pool_manager + @retry(wait=wait_random_exponential(multiplier=1, max=5), stop=stop_after_attempt(3), reraise=True) def get_length_online(self): timeout = Timeout(connect=50.0, read=500.0) - response = self._http_client.request('HEAD', self._url, timeout=timeout, preload_content=False) + response = self._http_client().request('HEAD', self._url, timeout=timeout, preload_content=False) if not (200 <= response.status <= 299): return -1 length = response.headers.get('content-length', 0) @@ -131,7 +140,7 @@ class URLFile: t1 = time.time() timeout = Timeout(connect=50.0, read=500.0) - response = self._http_client.request('GET', self._url, timeout=timeout, preload_content=False, headers=headers) + response = self._http_client().request('GET', self._url, timeout=timeout, preload_content=False, headers=headers) ret = response.data if self._debug: