diff --git a/tools/lib/logreader.py b/tools/lib/logreader.py index 2269430792..7957712aff 100755 --- a/tools/lib/logreader.py +++ b/tools/lib/logreader.py @@ -81,6 +81,8 @@ LogPaths = List[LogPath] ValidFileCallable = Callable[[LogPath], bool] Source = Callable[[SegmentRange, ReadMode], LogPaths] +InternalUnavailableException = Exception("Internal source not available") + def default_valid_file(fn: LogPath) -> bool: return fn is not None and file_exists(fn) @@ -126,7 +128,7 @@ def comma_api_source(sr: SegmentRange, mode: ReadMode) -> LogPaths: def internal_source(sr: SegmentRange, mode: ReadMode) -> LogPaths: if not internal_source_available(): - raise Exception("Internal source not available") + raise InternalUnavailableException def get_internal_url(sr: SegmentRange, seg, file): return f"cd:/{sr.dongle_id}/{sr.timestamp}/{seg}/{file}.bz2" @@ -160,7 +162,7 @@ def get_invalid_files(files): def check_source(source: Source, *args) -> LogPaths: files = source(*args) - assert next(get_invalid_files(files), None) is None + assert next(get_invalid_files(files), False) is False return files diff --git a/tools/lib/tests/test_logreader.py b/tools/lib/tests/test_logreader.py index 974182d638..2141915b87 100755 --- a/tools/lib/tests/test_logreader.py +++ b/tools/lib/tests/test_logreader.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +import contextlib import io import shutil import tempfile @@ -10,7 +11,7 @@ import requests from parameterized import parameterized from unittest import mock -from openpilot.tools.lib.logreader import LogIterable, LogReader, comma_api_source, parse_indirect, ReadMode +from openpilot.tools.lib.logreader import LogIterable, LogReader, comma_api_source, parse_indirect, ReadMode, InternalUnavailableException from openpilot.tools.lib.route import SegmentRange from openpilot.tools.lib.url_file import URLFileException @@ -24,6 +25,24 @@ def noop(segment: LogIterable): return segment +@contextlib.contextmanager +def setup_source_scenario(is_internal=False): + with ( + mock.patch("openpilot.tools.lib.logreader.internal_source") as internal_source_mock, + mock.patch("openpilot.tools.lib.logreader.openpilotci_source") as openpilotci_source_mock, + mock.patch("openpilot.tools.lib.logreader.comma_api_source") as comma_api_source_mock, + ): + if is_internal: + internal_source_mock.return_value = [QLOG_FILE] + else: + internal_source_mock.side_effect = InternalUnavailableException + + openpilotci_source_mock.return_value = [None] + comma_api_source_mock.return_value = [QLOG_FILE] + + yield + + class TestLogReader(unittest.TestCase): @parameterized.expand([ (f"{TEST_ROUTE}", ALL_SEGS), @@ -186,6 +205,17 @@ class TestLogReader(unittest.TestCase): log_len = len(list(lr)) self.assertEqual(qlog_len, log_len) + @parameterized.expand([(True,), (False,)]) + @pytest.mark.slow + def test_auto_source_scenarios(self, is_internal): + lr = LogReader(QLOG_FILE) + qlog_len = len(list(lr)) + + with setup_source_scenario(is_internal=is_internal): + lr = LogReader(f"{TEST_ROUTE}/0/q") + log_len = len(list(lr)) + self.assertEqual(qlog_len, log_len) + if __name__ == "__main__": unittest.main()