diff --git a/selfdrive/car/tests/test_models.py b/selfdrive/car/tests/test_models.py index 25db426d23..ab0e56827c 100644 --- a/selfdrive/car/tests/test_models.py +++ b/selfdrive/car/tests/test_models.py @@ -132,7 +132,7 @@ class TestCarModelBase(unittest.TestCase): segment_range = f"{cls.test_route.route}/{seg}" try: - lr = LogReader(segment_range, default_source=internal_source if is_internal else openpilotci_source) + lr = LogReader(segment_range, source=internal_source if is_internal else openpilotci_source) return cls.get_testing_data_from_logreader(lr) except Exception: pass diff --git a/tools/lib/logreader.py b/tools/lib/logreader.py index 1f5309c1f3..eb422b0273 100755 --- a/tools/lib/logreader.py +++ b/tools/lib/logreader.py @@ -227,18 +227,11 @@ def auto_source(sr: SegmentRange, mode=ReadMode.RLOG) -> list[LogPath]: "\n - ".join([f"{k}: {repr(v)}" for k, v in exceptions.items()])) -def parse_useradmin(identifier: str): +def parse_indirect(identifier: str) -> str: if "useradmin.comma.ai" in identifier: query = parse_qs(urlparse(identifier).query) return query["onebox"][0] - return None - - -def parse_cabana(identifier: str): - if "cabana.comma.ai" in identifier: - query = parse_qs(urlparse(identifier).query) - return query["route"][0] - return None + return identifier def parse_direct(identifier: str): @@ -247,32 +240,20 @@ def parse_direct(identifier: str): return None -def parse_indirect(identifier: str): - parsed = parse_useradmin(identifier) or parse_cabana(identifier) - - if parsed is not None: - return parsed, comma_api_source, True - - return identifier, None, False - - class LogReader: - def _parse_identifiers(self, identifier: str | list[str]): - if isinstance(identifier, list): - return [i for j in identifier for i in self._parse_identifiers(j)] - - parsed, source, is_indirect = parse_indirect(identifier) + def _parse_identifier(self, identifier: str) -> list[LogPath]: + # useradmin, etc. + identifier = parse_indirect(identifier) - if not is_indirect: - direct_parsed = parse_direct(identifier) - if direct_parsed is not None: - return direct_source(identifier) + # direct url or file + direct_parsed = parse_direct(identifier) + if direct_parsed is not None: + return direct_source(identifier) - sr = SegmentRange(parsed) + sr = SegmentRange(identifier) mode = self.default_mode if sr.selector is None else ReadMode(sr.selector) - source = self.default_source if source is None else source - identifiers = source(sr, mode) + identifiers = self.source(sr, mode) invalid_count = len(list(get_invalid_files(identifiers))) assert invalid_count == 0, (f"{invalid_count}/{len(identifiers)} invalid log(s) found, please ensure all logs " + @@ -280,10 +261,12 @@ class LogReader: return identifiers def __init__(self, identifier: str | list[str], default_mode: ReadMode = ReadMode.RLOG, - default_source: Source = auto_source, sort_by_time=False, only_union_types=False): + source: Source = auto_source, sort_by_time=False, only_union_types=False): self.default_mode = default_mode - self.default_source = default_source + self.source = source self.identifier = identifier + if isinstance(identifier, str): + self.identifier = [identifier] self.sort_by_time = sort_by_time self.only_union_types = only_union_types @@ -312,7 +295,9 @@ class LogReader: return ret def reset(self): - self.logreader_identifiers = self._parse_identifiers(self.identifier) + self.logreader_identifiers = [] + for identifier in self.identifier: + self.logreader_identifiers.extend(self._parse_identifier(identifier)) @staticmethod def from_bytes(dat): diff --git a/tools/lib/tests/test_logreader.py b/tools/lib/tests/test_logreader.py index f827c25902..230b6a65ea 100644 --- a/tools/lib/tests/test_logreader.py +++ b/tools/lib/tests/test_logreader.py @@ -70,10 +70,9 @@ class TestLogReader: (f"https://useradmin.comma.ai/?onebox={TEST_ROUTE}", ALL_SEGS), (f"https://useradmin.comma.ai/?onebox={TEST_ROUTE.replace('/', '|')}", ALL_SEGS), (f"https://useradmin.comma.ai/?onebox={TEST_ROUTE.replace('/', '%7C')}", ALL_SEGS), - (f"https://cabana.comma.ai/?route={TEST_ROUTE}", ALL_SEGS), ]) def test_indirect_parsing(self, identifier, expected): - parsed, _, _ = parse_indirect(identifier) + parsed = parse_indirect(identifier) sr = SegmentRange(parsed) assert list(sr.seg_idxs) == expected, identifier @@ -194,17 +193,17 @@ class TestLogReader: with subtests.test("interactive_yes"): mocker.patch("sys.stdin", new=io.StringIO("y\n")) - lr = LogReader(f"{TEST_ROUTE}/0", default_mode=ReadMode.AUTO_INTERACTIVE, default_source=comma_api_source) + lr = LogReader(f"{TEST_ROUTE}/0", default_mode=ReadMode.AUTO_INTERACTIVE, source=comma_api_source) log_len = len(list(lr)) assert qlog_len == log_len with subtests.test("interactive_no"): mocker.patch("sys.stdin", new=io.StringIO("n\n")) with pytest.raises(AssertionError): - lr = LogReader(f"{TEST_ROUTE}/0", default_mode=ReadMode.AUTO_INTERACTIVE, default_source=comma_api_source) + lr = LogReader(f"{TEST_ROUTE}/0", default_mode=ReadMode.AUTO_INTERACTIVE, source=comma_api_source) with subtests.test("non_interactive"): - lr = LogReader(f"{TEST_ROUTE}/0", default_mode=ReadMode.AUTO, default_source=comma_api_source) + lr = LogReader(f"{TEST_ROUTE}/0", default_mode=ReadMode.AUTO, source=comma_api_source) log_len = len(list(lr)) assert qlog_len == log_len