diff --git a/selfdrive/athena/athenad.py b/selfdrive/athena/athenad.py index 5b351ca0f5..d1cc4cea83 100755 --- a/selfdrive/athena/athenad.py +++ b/selfdrive/athena/athenad.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +from __future__ import annotations + import base64 import bz2 import hashlib @@ -14,14 +16,15 @@ import sys import tempfile import threading import time -from collections import namedtuple +from dataclasses import asdict, dataclass, replace from datetime import datetime from functools import partial -from typing import Any, Dict +from queue import Queue +from typing import BinaryIO, Callable, Dict, List, Optional, Set, Union, cast import requests from jsonrpc import JSONRPCResponseManager, dispatcher -from websocket import (ABNF, WebSocketException, WebSocketTimeoutException, +from websocket import (ABNF, WebSocket, WebSocketException, WebSocketTimeoutException, create_connection) import cereal.messaging as messaging @@ -54,19 +57,54 @@ WS_FRAME_SIZE = 4096 NetworkType = log.DeviceState.NetworkType +UploadFileDict = Dict[str, Union[str, int, float, bool]] +UploadItemDict = Dict[str, Union[str, bool, int, float, Dict[str, str]]] + +UploadFilesToUrlResponse = Dict[str, Union[int, List[UploadItemDict], List[str]]] + + +@dataclass +class UploadFile: + fn: str + url: str + headers: Dict[str, str] + allow_cellular: bool + + @classmethod + def from_dict(cls, d: Dict) -> UploadFile: + return cls(d.get("fn", ""), d.get("url", ""), d.get("headers", {}), d.get("allow_cellular", False)) + + +@dataclass +class UploadItem: + path: str + url: str + headers: Dict[str, str] + created_at: int + id: Optional[str] + retry_count: int = 0 + current: bool = False + progress: float = 0 + allow_cellular: bool = False + + @classmethod + def from_dict(cls, d: Dict) -> UploadItem: + return cls(d["path"], d["url"], d["headers"], d["created_at"], d["id"], d["retry_count"], d["current"], + d["progress"], d["allow_cellular"]) + + dispatcher["echo"] = lambda s: s -recv_queue: Any = queue.Queue() -send_queue: Any = queue.Queue() -upload_queue: Any = queue.Queue() -low_priority_send_queue: Any = queue.Queue() -log_recv_queue: Any = queue.Queue() -cancelled_uploads: Any = set() -UploadItem = namedtuple('UploadItem', ['path', 'url', 'headers', 'created_at', 'id', 'retry_count', 'current', 'progress', 'allow_cellular'], defaults=(0, False, 0, False)) +recv_queue: Queue[str] = queue.Queue() +send_queue: Queue[str] = queue.Queue() +upload_queue: Queue[UploadItem] = queue.Queue() +low_priority_send_queue: Queue[str] = queue.Queue() +log_recv_queue: Queue[str] = queue.Queue() +cancelled_uploads: Set[str] = set() -cur_upload_items: Dict[int, Any] = {} +cur_upload_items: Dict[int, Optional[UploadItem]] = {} -def strip_bz2_extension(fn): +def strip_bz2_extension(fn: str) -> str: if fn.endswith('.bz2'): return fn[:-4] return fn @@ -76,29 +114,30 @@ class AbortTransferException(Exception): pass -class UploadQueueCache(): +class UploadQueueCache: params = Params() @staticmethod - def initialize(upload_queue): + def initialize(upload_queue: Queue[UploadItem]) -> None: try: upload_queue_json = UploadQueueCache.params.get("AthenadUploadQueue") if upload_queue_json is not None: for item in json.loads(upload_queue_json): - upload_queue.put(UploadItem(**item)) + upload_queue.put(UploadItem.from_dict(item)) except Exception: cloudlog.exception("athena.UploadQueueCache.initialize.exception") @staticmethod - def cache(upload_queue): + def cache(upload_queue: Queue[UploadItem]) -> None: try: - items = [i._asdict() for i in upload_queue.queue if i.id not in cancelled_uploads] + queue: List[Optional[UploadItem]] = list(upload_queue.queue) + items = [asdict(i) for i in queue if i is not None and (i.id not in cancelled_uploads)] UploadQueueCache.params.put("AthenadUploadQueue", json.dumps(items)) except Exception: cloudlog.exception("athena.UploadQueueCache.cache.exception") -def handle_long_poll(ws): +def handle_long_poll(ws: WebSocket) -> None: end_event = threading.Event() threads = [ @@ -126,7 +165,7 @@ def handle_long_poll(ws): thread.join() -def jsonrpc_handler(end_event): +def jsonrpc_handler(end_event: threading.Event) -> None: dispatcher["startLocalProxy"] = partial(startLocalProxy, end_event) while not end_event.is_set(): try: @@ -147,11 +186,12 @@ def jsonrpc_handler(end_event): def retry_upload(tid: int, end_event: threading.Event, increase_count: bool = True) -> None: - if cur_upload_items[tid].retry_count < MAX_RETRY_COUNT: - item = cur_upload_items[tid] + item = cur_upload_items[tid] + if item is not None and item.retry_count < MAX_RETRY_COUNT: new_retry_count = item.retry_count + 1 if increase_count else item.retry_count - item = item._replace( + item = replace( + item, retry_count=new_retry_count, progress=0, current=False @@ -175,44 +215,44 @@ def upload_handler(end_event: threading.Event) -> None: cur_upload_items[tid] = None try: - cur_upload_items[tid] = upload_queue.get(timeout=1)._replace(current=True) + cur_upload_items[tid] = item = replace(upload_queue.get(timeout=1), current=True) - if cur_upload_items[tid].id in cancelled_uploads: - cancelled_uploads.remove(cur_upload_items[tid].id) + if item.id in cancelled_uploads: + cancelled_uploads.remove(item.id) continue # Remove item if too old - age = datetime.now() - datetime.fromtimestamp(cur_upload_items[tid].created_at / 1000) + age = datetime.now() - datetime.fromtimestamp(item.created_at / 1000) if age.total_seconds() > MAX_AGE: - cloudlog.event("athena.upload_handler.expired", item=cur_upload_items[tid], error=True) + cloudlog.event("athena.upload_handler.expired", item=item, error=True) continue # Check if uploading over metered connection is allowed sm.update(0) metered = sm['deviceState'].networkMetered network_type = sm['deviceState'].networkType.raw - if metered and (not cur_upload_items[tid].allow_cellular): + if metered and (not item.allow_cellular): retry_upload(tid, end_event, False) continue try: - def cb(sz, cur): + def cb(sz: int, cur: int) -> None: # Abort transfer if connection changed to metered after starting upload sm.update(0) metered = sm['deviceState'].networkMetered - if metered and (not cur_upload_items[tid].allow_cellular): + if metered and (not item.allow_cellular): raise AbortTransferException - cur_upload_items[tid] = cur_upload_items[tid]._replace(progress=cur / sz if sz else 1) + cur_upload_items[tid] = replace(item, progress=cur / sz if sz else 1) - fn = cur_upload_items[tid].path + fn = item.path try: sz = os.path.getsize(fn) except OSError: sz = -1 - cloudlog.event("athena.upload_handler.upload_start", fn=fn, sz=sz, network_type=network_type, metered=metered, retry_count=cur_upload_items[tid].retry_count) - response = _do_upload(cur_upload_items[tid], cb) + cloudlog.event("athena.upload_handler.upload_start", fn=fn, sz=sz, network_type=network_type, metered=metered, retry_count=item.retry_count) + response = _do_upload(item, cb) if response.status_code not in (200, 201, 401, 403, 412): cloudlog.event("athena.upload_handler.retry", status_code=response.status_code, fn=fn, sz=sz, network_type=network_type, metered=metered) @@ -234,7 +274,7 @@ def upload_handler(end_event: threading.Event) -> None: cloudlog.exception("athena.upload_handler.exception") -def _do_upload(upload_item, callback=None): +def _do_upload(upload_item: UploadItem, callback: Optional[Callable] = None) -> requests.Response: path = upload_item.path compress = False @@ -244,27 +284,25 @@ def _do_upload(upload_item, callback=None): compress = True with open(path, "rb") as f: + data: BinaryIO if compress: cloudlog.event("athena.upload_handler.compress", fn=path, fn_orig=upload_item.path) - data = bz2.compress(f.read()) - size = len(data) - data = io.BytesIO(data) + compressed = bz2.compress(f.read()) + size = len(compressed) + data = io.BytesIO(compressed) else: size = os.fstat(f.fileno()).st_size data = f - if callback: - data = CallbackReader(data, callback, size) - return requests.put(upload_item.url, - data=data, + data=CallbackReader(data, callback, size) if callback else data, headers={**upload_item.headers, 'Content-Length': str(size)}, timeout=30) # security: user should be able to request any message from their car @dispatcher.add_method -def getMessage(service=None, timeout=1000): +def getMessage(service: str, timeout: int = 1000) -> Dict: if service is None or service not in service_list: raise Exception("invalid service") @@ -274,7 +312,8 @@ def getMessage(service=None, timeout=1000): if ret is None: raise TimeoutError - return ret.to_dict() + # this is because capnp._DynamicStructReader doesn't have typing information + return cast(Dict, ret.to_dict()) @dispatcher.add_method @@ -288,7 +327,7 @@ def getVersion() -> Dict[str, str]: @dispatcher.add_method -def setNavDestination(latitude=0, longitude=0, place_name=None, place_details=None): +def setNavDestination(latitude: int = 0, longitude: int = 0, place_name: Optional[str] = None, place_details: Optional[str] = None) -> Dict[str, int]: destination = { "latitude": latitude, "longitude": longitude, @@ -300,8 +339,8 @@ def setNavDestination(latitude=0, longitude=0, place_name=None, place_details=No return {"success": 1} -def scan_dir(path, prefix): - files = list() +def scan_dir(path: str, prefix: str) -> List[str]: + files = [] # only walk directories that match the prefix # (glob and friends traverse entire dir tree) with os.scandir(path) as i: @@ -320,18 +359,18 @@ def scan_dir(path, prefix): return files @dispatcher.add_method -def listDataDirectory(prefix=''): +def listDataDirectory(prefix='') -> List[str]: return scan_dir(ROOT, prefix) @dispatcher.add_method -def reboot(): +def reboot() -> Dict[str, int]: sock = messaging.sub_sock("deviceState", timeout=1000) ret = messaging.recv_one(sock) if ret is None or ret.deviceState.started: raise Exception("Reboot unavailable") - def do_reboot(): + def do_reboot() -> None: time.sleep(2) HARDWARE.reboot() @@ -341,50 +380,53 @@ def reboot(): @dispatcher.add_method -def uploadFileToUrl(fn, url, headers): - return uploadFilesToUrls([{ +def uploadFileToUrl(fn: str, url: str, headers: Dict[str, str]) -> UploadFilesToUrlResponse: + # this is because mypy doesn't understand that the decorator doesn't change the return type + response: UploadFilesToUrlResponse = uploadFilesToUrls([{ "fn": fn, "url": url, "headers": headers, }]) + return response @dispatcher.add_method -def uploadFilesToUrls(files_data): - items = [] - failed = [] - for file in files_data: - fn = file.get('fn', '') - if len(fn) == 0 or fn[0] == '/' or '..' in fn or 'url' not in file: - failed.append(fn) +def uploadFilesToUrls(files_data: List[UploadFileDict]) -> UploadFilesToUrlResponse: + files = map(UploadFile.from_dict, files_data) + + items: List[UploadItemDict] = [] + failed: List[str] = [] + for file in files: + if len(file.fn) == 0 or file.fn[0] == '/' or '..' in file.fn or len(file.url) == 0: + failed.append(file.fn) continue - path = os.path.join(ROOT, fn) + path = os.path.join(ROOT, file.fn) if not os.path.exists(path) and not os.path.exists(strip_bz2_extension(path)): - failed.append(fn) + failed.append(file.fn) continue # Skip item if already in queue - url = file['url'].split('?')[0] + url = file.url.split('?')[0] if any(url == item['url'].split('?')[0] for item in listUploadQueue()): continue item = UploadItem( path=path, - url=file['url'], - headers=file.get('headers', {}), + url=file.url, + headers=file.headers, created_at=int(time.time() * 1000), id=None, - allow_cellular=file.get('allow_cellular', False), + allow_cellular=file.allow_cellular, ) upload_id = hashlib.sha1(str(item).encode()).hexdigest() - item = item._replace(id=upload_id) + item = replace(item, id=upload_id) upload_queue.put_nowait(item) - items.append(item._asdict()) + items.append(asdict(item)) UploadQueueCache.cache(upload_queue) - resp = {"enqueued": len(items), "items": items} + resp: UploadFilesToUrlResponse = {"enqueued": len(items), "items": items} if failed: resp["failed"] = failed @@ -392,32 +434,32 @@ def uploadFilesToUrls(files_data): @dispatcher.add_method -def listUploadQueue(): +def listUploadQueue() -> List[UploadItemDict]: items = list(upload_queue.queue) + list(cur_upload_items.values()) - return [i._asdict() for i in items if (i is not None) and (i.id not in cancelled_uploads)] + return [asdict(i) for i in items if (i is not None) and (i.id not in cancelled_uploads)] @dispatcher.add_method -def cancelUpload(upload_id): +def cancelUpload(upload_id: Union[str, List[str]]) -> Dict[str, Union[int, str]]: if not isinstance(upload_id, list): upload_id = [upload_id] uploading_ids = {item.id for item in list(upload_queue.queue)} cancelled_ids = uploading_ids.intersection(upload_id) if len(cancelled_ids) == 0: - return 404 + return {"success": 0, "error": "not found"} cancelled_uploads.update(cancelled_ids) return {"success": 1} @dispatcher.add_method -def primeActivated(activated): +def primeActivated(activated: bool) -> Dict[str, int]: return {"success": 1} @dispatcher.add_method -def setBandwithLimit(upload_speed_kbps, download_speed_kbps): +def setBandwithLimit(upload_speed_kbps: int, download_speed_kbps: int) -> Dict[str, Union[int, str]]: if not AGNOS: return {"success": 0, "error": "only supported on AGNOS"} @@ -428,7 +470,7 @@ def setBandwithLimit(upload_speed_kbps, download_speed_kbps): return {"success": 0, "error": "failed to set limit", "stdout": e.stdout, "stderr": e.stderr} -def startLocalProxy(global_end_event, remote_ws_uri, local_port): +def startLocalProxy(global_end_event: threading.Event, remote_ws_uri: str, local_port: int) -> Dict[str, int]: try: if local_port not in LOCAL_PORT_WHITELIST: raise Exception("Requested local port not whitelisted") @@ -462,7 +504,7 @@ def startLocalProxy(global_end_event, remote_ws_uri, local_port): @dispatcher.add_method -def getPublicKey(): +def getPublicKey() -> Optional[str]: if not os.path.isfile(PERSIST + '/comma/id_rsa.pub'): return None @@ -471,7 +513,7 @@ def getPublicKey(): @dispatcher.add_method -def getSshAuthorizedKeys(): +def getSshAuthorizedKeys() -> str: return Params().get("GithubSshKeys", encoding='utf8') or '' @@ -486,7 +528,7 @@ def getNetworkType(): @dispatcher.add_method -def getNetworkMetered(): +def getNetworkMetered() -> bool: network_type = HARDWARE.get_network_type() return HARDWARE.get_network_metered(network_type) @@ -497,7 +539,7 @@ def getNetworks(): @dispatcher.add_method -def takeSnapshot(): +def takeSnapshot() -> Optional[Union[str, Dict[str, str]]]: from system.camerad.snapshot.snapshot import jpeg_write, snapshot ret = snapshot() if ret is not None: @@ -514,16 +556,19 @@ def takeSnapshot(): raise Exception("not available while camerad is started") -def get_logs_to_send_sorted(): +def get_logs_to_send_sorted() -> List[str]: # TODO: scan once then use inotify to detect file creation/deletion curr_time = int(time.time()) logs = [] for log_entry in os.listdir(SWAGLOG_DIR): log_path = os.path.join(SWAGLOG_DIR, log_entry) + time_sent = 0 try: - time_sent = int.from_bytes(getxattr(log_path, LOG_ATTR_NAME), sys.byteorder) + value = getxattr(log_path, LOG_ATTR_NAME) + if value is not None: + time_sent = int.from_bytes(value, sys.byteorder) except (ValueError, TypeError): - time_sent = 0 + pass # assume send failed and we lost the response if sent more than one hour ago if not time_sent or curr_time - time_sent > 3600: logs.append(log_entry) @@ -531,7 +576,7 @@ def get_logs_to_send_sorted(): return sorted(logs)[:-1] -def log_handler(end_event): +def log_handler(end_event: threading.Event) -> None: if PC: return @@ -593,7 +638,7 @@ def log_handler(end_event): cloudlog.exception("athena.log_handler.exception") -def stat_handler(end_event): +def stat_handler(end_event: threading.Event) -> None: while not end_event.is_set(): last_scan = 0 curr_scan = sec_since_boot() @@ -619,7 +664,7 @@ def stat_handler(end_event): time.sleep(0.1) -def ws_proxy_recv(ws, local_sock, ssock, end_event, global_end_event): +def ws_proxy_recv(ws: WebSocket, local_sock: socket.socket, ssock: socket.socket, end_event: threading.Event, global_end_event: threading.Event) -> None: while not (end_event.is_set() or global_end_event.is_set()): try: data = ws.recv() @@ -638,7 +683,7 @@ def ws_proxy_recv(ws, local_sock, ssock, end_event, global_end_event): end_event.set() -def ws_proxy_send(ws, local_sock, signal_sock, end_event): +def ws_proxy_send(ws: WebSocket, local_sock: socket.socket, signal_sock: socket.socket, end_event: threading.Event) -> None: while not end_event.is_set(): try: r, _, _ = select.select((local_sock, signal_sock), (), ()) @@ -663,7 +708,7 @@ def ws_proxy_send(ws, local_sock, signal_sock, end_event): cloudlog.debug("athena.ws_proxy_send done closing sockets") -def ws_recv(ws, end_event): +def ws_recv(ws: WebSocket, end_event: threading.Event) -> None: last_ping = int(sec_since_boot() * 1e9) while not end_event.is_set(): try: @@ -685,7 +730,7 @@ def ws_recv(ws, end_event): end_event.set() -def ws_send(ws, end_event): +def ws_send(ws: WebSocket, end_event: threading.Event) -> None: while not end_event.is_set(): try: try: @@ -704,7 +749,7 @@ def ws_send(ws, end_event): end_event.set() -def backoff(retries): +def backoff(retries: int) -> int: return random.randrange(0, min(128, int(2 ** retries))) diff --git a/selfdrive/athena/tests/test_athenad.py b/selfdrive/athena/tests/test_athenad.py index 5e86a2e821..128fde0319 100755 --- a/selfdrive/athena/tests/test_athenad.py +++ b/selfdrive/athena/tests/test_athenad.py @@ -8,6 +8,7 @@ import time import threading import queue import unittest +from dataclasses import asdict, replace from datetime import datetime, timedelta from typing import Optional @@ -226,7 +227,7 @@ class TestAthenadMethods(unittest.TestCase): """When an upload times out or fails to connect it should be placed back in the queue""" fn = self._create_file('qlog.bz2') item = athenad.UploadItem(path=fn, url="http://localhost:44444/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='', allow_cellular=True) - item_no_retry = item._replace(retry_count=MAX_RETRY_COUNT) + item_no_retry = replace(item, retry_count=MAX_RETRY_COUNT) end_event = threading.Event() thread = threading.Thread(target=athenad.upload_handler, args=(end_event,)) @@ -296,7 +297,7 @@ class TestAthenadMethods(unittest.TestCase): self.assertEqual(len(items), 0) @with_http_server - def test_listUploadQueueCurrent(self, host): + def test_listUploadQueueCurrent(self, host: str): fn = self._create_file('qlog.bz2') item = athenad.UploadItem(path=fn, url=f"{host}/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='', allow_cellular=True) @@ -321,7 +322,7 @@ class TestAthenadMethods(unittest.TestCase): items = dispatcher["listUploadQueue"]() self.assertEqual(len(items), 1) - self.assertDictEqual(items[0], item._asdict()) + self.assertDictEqual(items[0], asdict(item)) self.assertFalse(items[0]['current']) athenad.cancelled_uploads.add(item.id) @@ -346,7 +347,7 @@ class TestAthenadMethods(unittest.TestCase): athenad.UploadQueueCache.initialize(athenad.upload_queue) self.assertEqual(athenad.upload_queue.qsize(), 1) - self.assertDictEqual(athenad.upload_queue.queue[-1]._asdict(), item1._asdict()) + self.assertDictEqual(asdict(athenad.upload_queue.queue[-1]), asdict(item1)) @mock.patch('selfdrive.athena.athenad.create_connection') def test_startLocalProxy(self, mock_create_connection): @@ -417,5 +418,6 @@ class TestAthenadMethods(unittest.TestCase): sl = athenad.get_logs_to_send_sorted() self.assertListEqual(sl, fl[:-1]) + if __name__ == '__main__': unittest.main()