diff --git a/selfdrive/car/vin.py b/selfdrive/car/vin.py index f69771546f..676593255a 100755 --- a/selfdrive/car/vin.py +++ b/selfdrive/car/vin.py @@ -42,7 +42,7 @@ def get_vin(logcan, sendcan, buses, timeout=0.1, retry=3, debug=False): if vin.startswith(b'\x11'): vin = vin[1:18] - cloudlog.warning(f"got vin with {request=}") + cloudlog.error(f"got vin with {request=}") return get_rx_addr_for_tx_addr(addr), bus, vin.decode() except Exception: cloudlog.exception("VIN query exception") diff --git a/system/webrtc/device/video.py b/system/webrtc/device/video.py index 1ecb6dbd74..314f812834 100644 --- a/system/webrtc/device/video.py +++ b/system/webrtc/device/video.py @@ -5,7 +5,6 @@ import av from teleoprtc.tracks import TiciVideoStreamTrack from cereal import messaging -from openpilot.tools.lib.framereader import FrameReader from openpilot.common.realtime import DT_MDL, DT_DMON @@ -43,27 +42,3 @@ class LiveStreamVideoStreamTrack(TiciVideoStreamTrack): def codec_preference(self) -> Optional[str]: return "H264" - - -class FrameReaderVideoStreamTrack(TiciVideoStreamTrack): - def __init__(self, input_file: str, dt: float = DT_MDL, camera_type: str = "driver"): - super().__init__(camera_type, dt) - - frame_reader = FrameReader(input_file) - self._frames = [frame_reader.get(i, pix_fmt="rgb24") for i in range(frame_reader.frame_count)] - self._frame_count = len(self.frames) - self._frame_index = 0 - self._pts = 0 - - async def recv(self): - self.log_debug("track sending frame %s", self._pts) - img = self._frames[self._frame_index] - - new_frame = av.VideoFrame.from_ndarray(img, format="rgb24") - new_frame.pts = self._pts - new_frame.time_base = self._time_base - - self._frame_index = (self._frame_index + 1) % self._frame_count - self._pts = await self.next_pts(self._pts) - - return new_frame diff --git a/system/webrtc/webrtcd.py b/system/webrtc/webrtcd.py index 12f9328532..cc26d50daf 100755 --- a/system/webrtc/webrtcd.py +++ b/system/webrtc/webrtcd.py @@ -6,34 +6,27 @@ import json import uuid import logging from dataclasses import dataclass, field -from typing import Any, List, Optional, Union +from typing import Any, List, Optional, Union, TYPE_CHECKING # aiortc and its dependencies have lots of internal warnings :( import warnings warnings.filterwarnings("ignore", category=DeprecationWarning) -import aiortc -from aiortc.mediastreams import VideoStreamTrack, AudioStreamTrack -from aiortc.contrib.media import MediaBlackhole -from aiortc.exceptions import InvalidStateError -from aiohttp import web import capnp -from teleoprtc import WebRTCAnswerBuilder -from teleoprtc.info import parse_info_from_offer +from aiohttp import web +if TYPE_CHECKING: + from aiortc.rtcdatachannel import RTCDataChannel -from openpilot.system.webrtc.device.video import LiveStreamVideoStreamTrack -from openpilot.system.webrtc.device.audio import AudioInputStreamTrack, AudioOutputSpeaker from openpilot.system.webrtc.schema import generate_field - from cereal import messaging, log class CerealOutgoingMessageProxy: def __init__(self, sm: messaging.SubMaster): self.sm = sm - self.channels: List[aiortc.RTCDataChannel] = [] + self.channels: List['RTCDataChannel'] = [] - def add_channel(self, channel: aiortc.RTCDataChannel): + def add_channel(self, channel: 'RTCDataChannel'): self.channels.append(channel) def to_json(self, msg_content: Any): @@ -96,6 +89,8 @@ class CerealProxyRunner: self.task = None async def run(self): + from aiortc.exceptions import InvalidStateError + while True: try: self.proxy.update() @@ -109,6 +104,13 @@ class CerealProxyRunner: class StreamSession: def __init__(self, sdp: str, cameras: List[str], incoming_services: List[str], outgoing_services: List[str], debug_mode: bool = False): + from aiortc.mediastreams import VideoStreamTrack, AudioStreamTrack + from aiortc.contrib.media import MediaBlackhole + from openpilot.system.webrtc.device.video import LiveStreamVideoStreamTrack + from openpilot.system.webrtc.device.audio import AudioInputStreamTrack, AudioOutputSpeaker + from teleoprtc import WebRTCAnswerBuilder + from teleoprtc.info import parse_info_from_offer + config = parse_info_from_offer(sdp) builder = WebRTCAnswerBuilder(sdp) @@ -192,7 +194,7 @@ class StreamRequestBody: bridge_services_out: List[str] = field(default_factory=list) -async def get_stream(request: web.Request): +async def get_stream(request: 'web.Request'): stream_dict, debug_mode = request.app['streams'], request.app['debug'] raw_body = await request.json() body = StreamRequestBody(**raw_body) @@ -206,7 +208,7 @@ async def get_stream(request: web.Request): return web.json_response({"sdp": answer.sdp, "type": answer.type}) -async def get_schema(request: web.Request): +async def get_schema(request: 'web.Request'): services = request.query["services"].split(",") services = [s for s in services if s] assert all(s in log.Event.schema.fields and not s.endswith("DEPRECATED") for s in services), "Invalid service name" @@ -214,7 +216,7 @@ async def get_schema(request: web.Request): return web.json_response(schema_dict) -async def on_shutdown(app: web.Application): +async def on_shutdown(app: 'web.Application'): for session in app['streams'].values(): session.stop() del app['streams'] diff --git a/tools/bodyteleop/web.py b/tools/bodyteleop/web.py index 53077af67e..b1fb9525db 100644 --- a/tools/bodyteleop/web.py +++ b/tools/bodyteleop/web.py @@ -6,9 +6,10 @@ import os import ssl import subprocess -from aiohttp import web, ClientSession import pyaudio import wave +from aiohttp import web +from aiohttp import ClientSession from openpilot.common.basedir import BASEDIR from openpilot.system.webrtc.webrtcd import StreamRequestBody @@ -22,7 +23,7 @@ WEBRTCD_HOST, WEBRTCD_PORT = "localhost", 5001 ## UTILS -async def play_sound(sound): +async def play_sound(sound: str): SOUNDS = { "engage": "selfdrive/assets/sounds/engage.wav", "disengage": "selfdrive/assets/sounds/disengage.wav", @@ -51,7 +52,7 @@ async def play_sound(sound): p.terminate() ## SSL -def create_ssl_cert(cert_path, key_path): +def create_ssl_cert(cert_path: str, key_path: str): try: proc = subprocess.run(f'openssl req -x509 -newkey rsa:4096 -nodes -out {cert_path} -keyout {key_path} \ -days 365 -subj "/C=US/ST=California/O=commaai/OU=comma body"', @@ -75,17 +76,17 @@ def create_ssl_context(): return ssl_context ## ENDPOINTS -async def index(request): +async def index(request: 'web.Request'): with open(os.path.join(TELEOPDIR, "static", "index.html"), "r") as f: content = f.read() return web.Response(content_type="text/html", text=content) -async def ping(request): +async def ping(request: 'web.Request'): return web.Response(text="pong") -async def sound(request): +async def sound(request: 'web.Request'): params = await request.json() sound_to_play = params["sound"] @@ -93,7 +94,7 @@ async def sound(request): return web.json_response({"status": "ok"}) -async def offer(request): +async def offer(request: 'web.Request'): params = await request.json() body = StreamRequestBody(params["sdp"], ["driver"], ["testJoystick"], ["carState"]) body_json = json.dumps(dataclasses.asdict(body)) diff --git a/tools/lib/tests/test_logreader.py b/tools/lib/tests/test_logreader.py index 676d2bbadf..7a8ea20b76 100644 --- a/tools/lib/tests/test_logreader.py +++ b/tools/lib/tests/test_logreader.py @@ -1,6 +1,7 @@ import shutil import tempfile import numpy as np +import os import unittest import pytest import requests @@ -8,7 +9,7 @@ import requests from parameterized import parameterized from unittest import mock -from openpilot.tools.lib.logreader import LogReader, parse_indirect, parse_slice, ReadMode +from openpilot.tools.lib.logreader import LogIterable, LogReader, parse_indirect, parse_slice, ReadMode from openpilot.tools.lib.route import SegmentRange NUM_SEGS = 17 # number of segments in the test route @@ -16,6 +17,11 @@ ALL_SEGS = list(np.arange(NUM_SEGS)) TEST_ROUTE = "344c5c15b34f2d8a/2024-01-03--09-37-12" QLOG_FILE = "https://commadataci.blob.core.windows.net/openpilotci/0375fdf7b1ce594d/2019-06-13--08-32-25/3/qlog.bz2" + +def noop(segment: LogIterable): + return segment + + class TestLogReader(unittest.TestCase): @parameterized.expand([ (f"{TEST_ROUTE}", ALL_SEGS), @@ -124,6 +130,13 @@ class TestLogReader(unittest.TestCase): self.assertEqual(lr.first("carParams").carFingerprint, "SUBARU OUTBACK 6TH GEN") self.assertTrue(0 < len(list(lr.filter("carParams"))) < len(list(lr))) + @parameterized.expand([(True,), (False,)]) + @pytest.mark.slow + def test_run_across_segments(self, cache_enabled): + os.environ["FILEREADER_CACHE"] = "1" if cache_enabled else "0" + lr = LogReader(f"{TEST_ROUTE}/0:4") + self.assertEqual(len(lr.run_across_segments(4, noop)), len(list(lr))) + if __name__ == "__main__": unittest.main() diff --git a/tools/lib/url_file.py b/tools/lib/url_file.py index 5c6f187eee..be9c815c93 100644 --- a/tools/lib/url_file.py +++ b/tools/lib/url_file.py @@ -6,6 +6,7 @@ from hashlib import sha256 from urllib3 import PoolManager from urllib3.util import Timeout from tenacity import retry, wait_random_exponential, stop_after_attempt +from typing import Optional from openpilot.common.file_helpers import atomic_write_in_dir from openpilot.system.hardware.hw import Paths @@ -25,9 +26,12 @@ class URLFileException(Exception): class URLFile: - _tlocal = threading.local() + _pid: Optional[int] = None + _pool_manager: Optional[PoolManager] = None + _pool_manager_lock = threading.Lock() def __init__(self, url, debug=False, cache=None): + self._pool_manager = None self._url = url self._pos = 0 self._length = None @@ -41,11 +45,6 @@ class URLFile: if not self._force_download: os.makedirs(Paths.download_cache_root(), exist_ok=True) - try: - self._http_client = URLFile._tlocal.http_client - except AttributeError: - self._http_client = URLFile._tlocal.http_client = PoolManager() - def __enter__(self): return self @@ -55,10 +54,20 @@ class URLFile: self._local_file.close() self._local_file = None + def _http_client(self) -> PoolManager: + if self._pool_manager is None: + pid = os.getpid() + with URLFile._pool_manager_lock: + if URLFile._pid != pid or URLFile._pool_manager is None: # unsafe to share after fork + URLFile._pid = pid + URLFile._pool_manager = PoolManager(num_pools=10, maxsize=10) + self._pool_manager = URLFile._pool_manager + return self._pool_manager + @retry(wait=wait_random_exponential(multiplier=1, max=5), stop=stop_after_attempt(3), reraise=True) def get_length_online(self): timeout = Timeout(connect=50.0, read=500.0) - response = self._http_client.request('HEAD', self._url, timeout=timeout, preload_content=False) + response = self._http_client().request('HEAD', self._url, timeout=timeout, preload_content=False) if not (200 <= response.status <= 299): return -1 length = response.headers.get('content-length', 0) @@ -131,7 +140,7 @@ class URLFile: t1 = time.time() timeout = Timeout(connect=50.0, read=500.0) - response = self._http_client.request('GET', self._url, timeout=timeout, preload_content=False, headers=headers) + response = self._http_client().request('GET', self._url, timeout=timeout, preload_content=False, headers=headers) ret = response.data if self._debug: