diff --git a/tools/lib/filereader.py b/tools/lib/filereader.py index 4aec965f1a..15e618f649 100644 --- a/tools/lib/filereader.py +++ b/tools/lib/filereader.py @@ -1,4 +1,6 @@ import os +import requests + from openpilot.tools.lib.url_file import URLFile DATA_ENDPOINT = os.getenv("DATA_ENDPOINT", "http://data-raw.comma.internal/") @@ -8,6 +10,12 @@ def resolve_name(fn): return fn.replace("cd:/", DATA_ENDPOINT) return fn +def file_exists(fn): + fn = resolve_name(fn) + if fn.startswith(("http://", "https://")): + return requests.head(fn, allow_redirects=True).status_code == 200 + return os.path.exists(fn) + def FileReader(fn, debug=False): fn = resolve_name(fn) if fn.startswith(("http://", "https://")): diff --git a/tools/lib/logreader.py b/tools/lib/logreader.py index 36cff0d724..36d099de50 100755 --- a/tools/lib/logreader.py +++ b/tools/lib/logreader.py @@ -17,7 +17,7 @@ from urllib.parse import parse_qs, urlparse from cereal import log as capnp_log from openpilot.tools.lib.openpilotci import get_url -from openpilot.tools.lib.filereader import FileReader +from openpilot.tools.lib.filereader import FileReader, file_exists from openpilot.tools.lib.helpers import RE from openpilot.tools.lib.route import Route, SegmentRange @@ -84,15 +84,13 @@ def create_slice_from_string(s: str): return start return slice(start, end, step) -def parse_slice(sr: SegmentRange): - route = Route(sr.route_name) +def parse_slice(sr: SegmentRange, route: Route): segs = np.arange(route.max_seg_number+1) s = create_slice_from_string(sr._slice) return segs[s] if isinstance(s, slice) else [segs[s]] -def comma_api_source(sr: SegmentRange, mode=ReadMode.RLOG): - segs = parse_slice(sr) - route = Route(sr.route_name) +def comma_api_source(sr: SegmentRange, route: Route, mode=ReadMode.RLOG): + segs = parse_slice(sr, route) log_paths = route.log_paths() if mode == ReadMode.RLOG else route.qlog_paths() @@ -102,35 +100,33 @@ def comma_api_source(sr: SegmentRange, mode=ReadMode.RLOG): return [(log_paths[seg]) for seg in segs] -def internal_source(sr: SegmentRange, mode=ReadMode.RLOG): - segs = parse_slice(sr) +def internal_source(sr: SegmentRange, route: Route, mode=ReadMode.RLOG): + segs = parse_slice(sr, route) return [f"cd:/{sr.dongle_id}/{sr.timestamp}/{seg}/{'rlog' if mode == ReadMode.RLOG else 'qlog'}.bz2" for seg in segs] -def openpilotci_source(sr: SegmentRange, mode=ReadMode.RLOG): - segs = parse_slice(sr) +def openpilotci_source(sr: SegmentRange, route: Route, mode=ReadMode.RLOG): + segs = parse_slice(sr, route) return [get_url(sr.route_name, seg, 'rlog' if mode == ReadMode.RLOG else 'qlog') for seg in segs] def direct_source(file_or_url): return [file_or_url] -def auto_source(*args): - # Automatically determine viable source - +def check_source(source, *args): try: - identifiers = internal_source(*args) - _LogFileReader(identifiers[0]) - return internal_source(*args) + files = source(*args) + assert all(file_exists(f) for f in files) + return True, files except Exception: - pass + return False, None - try: - identifiers = openpilotci_source(*args) - _LogFileReader(identifiers[0]) - return openpilotci_source(*args) - except Exception: - pass +def auto_source(*args): + # Automatically determine viable source + for source in [internal_source, openpilotci_source]: + valid, ret = check_source(source, *args) + if valid: + return ret return comma_api_source(*args) @@ -173,10 +169,11 @@ class LogReader: return direct_source(identifier) sr = SegmentRange(parsed) + route = Route(sr.route_name) mode = self.default_mode if sr.selector is None else ReadMode(sr.selector) source = self.default_source if source is None else source - return source(sr, mode) + return source(sr, route, mode) def __init__(self, identifier: str | List[str], default_mode=ReadMode.RLOG, default_source=auto_source, sort_by_time=False, only_union_types=False): self.default_mode = default_mode @@ -189,7 +186,6 @@ class LogReader: self.reset() def __iter__(self): - self.reset() for identifier in self.logreader_identifiers: yield from _LogFileReader(identifier) diff --git a/tools/lib/tests/test_logreader.py b/tools/lib/tests/test_logreader.py index a9769e5e63..74ee82d152 100644 --- a/tools/lib/tests/test_logreader.py +++ b/tools/lib/tests/test_logreader.py @@ -6,7 +6,7 @@ import pytest from parameterized import parameterized import requests from openpilot.tools.lib.logreader import LogReader, parse_indirect, parse_slice, ReadMode -from openpilot.tools.lib.route import SegmentRange +from openpilot.tools.lib.route import Route, SegmentRange NUM_SEGS = 17 # number of segments in the test route ALL_SEGS = list(np.arange(NUM_SEGS)) @@ -42,7 +42,8 @@ class TestLogReader(unittest.TestCase): def test_indirect_parsing(self, identifier, expected): parsed, _, _ = parse_indirect(identifier) sr = SegmentRange(parsed) - segs = parse_slice(sr) + route = Route(sr.route_name) + segs = parse_slice(sr, route) self.assertListEqual(list(segs), expected) def test_direct_parsing(self): @@ -91,6 +92,14 @@ class TestLogReader(unittest.TestCase): self.assertEqual(qlog_len*2, qlog_len_2) + @pytest.mark.slow + def test_multiple_iterations(self): + lr = LogReader(f"{TEST_ROUTE}/0/q") + qlog_len1 = len(list(lr)) + qlog_len2 = len(list(lr)) + + self.assertEqual(qlog_len1, qlog_len2) + if __name__ == "__main__": unittest.main()