multiprocessed data ingest

pull/36045/head
Quantizr (Jimmy) 2 weeks ago
parent 6048eb568a
commit 2eb697a730
  1. 393
      tools/jotpluggler/data.py
  2. 30
      tools/jotpluggler/pluggle.py
  3. 25
      tools/jotpluggler/views.py

@ -1,144 +1,301 @@
import threading
import numpy as np
from collections.abc import Callable
import threading
import multiprocessing
import bisect
from collections import defaultdict
from typing import Any
import tqdm
from openpilot.common.swaglog import cloudlog
from openpilot.tools.lib.logreader import LogReader
from openpilot.tools.lib.log_time_series import msgs_to_time_series
from openpilot.tools.lib.logreader import _LogFileReader, LogReader
def flatten_dict(d: dict, sep: str = "/", prefix: str = None) -> dict:
result = {}
stack = [(d, prefix)]
while stack:
obj, current_prefix = stack.pop()
if isinstance(obj, dict):
for key, val in obj.items():
new_prefix = key if current_prefix is None else f"{current_prefix}{sep}{key}"
if isinstance(val, (dict, list)):
stack.append((val, new_prefix))
else:
result[new_prefix] = val
elif isinstance(obj, list):
for i, item in enumerate(obj):
new_prefix = f"{current_prefix}{sep}{i}"
if isinstance(item, (dict, list)):
stack.append((item, new_prefix))
else:
result[new_prefix] = item
else:
if current_prefix is not None:
result[current_prefix] = obj
return result
def extract_field_types(schema, prefix, field_types_dict):
stack = [(schema, prefix)]
while stack:
current_schema, current_prefix = stack.pop()
for field in current_schema.fields_list:
field_name = field.proto.name
field_path = f"{current_prefix}/{field_name}"
field_proto = field.proto
field_which = field_proto.which()
field_type = field_proto.slot.type.which() if field_which == 'slot' else field_which
field_types_dict[field_path] = field_type
if field_which == 'slot':
slot_type = field_proto.slot.type
type_which = slot_type.which()
if type_which == 'list':
element_type = slot_type.list.elementType.which()
list_path = f"{field_path}/*"
field_types_dict[list_path] = element_type
if element_type == 'struct':
stack.append((field.schema.elementType, list_path))
elif type_which == 'struct':
stack.append((field.schema, field_path))
elif field_which == 'group':
stack.append((field.schema, field_path))
def _convert_to_optimal_dtype(values_list, capnp_type):
if not values_list:
return np.array([])
dtype_mapping = {
'bool': np.bool_, 'int8': np.int8, 'int16': np.int16, 'int32': np.int32, 'int64': np.int64,
'uint8': np.uint8, 'uint16': np.uint16, 'uint32': np.uint32, 'uint64': np.uint64,
'float32': np.float32, 'float64': np.float64, 'text': object, 'data': object,
'enum': object, 'anyPointer': object,
}
target_dtype = dtype_mapping.get(capnp_type)
return np.array(values_list, dtype=target_dtype) if target_dtype else np.array(values_list)
def _match_field_type(field_path, field_types):
if field_path in field_types:
return field_types[field_path]
path_parts = field_path.split('/')
template_parts = [p if not p.isdigit() else '*' for p in path_parts]
template_path = '/'.join(template_parts)
return field_types.get(template_path)
def msgs_to_time_series(msgs):
"""Extract scalar fields and return (time_series_data, start_time, end_time)."""
collected_data = defaultdict(lambda: {'timestamps': [], 'columns': defaultdict(list), 'sparse_fields': set()})
field_types = {}
extracted_schemas = set()
min_time = max_time = None
for msg in msgs:
typ = msg.which()
timestamp = msg.logMonoTime * 1e-9
if typ != 'initData':
if min_time is None:
min_time = timestamp
max_time = timestamp
sub_msg = getattr(msg, typ)
if not hasattr(sub_msg, 'to_dict') or typ in ('qcomGnss', 'ubloxGnss'):
continue
if hasattr(sub_msg, 'schema') and typ not in extracted_schemas:
extract_field_types(sub_msg.schema, typ, field_types)
extracted_schemas.add(typ)
msg_dict = sub_msg.to_dict(verbose=True)
flat_dict = flatten_dict(msg_dict)
flat_dict['_valid'] = msg.valid
type_data = collected_data[typ]
columns, sparse_fields = type_data['columns'], type_data['sparse_fields']
known_fields = set(columns.keys())
missing_fields = known_fields - flat_dict.keys()
for field, value in flat_dict.items():
if field not in known_fields and type_data['timestamps']:
sparse_fields.add(field)
columns[field].append(value)
if value is None:
sparse_fields.add(field)
for field in missing_fields:
columns[field].append(None)
sparse_fields.add(field)
type_data['timestamps'].append(timestamp)
final_result = {}
for typ, data in collected_data.items():
if not data['timestamps']:
continue
typ_result = {'t': np.array(data['timestamps'], dtype=np.float64)}
sparse_fields = data['sparse_fields']
for field_name, values in data['columns'].items():
if len(values) < len(data['timestamps']):
values = [None] * (len(data['timestamps']) - len(values)) + values
sparse_fields.add(field_name)
if field_name in sparse_fields:
typ_result[field_name] = np.array(values, dtype=object)
else:
capnp_type = _match_field_type(f"{typ}/{field_name}", field_types)
typ_result[field_name] = _convert_to_optimal_dtype(values, capnp_type)
final_result[typ] = typ_result
return final_result, min_time or 0.0, max_time or 0.0
def _process_segment(segment_identifier: str) -> tuple[dict[str, Any], float, float]:
try:
lr = _LogFileReader(segment_identifier, sort_by_time=True)
return msgs_to_time_series(lr)
except Exception as e:
cloudlog.warning(f"Warning: Failed to process segment {segment_identifier}: {e}")
return {}, 0.0, 0.0
class DataManager:
def __init__(self):
self.time_series_data = {}
self._segments = []
self._segment_starts = []
self._start_time = 0.0
self._duration = 0.0
self._paths = set()
self._observers = []
self.loading = False
self.route_start_time_mono = 0.0
self.duration = 0.0
self._callbacks: list[Callable[[dict], None]] = []
def add_callback(self, callback: Callable[[dict], None]):
self._callbacks.append(callback)
self._lock = threading.RLock()
def remove_callback(self, callback: Callable[[dict], None]):
if callback in self._callbacks:
self._callbacks.remove(callback)
def load_route(self, route: str) -> None:
if self.loading:
return
self._reset()
threading.Thread(target=self._load_async, args=(route,), daemon=True).start()
def _notify_callbacks(self, data: dict):
for callback in self._callbacks:
try:
callback(data)
except Exception as e:
cloudlog.exception(f"Error in data callback: {e}")
def get_timeseries(self, path: str):
with self._lock:
msg_type, field = path.split('/', 1)
times, values = [], []
def get_current_value(self, path: str, time_s: float, last_index: int | None = None):
try:
abs_time_s = self.route_start_time_mono + time_s
msg_type, field_path = path.split('/', 1)
ts_data = self.time_series_data[msg_type]
t, v = ts_data['t'], ts_data[field_path]
if len(t) == 0:
return None, None
if last_index is None: # jump
idx = np.searchsorted(t, abs_time_s, side='right') - 1
else: # continuous playback
idx = last_index
while idx < len(t) - 1 and t[idx + 1] < abs_time_s:
idx += 1
idx = max(0, idx)
return v[idx], idx
except (KeyError, IndexError):
return None, None
def get_all_paths(self) -> list[str]:
all_paths = []
for msg_type, data in self.time_series_data.items():
for key in data.keys():
if key != 't':
all_paths.append(f"{msg_type}/{key}")
return all_paths
def is_path_plottable(self, path: str) -> bool:
try:
msg_type, field_path = path.split('/', 1)
value_array = self.time_series_data.get(msg_type, {}).get(field_path)
if value_array is not None:
return np.issubdtype(value_array.dtype, np.number) or np.issubdtype(value_array.dtype, np.bool_)
except (ValueError, KeyError):
pass
return False
def get_time_series(self, path: str):
try:
msg_type, field_path = path.split('/', 1)
ts_data = self.time_series_data[msg_type]
time_array = ts_data['t']
values = ts_data[field_path]
for segment in self._segments:
if msg_type in segment and field in segment[msg_type]:
times.append(segment[msg_type]['t'])
values.append(segment[msg_type][field])
if len(time_array) == 0:
if not times:
return None
rel_time = time_array - self.route_start_time_mono
return rel_time, values
combined_times = np.concatenate(times) - self._start_time
if len(values) > 1 and any(arr.dtype != values[0].dtype for arr in values):
values = [arr.astype(object) for arr in values]
return combined_times, np.concatenate(values)
except (KeyError, ValueError):
def get_value_at(self, path: str, time: float):
with self._lock:
absolute_time = self._start_time + time
message_type, field = path.split('/', 1)
current_index = bisect.bisect_right(self._segment_starts, absolute_time) - 1
for index in (current_index, current_index - 1):
if not 0 <= index < len(self._segments):
continue
segment = self._segments[index].get(message_type)
if not segment or field not in segment:
continue
times = segment['t']
if len(times) == 0 or (index != current_index and absolute_time - times[-1] > 1):
continue
position = np.searchsorted(times, absolute_time, 'right') - 1
if position >= 0 and absolute_time - times[position] <= 1:
return segment[field][position]
return None
def load_route(self, route_name: str):
if self.loading:
return
def get_all_paths(self):
with self._lock:
return sorted(self._paths)
self.loading = True
threading.Thread(target=self._load_route_background, args=(route_name,), daemon=True).start()
def get_duration(self):
with self._lock:
return self._duration
def _load_route_background(self, route_name: str):
try:
lr = LogReader(route_name)
raw_data = msgs_to_time_series(lr)
processed_data = self._expand_list_fields(raw_data)
def is_plottable(self, path: str):
data = self.get_timeseries(path)
if data is None:
return False
_, values = data
return np.issubdtype(values.dtype, np.number) or np.issubdtype(values.dtype, np.bool_)
min_time = float('inf')
max_time = float('-inf')
for data in processed_data.values():
if len(data['t']) > 0:
min_time = min(min_time, data['t'][0])
max_time = max(max_time, data['t'][-1])
def add_observer(self, callback):
with self._lock:
self._observers.append(callback)
self.time_series_data = processed_data
self.route_start_time_mono = min_time if min_time != float('inf') else 0.0
self.duration = max_time - min_time if max_time != float('-inf') else 0.0
def _reset(self):
with self._lock:
self.loading = True
self._segments.clear()
self._segment_starts.clear()
self._paths.clear()
self._start_time = self._duration = 0.0
self._notify_callbacks({'time_series_data': processed_data, 'route_start_time_mono': self.route_start_time_mono, 'duration': self.duration})
def _load_async(self, route: str):
try:
lr = LogReader(route, sort_by_time=True)
if not lr.logreader_identifiers:
cloudlog.warning(f"Warning: No log segments found for route: {route}")
return
with multiprocessing.Pool() as pool, tqdm.tqdm(total=len(lr.logreader_identifiers), desc="Processing Segments") as pbar:
for segment_result, start_time, end_time in pool.imap(_process_segment, lr.logreader_identifiers):
pbar.update(1)
if segment_result:
self._add_segment(segment_result, start_time, end_time)
except Exception as e:
cloudlog.exception(f"Error loading route {route_name}: {e}")
cloudlog.exception(f"Error loading route {route}:")
finally:
self.loading = False
self._finalize_loading()
def _expand_list_fields(self, time_series_data):
expanded_data = {}
for msg_type, data in time_series_data.items():
expanded_data[msg_type] = {}
for field, values in data.items():
if field == 't':
expanded_data[msg_type]['t'] = values
continue
def _add_segment(self, segment_data: dict, start_time: float, end_time: float):
with self._lock:
self._segments.append(segment_data)
self._segment_starts.append(start_time)
if values.dtype == object: # ragged array
lens = np.fromiter((len(v) for v in values), dtype=int, count=len(values))
max_len = lens.max() if lens.size else 0
if max_len > 0:
arr = np.full((len(values), max_len), None, dtype=object)
for i, v in enumerate(values):
arr[i, : lens[i]] = v
for i in range(max_len):
sub_arr = arr[:, i]
expanded_data[msg_type][f"{field}/{i}"] = sub_arr
elif values.ndim > 1: # regular multidimensional array
for i in range(values.shape[1]):
col_data = values[:, i]
expanded_data[msg_type][f"{field}/{i}"] = col_data
else:
expanded_data[msg_type][field] = values
return expanded_data
if len(self._segments) == 1:
self._start_time = start_time
self._duration = end_time - self._start_time
for msg_type, data in segment_data.items():
for field in data.keys():
if field != 't':
self._paths.add(f"{msg_type}/{field}")
observers = self._observers.copy()
for callback in observers:
callback({'segment_added': True, 'duration': self._duration})
def _finalize_loading(self):
with self._lock:
self.loading = False
observers = self._observers.copy()
duration = self._duration
for callback in observers:
callback({'loading_complete': True, 'duration': duration})

@ -18,7 +18,6 @@ class PlaybackManager:
self.is_playing = False
self.current_time_s = 0.0
self.duration_s = 0.0
self.last_indices = {}
def set_route_duration(self, duration: float):
self.duration_s = duration
@ -32,7 +31,6 @@ class PlaybackManager:
def seek(self, time_s: float):
self.is_playing = False
self.current_time_s = max(0.0, min(time_s, self.duration_s))
self.last_indices.clear()
def update_time(self, delta_t: float):
if self.is_playing:
@ -41,10 +39,6 @@ class PlaybackManager:
self.is_playing = False
return self.current_time_s
def update_index(self, path: str, new_idx: int | None):
if new_idx is not None:
self.last_indices[path] = new_idx
def calculate_avg_char_width(font):
sample_text = "abcdefghijklmnopqrstuvwxyz0123456789"
@ -70,7 +64,7 @@ class MainController:
self._create_global_themes()
self.data_tree_view = DataTreeView(self.data_manager, self.ui_lock)
self.plot_layout_manager = PlotLayoutManager(self.data_manager, self.playback_manager, scale=self.scale)
self.data_manager.add_callback(self.on_data_loaded)
self.data_manager.add_observer(self.on_data_loaded)
self.avg_char_width = None
def _create_global_themes(self):
@ -86,11 +80,18 @@ class MainController:
dpg.add_theme_color(dpg.mvPlotCol_Line, (255, 0, 0, 128), category=dpg.mvThemeCat_Plots)
def on_data_loaded(self, data: dict):
self.playback_manager.set_route_duration(data['duration'])
num_msg_types = len(data['time_series_data'])
dpg.set_value("load_status", f"Loaded {num_msg_types} message types")
dpg.configure_item("load_button", enabled=True)
dpg.configure_item("timeline_slider", max_value=data['duration'])
duration = data.get('duration', 0.0)
self.playback_manager.set_route_duration(duration)
if data.get('loading_complete'):
num_paths = len(self.data_manager.get_all_paths())
dpg.set_value("load_status", f"Loaded {num_paths} data paths")
dpg.configure_item("load_button", enabled=True)
elif data.get('segment_added'):
segment_count = data.get('segment_count', 0)
dpg.set_value("load_status", f"Loading... {segment_count} segments processed")
dpg.configure_item("timeline_slider", max_value=duration)
def setup_ui(self):
with dpg.item_handler_registry(tag="tree_node_handler"):
@ -179,11 +180,8 @@ class MainController:
value_tag = f"value_{path}"
if dpg.does_item_exist(value_tag) and dpg.is_item_visible(value_tag):
last_index = self.playback_manager.last_indices.get(path)
value, new_idx = self.data_manager.get_current_value(path, self.playback_manager.current_time_s, last_index)
value = self.data_manager.get_value_at(path, self.playback_manager.current_time_s)
if value is not None:
self.playback_manager.update_index(path, new_idx)
formatted_value = format_and_truncate(value, value_column_width, self.avg_char_width)
dpg.set_value(value_tag, formatted_value)

@ -45,13 +45,13 @@ class TimeSeriesPanel(ViewPanel):
self._ui_created = False
self._preserved_series_data: list[tuple[str, tuple]] = [] # TODO: the way we do this right now doesn't make much sense
self._series_legend_tags: dict[str, str] = {} # Maps series_path to legend tag
self.data_manager.add_callback(self.on_data_loaded)
self.data_manager.add_observer(self.on_data_loaded)
def preserve_data(self):
self._preserved_series_data = []
if self.plotted_series and self._ui_created:
for series_path in self.plotted_series:
time_value_data = self.data_manager.get_time_series(series_path)
time_value_data = self.data_manager.get_timeseries(series_path)
if time_value_data:
self._preserved_series_data.append((series_path, time_value_data))
@ -86,12 +86,9 @@ class TimeSeriesPanel(ViewPanel):
if self.plotted_series: # update legend labels with current values
for series_path in self.plotted_series:
last_index = self.playback_manager.last_indices.get(series_path)
value, new_idx = self.data_manager.get_current_value(series_path, current_time_s, last_index)
value = self.data_manager.get_value_at(series_path, current_time_s)
if value is not None:
self.playback_manager.update_index(series_path, new_idx)
if isinstance(value, (int, float)):
if isinstance(value, float):
formatted_value = f"{value:.4f}" if abs(value) < 1000 else f"{value:.3e}"
@ -100,7 +97,6 @@ class TimeSeriesPanel(ViewPanel):
else:
formatted_value = str(value)
# Update the series label to include current value
series_tag = f"series_{self.panel_id}_{series_path.replace('/', '_')}"
legend_label = f"{series_path}: {formatted_value}"
@ -125,7 +121,6 @@ class TimeSeriesPanel(ViewPanel):
if self.plot_tag and dpg.does_item_exist(self.plot_tag):
dpg.delete_item(self.plot_tag)
# self.data_manager.remove_callback(self.on_data_loaded)
self._series_legend_tags.clear()
self._ui_created = False
@ -136,7 +131,7 @@ class TimeSeriesPanel(ViewPanel):
if series_path in self.plotted_series:
return False
time_value_data = self.data_manager.get_time_series(series_path)
time_value_data = self.data_manager.get_timeseries(series_path)
if time_value_data is None:
return False
@ -153,7 +148,7 @@ class TimeSeriesPanel(ViewPanel):
if dpg.does_item_exist(series_tag):
dpg.delete_item(series_tag)
self.plotted_series.remove(series_path)
if series_path in self._series_legend_tags: # Clean up legend tag mapping
if series_path in self._series_legend_tags:
del self._series_legend_tags[series_path]
def on_data_loaded(self, data: dict):
@ -161,7 +156,7 @@ class TimeSeriesPanel(ViewPanel):
self._update_series_data(series_path)
def _update_series_data(self, series_path: str) -> bool:
time_value_data = self.data_manager.get_time_series(series_path)
time_value_data = self.data_manager.get_timeseries(series_path)
if time_value_data is None:
return False
@ -196,7 +191,7 @@ class DataTreeView:
self.current_search = ""
self.data_tree = DataTreeNode(name="root")
self.active_leaf_nodes: list[DataTreeNode] = []
self.data_manager.add_callback(self.on_data_loaded)
self.data_manager.add_observer(self.on_data_loaded)
def on_data_loaded(self, data: dict):
with self.ui_lock:
@ -246,7 +241,7 @@ class DataTreeView:
for child in sorted_children:
if child.is_leaf:
is_plottable = self.data_manager.is_path_plottable(child.full_path)
is_plottable = self.data_manager.is_plottable(child.full_path)
# Create draggable item
with dpg.group(parent=parent_tag) as draggable_group:
@ -266,10 +261,6 @@ class DataTreeView:
node_tag = f"tree_{child.full_path}"
label = child.name
if '/' not in child.full_path:
sample_count = len(self.data_manager.time_series_data.get(child.full_path, {}).get('t', []))
label = f"{child.name} ({sample_count} samples)"
should_open = bool(search_term) and len(search_term) > 1 and any(search_term in path for path in self._get_all_descendant_paths(child))
with dpg.tree_node(label=label, parent=parent_tag, tag=node_tag, default_open=should_open):

Loading…
Cancel
Save