diff --git a/tools/jotpluggler/data.py b/tools/jotpluggler/data.py index 2b31df1811..57b6ac972c 100644 --- a/tools/jotpluggler/data.py +++ b/tools/jotpluggler/data.py @@ -1,144 +1,301 @@ -import threading import numpy as np -from collections.abc import Callable +import threading +import multiprocessing +import bisect +from collections import defaultdict +from typing import Any +import tqdm from openpilot.common.swaglog import cloudlog -from openpilot.tools.lib.logreader import LogReader -from openpilot.tools.lib.log_time_series import msgs_to_time_series +from openpilot.tools.lib.logreader import _LogFileReader, LogReader + + +def flatten_dict(d: dict, sep: str = "/", prefix: str = None) -> dict: + result = {} + stack = [(d, prefix)] + + while stack: + obj, current_prefix = stack.pop() + + if isinstance(obj, dict): + for key, val in obj.items(): + new_prefix = key if current_prefix is None else f"{current_prefix}{sep}{key}" + if isinstance(val, (dict, list)): + stack.append((val, new_prefix)) + else: + result[new_prefix] = val + elif isinstance(obj, list): + for i, item in enumerate(obj): + new_prefix = f"{current_prefix}{sep}{i}" + if isinstance(item, (dict, list)): + stack.append((item, new_prefix)) + else: + result[new_prefix] = item + else: + if current_prefix is not None: + result[current_prefix] = obj + return result + + +def extract_field_types(schema, prefix, field_types_dict): + stack = [(schema, prefix)] + + while stack: + current_schema, current_prefix = stack.pop() + + for field in current_schema.fields_list: + field_name = field.proto.name + field_path = f"{current_prefix}/{field_name}" + field_proto = field.proto + field_which = field_proto.which() + + field_type = field_proto.slot.type.which() if field_which == 'slot' else field_which + field_types_dict[field_path] = field_type + + if field_which == 'slot': + slot_type = field_proto.slot.type + type_which = slot_type.which() + + if type_which == 'list': + element_type = slot_type.list.elementType.which() + list_path = f"{field_path}/*" + field_types_dict[list_path] = element_type + + if element_type == 'struct': + stack.append((field.schema.elementType, list_path)) + + elif type_which == 'struct': + stack.append((field.schema, field_path)) + + elif field_which == 'group': + stack.append((field.schema, field_path)) + + +def _convert_to_optimal_dtype(values_list, capnp_type): + if not values_list: + return np.array([]) + + dtype_mapping = { + 'bool': np.bool_, 'int8': np.int8, 'int16': np.int16, 'int32': np.int32, 'int64': np.int64, + 'uint8': np.uint8, 'uint16': np.uint16, 'uint32': np.uint32, 'uint64': np.uint64, + 'float32': np.float32, 'float64': np.float64, 'text': object, 'data': object, + 'enum': object, 'anyPointer': object, + } + + target_dtype = dtype_mapping.get(capnp_type) + return np.array(values_list, dtype=target_dtype) if target_dtype else np.array(values_list) + + +def _match_field_type(field_path, field_types): + if field_path in field_types: + return field_types[field_path] + + path_parts = field_path.split('/') + template_parts = [p if not p.isdigit() else '*' for p in path_parts] + template_path = '/'.join(template_parts) + return field_types.get(template_path) + + +def msgs_to_time_series(msgs): + """Extract scalar fields and return (time_series_data, start_time, end_time).""" + collected_data = defaultdict(lambda: {'timestamps': [], 'columns': defaultdict(list), 'sparse_fields': set()}) + field_types = {} + extracted_schemas = set() + min_time = max_time = None + + for msg in msgs: + typ = msg.which() + timestamp = msg.logMonoTime * 1e-9 + if typ != 'initData': + if min_time is None: + min_time = timestamp + max_time = timestamp + + sub_msg = getattr(msg, typ) + if not hasattr(sub_msg, 'to_dict') or typ in ('qcomGnss', 'ubloxGnss'): + continue + + if hasattr(sub_msg, 'schema') and typ not in extracted_schemas: + extract_field_types(sub_msg.schema, typ, field_types) + extracted_schemas.add(typ) + + msg_dict = sub_msg.to_dict(verbose=True) + flat_dict = flatten_dict(msg_dict) + flat_dict['_valid'] = msg.valid + + type_data = collected_data[typ] + columns, sparse_fields = type_data['columns'], type_data['sparse_fields'] + known_fields = set(columns.keys()) + missing_fields = known_fields - flat_dict.keys() + + for field, value in flat_dict.items(): + if field not in known_fields and type_data['timestamps']: + sparse_fields.add(field) + columns[field].append(value) + if value is None: + sparse_fields.add(field) + + for field in missing_fields: + columns[field].append(None) + sparse_fields.add(field) + + type_data['timestamps'].append(timestamp) + + final_result = {} + for typ, data in collected_data.items(): + if not data['timestamps']: + continue + + typ_result = {'t': np.array(data['timestamps'], dtype=np.float64)} + sparse_fields = data['sparse_fields'] + + for field_name, values in data['columns'].items(): + if len(values) < len(data['timestamps']): + values = [None] * (len(data['timestamps']) - len(values)) + values + sparse_fields.add(field_name) + + if field_name in sparse_fields: + typ_result[field_name] = np.array(values, dtype=object) + else: + capnp_type = _match_field_type(f"{typ}/{field_name}", field_types) + typ_result[field_name] = _convert_to_optimal_dtype(values, capnp_type) + + final_result[typ] = typ_result + + return final_result, min_time or 0.0, max_time or 0.0 + + +def _process_segment(segment_identifier: str) -> tuple[dict[str, Any], float, float]: + try: + lr = _LogFileReader(segment_identifier, sort_by_time=True) + return msgs_to_time_series(lr) + except Exception as e: + cloudlog.warning(f"Warning: Failed to process segment {segment_identifier}: {e}") + return {}, 0.0, 0.0 class DataManager: def __init__(self): - self.time_series_data = {} + self._segments = [] + self._segment_starts = [] + self._start_time = 0.0 + self._duration = 0.0 + self._paths = set() + self._observers = [] self.loading = False - self.route_start_time_mono = 0.0 - self.duration = 0.0 - self._callbacks: list[Callable[[dict], None]] = [] - - def add_callback(self, callback: Callable[[dict], None]): - self._callbacks.append(callback) + self._lock = threading.RLock() - def remove_callback(self, callback: Callable[[dict], None]): - if callback in self._callbacks: - self._callbacks.remove(callback) + def load_route(self, route: str) -> None: + if self.loading: + return + self._reset() + threading.Thread(target=self._load_async, args=(route,), daemon=True).start() - def _notify_callbacks(self, data: dict): - for callback in self._callbacks: - try: - callback(data) - except Exception as e: - cloudlog.exception(f"Error in data callback: {e}") + def get_timeseries(self, path: str): + with self._lock: + msg_type, field = path.split('/', 1) + times, values = [], [] - def get_current_value(self, path: str, time_s: float, last_index: int | None = None): - try: - abs_time_s = self.route_start_time_mono + time_s - msg_type, field_path = path.split('/', 1) - ts_data = self.time_series_data[msg_type] - t, v = ts_data['t'], ts_data[field_path] - - if len(t) == 0: - return None, None - - if last_index is None: # jump - idx = np.searchsorted(t, abs_time_s, side='right') - 1 - else: # continuous playback - idx = last_index - while idx < len(t) - 1 and t[idx + 1] < abs_time_s: - idx += 1 - - idx = max(0, idx) - return v[idx], idx - - except (KeyError, IndexError): - return None, None - - def get_all_paths(self) -> list[str]: - all_paths = [] - for msg_type, data in self.time_series_data.items(): - for key in data.keys(): - if key != 't': - all_paths.append(f"{msg_type}/{key}") - return all_paths - - def is_path_plottable(self, path: str) -> bool: - try: - msg_type, field_path = path.split('/', 1) - value_array = self.time_series_data.get(msg_type, {}).get(field_path) - if value_array is not None: - return np.issubdtype(value_array.dtype, np.number) or np.issubdtype(value_array.dtype, np.bool_) - except (ValueError, KeyError): - pass - return False - - def get_time_series(self, path: str): - try: - msg_type, field_path = path.split('/', 1) - ts_data = self.time_series_data[msg_type] - time_array = ts_data['t'] - values = ts_data[field_path] + for segment in self._segments: + if msg_type in segment and field in segment[msg_type]: + times.append(segment[msg_type]['t']) + values.append(segment[msg_type][field]) - if len(time_array) == 0: + if not times: return None - rel_time = time_array - self.route_start_time_mono - return rel_time, values + combined_times = np.concatenate(times) - self._start_time + if len(values) > 1 and any(arr.dtype != values[0].dtype for arr in values): + values = [arr.astype(object) for arr in values] + + return combined_times, np.concatenate(values) - except (KeyError, ValueError): + def get_value_at(self, path: str, time: float): + with self._lock: + absolute_time = self._start_time + time + message_type, field = path.split('/', 1) + current_index = bisect.bisect_right(self._segment_starts, absolute_time) - 1 + for index in (current_index, current_index - 1): + if not 0 <= index < len(self._segments): + continue + segment = self._segments[index].get(message_type) + if not segment or field not in segment: + continue + times = segment['t'] + if len(times) == 0 or (index != current_index and absolute_time - times[-1] > 1): + continue + position = np.searchsorted(times, absolute_time, 'right') - 1 + if position >= 0 and absolute_time - times[position] <= 1: + return segment[field][position] return None - def load_route(self, route_name: str): - if self.loading: - return + def get_all_paths(self): + with self._lock: + return sorted(self._paths) - self.loading = True - threading.Thread(target=self._load_route_background, args=(route_name,), daemon=True).start() + def get_duration(self): + with self._lock: + return self._duration - def _load_route_background(self, route_name: str): - try: - lr = LogReader(route_name) - raw_data = msgs_to_time_series(lr) - processed_data = self._expand_list_fields(raw_data) + def is_plottable(self, path: str): + data = self.get_timeseries(path) + if data is None: + return False + _, values = data + return np.issubdtype(values.dtype, np.number) or np.issubdtype(values.dtype, np.bool_) - min_time = float('inf') - max_time = float('-inf') - for data in processed_data.values(): - if len(data['t']) > 0: - min_time = min(min_time, data['t'][0]) - max_time = max(max_time, data['t'][-1]) + def add_observer(self, callback): + with self._lock: + self._observers.append(callback) - self.time_series_data = processed_data - self.route_start_time_mono = min_time if min_time != float('inf') else 0.0 - self.duration = max_time - min_time if max_time != float('-inf') else 0.0 + def _reset(self): + with self._lock: + self.loading = True + self._segments.clear() + self._segment_starts.clear() + self._paths.clear() + self._start_time = self._duration = 0.0 - self._notify_callbacks({'time_series_data': processed_data, 'route_start_time_mono': self.route_start_time_mono, 'duration': self.duration}) + def _load_async(self, route: str): + try: + lr = LogReader(route, sort_by_time=True) + if not lr.logreader_identifiers: + cloudlog.warning(f"Warning: No log segments found for route: {route}") + return + with multiprocessing.Pool() as pool, tqdm.tqdm(total=len(lr.logreader_identifiers), desc="Processing Segments") as pbar: + for segment_result, start_time, end_time in pool.imap(_process_segment, lr.logreader_identifiers): + pbar.update(1) + if segment_result: + self._add_segment(segment_result, start_time, end_time) except Exception as e: - cloudlog.exception(f"Error loading route {route_name}: {e}") + cloudlog.exception(f"Error loading route {route}:") finally: - self.loading = False + self._finalize_loading() - def _expand_list_fields(self, time_series_data): - expanded_data = {} - for msg_type, data in time_series_data.items(): - expanded_data[msg_type] = {} - for field, values in data.items(): - if field == 't': - expanded_data[msg_type]['t'] = values - continue + def _add_segment(self, segment_data: dict, start_time: float, end_time: float): + with self._lock: + self._segments.append(segment_data) + self._segment_starts.append(start_time) - if values.dtype == object: # ragged array - lens = np.fromiter((len(v) for v in values), dtype=int, count=len(values)) - max_len = lens.max() if lens.size else 0 - if max_len > 0: - arr = np.full((len(values), max_len), None, dtype=object) - for i, v in enumerate(values): - arr[i, : lens[i]] = v - for i in range(max_len): - sub_arr = arr[:, i] - expanded_data[msg_type][f"{field}/{i}"] = sub_arr - elif values.ndim > 1: # regular multidimensional array - for i in range(values.shape[1]): - col_data = values[:, i] - expanded_data[msg_type][f"{field}/{i}"] = col_data - else: - expanded_data[msg_type][field] = values - return expanded_data + if len(self._segments) == 1: + self._start_time = start_time + self._duration = end_time - self._start_time + + for msg_type, data in segment_data.items(): + for field in data.keys(): + if field != 't': + self._paths.add(f"{msg_type}/{field}") + + observers = self._observers.copy() + + for callback in observers: + callback({'segment_added': True, 'duration': self._duration}) + + def _finalize_loading(self): + with self._lock: + self.loading = False + observers = self._observers.copy() + duration = self._duration + + for callback in observers: + callback({'loading_complete': True, 'duration': duration}) diff --git a/tools/jotpluggler/pluggle.py b/tools/jotpluggler/pluggle.py index e2c9df159b..614bb53e8e 100755 --- a/tools/jotpluggler/pluggle.py +++ b/tools/jotpluggler/pluggle.py @@ -18,7 +18,6 @@ class PlaybackManager: self.is_playing = False self.current_time_s = 0.0 self.duration_s = 0.0 - self.last_indices = {} def set_route_duration(self, duration: float): self.duration_s = duration @@ -32,7 +31,6 @@ class PlaybackManager: def seek(self, time_s: float): self.is_playing = False self.current_time_s = max(0.0, min(time_s, self.duration_s)) - self.last_indices.clear() def update_time(self, delta_t: float): if self.is_playing: @@ -41,10 +39,6 @@ class PlaybackManager: self.is_playing = False return self.current_time_s - def update_index(self, path: str, new_idx: int | None): - if new_idx is not None: - self.last_indices[path] = new_idx - def calculate_avg_char_width(font): sample_text = "abcdefghijklmnopqrstuvwxyz0123456789" @@ -70,7 +64,7 @@ class MainController: self._create_global_themes() self.data_tree_view = DataTreeView(self.data_manager, self.ui_lock) self.plot_layout_manager = PlotLayoutManager(self.data_manager, self.playback_manager, scale=self.scale) - self.data_manager.add_callback(self.on_data_loaded) + self.data_manager.add_observer(self.on_data_loaded) self.avg_char_width = None def _create_global_themes(self): @@ -86,11 +80,18 @@ class MainController: dpg.add_theme_color(dpg.mvPlotCol_Line, (255, 0, 0, 128), category=dpg.mvThemeCat_Plots) def on_data_loaded(self, data: dict): - self.playback_manager.set_route_duration(data['duration']) - num_msg_types = len(data['time_series_data']) - dpg.set_value("load_status", f"Loaded {num_msg_types} message types") - dpg.configure_item("load_button", enabled=True) - dpg.configure_item("timeline_slider", max_value=data['duration']) + duration = data.get('duration', 0.0) + self.playback_manager.set_route_duration(duration) + + if data.get('loading_complete'): + num_paths = len(self.data_manager.get_all_paths()) + dpg.set_value("load_status", f"Loaded {num_paths} data paths") + dpg.configure_item("load_button", enabled=True) + elif data.get('segment_added'): + segment_count = data.get('segment_count', 0) + dpg.set_value("load_status", f"Loading... {segment_count} segments processed") + + dpg.configure_item("timeline_slider", max_value=duration) def setup_ui(self): with dpg.item_handler_registry(tag="tree_node_handler"): @@ -179,11 +180,8 @@ class MainController: value_tag = f"value_{path}" if dpg.does_item_exist(value_tag) and dpg.is_item_visible(value_tag): - last_index = self.playback_manager.last_indices.get(path) - value, new_idx = self.data_manager.get_current_value(path, self.playback_manager.current_time_s, last_index) - + value = self.data_manager.get_value_at(path, self.playback_manager.current_time_s) if value is not None: - self.playback_manager.update_index(path, new_idx) formatted_value = format_and_truncate(value, value_column_width, self.avg_char_width) dpg.set_value(value_tag, formatted_value) diff --git a/tools/jotpluggler/views.py b/tools/jotpluggler/views.py index feeff7ab24..feb8abac69 100644 --- a/tools/jotpluggler/views.py +++ b/tools/jotpluggler/views.py @@ -45,13 +45,13 @@ class TimeSeriesPanel(ViewPanel): self._ui_created = False self._preserved_series_data: list[tuple[str, tuple]] = [] # TODO: the way we do this right now doesn't make much sense self._series_legend_tags: dict[str, str] = {} # Maps series_path to legend tag - self.data_manager.add_callback(self.on_data_loaded) + self.data_manager.add_observer(self.on_data_loaded) def preserve_data(self): self._preserved_series_data = [] if self.plotted_series and self._ui_created: for series_path in self.plotted_series: - time_value_data = self.data_manager.get_time_series(series_path) + time_value_data = self.data_manager.get_timeseries(series_path) if time_value_data: self._preserved_series_data.append((series_path, time_value_data)) @@ -86,12 +86,9 @@ class TimeSeriesPanel(ViewPanel): if self.plotted_series: # update legend labels with current values for series_path in self.plotted_series: - last_index = self.playback_manager.last_indices.get(series_path) - value, new_idx = self.data_manager.get_current_value(series_path, current_time_s, last_index) + value = self.data_manager.get_value_at(series_path, current_time_s) if value is not None: - self.playback_manager.update_index(series_path, new_idx) - if isinstance(value, (int, float)): if isinstance(value, float): formatted_value = f"{value:.4f}" if abs(value) < 1000 else f"{value:.3e}" @@ -100,7 +97,6 @@ class TimeSeriesPanel(ViewPanel): else: formatted_value = str(value) - # Update the series label to include current value series_tag = f"series_{self.panel_id}_{series_path.replace('/', '_')}" legend_label = f"{series_path}: {formatted_value}" @@ -125,7 +121,6 @@ class TimeSeriesPanel(ViewPanel): if self.plot_tag and dpg.does_item_exist(self.plot_tag): dpg.delete_item(self.plot_tag) - # self.data_manager.remove_callback(self.on_data_loaded) self._series_legend_tags.clear() self._ui_created = False @@ -136,7 +131,7 @@ class TimeSeriesPanel(ViewPanel): if series_path in self.plotted_series: return False - time_value_data = self.data_manager.get_time_series(series_path) + time_value_data = self.data_manager.get_timeseries(series_path) if time_value_data is None: return False @@ -153,7 +148,7 @@ class TimeSeriesPanel(ViewPanel): if dpg.does_item_exist(series_tag): dpg.delete_item(series_tag) self.plotted_series.remove(series_path) - if series_path in self._series_legend_tags: # Clean up legend tag mapping + if series_path in self._series_legend_tags: del self._series_legend_tags[series_path] def on_data_loaded(self, data: dict): @@ -161,7 +156,7 @@ class TimeSeriesPanel(ViewPanel): self._update_series_data(series_path) def _update_series_data(self, series_path: str) -> bool: - time_value_data = self.data_manager.get_time_series(series_path) + time_value_data = self.data_manager.get_timeseries(series_path) if time_value_data is None: return False @@ -196,7 +191,7 @@ class DataTreeView: self.current_search = "" self.data_tree = DataTreeNode(name="root") self.active_leaf_nodes: list[DataTreeNode] = [] - self.data_manager.add_callback(self.on_data_loaded) + self.data_manager.add_observer(self.on_data_loaded) def on_data_loaded(self, data: dict): with self.ui_lock: @@ -246,7 +241,7 @@ class DataTreeView: for child in sorted_children: if child.is_leaf: - is_plottable = self.data_manager.is_path_plottable(child.full_path) + is_plottable = self.data_manager.is_plottable(child.full_path) # Create draggable item with dpg.group(parent=parent_tag) as draggable_group: @@ -266,10 +261,6 @@ class DataTreeView: node_tag = f"tree_{child.full_path}" label = child.name - if '/' not in child.full_path: - sample_count = len(self.data_manager.time_series_data.get(child.full_path, {}).get('t', [])) - label = f"{child.name} ({sample_count} samples)" - should_open = bool(search_term) and len(search_term) > 1 and any(search_term in path for path in self._get_all_descendant_paths(child)) with dpg.tree_node(label=label, parent=parent_tag, tag=node_tag, default_open=should_open):