From a14ca845d6d6bf02e66493016a7a463913d77106 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20R=C4=85czy?= Date: Mon, 4 Dec 2023 20:43:19 -0800 Subject: [PATCH] webrtcd: endpoint for message schema retrieval (#30578) * Capnp json schema conversion * Schema get endpoint * Type annotation for generate_field * Filter empty services old-commit-hash: 10eb70daf707760f785490f69b2d17621698ed75 --- system/webrtc/schema.py | 43 ++++++++++++++++++++++++++++++++++++++++ system/webrtc/webrtcd.py | 12 ++++++++++- 2 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 system/webrtc/schema.py diff --git a/system/webrtc/schema.py b/system/webrtc/schema.py new file mode 100644 index 0000000000..f659b34293 --- /dev/null +++ b/system/webrtc/schema.py @@ -0,0 +1,43 @@ +import capnp +from typing import Union, List, Dict, Any + + +def generate_type(type_walker, schema_walker) -> Union[str, List[Any], Dict[str, Any]]: + data_type = next(type_walker) + if data_type.which() == 'struct': + return generate_struct(next(schema_walker)) + elif data_type.which() == 'list': + _ = next(schema_walker) + return [generate_type(type_walker, schema_walker)] + elif data_type.which() == 'enum': + return "text" + else: + return str(data_type.which()) + + +def generate_struct(schema: capnp.lib.capnp._StructSchema) -> Dict[str, Any]: + return {field: generate_field(schema.fields[field]) for field in schema.fields if not field.endswith("DEPRECATED")} + + +def generate_field(field: capnp.lib.capnp._StructSchemaField) -> Union[str, List[Any], Dict[str, Any]]: + def schema_walker(field): + yield field.schema + + s = field.schema + while hasattr(s, 'elementType'): + s = s.elementType + yield s + + def type_walker(field): + yield field.proto.slot.type + + t = field.proto.slot.type + while hasattr(getattr(t, t.which()), 'elementType'): + t = getattr(t, t.which()).elementType + yield t + + if field.proto.which() == "slot": + schema_gen, type_gen = schema_walker(field), type_walker(field) + return generate_type(type_gen, schema_gen) + else: + return generate_struct(field.schema) diff --git a/system/webrtc/webrtcd.py b/system/webrtc/webrtcd.py index 3f2ef2ceb4..12f9328532 100755 --- a/system/webrtc/webrtcd.py +++ b/system/webrtc/webrtcd.py @@ -23,8 +23,9 @@ from teleoprtc.info import parse_info_from_offer from openpilot.system.webrtc.device.video import LiveStreamVideoStreamTrack from openpilot.system.webrtc.device.audio import AudioInputStreamTrack, AudioOutputSpeaker +from openpilot.system.webrtc.schema import generate_field -from cereal import messaging +from cereal import messaging, log class CerealOutgoingMessageProxy: @@ -205,6 +206,14 @@ async def get_stream(request: web.Request): return web.json_response({"sdp": answer.sdp, "type": answer.type}) +async def get_schema(request: web.Request): + services = request.query["services"].split(",") + services = [s for s in services if s] + assert all(s in log.Event.schema.fields and not s.endswith("DEPRECATED") for s in services), "Invalid service name" + schema_dict = {s: generate_field(log.Event.schema.fields[s]) for s in services} + return web.json_response(schema_dict) + + async def on_shutdown(app: web.Application): for session in app['streams'].values(): session.stop() @@ -223,6 +232,7 @@ def webrtcd_thread(host: str, port: int, debug: bool): app['debug'] = debug app.on_shutdown.append(on_shutdown) app.router.add_post("/stream", get_stream) + app.router.add_get("/schema", get_schema) web.run_app(app, host=host, port=port)