webrtcd: lazy import of aiortc (#31304)

* Lazy imports in webrtcd

* Lazy imports in web.py

* Type hints

* Remove FrameReaderVideoStreamTrack

* leave the aiohttp.web import

* Leave the client session

* main leftover
pull/31314/head
Kacper Rączy 1 year ago committed by GitHub
parent b17f24d68e
commit 35d848ad52
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 25
      system/webrtc/device/video.py
  2. 34
      system/webrtc/webrtcd.py
  3. 15
      tools/bodyteleop/web.py

@ -5,7 +5,6 @@ import av
from teleoprtc.tracks import TiciVideoStreamTrack from teleoprtc.tracks import TiciVideoStreamTrack
from cereal import messaging from cereal import messaging
from openpilot.tools.lib.framereader import FrameReader
from openpilot.common.realtime import DT_MDL, DT_DMON from openpilot.common.realtime import DT_MDL, DT_DMON
@ -43,27 +42,3 @@ class LiveStreamVideoStreamTrack(TiciVideoStreamTrack):
def codec_preference(self) -> Optional[str]: def codec_preference(self) -> Optional[str]:
return "H264" return "H264"
class FrameReaderVideoStreamTrack(TiciVideoStreamTrack):
def __init__(self, input_file: str, dt: float = DT_MDL, camera_type: str = "driver"):
super().__init__(camera_type, dt)
frame_reader = FrameReader(input_file)
self._frames = [frame_reader.get(i, pix_fmt="rgb24") for i in range(frame_reader.frame_count)]
self._frame_count = len(self.frames)
self._frame_index = 0
self._pts = 0
async def recv(self):
self.log_debug("track sending frame %s", self._pts)
img = self._frames[self._frame_index]
new_frame = av.VideoFrame.from_ndarray(img, format="rgb24")
new_frame.pts = self._pts
new_frame.time_base = self._time_base
self._frame_index = (self._frame_index + 1) % self._frame_count
self._pts = await self.next_pts(self._pts)
return new_frame

@ -6,34 +6,27 @@ import json
import uuid import uuid
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, List, Optional, Union from typing import Any, List, Optional, Union, TYPE_CHECKING
# aiortc and its dependencies have lots of internal warnings :( # aiortc and its dependencies have lots of internal warnings :(
import warnings import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=DeprecationWarning)
import aiortc
from aiortc.mediastreams import VideoStreamTrack, AudioStreamTrack
from aiortc.contrib.media import MediaBlackhole
from aiortc.exceptions import InvalidStateError
from aiohttp import web
import capnp import capnp
from teleoprtc import WebRTCAnswerBuilder from aiohttp import web
from teleoprtc.info import parse_info_from_offer if TYPE_CHECKING:
from aiortc.rtcdatachannel import RTCDataChannel
from openpilot.system.webrtc.device.video import LiveStreamVideoStreamTrack
from openpilot.system.webrtc.device.audio import AudioInputStreamTrack, AudioOutputSpeaker
from openpilot.system.webrtc.schema import generate_field from openpilot.system.webrtc.schema import generate_field
from cereal import messaging, log from cereal import messaging, log
class CerealOutgoingMessageProxy: class CerealOutgoingMessageProxy:
def __init__(self, sm: messaging.SubMaster): def __init__(self, sm: messaging.SubMaster):
self.sm = sm self.sm = sm
self.channels: List[aiortc.RTCDataChannel] = [] self.channels: List['RTCDataChannel'] = []
def add_channel(self, channel: aiortc.RTCDataChannel): def add_channel(self, channel: 'RTCDataChannel'):
self.channels.append(channel) self.channels.append(channel)
def to_json(self, msg_content: Any): def to_json(self, msg_content: Any):
@ -96,6 +89,8 @@ class CerealProxyRunner:
self.task = None self.task = None
async def run(self): async def run(self):
from aiortc.exceptions import InvalidStateError
while True: while True:
try: try:
self.proxy.update() self.proxy.update()
@ -109,6 +104,13 @@ class CerealProxyRunner:
class StreamSession: class StreamSession:
def __init__(self, sdp: str, cameras: List[str], incoming_services: List[str], outgoing_services: List[str], debug_mode: bool = False): def __init__(self, sdp: str, cameras: List[str], incoming_services: List[str], outgoing_services: List[str], debug_mode: bool = False):
from aiortc.mediastreams import VideoStreamTrack, AudioStreamTrack
from aiortc.contrib.media import MediaBlackhole
from openpilot.system.webrtc.device.video import LiveStreamVideoStreamTrack
from openpilot.system.webrtc.device.audio import AudioInputStreamTrack, AudioOutputSpeaker
from teleoprtc import WebRTCAnswerBuilder
from teleoprtc.info import parse_info_from_offer
config = parse_info_from_offer(sdp) config = parse_info_from_offer(sdp)
builder = WebRTCAnswerBuilder(sdp) builder = WebRTCAnswerBuilder(sdp)
@ -192,7 +194,7 @@ class StreamRequestBody:
bridge_services_out: List[str] = field(default_factory=list) bridge_services_out: List[str] = field(default_factory=list)
async def get_stream(request: web.Request): async def get_stream(request: 'web.Request'):
stream_dict, debug_mode = request.app['streams'], request.app['debug'] stream_dict, debug_mode = request.app['streams'], request.app['debug']
raw_body = await request.json() raw_body = await request.json()
body = StreamRequestBody(**raw_body) body = StreamRequestBody(**raw_body)
@ -206,7 +208,7 @@ async def get_stream(request: web.Request):
return web.json_response({"sdp": answer.sdp, "type": answer.type}) return web.json_response({"sdp": answer.sdp, "type": answer.type})
async def get_schema(request: web.Request): async def get_schema(request: 'web.Request'):
services = request.query["services"].split(",") services = request.query["services"].split(",")
services = [s for s in services if s] services = [s for s in services if s]
assert all(s in log.Event.schema.fields and not s.endswith("DEPRECATED") for s in services), "Invalid service name" assert all(s in log.Event.schema.fields and not s.endswith("DEPRECATED") for s in services), "Invalid service name"
@ -214,7 +216,7 @@ async def get_schema(request: web.Request):
return web.json_response(schema_dict) return web.json_response(schema_dict)
async def on_shutdown(app: web.Application): async def on_shutdown(app: 'web.Application'):
for session in app['streams'].values(): for session in app['streams'].values():
session.stop() session.stop()
del app['streams'] del app['streams']

@ -6,9 +6,10 @@ import os
import ssl import ssl
import subprocess import subprocess
from aiohttp import web, ClientSession
import pyaudio import pyaudio
import wave import wave
from aiohttp import web
from aiohttp import ClientSession
from openpilot.common.basedir import BASEDIR from openpilot.common.basedir import BASEDIR
from openpilot.system.webrtc.webrtcd import StreamRequestBody from openpilot.system.webrtc.webrtcd import StreamRequestBody
@ -22,7 +23,7 @@ WEBRTCD_HOST, WEBRTCD_PORT = "localhost", 5001
## UTILS ## UTILS
async def play_sound(sound): async def play_sound(sound: str):
SOUNDS = { SOUNDS = {
"engage": "selfdrive/assets/sounds/engage.wav", "engage": "selfdrive/assets/sounds/engage.wav",
"disengage": "selfdrive/assets/sounds/disengage.wav", "disengage": "selfdrive/assets/sounds/disengage.wav",
@ -51,7 +52,7 @@ async def play_sound(sound):
p.terminate() p.terminate()
## SSL ## SSL
def create_ssl_cert(cert_path, key_path): def create_ssl_cert(cert_path: str, key_path: str):
try: try:
proc = subprocess.run(f'openssl req -x509 -newkey rsa:4096 -nodes -out {cert_path} -keyout {key_path} \ proc = subprocess.run(f'openssl req -x509 -newkey rsa:4096 -nodes -out {cert_path} -keyout {key_path} \
-days 365 -subj "/C=US/ST=California/O=commaai/OU=comma body"', -days 365 -subj "/C=US/ST=California/O=commaai/OU=comma body"',
@ -75,17 +76,17 @@ def create_ssl_context():
return ssl_context return ssl_context
## ENDPOINTS ## ENDPOINTS
async def index(request): async def index(request: 'web.Request'):
with open(os.path.join(TELEOPDIR, "static", "index.html"), "r") as f: with open(os.path.join(TELEOPDIR, "static", "index.html"), "r") as f:
content = f.read() content = f.read()
return web.Response(content_type="text/html", text=content) return web.Response(content_type="text/html", text=content)
async def ping(request): async def ping(request: 'web.Request'):
return web.Response(text="pong") return web.Response(text="pong")
async def sound(request): async def sound(request: 'web.Request'):
params = await request.json() params = await request.json()
sound_to_play = params["sound"] sound_to_play = params["sound"]
@ -93,7 +94,7 @@ async def sound(request):
return web.json_response({"status": "ok"}) return web.json_response({"status": "ok"})
async def offer(request): async def offer(request: 'web.Request'):
params = await request.json() params = await request.json()
body = StreamRequestBody(params["sdp"], ["driver"], ["testJoystick"], ["carState"]) body = StreamRequestBody(params["sdp"], ["driver"], ["testJoystick"], ["carState"])
body_json = json.dumps(dataclasses.asdict(body)) body_json = json.dumps(dataclasses.asdict(body))

Loading…
Cancel
Save