parent
6048eb568a
commit
2eb697a730
3 changed files with 297 additions and 151 deletions
@ -1,144 +1,301 @@ |
||||
import threading |
||||
import numpy as np |
||||
from collections.abc import Callable |
||||
import threading |
||||
import multiprocessing |
||||
import bisect |
||||
from collections import defaultdict |
||||
from typing import Any |
||||
import tqdm |
||||
from openpilot.common.swaglog import cloudlog |
||||
from openpilot.tools.lib.logreader import LogReader |
||||
from openpilot.tools.lib.log_time_series import msgs_to_time_series |
||||
from openpilot.tools.lib.logreader import _LogFileReader, LogReader |
||||
|
||||
|
||||
def flatten_dict(d: dict, sep: str = "/", prefix: str = None) -> dict: |
||||
result = {} |
||||
stack = [(d, prefix)] |
||||
|
||||
while stack: |
||||
obj, current_prefix = stack.pop() |
||||
|
||||
if isinstance(obj, dict): |
||||
for key, val in obj.items(): |
||||
new_prefix = key if current_prefix is None else f"{current_prefix}{sep}{key}" |
||||
if isinstance(val, (dict, list)): |
||||
stack.append((val, new_prefix)) |
||||
else: |
||||
result[new_prefix] = val |
||||
elif isinstance(obj, list): |
||||
for i, item in enumerate(obj): |
||||
new_prefix = f"{current_prefix}{sep}{i}" |
||||
if isinstance(item, (dict, list)): |
||||
stack.append((item, new_prefix)) |
||||
else: |
||||
result[new_prefix] = item |
||||
else: |
||||
if current_prefix is not None: |
||||
result[current_prefix] = obj |
||||
return result |
||||
|
||||
|
||||
def extract_field_types(schema, prefix, field_types_dict): |
||||
stack = [(schema, prefix)] |
||||
|
||||
while stack: |
||||
current_schema, current_prefix = stack.pop() |
||||
|
||||
for field in current_schema.fields_list: |
||||
field_name = field.proto.name |
||||
field_path = f"{current_prefix}/{field_name}" |
||||
field_proto = field.proto |
||||
field_which = field_proto.which() |
||||
|
||||
field_type = field_proto.slot.type.which() if field_which == 'slot' else field_which |
||||
field_types_dict[field_path] = field_type |
||||
|
||||
if field_which == 'slot': |
||||
slot_type = field_proto.slot.type |
||||
type_which = slot_type.which() |
||||
|
||||
if type_which == 'list': |
||||
element_type = slot_type.list.elementType.which() |
||||
list_path = f"{field_path}/*" |
||||
field_types_dict[list_path] = element_type |
||||
|
||||
if element_type == 'struct': |
||||
stack.append((field.schema.elementType, list_path)) |
||||
|
||||
elif type_which == 'struct': |
||||
stack.append((field.schema, field_path)) |
||||
|
||||
elif field_which == 'group': |
||||
stack.append((field.schema, field_path)) |
||||
|
||||
|
||||
def _convert_to_optimal_dtype(values_list, capnp_type): |
||||
if not values_list: |
||||
return np.array([]) |
||||
|
||||
dtype_mapping = { |
||||
'bool': np.bool_, 'int8': np.int8, 'int16': np.int16, 'int32': np.int32, 'int64': np.int64, |
||||
'uint8': np.uint8, 'uint16': np.uint16, 'uint32': np.uint32, 'uint64': np.uint64, |
||||
'float32': np.float32, 'float64': np.float64, 'text': object, 'data': object, |
||||
'enum': object, 'anyPointer': object, |
||||
} |
||||
|
||||
target_dtype = dtype_mapping.get(capnp_type) |
||||
return np.array(values_list, dtype=target_dtype) if target_dtype else np.array(values_list) |
||||
|
||||
|
||||
def _match_field_type(field_path, field_types): |
||||
if field_path in field_types: |
||||
return field_types[field_path] |
||||
|
||||
path_parts = field_path.split('/') |
||||
template_parts = [p if not p.isdigit() else '*' for p in path_parts] |
||||
template_path = '/'.join(template_parts) |
||||
return field_types.get(template_path) |
||||
|
||||
|
||||
def msgs_to_time_series(msgs): |
||||
"""Extract scalar fields and return (time_series_data, start_time, end_time).""" |
||||
collected_data = defaultdict(lambda: {'timestamps': [], 'columns': defaultdict(list), 'sparse_fields': set()}) |
||||
field_types = {} |
||||
extracted_schemas = set() |
||||
min_time = max_time = None |
||||
|
||||
for msg in msgs: |
||||
typ = msg.which() |
||||
timestamp = msg.logMonoTime * 1e-9 |
||||
if typ != 'initData': |
||||
if min_time is None: |
||||
min_time = timestamp |
||||
max_time = timestamp |
||||
|
||||
sub_msg = getattr(msg, typ) |
||||
if not hasattr(sub_msg, 'to_dict') or typ in ('qcomGnss', 'ubloxGnss'): |
||||
continue |
||||
|
||||
if hasattr(sub_msg, 'schema') and typ not in extracted_schemas: |
||||
extract_field_types(sub_msg.schema, typ, field_types) |
||||
extracted_schemas.add(typ) |
||||
|
||||
msg_dict = sub_msg.to_dict(verbose=True) |
||||
flat_dict = flatten_dict(msg_dict) |
||||
flat_dict['_valid'] = msg.valid |
||||
|
||||
type_data = collected_data[typ] |
||||
columns, sparse_fields = type_data['columns'], type_data['sparse_fields'] |
||||
known_fields = set(columns.keys()) |
||||
missing_fields = known_fields - flat_dict.keys() |
||||
|
||||
for field, value in flat_dict.items(): |
||||
if field not in known_fields and type_data['timestamps']: |
||||
sparse_fields.add(field) |
||||
columns[field].append(value) |
||||
if value is None: |
||||
sparse_fields.add(field) |
||||
|
||||
for field in missing_fields: |
||||
columns[field].append(None) |
||||
sparse_fields.add(field) |
||||
|
||||
type_data['timestamps'].append(timestamp) |
||||
|
||||
final_result = {} |
||||
for typ, data in collected_data.items(): |
||||
if not data['timestamps']: |
||||
continue |
||||
|
||||
typ_result = {'t': np.array(data['timestamps'], dtype=np.float64)} |
||||
sparse_fields = data['sparse_fields'] |
||||
|
||||
for field_name, values in data['columns'].items(): |
||||
if len(values) < len(data['timestamps']): |
||||
values = [None] * (len(data['timestamps']) - len(values)) + values |
||||
sparse_fields.add(field_name) |
||||
|
||||
if field_name in sparse_fields: |
||||
typ_result[field_name] = np.array(values, dtype=object) |
||||
else: |
||||
capnp_type = _match_field_type(f"{typ}/{field_name}", field_types) |
||||
typ_result[field_name] = _convert_to_optimal_dtype(values, capnp_type) |
||||
|
||||
final_result[typ] = typ_result |
||||
|
||||
return final_result, min_time or 0.0, max_time or 0.0 |
||||
|
||||
|
||||
def _process_segment(segment_identifier: str) -> tuple[dict[str, Any], float, float]: |
||||
try: |
||||
lr = _LogFileReader(segment_identifier, sort_by_time=True) |
||||
return msgs_to_time_series(lr) |
||||
except Exception as e: |
||||
cloudlog.warning(f"Warning: Failed to process segment {segment_identifier}: {e}") |
||||
return {}, 0.0, 0.0 |
||||
|
||||
|
||||
class DataManager: |
||||
def __init__(self): |
||||
self.time_series_data = {} |
||||
self._segments = [] |
||||
self._segment_starts = [] |
||||
self._start_time = 0.0 |
||||
self._duration = 0.0 |
||||
self._paths = set() |
||||
self._observers = [] |
||||
self.loading = False |
||||
self.route_start_time_mono = 0.0 |
||||
self.duration = 0.0 |
||||
self._callbacks: list[Callable[[dict], None]] = [] |
||||
|
||||
def add_callback(self, callback: Callable[[dict], None]): |
||||
self._callbacks.append(callback) |
||||
self._lock = threading.RLock() |
||||
|
||||
def remove_callback(self, callback: Callable[[dict], None]): |
||||
if callback in self._callbacks: |
||||
self._callbacks.remove(callback) |
||||
def load_route(self, route: str) -> None: |
||||
if self.loading: |
||||
return |
||||
self._reset() |
||||
threading.Thread(target=self._load_async, args=(route,), daemon=True).start() |
||||
|
||||
def _notify_callbacks(self, data: dict): |
||||
for callback in self._callbacks: |
||||
try: |
||||
callback(data) |
||||
except Exception as e: |
||||
cloudlog.exception(f"Error in data callback: {e}") |
||||
def get_timeseries(self, path: str): |
||||
with self._lock: |
||||
msg_type, field = path.split('/', 1) |
||||
times, values = [], [] |
||||
|
||||
def get_current_value(self, path: str, time_s: float, last_index: int | None = None): |
||||
try: |
||||
abs_time_s = self.route_start_time_mono + time_s |
||||
msg_type, field_path = path.split('/', 1) |
||||
ts_data = self.time_series_data[msg_type] |
||||
t, v = ts_data['t'], ts_data[field_path] |
||||
|
||||
if len(t) == 0: |
||||
return None, None |
||||
|
||||
if last_index is None: # jump |
||||
idx = np.searchsorted(t, abs_time_s, side='right') - 1 |
||||
else: # continuous playback |
||||
idx = last_index |
||||
while idx < len(t) - 1 and t[idx + 1] < abs_time_s: |
||||
idx += 1 |
||||
|
||||
idx = max(0, idx) |
||||
return v[idx], idx |
||||
|
||||
except (KeyError, IndexError): |
||||
return None, None |
||||
|
||||
def get_all_paths(self) -> list[str]: |
||||
all_paths = [] |
||||
for msg_type, data in self.time_series_data.items(): |
||||
for key in data.keys(): |
||||
if key != 't': |
||||
all_paths.append(f"{msg_type}/{key}") |
||||
return all_paths |
||||
|
||||
def is_path_plottable(self, path: str) -> bool: |
||||
try: |
||||
msg_type, field_path = path.split('/', 1) |
||||
value_array = self.time_series_data.get(msg_type, {}).get(field_path) |
||||
if value_array is not None: |
||||
return np.issubdtype(value_array.dtype, np.number) or np.issubdtype(value_array.dtype, np.bool_) |
||||
except (ValueError, KeyError): |
||||
pass |
||||
return False |
||||
|
||||
def get_time_series(self, path: str): |
||||
try: |
||||
msg_type, field_path = path.split('/', 1) |
||||
ts_data = self.time_series_data[msg_type] |
||||
time_array = ts_data['t'] |
||||
values = ts_data[field_path] |
||||
for segment in self._segments: |
||||
if msg_type in segment and field in segment[msg_type]: |
||||
times.append(segment[msg_type]['t']) |
||||
values.append(segment[msg_type][field]) |
||||
|
||||
if len(time_array) == 0: |
||||
if not times: |
||||
return None |
||||
|
||||
rel_time = time_array - self.route_start_time_mono |
||||
return rel_time, values |
||||
combined_times = np.concatenate(times) - self._start_time |
||||
if len(values) > 1 and any(arr.dtype != values[0].dtype for arr in values): |
||||
values = [arr.astype(object) for arr in values] |
||||
|
||||
return combined_times, np.concatenate(values) |
||||
|
||||
except (KeyError, ValueError): |
||||
def get_value_at(self, path: str, time: float): |
||||
with self._lock: |
||||
absolute_time = self._start_time + time |
||||
message_type, field = path.split('/', 1) |
||||
current_index = bisect.bisect_right(self._segment_starts, absolute_time) - 1 |
||||
for index in (current_index, current_index - 1): |
||||
if not 0 <= index < len(self._segments): |
||||
continue |
||||
segment = self._segments[index].get(message_type) |
||||
if not segment or field not in segment: |
||||
continue |
||||
times = segment['t'] |
||||
if len(times) == 0 or (index != current_index and absolute_time - times[-1] > 1): |
||||
continue |
||||
position = np.searchsorted(times, absolute_time, 'right') - 1 |
||||
if position >= 0 and absolute_time - times[position] <= 1: |
||||
return segment[field][position] |
||||
return None |
||||
|
||||
def load_route(self, route_name: str): |
||||
if self.loading: |
||||
return |
||||
def get_all_paths(self): |
||||
with self._lock: |
||||
return sorted(self._paths) |
||||
|
||||
self.loading = True |
||||
threading.Thread(target=self._load_route_background, args=(route_name,), daemon=True).start() |
||||
def get_duration(self): |
||||
with self._lock: |
||||
return self._duration |
||||
|
||||
def _load_route_background(self, route_name: str): |
||||
try: |
||||
lr = LogReader(route_name) |
||||
raw_data = msgs_to_time_series(lr) |
||||
processed_data = self._expand_list_fields(raw_data) |
||||
def is_plottable(self, path: str): |
||||
data = self.get_timeseries(path) |
||||
if data is None: |
||||
return False |
||||
_, values = data |
||||
return np.issubdtype(values.dtype, np.number) or np.issubdtype(values.dtype, np.bool_) |
||||
|
||||
min_time = float('inf') |
||||
max_time = float('-inf') |
||||
for data in processed_data.values(): |
||||
if len(data['t']) > 0: |
||||
min_time = min(min_time, data['t'][0]) |
||||
max_time = max(max_time, data['t'][-1]) |
||||
def add_observer(self, callback): |
||||
with self._lock: |
||||
self._observers.append(callback) |
||||
|
||||
self.time_series_data = processed_data |
||||
self.route_start_time_mono = min_time if min_time != float('inf') else 0.0 |
||||
self.duration = max_time - min_time if max_time != float('-inf') else 0.0 |
||||
def _reset(self): |
||||
with self._lock: |
||||
self.loading = True |
||||
self._segments.clear() |
||||
self._segment_starts.clear() |
||||
self._paths.clear() |
||||
self._start_time = self._duration = 0.0 |
||||
|
||||
self._notify_callbacks({'time_series_data': processed_data, 'route_start_time_mono': self.route_start_time_mono, 'duration': self.duration}) |
||||
def _load_async(self, route: str): |
||||
try: |
||||
lr = LogReader(route, sort_by_time=True) |
||||
if not lr.logreader_identifiers: |
||||
cloudlog.warning(f"Warning: No log segments found for route: {route}") |
||||
return |
||||
|
||||
with multiprocessing.Pool() as pool, tqdm.tqdm(total=len(lr.logreader_identifiers), desc="Processing Segments") as pbar: |
||||
for segment_result, start_time, end_time in pool.imap(_process_segment, lr.logreader_identifiers): |
||||
pbar.update(1) |
||||
if segment_result: |
||||
self._add_segment(segment_result, start_time, end_time) |
||||
except Exception as e: |
||||
cloudlog.exception(f"Error loading route {route_name}: {e}") |
||||
cloudlog.exception(f"Error loading route {route}:") |
||||
finally: |
||||
self.loading = False |
||||
self._finalize_loading() |
||||
|
||||
def _expand_list_fields(self, time_series_data): |
||||
expanded_data = {} |
||||
for msg_type, data in time_series_data.items(): |
||||
expanded_data[msg_type] = {} |
||||
for field, values in data.items(): |
||||
if field == 't': |
||||
expanded_data[msg_type]['t'] = values |
||||
continue |
||||
def _add_segment(self, segment_data: dict, start_time: float, end_time: float): |
||||
with self._lock: |
||||
self._segments.append(segment_data) |
||||
self._segment_starts.append(start_time) |
||||
|
||||
if values.dtype == object: # ragged array |
||||
lens = np.fromiter((len(v) for v in values), dtype=int, count=len(values)) |
||||
max_len = lens.max() if lens.size else 0 |
||||
if max_len > 0: |
||||
arr = np.full((len(values), max_len), None, dtype=object) |
||||
for i, v in enumerate(values): |
||||
arr[i, : lens[i]] = v |
||||
for i in range(max_len): |
||||
sub_arr = arr[:, i] |
||||
expanded_data[msg_type][f"{field}/{i}"] = sub_arr |
||||
elif values.ndim > 1: # regular multidimensional array |
||||
for i in range(values.shape[1]): |
||||
col_data = values[:, i] |
||||
expanded_data[msg_type][f"{field}/{i}"] = col_data |
||||
else: |
||||
expanded_data[msg_type][field] = values |
||||
return expanded_data |
||||
if len(self._segments) == 1: |
||||
self._start_time = start_time |
||||
self._duration = end_time - self._start_time |
||||
|
||||
for msg_type, data in segment_data.items(): |
||||
for field in data.keys(): |
||||
if field != 't': |
||||
self._paths.add(f"{msg_type}/{field}") |
||||
|
||||
observers = self._observers.copy() |
||||
|
||||
for callback in observers: |
||||
callback({'segment_added': True, 'duration': self._duration}) |
||||
|
||||
def _finalize_loading(self): |
||||
with self._lock: |
||||
self.loading = False |
||||
observers = self._observers.copy() |
||||
duration = self._duration |
||||
|
||||
for callback in observers: |
||||
callback({'loading_complete': True, 'duration': duration}) |
||||
|
Loading…
Reference in new issue