From 6048eb568aa57153affc69e7c74c1975a694cc8b Mon Sep 17 00:00:00 2001 From: "Quantizr (Jimmy)" <9859727+Quantizr@users.noreply.github.com> Date: Sat, 23 Aug 2025 03:15:15 -0700 Subject: [PATCH] simplified data.py --- tools/jotpluggler/data.py | 167 ++++++++++++++--------------------- tools/jotpluggler/pluggle.py | 16 ++-- tools/jotpluggler/views.py | 25 +++--- 3 files changed, 85 insertions(+), 123 deletions(-) diff --git a/tools/jotpluggler/data.py b/tools/jotpluggler/data.py index a9eca9766a..2b31df1811 100644 --- a/tools/jotpluggler/data.py +++ b/tools/jotpluggler/data.py @@ -1,107 +1,34 @@ import threading import numpy as np -from abc import ABC, abstractmethod -from typing import Any +from collections.abc import Callable from openpilot.common.swaglog import cloudlog from openpilot.tools.lib.logreader import LogReader from openpilot.tools.lib.log_time_series import msgs_to_time_series -# TODO: support cereal/ZMQ streaming -class DataSource(ABC): - @abstractmethod - def load_data(self) -> dict[str, Any]: - pass - - @abstractmethod - def get_duration(self) -> float: - pass - - -class LogReaderSource(DataSource): - def __init__(self, route_name: str): - self.route_name = route_name - self._duration = 0.0 - self._start_time_mono = 0.0 - - def load_data(self) -> dict[str, Any]: - lr = LogReader(self.route_name) - raw_time_series = msgs_to_time_series(lr) - processed_data = self._expand_list_fields(raw_time_series) - - min_time = float('inf') - max_time = float('-inf') - for data in processed_data.values(): - min_time = min(min_time, data['t'][0]) - max_time = max(max_time, data['t'][-1]) - self._start_time_mono = min_time - self._duration = max_time - min_time - - return {'time_series_data': processed_data, 'route_start_time_mono': self._start_time_mono, 'duration': self._duration} - - def get_duration(self) -> float: - return self._duration - - # TODO: lists are expanded, but lists of structs are not - def _expand_list_fields(self, time_series_data): - expanded_data = {} - for msg_type, data in time_series_data.items(): - expanded_data[msg_type] = {} - for field, values in data.items(): - if field == 't': - expanded_data[msg_type]['t'] = values - continue - - if values.dtype == object: # ragged array - lens = np.fromiter((len(v) for v in values), dtype=int, count=len(values)) - max_len = lens.max() if lens.size else 0 - if max_len > 0: - arr = np.full((len(values), max_len), None, dtype=object) - for i, v in enumerate(values): - arr[i, : lens[i]] = v - for i in range(max_len): - sub_arr = arr[:, i] - expanded_data[msg_type][f"{field}/{i}"] = sub_arr - elif values.ndim > 1: # regular multidimensional array - for i in range(values.shape[1]): - col_data = values[:, i] - expanded_data[msg_type][f"{field}/{i}"] = col_data - else: - expanded_data[msg_type][field] = values - return expanded_data - - -class DataLoadedEvent: - def __init__(self, data: dict[str, Any]): - self.data = data - - -class Observer(ABC): - @abstractmethod - def on_data_loaded(self, event: DataLoadedEvent): - pass - - class DataManager: def __init__(self): self.time_series_data = {} self.loading = False self.route_start_time_mono = 0.0 self.duration = 0.0 - self._observers: list[Observer] = [] + self._callbacks: list[Callable[[dict], None]] = [] - def add_observer(self, observer: Observer): - self._observers.append(observer) + def add_callback(self, callback: Callable[[dict], None]): + self._callbacks.append(callback) - def remove_observer(self, observer: Observer): - if observer in self._observers: - self._observers.remove(observer) + def remove_callback(self, callback: Callable[[dict], None]): + if callback in self._callbacks: + self._callbacks.remove(callback) - def _notify_observers(self, event: DataLoadedEvent): - for observer in self._observers: - observer.on_data_loaded(event) + def _notify_callbacks(self, data: dict): + for callback in self._callbacks: + try: + callback(data) + except Exception as e: + cloudlog.exception(f"Error in data callback: {e}") - def get_current_value_for_path(self, path: str, time_s: float, last_index: int | None = None): + def get_current_value(self, path: str, time_s: float, last_index: int | None = None): try: abs_time_s = self.route_start_time_mono + time_s msg_type, field_path = path.split('/', 1) @@ -136,24 +63,24 @@ class DataManager: try: msg_type, field_path = path.split('/', 1) value_array = self.time_series_data.get(msg_type, {}).get(field_path) - if value_array is not None: # only numbers and bools are plottable + if value_array is not None: return np.issubdtype(value_array.dtype, np.number) or np.issubdtype(value_array.dtype, np.bool_) except (ValueError, KeyError): pass return False - def get_time_series_data(self, path: str) -> tuple | None: + def get_time_series(self, path: str): try: msg_type, field_path = path.split('/', 1) ts_data = self.time_series_data[msg_type] time_array = ts_data['t'] - plot_values = ts_data[field_path] + values = ts_data[field_path] if len(time_array) == 0: return None - rel_time_array = time_array - self.route_start_time_mono - return rel_time_array, plot_values + rel_time = time_array - self.route_start_time_mono + return rel_time, values except (KeyError, ValueError): return None @@ -163,19 +90,55 @@ class DataManager: return self.loading = True - data_source = LogReaderSource(route_name) - threading.Thread(target=self._load_in_background, args=(data_source,), daemon=True).start() + threading.Thread(target=self._load_route_background, args=(route_name,), daemon=True).start() - def _load_in_background(self, data_source: DataSource): + def _load_route_background(self, route_name: str): try: - data = data_source.load_data() - self.time_series_data = data['time_series_data'] - self.route_start_time_mono = data['route_start_time_mono'] - self.duration = data['duration'] + lr = LogReader(route_name) + raw_data = msgs_to_time_series(lr) + processed_data = self._expand_list_fields(raw_data) - self._notify_observers(DataLoadedEvent(data)) + min_time = float('inf') + max_time = float('-inf') + for data in processed_data.values(): + if len(data['t']) > 0: + min_time = min(min_time, data['t'][0]) + max_time = max(max_time, data['t'][-1]) - except Exception: - cloudlog.exception("Error loading route:") + self.time_series_data = processed_data + self.route_start_time_mono = min_time if min_time != float('inf') else 0.0 + self.duration = max_time - min_time if max_time != float('-inf') else 0.0 + + self._notify_callbacks({'time_series_data': processed_data, 'route_start_time_mono': self.route_start_time_mono, 'duration': self.duration}) + + except Exception as e: + cloudlog.exception(f"Error loading route {route_name}: {e}") finally: self.loading = False + + def _expand_list_fields(self, time_series_data): + expanded_data = {} + for msg_type, data in time_series_data.items(): + expanded_data[msg_type] = {} + for field, values in data.items(): + if field == 't': + expanded_data[msg_type]['t'] = values + continue + + if values.dtype == object: # ragged array + lens = np.fromiter((len(v) for v in values), dtype=int, count=len(values)) + max_len = lens.max() if lens.size else 0 + if max_len > 0: + arr = np.full((len(values), max_len), None, dtype=object) + for i, v in enumerate(values): + arr[i, : lens[i]] = v + for i in range(max_len): + sub_arr = arr[:, i] + expanded_data[msg_type][f"{field}/{i}"] = sub_arr + elif values.ndim > 1: # regular multidimensional array + for i in range(values.shape[1]): + col_data = values[:, i] + expanded_data[msg_type][f"{field}/{i}"] = col_data + else: + expanded_data[msg_type][field] = values + return expanded_data diff --git a/tools/jotpluggler/pluggle.py b/tools/jotpluggler/pluggle.py index c1feb6db0a..e2c9df159b 100755 --- a/tools/jotpluggler/pluggle.py +++ b/tools/jotpluggler/pluggle.py @@ -6,7 +6,7 @@ import subprocess import dearpygui.dearpygui as dpg import threading from openpilot.common.basedir import BASEDIR -from openpilot.tools.jotpluggler.data import DataManager, Observer, DataLoadedEvent +from openpilot.tools.jotpluggler.data import DataManager from openpilot.tools.jotpluggler.views import DataTreeView from openpilot.tools.jotpluggler.layout import PlotLayoutManager, SplitterNode, LeafNode @@ -61,7 +61,7 @@ def format_and_truncate(value, available_width: float, avg_char_width: float) -> return s -class MainController(Observer): +class MainController: def __init__(self, scale: float = 1.0): self.ui_lock = threading.Lock() self.scale = scale @@ -70,7 +70,7 @@ class MainController(Observer): self._create_global_themes() self.data_tree_view = DataTreeView(self.data_manager, self.ui_lock) self.plot_layout_manager = PlotLayoutManager(self.data_manager, self.playback_manager, scale=self.scale) - self.data_manager.add_observer(self) + self.data_manager.add_callback(self.on_data_loaded) self.avg_char_width = None def _create_global_themes(self): @@ -85,12 +85,12 @@ class MainController(Observer): dpg.add_theme_style(dpg.mvPlotStyleVar_LineWeight, scaled_thickness, category=dpg.mvThemeCat_Plots) dpg.add_theme_color(dpg.mvPlotCol_Line, (255, 0, 0, 128), category=dpg.mvThemeCat_Plots) - def on_data_loaded(self, event: DataLoadedEvent): - self.playback_manager.set_route_duration(event.data['duration']) - num_msg_types = len(event.data['time_series_data']) + def on_data_loaded(self, data: dict): + self.playback_manager.set_route_duration(data['duration']) + num_msg_types = len(data['time_series_data']) dpg.set_value("load_status", f"Loaded {num_msg_types} message types") dpg.configure_item("load_button", enabled=True) - dpg.configure_item("timeline_slider", max_value=event.data['duration']) + dpg.configure_item("timeline_slider", max_value=data['duration']) def setup_ui(self): with dpg.item_handler_registry(tag="tree_node_handler"): @@ -180,7 +180,7 @@ class MainController(Observer): if dpg.does_item_exist(value_tag) and dpg.is_item_visible(value_tag): last_index = self.playback_manager.last_indices.get(path) - value, new_idx = self.data_manager.get_current_value_for_path(path, self.playback_manager.current_time_s, last_index) + value, new_idx = self.data_manager.get_current_value(path, self.playback_manager.current_time_s, last_index) if value is not None: self.playback_manager.update_index(path, new_idx) diff --git a/tools/jotpluggler/views.py b/tools/jotpluggler/views.py index 0e86a18c44..feeff7ab24 100644 --- a/tools/jotpluggler/views.py +++ b/tools/jotpluggler/views.py @@ -4,7 +4,7 @@ import uuid import threading import dearpygui.dearpygui as dpg from abc import ABC, abstractmethod -from openpilot.tools.jotpluggler.data import Observer, DataLoadedEvent, DataManager +from openpilot.tools.jotpluggler.data import DataManager class ViewPanel(ABC): @@ -31,7 +31,7 @@ class ViewPanel(ABC): pass -class TimeSeriesPanel(ViewPanel, Observer): +class TimeSeriesPanel(ViewPanel): def __init__(self, data_manager: DataManager, playback_manager, panel_id: str | None = None): super().__init__(panel_id) self.data_manager = data_manager @@ -45,14 +45,13 @@ class TimeSeriesPanel(ViewPanel, Observer): self._ui_created = False self._preserved_series_data: list[tuple[str, tuple]] = [] # TODO: the way we do this right now doesn't make much sense self._series_legend_tags: dict[str, str] = {} # Maps series_path to legend tag - - self.data_manager.add_observer(self) + self.data_manager.add_callback(self.on_data_loaded) def preserve_data(self): self._preserved_series_data = [] if self.plotted_series and self._ui_created: for series_path in self.plotted_series: - time_value_data = self.data_manager.get_time_series_data(series_path) + time_value_data = self.data_manager.get_time_series(series_path) if time_value_data: self._preserved_series_data.append((series_path, time_value_data)) @@ -88,7 +87,7 @@ class TimeSeriesPanel(ViewPanel, Observer): if self.plotted_series: # update legend labels with current values for series_path in self.plotted_series: last_index = self.playback_manager.last_indices.get(series_path) - value, new_idx = self.data_manager.get_current_value_for_path(series_path, current_time_s, last_index) + value, new_idx = self.data_manager.get_current_value(series_path, current_time_s, last_index) if value is not None: self.playback_manager.update_index(series_path, new_idx) @@ -126,7 +125,7 @@ class TimeSeriesPanel(ViewPanel, Observer): if self.plot_tag and dpg.does_item_exist(self.plot_tag): dpg.delete_item(self.plot_tag) - # self.data_manager.remove_observer(self) + # self.data_manager.remove_callback(self.on_data_loaded) self._series_legend_tags.clear() self._ui_created = False @@ -137,7 +136,7 @@ class TimeSeriesPanel(ViewPanel, Observer): if series_path in self.plotted_series: return False - time_value_data = self.data_manager.get_time_series_data(series_path) + time_value_data = self.data_manager.get_time_series(series_path) if time_value_data is None: return False @@ -157,12 +156,12 @@ class TimeSeriesPanel(ViewPanel, Observer): if series_path in self._series_legend_tags: # Clean up legend tag mapping del self._series_legend_tags[series_path] - def on_data_loaded(self, event: DataLoadedEvent): + def on_data_loaded(self, data: dict): for series_path in self.plotted_series.copy(): self._update_series_data(series_path) def _update_series_data(self, series_path: str) -> bool: - time_value_data = self.data_manager.get_time_series_data(series_path) + time_value_data = self.data_manager.get_time_series(series_path) if time_value_data is None: return False @@ -190,16 +189,16 @@ class DataTreeNode: self.children: dict[str, DataTreeNode] = {} self.is_leaf = False -class DataTreeView(Observer): +class DataTreeView: def __init__(self, data_manager: DataManager, ui_lock: threading.Lock): self.data_manager = data_manager self.ui_lock = ui_lock self.current_search = "" self.data_tree = DataTreeNode(name="root") self.active_leaf_nodes: list[DataTreeNode] = [] - self.data_manager.add_observer(self) + self.data_manager.add_callback(self.on_data_loaded) - def on_data_loaded(self, event: DataLoadedEvent): + def on_data_loaded(self, data: dict): with self.ui_lock: self.populate_data_tree()