simplified data.py

pull/36045/head
Quantizr (Jimmy) 1 month ago
parent 2001898d8b
commit 6048eb568a
  1. 167
      tools/jotpluggler/data.py
  2. 16
      tools/jotpluggler/pluggle.py
  3. 25
      tools/jotpluggler/views.py

@ -1,107 +1,34 @@
import threading
import numpy as np
from abc import ABC, abstractmethod
from typing import Any
from collections.abc import Callable
from openpilot.common.swaglog import cloudlog
from openpilot.tools.lib.logreader import LogReader
from openpilot.tools.lib.log_time_series import msgs_to_time_series
# TODO: support cereal/ZMQ streaming
class DataSource(ABC):
@abstractmethod
def load_data(self) -> dict[str, Any]:
pass
@abstractmethod
def get_duration(self) -> float:
pass
class LogReaderSource(DataSource):
def __init__(self, route_name: str):
self.route_name = route_name
self._duration = 0.0
self._start_time_mono = 0.0
def load_data(self) -> dict[str, Any]:
lr = LogReader(self.route_name)
raw_time_series = msgs_to_time_series(lr)
processed_data = self._expand_list_fields(raw_time_series)
min_time = float('inf')
max_time = float('-inf')
for data in processed_data.values():
min_time = min(min_time, data['t'][0])
max_time = max(max_time, data['t'][-1])
self._start_time_mono = min_time
self._duration = max_time - min_time
return {'time_series_data': processed_data, 'route_start_time_mono': self._start_time_mono, 'duration': self._duration}
def get_duration(self) -> float:
return self._duration
# TODO: lists are expanded, but lists of structs are not
def _expand_list_fields(self, time_series_data):
expanded_data = {}
for msg_type, data in time_series_data.items():
expanded_data[msg_type] = {}
for field, values in data.items():
if field == 't':
expanded_data[msg_type]['t'] = values
continue
if values.dtype == object: # ragged array
lens = np.fromiter((len(v) for v in values), dtype=int, count=len(values))
max_len = lens.max() if lens.size else 0
if max_len > 0:
arr = np.full((len(values), max_len), None, dtype=object)
for i, v in enumerate(values):
arr[i, : lens[i]] = v
for i in range(max_len):
sub_arr = arr[:, i]
expanded_data[msg_type][f"{field}/{i}"] = sub_arr
elif values.ndim > 1: # regular multidimensional array
for i in range(values.shape[1]):
col_data = values[:, i]
expanded_data[msg_type][f"{field}/{i}"] = col_data
else:
expanded_data[msg_type][field] = values
return expanded_data
class DataLoadedEvent:
def __init__(self, data: dict[str, Any]):
self.data = data
class Observer(ABC):
@abstractmethod
def on_data_loaded(self, event: DataLoadedEvent):
pass
class DataManager:
def __init__(self):
self.time_series_data = {}
self.loading = False
self.route_start_time_mono = 0.0
self.duration = 0.0
self._observers: list[Observer] = []
self._callbacks: list[Callable[[dict], None]] = []
def add_observer(self, observer: Observer):
self._observers.append(observer)
def add_callback(self, callback: Callable[[dict], None]):
self._callbacks.append(callback)
def remove_observer(self, observer: Observer):
if observer in self._observers:
self._observers.remove(observer)
def remove_callback(self, callback: Callable[[dict], None]):
if callback in self._callbacks:
self._callbacks.remove(callback)
def _notify_observers(self, event: DataLoadedEvent):
for observer in self._observers:
observer.on_data_loaded(event)
def _notify_callbacks(self, data: dict):
for callback in self._callbacks:
try:
callback(data)
except Exception as e:
cloudlog.exception(f"Error in data callback: {e}")
def get_current_value_for_path(self, path: str, time_s: float, last_index: int | None = None):
def get_current_value(self, path: str, time_s: float, last_index: int | None = None):
try:
abs_time_s = self.route_start_time_mono + time_s
msg_type, field_path = path.split('/', 1)
@ -136,24 +63,24 @@ class DataManager:
try:
msg_type, field_path = path.split('/', 1)
value_array = self.time_series_data.get(msg_type, {}).get(field_path)
if value_array is not None: # only numbers and bools are plottable
if value_array is not None:
return np.issubdtype(value_array.dtype, np.number) or np.issubdtype(value_array.dtype, np.bool_)
except (ValueError, KeyError):
pass
return False
def get_time_series_data(self, path: str) -> tuple | None:
def get_time_series(self, path: str):
try:
msg_type, field_path = path.split('/', 1)
ts_data = self.time_series_data[msg_type]
time_array = ts_data['t']
plot_values = ts_data[field_path]
values = ts_data[field_path]
if len(time_array) == 0:
return None
rel_time_array = time_array - self.route_start_time_mono
return rel_time_array, plot_values
rel_time = time_array - self.route_start_time_mono
return rel_time, values
except (KeyError, ValueError):
return None
@ -163,19 +90,55 @@ class DataManager:
return
self.loading = True
data_source = LogReaderSource(route_name)
threading.Thread(target=self._load_in_background, args=(data_source,), daemon=True).start()
threading.Thread(target=self._load_route_background, args=(route_name,), daemon=True).start()
def _load_in_background(self, data_source: DataSource):
def _load_route_background(self, route_name: str):
try:
data = data_source.load_data()
self.time_series_data = data['time_series_data']
self.route_start_time_mono = data['route_start_time_mono']
self.duration = data['duration']
lr = LogReader(route_name)
raw_data = msgs_to_time_series(lr)
processed_data = self._expand_list_fields(raw_data)
self._notify_observers(DataLoadedEvent(data))
min_time = float('inf')
max_time = float('-inf')
for data in processed_data.values():
if len(data['t']) > 0:
min_time = min(min_time, data['t'][0])
max_time = max(max_time, data['t'][-1])
self.time_series_data = processed_data
self.route_start_time_mono = min_time if min_time != float('inf') else 0.0
self.duration = max_time - min_time if max_time != float('-inf') else 0.0
self._notify_callbacks({'time_series_data': processed_data, 'route_start_time_mono': self.route_start_time_mono, 'duration': self.duration})
except Exception:
cloudlog.exception("Error loading route:")
except Exception as e:
cloudlog.exception(f"Error loading route {route_name}: {e}")
finally:
self.loading = False
def _expand_list_fields(self, time_series_data):
expanded_data = {}
for msg_type, data in time_series_data.items():
expanded_data[msg_type] = {}
for field, values in data.items():
if field == 't':
expanded_data[msg_type]['t'] = values
continue
if values.dtype == object: # ragged array
lens = np.fromiter((len(v) for v in values), dtype=int, count=len(values))
max_len = lens.max() if lens.size else 0
if max_len > 0:
arr = np.full((len(values), max_len), None, dtype=object)
for i, v in enumerate(values):
arr[i, : lens[i]] = v
for i in range(max_len):
sub_arr = arr[:, i]
expanded_data[msg_type][f"{field}/{i}"] = sub_arr
elif values.ndim > 1: # regular multidimensional array
for i in range(values.shape[1]):
col_data = values[:, i]
expanded_data[msg_type][f"{field}/{i}"] = col_data
else:
expanded_data[msg_type][field] = values
return expanded_data

@ -6,7 +6,7 @@ import subprocess
import dearpygui.dearpygui as dpg
import threading
from openpilot.common.basedir import BASEDIR
from openpilot.tools.jotpluggler.data import DataManager, Observer, DataLoadedEvent
from openpilot.tools.jotpluggler.data import DataManager
from openpilot.tools.jotpluggler.views import DataTreeView
from openpilot.tools.jotpluggler.layout import PlotLayoutManager, SplitterNode, LeafNode
@ -61,7 +61,7 @@ def format_and_truncate(value, available_width: float, avg_char_width: float) ->
return s
class MainController(Observer):
class MainController:
def __init__(self, scale: float = 1.0):
self.ui_lock = threading.Lock()
self.scale = scale
@ -70,7 +70,7 @@ class MainController(Observer):
self._create_global_themes()
self.data_tree_view = DataTreeView(self.data_manager, self.ui_lock)
self.plot_layout_manager = PlotLayoutManager(self.data_manager, self.playback_manager, scale=self.scale)
self.data_manager.add_observer(self)
self.data_manager.add_callback(self.on_data_loaded)
self.avg_char_width = None
def _create_global_themes(self):
@ -85,12 +85,12 @@ class MainController(Observer):
dpg.add_theme_style(dpg.mvPlotStyleVar_LineWeight, scaled_thickness, category=dpg.mvThemeCat_Plots)
dpg.add_theme_color(dpg.mvPlotCol_Line, (255, 0, 0, 128), category=dpg.mvThemeCat_Plots)
def on_data_loaded(self, event: DataLoadedEvent):
self.playback_manager.set_route_duration(event.data['duration'])
num_msg_types = len(event.data['time_series_data'])
def on_data_loaded(self, data: dict):
self.playback_manager.set_route_duration(data['duration'])
num_msg_types = len(data['time_series_data'])
dpg.set_value("load_status", f"Loaded {num_msg_types} message types")
dpg.configure_item("load_button", enabled=True)
dpg.configure_item("timeline_slider", max_value=event.data['duration'])
dpg.configure_item("timeline_slider", max_value=data['duration'])
def setup_ui(self):
with dpg.item_handler_registry(tag="tree_node_handler"):
@ -180,7 +180,7 @@ class MainController(Observer):
if dpg.does_item_exist(value_tag) and dpg.is_item_visible(value_tag):
last_index = self.playback_manager.last_indices.get(path)
value, new_idx = self.data_manager.get_current_value_for_path(path, self.playback_manager.current_time_s, last_index)
value, new_idx = self.data_manager.get_current_value(path, self.playback_manager.current_time_s, last_index)
if value is not None:
self.playback_manager.update_index(path, new_idx)

@ -4,7 +4,7 @@ import uuid
import threading
import dearpygui.dearpygui as dpg
from abc import ABC, abstractmethod
from openpilot.tools.jotpluggler.data import Observer, DataLoadedEvent, DataManager
from openpilot.tools.jotpluggler.data import DataManager
class ViewPanel(ABC):
@ -31,7 +31,7 @@ class ViewPanel(ABC):
pass
class TimeSeriesPanel(ViewPanel, Observer):
class TimeSeriesPanel(ViewPanel):
def __init__(self, data_manager: DataManager, playback_manager, panel_id: str | None = None):
super().__init__(panel_id)
self.data_manager = data_manager
@ -45,14 +45,13 @@ class TimeSeriesPanel(ViewPanel, Observer):
self._ui_created = False
self._preserved_series_data: list[tuple[str, tuple]] = [] # TODO: the way we do this right now doesn't make much sense
self._series_legend_tags: dict[str, str] = {} # Maps series_path to legend tag
self.data_manager.add_observer(self)
self.data_manager.add_callback(self.on_data_loaded)
def preserve_data(self):
self._preserved_series_data = []
if self.plotted_series and self._ui_created:
for series_path in self.plotted_series:
time_value_data = self.data_manager.get_time_series_data(series_path)
time_value_data = self.data_manager.get_time_series(series_path)
if time_value_data:
self._preserved_series_data.append((series_path, time_value_data))
@ -88,7 +87,7 @@ class TimeSeriesPanel(ViewPanel, Observer):
if self.plotted_series: # update legend labels with current values
for series_path in self.plotted_series:
last_index = self.playback_manager.last_indices.get(series_path)
value, new_idx = self.data_manager.get_current_value_for_path(series_path, current_time_s, last_index)
value, new_idx = self.data_manager.get_current_value(series_path, current_time_s, last_index)
if value is not None:
self.playback_manager.update_index(series_path, new_idx)
@ -126,7 +125,7 @@ class TimeSeriesPanel(ViewPanel, Observer):
if self.plot_tag and dpg.does_item_exist(self.plot_tag):
dpg.delete_item(self.plot_tag)
# self.data_manager.remove_observer(self)
# self.data_manager.remove_callback(self.on_data_loaded)
self._series_legend_tags.clear()
self._ui_created = False
@ -137,7 +136,7 @@ class TimeSeriesPanel(ViewPanel, Observer):
if series_path in self.plotted_series:
return False
time_value_data = self.data_manager.get_time_series_data(series_path)
time_value_data = self.data_manager.get_time_series(series_path)
if time_value_data is None:
return False
@ -157,12 +156,12 @@ class TimeSeriesPanel(ViewPanel, Observer):
if series_path in self._series_legend_tags: # Clean up legend tag mapping
del self._series_legend_tags[series_path]
def on_data_loaded(self, event: DataLoadedEvent):
def on_data_loaded(self, data: dict):
for series_path in self.plotted_series.copy():
self._update_series_data(series_path)
def _update_series_data(self, series_path: str) -> bool:
time_value_data = self.data_manager.get_time_series_data(series_path)
time_value_data = self.data_manager.get_time_series(series_path)
if time_value_data is None:
return False
@ -190,16 +189,16 @@ class DataTreeNode:
self.children: dict[str, DataTreeNode] = {}
self.is_leaf = False
class DataTreeView(Observer):
class DataTreeView:
def __init__(self, data_manager: DataManager, ui_lock: threading.Lock):
self.data_manager = data_manager
self.ui_lock = ui_lock
self.current_search = ""
self.data_tree = DataTreeNode(name="root")
self.active_leaf_nodes: list[DataTreeNode] = []
self.data_manager.add_observer(self)
self.data_manager.add_callback(self.on_data_loaded)
def on_data_loaded(self, event: DataLoadedEvent):
def on_data_loaded(self, data: dict):
with self.ui_lock:
self.populate_data_tree()

Loading…
Cancel
Save