SegmentRange: type annotations (#31453)

* type annotate SegmentRange

* proper formatting

* oops

* numpy?

format test too

* draft

* fixed

* clean up

* rm

* more

* clean up

* clean up

* rm

* not here

* revert
pull/31456/head
Shane Smiskol 1 year ago committed by GitHub
parent 0846175f44
commit 8276371009
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 25
      tools/lib/route.py
  2. 3
      tools/lib/tests/test_logreader.py

@ -4,7 +4,7 @@ from functools import cache
from urllib.parse import urlparse
from collections import defaultdict
from itertools import chain
from typing import Optional
from typing import Optional, cast
from openpilot.tools.lib.auth_config import get_token
from openpilot.tools.lib.api import CommaApi
@ -237,44 +237,45 @@ class SegmentName:
@cache
def get_max_seg_number_cached(sr: 'SegmentRange'):
def get_max_seg_number_cached(sr: 'SegmentRange') -> int:
try:
api = CommaApi(get_token())
return api.get("/v1/route/" + sr.route_name.replace("/", "|"))["segment_numbers"][-1]
return cast(int, api.get("/v1/route/" + sr.route_name.replace("/", "|"))["segment_numbers"][-1])
except Exception as e:
raise Exception("unable to get max_segment_number. ensure you have access to this route or the route is public.") from e
class SegmentRange:
def __init__(self, segment_range: str):
self.m = re.fullmatch(RE.SEGMENT_RANGE, segment_range)
assert self.m, f"Segment range is not valid {segment_range}"
m = re.fullmatch(RE.SEGMENT_RANGE, segment_range)
assert m is not None, f"Segment range is not valid {segment_range}"
self.m = m
def get_max_seg_number(self):
return get_max_seg_number_cached(self)
@property
def route_name(self):
def route_name(self) -> str:
return self.m.group("route_name")
@property
def dongle_id(self):
def dongle_id(self) -> str:
return self.m.group("dongle_id")
@property
def timestamp(self):
def timestamp(self) -> str:
return self.m.group("timestamp")
@property
def _slice(self):
def _slice(self) -> str:
return self.m.group("slice")
@property
def selector(self):
def selector(self) -> str:
return self.m.group("selector")
def __str__(self):
def __str__(self) -> str:
return f"{self.dongle_id}/{self.timestamp}" + (f"/{self._slice}" if self._slice else "") + (f"/{self.selector}" if self.selector else "")
def __repr__(self):
def __repr__(self) -> str:
return self.__str__()

@ -1,6 +1,5 @@
import shutil
import tempfile
import numpy as np
import os
import unittest
import pytest
@ -13,7 +12,7 @@ from openpilot.tools.lib.logreader import LogIterable, LogReader, comma_api_sour
from openpilot.tools.lib.route import SegmentRange
NUM_SEGS = 17 # number of segments in the test route
ALL_SEGS = list(np.arange(NUM_SEGS))
ALL_SEGS = list(range(NUM_SEGS))
TEST_ROUTE = "344c5c15b34f2d8a/2024-01-03--09-37-12"
QLOG_FILE = "https://commadataci.blob.core.windows.net/openpilotci/0375fdf7b1ce594d/2019-06-13--08-32-25/3/qlog.bz2"

Loading…
Cancel
Save