From edd4649d2479ce6e7a5d4927b03688650c9069fc Mon Sep 17 00:00:00 2001 From: Justin Newberry Date: Wed, 14 Feb 2024 13:34:17 -0500 Subject: [PATCH] LogReader: add typing hints (#31464) logreader typing old-commit-hash: 33cf6bda9ef6b1ed19f3a0fed4a5914a414ae653 --- tools/lib/logreader.py | 63 ++++++++++++++++++++++-------------------- 1 file changed, 33 insertions(+), 30 deletions(-) diff --git a/tools/lib/logreader.py b/tools/lib/logreader.py index 0a555708b7..af2c23ef48 100755 --- a/tools/lib/logreader.py +++ b/tools/lib/logreader.py @@ -11,7 +11,7 @@ import tqdm import urllib.parse import warnings -from typing import Dict, Iterable, Iterator, List, Type +from typing import Callable, Dict, Iterable, Iterator, List, Optional, Type from urllib.parse import parse_qs, urlparse from cereal import log as capnp_log @@ -76,11 +76,16 @@ class ReadMode(enum.StrEnum): AUTO_INTERACIVE = "i" # default to rlogs, fallback to qlogs with a prompt from the user -def default_valid_file(fn): +LogPath = Optional[str] +LogPaths = List[LogPath] +ValidFileCallable = Callable[[LogPath], bool] +Source = Callable[[SegmentRange, ReadMode], LogPaths] + +def default_valid_file(fn: LogPath) -> bool: return fn is not None and file_exists(fn) -def auto_strategy(rlog_paths, qlog_paths, interactive, valid_file): +def auto_strategy(rlog_paths: LogPaths, qlog_paths: LogPaths, interactive: bool, valid_file: ValidFileCallable) -> LogPaths: # auto select logs based on availability if any(rlog is None or not valid_file(rlog) for rlog in rlog_paths): if interactive: @@ -89,12 +94,12 @@ def auto_strategy(rlog_paths, qlog_paths, interactive, valid_file): else: cloudlog.warning("Some rlogs were not found, falling back to qlogs for those segments...") - return [rlog if (valid_file(rlog)) else (qlog if (valid_file(qlog)) else None) + return [rlog if valid_file(rlog) else (qlog if valid_file(qlog) else None) for (rlog, qlog) in zip(rlog_paths, qlog_paths, strict=True)] return rlog_paths -def apply_strategy(mode: ReadMode, rlog_paths, qlog_paths, valid_file=default_valid_file): +def apply_strategy(mode: ReadMode, rlog_paths: LogPaths, qlog_paths: LogPaths, valid_file: ValidFileCallable = default_valid_file) -> LogPaths: if mode == ReadMode.RLOG: return rlog_paths elif mode == ReadMode.QLOG: @@ -103,9 +108,10 @@ def apply_strategy(mode: ReadMode, rlog_paths, qlog_paths, valid_file=default_va return auto_strategy(rlog_paths, qlog_paths, False, valid_file) elif mode == ReadMode.AUTO_INTERACIVE: return auto_strategy(rlog_paths, qlog_paths, True, valid_file) + raise Exception(f"invalid mode: {mode}") -def comma_api_source(sr: SegmentRange, mode: ReadMode): +def comma_api_source(sr: SegmentRange, mode: ReadMode) -> LogPaths: route = Route(sr.route_name) rlog_paths = [route.log_paths()[seg] for seg in sr.seg_idxs] @@ -118,7 +124,7 @@ def comma_api_source(sr: SegmentRange, mode: ReadMode): return apply_strategy(mode, rlog_paths, qlog_paths, valid_file=valid_file) -def internal_source(sr: SegmentRange, mode: ReadMode): +def internal_source(sr: SegmentRange, mode: ReadMode) -> LogPaths: if not internal_source_available(): raise Exception("Internal source not available") @@ -131,19 +137,18 @@ def internal_source(sr: SegmentRange, mode: ReadMode): return apply_strategy(mode, rlog_paths, qlog_paths) -def openpilotci_source(sr: SegmentRange, mode: ReadMode): +def openpilotci_source(sr: SegmentRange, mode: ReadMode) -> LogPaths: rlog_paths = [get_url(sr.route_name, seg, "rlog") for seg in sr.seg_idxs] qlog_paths = [get_url(sr.route_name, seg, "qlog") for seg in sr.seg_idxs] return apply_strategy(mode, rlog_paths, qlog_paths) -def comma_car_segments_source(sr: SegmentRange, mode=ReadMode.RLOG): +def comma_car_segments_source(sr: SegmentRange, mode=ReadMode.RLOG) -> LogPaths: return [get_comma_segments_url(sr.route_name, seg) for seg in sr.seg_idxs] - -def direct_source(file_or_url): +def direct_source(file_or_url: str) -> LogPaths: return [file_or_url] @@ -153,52 +158,49 @@ def get_invalid_files(files): yield f -def check_source(source, *args): - try: - files = source(*args) - assert next(get_invalid_files(files), None) is None - return None, files - except Exception as e: - return e, None +def check_source(source: Source, *args) -> LogPaths: + files = source(*args) + assert next(get_invalid_files(files), None) is None + return files -def auto_source(sr: SegmentRange, mode=ReadMode.RLOG): +def auto_source(sr: SegmentRange, mode=ReadMode.RLOG) -> LogPaths: if mode == ReadMode.SANITIZED: return comma_car_segments_source(sr, mode) + SOURCES: List[Source] = [internal_source, openpilotci_source, comma_api_source, comma_car_segments_source,] exceptions = [] # Automatically determine viable source - for source in [internal_source, openpilotci_source, comma_api_source, comma_car_segments_source]: - exception, ret = check_source(source, sr, mode) - if exception is None: - return ret - else: - exceptions.append(exception) + for source in SOURCES: + try: + return check_source(source, sr, mode) + except Exception as e: + exceptions.append(e) raise Exception(f"auto_source could not find any valid source, exceptions for sources: {exceptions}") -def parse_useradmin(identifier): +def parse_useradmin(identifier: str): if "useradmin.comma.ai" in identifier: query = parse_qs(urlparse(identifier).query) return query["onebox"][0] return None -def parse_cabana(identifier): +def parse_cabana(identifier: str): if "cabana.comma.ai" in identifier: query = parse_qs(urlparse(identifier).query) return query["route"][0] return None -def parse_direct(identifier): +def parse_direct(identifier: str): if identifier.startswith(("http://", "https://", "cd:/")) or pathlib.Path(identifier).exists(): return identifier return None -def parse_indirect(identifier): +def parse_indirect(identifier: str): parsed = parse_useradmin(identifier) or parse_cabana(identifier) if parsed is not None: @@ -230,7 +232,8 @@ class LogReader: are uploaded or auto fallback to qlogs with '/a' selector at the end of the route name." return identifiers - def __init__(self, identifier: str | List[str], default_mode=ReadMode.RLOG, default_source=auto_source, sort_by_time=False, only_union_types=False): + def __init__(self, identifier: str | List[str], default_mode: ReadMode = ReadMode.RLOG, + default_source=auto_source, sort_by_time=False, only_union_types=False): self.default_mode = default_mode self.default_source = default_source self.identifier = identifier