diff --git a/tools/lib/route.py b/tools/lib/route.py index 93788e7ab3..d3df93ccda 100644 --- a/tools/lib/route.py +++ b/tools/lib/route.py @@ -1,6 +1,6 @@ import os import re -from urllib.parse import parse_qs, urlparse +from urllib.parse import urlparse from collections import defaultdict from itertools import chain from typing import Optional @@ -231,27 +231,8 @@ class SegmentName: def __str__(self) -> str: return self._canonical_name -def parse_useradmin(segment_range): - if "useradmin.comma.ai" in segment_range: - query = parse_qs(urlparse(segment_range).query) - return query["onebox"][0] - return segment_range - -def parse_cabana(segment_range): - if "cabana.comma.ai" in segment_range: - query = parse_qs(urlparse(segment_range).query) - return query["route"][0] - return segment_range - -def parse_cd(segment_range): - return segment_range.replace("cd:/", "") - class SegmentRange: def __init__(self, segment_range: str): - segment_range = parse_useradmin(segment_range) - segment_range = parse_cabana(segment_range) - segment_range = parse_cd(segment_range) - self.m = re.fullmatch(RE.SEGMENT_RANGE, segment_range) assert self.m, f"Segment range is not valid {segment_range}" diff --git a/tools/lib/srreader.py b/tools/lib/srreader.py index 08abd03d33..db2ebd1936 100644 --- a/tools/lib/srreader.py +++ b/tools/lib/srreader.py @@ -1,6 +1,9 @@ import enum -import re import numpy as np +import pathlib +import re +from urllib.parse import parse_qs, urlparse + from openpilot.selfdrive.test.openpilotci import get_url from openpilot.tools.lib.helpers import RE from openpilot.tools.lib.logreader import LogReader @@ -37,6 +40,10 @@ def comma_api_source(sr: SegmentRange, mode=ReadMode.RLOG, sort_by_time=False): log_paths = route.log_paths() if mode == ReadMode.RLOG else route.qlog_paths() + invalid_segs = [seg for seg in segs if log_paths[seg] is None] + + assert not len(invalid_segs), f"Some of the requested segments are not available: {invalid_segs}" + for seg in segs: yield LogReader(log_paths[seg], sort_by_time=sort_by_time) @@ -52,6 +59,9 @@ def openpilotci_source(sr: SegmentRange, mode=ReadMode.RLOG, sort_by_time=False) for seg in segs: yield LogReader(get_url(sr.route_name, seg, 'rlog' if mode == ReadMode.RLOG else 'qlog'), sort_by_time=sort_by_time) +def direct_source(file_or_url, sort_by_time): + yield LogReader(file_or_url, sort_by_time=sort_by_time) + def auto_source(*args, **kwargs): # Automatically determine viable source @@ -69,14 +79,61 @@ def auto_source(*args, **kwargs): return comma_api_source(*args, **kwargs) +def parse_useradmin(identifier): + if "useradmin.comma.ai" in identifier: + query = parse_qs(urlparse(identifier).query) + return query["onebox"][0] + return None + +def parse_cabana(identifier): + if "cabana.comma.ai" in identifier: + query = parse_qs(urlparse(identifier).query) + return query["route"][0] + return None + +def parse_cd(identifier): + if "cd:/" in identifier: + return identifier.replace("cd:/", "") + return None + +def parse_direct(identifier): + if "https://" in identifier or "http://" in identifier or pathlib.Path(identifier).exists(): + return identifier + return None + +def parse_indirect(identifier): + parsed = parse_useradmin(identifier) or parse_cabana(identifier) + + if parsed is not None: + return parsed, comma_api_source, True + + parsed = parse_cd(identifier) + if parsed is not None: + return parsed, internal_source, True + + return identifier, None, False class SegmentRangeReader: - def __init__(self, segment_range: str, default_mode=ReadMode.RLOG, default_source=auto_source, sort_by_time=False): - sr = SegmentRange(segment_range) + def _logreaders_from_identifier(self, identifier): + parsed, source, is_indirect = parse_indirect(identifier) + + if not is_indirect: + direct_parsed = parse_direct(identifier) + if direct_parsed is not None: + return direct_source(identifier, sort_by_time=self.sort_by_time) + + sr = SegmentRange(parsed) + mode = self.default_mode if sr.selector is None else ReadMode(sr.selector) + source = self.default_source if source is None else source + + return source(sr, mode, sort_by_time=self.sort_by_time) - mode = default_mode if sr.selector is None else ReadMode(sr.selector) + def __init__(self, identifier: str, default_mode=ReadMode.RLOG, default_source=auto_source, sort_by_time=False): + self.default_mode = default_mode + self.default_source = default_source + self.sort_by_time = sort_by_time - self.lrs = default_source(sr, mode, sort_by_time) + self.lrs = self._logreaders_from_identifier(identifier) def __iter__(self): for lr in self.lrs: diff --git a/tools/lib/tests/test_srreader.py b/tools/lib/tests/test_srreader.py index e1c54557be..b22599dee8 100644 --- a/tools/lib/tests/test_srreader.py +++ b/tools/lib/tests/test_srreader.py @@ -1,13 +1,17 @@ +import shutil +import tempfile import numpy as np import unittest from parameterized import parameterized +import requests from openpilot.tools.lib.route import SegmentRange -from openpilot.tools.lib.srreader import ReadMode, SegmentRangeReader, parse_slice +from openpilot.tools.lib.srreader import ReadMode, SegmentRangeReader, parse_slice, parse_indirect NUM_SEGS = 17 # number of segments in the test route ALL_SEGS = list(np.arange(NUM_SEGS)) TEST_ROUTE = "344c5c15b34f2d8a/2024-01-03--09-37-12" +QLOG_FILE = "https://commadataci.blob.core.windows.net/openpilotci/0375fdf7b1ce594d/2019-06-13--08-32-25/3/qlog.bz2" class TestSegmentRangeReader(unittest.TestCase): @parameterized.expand([ @@ -36,11 +40,23 @@ class TestSegmentRangeReader(unittest.TestCase): (f"https://cabana.comma.ai/?route={TEST_ROUTE}", ALL_SEGS), (f"cd:/{TEST_ROUTE}", ALL_SEGS), ]) - def test_parse_slice(self, segment_range, expected): - sr = SegmentRange(segment_range) + def test_indirect_parsing(self, identifier, expected): + parsed, _, _ = parse_indirect(identifier) + sr = SegmentRange(parsed) segs = parse_slice(sr) self.assertListEqual(list(segs), expected) + def test_direct_parsing(self): + qlog = tempfile.NamedTemporaryFile(mode='wb', delete=False) + + with requests.get(QLOG_FILE, stream=True) as r: + with qlog as f: + shutil.copyfileobj(r.raw, f) + + for f in [QLOG_FILE, qlog.name]: + l = len(list(SegmentRangeReader(f))) + self.assertGreater(l, 100) + @parameterized.expand([ (f"{TEST_ROUTE}///",), (f"{TEST_ROUTE}---",),