jotpluggler: better handle sparse message data and bools (#36124)

* better handle sparse message data

* fix plotting of of bools

* add type for msg._valid

* fix typing

* add assert in case something changes in future
pull/36130/head
Jimmy 1 month ago committed by GitHub
parent d0171084b5
commit 6b13175338
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 99
      tools/jotpluggler/data.py
  2. 18
      tools/jotpluggler/views.py

@ -3,7 +3,7 @@ import threading
import multiprocessing
import bisect
from collections import defaultdict
import tqdm
from tqdm import tqdm
from openpilot.common.swaglog import cloudlog
from openpilot.tools.lib.logreader import _LogFileReader, LogReader
@ -70,9 +70,6 @@ def extract_field_types(schema, prefix, field_types_dict):
def _convert_to_optimal_dtype(values_list, capnp_type):
if not values_list:
return np.array([])
dtype_mapping = {
'bool': np.bool_, 'int8': np.int8, 'int16': np.int16, 'int32': np.int32, 'int64': np.int64,
'uint8': np.uint8, 'uint16': np.uint16, 'uint32': np.uint32, 'uint64': np.uint64,
@ -80,8 +77,8 @@ def _convert_to_optimal_dtype(values_list, capnp_type):
'enum': object, 'anyPointer': object,
}
target_dtype = dtype_mapping.get(capnp_type)
return np.array(values_list, dtype=target_dtype) if target_dtype else np.array(values_list)
target_dtype = dtype_mapping.get(capnp_type, object)
return np.array(values_list, dtype=target_dtype)
def _match_field_type(field_path, field_types):
@ -94,6 +91,21 @@ def _match_field_type(field_path, field_types):
return field_types.get(template_path)
def _get_field_times_values(segment, field_name):
if field_name not in segment:
return None, None
field_data = segment[field_name]
segment_times = segment['t']
if field_data['sparse']:
if len(field_data['t_index']) == 0:
return None, None
return segment_times[field_data['t_index']], field_data['values']
else:
return segment_times, field_data['values']
def msgs_to_time_series(msgs):
"""Extract scalar fields and return (time_series_data, start_time, end_time)."""
collected_data = defaultdict(lambda: {'timestamps': [], 'columns': defaultdict(list), 'sparse_fields': set()})
@ -110,16 +122,22 @@ def msgs_to_time_series(msgs):
max_time = timestamp
sub_msg = getattr(msg, typ)
if not hasattr(sub_msg, 'to_dict') or typ in ('qcomGnss', 'ubloxGnss'):
if not hasattr(sub_msg, 'to_dict'):
continue
if hasattr(sub_msg, 'schema') and typ not in extracted_schemas:
extract_field_types(sub_msg.schema, typ, field_types)
extracted_schemas.add(typ)
try:
msg_dict = sub_msg.to_dict(verbose=True)
except Exception as e:
cloudlog.warning(f"Failed to convert sub_msg.to_dict() for message of type: {typ}: {e}")
continue
flat_dict = flatten_dict(msg_dict)
flat_dict['_valid'] = msg.valid
field_types[f"{typ}/_valid"] = 'bool'
type_data = collected_data[typ]
columns, sparse_fields = type_data['columns'], type_data['sparse_fields']
@ -152,11 +170,26 @@ def msgs_to_time_series(msgs):
values = [None] * (len(data['timestamps']) - len(values)) + values
sparse_fields.add(field_name)
if field_name in sparse_fields:
typ_result[field_name] = np.array(values, dtype=object)
else:
capnp_type = _match_field_type(f"{typ}/{field_name}", field_types)
typ_result[field_name] = _convert_to_optimal_dtype(values, capnp_type)
if field_name in sparse_fields: # extract non-None values and their indices
non_none_indices = []
non_none_values = []
for i, value in enumerate(values):
if value is not None:
non_none_indices.append(i)
non_none_values.append(value)
if non_none_values: # check if indices > uint16 max, currently would require a 1000+ Hz signal since indices are within segments
assert max(non_none_indices) <= 65535, f"Sparse field {typ}/{field_name} has timestamp indices exceeding uint16 max. Max: {max(non_none_indices)}"
typ_result[field_name] = {
'values': _convert_to_optimal_dtype(non_none_values, capnp_type),
'sparse': True,
't_index': np.array(non_none_indices, dtype=np.uint16),
}
else: # dense representation
typ_result[field_name] = {'values': _convert_to_optimal_dtype(values, capnp_type), 'sparse': False}
final_result[typ] = typ_result
@ -195,18 +228,27 @@ class DataManager:
times, values = [], []
for segment in self._segments:
if msg_type in segment and field in segment[msg_type]:
times.append(segment[msg_type]['t'])
values.append(segment[msg_type][field])
if msg_type in segment:
field_times, field_values = _get_field_times_values(segment[msg_type], field)
if field_times is not None:
times.append(field_times)
values.append(field_values)
if not times:
return [], []
return np.array([]), np.array([])
combined_times = np.concatenate(times) - self._start_time
if len(values) > 1 and any(arr.dtype != values[0].dtype for arr in values):
values = [arr.astype(object) for arr in values]
return combined_times, np.concatenate(values)
if len(values) > 1:
first_dtype = values[0].dtype
if all(arr.dtype == first_dtype for arr in values): # check if all arrays have compatible dtypes
combined_values = np.concatenate(values)
else:
combined_values = np.concatenate([arr.astype(object) for arr in values])
else:
combined_values = values[0] if values else np.array([])
return combined_times, combined_values
def get_value_at(self, path: str, time: float):
with self._lock:
@ -218,14 +260,14 @@ class DataManager:
if not 0 <= index < len(self._segments):
continue
segment = self._segments[index].get(message_type)
if not segment or field not in segment:
if not segment:
continue
times = segment['t']
if len(times) == 0 or (index != current_index and absolute_time - times[-1] > MAX_LOOKBACK):
times, values = _get_field_times_values(segment, field)
if times is None or len(times) == 0 or (index != current_index and absolute_time - times[-1] > MAX_LOOKBACK):
continue
position = np.searchsorted(times, absolute_time, 'right') - 1
if position >= 0 and absolute_time - times[position] <= MAX_LOOKBACK:
return segment[field][position]
return values[position]
return None
def get_all_paths(self):
@ -237,10 +279,9 @@ class DataManager:
return self._duration
def is_plottable(self, path: str):
data = self.get_timeseries(path)
if data is None:
_, values = self.get_timeseries(path)
if len(values) == 0:
return False
_, values = data
return np.issubdtype(values.dtype, np.number) or np.issubdtype(values.dtype, np.bool_)
def add_observer(self, callback):
@ -272,7 +313,7 @@ class DataManager:
return
num_processes = max(1, multiprocessing.cpu_count() // 2)
with multiprocessing.Pool(processes=num_processes) as pool, tqdm.tqdm(total=len(lr.logreader_identifiers), desc="Processing Segments") as pbar:
with multiprocessing.Pool(processes=num_processes) as pool, tqdm(total=len(lr.logreader_identifiers), desc="Processing Segments") as pbar:
for segment_result, start_time, end_time in pool.imap(_process_segment, lr.logreader_identifiers):
pbar.update(1)
if segment_result:
@ -292,9 +333,9 @@ class DataManager:
self._duration = end_time - self._start_time
for msg_type, data in segment_data.items():
for field in data.keys():
if field != 't':
self._paths.add(f"{msg_type}/{field}")
for field_name in data.keys():
if field_name != 't':
self._paths.add(f"{msg_type}/{field_name}")
observers = self._observers.copy()

@ -46,10 +46,10 @@ class TimeSeriesPanel(ViewPanel):
self.y_axis_tag = f"{self.plot_tag}_y_axis"
self.timeline_indicator_tag = f"{self.plot_tag}_timeline"
self._ui_created = False
self._series_data: dict[str, tuple[list, list]] = {}
self._series_data: dict[str, tuple[np.ndarray, np.ndarray]] = {}
self._last_plot_duration = 0
self._update_lock = threading.RLock()
self.results_deque: deque[tuple[str, list, list]] = deque()
self._results_deque: deque[tuple[str, list, list]] = deque()
self._new_data = False
def create_ui(self, parent_tag: str):
@ -75,12 +75,12 @@ class TimeSeriesPanel(ViewPanel):
for series_path in list(self._series_data.keys()):
self.add_series(series_path, update=True)
while self.results_deque: # handle downsampled results in main thread
results = self.results_deque.popleft()
while self._results_deque: # handle downsampled results in main thread
results = self._results_deque.popleft()
for series_path, downsampled_time, downsampled_values in results:
series_tag = f"series_{self.panel_id}_{series_path}"
if dpg.does_item_exist(series_tag):
dpg.set_value(series_tag, [downsampled_time, downsampled_values])
dpg.set_value(series_tag, (downsampled_time, downsampled_values.astype(float)))
# update timeline
current_time_s = self.playback_manager.current_time_s
@ -118,11 +118,11 @@ class TimeSeriesPanel(ViewPanel):
target_points = max(int(target_points_per_second * series_duration), plot_width)
work_items.append((series_path, time_array, value_array, target_points))
elif dpg.does_item_exist(f"series_{self.panel_id}_{series_path}"):
dpg.set_value(f"series_{self.panel_id}_{series_path}", [time_array, value_array])
dpg.set_value(f"series_{self.panel_id}_{series_path}", (time_array, value_array.astype(float)))
if work_items:
self.worker_manager.submit_task(
TimeSeriesPanel._downsample_worker, work_items, callback=lambda results: self.results_deque.append(results), task_id=f"downsample_{self.panel_id}"
TimeSeriesPanel._downsample_worker, work_items, callback=lambda results: self._results_deque.append(results), task_id=f"downsample_{self.panel_id}"
)
def add_series(self, series_path: str, update: bool = False):
@ -133,9 +133,9 @@ class TimeSeriesPanel(ViewPanel):
time_array, value_array = self._series_data[series_path]
series_tag = f"series_{self.panel_id}_{series_path}"
if dpg.does_item_exist(series_tag):
dpg.set_value(series_tag, [time_array, value_array])
dpg.set_value(series_tag, (time_array, value_array.astype(float)))
else:
line_series_tag = dpg.add_line_series(x=time_array, y=value_array, label=series_path, parent=self.y_axis_tag, tag=series_tag)
line_series_tag = dpg.add_line_series(x=time_array, y=value_array.astype(float), label=series_path, parent=self.y_axis_tag, tag=series_tag)
dpg.bind_item_theme(line_series_tag, "global_line_theme")
dpg.fit_axis_data(self.x_axis_tag)
dpg.fit_axis_data(self.y_axis_tag)

Loading…
Cancel
Save