webrtcd: webrtc streaming server (audio/video/cereal) (#30186)
* WebRTCClient and WebRTCServer abstractions * webrtc client implementation * Interactive test scripts * Send localDescriptions as offer/asnwer, as they are different * Tracks need to be added after setting remote description for multi-cam streaming to work * Remove WebRTCStreamingMetadata * Wait for tracks * Move stuff to separate files, rename some things * Refactor everything, create WebRTCStreamBuilder for both offer and answers * ta flight done time to grind * wait for incoming tracks and channels * Dummy track and frame reader track. Fix timing. * dt based on camera type * first trial of the new api * Fix audio track * methods for checking for incoming tracks * Web migration part 2 * Fixes for stream api * use rtc description for web.py * experimental cereal proxy * remove old code from bodyav * fix is_started * serialize session description * fix audio * messaging channel wrapper * fix audiotrack * h264 codec preference * Add codec preference to tracks * override sdp codecs * add logging * Move cli stuff to separate file * slight cleanup * Fix audio track * create codec_mime inside force_codec function * fix incoming media estimation * move builders to __init__ * stream updates following builders * Update example script * web.py support for new builder * web speaker fixes * StreamingMediaInfo API * Move things around * should_add_data_channel rename * is_connected_and_ready * fix linter errors * make cli executable * remove dumb comments * logging support * fix parse_info_from_offer * improve type annotations * satisfy linters * Support for waiting for disconnection * Split device tracks into video/audio files. Move audio speaker to audio.py * default dt for dummy video track * Fix cli * new speaker fixes * Remove almost all functionality from web.py * webrtcd * continue refactoring web.py * after handling joystick reset in controlsd with #30409, controls are not necessary anymore * ping endpoint * Update js files to at least support what worked previously * Fixes after some tests on the body * Streaming fixes * Remove the use of WebRTCStreamBuilder. Subclass use is now required * Add todo * delete all streams on shutdown * Replace lastPing with lastChannelMessageTime * Update ping text only if rtc is still on * That should affect the chart too * Fix paths in web * use protocol in SSLContext * remove warnings since aiortc is not used directly anymore * check if task is done in stop * remove channel handler wrapper, since theres only one channel * Move things around * Moved webrtc abstractions to separate repository * Moved webrtcd to tools/webrtc * Update imports * Add bodyrtc as dependency * Add webrtcd to process_config * Remove usage of DummyVideoStreamTrack * Add main to webrtcd * Move webrtcd to system * Fix imports * Move cereal proxy logic outside of runner * Incoming proxy abstractions * Add some tests * Make it executable * Fix process config * Fix imports * Additional tests. Add tests to pyproject.toml * Update poetry lock * New line * Bump aiortc to 1.6.0 * Added teleoprtc_repo as submodule, and linked its source dir * Add init file to webrtc module * Handle aiortc warnings * Ignore deprecation warnings * Ignore resource warning too * Ignore the warnings * find free port for test_webrtcd * Start process inside the test case * random sleep test * test 2 * Test endpoint function instead * Update comment * Add system/webrtc to release * default arguments for body fields * Add teleoprtc to release * Bump teleoprtc * Exclude teleoprtc from static analysis * Use separate event loop for stream session testspull/30581/head
parent
e34ee43eea
commit
f058b5d64e
18 changed files with 788 additions and 474 deletions
@ -0,0 +1,110 @@ |
||||
import asyncio |
||||
import io |
||||
from typing import Optional, List, Tuple |
||||
|
||||
import aiortc |
||||
import av |
||||
import numpy as np |
||||
import pyaudio |
||||
|
||||
|
||||
class AudioInputStreamTrack(aiortc.mediastreams.AudioStreamTrack): |
||||
PYAUDIO_TO_AV_FORMAT_MAP = { |
||||
pyaudio.paUInt8: 'u8', |
||||
pyaudio.paInt16: 's16', |
||||
pyaudio.paInt24: 's24', |
||||
pyaudio.paInt32: 's32', |
||||
pyaudio.paFloat32: 'flt', |
||||
} |
||||
|
||||
def __init__(self, audio_format: int = pyaudio.paInt16, rate: int = 16000, channels: int = 1, packet_time: float = 0.020, device_index: Optional[int] = None): |
||||
super().__init__() |
||||
|
||||
self.p = pyaudio.PyAudio() |
||||
chunk_size = int(packet_time * rate) |
||||
self.stream = self.p.open(format=audio_format, |
||||
channels=channels, |
||||
rate=rate, |
||||
frames_per_buffer=chunk_size, |
||||
input=True, |
||||
input_device_index=device_index) |
||||
self.format = audio_format |
||||
self.rate = rate |
||||
self.channels = channels |
||||
self.packet_time = packet_time |
||||
self.chunk_size = chunk_size |
||||
self.pts = 0 |
||||
|
||||
async def recv(self): |
||||
mic_data = self.stream.read(self.chunk_size) |
||||
mic_array = np.frombuffer(mic_data, dtype=np.int16) |
||||
mic_array = np.expand_dims(mic_array, axis=0) |
||||
layout = 'stereo' if self.channels > 1 else 'mono' |
||||
frame = av.AudioFrame.from_ndarray(mic_array, format=self.PYAUDIO_TO_AV_FORMAT_MAP[self.format], layout=layout) |
||||
frame.rate = self.rate |
||||
frame.pts = self.pts |
||||
self.pts += frame.samples |
||||
|
||||
return frame |
||||
|
||||
|
||||
class AudioOutputSpeaker: |
||||
def __init__(self, audio_format: int = pyaudio.paInt16, rate: int = 48000, channels: int = 2, packet_time: float = 0.2, device_index: Optional[int] = None): |
||||
|
||||
chunk_size = int(packet_time * rate) |
||||
self.p = pyaudio.PyAudio() |
||||
self.buffer = io.BytesIO() |
||||
self.channels = channels |
||||
self.stream = self.p.open(format=audio_format, |
||||
channels=channels, |
||||
rate=rate, |
||||
frames_per_buffer=chunk_size, |
||||
output=True, |
||||
output_device_index=device_index, |
||||
stream_callback=self.__pyaudio_callback) |
||||
self.tracks_and_tasks: List[Tuple[aiortc.MediaStreamTrack, Optional[asyncio.Task]]] = [] |
||||
|
||||
def __pyaudio_callback(self, in_data, frame_count, time_info, status): |
||||
if self.buffer.getbuffer().nbytes < frame_count * self.channels * 2: |
||||
buff = b'\x00\x00' * frame_count * self.channels |
||||
elif self.buffer.getbuffer().nbytes > 115200: # 3x the usual read size |
||||
self.buffer.seek(0) |
||||
buff = self.buffer.read(frame_count * self.channels * 4) |
||||
buff = buff[:frame_count * self.channels * 2] |
||||
self.buffer.seek(2) |
||||
else: |
||||
self.buffer.seek(0) |
||||
buff = self.buffer.read(frame_count * self.channels * 2) |
||||
self.buffer.seek(2) |
||||
return (buff, pyaudio.paContinue) |
||||
|
||||
async def __consume(self, track): |
||||
while True: |
||||
try: |
||||
frame = await track.recv() |
||||
except aiortc.MediaStreamError: |
||||
return |
||||
|
||||
self.buffer.write(bytes(frame.planes[0])) |
||||
|
||||
def hasTrack(self, track: aiortc.MediaStreamTrack) -> bool: |
||||
return any(t == track for t, _ in self.tracks_and_tasks) |
||||
|
||||
def addTrack(self, track: aiortc.MediaStreamTrack): |
||||
if not self.hasTrack(track): |
||||
self.tracks_and_tasks.append((track, None)) |
||||
|
||||
def start(self): |
||||
for index, (track, task) in enumerate(self.tracks_and_tasks): |
||||
if task is None: |
||||
self.tracks_and_tasks[index] = (track, asyncio.create_task(self.__consume(track))) |
||||
|
||||
def stop(self): |
||||
for _, task in self.tracks_and_tasks: |
||||
if task is not None: |
||||
task.cancel() |
||||
|
||||
self.tracks_and_tasks = [] |
||||
self.stream.stop_stream() |
||||
self.stream.close() |
||||
self.p.terminate() |
@ -0,0 +1,69 @@ |
||||
import asyncio |
||||
from typing import Optional |
||||
|
||||
import av |
||||
from teleoprtc.tracks import TiciVideoStreamTrack |
||||
|
||||
from cereal import messaging |
||||
from openpilot.tools.lib.framereader import FrameReader |
||||
from openpilot.common.realtime import DT_MDL, DT_DMON |
||||
|
||||
|
||||
class LiveStreamVideoStreamTrack(TiciVideoStreamTrack): |
||||
camera_to_sock_mapping = { |
||||
"driver": "livestreamDriverEncodeData", |
||||
"wideRoad": "livestreamWideRoadEncodeData", |
||||
"road": "livestreamRoadEncodeData", |
||||
} |
||||
|
||||
def __init__(self, camera_type: str): |
||||
dt = DT_DMON if camera_type == "driver" else DT_MDL |
||||
super().__init__(camera_type, dt) |
||||
|
||||
self._sock = messaging.sub_sock(self.camera_to_sock_mapping[camera_type], conflate=True) |
||||
self._pts = 0 |
||||
|
||||
async def recv(self): |
||||
while True: |
||||
msg = messaging.recv_one_or_none(self._sock) |
||||
if msg is not None: |
||||
break |
||||
await asyncio.sleep(0.005) |
||||
|
||||
evta = getattr(msg, msg.which()) |
||||
|
||||
packet = av.Packet(evta.header + evta.data) |
||||
packet.time_base = self._time_base |
||||
packet.pts = self._pts |
||||
|
||||
self.log_debug("track sending frame %s", self._pts) |
||||
self._pts += self._dt * self._clock_rate |
||||
|
||||
return packet |
||||
|
||||
def codec_preference(self) -> Optional[str]: |
||||
return "H264" |
||||
|
||||
|
||||
class FrameReaderVideoStreamTrack(TiciVideoStreamTrack): |
||||
def __init__(self, input_file: str, dt: float = DT_MDL, camera_type: str = "driver"): |
||||
super().__init__(camera_type, dt) |
||||
|
||||
frame_reader = FrameReader(input_file) |
||||
self._frames = [frame_reader.get(i, pix_fmt="rgb24") for i in range(frame_reader.frame_count)] |
||||
self._frame_count = len(self.frames) |
||||
self._frame_index = 0 |
||||
self._pts = 0 |
||||
|
||||
async def recv(self): |
||||
self.log_debug("track sending frame %s", self._pts) |
||||
img = self._frames[self._frame_index] |
||||
|
||||
new_frame = av.VideoFrame.from_ndarray(img, format="rgb24") |
||||
new_frame.pts = self._pts |
||||
new_frame.time_base = self._time_base |
||||
|
||||
self._frame_index = (self._frame_index + 1) % self._frame_count |
||||
self._pts = await self.next_pts(self._pts) |
||||
|
||||
return new_frame |
@ -0,0 +1,108 @@ |
||||
#!/usr/bin/env python3 |
||||
import asyncio |
||||
import unittest |
||||
from unittest.mock import Mock, MagicMock, patch |
||||
import json |
||||
# for aiortc and its dependencies |
||||
import warnings |
||||
warnings.filterwarnings("ignore", category=DeprecationWarning) |
||||
|
||||
from aiortc import RTCDataChannel |
||||
from aiortc.mediastreams import VIDEO_CLOCK_RATE, VIDEO_TIME_BASE |
||||
import capnp |
||||
import pyaudio |
||||
|
||||
from cereal import messaging, log |
||||
|
||||
from openpilot.system.webrtc.webrtcd import CerealOutgoingMessageProxy, CerealIncomingMessageProxy |
||||
from openpilot.system.webrtc.device.video import LiveStreamVideoStreamTrack |
||||
from openpilot.system.webrtc.device.audio import AudioInputStreamTrack |
||||
from openpilot.common.realtime import DT_DMON |
||||
|
||||
|
||||
class TestStreamSession(unittest.TestCase): |
||||
def setUp(self): |
||||
self.loop = asyncio.new_event_loop() |
||||
|
||||
def tearDown(self): |
||||
self.loop.stop() |
||||
self.loop.close() |
||||
|
||||
def test_outgoing_proxy(self): |
||||
test_msg = log.Event.new_message() |
||||
test_msg.logMonoTime = 123 |
||||
test_msg.valid = True |
||||
test_msg.customReservedRawData0 = b"test" |
||||
expected_dict = {"type": "customReservedRawData0", "logMonoTime": 123, "valid": True, "data": "test"} |
||||
expected_json = json.dumps(expected_dict).encode() |
||||
|
||||
channel = Mock(spec=RTCDataChannel) |
||||
mocked_submaster = messaging.SubMaster(["customReservedRawData0"]) |
||||
def mocked_update(t): |
||||
mocked_submaster.update_msgs(0, [test_msg]) |
||||
|
||||
with patch.object(messaging.SubMaster, "update", side_effect=mocked_update): |
||||
proxy = CerealOutgoingMessageProxy(mocked_submaster) |
||||
proxy.add_channel(channel) |
||||
|
||||
proxy.update() |
||||
|
||||
channel.send.assert_called_once_with(expected_json) |
||||
|
||||
def test_incoming_proxy(self): |
||||
tested_msgs = [ |
||||
{"type": "customReservedRawData0", "data": "test"}, # primitive |
||||
{"type": "can", "data": [{"address": 0, "busTime": 0, "dat": "", "src": 0}]}, # list |
||||
{"type": "testJoystick", "data": {"axes": [0, 0], "buttons": [False]}}, # dict |
||||
] |
||||
|
||||
mocked_pubmaster = MagicMock(spec=messaging.PubMaster) |
||||
|
||||
proxy = CerealIncomingMessageProxy(mocked_pubmaster) |
||||
|
||||
for msg in tested_msgs: |
||||
proxy.send(json.dumps(msg).encode()) |
||||
|
||||
mocked_pubmaster.send.assert_called_once() |
||||
mt, md = mocked_pubmaster.send.call_args.args |
||||
self.assertEqual(mt, msg["type"]) |
||||
self.assertIsInstance(md, capnp._DynamicStructBuilder) |
||||
self.assertTrue(hasattr(md, msg["type"])) |
||||
|
||||
mocked_pubmaster.reset_mock() |
||||
|
||||
def test_livestream_track(self): |
||||
fake_msg = messaging.new_message("livestreamDriverEncodeData") |
||||
|
||||
config = {"receive.return_value": fake_msg.to_bytes()} |
||||
with patch("cereal.messaging.SubSocket", spec=True, **config): |
||||
track = LiveStreamVideoStreamTrack("driver") |
||||
|
||||
self.assertTrue(track.id.startswith("driver")) |
||||
self.assertEqual(track.codec_preference(), "H264") |
||||
|
||||
for i in range(5): |
||||
packet = self.loop.run_until_complete(track.recv()) |
||||
self.assertEqual(packet.time_base, VIDEO_TIME_BASE) |
||||
self.assertEqual(packet.pts, int(i * DT_DMON * VIDEO_CLOCK_RATE)) |
||||
self.assertEqual(packet.size, 0) |
||||
|
||||
def test_input_audio_track(self): |
||||
packet_time, rate = 0.02, 16000 |
||||
sample_count = int(packet_time * rate) |
||||
mocked_stream = MagicMock(spec=pyaudio.Stream) |
||||
mocked_stream.read.return_value = b"\x00" * 2 * sample_count |
||||
|
||||
config = {"open.side_effect": lambda *args, **kwargs: mocked_stream} |
||||
with patch("pyaudio.PyAudio", spec=True, **config): |
||||
track = AudioInputStreamTrack(audio_format=pyaudio.paInt16, packet_time=packet_time, rate=rate) |
||||
|
||||
for i in range(5): |
||||
frame = self.loop.run_until_complete(track.recv()) |
||||
self.assertEqual(frame.rate, rate) |
||||
self.assertEqual(frame.samples, sample_count) |
||||
self.assertEqual(frame.pts, i * sample_count) |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
unittest.main() |
@ -0,0 +1,60 @@ |
||||
#!/usr/bin/env python |
||||
import asyncio |
||||
import json |
||||
import unittest |
||||
from unittest.mock import MagicMock, AsyncMock |
||||
# for aiortc and its dependencies |
||||
import warnings |
||||
warnings.filterwarnings("ignore", category=DeprecationWarning) |
||||
|
||||
from openpilot.system.webrtc.webrtcd import get_stream |
||||
|
||||
import aiortc |
||||
from teleoprtc import WebRTCOfferBuilder |
||||
|
||||
|
||||
class TestWebrtcdProc(unittest.IsolatedAsyncioTestCase): |
||||
async def assertCompletesWithTimeout(self, awaitable, timeout=1): |
||||
try: |
||||
async with asyncio.timeout(timeout): |
||||
await awaitable |
||||
except asyncio.TimeoutError: |
||||
self.fail("Timeout while waiting for awaitable to complete") |
||||
|
||||
async def test_webrtcd(self): |
||||
mock_request = MagicMock() |
||||
async def connect(offer): |
||||
body = {'sdp': offer.sdp, 'cameras': offer.video, 'bridge_services_in': [], 'bridge_services_out': []} |
||||
mock_request.json.side_effect = AsyncMock(return_value=body) |
||||
response = await get_stream(mock_request) |
||||
response_json = json.loads(response.text) |
||||
return aiortc.RTCSessionDescription(**response_json) |
||||
|
||||
builder = WebRTCOfferBuilder(connect) |
||||
builder.offer_to_receive_video_stream("road") |
||||
builder.offer_to_receive_audio_stream() |
||||
builder.add_messaging() |
||||
|
||||
stream = builder.stream() |
||||
|
||||
await self.assertCompletesWithTimeout(stream.start()) |
||||
await self.assertCompletesWithTimeout(stream.wait_for_connection()) |
||||
|
||||
self.assertTrue(stream.has_incoming_video_track("road")) |
||||
self.assertTrue(stream.has_incoming_audio_track()) |
||||
self.assertTrue(stream.has_messaging_channel()) |
||||
|
||||
video_track, audio_track = stream.get_incoming_video_track("road"), stream.get_incoming_audio_track() |
||||
await self.assertCompletesWithTimeout(video_track.recv()) |
||||
await self.assertCompletesWithTimeout(audio_track.recv()) |
||||
|
||||
await self.assertCompletesWithTimeout(stream.stop()) |
||||
|
||||
# cleanup, very implementation specific, test may break if it changes |
||||
self.assertTrue(mock_request.app["streams"].__setitem__.called, "Implementation changed, please update this test") |
||||
_, session = mock_request.app["streams"].__setitem__.call_args.args |
||||
await self.assertCompletesWithTimeout(session.post_run_cleanup()) |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
unittest.main() |
@ -0,0 +1,237 @@ |
||||
#!/usr/bin/env python3 |
||||
|
||||
import argparse |
||||
import asyncio |
||||
import json |
||||
import uuid |
||||
import logging |
||||
from dataclasses import dataclass, field |
||||
from typing import Any, List, Optional, Union |
||||
|
||||
# aiortc and its dependencies have lots of internal warnings :( |
||||
import warnings |
||||
warnings.filterwarnings("ignore", category=DeprecationWarning) |
||||
|
||||
import aiortc |
||||
from aiortc.mediastreams import VideoStreamTrack, AudioStreamTrack |
||||
from aiortc.contrib.media import MediaBlackhole |
||||
from aiohttp import web |
||||
import capnp |
||||
from teleoprtc import WebRTCAnswerBuilder |
||||
from teleoprtc.info import parse_info_from_offer |
||||
|
||||
from openpilot.system.webrtc.device.video import LiveStreamVideoStreamTrack |
||||
from openpilot.system.webrtc.device.audio import AudioInputStreamTrack, AudioOutputSpeaker |
||||
|
||||
from cereal import messaging |
||||
|
||||
|
||||
class CerealOutgoingMessageProxy: |
||||
def __init__(self, sm: messaging.SubMaster): |
||||
self.sm = sm |
||||
self.channels: List[aiortc.RTCDataChannel] = [] |
||||
|
||||
def add_channel(self, channel: aiortc.RTCDataChannel): |
||||
self.channels.append(channel) |
||||
|
||||
def to_json(self, msg_content: Any): |
||||
if isinstance(msg_content, capnp._DynamicStructReader): |
||||
msg_dict = msg_content.to_dict() |
||||
elif isinstance(msg_content, capnp._DynamicListReader): |
||||
msg_dict = [self.to_json(msg) for msg in msg_content] |
||||
elif isinstance(msg_content, bytes): |
||||
msg_dict = msg_content.decode() |
||||
else: |
||||
msg_dict = msg_content |
||||
|
||||
return msg_dict |
||||
|
||||
def update(self): |
||||
# this is blocking in async context... |
||||
self.sm.update(0) |
||||
for service, updated in self.sm.updated.items(): |
||||
if not updated: |
||||
continue |
||||
msg_dict = self.to_json(self.sm[service]) |
||||
mono_time, valid = self.sm.logMonoTime[service], self.sm.valid[service] |
||||
outgoing_msg = {"type": service, "logMonoTime": mono_time, "valid": valid, "data": msg_dict} |
||||
encoded_msg = json.dumps(outgoing_msg).encode() |
||||
for channel in self.channels: |
||||
channel.send(encoded_msg) |
||||
|
||||
|
||||
class CerealIncomingMessageProxy: |
||||
def __init__(self, pm: messaging.PubMaster): |
||||
self.pm = pm |
||||
|
||||
def send(self, message: bytes): |
||||
msg_json = json.loads(message) |
||||
msg_type, msg_data = msg_json["type"], msg_json["data"] |
||||
size = None |
||||
if not isinstance(msg_data, dict): |
||||
size = len(msg_data) |
||||
|
||||
msg = messaging.new_message(msg_type, size=size) |
||||
setattr(msg, msg_type, msg_data) |
||||
self.pm.send(msg_type, msg) |
||||
|
||||
|
||||
class CerealProxyRunner: |
||||
def __init__(self, proxy: CerealOutgoingMessageProxy): |
||||
self.proxy = proxy |
||||
self.is_running = False |
||||
self.task = None |
||||
self.logger = logging.getLogger("webrtcd") |
||||
|
||||
def start(self): |
||||
assert self.task is None |
||||
self.task = asyncio.create_task(self.run()) |
||||
|
||||
def stop(self): |
||||
if self.task is None or self.task.done(): |
||||
return |
||||
self.task.cancel() |
||||
self.task = None |
||||
|
||||
async def run(self): |
||||
while True: |
||||
try: |
||||
self.proxy.update() |
||||
except Exception as ex: |
||||
self.logger.error("Cereal outgoing proxy failure: %s", ex) |
||||
await asyncio.sleep(0.01) |
||||
|
||||
|
||||
class StreamSession: |
||||
def __init__(self, sdp: str, cameras: List[str], incoming_services: List[str], outgoing_services: List[str], debug_mode: bool = False): |
||||
config = parse_info_from_offer(sdp) |
||||
builder = WebRTCAnswerBuilder(sdp) |
||||
|
||||
assert len(cameras) == config.n_expected_camera_tracks, "Incoming stream has misconfigured number of video tracks" |
||||
for cam in cameras: |
||||
track = LiveStreamVideoStreamTrack(cam) if not debug_mode else VideoStreamTrack() |
||||
builder.add_video_stream(cam, track) |
||||
if config.expected_audio_track: |
||||
track = AudioInputStreamTrack() if not debug_mode else AudioStreamTrack() |
||||
builder.add_audio_stream(track) |
||||
if config.incoming_audio_track: |
||||
self.audio_output_cls = AudioOutputSpeaker if not debug_mode else MediaBlackhole |
||||
builder.offer_to_receive_audio_stream() |
||||
|
||||
self.stream = builder.stream() |
||||
self.identifier = str(uuid.uuid4()) |
||||
|
||||
self.outgoing_bridge = CerealOutgoingMessageProxy(messaging.SubMaster(outgoing_services)) |
||||
self.incoming_bridge = CerealIncomingMessageProxy(messaging.PubMaster(incoming_services)) |
||||
self.outgoing_bridge_runner = CerealProxyRunner(self.outgoing_bridge) |
||||
|
||||
self.audio_output: Optional[Union[AudioOutputSpeaker, MediaBlackhole]] = None |
||||
self.run_task: Optional[asyncio.Task] = None |
||||
self.logger = logging.getLogger("webrtcd") |
||||
self.logger.info("New stream session (%s), cameras %s, audio in %s out %s, incoming services %s, outgoing services %s", |
||||
self.identifier, cameras, config.incoming_audio_track, config.expected_audio_track, incoming_services, outgoing_services) |
||||
|
||||
def start(self): |
||||
self.run_task = asyncio.create_task(self.run()) |
||||
|
||||
def stop(self): |
||||
if self.run_task.done(): |
||||
return |
||||
self.run_task.cancel() |
||||
self.run_task = None |
||||
asyncio.run(self.post_run_cleanup()) |
||||
|
||||
async def get_answer(self): |
||||
return await self.stream.start() |
||||
|
||||
async def message_handler(self, message: bytes): |
||||
try: |
||||
self.incoming_bridge.send(message) |
||||
except Exception as ex: |
||||
self.logger.error("Cereal incoming proxy failure: %s", ex) |
||||
|
||||
async def run(self): |
||||
try: |
||||
await self.stream.wait_for_connection() |
||||
if self.stream.has_messaging_channel(): |
||||
self.stream.set_message_handler(self.message_handler) |
||||
channel = self.stream.get_messaging_channel() |
||||
self.outgoing_bridge_runner.proxy.add_channel(channel) |
||||
self.outgoing_bridge_runner.start() |
||||
if self.stream.has_incoming_audio_track(): |
||||
track = self.stream.get_incoming_audio_track(buffered=False) |
||||
self.audio_output = self.audio_output_cls() |
||||
self.audio_output.addTrack(track) |
||||
self.audio_output.start() |
||||
self.logger.info("Stream session (%s) connected", self.identifier) |
||||
|
||||
await self.stream.wait_for_disconnection() |
||||
await self.post_run_cleanup() |
||||
|
||||
self.logger.info("Stream session (%s) ended", self.identifier) |
||||
except Exception as ex: |
||||
self.logger.error("Stream session failure: %s", ex) |
||||
|
||||
async def post_run_cleanup(self): |
||||
await self.stream.stop() |
||||
self.outgoing_bridge_runner.stop() |
||||
if self.audio_output: |
||||
self.audio_output.stop() |
||||
|
||||
|
||||
@dataclass |
||||
class StreamRequestBody: |
||||
sdp: str |
||||
cameras: List[str] |
||||
bridge_services_in: List[str] = field(default_factory=list) |
||||
bridge_services_out: List[str] = field(default_factory=list) |
||||
|
||||
|
||||
async def get_stream(request: web.Request): |
||||
stream_dict, debug_mode = request.app['streams'], request.app['debug'] |
||||
raw_body = await request.json() |
||||
body = StreamRequestBody(**raw_body) |
||||
|
||||
session = StreamSession(body.sdp, body.cameras, body.bridge_services_in, body.bridge_services_out, debug_mode) |
||||
answer = await session.get_answer() |
||||
session.start() |
||||
|
||||
stream_dict[session.identifier] = session |
||||
|
||||
return web.json_response({"sdp": answer.sdp, "type": answer.type}) |
||||
|
||||
|
||||
async def on_shutdown(app: web.Application): |
||||
for session in app['streams'].values(): |
||||
session.stop() |
||||
del app['streams'] |
||||
|
||||
|
||||
def webrtcd_thread(host: str, port: int, debug: bool): |
||||
logging.basicConfig(level=logging.CRITICAL, handlers=[logging.StreamHandler()]) |
||||
logging_level = logging.DEBUG if debug else logging.INFO |
||||
logging.getLogger("WebRTCStream").setLevel(logging_level) |
||||
logging.getLogger("webrtcd").setLevel(logging_level) |
||||
|
||||
app = web.Application() |
||||
|
||||
app['streams'] = dict() |
||||
app['debug'] = debug |
||||
app.on_shutdown.append(on_shutdown) |
||||
app.router.add_post("/stream", get_stream) |
||||
|
||||
web.run_app(app, host=host, port=port) |
||||
|
||||
|
||||
def main(): |
||||
parser = argparse.ArgumentParser(description="WebRTC daemon") |
||||
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to listen on") |
||||
parser.add_argument("--port", type=int, default=5001, help="Port to listen on") |
||||
parser.add_argument("--debug", action="store_true", help="Enable debug mode") |
||||
args = parser.parse_args() |
||||
|
||||
webrtcd_thread(args.host, args.port, args.debug) |
||||
|
||||
|
||||
if __name__=="__main__": |
||||
main() |
@ -0,0 +1 @@ |
||||
Subproject commit 8ec477868591eed9a6136a44f16428bc0468b4e9 |
@ -1,159 +0,0 @@ |
||||
import asyncio |
||||
import io |
||||
import numpy as np |
||||
import pyaudio |
||||
import wave |
||||
|
||||
from aiortc.contrib.media import MediaBlackhole |
||||
from aiortc.mediastreams import AudioStreamTrack, MediaStreamError, MediaStreamTrack |
||||
from aiortc.mediastreams import VIDEO_CLOCK_RATE, VIDEO_TIME_BASE |
||||
from aiortc.rtcrtpsender import RTCRtpSender |
||||
from av import CodecContext, Packet |
||||
from pydub import AudioSegment |
||||
import cereal.messaging as messaging |
||||
|
||||
AUDIO_RATE = 16000 |
||||
SOUNDS = { |
||||
'engage': '../../selfdrive/assets/sounds/engage.wav', |
||||
'disengage': '../../selfdrive/assets/sounds/disengage.wav', |
||||
'error': '../../selfdrive/assets/sounds/warning_immediate.wav', |
||||
} |
||||
|
||||
|
||||
def force_codec(pc, sender, forced_codec='video/VP9', stream_type="video"): |
||||
codecs = RTCRtpSender.getCapabilities(stream_type).codecs |
||||
codec = [codec for codec in codecs if codec.mimeType == forced_codec] |
||||
transceiver = next(t for t in pc.getTransceivers() if t.sender == sender) |
||||
transceiver.setCodecPreferences(codec) |
||||
|
||||
|
||||
class EncodedBodyVideo(MediaStreamTrack): |
||||
kind = "video" |
||||
|
||||
_start: float |
||||
_timestamp: int |
||||
|
||||
def __init__(self): |
||||
super().__init__() |
||||
sock_name = 'livestreamDriverEncodeData' |
||||
messaging.context = messaging.Context() |
||||
self.sock = messaging.sub_sock(sock_name, None, conflate=True) |
||||
self.pts = 0 |
||||
|
||||
async def recv(self) -> Packet: |
||||
while True: |
||||
msg = messaging.recv_one_or_none(self.sock) |
||||
if msg is not None: |
||||
break |
||||
await asyncio.sleep(0.005) |
||||
|
||||
evta = getattr(msg, msg.which()) |
||||
self.last_idx = evta.idx.encodeId |
||||
|
||||
packet = Packet(evta.header + evta.data) |
||||
packet.time_base = VIDEO_TIME_BASE |
||||
packet.pts = self.pts |
||||
self.pts += 0.05 * VIDEO_CLOCK_RATE |
||||
return packet |
||||
|
||||
|
||||
class WebClientSpeaker(MediaBlackhole): |
||||
def __init__(self): |
||||
super().__init__() |
||||
self.p = pyaudio.PyAudio() |
||||
self.buffer = io.BytesIO() |
||||
self.channels = 2 |
||||
self.stream = self.p.open(format=pyaudio.paInt16, channels=self.channels, rate=48000, frames_per_buffer=9600, |
||||
output=True, stream_callback=self.pyaudio_callback) |
||||
|
||||
def pyaudio_callback(self, in_data, frame_count, time_info, status): |
||||
if self.buffer.getbuffer().nbytes < frame_count * self.channels * 2: |
||||
buff = np.zeros((frame_count, 2), dtype=np.int16).tobytes() |
||||
elif self.buffer.getbuffer().nbytes > 115200: # 3x the usual read size |
||||
self.buffer.seek(0) |
||||
buff = self.buffer.read(frame_count * self.channels * 4) |
||||
buff = buff[:frame_count * self.channels * 2] |
||||
self.buffer.seek(2) |
||||
else: |
||||
self.buffer.seek(0) |
||||
buff = self.buffer.read(frame_count * self.channels * 2) |
||||
self.buffer.seek(2) |
||||
return (buff, pyaudio.paContinue) |
||||
|
||||
async def consume(self, track): |
||||
while True: |
||||
try: |
||||
frame = await track.recv() |
||||
except MediaStreamError: |
||||
return |
||||
bio = bytes(frame.planes[0]) |
||||
self.buffer.write(bio) |
||||
|
||||
async def start(self): |
||||
for track, task in self._MediaBlackhole__tracks.items(): |
||||
if task is None: |
||||
self._MediaBlackhole__tracks[track] = asyncio.ensure_future(self.consume(track)) |
||||
|
||||
async def stop(self): |
||||
for task in self._MediaBlackhole__tracks.values(): |
||||
if task is not None: |
||||
task.cancel() |
||||
self._MediaBlackhole__tracks = {} |
||||
self.stream.stop_stream() |
||||
self.stream.close() |
||||
self.p.terminate() |
||||
|
||||
|
||||
class BodyMic(AudioStreamTrack): |
||||
def __init__(self): |
||||
super().__init__() |
||||
|
||||
self.sample_rate = AUDIO_RATE |
||||
self.AUDIO_PTIME = 0.020 # 20ms audio packetization |
||||
self.samples = int(self.AUDIO_PTIME * self.sample_rate) |
||||
self.FORMAT = pyaudio.paInt16 |
||||
self.CHANNELS = 2 |
||||
self.RATE = self.sample_rate |
||||
self.CHUNK = int(AUDIO_RATE * 0.020) |
||||
self.p = pyaudio.PyAudio() |
||||
self.mic_stream = self.p.open(format=self.FORMAT, channels=1, rate=self.RATE, input=True, frames_per_buffer=self.CHUNK) |
||||
|
||||
self.codec = CodecContext.create('pcm_s16le', 'r') |
||||
self.codec.sample_rate = self.RATE |
||||
self.codec.channels = 2 |
||||
self.audio_samples = 0 |
||||
self.chunk_number = 0 |
||||
|
||||
async def recv(self): |
||||
mic_data = self.mic_stream.read(self.CHUNK) |
||||
mic_sound = AudioSegment(mic_data, sample_width=2, channels=1, frame_rate=self.RATE) |
||||
mic_sound = AudioSegment.from_mono_audiosegments(mic_sound, mic_sound) |
||||
mic_sound += 3 # increase volume by 3db |
||||
packet = Packet(mic_sound.raw_data) |
||||
frame = self.codec.decode(packet)[0] |
||||
frame.pts = self.audio_samples |
||||
self.audio_samples += frame.samples |
||||
self.chunk_number = self.chunk_number + 1 |
||||
return frame |
||||
|
||||
|
||||
async def play_sound(sound): |
||||
chunk = 5120 |
||||
with wave.open(SOUNDS[sound], 'rb') as wf: |
||||
def callback(in_data, frame_count, time_info, status): |
||||
data = wf.readframes(frame_count) |
||||
return data, pyaudio.paContinue |
||||
|
||||
p = pyaudio.PyAudio() |
||||
stream = p.open(format=p.get_format_from_width(wf.getsampwidth()), |
||||
channels=wf.getnchannels(), |
||||
rate=wf.getframerate(), |
||||
output=True, |
||||
frames_per_buffer=chunk, |
||||
stream_callback=callback) |
||||
stream.start_stream() |
||||
while stream.is_active(): |
||||
await asyncio.sleep(0) |
||||
stream.stop_stream() |
||||
stream.close() |
||||
p.terminate() |
Loading…
Reference in new issue