diff --git a/common/file_helpers.py b/common/file_helpers.py index c7a70ab879..a344a71fef 100644 --- a/common/file_helpers.py +++ b/common/file_helpers.py @@ -79,6 +79,25 @@ class NamedTemporaryDir(): self.close() +class CallbackReader: + """Wraps a file, but overrides the read method to also + call a callback function with the number of bytes read so far.""" + def __init__(self, f, callback, *args): + self.f = f + self.callback = callback + self.cb_args = args + self.total_read = 0 + + def __getattr__(self, attr): + return getattr(self.f, attr) + + def read(self, *args, **kwargs): + chunk = self.f.read(*args, **kwargs) + self.total_read += len(chunk) + self.callback(*self.cb_args, self.total_read) + return chunk + + def _get_fileobject_func(writer, temp_dir): def _get_fileobject(): file_obj = writer.get_fileobject(dir=temp_dir) diff --git a/selfdrive/athena/athenad.py b/selfdrive/athena/athenad.py index 86751fad86..97ea18d5cd 100755 --- a/selfdrive/athena/athenad.py +++ b/selfdrive/athena/athenad.py @@ -22,6 +22,7 @@ from websocket import ABNF, WebSocketTimeoutException, WebSocketException, creat import cereal.messaging as messaging from cereal.services import service_list from common.api import Api +from common.file_helpers import CallbackReader from common.basedir import PERSIST from common.params import Params from common.realtime import sec_since_boot @@ -49,7 +50,9 @@ upload_queue: Any = queue.Queue() log_send_queue: Any = queue.Queue() log_recv_queue: Any = queue.Queue() cancelled_uploads: Any = set() -UploadItem = namedtuple('UploadItem', ['path', 'url', 'headers', 'created_at', 'id', 'retry_count'], defaults=(0,)) +UploadItem = namedtuple('UploadItem', ['path', 'url', 'headers', 'created_at', 'id', 'retry_count', 'current', 'progress'], defaults=(0, False, 0)) + +cur_upload_items = {} def handle_long_poll(ws): @@ -100,35 +103,53 @@ def jsonrpc_handler(end_event): def upload_handler(end_event): + tid = threading.get_ident() + while not end_event.is_set(): + cur_upload_items[tid] = None + try: - item = upload_queue.get(timeout=1) - if item.id in cancelled_uploads: - cancelled_uploads.remove(item.id) + cur_upload_items[tid] = upload_queue.get(timeout=1)._replace(current=True) + if cur_upload_items[tid].id in cancelled_uploads: + cancelled_uploads.remove(cur_upload_items[tid].id) continue try: - _do_upload(item) - except (requests.exceptions.Timeout, requests.exceptions.ConnectionError, requests.exceptions.SSLError) as e: - cloudlog.warning(f"athena.upload_handler.retry {e} {item}") + def cb(sz, cur): + cur_upload_items[tid] = cur_upload_items[tid]._replace(progress=cur / sz if sz else 1) - if item.retry_count < MAX_RETRY_COUNT: - item = item._replace(retry_count=item.retry_count + 1) + _do_upload(cur_upload_items[tid], cb) + except (requests.exceptions.Timeout, requests.exceptions.ConnectionError, requests.exceptions.SSLError) as e: + cloudlog.warning(f"athena.upload_handler.retry {e} {cur_upload_items[tid]}") + + if cur_upload_items[tid].retry_count < MAX_RETRY_COUNT: + item = cur_upload_items[tid] + item = item._replace( + retry_count=item.retry_count + 1, + progress=0, + current=False + ) upload_queue.put_nowait(item) + cur_upload_items[tid] = None for _ in range(RETRY_DELAY): time.sleep(1) if end_event.is_set(): break + except queue.Empty: pass except Exception: cloudlog.exception("athena.upload_handler.exception") -def _do_upload(upload_item): +def _do_upload(upload_item, callback=None): with open(upload_item.path, "rb") as f: size = os.fstat(f.fileno()).st_size + + if callback: + f = CallbackReader(f, callback, size) + return requests.put(upload_item.url, data=f, headers={**upload_item.headers, 'Content-Length': str(size)}, @@ -212,7 +233,8 @@ def uploadFileToUrl(fn, url, headers): @dispatcher.add_method def listUploadQueue(): - return [item._asdict() for item in list(upload_queue.queue)] + items = list(upload_queue.queue) + list(cur_upload_items.values()) + return [i._asdict() for i in items if i is not None] @dispatcher.add_method @@ -514,6 +536,8 @@ def main(): manage_tokens(api) conn_retries = 0 + cur_upload_items.clear() + handle_long_poll(ws) except (KeyboardInterrupt, SystemExit): break diff --git a/selfdrive/athena/tests/test_athenad.py b/selfdrive/athena/tests/test_athenad.py index 67069f136f..afe8988191 100755 --- a/selfdrive/athena/tests/test_athenad.py +++ b/selfdrive/athena/tests/test_athenad.py @@ -30,6 +30,10 @@ class TestAthenadMethods(unittest.TestCase): athenad.Api = MockApi athenad.LOCAL_PORT_WHITELIST = set([cls.SOCKET_PORT]) + def tearDown(self): + athenad.upload_queue = queue.Queue() + athenad.cur_upload_items.clear() + def wait_for_upload(self): now = time.time() while time.time() - now < 5: @@ -96,7 +100,6 @@ class TestAthenadMethods(unittest.TestCase): self.assertIsNotNone(resp['item'].get('id')) self.assertEqual(athenad.upload_queue.qsize(), 1) finally: - athenad.upload_queue = queue.Queue() os.unlink(fn) @with_http_server @@ -118,7 +121,6 @@ class TestAthenadMethods(unittest.TestCase): self.assertEqual(athenad.upload_queue.qsize(), 0) finally: end_event.set() - athenad.upload_queue = queue.Queue() os.unlink(fn) def test_upload_handler_timeout(self): @@ -150,7 +152,6 @@ class TestAthenadMethods(unittest.TestCase): finally: end_event.set() - athenad.upload_queue = queue.Queue() os.unlink(fn) def test_cancelUpload(self): @@ -171,18 +172,41 @@ class TestAthenadMethods(unittest.TestCase): self.assertEqual(len(athenad.cancelled_uploads), 0) finally: end_event.set() - athenad.upload_queue = queue.Queue() - def test_listUploadQueue(self): - item = athenad.UploadItem(path="qlog.bz2", url="http://localhost:44444/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='id') - athenad.upload_queue.put_nowait(item) + def test_listUploadQueueEmpty(self): + items = dispatcher["listUploadQueue"]() + self.assertEqual(len(items), 0) + + @with_http_server + def test_listUploadQueueCurrent(self, host): + fn = os.path.join(athenad.ROOT, 'qlog.bz2') + Path(fn).touch() + item = athenad.UploadItem(path=fn, url=f"{host}/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='') + + end_event = threading.Event() + thread = threading.Thread(target=athenad.upload_handler, args=(end_event,)) + thread.start() try: + athenad.upload_queue.put_nowait(item) + self.wait_for_upload() + items = dispatcher["listUploadQueue"]() self.assertEqual(len(items), 1) - self.assertDictEqual(items[0], item._asdict()) + self.assertTrue(items[0]['current']) + finally: - athenad.upload_queue = queue.Queue() + end_event.set() + os.unlink(fn) + + def test_listUploadQueue(self): + item = athenad.UploadItem(path="qlog.bz2", url="http://localhost:44444/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='id') + athenad.upload_queue.put_nowait(item) + + items = dispatcher["listUploadQueue"]() + self.assertEqual(len(items), 1) + self.assertDictEqual(items[0], item._asdict()) + self.assertFalse(items[0]['current']) @mock.patch('selfdrive.athena.athenad.create_connection') def test_startLocalProxy(self, mock_create_connection):