SegmentRange: simplify slice (#31455)

* simplify slicing

* rm
pull/31451/head
Shane Smiskol 1 year ago committed by GitHub
parent c4f7991bb6
commit 8fe9bc7a69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 51
      tools/lib/logreader.py
  2. 22
      tools/lib/route.py
  3. 10
      tools/lib/tests/test_logreader.py

@ -4,10 +4,8 @@ from functools import partial
import multiprocessing
import capnp
import enum
import numpy as np
import os
import pathlib
import re
import sys
import tqdm
import urllib.parse
@ -21,7 +19,6 @@ from openpilot.common.swaglog import cloudlog
from openpilot.tools.lib.comma_car_segments import get_url as get_comma_segments_url
from openpilot.tools.lib.openpilotci import get_url
from openpilot.tools.lib.filereader import FileReader, file_exists, internal_source_available
from openpilot.tools.lib.helpers import RE
from openpilot.tools.lib.route import Route, SegmentRange
LogMessage = Type[capnp._DynamicStructReader]
@ -79,19 +76,6 @@ class ReadMode(enum.StrEnum):
AUTO_INTERACIVE = "i" # default to rlogs, fallback to qlogs with a prompt from the user
def create_slice_from_string(s: str):
m = re.fullmatch(RE.SLICE, s)
assert m is not None, f"Invalid slice: {s}"
start, end, step = m.groups()
start = int(start) if start is not None else None
end = int(end) if end is not None else None
step = int(step) if step is not None else None
if start is not None and ":" not in s and end is None and step is None:
return start
return slice(start, end, step)
def default_valid_file(fn):
return fn is not None and file_exists(fn)
@ -121,27 +105,11 @@ def apply_strategy(mode: ReadMode, rlog_paths, qlog_paths, valid_file=default_va
return auto_strategy(rlog_paths, qlog_paths, True, valid_file)
def parse_slice(sr: SegmentRange):
s = create_slice_from_string(sr._slice)
if isinstance(s, slice):
if s.stop is None or s.stop < 0 or (s.start is not None and s.start < 0): # we need the number of segments in order to parse this slice
segs = np.arange(sr.get_max_seg_number() + 1)
else:
segs = np.arange(s.stop + 1)
return segs[s]
else:
if s < 0:
s = sr.get_max_seg_number() + s + 1
return [s]
def comma_api_source(sr: SegmentRange, mode: ReadMode):
segs = parse_slice(sr)
route = Route(sr.route_name)
rlog_paths = [route.log_paths()[seg] for seg in segs]
qlog_paths = [route.qlog_paths()[seg] for seg in segs]
rlog_paths = [route.log_paths()[seg] for seg in sr.seg_idxs]
qlog_paths = [route.qlog_paths()[seg] for seg in sr.seg_idxs]
# comma api will have already checked if the file exists
def valid_file(fn):
@ -154,30 +122,25 @@ def internal_source(sr: SegmentRange, mode: ReadMode):
if not internal_source_available():
raise Exception("Internal source not available")
segs = parse_slice(sr)
def get_internal_url(sr: SegmentRange, seg, file):
return f"cd:/{sr.dongle_id}/{sr.timestamp}/{seg}/{file}.bz2"
rlog_paths = [get_internal_url(sr, seg, "rlog") for seg in segs]
qlog_paths = [get_internal_url(sr, seg, "qlog") for seg in segs]
rlog_paths = [get_internal_url(sr, seg, "rlog") for seg in sr.seg_idxs]
qlog_paths = [get_internal_url(sr, seg, "qlog") for seg in sr.seg_idxs]
return apply_strategy(mode, rlog_paths, qlog_paths)
def openpilotci_source(sr: SegmentRange, mode: ReadMode):
segs = parse_slice(sr)
rlog_paths = [get_url(sr.route_name, seg, "rlog") for seg in segs]
qlog_paths = [get_url(sr.route_name, seg, "qlog") for seg in segs]
rlog_paths = [get_url(sr.route_name, seg, "rlog") for seg in sr.seg_idxs]
qlog_paths = [get_url(sr.route_name, seg, "qlog") for seg in sr.seg_idxs]
return apply_strategy(mode, rlog_paths, qlog_paths)
def comma_car_segments_source(sr: SegmentRange, mode=ReadMode.RLOG):
segs = parse_slice(sr)
return [get_comma_segments_url(sr.route_name, seg) for seg in sr.seg_idxs]
return [get_comma_segments_url(sr.route_name, seg) for seg in segs]
def direct_source(file_or_url):

@ -251,9 +251,6 @@ class SegmentRange:
assert m is not None, f"Segment range is not valid {segment_range}"
self.m = m
def get_max_seg_number(self):
return get_max_seg_number_cached(self)
@property
def route_name(self) -> str:
return self.m.group("route_name")
@ -270,6 +267,25 @@ class SegmentRange:
def _slice(self) -> str:
return self.m.group("slice")
@property
def seg_idxs(self) -> list[int]:
m = re.fullmatch(RE.SLICE, self._slice)
assert m is not None, f"Invalid slice: {self._slice}"
start, end, step = (None if s is None else int(s) for s in m.groups())
# one segment specified
if start is not None and end is None and ':' not in self._slice:
if start < 0:
start += get_max_seg_number_cached(self) + 1
return [start]
s = slice(start, end, step)
# no specified end or using relative indexing, need number of segments
if end is None or end < 0 or (start is not None and start < 0):
return list(range(get_max_seg_number_cached(self) + 1))[s]
else:
return list(range(end + 1))[s]
@property
def selector(self) -> str:
return self.m.group("selector")

@ -8,7 +8,7 @@ import requests
from parameterized import parameterized
from unittest import mock
from openpilot.tools.lib.logreader import LogIterable, LogReader, comma_api_source, parse_indirect, parse_slice, ReadMode
from openpilot.tools.lib.logreader import LogIterable, LogReader, comma_api_source, parse_indirect, ReadMode
from openpilot.tools.lib.route import SegmentRange
NUM_SEGS = 17 # number of segments in the test route
@ -50,8 +50,7 @@ class TestLogReader(unittest.TestCase):
def test_indirect_parsing(self, identifier, expected):
parsed, _, _ = parse_indirect(identifier)
sr = SegmentRange(parsed)
segs = parse_slice(sr)
self.assertListEqual(list(segs), expected)
self.assertListEqual(list(sr.seg_idxs), expected, identifier)
@parameterized.expand([
(f"{TEST_ROUTE}", f"{TEST_ROUTE}"),
@ -87,8 +86,7 @@ class TestLogReader(unittest.TestCase):
])
def test_bad_ranges(self, segment_range):
with self.assertRaises(AssertionError):
sr = SegmentRange(segment_range)
parse_slice(sr)
_ = SegmentRange(segment_range).seg_idxs
@parameterized.expand([
(f"{TEST_ROUTE}/0", False),
@ -100,7 +98,7 @@ class TestLogReader(unittest.TestCase):
def test_slicing_api_call(self, segment_range, api_call):
with mock.patch("openpilot.tools.lib.route.get_max_seg_number_cached") as max_seg_mock:
max_seg_mock.return_value = NUM_SEGS
parse_slice(SegmentRange(segment_range))
_ = SegmentRange(segment_range).seg_idxs
self.assertEqual(api_call, max_seg_mock.called)
@pytest.mark.slow

Loading…
Cancel
Save