diff --git a/tools/lib/route.py b/tools/lib/route.py index 529e42e8e6..471fa2226f 100644 --- a/tools/lib/route.py +++ b/tools/lib/route.py @@ -4,7 +4,7 @@ from functools import cache from urllib.parse import urlparse from collections import defaultdict from itertools import chain -from typing import Optional +from typing import Optional, cast from openpilot.tools.lib.auth_config import get_token from openpilot.tools.lib.api import CommaApi @@ -237,44 +237,45 @@ class SegmentName: @cache -def get_max_seg_number_cached(sr: 'SegmentRange'): +def get_max_seg_number_cached(sr: 'SegmentRange') -> int: try: api = CommaApi(get_token()) - return api.get("/v1/route/" + sr.route_name.replace("/", "|"))["segment_numbers"][-1] + return cast(int, api.get("/v1/route/" + sr.route_name.replace("/", "|"))["segment_numbers"][-1]) except Exception as e: raise Exception("unable to get max_segment_number. ensure you have access to this route or the route is public.") from e class SegmentRange: def __init__(self, segment_range: str): - self.m = re.fullmatch(RE.SEGMENT_RANGE, segment_range) - assert self.m, f"Segment range is not valid {segment_range}" + m = re.fullmatch(RE.SEGMENT_RANGE, segment_range) + assert m is not None, f"Segment range is not valid {segment_range}" + self.m = m def get_max_seg_number(self): return get_max_seg_number_cached(self) @property - def route_name(self): + def route_name(self) -> str: return self.m.group("route_name") @property - def dongle_id(self): + def dongle_id(self) -> str: return self.m.group("dongle_id") @property - def timestamp(self): + def timestamp(self) -> str: return self.m.group("timestamp") @property - def _slice(self): + def _slice(self) -> str: return self.m.group("slice") @property - def selector(self): + def selector(self) -> str: return self.m.group("selector") - def __str__(self): + def __str__(self) -> str: return f"{self.dongle_id}/{self.timestamp}" + (f"/{self._slice}" if self._slice else "") + (f"/{self.selector}" if self.selector else "") - def __repr__(self): + def __repr__(self) -> str: return self.__str__() diff --git a/tools/lib/tests/test_logreader.py b/tools/lib/tests/test_logreader.py index 5131835017..c21c94342c 100644 --- a/tools/lib/tests/test_logreader.py +++ b/tools/lib/tests/test_logreader.py @@ -1,6 +1,5 @@ import shutil import tempfile -import numpy as np import os import unittest import pytest @@ -13,7 +12,7 @@ from openpilot.tools.lib.logreader import LogIterable, LogReader, comma_api_sour from openpilot.tools.lib.route import SegmentRange NUM_SEGS = 17 # number of segments in the test route -ALL_SEGS = list(np.arange(NUM_SEGS)) +ALL_SEGS = list(range(NUM_SEGS)) TEST_ROUTE = "344c5c15b34f2d8a/2024-01-03--09-37-12" QLOG_FILE = "https://commadataci.blob.core.windows.net/openpilotci/0375fdf7b1ce594d/2019-06-13--08-32-25/3/qlog.bz2"