openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

278 lines
9.4 KiB

#!/usr/bin/env python3
import argparse
import asyncio
import json
import uuid
import logging
from dataclasses import dataclass, field
from typing import Any, TYPE_CHECKING
# aiortc and its dependencies have lots of internal warnings :(
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
import capnp
from aiohttp import web
if TYPE_CHECKING:
from aiortc.rtcdatachannel import RTCDataChannel
from openpilot.system.webrtc.schema import generate_field
from cereal import messaging, log
class CerealOutgoingMessageProxy:
def __init__(self, sm: messaging.SubMaster):
self.sm = sm
self.channels: list['RTCDataChannel'] = []
def add_channel(self, channel: 'RTCDataChannel'):
self.channels.append(channel)
def to_json(self, msg_content: Any):
if isinstance(msg_content, capnp._DynamicStructReader):
msg_dict = msg_content.to_dict()
elif isinstance(msg_content, capnp._DynamicListReader):
msg_dict = [self.to_json(msg) for msg in msg_content]
elif isinstance(msg_content, bytes):
msg_dict = msg_content.decode()
else:
msg_dict = msg_content
return msg_dict
def update(self):
# this is blocking in async context...
self.sm.update(0)
for service, updated in self.sm.updated.items():
if not updated:
continue
msg_dict = self.to_json(self.sm[service])
mono_time, valid = self.sm.logMonoTime[service], self.sm.valid[service]
outgoing_msg = {"type": service, "logMonoTime": mono_time, "valid": valid, "data": msg_dict}
encoded_msg = json.dumps(outgoing_msg).encode()
for channel in self.channels:
channel.send(encoded_msg)
class CerealIncomingMessageProxy:
def __init__(self, pm: messaging.PubMaster):
self.pm = pm
def send(self, message: bytes):
msg_json = json.loads(message)
msg_type, msg_data = msg_json["type"], msg_json["data"]
size = None
if not isinstance(msg_data, dict):
size = len(msg_data)
msg = messaging.new_message(msg_type, size=size)
setattr(msg, msg_type, msg_data)
self.pm.send(msg_type, msg)
class CerealProxyRunner:
def __init__(self, proxy: CerealOutgoingMessageProxy):
self.proxy = proxy
self.is_running = False
self.task = None
self.logger = logging.getLogger("webrtcd")
def start(self):
assert self.task is None
self.task = asyncio.create_task(self.run())
def stop(self):
if self.task is None or self.task.done():
return
self.task.cancel()
self.task = None
async def run(self):
from aiortc.exceptions import InvalidStateError
while True:
try:
self.proxy.update()
except InvalidStateError:
self.logger.warning("Cereal outgoing proxy invalid state (connection closed)")
break
except Exception:
self.logger.exception("Cereal outgoing proxy failure")
await asyncio.sleep(0.01)
class DynamicPubMaster(messaging.PubMaster):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lock = asyncio.Lock()
async def add_services_if_needed(self, services):
async with self.lock:
for service in services:
if service not in self.sock:
self.sock[service] = messaging.pub_sock(service)
class StreamSession:
shared_pub_master = DynamicPubMaster([])
def __init__(self, sdp: str, cameras: list[str], incoming_services: list[str], outgoing_services: list[str], debug_mode: bool = False):
from aiortc.mediastreams import VideoStreamTrack, AudioStreamTrack
from aiortc.contrib.media import MediaBlackhole
from openpilot.system.webrtc.device.video import LiveStreamVideoStreamTrack
from openpilot.system.webrtc.device.audio import AudioInputStreamTrack, AudioOutputSpeaker
from teleoprtc import WebRTCAnswerBuilder
from teleoprtc.info import parse_info_from_offer
config = parse_info_from_offer(sdp)
builder = WebRTCAnswerBuilder(sdp)
assert len(cameras) == config.n_expected_camera_tracks, "Incoming stream has misconfigured number of video tracks"
for cam in cameras:
track = LiveStreamVideoStreamTrack(cam) if not debug_mode else VideoStreamTrack()
builder.add_video_stream(cam, track)
if config.expected_audio_track:
track = AudioInputStreamTrack() if not debug_mode else AudioStreamTrack()
builder.add_audio_stream(track)
if config.incoming_audio_track:
self.audio_output_cls = AudioOutputSpeaker if not debug_mode else MediaBlackhole
builder.offer_to_receive_audio_stream()
self.stream = builder.stream()
self.identifier = str(uuid.uuid4())
self.incoming_bridge: CerealIncomingMessageProxy | None = None
self.incoming_bridge_services = incoming_services
self.outgoing_bridge: CerealOutgoingMessageProxy | None = None
self.outgoing_bridge_runner: CerealProxyRunner | None = None
if len(incoming_services) > 0:
self.incoming_bridge = CerealIncomingMessageProxy(self.shared_pub_master)
if len(outgoing_services) > 0:
self.outgoing_bridge = CerealOutgoingMessageProxy(messaging.SubMaster(outgoing_services))
self.outgoing_bridge_runner = CerealProxyRunner(self.outgoing_bridge)
self.audio_output: AudioOutputSpeaker | MediaBlackhole | None = None
self.run_task: asyncio.Task | None = None
self.logger = logging.getLogger("webrtcd")
self.logger.info("New stream session (%s), cameras %s, audio in %s out %s, incoming services %s, outgoing services %s",
self.identifier, cameras, config.incoming_audio_track, config.expected_audio_track, incoming_services, outgoing_services)
def start(self):
self.run_task = asyncio.create_task(self.run())
def stop(self):
if self.run_task.done():
return
self.run_task.cancel()
self.run_task = None
asyncio.run(self.post_run_cleanup())
async def get_answer(self):
return await self.stream.start()
async def message_handler(self, message: bytes):
assert self.incoming_bridge is not None
try:
self.incoming_bridge.send(message)
except Exception:
self.logger.exception("Cereal incoming proxy failure")
async def run(self):
try:
await self.stream.wait_for_connection()
if self.stream.has_messaging_channel():
if self.incoming_bridge is not None:
await self.shared_pub_master.add_services_if_needed(self.incoming_bridge_services)
self.stream.set_message_handler(self.message_handler)
if self.outgoing_bridge_runner is not None:
channel = self.stream.get_messaging_channel()
self.outgoing_bridge_runner.proxy.add_channel(channel)
self.outgoing_bridge_runner.start()
if self.stream.has_incoming_audio_track():
track = self.stream.get_incoming_audio_track(buffered=False)
self.audio_output = self.audio_output_cls()
self.audio_output.addTrack(track)
self.audio_output.start()
self.logger.info("Stream session (%s) connected", self.identifier)
await self.stream.wait_for_disconnection()
await self.post_run_cleanup()
self.logger.info("Stream session (%s) ended", self.identifier)
except Exception:
self.logger.exception("Stream session failure")
async def post_run_cleanup(self):
await self.stream.stop()
if self.outgoing_bridge is not None:
self.outgoing_bridge_runner.stop()
if self.audio_output:
self.audio_output.stop()
@dataclass
class StreamRequestBody:
sdp: str
cameras: list[str]
bridge_services_in: list[str] = field(default_factory=list)
bridge_services_out: list[str] = field(default_factory=list)
async def get_stream(request: 'web.Request'):
stream_dict, debug_mode = request.app['streams'], request.app['debug']
raw_body = await request.json()
body = StreamRequestBody(**raw_body)
session = StreamSession(body.sdp, body.cameras, body.bridge_services_in, body.bridge_services_out, debug_mode)
answer = await session.get_answer()
session.start()
stream_dict[session.identifier] = session
return web.json_response({"sdp": answer.sdp, "type": answer.type})
async def get_schema(request: 'web.Request'):
services = request.query["services"].split(",")
services = [s for s in services if s]
assert all(s in log.Event.schema.fields and not s.endswith("DEPRECATED") for s in services), "Invalid service name"
schema_dict = {s: generate_field(log.Event.schema.fields[s]) for s in services}
return web.json_response(schema_dict)
async def on_shutdown(app: 'web.Application'):
for session in app['streams'].values():
session.stop()
del app['streams']
def webrtcd_thread(host: str, port: int, debug: bool):
logging.basicConfig(level=logging.CRITICAL, handlers=[logging.StreamHandler()])
logging_level = logging.DEBUG if debug else logging.INFO
logging.getLogger("WebRTCStream").setLevel(logging_level)
logging.getLogger("webrtcd").setLevel(logging_level)
app = web.Application()
app['streams'] = dict()
app['debug'] = debug
app.on_shutdown.append(on_shutdown)
app.router.add_post("/stream", get_stream)
app.router.add_get("/schema", get_schema)
web.run_app(app, host=host, port=port)
def main():
parser = argparse.ArgumentParser(description="WebRTC daemon")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to listen on")
parser.add_argument("--port", type=int, default=5001, help="Port to listen on")
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
args = parser.parse_args()
webrtcd_thread(args.host, args.port, args.debug)
if __name__=="__main__":
main()