athenad: more types (#25877)

* add typing hints

* missed these

* revert functional changes and changes to uploader

* remove

* try any

* add types to test code

* try dataclass instead

* mypy needs this

* comments

* remove Any type

* remove unused method

* cleanup

Co-authored-by: Adeeb Shihadeh <adeebshihadeh@gmail.com>
pull/26776/head
Cameron Clough 2 years ago committed by GitHub
parent e6e33531ca
commit e1c739f709
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 223
      selfdrive/athena/athenad.py
  2. 10
      selfdrive/athena/tests/test_athenad.py

@ -1,4 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from __future__ import annotations
import base64 import base64
import bz2 import bz2
import hashlib import hashlib
@ -14,14 +16,15 @@ import sys
import tempfile import tempfile
import threading import threading
import time import time
from collections import namedtuple from dataclasses import asdict, dataclass, replace
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from typing import Any, Dict from queue import Queue
from typing import BinaryIO, Callable, Dict, List, Optional, Set, Union, cast
import requests import requests
from jsonrpc import JSONRPCResponseManager, dispatcher from jsonrpc import JSONRPCResponseManager, dispatcher
from websocket import (ABNF, WebSocketException, WebSocketTimeoutException, from websocket import (ABNF, WebSocket, WebSocketException, WebSocketTimeoutException,
create_connection) create_connection)
import cereal.messaging as messaging import cereal.messaging as messaging
@ -54,19 +57,54 @@ WS_FRAME_SIZE = 4096
NetworkType = log.DeviceState.NetworkType NetworkType = log.DeviceState.NetworkType
UploadFileDict = Dict[str, Union[str, int, float, bool]]
UploadItemDict = Dict[str, Union[str, bool, int, float, Dict[str, str]]]
UploadFilesToUrlResponse = Dict[str, Union[int, List[UploadItemDict], List[str]]]
@dataclass
class UploadFile:
fn: str
url: str
headers: Dict[str, str]
allow_cellular: bool
@classmethod
def from_dict(cls, d: Dict) -> UploadFile:
return cls(d.get("fn", ""), d.get("url", ""), d.get("headers", {}), d.get("allow_cellular", False))
@dataclass
class UploadItem:
path: str
url: str
headers: Dict[str, str]
created_at: int
id: Optional[str]
retry_count: int = 0
current: bool = False
progress: float = 0
allow_cellular: bool = False
@classmethod
def from_dict(cls, d: Dict) -> UploadItem:
return cls(d["path"], d["url"], d["headers"], d["created_at"], d["id"], d["retry_count"], d["current"],
d["progress"], d["allow_cellular"])
dispatcher["echo"] = lambda s: s dispatcher["echo"] = lambda s: s
recv_queue: Any = queue.Queue() recv_queue: Queue[str] = queue.Queue()
send_queue: Any = queue.Queue() send_queue: Queue[str] = queue.Queue()
upload_queue: Any = queue.Queue() upload_queue: Queue[UploadItem] = queue.Queue()
low_priority_send_queue: Any = queue.Queue() low_priority_send_queue: Queue[str] = queue.Queue()
log_recv_queue: Any = queue.Queue() log_recv_queue: Queue[str] = queue.Queue()
cancelled_uploads: Any = set() cancelled_uploads: Set[str] = set()
UploadItem = namedtuple('UploadItem', ['path', 'url', 'headers', 'created_at', 'id', 'retry_count', 'current', 'progress', 'allow_cellular'], defaults=(0, False, 0, False))
cur_upload_items: Dict[int, Any] = {} cur_upload_items: Dict[int, Optional[UploadItem]] = {}
def strip_bz2_extension(fn): def strip_bz2_extension(fn: str) -> str:
if fn.endswith('.bz2'): if fn.endswith('.bz2'):
return fn[:-4] return fn[:-4]
return fn return fn
@ -76,29 +114,30 @@ class AbortTransferException(Exception):
pass pass
class UploadQueueCache(): class UploadQueueCache:
params = Params() params = Params()
@staticmethod @staticmethod
def initialize(upload_queue): def initialize(upload_queue: Queue[UploadItem]) -> None:
try: try:
upload_queue_json = UploadQueueCache.params.get("AthenadUploadQueue") upload_queue_json = UploadQueueCache.params.get("AthenadUploadQueue")
if upload_queue_json is not None: if upload_queue_json is not None:
for item in json.loads(upload_queue_json): for item in json.loads(upload_queue_json):
upload_queue.put(UploadItem(**item)) upload_queue.put(UploadItem.from_dict(item))
except Exception: except Exception:
cloudlog.exception("athena.UploadQueueCache.initialize.exception") cloudlog.exception("athena.UploadQueueCache.initialize.exception")
@staticmethod @staticmethod
def cache(upload_queue): def cache(upload_queue: Queue[UploadItem]) -> None:
try: try:
items = [i._asdict() for i in upload_queue.queue if i.id not in cancelled_uploads] queue: List[Optional[UploadItem]] = list(upload_queue.queue)
items = [asdict(i) for i in queue if i is not None and (i.id not in cancelled_uploads)]
UploadQueueCache.params.put("AthenadUploadQueue", json.dumps(items)) UploadQueueCache.params.put("AthenadUploadQueue", json.dumps(items))
except Exception: except Exception:
cloudlog.exception("athena.UploadQueueCache.cache.exception") cloudlog.exception("athena.UploadQueueCache.cache.exception")
def handle_long_poll(ws): def handle_long_poll(ws: WebSocket) -> None:
end_event = threading.Event() end_event = threading.Event()
threads = [ threads = [
@ -126,7 +165,7 @@ def handle_long_poll(ws):
thread.join() thread.join()
def jsonrpc_handler(end_event): def jsonrpc_handler(end_event: threading.Event) -> None:
dispatcher["startLocalProxy"] = partial(startLocalProxy, end_event) dispatcher["startLocalProxy"] = partial(startLocalProxy, end_event)
while not end_event.is_set(): while not end_event.is_set():
try: try:
@ -147,11 +186,12 @@ def jsonrpc_handler(end_event):
def retry_upload(tid: int, end_event: threading.Event, increase_count: bool = True) -> None: def retry_upload(tid: int, end_event: threading.Event, increase_count: bool = True) -> None:
if cur_upload_items[tid].retry_count < MAX_RETRY_COUNT: item = cur_upload_items[tid]
item = cur_upload_items[tid] if item is not None and item.retry_count < MAX_RETRY_COUNT:
new_retry_count = item.retry_count + 1 if increase_count else item.retry_count new_retry_count = item.retry_count + 1 if increase_count else item.retry_count
item = item._replace( item = replace(
item,
retry_count=new_retry_count, retry_count=new_retry_count,
progress=0, progress=0,
current=False current=False
@ -175,44 +215,44 @@ def upload_handler(end_event: threading.Event) -> None:
cur_upload_items[tid] = None cur_upload_items[tid] = None
try: try:
cur_upload_items[tid] = upload_queue.get(timeout=1)._replace(current=True) cur_upload_items[tid] = item = replace(upload_queue.get(timeout=1), current=True)
if cur_upload_items[tid].id in cancelled_uploads: if item.id in cancelled_uploads:
cancelled_uploads.remove(cur_upload_items[tid].id) cancelled_uploads.remove(item.id)
continue continue
# Remove item if too old # Remove item if too old
age = datetime.now() - datetime.fromtimestamp(cur_upload_items[tid].created_at / 1000) age = datetime.now() - datetime.fromtimestamp(item.created_at / 1000)
if age.total_seconds() > MAX_AGE: if age.total_seconds() > MAX_AGE:
cloudlog.event("athena.upload_handler.expired", item=cur_upload_items[tid], error=True) cloudlog.event("athena.upload_handler.expired", item=item, error=True)
continue continue
# Check if uploading over metered connection is allowed # Check if uploading over metered connection is allowed
sm.update(0) sm.update(0)
metered = sm['deviceState'].networkMetered metered = sm['deviceState'].networkMetered
network_type = sm['deviceState'].networkType.raw network_type = sm['deviceState'].networkType.raw
if metered and (not cur_upload_items[tid].allow_cellular): if metered and (not item.allow_cellular):
retry_upload(tid, end_event, False) retry_upload(tid, end_event, False)
continue continue
try: try:
def cb(sz, cur): def cb(sz: int, cur: int) -> None:
# Abort transfer if connection changed to metered after starting upload # Abort transfer if connection changed to metered after starting upload
sm.update(0) sm.update(0)
metered = sm['deviceState'].networkMetered metered = sm['deviceState'].networkMetered
if metered and (not cur_upload_items[tid].allow_cellular): if metered and (not item.allow_cellular):
raise AbortTransferException raise AbortTransferException
cur_upload_items[tid] = cur_upload_items[tid]._replace(progress=cur / sz if sz else 1) cur_upload_items[tid] = replace(item, progress=cur / sz if sz else 1)
fn = cur_upload_items[tid].path fn = item.path
try: try:
sz = os.path.getsize(fn) sz = os.path.getsize(fn)
except OSError: except OSError:
sz = -1 sz = -1
cloudlog.event("athena.upload_handler.upload_start", fn=fn, sz=sz, network_type=network_type, metered=metered, retry_count=cur_upload_items[tid].retry_count) cloudlog.event("athena.upload_handler.upload_start", fn=fn, sz=sz, network_type=network_type, metered=metered, retry_count=item.retry_count)
response = _do_upload(cur_upload_items[tid], cb) response = _do_upload(item, cb)
if response.status_code not in (200, 201, 401, 403, 412): if response.status_code not in (200, 201, 401, 403, 412):
cloudlog.event("athena.upload_handler.retry", status_code=response.status_code, fn=fn, sz=sz, network_type=network_type, metered=metered) cloudlog.event("athena.upload_handler.retry", status_code=response.status_code, fn=fn, sz=sz, network_type=network_type, metered=metered)
@ -234,7 +274,7 @@ def upload_handler(end_event: threading.Event) -> None:
cloudlog.exception("athena.upload_handler.exception") cloudlog.exception("athena.upload_handler.exception")
def _do_upload(upload_item, callback=None): def _do_upload(upload_item: UploadItem, callback: Optional[Callable] = None) -> requests.Response:
path = upload_item.path path = upload_item.path
compress = False compress = False
@ -244,27 +284,25 @@ def _do_upload(upload_item, callback=None):
compress = True compress = True
with open(path, "rb") as f: with open(path, "rb") as f:
data: BinaryIO
if compress: if compress:
cloudlog.event("athena.upload_handler.compress", fn=path, fn_orig=upload_item.path) cloudlog.event("athena.upload_handler.compress", fn=path, fn_orig=upload_item.path)
data = bz2.compress(f.read()) compressed = bz2.compress(f.read())
size = len(data) size = len(compressed)
data = io.BytesIO(data) data = io.BytesIO(compressed)
else: else:
size = os.fstat(f.fileno()).st_size size = os.fstat(f.fileno()).st_size
data = f data = f
if callback:
data = CallbackReader(data, callback, size)
return requests.put(upload_item.url, return requests.put(upload_item.url,
data=data, data=CallbackReader(data, callback, size) if callback else data,
headers={**upload_item.headers, 'Content-Length': str(size)}, headers={**upload_item.headers, 'Content-Length': str(size)},
timeout=30) timeout=30)
# security: user should be able to request any message from their car # security: user should be able to request any message from their car
@dispatcher.add_method @dispatcher.add_method
def getMessage(service=None, timeout=1000): def getMessage(service: str, timeout: int = 1000) -> Dict:
if service is None or service not in service_list: if service is None or service not in service_list:
raise Exception("invalid service") raise Exception("invalid service")
@ -274,7 +312,8 @@ def getMessage(service=None, timeout=1000):
if ret is None: if ret is None:
raise TimeoutError raise TimeoutError
return ret.to_dict() # this is because capnp._DynamicStructReader doesn't have typing information
return cast(Dict, ret.to_dict())
@dispatcher.add_method @dispatcher.add_method
@ -288,7 +327,7 @@ def getVersion() -> Dict[str, str]:
@dispatcher.add_method @dispatcher.add_method
def setNavDestination(latitude=0, longitude=0, place_name=None, place_details=None): def setNavDestination(latitude: int = 0, longitude: int = 0, place_name: Optional[str] = None, place_details: Optional[str] = None) -> Dict[str, int]:
destination = { destination = {
"latitude": latitude, "latitude": latitude,
"longitude": longitude, "longitude": longitude,
@ -300,8 +339,8 @@ def setNavDestination(latitude=0, longitude=0, place_name=None, place_details=No
return {"success": 1} return {"success": 1}
def scan_dir(path, prefix): def scan_dir(path: str, prefix: str) -> List[str]:
files = list() files = []
# only walk directories that match the prefix # only walk directories that match the prefix
# (glob and friends traverse entire dir tree) # (glob and friends traverse entire dir tree)
with os.scandir(path) as i: with os.scandir(path) as i:
@ -320,18 +359,18 @@ def scan_dir(path, prefix):
return files return files
@dispatcher.add_method @dispatcher.add_method
def listDataDirectory(prefix=''): def listDataDirectory(prefix='') -> List[str]:
return scan_dir(ROOT, prefix) return scan_dir(ROOT, prefix)
@dispatcher.add_method @dispatcher.add_method
def reboot(): def reboot() -> Dict[str, int]:
sock = messaging.sub_sock("deviceState", timeout=1000) sock = messaging.sub_sock("deviceState", timeout=1000)
ret = messaging.recv_one(sock) ret = messaging.recv_one(sock)
if ret is None or ret.deviceState.started: if ret is None or ret.deviceState.started:
raise Exception("Reboot unavailable") raise Exception("Reboot unavailable")
def do_reboot(): def do_reboot() -> None:
time.sleep(2) time.sleep(2)
HARDWARE.reboot() HARDWARE.reboot()
@ -341,50 +380,53 @@ def reboot():
@dispatcher.add_method @dispatcher.add_method
def uploadFileToUrl(fn, url, headers): def uploadFileToUrl(fn: str, url: str, headers: Dict[str, str]) -> UploadFilesToUrlResponse:
return uploadFilesToUrls([{ # this is because mypy doesn't understand that the decorator doesn't change the return type
response: UploadFilesToUrlResponse = uploadFilesToUrls([{
"fn": fn, "fn": fn,
"url": url, "url": url,
"headers": headers, "headers": headers,
}]) }])
return response
@dispatcher.add_method @dispatcher.add_method
def uploadFilesToUrls(files_data): def uploadFilesToUrls(files_data: List[UploadFileDict]) -> UploadFilesToUrlResponse:
items = [] files = map(UploadFile.from_dict, files_data)
failed = []
for file in files_data: items: List[UploadItemDict] = []
fn = file.get('fn', '') failed: List[str] = []
if len(fn) == 0 or fn[0] == '/' or '..' in fn or 'url' not in file: for file in files:
failed.append(fn) if len(file.fn) == 0 or file.fn[0] == '/' or '..' in file.fn or len(file.url) == 0:
failed.append(file.fn)
continue continue
path = os.path.join(ROOT, fn) path = os.path.join(ROOT, file.fn)
if not os.path.exists(path) and not os.path.exists(strip_bz2_extension(path)): if not os.path.exists(path) and not os.path.exists(strip_bz2_extension(path)):
failed.append(fn) failed.append(file.fn)
continue continue
# Skip item if already in queue # Skip item if already in queue
url = file['url'].split('?')[0] url = file.url.split('?')[0]
if any(url == item['url'].split('?')[0] for item in listUploadQueue()): if any(url == item['url'].split('?')[0] for item in listUploadQueue()):
continue continue
item = UploadItem( item = UploadItem(
path=path, path=path,
url=file['url'], url=file.url,
headers=file.get('headers', {}), headers=file.headers,
created_at=int(time.time() * 1000), created_at=int(time.time() * 1000),
id=None, id=None,
allow_cellular=file.get('allow_cellular', False), allow_cellular=file.allow_cellular,
) )
upload_id = hashlib.sha1(str(item).encode()).hexdigest() upload_id = hashlib.sha1(str(item).encode()).hexdigest()
item = item._replace(id=upload_id) item = replace(item, id=upload_id)
upload_queue.put_nowait(item) upload_queue.put_nowait(item)
items.append(item._asdict()) items.append(asdict(item))
UploadQueueCache.cache(upload_queue) UploadQueueCache.cache(upload_queue)
resp = {"enqueued": len(items), "items": items} resp: UploadFilesToUrlResponse = {"enqueued": len(items), "items": items}
if failed: if failed:
resp["failed"] = failed resp["failed"] = failed
@ -392,32 +434,32 @@ def uploadFilesToUrls(files_data):
@dispatcher.add_method @dispatcher.add_method
def listUploadQueue(): def listUploadQueue() -> List[UploadItemDict]:
items = list(upload_queue.queue) + list(cur_upload_items.values()) items = list(upload_queue.queue) + list(cur_upload_items.values())
return [i._asdict() for i in items if (i is not None) and (i.id not in cancelled_uploads)] return [asdict(i) for i in items if (i is not None) and (i.id not in cancelled_uploads)]
@dispatcher.add_method @dispatcher.add_method
def cancelUpload(upload_id): def cancelUpload(upload_id: Union[str, List[str]]) -> Dict[str, Union[int, str]]:
if not isinstance(upload_id, list): if not isinstance(upload_id, list):
upload_id = [upload_id] upload_id = [upload_id]
uploading_ids = {item.id for item in list(upload_queue.queue)} uploading_ids = {item.id for item in list(upload_queue.queue)}
cancelled_ids = uploading_ids.intersection(upload_id) cancelled_ids = uploading_ids.intersection(upload_id)
if len(cancelled_ids) == 0: if len(cancelled_ids) == 0:
return 404 return {"success": 0, "error": "not found"}
cancelled_uploads.update(cancelled_ids) cancelled_uploads.update(cancelled_ids)
return {"success": 1} return {"success": 1}
@dispatcher.add_method @dispatcher.add_method
def primeActivated(activated): def primeActivated(activated: bool) -> Dict[str, int]:
return {"success": 1} return {"success": 1}
@dispatcher.add_method @dispatcher.add_method
def setBandwithLimit(upload_speed_kbps, download_speed_kbps): def setBandwithLimit(upload_speed_kbps: int, download_speed_kbps: int) -> Dict[str, Union[int, str]]:
if not AGNOS: if not AGNOS:
return {"success": 0, "error": "only supported on AGNOS"} return {"success": 0, "error": "only supported on AGNOS"}
@ -428,7 +470,7 @@ def setBandwithLimit(upload_speed_kbps, download_speed_kbps):
return {"success": 0, "error": "failed to set limit", "stdout": e.stdout, "stderr": e.stderr} return {"success": 0, "error": "failed to set limit", "stdout": e.stdout, "stderr": e.stderr}
def startLocalProxy(global_end_event, remote_ws_uri, local_port): def startLocalProxy(global_end_event: threading.Event, remote_ws_uri: str, local_port: int) -> Dict[str, int]:
try: try:
if local_port not in LOCAL_PORT_WHITELIST: if local_port not in LOCAL_PORT_WHITELIST:
raise Exception("Requested local port not whitelisted") raise Exception("Requested local port not whitelisted")
@ -462,7 +504,7 @@ def startLocalProxy(global_end_event, remote_ws_uri, local_port):
@dispatcher.add_method @dispatcher.add_method
def getPublicKey(): def getPublicKey() -> Optional[str]:
if not os.path.isfile(PERSIST + '/comma/id_rsa.pub'): if not os.path.isfile(PERSIST + '/comma/id_rsa.pub'):
return None return None
@ -471,7 +513,7 @@ def getPublicKey():
@dispatcher.add_method @dispatcher.add_method
def getSshAuthorizedKeys(): def getSshAuthorizedKeys() -> str:
return Params().get("GithubSshKeys", encoding='utf8') or '' return Params().get("GithubSshKeys", encoding='utf8') or ''
@ -486,7 +528,7 @@ def getNetworkType():
@dispatcher.add_method @dispatcher.add_method
def getNetworkMetered(): def getNetworkMetered() -> bool:
network_type = HARDWARE.get_network_type() network_type = HARDWARE.get_network_type()
return HARDWARE.get_network_metered(network_type) return HARDWARE.get_network_metered(network_type)
@ -497,7 +539,7 @@ def getNetworks():
@dispatcher.add_method @dispatcher.add_method
def takeSnapshot(): def takeSnapshot() -> Optional[Union[str, Dict[str, str]]]:
from system.camerad.snapshot.snapshot import jpeg_write, snapshot from system.camerad.snapshot.snapshot import jpeg_write, snapshot
ret = snapshot() ret = snapshot()
if ret is not None: if ret is not None:
@ -514,16 +556,19 @@ def takeSnapshot():
raise Exception("not available while camerad is started") raise Exception("not available while camerad is started")
def get_logs_to_send_sorted(): def get_logs_to_send_sorted() -> List[str]:
# TODO: scan once then use inotify to detect file creation/deletion # TODO: scan once then use inotify to detect file creation/deletion
curr_time = int(time.time()) curr_time = int(time.time())
logs = [] logs = []
for log_entry in os.listdir(SWAGLOG_DIR): for log_entry in os.listdir(SWAGLOG_DIR):
log_path = os.path.join(SWAGLOG_DIR, log_entry) log_path = os.path.join(SWAGLOG_DIR, log_entry)
time_sent = 0
try: try:
time_sent = int.from_bytes(getxattr(log_path, LOG_ATTR_NAME), sys.byteorder) value = getxattr(log_path, LOG_ATTR_NAME)
if value is not None:
time_sent = int.from_bytes(value, sys.byteorder)
except (ValueError, TypeError): except (ValueError, TypeError):
time_sent = 0 pass
# assume send failed and we lost the response if sent more than one hour ago # assume send failed and we lost the response if sent more than one hour ago
if not time_sent or curr_time - time_sent > 3600: if not time_sent or curr_time - time_sent > 3600:
logs.append(log_entry) logs.append(log_entry)
@ -531,7 +576,7 @@ def get_logs_to_send_sorted():
return sorted(logs)[:-1] return sorted(logs)[:-1]
def log_handler(end_event): def log_handler(end_event: threading.Event) -> None:
if PC: if PC:
return return
@ -593,7 +638,7 @@ def log_handler(end_event):
cloudlog.exception("athena.log_handler.exception") cloudlog.exception("athena.log_handler.exception")
def stat_handler(end_event): def stat_handler(end_event: threading.Event) -> None:
while not end_event.is_set(): while not end_event.is_set():
last_scan = 0 last_scan = 0
curr_scan = sec_since_boot() curr_scan = sec_since_boot()
@ -619,7 +664,7 @@ def stat_handler(end_event):
time.sleep(0.1) time.sleep(0.1)
def ws_proxy_recv(ws, local_sock, ssock, end_event, global_end_event): def ws_proxy_recv(ws: WebSocket, local_sock: socket.socket, ssock: socket.socket, end_event: threading.Event, global_end_event: threading.Event) -> None:
while not (end_event.is_set() or global_end_event.is_set()): while not (end_event.is_set() or global_end_event.is_set()):
try: try:
data = ws.recv() data = ws.recv()
@ -638,7 +683,7 @@ def ws_proxy_recv(ws, local_sock, ssock, end_event, global_end_event):
end_event.set() end_event.set()
def ws_proxy_send(ws, local_sock, signal_sock, end_event): def ws_proxy_send(ws: WebSocket, local_sock: socket.socket, signal_sock: socket.socket, end_event: threading.Event) -> None:
while not end_event.is_set(): while not end_event.is_set():
try: try:
r, _, _ = select.select((local_sock, signal_sock), (), ()) r, _, _ = select.select((local_sock, signal_sock), (), ())
@ -663,7 +708,7 @@ def ws_proxy_send(ws, local_sock, signal_sock, end_event):
cloudlog.debug("athena.ws_proxy_send done closing sockets") cloudlog.debug("athena.ws_proxy_send done closing sockets")
def ws_recv(ws, end_event): def ws_recv(ws: WebSocket, end_event: threading.Event) -> None:
last_ping = int(sec_since_boot() * 1e9) last_ping = int(sec_since_boot() * 1e9)
while not end_event.is_set(): while not end_event.is_set():
try: try:
@ -685,7 +730,7 @@ def ws_recv(ws, end_event):
end_event.set() end_event.set()
def ws_send(ws, end_event): def ws_send(ws: WebSocket, end_event: threading.Event) -> None:
while not end_event.is_set(): while not end_event.is_set():
try: try:
try: try:
@ -704,7 +749,7 @@ def ws_send(ws, end_event):
end_event.set() end_event.set()
def backoff(retries): def backoff(retries: int) -> int:
return random.randrange(0, min(128, int(2 ** retries))) return random.randrange(0, min(128, int(2 ** retries)))

@ -8,6 +8,7 @@ import time
import threading import threading
import queue import queue
import unittest import unittest
from dataclasses import asdict, replace
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional from typing import Optional
@ -226,7 +227,7 @@ class TestAthenadMethods(unittest.TestCase):
"""When an upload times out or fails to connect it should be placed back in the queue""" """When an upload times out or fails to connect it should be placed back in the queue"""
fn = self._create_file('qlog.bz2') fn = self._create_file('qlog.bz2')
item = athenad.UploadItem(path=fn, url="http://localhost:44444/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='', allow_cellular=True) item = athenad.UploadItem(path=fn, url="http://localhost:44444/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='', allow_cellular=True)
item_no_retry = item._replace(retry_count=MAX_RETRY_COUNT) item_no_retry = replace(item, retry_count=MAX_RETRY_COUNT)
end_event = threading.Event() end_event = threading.Event()
thread = threading.Thread(target=athenad.upload_handler, args=(end_event,)) thread = threading.Thread(target=athenad.upload_handler, args=(end_event,))
@ -296,7 +297,7 @@ class TestAthenadMethods(unittest.TestCase):
self.assertEqual(len(items), 0) self.assertEqual(len(items), 0)
@with_http_server @with_http_server
def test_listUploadQueueCurrent(self, host): def test_listUploadQueueCurrent(self, host: str):
fn = self._create_file('qlog.bz2') fn = self._create_file('qlog.bz2')
item = athenad.UploadItem(path=fn, url=f"{host}/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='', allow_cellular=True) item = athenad.UploadItem(path=fn, url=f"{host}/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='', allow_cellular=True)
@ -321,7 +322,7 @@ class TestAthenadMethods(unittest.TestCase):
items = dispatcher["listUploadQueue"]() items = dispatcher["listUploadQueue"]()
self.assertEqual(len(items), 1) self.assertEqual(len(items), 1)
self.assertDictEqual(items[0], item._asdict()) self.assertDictEqual(items[0], asdict(item))
self.assertFalse(items[0]['current']) self.assertFalse(items[0]['current'])
athenad.cancelled_uploads.add(item.id) athenad.cancelled_uploads.add(item.id)
@ -346,7 +347,7 @@ class TestAthenadMethods(unittest.TestCase):
athenad.UploadQueueCache.initialize(athenad.upload_queue) athenad.UploadQueueCache.initialize(athenad.upload_queue)
self.assertEqual(athenad.upload_queue.qsize(), 1) self.assertEqual(athenad.upload_queue.qsize(), 1)
self.assertDictEqual(athenad.upload_queue.queue[-1]._asdict(), item1._asdict()) self.assertDictEqual(asdict(athenad.upload_queue.queue[-1]), asdict(item1))
@mock.patch('selfdrive.athena.athenad.create_connection') @mock.patch('selfdrive.athena.athenad.create_connection')
def test_startLocalProxy(self, mock_create_connection): def test_startLocalProxy(self, mock_create_connection):
@ -417,5 +418,6 @@ class TestAthenadMethods(unittest.TestCase):
sl = athenad.get_logs_to_send_sorted() sl = athenad.get_logs_to_send_sorted()
self.assertListEqual(sl, fl[:-1]) self.assertListEqual(sl, fl[:-1])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

Loading…
Cancel
Save