segmentrangereader: support direct parsing (#30973)

* use correct source

* revert

* cleanup imports

* clean

* direct parsing

* rename

* move up

* fixes

* fix that

* better error message
old-commit-hash: eb09294fc2
chrysler-long2
Justin Newberry 1 year ago committed by GitHub
parent 502322ed11
commit 1b65d5cd85
  1. 21
      tools/lib/route.py
  2. 67
      tools/lib/srreader.py
  3. 22
      tools/lib/tests/test_srreader.py

@ -1,6 +1,6 @@
import os import os
import re import re
from urllib.parse import parse_qs, urlparse from urllib.parse import urlparse
from collections import defaultdict from collections import defaultdict
from itertools import chain from itertools import chain
from typing import Optional from typing import Optional
@ -231,27 +231,8 @@ class SegmentName:
def __str__(self) -> str: return self._canonical_name def __str__(self) -> str: return self._canonical_name
def parse_useradmin(segment_range):
if "useradmin.comma.ai" in segment_range:
query = parse_qs(urlparse(segment_range).query)
return query["onebox"][0]
return segment_range
def parse_cabana(segment_range):
if "cabana.comma.ai" in segment_range:
query = parse_qs(urlparse(segment_range).query)
return query["route"][0]
return segment_range
def parse_cd(segment_range):
return segment_range.replace("cd:/", "")
class SegmentRange: class SegmentRange:
def __init__(self, segment_range: str): def __init__(self, segment_range: str):
segment_range = parse_useradmin(segment_range)
segment_range = parse_cabana(segment_range)
segment_range = parse_cd(segment_range)
self.m = re.fullmatch(RE.SEGMENT_RANGE, segment_range) self.m = re.fullmatch(RE.SEGMENT_RANGE, segment_range)
assert self.m, f"Segment range is not valid {segment_range}" assert self.m, f"Segment range is not valid {segment_range}"

@ -1,6 +1,9 @@
import enum import enum
import re
import numpy as np import numpy as np
import pathlib
import re
from urllib.parse import parse_qs, urlparse
from openpilot.selfdrive.test.openpilotci import get_url from openpilot.selfdrive.test.openpilotci import get_url
from openpilot.tools.lib.helpers import RE from openpilot.tools.lib.helpers import RE
from openpilot.tools.lib.logreader import LogReader from openpilot.tools.lib.logreader import LogReader
@ -37,6 +40,10 @@ def comma_api_source(sr: SegmentRange, mode=ReadMode.RLOG, sort_by_time=False):
log_paths = route.log_paths() if mode == ReadMode.RLOG else route.qlog_paths() log_paths = route.log_paths() if mode == ReadMode.RLOG else route.qlog_paths()
invalid_segs = [seg for seg in segs if log_paths[seg] is None]
assert not len(invalid_segs), f"Some of the requested segments are not available: {invalid_segs}"
for seg in segs: for seg in segs:
yield LogReader(log_paths[seg], sort_by_time=sort_by_time) yield LogReader(log_paths[seg], sort_by_time=sort_by_time)
@ -52,6 +59,9 @@ def openpilotci_source(sr: SegmentRange, mode=ReadMode.RLOG, sort_by_time=False)
for seg in segs: for seg in segs:
yield LogReader(get_url(sr.route_name, seg, 'rlog' if mode == ReadMode.RLOG else 'qlog'), sort_by_time=sort_by_time) yield LogReader(get_url(sr.route_name, seg, 'rlog' if mode == ReadMode.RLOG else 'qlog'), sort_by_time=sort_by_time)
def direct_source(file_or_url, sort_by_time):
yield LogReader(file_or_url, sort_by_time=sort_by_time)
def auto_source(*args, **kwargs): def auto_source(*args, **kwargs):
# Automatically determine viable source # Automatically determine viable source
@ -69,14 +79,61 @@ def auto_source(*args, **kwargs):
return comma_api_source(*args, **kwargs) return comma_api_source(*args, **kwargs)
def parse_useradmin(identifier):
if "useradmin.comma.ai" in identifier:
query = parse_qs(urlparse(identifier).query)
return query["onebox"][0]
return None
def parse_cabana(identifier):
if "cabana.comma.ai" in identifier:
query = parse_qs(urlparse(identifier).query)
return query["route"][0]
return None
def parse_cd(identifier):
if "cd:/" in identifier:
return identifier.replace("cd:/", "")
return None
def parse_direct(identifier):
if "https://" in identifier or "http://" in identifier or pathlib.Path(identifier).exists():
return identifier
return None
def parse_indirect(identifier):
parsed = parse_useradmin(identifier) or parse_cabana(identifier)
if parsed is not None:
return parsed, comma_api_source, True
parsed = parse_cd(identifier)
if parsed is not None:
return parsed, internal_source, True
return identifier, None, False
class SegmentRangeReader: class SegmentRangeReader:
def __init__(self, segment_range: str, default_mode=ReadMode.RLOG, default_source=auto_source, sort_by_time=False): def _logreaders_from_identifier(self, identifier):
sr = SegmentRange(segment_range) parsed, source, is_indirect = parse_indirect(identifier)
if not is_indirect:
direct_parsed = parse_direct(identifier)
if direct_parsed is not None:
return direct_source(identifier, sort_by_time=self.sort_by_time)
sr = SegmentRange(parsed)
mode = self.default_mode if sr.selector is None else ReadMode(sr.selector)
source = self.default_source if source is None else source
return source(sr, mode, sort_by_time=self.sort_by_time)
mode = default_mode if sr.selector is None else ReadMode(sr.selector) def __init__(self, identifier: str, default_mode=ReadMode.RLOG, default_source=auto_source, sort_by_time=False):
self.default_mode = default_mode
self.default_source = default_source
self.sort_by_time = sort_by_time
self.lrs = default_source(sr, mode, sort_by_time) self.lrs = self._logreaders_from_identifier(identifier)
def __iter__(self): def __iter__(self):
for lr in self.lrs: for lr in self.lrs:

@ -1,13 +1,17 @@
import shutil
import tempfile
import numpy as np import numpy as np
import unittest import unittest
from parameterized import parameterized from parameterized import parameterized
import requests
from openpilot.tools.lib.route import SegmentRange from openpilot.tools.lib.route import SegmentRange
from openpilot.tools.lib.srreader import ReadMode, SegmentRangeReader, parse_slice from openpilot.tools.lib.srreader import ReadMode, SegmentRangeReader, parse_slice, parse_indirect
NUM_SEGS = 17 # number of segments in the test route NUM_SEGS = 17 # number of segments in the test route
ALL_SEGS = list(np.arange(NUM_SEGS)) ALL_SEGS = list(np.arange(NUM_SEGS))
TEST_ROUTE = "344c5c15b34f2d8a/2024-01-03--09-37-12" TEST_ROUTE = "344c5c15b34f2d8a/2024-01-03--09-37-12"
QLOG_FILE = "https://commadataci.blob.core.windows.net/openpilotci/0375fdf7b1ce594d/2019-06-13--08-32-25/3/qlog.bz2"
class TestSegmentRangeReader(unittest.TestCase): class TestSegmentRangeReader(unittest.TestCase):
@parameterized.expand([ @parameterized.expand([
@ -36,11 +40,23 @@ class TestSegmentRangeReader(unittest.TestCase):
(f"https://cabana.comma.ai/?route={TEST_ROUTE}", ALL_SEGS), (f"https://cabana.comma.ai/?route={TEST_ROUTE}", ALL_SEGS),
(f"cd:/{TEST_ROUTE}", ALL_SEGS), (f"cd:/{TEST_ROUTE}", ALL_SEGS),
]) ])
def test_parse_slice(self, segment_range, expected): def test_indirect_parsing(self, identifier, expected):
sr = SegmentRange(segment_range) parsed, _, _ = parse_indirect(identifier)
sr = SegmentRange(parsed)
segs = parse_slice(sr) segs = parse_slice(sr)
self.assertListEqual(list(segs), expected) self.assertListEqual(list(segs), expected)
def test_direct_parsing(self):
qlog = tempfile.NamedTemporaryFile(mode='wb', delete=False)
with requests.get(QLOG_FILE, stream=True) as r:
with qlog as f:
shutil.copyfileobj(r.raw, f)
for f in [QLOG_FILE, qlog.name]:
l = len(list(SegmentRangeReader(f)))
self.assertGreater(l, 100)
@parameterized.expand([ @parameterized.expand([
(f"{TEST_ROUTE}///",), (f"{TEST_ROUTE}///",),
(f"{TEST_ROUTE}---",), (f"{TEST_ROUTE}---",),

Loading…
Cancel
Save