URLFile: add typing and internalize pool manager (#31466)

* URLFile: add typing and internalize pool manager

* cleanup
pull/31471/head
Greg Hogan 1 year ago committed by GitHub
parent 3cd0e5d43c
commit e59fe0014a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 57
      tools/lib/url_file.py

@ -4,6 +4,7 @@ import socket
import time import time
from hashlib import sha256 from hashlib import sha256
from urllib3 import PoolManager, Retry from urllib3 import PoolManager, Retry
from urllib3.response import BaseHTTPResponse
from urllib3.util import Timeout from urllib3.util import Timeout
from openpilot.common.file_helpers import atomic_write_in_dir from openpilot.common.file_helpers import atomic_write_in_dir
@ -14,7 +15,7 @@ CHUNK_SIZE = 1000 * K
logging.getLogger("urllib3").setLevel(logging.WARNING) logging.getLogger("urllib3").setLevel(logging.WARNING)
def hash_256(link): def hash_256(link: str) -> str:
hsh = str(sha256((link.split("?")[0]).encode('utf-8')).hexdigest()) hsh = str(sha256((link.split("?")[0]).encode('utf-8')).hexdigest())
return hsh return hsh
@ -23,26 +24,26 @@ class URLFileException(Exception):
pass pass
def new_pool_manager() -> PoolManager: class URLFile:
socket_options = [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),] _pool_manager: PoolManager|None = None
retries = Retry(total=5, backoff_factor=0.5, status_forcelist=[409, 429, 503, 504])
return PoolManager(num_pools=10, maxsize=100, socket_options=socket_options, retries=retries)
def set_pool_manager():
URLFile._pool_manager = new_pool_manager()
os.register_at_fork(after_in_child=set_pool_manager)
@staticmethod
def reset() -> None:
URLFile._pool_manager = None
class URLFile: @staticmethod
_pool_manager = new_pool_manager() def pool_manager() -> PoolManager:
if URLFile._pool_manager is None:
socket_options = [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),]
retries = Retry(total=5, backoff_factor=0.5, status_forcelist=[409, 429, 503, 504])
URLFile._pool_manager = PoolManager(num_pools=10, maxsize=100, socket_options=socket_options, retries=retries)
return URLFile._pool_manager
def __init__(self, url, timeout=10, debug=False, cache=None): def __init__(self, url: str, timeout: int=10, debug: bool=False, cache: bool|None=None):
self._url = url self._url = url
self._timeout = Timeout(connect=timeout, read=timeout) self._timeout = Timeout(connect=timeout, read=timeout)
self._pos = 0 self._pos = 0
self._length = None self._length: int|None = None
self._local_file = None
self._debug = debug self._debug = debug
# True by default, false if FILEREADER_CACHE is defined, but can be overwritten by the cache input # True by default, false if FILEREADER_CACHE is defined, but can be overwritten by the cache input
self._force_download = not int(os.environ.get("FILEREADER_CACHE", "0")) self._force_download = not int(os.environ.get("FILEREADER_CACHE", "0"))
@ -55,23 +56,20 @@ class URLFile:
def __enter__(self): def __enter__(self):
return self return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback) -> None:
if self._local_file is not None: pass
os.remove(self._local_file.name)
self._local_file.close()
self._local_file = None
def _request(self, method, url, headers=None): def _request(self, method: str, url: str, headers: dict[str, str]|None=None) -> BaseHTTPResponse:
return URLFile._pool_manager.request(method, url, timeout=self._timeout, headers=headers) return URLFile.pool_manager().request(method, url, timeout=self._timeout, headers=headers)
def get_length_online(self): def get_length_online(self) -> int:
response = self._request('HEAD', self._url) response = self._request('HEAD', self._url)
if not (200 <= response.status <= 299): if not (200 <= response.status <= 299):
return -1 return -1
length = response.headers.get('content-length', 0) length = response.headers.get('content-length', 0)
return int(length) return int(length)
def get_length(self): def get_length(self) -> int:
if self._length is not None: if self._length is not None:
return self._length return self._length
@ -88,7 +86,7 @@ class URLFile:
file_length.write(str(self._length)) file_length.write(str(self._length))
return self._length return self._length
def read(self, ll=None): def read(self, ll: int|None=None) -> bytes:
if self._force_download: if self._force_download:
return self.read_aux(ll=ll) return self.read_aux(ll=ll)
@ -120,7 +118,7 @@ class URLFile:
self._pos = file_end self._pos = file_end
return response return response
def read_aux(self, ll=None): def read_aux(self, ll: int|None=None) -> bytes:
download_range = False download_range = False
headers = {} headers = {}
if self._pos != 0 or ll is not None: if self._pos != 0 or ll is not None:
@ -155,9 +153,12 @@ class URLFile:
self._pos += len(ret) self._pos += len(ret)
return ret return ret
def seek(self, pos): def seek(self, pos:int) -> None:
self._pos = pos self._pos = pos
@property @property
def name(self): def name(self) -> str:
return self._url return self._url
os.register_at_fork(after_in_child=URLFile.reset)

Loading…
Cancel
Save