LogReader: no redownloading on multiple iterations (#31141)

* no redownload

* sort
pull/30934/head
Justin Newberry 1 year ago committed by GitHub
parent b9ad854451
commit 88dcaa51c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 19
      tools/lib/logreader.py
  2. 11
      tools/lib/tests/test_logreader.py

@ -12,7 +12,7 @@ import sys
import urllib.parse import urllib.parse
import warnings import warnings
from typing import Iterable, Iterator, List, Type from typing import Dict, Iterable, Iterator, List, Type
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs, urlparse
from cereal import log as capnp_log from cereal import log as capnp_log
@ -232,20 +232,25 @@ class LogReader:
self.sort_by_time = sort_by_time self.sort_by_time = sort_by_time
self.only_union_types = only_union_types self.only_union_types = only_union_types
self.__lrs: Dict[int, _LogFileReader] = {}
self.reset() self.reset()
def _get_lr(self, i):
if i not in self.__lrs:
self.__lrs[i] = _LogFileReader(self.logreader_identifiers[i])
return self.__lrs[i]
def __iter__(self): def __iter__(self):
for identifier in self.logreader_identifiers: for i in range(len(self.logreader_identifiers)):
yield from _LogFileReader(identifier) yield from self._get_lr(i)
def _run_on_segment(self, func, identifier): def _run_on_segment(self, func, i):
lr = _LogFileReader(identifier) return func(self._get_lr(i))
return func(lr)
def run_across_segments(self, num_processes, func): def run_across_segments(self, num_processes, func):
with multiprocessing.Pool(num_processes) as pool: with multiprocessing.Pool(num_processes) as pool:
ret = [] ret = []
for p in pool.map(partial(self._run_on_segment, func), self.logreader_identifiers): for p in pool.map(partial(self._run_on_segment, func), range(len(self.logreader_identifiers))):
ret.extend(p) ret.extend(p)
return ret return ret

@ -3,8 +3,11 @@ import tempfile
import numpy as np import numpy as np
import unittest import unittest
import pytest import pytest
from parameterized import parameterized
import requests import requests
from parameterized import parameterized
from unittest import mock
from openpilot.tools.lib.logreader import LogReader, parse_indirect, parse_slice, ReadMode from openpilot.tools.lib.logreader import LogReader, parse_indirect, parse_slice, ReadMode
from openpilot.tools.lib.route import SegmentRange from openpilot.tools.lib.route import SegmentRange
@ -104,11 +107,15 @@ class TestLogReader(unittest.TestCase):
self.assertEqual(qlog_len*2, qlog_len_2) self.assertEqual(qlog_len*2, qlog_len_2)
@pytest.mark.slow @pytest.mark.slow
def test_multiple_iterations(self): @mock.patch("openpilot.tools.lib.logreader._LogFileReader")
def test_multiple_iterations(self, init_mock):
lr = LogReader(f"{TEST_ROUTE}/0/q") lr = LogReader(f"{TEST_ROUTE}/0/q")
qlog_len1 = len(list(lr)) qlog_len1 = len(list(lr))
qlog_len2 = len(list(lr)) qlog_len2 = len(list(lr))
# ensure we don't create multiple instances of _LogFileReader, which means downloading the files twice
self.assertEqual(init_mock.call_count, 1)
self.assertEqual(qlog_len1, qlog_len2) self.assertEqual(qlog_len1, qlog_len2)

Loading…
Cancel
Save