athenad: more types (#25877)

* add typing hints

* missed these

* revert functional changes and changes to uploader

* remove

* try any

* add types to test code

* try dataclass instead

* mypy needs this

* comments

* remove Any type

* remove unused method

* cleanup

Co-authored-by: Adeeb Shihadeh <adeebshihadeh@gmail.com>
old-commit-hash: e1c739f709
taco
Cameron Clough 2 years ago committed by GitHub
parent 2bbf68c63b
commit c5cbd60a34
  1. 221
      selfdrive/athena/athenad.py
  2. 10
      selfdrive/athena/tests/test_athenad.py

@ -1,4 +1,6 @@
#!/usr/bin/env python3
from __future__ import annotations
import base64
import bz2
import hashlib
@ -14,14 +16,15 @@ import sys
import tempfile
import threading
import time
from collections import namedtuple
from dataclasses import asdict, dataclass, replace
from datetime import datetime
from functools import partial
from typing import Any, Dict
from queue import Queue
from typing import BinaryIO, Callable, Dict, List, Optional, Set, Union, cast
import requests
from jsonrpc import JSONRPCResponseManager, dispatcher
from websocket import (ABNF, WebSocketException, WebSocketTimeoutException,
from websocket import (ABNF, WebSocket, WebSocketException, WebSocketTimeoutException,
create_connection)
import cereal.messaging as messaging
@ -54,19 +57,54 @@ WS_FRAME_SIZE = 4096
NetworkType = log.DeviceState.NetworkType
UploadFileDict = Dict[str, Union[str, int, float, bool]]
UploadItemDict = Dict[str, Union[str, bool, int, float, Dict[str, str]]]
UploadFilesToUrlResponse = Dict[str, Union[int, List[UploadItemDict], List[str]]]
@dataclass
class UploadFile:
fn: str
url: str
headers: Dict[str, str]
allow_cellular: bool
@classmethod
def from_dict(cls, d: Dict) -> UploadFile:
return cls(d.get("fn", ""), d.get("url", ""), d.get("headers", {}), d.get("allow_cellular", False))
@dataclass
class UploadItem:
path: str
url: str
headers: Dict[str, str]
created_at: int
id: Optional[str]
retry_count: int = 0
current: bool = False
progress: float = 0
allow_cellular: bool = False
@classmethod
def from_dict(cls, d: Dict) -> UploadItem:
return cls(d["path"], d["url"], d["headers"], d["created_at"], d["id"], d["retry_count"], d["current"],
d["progress"], d["allow_cellular"])
dispatcher["echo"] = lambda s: s
recv_queue: Any = queue.Queue()
send_queue: Any = queue.Queue()
upload_queue: Any = queue.Queue()
low_priority_send_queue: Any = queue.Queue()
log_recv_queue: Any = queue.Queue()
cancelled_uploads: Any = set()
UploadItem = namedtuple('UploadItem', ['path', 'url', 'headers', 'created_at', 'id', 'retry_count', 'current', 'progress', 'allow_cellular'], defaults=(0, False, 0, False))
recv_queue: Queue[str] = queue.Queue()
send_queue: Queue[str] = queue.Queue()
upload_queue: Queue[UploadItem] = queue.Queue()
low_priority_send_queue: Queue[str] = queue.Queue()
log_recv_queue: Queue[str] = queue.Queue()
cancelled_uploads: Set[str] = set()
cur_upload_items: Dict[int, Any] = {}
cur_upload_items: Dict[int, Optional[UploadItem]] = {}
def strip_bz2_extension(fn):
def strip_bz2_extension(fn: str) -> str:
if fn.endswith('.bz2'):
return fn[:-4]
return fn
@ -76,29 +114,30 @@ class AbortTransferException(Exception):
pass
class UploadQueueCache():
class UploadQueueCache:
params = Params()
@staticmethod
def initialize(upload_queue):
def initialize(upload_queue: Queue[UploadItem]) -> None:
try:
upload_queue_json = UploadQueueCache.params.get("AthenadUploadQueue")
if upload_queue_json is not None:
for item in json.loads(upload_queue_json):
upload_queue.put(UploadItem(**item))
upload_queue.put(UploadItem.from_dict(item))
except Exception:
cloudlog.exception("athena.UploadQueueCache.initialize.exception")
@staticmethod
def cache(upload_queue):
def cache(upload_queue: Queue[UploadItem]) -> None:
try:
items = [i._asdict() for i in upload_queue.queue if i.id not in cancelled_uploads]
queue: List[Optional[UploadItem]] = list(upload_queue.queue)
items = [asdict(i) for i in queue if i is not None and (i.id not in cancelled_uploads)]
UploadQueueCache.params.put("AthenadUploadQueue", json.dumps(items))
except Exception:
cloudlog.exception("athena.UploadQueueCache.cache.exception")
def handle_long_poll(ws):
def handle_long_poll(ws: WebSocket) -> None:
end_event = threading.Event()
threads = [
@ -126,7 +165,7 @@ def handle_long_poll(ws):
thread.join()
def jsonrpc_handler(end_event):
def jsonrpc_handler(end_event: threading.Event) -> None:
dispatcher["startLocalProxy"] = partial(startLocalProxy, end_event)
while not end_event.is_set():
try:
@ -147,11 +186,12 @@ def jsonrpc_handler(end_event):
def retry_upload(tid: int, end_event: threading.Event, increase_count: bool = True) -> None:
if cur_upload_items[tid].retry_count < MAX_RETRY_COUNT:
item = cur_upload_items[tid]
if item is not None and item.retry_count < MAX_RETRY_COUNT:
new_retry_count = item.retry_count + 1 if increase_count else item.retry_count
item = item._replace(
item = replace(
item,
retry_count=new_retry_count,
progress=0,
current=False
@ -175,44 +215,44 @@ def upload_handler(end_event: threading.Event) -> None:
cur_upload_items[tid] = None
try:
cur_upload_items[tid] = upload_queue.get(timeout=1)._replace(current=True)
cur_upload_items[tid] = item = replace(upload_queue.get(timeout=1), current=True)
if cur_upload_items[tid].id in cancelled_uploads:
cancelled_uploads.remove(cur_upload_items[tid].id)
if item.id in cancelled_uploads:
cancelled_uploads.remove(item.id)
continue
# Remove item if too old
age = datetime.now() - datetime.fromtimestamp(cur_upload_items[tid].created_at / 1000)
age = datetime.now() - datetime.fromtimestamp(item.created_at / 1000)
if age.total_seconds() > MAX_AGE:
cloudlog.event("athena.upload_handler.expired", item=cur_upload_items[tid], error=True)
cloudlog.event("athena.upload_handler.expired", item=item, error=True)
continue
# Check if uploading over metered connection is allowed
sm.update(0)
metered = sm['deviceState'].networkMetered
network_type = sm['deviceState'].networkType.raw
if metered and (not cur_upload_items[tid].allow_cellular):
if metered and (not item.allow_cellular):
retry_upload(tid, end_event, False)
continue
try:
def cb(sz, cur):
def cb(sz: int, cur: int) -> None:
# Abort transfer if connection changed to metered after starting upload
sm.update(0)
metered = sm['deviceState'].networkMetered
if metered and (not cur_upload_items[tid].allow_cellular):
if metered and (not item.allow_cellular):
raise AbortTransferException
cur_upload_items[tid] = cur_upload_items[tid]._replace(progress=cur / sz if sz else 1)
cur_upload_items[tid] = replace(item, progress=cur / sz if sz else 1)
fn = cur_upload_items[tid].path
fn = item.path
try:
sz = os.path.getsize(fn)
except OSError:
sz = -1
cloudlog.event("athena.upload_handler.upload_start", fn=fn, sz=sz, network_type=network_type, metered=metered, retry_count=cur_upload_items[tid].retry_count)
response = _do_upload(cur_upload_items[tid], cb)
cloudlog.event("athena.upload_handler.upload_start", fn=fn, sz=sz, network_type=network_type, metered=metered, retry_count=item.retry_count)
response = _do_upload(item, cb)
if response.status_code not in (200, 201, 401, 403, 412):
cloudlog.event("athena.upload_handler.retry", status_code=response.status_code, fn=fn, sz=sz, network_type=network_type, metered=metered)
@ -234,7 +274,7 @@ def upload_handler(end_event: threading.Event) -> None:
cloudlog.exception("athena.upload_handler.exception")
def _do_upload(upload_item, callback=None):
def _do_upload(upload_item: UploadItem, callback: Optional[Callable] = None) -> requests.Response:
path = upload_item.path
compress = False
@ -244,27 +284,25 @@ def _do_upload(upload_item, callback=None):
compress = True
with open(path, "rb") as f:
data: BinaryIO
if compress:
cloudlog.event("athena.upload_handler.compress", fn=path, fn_orig=upload_item.path)
data = bz2.compress(f.read())
size = len(data)
data = io.BytesIO(data)
compressed = bz2.compress(f.read())
size = len(compressed)
data = io.BytesIO(compressed)
else:
size = os.fstat(f.fileno()).st_size
data = f
if callback:
data = CallbackReader(data, callback, size)
return requests.put(upload_item.url,
data=data,
data=CallbackReader(data, callback, size) if callback else data,
headers={**upload_item.headers, 'Content-Length': str(size)},
timeout=30)
# security: user should be able to request any message from their car
@dispatcher.add_method
def getMessage(service=None, timeout=1000):
def getMessage(service: str, timeout: int = 1000) -> Dict:
if service is None or service not in service_list:
raise Exception("invalid service")
@ -274,7 +312,8 @@ def getMessage(service=None, timeout=1000):
if ret is None:
raise TimeoutError
return ret.to_dict()
# this is because capnp._DynamicStructReader doesn't have typing information
return cast(Dict, ret.to_dict())
@dispatcher.add_method
@ -288,7 +327,7 @@ def getVersion() -> Dict[str, str]:
@dispatcher.add_method
def setNavDestination(latitude=0, longitude=0, place_name=None, place_details=None):
def setNavDestination(latitude: int = 0, longitude: int = 0, place_name: Optional[str] = None, place_details: Optional[str] = None) -> Dict[str, int]:
destination = {
"latitude": latitude,
"longitude": longitude,
@ -300,8 +339,8 @@ def setNavDestination(latitude=0, longitude=0, place_name=None, place_details=No
return {"success": 1}
def scan_dir(path, prefix):
files = list()
def scan_dir(path: str, prefix: str) -> List[str]:
files = []
# only walk directories that match the prefix
# (glob and friends traverse entire dir tree)
with os.scandir(path) as i:
@ -320,18 +359,18 @@ def scan_dir(path, prefix):
return files
@dispatcher.add_method
def listDataDirectory(prefix=''):
def listDataDirectory(prefix='') -> List[str]:
return scan_dir(ROOT, prefix)
@dispatcher.add_method
def reboot():
def reboot() -> Dict[str, int]:
sock = messaging.sub_sock("deviceState", timeout=1000)
ret = messaging.recv_one(sock)
if ret is None or ret.deviceState.started:
raise Exception("Reboot unavailable")
def do_reboot():
def do_reboot() -> None:
time.sleep(2)
HARDWARE.reboot()
@ -341,50 +380,53 @@ def reboot():
@dispatcher.add_method
def uploadFileToUrl(fn, url, headers):
return uploadFilesToUrls([{
def uploadFileToUrl(fn: str, url: str, headers: Dict[str, str]) -> UploadFilesToUrlResponse:
# this is because mypy doesn't understand that the decorator doesn't change the return type
response: UploadFilesToUrlResponse = uploadFilesToUrls([{
"fn": fn,
"url": url,
"headers": headers,
}])
return response
@dispatcher.add_method
def uploadFilesToUrls(files_data):
items = []
failed = []
for file in files_data:
fn = file.get('fn', '')
if len(fn) == 0 or fn[0] == '/' or '..' in fn or 'url' not in file:
failed.append(fn)
def uploadFilesToUrls(files_data: List[UploadFileDict]) -> UploadFilesToUrlResponse:
files = map(UploadFile.from_dict, files_data)
items: List[UploadItemDict] = []
failed: List[str] = []
for file in files:
if len(file.fn) == 0 or file.fn[0] == '/' or '..' in file.fn or len(file.url) == 0:
failed.append(file.fn)
continue
path = os.path.join(ROOT, fn)
path = os.path.join(ROOT, file.fn)
if not os.path.exists(path) and not os.path.exists(strip_bz2_extension(path)):
failed.append(fn)
failed.append(file.fn)
continue
# Skip item if already in queue
url = file['url'].split('?')[0]
url = file.url.split('?')[0]
if any(url == item['url'].split('?')[0] for item in listUploadQueue()):
continue
item = UploadItem(
path=path,
url=file['url'],
headers=file.get('headers', {}),
url=file.url,
headers=file.headers,
created_at=int(time.time() * 1000),
id=None,
allow_cellular=file.get('allow_cellular', False),
allow_cellular=file.allow_cellular,
)
upload_id = hashlib.sha1(str(item).encode()).hexdigest()
item = item._replace(id=upload_id)
item = replace(item, id=upload_id)
upload_queue.put_nowait(item)
items.append(item._asdict())
items.append(asdict(item))
UploadQueueCache.cache(upload_queue)
resp = {"enqueued": len(items), "items": items}
resp: UploadFilesToUrlResponse = {"enqueued": len(items), "items": items}
if failed:
resp["failed"] = failed
@ -392,32 +434,32 @@ def uploadFilesToUrls(files_data):
@dispatcher.add_method
def listUploadQueue():
def listUploadQueue() -> List[UploadItemDict]:
items = list(upload_queue.queue) + list(cur_upload_items.values())
return [i._asdict() for i in items if (i is not None) and (i.id not in cancelled_uploads)]
return [asdict(i) for i in items if (i is not None) and (i.id not in cancelled_uploads)]
@dispatcher.add_method
def cancelUpload(upload_id):
def cancelUpload(upload_id: Union[str, List[str]]) -> Dict[str, Union[int, str]]:
if not isinstance(upload_id, list):
upload_id = [upload_id]
uploading_ids = {item.id for item in list(upload_queue.queue)}
cancelled_ids = uploading_ids.intersection(upload_id)
if len(cancelled_ids) == 0:
return 404
return {"success": 0, "error": "not found"}
cancelled_uploads.update(cancelled_ids)
return {"success": 1}
@dispatcher.add_method
def primeActivated(activated):
def primeActivated(activated: bool) -> Dict[str, int]:
return {"success": 1}
@dispatcher.add_method
def setBandwithLimit(upload_speed_kbps, download_speed_kbps):
def setBandwithLimit(upload_speed_kbps: int, download_speed_kbps: int) -> Dict[str, Union[int, str]]:
if not AGNOS:
return {"success": 0, "error": "only supported on AGNOS"}
@ -428,7 +470,7 @@ def setBandwithLimit(upload_speed_kbps, download_speed_kbps):
return {"success": 0, "error": "failed to set limit", "stdout": e.stdout, "stderr": e.stderr}
def startLocalProxy(global_end_event, remote_ws_uri, local_port):
def startLocalProxy(global_end_event: threading.Event, remote_ws_uri: str, local_port: int) -> Dict[str, int]:
try:
if local_port not in LOCAL_PORT_WHITELIST:
raise Exception("Requested local port not whitelisted")
@ -462,7 +504,7 @@ def startLocalProxy(global_end_event, remote_ws_uri, local_port):
@dispatcher.add_method
def getPublicKey():
def getPublicKey() -> Optional[str]:
if not os.path.isfile(PERSIST + '/comma/id_rsa.pub'):
return None
@ -471,7 +513,7 @@ def getPublicKey():
@dispatcher.add_method
def getSshAuthorizedKeys():
def getSshAuthorizedKeys() -> str:
return Params().get("GithubSshKeys", encoding='utf8') or ''
@ -486,7 +528,7 @@ def getNetworkType():
@dispatcher.add_method
def getNetworkMetered():
def getNetworkMetered() -> bool:
network_type = HARDWARE.get_network_type()
return HARDWARE.get_network_metered(network_type)
@ -497,7 +539,7 @@ def getNetworks():
@dispatcher.add_method
def takeSnapshot():
def takeSnapshot() -> Optional[Union[str, Dict[str, str]]]:
from system.camerad.snapshot.snapshot import jpeg_write, snapshot
ret = snapshot()
if ret is not None:
@ -514,16 +556,19 @@ def takeSnapshot():
raise Exception("not available while camerad is started")
def get_logs_to_send_sorted():
def get_logs_to_send_sorted() -> List[str]:
# TODO: scan once then use inotify to detect file creation/deletion
curr_time = int(time.time())
logs = []
for log_entry in os.listdir(SWAGLOG_DIR):
log_path = os.path.join(SWAGLOG_DIR, log_entry)
time_sent = 0
try:
time_sent = int.from_bytes(getxattr(log_path, LOG_ATTR_NAME), sys.byteorder)
value = getxattr(log_path, LOG_ATTR_NAME)
if value is not None:
time_sent = int.from_bytes(value, sys.byteorder)
except (ValueError, TypeError):
time_sent = 0
pass
# assume send failed and we lost the response if sent more than one hour ago
if not time_sent or curr_time - time_sent > 3600:
logs.append(log_entry)
@ -531,7 +576,7 @@ def get_logs_to_send_sorted():
return sorted(logs)[:-1]
def log_handler(end_event):
def log_handler(end_event: threading.Event) -> None:
if PC:
return
@ -593,7 +638,7 @@ def log_handler(end_event):
cloudlog.exception("athena.log_handler.exception")
def stat_handler(end_event):
def stat_handler(end_event: threading.Event) -> None:
while not end_event.is_set():
last_scan = 0
curr_scan = sec_since_boot()
@ -619,7 +664,7 @@ def stat_handler(end_event):
time.sleep(0.1)
def ws_proxy_recv(ws, local_sock, ssock, end_event, global_end_event):
def ws_proxy_recv(ws: WebSocket, local_sock: socket.socket, ssock: socket.socket, end_event: threading.Event, global_end_event: threading.Event) -> None:
while not (end_event.is_set() or global_end_event.is_set()):
try:
data = ws.recv()
@ -638,7 +683,7 @@ def ws_proxy_recv(ws, local_sock, ssock, end_event, global_end_event):
end_event.set()
def ws_proxy_send(ws, local_sock, signal_sock, end_event):
def ws_proxy_send(ws: WebSocket, local_sock: socket.socket, signal_sock: socket.socket, end_event: threading.Event) -> None:
while not end_event.is_set():
try:
r, _, _ = select.select((local_sock, signal_sock), (), ())
@ -663,7 +708,7 @@ def ws_proxy_send(ws, local_sock, signal_sock, end_event):
cloudlog.debug("athena.ws_proxy_send done closing sockets")
def ws_recv(ws, end_event):
def ws_recv(ws: WebSocket, end_event: threading.Event) -> None:
last_ping = int(sec_since_boot() * 1e9)
while not end_event.is_set():
try:
@ -685,7 +730,7 @@ def ws_recv(ws, end_event):
end_event.set()
def ws_send(ws, end_event):
def ws_send(ws: WebSocket, end_event: threading.Event) -> None:
while not end_event.is_set():
try:
try:
@ -704,7 +749,7 @@ def ws_send(ws, end_event):
end_event.set()
def backoff(retries):
def backoff(retries: int) -> int:
return random.randrange(0, min(128, int(2 ** retries)))

@ -8,6 +8,7 @@ import time
import threading
import queue
import unittest
from dataclasses import asdict, replace
from datetime import datetime, timedelta
from typing import Optional
@ -226,7 +227,7 @@ class TestAthenadMethods(unittest.TestCase):
"""When an upload times out or fails to connect it should be placed back in the queue"""
fn = self._create_file('qlog.bz2')
item = athenad.UploadItem(path=fn, url="http://localhost:44444/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='', allow_cellular=True)
item_no_retry = item._replace(retry_count=MAX_RETRY_COUNT)
item_no_retry = replace(item, retry_count=MAX_RETRY_COUNT)
end_event = threading.Event()
thread = threading.Thread(target=athenad.upload_handler, args=(end_event,))
@ -296,7 +297,7 @@ class TestAthenadMethods(unittest.TestCase):
self.assertEqual(len(items), 0)
@with_http_server
def test_listUploadQueueCurrent(self, host):
def test_listUploadQueueCurrent(self, host: str):
fn = self._create_file('qlog.bz2')
item = athenad.UploadItem(path=fn, url=f"{host}/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='', allow_cellular=True)
@ -321,7 +322,7 @@ class TestAthenadMethods(unittest.TestCase):
items = dispatcher["listUploadQueue"]()
self.assertEqual(len(items), 1)
self.assertDictEqual(items[0], item._asdict())
self.assertDictEqual(items[0], asdict(item))
self.assertFalse(items[0]['current'])
athenad.cancelled_uploads.add(item.id)
@ -346,7 +347,7 @@ class TestAthenadMethods(unittest.TestCase):
athenad.UploadQueueCache.initialize(athenad.upload_queue)
self.assertEqual(athenad.upload_queue.qsize(), 1)
self.assertDictEqual(athenad.upload_queue.queue[-1]._asdict(), item1._asdict())
self.assertDictEqual(asdict(athenad.upload_queue.queue[-1]), asdict(item1))
@mock.patch('selfdrive.athena.athenad.create_connection')
def test_startLocalProxy(self, mock_create_connection):
@ -417,5 +418,6 @@ class TestAthenadMethods(unittest.TestCase):
sl = athenad.get_logs_to_send_sorted()
self.assertListEqual(sl, fl[:-1])
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save