athenad: fix thread safety issues in upload handing (#34199)

* fix thread safety issues in upload handing

* remove cancelled_uploads

* remove None from current upload items & atomic updates
pull/34200/head
Dean Lee 5 months ago committed by GitHub
parent 015aadd48c
commit dcb3113c4b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 47
      system/athena/athenad.py
  2. 8
      system/athena/tests/test_athenad.py

@ -100,9 +100,9 @@ send_queue: Queue[str] = queue.Queue()
upload_queue: Queue[UploadItem] = queue.Queue() upload_queue: Queue[UploadItem] = queue.Queue()
low_priority_send_queue: Queue[str] = queue.Queue() low_priority_send_queue: Queue[str] = queue.Queue()
log_recv_queue: Queue[str] = queue.Queue() log_recv_queue: Queue[str] = queue.Queue()
cancelled_uploads: set[str] = set()
cur_upload_items: dict[int, UploadItem | None] = {} cur_upload_items: dict[int, UploadItem | None] = {}
cur_upload_items_lock = threading.Lock()
def strip_zst_extension(fn: str) -> str: def strip_zst_extension(fn: str) -> str:
@ -130,8 +130,9 @@ class UploadQueueCache:
@staticmethod @staticmethod
def cache(upload_queue: Queue[UploadItem]) -> None: def cache(upload_queue: Queue[UploadItem]) -> None:
try: try:
queue: list[UploadItem | None] = list(upload_queue.queue) with upload_queue.mutex:
items = [asdict(i) for i in queue if i is not None and (i.id not in cancelled_uploads)] items = [asdict(item) for item in upload_queue.queue]
Params().put("AthenadUploadQueue", json.dumps(items)) Params().put("AthenadUploadQueue", json.dumps(items))
except Exception: except Exception:
cloudlog.exception("athena.UploadQueueCache.cache.exception") cloudlog.exception("athena.UploadQueueCache.cache.exception")
@ -198,10 +199,12 @@ def retry_upload(tid: int, end_event: threading.Event, increase_count: bool = Tr
progress=0, progress=0,
current=False current=False
) )
upload_queue.put_nowait(item)
UploadQueueCache.cache(upload_queue)
cur_upload_items[tid] = None with cur_upload_items_lock:
upload_queue.put_nowait(item)
cur_upload_items[tid] = None
UploadQueueCache.cache(upload_queue)
for _ in range(RETRY_DELAY): for _ in range(RETRY_DELAY):
time.sleep(1) time.sleep(1)
@ -221,7 +224,8 @@ def cb(sm, item, tid, end_event: threading.Event, sz: int, cur: int) -> None:
if end_event.is_set(): if end_event.is_set():
raise AbortTransferException raise AbortTransferException
cur_upload_items[tid] = replace(item, progress=cur / sz if sz else 1) with cur_upload_items_lock:
cur_upload_items[tid] = replace(item, progress=cur / sz if sz else 1)
def upload_handler(end_event: threading.Event) -> None: def upload_handler(end_event: threading.Event) -> None:
@ -229,14 +233,10 @@ def upload_handler(end_event: threading.Event) -> None:
tid = threading.get_ident() tid = threading.get_ident()
while not end_event.is_set(): while not end_event.is_set():
cur_upload_items[tid] = None
try: try:
cur_upload_items[tid] = item = replace(upload_queue.get(timeout=1), current=True) with cur_upload_items_lock:
cur_upload_items[tid] = None
if item.id in cancelled_uploads: cur_upload_items[tid] = item = replace(upload_queue.get(timeout=1), current=True)
cancelled_uploads.remove(item.id)
continue
# Remove item if too old # Remove item if too old
age = datetime.now() - datetime.fromtimestamp(item.created_at / 1000) age = datetime.now() - datetime.fromtimestamp(item.created_at / 1000)
@ -415,8 +415,10 @@ def uploadFilesToUrls(files_data: list[UploadFileDict]) -> UploadFilesToUrlRespo
@dispatcher.add_method @dispatcher.add_method
def listUploadQueue() -> list[UploadItemDict]: def listUploadQueue() -> list[UploadItemDict]:
items = list(upload_queue.queue) + list(cur_upload_items.values()) with cur_upload_items_lock, upload_queue.mutex:
return [asdict(i) for i in items if (i is not None) and (i.id not in cancelled_uploads)] items = list(upload_queue.queue) + [item for item in cur_upload_items.values() if item is not None]
return [asdict(item) for item in items]
@dispatcher.add_method @dispatcher.add_method
@ -424,13 +426,14 @@ def cancelUpload(upload_id: str | list[str]) -> dict[str, int | str]:
if not isinstance(upload_id, list): if not isinstance(upload_id, list):
upload_id = [upload_id] upload_id = [upload_id]
uploading_ids = {item.id for item in list(upload_queue.queue)} with upload_queue.mutex:
cancelled_ids = uploading_ids.intersection(upload_id) remaining_items = [item for item in upload_queue.queue if item.id not in upload_id]
if len(cancelled_ids) == 0: if len(remaining_items) == len(upload_queue.queue):
return {"success": 0, "error": "not found"} return {"success": 0, "error": "not found"}
cancelled_uploads.update(cancelled_ids) upload_queue.queue.clear()
return {"success": 1} upload_queue.queue.extend(remaining_items)
return {"success": 1}
@dispatcher.add_method @dispatcher.add_method
def setRouteViewed(route: str) -> dict[str, int | str]: def setRouteViewed(route: str) -> dict[str, int | str]:

@ -78,7 +78,6 @@ class TestAthenadMethods:
athenad.upload_queue = queue.Queue() athenad.upload_queue = queue.Queue()
athenad.cur_upload_items.clear() athenad.cur_upload_items.clear()
athenad.cancelled_uploads.clear()
for i in os.listdir(Paths.log_root()): for i in os.listdir(Paths.log_root()):
p = os.path.join(Paths.log_root(), i) p = os.path.join(Paths.log_root(), i)
@ -282,13 +281,10 @@ class TestAthenadMethods:
athenad.upload_queue.put_nowait(item) athenad.upload_queue.put_nowait(item)
dispatcher["cancelUpload"](item.id) dispatcher["cancelUpload"](item.id)
assert item.id in athenad.cancelled_uploads
self._wait_for_upload() self._wait_for_upload()
time.sleep(0.1) time.sleep(0.1)
assert athenad.upload_queue.qsize() == 0 assert athenad.upload_queue.qsize() == 0
assert len(athenad.cancelled_uploads) == 0
@with_upload_handler @with_upload_handler
def test_cancel_expiry(self): def test_cancel_expiry(self):
@ -331,7 +327,7 @@ class TestAthenadMethods:
assert items[0] == asdict(item) assert items[0] == asdict(item)
assert not items[0]['current'] assert not items[0]['current']
athenad.cancelled_uploads.add(item.id) dispatcher["cancelUpload"](item.id)
items = dispatcher["listUploadQueue"]() items = dispatcher["listUploadQueue"]()
assert len(items) == 0 assert len(items) == 0
@ -343,7 +339,7 @@ class TestAthenadMethods:
athenad.upload_queue.put_nowait(item2) athenad.upload_queue.put_nowait(item2)
# Ensure canceled items are not persisted # Ensure canceled items are not persisted
athenad.cancelled_uploads.add(item2.id) dispatcher["cancelUpload"](item2.id)
# serialize item # serialize item
athenad.UploadQueueCache.cache(athenad.upload_queue) athenad.UploadQueueCache.cache(athenad.upload_queue)

Loading…
Cancel
Save