You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
306 lines
9.8 KiB
306 lines
9.8 KiB
import numpy as np
|
|
import threading
|
|
import multiprocessing
|
|
import bisect
|
|
from collections import defaultdict
|
|
import tqdm
|
|
from openpilot.common.swaglog import cloudlog
|
|
from openpilot.tools.lib.logreader import _LogFileReader, LogReader
|
|
|
|
|
|
def flatten_dict(d: dict, sep: str = "/", prefix: str = None) -> dict:
|
|
result = {}
|
|
stack: list[tuple] = [(d, prefix)]
|
|
|
|
while stack:
|
|
obj, current_prefix = stack.pop()
|
|
|
|
if isinstance(obj, dict):
|
|
for key, val in obj.items():
|
|
new_prefix = key if current_prefix is None else f"{current_prefix}{sep}{key}"
|
|
if isinstance(val, (dict, list)):
|
|
stack.append((val, new_prefix))
|
|
else:
|
|
result[new_prefix] = val
|
|
elif isinstance(obj, list):
|
|
for i, item in enumerate(obj):
|
|
new_prefix = f"{current_prefix}{sep}{i}"
|
|
if isinstance(item, (dict, list)):
|
|
stack.append((item, new_prefix))
|
|
else:
|
|
result[new_prefix] = item
|
|
else:
|
|
if current_prefix is not None:
|
|
result[current_prefix] = obj
|
|
return result
|
|
|
|
|
|
def extract_field_types(schema, prefix, field_types_dict):
|
|
stack = [(schema, prefix)]
|
|
|
|
while stack:
|
|
current_schema, current_prefix = stack.pop()
|
|
|
|
for field in current_schema.fields_list:
|
|
field_name = field.proto.name
|
|
field_path = f"{current_prefix}/{field_name}"
|
|
field_proto = field.proto
|
|
field_which = field_proto.which()
|
|
|
|
field_type = field_proto.slot.type.which() if field_which == 'slot' else field_which
|
|
field_types_dict[field_path] = field_type
|
|
|
|
if field_which == 'slot':
|
|
slot_type = field_proto.slot.type
|
|
type_which = slot_type.which()
|
|
|
|
if type_which == 'list':
|
|
element_type = slot_type.list.elementType.which()
|
|
list_path = f"{field_path}/*"
|
|
field_types_dict[list_path] = element_type
|
|
|
|
if element_type == 'struct':
|
|
stack.append((field.schema.elementType, list_path))
|
|
|
|
elif type_which == 'struct':
|
|
stack.append((field.schema, field_path))
|
|
|
|
elif field_which == 'group':
|
|
stack.append((field.schema, field_path))
|
|
|
|
|
|
def _convert_to_optimal_dtype(values_list, capnp_type):
|
|
if not values_list:
|
|
return np.array([])
|
|
|
|
dtype_mapping = {
|
|
'bool': np.bool_, 'int8': np.int8, 'int16': np.int16, 'int32': np.int32, 'int64': np.int64,
|
|
'uint8': np.uint8, 'uint16': np.uint16, 'uint32': np.uint32, 'uint64': np.uint64,
|
|
'float32': np.float32, 'float64': np.float64, 'text': object, 'data': object,
|
|
'enum': object, 'anyPointer': object,
|
|
}
|
|
|
|
target_dtype = dtype_mapping.get(capnp_type)
|
|
return np.array(values_list, dtype=target_dtype) if target_dtype else np.array(values_list)
|
|
|
|
|
|
def _match_field_type(field_path, field_types):
|
|
if field_path in field_types:
|
|
return field_types[field_path]
|
|
|
|
path_parts = field_path.split('/')
|
|
template_parts = [p if not p.isdigit() else '*' for p in path_parts]
|
|
template_path = '/'.join(template_parts)
|
|
return field_types.get(template_path)
|
|
|
|
|
|
def msgs_to_time_series(msgs):
|
|
"""Extract scalar fields and return (time_series_data, start_time, end_time)."""
|
|
collected_data = defaultdict(lambda: {'timestamps': [], 'columns': defaultdict(list), 'sparse_fields': set()})
|
|
field_types = {}
|
|
extracted_schemas = set()
|
|
min_time = max_time = None
|
|
|
|
for msg in msgs:
|
|
typ = msg.which()
|
|
timestamp = msg.logMonoTime * 1e-9
|
|
if typ != 'initData':
|
|
if min_time is None:
|
|
min_time = timestamp
|
|
max_time = timestamp
|
|
|
|
sub_msg = getattr(msg, typ)
|
|
if not hasattr(sub_msg, 'to_dict') or typ in ('qcomGnss', 'ubloxGnss'):
|
|
continue
|
|
|
|
if hasattr(sub_msg, 'schema') and typ not in extracted_schemas:
|
|
extract_field_types(sub_msg.schema, typ, field_types)
|
|
extracted_schemas.add(typ)
|
|
|
|
msg_dict = sub_msg.to_dict(verbose=True)
|
|
flat_dict = flatten_dict(msg_dict)
|
|
flat_dict['_valid'] = msg.valid
|
|
|
|
type_data = collected_data[typ]
|
|
columns, sparse_fields = type_data['columns'], type_data['sparse_fields']
|
|
known_fields = set(columns.keys())
|
|
missing_fields = known_fields - flat_dict.keys()
|
|
|
|
for field, value in flat_dict.items():
|
|
if field not in known_fields and type_data['timestamps']:
|
|
sparse_fields.add(field)
|
|
columns[field].append(value)
|
|
if value is None:
|
|
sparse_fields.add(field)
|
|
|
|
for field in missing_fields:
|
|
columns[field].append(None)
|
|
sparse_fields.add(field)
|
|
|
|
type_data['timestamps'].append(timestamp)
|
|
|
|
final_result = {}
|
|
for typ, data in collected_data.items():
|
|
if not data['timestamps']:
|
|
continue
|
|
|
|
typ_result = {'t': np.array(data['timestamps'], dtype=np.float64)}
|
|
sparse_fields = data['sparse_fields']
|
|
|
|
for field_name, values in data['columns'].items():
|
|
if len(values) < len(data['timestamps']):
|
|
values = [None] * (len(data['timestamps']) - len(values)) + values
|
|
sparse_fields.add(field_name)
|
|
|
|
if field_name in sparse_fields:
|
|
typ_result[field_name] = np.array(values, dtype=object)
|
|
else:
|
|
capnp_type = _match_field_type(f"{typ}/{field_name}", field_types)
|
|
typ_result[field_name] = _convert_to_optimal_dtype(values, capnp_type)
|
|
|
|
final_result[typ] = typ_result
|
|
|
|
return final_result, min_time or 0.0, max_time or 0.0
|
|
|
|
|
|
def _process_segment(segment_identifier: str):
|
|
try:
|
|
lr = _LogFileReader(segment_identifier, sort_by_time=True)
|
|
return msgs_to_time_series(lr)
|
|
except Exception as e:
|
|
cloudlog.warning(f"Warning: Failed to process segment {segment_identifier}: {e}")
|
|
return {}, 0.0, 0.0
|
|
|
|
|
|
class DataManager:
|
|
def __init__(self):
|
|
self._segments = []
|
|
self._segment_starts = []
|
|
self._start_time = 0.0
|
|
self._duration = 0.0
|
|
self._paths = set()
|
|
self._observers = []
|
|
self.loading = False
|
|
self._lock = threading.RLock()
|
|
|
|
def load_route(self, route: str) -> None:
|
|
if self.loading:
|
|
return
|
|
self._reset()
|
|
threading.Thread(target=self._load_async, args=(route,), daemon=True).start()
|
|
|
|
def get_timeseries(self, path: str):
|
|
with self._lock:
|
|
msg_type, field = path.split('/', 1)
|
|
times, values = [], []
|
|
|
|
for segment in self._segments:
|
|
if msg_type in segment and field in segment[msg_type]:
|
|
times.append(segment[msg_type]['t'])
|
|
values.append(segment[msg_type][field])
|
|
|
|
if not times:
|
|
return [], []
|
|
|
|
combined_times = np.concatenate(times) - self._start_time
|
|
if len(values) > 1 and any(arr.dtype != values[0].dtype for arr in values):
|
|
values = [arr.astype(object) for arr in values]
|
|
|
|
return combined_times, np.concatenate(values)
|
|
|
|
def get_value_at(self, path: str, time: float):
|
|
with self._lock:
|
|
absolute_time = self._start_time + time
|
|
message_type, field = path.split('/', 1)
|
|
current_index = bisect.bisect_right(self._segment_starts, absolute_time) - 1
|
|
for index in (current_index, current_index - 1):
|
|
if not 0 <= index < len(self._segments):
|
|
continue
|
|
segment = self._segments[index].get(message_type)
|
|
if not segment or field not in segment:
|
|
continue
|
|
times = segment['t']
|
|
if len(times) == 0 or (index != current_index and absolute_time - times[-1] > 1):
|
|
continue
|
|
position = np.searchsorted(times, absolute_time, 'right') - 1
|
|
if position >= 0 and absolute_time - times[position] <= 1:
|
|
return segment[field][position]
|
|
return None
|
|
|
|
def get_all_paths(self):
|
|
with self._lock:
|
|
return sorted(self._paths)
|
|
|
|
def get_duration(self):
|
|
with self._lock:
|
|
return self._duration
|
|
|
|
def is_plottable(self, path: str):
|
|
data = self.get_timeseries(path)
|
|
if data is None:
|
|
return False
|
|
_, values = data
|
|
return np.issubdtype(values.dtype, np.number) or np.issubdtype(values.dtype, np.bool_)
|
|
|
|
def add_observer(self, callback):
|
|
with self._lock:
|
|
self._observers.append(callback)
|
|
|
|
def remove_observer(self, callback):
|
|
with self._lock:
|
|
if callback in self._observers:
|
|
self._observers.remove(callback)
|
|
|
|
def _reset(self):
|
|
with self._lock:
|
|
self.loading = True
|
|
self._segments.clear()
|
|
self._segment_starts.clear()
|
|
self._paths.clear()
|
|
self._start_time = self._duration = 0.0
|
|
|
|
def _load_async(self, route: str):
|
|
try:
|
|
lr = LogReader(route, sort_by_time=True)
|
|
if not lr.logreader_identifiers:
|
|
cloudlog.warning(f"Warning: No log segments found for route: {route}")
|
|
return
|
|
|
|
num_processes = max(1, multiprocessing.cpu_count() // 2)
|
|
with multiprocessing.Pool(processes=num_processes) as pool, tqdm.tqdm(total=len(lr.logreader_identifiers), desc="Processing Segments") as pbar:
|
|
for segment_result, start_time, end_time in pool.imap(_process_segment, lr.logreader_identifiers):
|
|
pbar.update(1)
|
|
if segment_result:
|
|
self._add_segment(segment_result, start_time, end_time)
|
|
except Exception:
|
|
cloudlog.exception(f"Error loading route {route}:")
|
|
finally:
|
|
self._finalize_loading()
|
|
|
|
def _add_segment(self, segment_data: dict, start_time: float, end_time: float):
|
|
with self._lock:
|
|
self._segments.append(segment_data)
|
|
self._segment_starts.append(start_time)
|
|
|
|
if len(self._segments) == 1:
|
|
self._start_time = start_time
|
|
self._duration = end_time - self._start_time
|
|
|
|
for msg_type, data in segment_data.items():
|
|
for field in data.keys():
|
|
if field != 't':
|
|
self._paths.add(f"{msg_type}/{field}")
|
|
|
|
observers = self._observers.copy()
|
|
|
|
for callback in observers:
|
|
callback({'segment_added': True, 'duration': self._duration, 'segment_count': len(self._segments)})
|
|
|
|
def _finalize_loading(self):
|
|
with self._lock:
|
|
self.loading = False
|
|
observers = self._observers.copy()
|
|
duration = self._duration
|
|
|
|
for callback in observers:
|
|
callback({'loading_complete': True, 'duration': duration})
|
|
|