diff --git a/tools/lib/logreader.py b/tools/lib/logreader.py index 29e3acc076..58883c6865 100755 --- a/tools/lib/logreader.py +++ b/tools/lib/logreader.py @@ -12,7 +12,7 @@ import sys import urllib.parse import warnings -from typing import Iterable, Iterator, List, Type +from typing import Dict, Iterable, Iterator, List, Type from urllib.parse import parse_qs, urlparse from cereal import log as capnp_log @@ -232,20 +232,25 @@ class LogReader: self.sort_by_time = sort_by_time self.only_union_types = only_union_types + self.__lrs: Dict[int, _LogFileReader] = {} self.reset() + def _get_lr(self, i): + if i not in self.__lrs: + self.__lrs[i] = _LogFileReader(self.logreader_identifiers[i]) + return self.__lrs[i] + def __iter__(self): - for identifier in self.logreader_identifiers: - yield from _LogFileReader(identifier) + for i in range(len(self.logreader_identifiers)): + yield from self._get_lr(i) - def _run_on_segment(self, func, identifier): - lr = _LogFileReader(identifier) - return func(lr) + def _run_on_segment(self, func, i): + return func(self._get_lr(i)) def run_across_segments(self, num_processes, func): with multiprocessing.Pool(num_processes) as pool: ret = [] - for p in pool.map(partial(self._run_on_segment, func), self.logreader_identifiers): + for p in pool.map(partial(self._run_on_segment, func), range(len(self.logreader_identifiers))): ret.extend(p) return ret diff --git a/tools/lib/tests/test_logreader.py b/tools/lib/tests/test_logreader.py index f7874a3fb3..d8a9c14088 100644 --- a/tools/lib/tests/test_logreader.py +++ b/tools/lib/tests/test_logreader.py @@ -3,8 +3,11 @@ import tempfile import numpy as np import unittest import pytest -from parameterized import parameterized import requests + +from parameterized import parameterized +from unittest import mock + from openpilot.tools.lib.logreader import LogReader, parse_indirect, parse_slice, ReadMode from openpilot.tools.lib.route import SegmentRange @@ -104,11 +107,15 @@ class TestLogReader(unittest.TestCase): self.assertEqual(qlog_len*2, qlog_len_2) @pytest.mark.slow - def test_multiple_iterations(self): + @mock.patch("openpilot.tools.lib.logreader._LogFileReader") + def test_multiple_iterations(self, init_mock): lr = LogReader(f"{TEST_ROUTE}/0/q") qlog_len1 = len(list(lr)) qlog_len2 = len(list(lr)) + # ensure we don't create multiple instances of _LogFileReader, which means downloading the files twice + self.assertEqual(init_mock.call_count, 1) + self.assertEqual(qlog_len1, qlog_len2)