diff --git a/tools/lib/logreader.py b/tools/lib/logreader.py index ee893c7e15..36cff0d724 100755 --- a/tools/lib/logreader.py +++ b/tools/lib/logreader.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 import bz2 +from functools import partial +import multiprocessing import capnp import enum -import itertools import numpy as np import os import pathlib @@ -89,7 +90,7 @@ def parse_slice(sr: SegmentRange): s = create_slice_from_string(sr._slice) return segs[s] if isinstance(s, slice) else [segs[s]] -def comma_api_source(sr: SegmentRange, mode=ReadMode.RLOG, **kwargs): +def comma_api_source(sr: SegmentRange, mode=ReadMode.RLOG): segs = parse_slice(sr) route = Route(sr.route_name) @@ -99,40 +100,39 @@ def comma_api_source(sr: SegmentRange, mode=ReadMode.RLOG, **kwargs): assert not len(invalid_segs), f"Some of the requested segments are not available: {invalid_segs}" - for seg in segs: - yield _LogFileReader(log_paths[seg], **kwargs) + return [(log_paths[seg]) for seg in segs] -def internal_source(sr: SegmentRange, mode=ReadMode.RLOG, **kwargs): +def internal_source(sr: SegmentRange, mode=ReadMode.RLOG): segs = parse_slice(sr) - for seg in segs: - yield _LogFileReader(f"cd:/{sr.dongle_id}/{sr.timestamp}/{seg}/{'rlog' if mode == ReadMode.RLOG else 'qlog'}.bz2", **kwargs) + return [f"cd:/{sr.dongle_id}/{sr.timestamp}/{seg}/{'rlog' if mode == ReadMode.RLOG else 'qlog'}.bz2" for seg in segs] -def openpilotci_source(sr: SegmentRange, mode=ReadMode.RLOG, **kwargs): +def openpilotci_source(sr: SegmentRange, mode=ReadMode.RLOG): segs = parse_slice(sr) - for seg in segs: - yield _LogFileReader(get_url(sr.route_name, seg, 'rlog' if mode == ReadMode.RLOG else 'qlog'), **kwargs) + return [get_url(sr.route_name, seg, 'rlog' if mode == ReadMode.RLOG else 'qlog') for seg in segs] -def direct_source(file_or_url, **kwargs): - yield _LogFileReader(file_or_url, **kwargs) +def direct_source(file_or_url): + return [file_or_url] -def auto_source(*args, **kwargs): +def auto_source(*args): # Automatically determine viable source try: - next(internal_source(*args, **kwargs)) - return internal_source(*args, **kwargs) + identifiers = internal_source(*args) + _LogFileReader(identifiers[0]) + return internal_source(*args) except Exception: pass try: - next(openpilotci_source(*args, **kwargs)) - return openpilotci_source(*args, **kwargs) + identifiers = openpilotci_source(*args) + _LogFileReader(identifiers[0]) + return openpilotci_source(*args) except Exception: pass - return comma_api_source(*args, **kwargs) + return comma_api_source(*args) def parse_useradmin(identifier): if "useradmin.comma.ai" in identifier: @@ -161,22 +161,22 @@ def parse_indirect(identifier): class LogReader: - def _logreaders_from_identifier(self, identifier: str | List[str]): + def _parse_identifiers(self, identifier: str | List[str]): if isinstance(identifier, list): - return [LogReader(i) for i in identifier] + return [i for j in identifier for i in self._parse_identifiers(j)] parsed, source, is_indirect = parse_indirect(identifier) if not is_indirect: direct_parsed = parse_direct(identifier) if direct_parsed is not None: - return direct_source(identifier, sort_by_time=self.sort_by_time) + return direct_source(identifier) sr = SegmentRange(parsed) mode = self.default_mode if sr.selector is None else ReadMode(sr.selector) source = self.default_source if source is None else source - return source(sr, mode, sort_by_time=self.sort_by_time, only_union_types=self.only_union_types) + return source(sr, mode) def __init__(self, identifier: str | List[str], default_mode=ReadMode.RLOG, default_source=auto_source, sort_by_time=False, only_union_types=False): self.default_mode = default_mode @@ -190,14 +190,22 @@ class LogReader: def __iter__(self): self.reset() - return self + for identifier in self.logreader_identifiers: + yield from _LogFileReader(identifier) - def __next__(self): - return next(self.chain) + def _run_on_segment(self, func, identifier): + lr = _LogFileReader(identifier) + return func(lr) + + def run_across_segments(self, num_processes, func): + with multiprocessing.Pool(num_processes) as pool: + ret = [] + for p in pool.map(partial(self._run_on_segment, func), self.logreader_identifiers): + ret.extend(p) + return ret def reset(self): - self.lrs = self._logreaders_from_identifier(self.identifier) - self.chain = itertools.chain(*self.lrs) + self.logreader_identifiers = self._parse_identifiers(self.identifier) @staticmethod def from_bytes(dat): diff --git a/tools/plotjuggler/juggle.py b/tools/plotjuggler/juggle.py index b497bdaa39..292db5a50a 100755 --- a/tools/plotjuggler/juggle.py +++ b/tools/plotjuggler/juggle.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -import multiprocessing import os import sys import platform @@ -76,10 +75,7 @@ def process(can, lr): def juggle_route(route_or_segment_name, can, layout, dbc=None): sr = LogReader(route_or_segment_name) - with multiprocessing.Pool(24) as pool: - all_data = [] - for p in pool.map(partial(process, can), sr.lrs): - all_data.extend(p) + all_data = sr.run_across_segments(24, partial(process, can)) # Infer DBC name from logs if dbc is None: diff --git a/tools/plotjuggler/test_plotjuggler.py b/tools/plotjuggler/test_plotjuggler.py index 1cb2dc0674..17287fb803 100755 --- a/tools/plotjuggler/test_plotjuggler.py +++ b/tools/plotjuggler/test_plotjuggler.py @@ -31,6 +31,8 @@ class TestPlotJuggler(unittest.TestCase): self.assertEqual(p.poll(), None) os.killpg(os.getpgid(p.pid), signal.SIGTERM) + self.assertNotIn("Raw file read failed", output) + # TODO: also test that layouts successfully load def test_layouts(self): bad_strings = ( diff --git a/tools/replay/can_replay.py b/tools/replay/can_replay.py index d0a5304cff..8ed6c63aa4 100755 --- a/tools/replay/can_replay.py +++ b/tools/replay/can_replay.py @@ -3,7 +3,6 @@ import argparse import os import time import threading -import multiprocessing os.environ['FILEREADER_CACHE'] = '1' @@ -99,10 +98,7 @@ if __name__ == "__main__": sr = LogReader(args.route_or_segment_name) - with multiprocessing.Pool(24) as pool: - CAN_MSGS = [] - for p in pool.map(process, sr.lrs): - CAN_MSGS.extend(p) + CAN_MSGS = sr.run_across_segments(24, process) print("Finished loading...")