LogReader: support passing list of sources (#35749)

* far too long

* this is a better experience

* no rename for now
pull/35748/head^2
Shane Smiskol 1 week ago committed by GitHub
parent 54da96dbdf
commit c553c1f872
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 9
      selfdrive/car/tests/test_models.py
  2. 17
      tools/lib/logreader.py

@ -4,7 +4,6 @@ import pytest
import random
import unittest # noqa: TID251
from collections import defaultdict, Counter
from functools import partial
import hypothesis.strategies as st
from hypothesis import Phase, given, settings
from parameterized import parameterized_class
@ -23,7 +22,7 @@ from openpilot.selfdrive.pandad import can_capnp_to_list
from openpilot.selfdrive.test.helpers import read_segment_list
from openpilot.system.hardware.hw import DEFAULT_DOWNLOAD_CACHE_ROOT
from openpilot.tools.lib.logreader import LogReader, LogsUnavailable, openpilotci_source_zst, openpilotci_source, internal_source, \
internal_source_zst, comma_api_source, auto_source
internal_source_zst, comma_api_source
from openpilot.tools.lib.route import SegmentName
SafetyModel = car.CarParams.SafetyModel
@ -125,9 +124,9 @@ class TestCarModelBase(unittest.TestCase):
segment_range = f"{cls.test_route.route}/{seg}"
try:
source = partial(auto_source, sources=[internal_source, internal_source_zst] if len(INTERNAL_SEG_LIST) else \
[openpilotci_source_zst, openpilotci_source, comma_api_source])
lr = LogReader(segment_range, source=source, sort_by_time=True)
sources = ([internal_source, internal_source_zst] if len(INTERNAL_SEG_LIST) else
[openpilotci_source_zst, openpilotci_source, comma_api_source])
lr = LogReader(segment_range, sources=sources, sort_by_time=True)
return cls.get_testing_data_from_logreader(lr)
except (LogsUnavailable, AssertionError):
pass

@ -186,7 +186,7 @@ def openpilotci_source_zst(sr: SegmentRange, mode: ReadMode) -> list[LogPath]:
return openpilotci_source(sr, mode, "zst")
def comma_car_segments_source(sr: SegmentRange, mode=ReadMode.RLOG) -> list[LogPath]:
def comma_car_segments_source(sr: SegmentRange, mode: ReadMode = ReadMode.RLOG) -> list[LogPath]:
return [get_comma_segments_url(sr.route_name, seg) for seg in sr.seg_idxs]
@ -213,13 +213,10 @@ def check_source(source: Source, *args) -> list[LogPath]:
return files
def auto_source(sr: SegmentRange, mode=ReadMode.RLOG, sources: list[Source] = None) -> list[LogPath]:
def auto_source(sr: SegmentRange, sources: list[Source], mode: ReadMode = ReadMode.RLOG) -> list[LogPath]:
if mode == ReadMode.SANITIZED:
return comma_car_segments_source(sr, mode)
if sources is None:
sources = [internal_source, internal_source_zst, openpilotci_source, openpilotci_source_zst,
comma_api_source, comma_car_segments_source, testing_closet_source]
exceptions = {}
# for automatic fallback modes, auto_source needs to first check if rlogs exist for any source
@ -267,7 +264,7 @@ class LogReader:
sr = SegmentRange(identifier)
mode = self.default_mode if sr.selector is None else ReadMode(sr.selector)
identifiers = self.source(sr, mode)
identifiers = auto_source(sr, self.sources, mode)
invalid_count = len(list(get_invalid_files(identifiers)))
assert invalid_count == 0, (f"{invalid_count}/{len(identifiers)} invalid log(s) found, please ensure all logs " +
@ -275,9 +272,13 @@ class LogReader:
return identifiers
def __init__(self, identifier: str | list[str], default_mode: ReadMode = ReadMode.RLOG,
source: Source = auto_source, sort_by_time=False, only_union_types=False):
sources: list[Source] = None, sort_by_time=False, only_union_types=False):
if sources is None:
sources = [internal_source, internal_source_zst, openpilotci_source, openpilotci_source_zst,
comma_api_source, comma_car_segments_source, testing_closet_source]
self.default_mode = default_mode
self.source = source
self.sources = sources
self.identifier = identifier
if isinstance(identifier, str):
self.identifier = [identifier]

Loading…
Cancel
Save