SegmentRange: simplify slice (#31455)

* simplify slicing

* rm
old-commit-hash: 8fe9bc7a69
chrysler-long2
Shane Smiskol 1 year ago committed by GitHub
parent ae65a01afd
commit 5c16ae62d8
  1. 51
      tools/lib/logreader.py
  2. 22
      tools/lib/route.py
  3. 10
      tools/lib/tests/test_logreader.py

@ -4,10 +4,8 @@ from functools import partial
import multiprocessing import multiprocessing
import capnp import capnp
import enum import enum
import numpy as np
import os import os
import pathlib import pathlib
import re
import sys import sys
import tqdm import tqdm
import urllib.parse import urllib.parse
@ -21,7 +19,6 @@ from openpilot.common.swaglog import cloudlog
from openpilot.tools.lib.comma_car_segments import get_url as get_comma_segments_url from openpilot.tools.lib.comma_car_segments import get_url as get_comma_segments_url
from openpilot.tools.lib.openpilotci import get_url from openpilot.tools.lib.openpilotci import get_url
from openpilot.tools.lib.filereader import FileReader, file_exists, internal_source_available from openpilot.tools.lib.filereader import FileReader, file_exists, internal_source_available
from openpilot.tools.lib.helpers import RE
from openpilot.tools.lib.route import Route, SegmentRange from openpilot.tools.lib.route import Route, SegmentRange
LogMessage = Type[capnp._DynamicStructReader] LogMessage = Type[capnp._DynamicStructReader]
@ -79,19 +76,6 @@ class ReadMode(enum.StrEnum):
AUTO_INTERACIVE = "i" # default to rlogs, fallback to qlogs with a prompt from the user AUTO_INTERACIVE = "i" # default to rlogs, fallback to qlogs with a prompt from the user
def create_slice_from_string(s: str):
m = re.fullmatch(RE.SLICE, s)
assert m is not None, f"Invalid slice: {s}"
start, end, step = m.groups()
start = int(start) if start is not None else None
end = int(end) if end is not None else None
step = int(step) if step is not None else None
if start is not None and ":" not in s and end is None and step is None:
return start
return slice(start, end, step)
def default_valid_file(fn): def default_valid_file(fn):
return fn is not None and file_exists(fn) return fn is not None and file_exists(fn)
@ -121,27 +105,11 @@ def apply_strategy(mode: ReadMode, rlog_paths, qlog_paths, valid_file=default_va
return auto_strategy(rlog_paths, qlog_paths, True, valid_file) return auto_strategy(rlog_paths, qlog_paths, True, valid_file)
def parse_slice(sr: SegmentRange):
s = create_slice_from_string(sr._slice)
if isinstance(s, slice):
if s.stop is None or s.stop < 0 or (s.start is not None and s.start < 0): # we need the number of segments in order to parse this slice
segs = np.arange(sr.get_max_seg_number() + 1)
else:
segs = np.arange(s.stop + 1)
return segs[s]
else:
if s < 0:
s = sr.get_max_seg_number() + s + 1
return [s]
def comma_api_source(sr: SegmentRange, mode: ReadMode): def comma_api_source(sr: SegmentRange, mode: ReadMode):
segs = parse_slice(sr)
route = Route(sr.route_name) route = Route(sr.route_name)
rlog_paths = [route.log_paths()[seg] for seg in segs] rlog_paths = [route.log_paths()[seg] for seg in sr.seg_idxs]
qlog_paths = [route.qlog_paths()[seg] for seg in segs] qlog_paths = [route.qlog_paths()[seg] for seg in sr.seg_idxs]
# comma api will have already checked if the file exists # comma api will have already checked if the file exists
def valid_file(fn): def valid_file(fn):
@ -154,30 +122,25 @@ def internal_source(sr: SegmentRange, mode: ReadMode):
if not internal_source_available(): if not internal_source_available():
raise Exception("Internal source not available") raise Exception("Internal source not available")
segs = parse_slice(sr)
def get_internal_url(sr: SegmentRange, seg, file): def get_internal_url(sr: SegmentRange, seg, file):
return f"cd:/{sr.dongle_id}/{sr.timestamp}/{seg}/{file}.bz2" return f"cd:/{sr.dongle_id}/{sr.timestamp}/{seg}/{file}.bz2"
rlog_paths = [get_internal_url(sr, seg, "rlog") for seg in segs] rlog_paths = [get_internal_url(sr, seg, "rlog") for seg in sr.seg_idxs]
qlog_paths = [get_internal_url(sr, seg, "qlog") for seg in segs] qlog_paths = [get_internal_url(sr, seg, "qlog") for seg in sr.seg_idxs]
return apply_strategy(mode, rlog_paths, qlog_paths) return apply_strategy(mode, rlog_paths, qlog_paths)
def openpilotci_source(sr: SegmentRange, mode: ReadMode): def openpilotci_source(sr: SegmentRange, mode: ReadMode):
segs = parse_slice(sr) rlog_paths = [get_url(sr.route_name, seg, "rlog") for seg in sr.seg_idxs]
qlog_paths = [get_url(sr.route_name, seg, "qlog") for seg in sr.seg_idxs]
rlog_paths = [get_url(sr.route_name, seg, "rlog") for seg in segs]
qlog_paths = [get_url(sr.route_name, seg, "qlog") for seg in segs]
return apply_strategy(mode, rlog_paths, qlog_paths) return apply_strategy(mode, rlog_paths, qlog_paths)
def comma_car_segments_source(sr: SegmentRange, mode=ReadMode.RLOG): def comma_car_segments_source(sr: SegmentRange, mode=ReadMode.RLOG):
segs = parse_slice(sr) return [get_comma_segments_url(sr.route_name, seg) for seg in sr.seg_idxs]
return [get_comma_segments_url(sr.route_name, seg) for seg in segs]
def direct_source(file_or_url): def direct_source(file_or_url):

@ -251,9 +251,6 @@ class SegmentRange:
assert m is not None, f"Segment range is not valid {segment_range}" assert m is not None, f"Segment range is not valid {segment_range}"
self.m = m self.m = m
def get_max_seg_number(self):
return get_max_seg_number_cached(self)
@property @property
def route_name(self) -> str: def route_name(self) -> str:
return self.m.group("route_name") return self.m.group("route_name")
@ -270,6 +267,25 @@ class SegmentRange:
def _slice(self) -> str: def _slice(self) -> str:
return self.m.group("slice") return self.m.group("slice")
@property
def seg_idxs(self) -> list[int]:
m = re.fullmatch(RE.SLICE, self._slice)
assert m is not None, f"Invalid slice: {self._slice}"
start, end, step = (None if s is None else int(s) for s in m.groups())
# one segment specified
if start is not None and end is None and ':' not in self._slice:
if start < 0:
start += get_max_seg_number_cached(self) + 1
return [start]
s = slice(start, end, step)
# no specified end or using relative indexing, need number of segments
if end is None or end < 0 or (start is not None and start < 0):
return list(range(get_max_seg_number_cached(self) + 1))[s]
else:
return list(range(end + 1))[s]
@property @property
def selector(self) -> str: def selector(self) -> str:
return self.m.group("selector") return self.m.group("selector")

@ -8,7 +8,7 @@ import requests
from parameterized import parameterized from parameterized import parameterized
from unittest import mock from unittest import mock
from openpilot.tools.lib.logreader import LogIterable, LogReader, comma_api_source, parse_indirect, parse_slice, ReadMode from openpilot.tools.lib.logreader import LogIterable, LogReader, comma_api_source, parse_indirect, ReadMode
from openpilot.tools.lib.route import SegmentRange from openpilot.tools.lib.route import SegmentRange
NUM_SEGS = 17 # number of segments in the test route NUM_SEGS = 17 # number of segments in the test route
@ -50,8 +50,7 @@ class TestLogReader(unittest.TestCase):
def test_indirect_parsing(self, identifier, expected): def test_indirect_parsing(self, identifier, expected):
parsed, _, _ = parse_indirect(identifier) parsed, _, _ = parse_indirect(identifier)
sr = SegmentRange(parsed) sr = SegmentRange(parsed)
segs = parse_slice(sr) self.assertListEqual(list(sr.seg_idxs), expected, identifier)
self.assertListEqual(list(segs), expected)
@parameterized.expand([ @parameterized.expand([
(f"{TEST_ROUTE}", f"{TEST_ROUTE}"), (f"{TEST_ROUTE}", f"{TEST_ROUTE}"),
@ -87,8 +86,7 @@ class TestLogReader(unittest.TestCase):
]) ])
def test_bad_ranges(self, segment_range): def test_bad_ranges(self, segment_range):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
sr = SegmentRange(segment_range) _ = SegmentRange(segment_range).seg_idxs
parse_slice(sr)
@parameterized.expand([ @parameterized.expand([
(f"{TEST_ROUTE}/0", False), (f"{TEST_ROUTE}/0", False),
@ -100,7 +98,7 @@ class TestLogReader(unittest.TestCase):
def test_slicing_api_call(self, segment_range, api_call): def test_slicing_api_call(self, segment_range, api_call):
with mock.patch("openpilot.tools.lib.route.get_max_seg_number_cached") as max_seg_mock: with mock.patch("openpilot.tools.lib.route.get_max_seg_number_cached") as max_seg_mock:
max_seg_mock.return_value = NUM_SEGS max_seg_mock.return_value = NUM_SEGS
parse_slice(SegmentRange(segment_range)) _ = SegmentRange(segment_range).seg_idxs
self.assertEqual(api_call, max_seg_mock.called) self.assertEqual(api_call, max_seg_mock.called)
@pytest.mark.slow @pytest.mark.slow

Loading…
Cancel
Save