From 35d848ad5290ed9c0747b0ac20a5affab6a561d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20R=C4=85czy?= Date: Mon, 5 Feb 2024 19:04:59 -0800 Subject: [PATCH] webrtcd: lazy import of aiortc (#31304) * Lazy imports in webrtcd * Lazy imports in web.py * Type hints * Remove FrameReaderVideoStreamTrack * leave the aiohttp.web import * Leave the client session * main leftover --- system/webrtc/device/video.py | 25 ------------------------- system/webrtc/webrtcd.py | 34 ++++++++++++++++++---------------- tools/bodyteleop/web.py | 15 ++++++++------- 3 files changed, 26 insertions(+), 48 deletions(-) diff --git a/system/webrtc/device/video.py b/system/webrtc/device/video.py index 1ecb6dbd74..314f812834 100644 --- a/system/webrtc/device/video.py +++ b/system/webrtc/device/video.py @@ -5,7 +5,6 @@ import av from teleoprtc.tracks import TiciVideoStreamTrack from cereal import messaging -from openpilot.tools.lib.framereader import FrameReader from openpilot.common.realtime import DT_MDL, DT_DMON @@ -43,27 +42,3 @@ class LiveStreamVideoStreamTrack(TiciVideoStreamTrack): def codec_preference(self) -> Optional[str]: return "H264" - - -class FrameReaderVideoStreamTrack(TiciVideoStreamTrack): - def __init__(self, input_file: str, dt: float = DT_MDL, camera_type: str = "driver"): - super().__init__(camera_type, dt) - - frame_reader = FrameReader(input_file) - self._frames = [frame_reader.get(i, pix_fmt="rgb24") for i in range(frame_reader.frame_count)] - self._frame_count = len(self.frames) - self._frame_index = 0 - self._pts = 0 - - async def recv(self): - self.log_debug("track sending frame %s", self._pts) - img = self._frames[self._frame_index] - - new_frame = av.VideoFrame.from_ndarray(img, format="rgb24") - new_frame.pts = self._pts - new_frame.time_base = self._time_base - - self._frame_index = (self._frame_index + 1) % self._frame_count - self._pts = await self.next_pts(self._pts) - - return new_frame diff --git a/system/webrtc/webrtcd.py b/system/webrtc/webrtcd.py index 12f9328532..cc26d50daf 100755 --- a/system/webrtc/webrtcd.py +++ b/system/webrtc/webrtcd.py @@ -6,34 +6,27 @@ import json import uuid import logging from dataclasses import dataclass, field -from typing import Any, List, Optional, Union +from typing import Any, List, Optional, Union, TYPE_CHECKING # aiortc and its dependencies have lots of internal warnings :( import warnings warnings.filterwarnings("ignore", category=DeprecationWarning) -import aiortc -from aiortc.mediastreams import VideoStreamTrack, AudioStreamTrack -from aiortc.contrib.media import MediaBlackhole -from aiortc.exceptions import InvalidStateError -from aiohttp import web import capnp -from teleoprtc import WebRTCAnswerBuilder -from teleoprtc.info import parse_info_from_offer +from aiohttp import web +if TYPE_CHECKING: + from aiortc.rtcdatachannel import RTCDataChannel -from openpilot.system.webrtc.device.video import LiveStreamVideoStreamTrack -from openpilot.system.webrtc.device.audio import AudioInputStreamTrack, AudioOutputSpeaker from openpilot.system.webrtc.schema import generate_field - from cereal import messaging, log class CerealOutgoingMessageProxy: def __init__(self, sm: messaging.SubMaster): self.sm = sm - self.channels: List[aiortc.RTCDataChannel] = [] + self.channels: List['RTCDataChannel'] = [] - def add_channel(self, channel: aiortc.RTCDataChannel): + def add_channel(self, channel: 'RTCDataChannel'): self.channels.append(channel) def to_json(self, msg_content: Any): @@ -96,6 +89,8 @@ class CerealProxyRunner: self.task = None async def run(self): + from aiortc.exceptions import InvalidStateError + while True: try: self.proxy.update() @@ -109,6 +104,13 @@ class CerealProxyRunner: class StreamSession: def __init__(self, sdp: str, cameras: List[str], incoming_services: List[str], outgoing_services: List[str], debug_mode: bool = False): + from aiortc.mediastreams import VideoStreamTrack, AudioStreamTrack + from aiortc.contrib.media import MediaBlackhole + from openpilot.system.webrtc.device.video import LiveStreamVideoStreamTrack + from openpilot.system.webrtc.device.audio import AudioInputStreamTrack, AudioOutputSpeaker + from teleoprtc import WebRTCAnswerBuilder + from teleoprtc.info import parse_info_from_offer + config = parse_info_from_offer(sdp) builder = WebRTCAnswerBuilder(sdp) @@ -192,7 +194,7 @@ class StreamRequestBody: bridge_services_out: List[str] = field(default_factory=list) -async def get_stream(request: web.Request): +async def get_stream(request: 'web.Request'): stream_dict, debug_mode = request.app['streams'], request.app['debug'] raw_body = await request.json() body = StreamRequestBody(**raw_body) @@ -206,7 +208,7 @@ async def get_stream(request: web.Request): return web.json_response({"sdp": answer.sdp, "type": answer.type}) -async def get_schema(request: web.Request): +async def get_schema(request: 'web.Request'): services = request.query["services"].split(",") services = [s for s in services if s] assert all(s in log.Event.schema.fields and not s.endswith("DEPRECATED") for s in services), "Invalid service name" @@ -214,7 +216,7 @@ async def get_schema(request: web.Request): return web.json_response(schema_dict) -async def on_shutdown(app: web.Application): +async def on_shutdown(app: 'web.Application'): for session in app['streams'].values(): session.stop() del app['streams'] diff --git a/tools/bodyteleop/web.py b/tools/bodyteleop/web.py index 53077af67e..b1fb9525db 100644 --- a/tools/bodyteleop/web.py +++ b/tools/bodyteleop/web.py @@ -6,9 +6,10 @@ import os import ssl import subprocess -from aiohttp import web, ClientSession import pyaudio import wave +from aiohttp import web +from aiohttp import ClientSession from openpilot.common.basedir import BASEDIR from openpilot.system.webrtc.webrtcd import StreamRequestBody @@ -22,7 +23,7 @@ WEBRTCD_HOST, WEBRTCD_PORT = "localhost", 5001 ## UTILS -async def play_sound(sound): +async def play_sound(sound: str): SOUNDS = { "engage": "selfdrive/assets/sounds/engage.wav", "disengage": "selfdrive/assets/sounds/disengage.wav", @@ -51,7 +52,7 @@ async def play_sound(sound): p.terminate() ## SSL -def create_ssl_cert(cert_path, key_path): +def create_ssl_cert(cert_path: str, key_path: str): try: proc = subprocess.run(f'openssl req -x509 -newkey rsa:4096 -nodes -out {cert_path} -keyout {key_path} \ -days 365 -subj "/C=US/ST=California/O=commaai/OU=comma body"', @@ -75,17 +76,17 @@ def create_ssl_context(): return ssl_context ## ENDPOINTS -async def index(request): +async def index(request: 'web.Request'): with open(os.path.join(TELEOPDIR, "static", "index.html"), "r") as f: content = f.read() return web.Response(content_type="text/html", text=content) -async def ping(request): +async def ping(request: 'web.Request'): return web.Response(text="pong") -async def sound(request): +async def sound(request: 'web.Request'): params = await request.json() sound_to_play = params["sound"] @@ -93,7 +94,7 @@ async def sound(request): return web.json_response({"status": "ok"}) -async def offer(request): +async def offer(request: 'web.Request'): params = await request.json() body = StreamRequestBody(params["sdp"], ["driver"], ["testJoystick"], ["carState"]) body_json = json.dumps(dataclasses.asdict(body))