auto_source: optimize api calls and use head to determine if file exists (#31025)

* fast

* catch all

* source

* fix file_exists

* remove duplicate reset

* test multiple loops

* iterations

* cleanup imports
pull/31028/head
Justin Newberry 1 year ago committed by GitHub
parent 5c24527683
commit 2967cada71
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 8
      tools/lib/filereader.py
  2. 46
      tools/lib/logreader.py
  3. 13
      tools/lib/tests/test_logreader.py

@ -1,4 +1,6 @@
import os
import requests
from openpilot.tools.lib.url_file import URLFile
DATA_ENDPOINT = os.getenv("DATA_ENDPOINT", "http://data-raw.comma.internal/")
@ -8,6 +10,12 @@ def resolve_name(fn):
return fn.replace("cd:/", DATA_ENDPOINT)
return fn
def file_exists(fn):
fn = resolve_name(fn)
if fn.startswith(("http://", "https://")):
return requests.head(fn, allow_redirects=True).status_code == 200
return os.path.exists(fn)
def FileReader(fn, debug=False):
fn = resolve_name(fn)
if fn.startswith(("http://", "https://")):

@ -17,7 +17,7 @@ from urllib.parse import parse_qs, urlparse
from cereal import log as capnp_log
from openpilot.tools.lib.openpilotci import get_url
from openpilot.tools.lib.filereader import FileReader
from openpilot.tools.lib.filereader import FileReader, file_exists
from openpilot.tools.lib.helpers import RE
from openpilot.tools.lib.route import Route, SegmentRange
@ -84,15 +84,13 @@ def create_slice_from_string(s: str):
return start
return slice(start, end, step)
def parse_slice(sr: SegmentRange):
route = Route(sr.route_name)
def parse_slice(sr: SegmentRange, route: Route):
segs = np.arange(route.max_seg_number+1)
s = create_slice_from_string(sr._slice)
return segs[s] if isinstance(s, slice) else [segs[s]]
def comma_api_source(sr: SegmentRange, mode=ReadMode.RLOG):
segs = parse_slice(sr)
route = Route(sr.route_name)
def comma_api_source(sr: SegmentRange, route: Route, mode=ReadMode.RLOG):
segs = parse_slice(sr, route)
log_paths = route.log_paths() if mode == ReadMode.RLOG else route.qlog_paths()
@ -102,35 +100,33 @@ def comma_api_source(sr: SegmentRange, mode=ReadMode.RLOG):
return [(log_paths[seg]) for seg in segs]
def internal_source(sr: SegmentRange, mode=ReadMode.RLOG):
segs = parse_slice(sr)
def internal_source(sr: SegmentRange, route: Route, mode=ReadMode.RLOG):
segs = parse_slice(sr, route)
return [f"cd:/{sr.dongle_id}/{sr.timestamp}/{seg}/{'rlog' if mode == ReadMode.RLOG else 'qlog'}.bz2" for seg in segs]
def openpilotci_source(sr: SegmentRange, mode=ReadMode.RLOG):
segs = parse_slice(sr)
def openpilotci_source(sr: SegmentRange, route: Route, mode=ReadMode.RLOG):
segs = parse_slice(sr, route)
return [get_url(sr.route_name, seg, 'rlog' if mode == ReadMode.RLOG else 'qlog') for seg in segs]
def direct_source(file_or_url):
return [file_or_url]
def auto_source(*args):
# Automatically determine viable source
def check_source(source, *args):
try:
identifiers = internal_source(*args)
_LogFileReader(identifiers[0])
return internal_source(*args)
files = source(*args)
assert all(file_exists(f) for f in files)
return True, files
except Exception:
pass
return False, None
try:
identifiers = openpilotci_source(*args)
_LogFileReader(identifiers[0])
return openpilotci_source(*args)
except Exception:
pass
def auto_source(*args):
# Automatically determine viable source
for source in [internal_source, openpilotci_source]:
valid, ret = check_source(source, *args)
if valid:
return ret
return comma_api_source(*args)
@ -173,10 +169,11 @@ class LogReader:
return direct_source(identifier)
sr = SegmentRange(parsed)
route = Route(sr.route_name)
mode = self.default_mode if sr.selector is None else ReadMode(sr.selector)
source = self.default_source if source is None else source
return source(sr, mode)
return source(sr, route, mode)
def __init__(self, identifier: str | List[str], default_mode=ReadMode.RLOG, default_source=auto_source, sort_by_time=False, only_union_types=False):
self.default_mode = default_mode
@ -189,7 +186,6 @@ class LogReader:
self.reset()
def __iter__(self):
self.reset()
for identifier in self.logreader_identifiers:
yield from _LogFileReader(identifier)

@ -6,7 +6,7 @@ import pytest
from parameterized import parameterized
import requests
from openpilot.tools.lib.logreader import LogReader, parse_indirect, parse_slice, ReadMode
from openpilot.tools.lib.route import SegmentRange
from openpilot.tools.lib.route import Route, SegmentRange
NUM_SEGS = 17 # number of segments in the test route
ALL_SEGS = list(np.arange(NUM_SEGS))
@ -42,7 +42,8 @@ class TestLogReader(unittest.TestCase):
def test_indirect_parsing(self, identifier, expected):
parsed, _, _ = parse_indirect(identifier)
sr = SegmentRange(parsed)
segs = parse_slice(sr)
route = Route(sr.route_name)
segs = parse_slice(sr, route)
self.assertListEqual(list(segs), expected)
def test_direct_parsing(self):
@ -91,6 +92,14 @@ class TestLogReader(unittest.TestCase):
self.assertEqual(qlog_len*2, qlog_len_2)
@pytest.mark.slow
def test_multiple_iterations(self):
lr = LogReader(f"{TEST_ROUTE}/0/q")
qlog_len1 = len(list(lr))
qlog_len2 = len(list(lr))
self.assertEqual(qlog_len1, qlog_len2)
if __name__ == "__main__":
unittest.main()

Loading…
Cancel
Save