You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							360 lines
						
					
					
						
							12 KiB
						
					
					
				
			
		
		
	
	
							360 lines
						
					
					
						
							12 KiB
						
					
					
				| import numpy as np
 | |
| import threading
 | |
| import multiprocessing
 | |
| import bisect
 | |
| from collections import defaultdict
 | |
| from tqdm import tqdm
 | |
| from openpilot.common.swaglog import cloudlog
 | |
| from openpilot.selfdrive.test.process_replay.migration import migrate_all
 | |
| from openpilot.tools.lib.logreader import _LogFileReader, LogReader
 | |
| 
 | |
| 
 | |
| def flatten_dict(d: dict, sep: str = "/", prefix: str = None) -> dict:
 | |
|   result = {}
 | |
|   stack: list[tuple] = [(d, prefix)]
 | |
| 
 | |
|   while stack:
 | |
|     obj, current_prefix = stack.pop()
 | |
| 
 | |
|     if isinstance(obj, dict):
 | |
|       for key, val in obj.items():
 | |
|         new_prefix = key if current_prefix is None else f"{current_prefix}{sep}{key}"
 | |
|         if isinstance(val, (dict, list)):
 | |
|           stack.append((val, new_prefix))
 | |
|         else:
 | |
|           result[new_prefix] = val
 | |
|     elif isinstance(obj, list):
 | |
|       for i, item in enumerate(obj):
 | |
|         new_prefix = f"{current_prefix}{sep}{i}"
 | |
|         if isinstance(item, (dict, list)):
 | |
|           stack.append((item, new_prefix))
 | |
|         else:
 | |
|           result[new_prefix] = item
 | |
|     else:
 | |
|       if current_prefix is not None:
 | |
|         result[current_prefix] = obj
 | |
|   return result
 | |
| 
 | |
| 
 | |
| def extract_field_types(schema, prefix, field_types_dict):
 | |
|   stack = [(schema, prefix)]
 | |
| 
 | |
|   while stack:
 | |
|     current_schema, current_prefix = stack.pop()
 | |
| 
 | |
|     for field in current_schema.fields_list:
 | |
|       field_name = field.proto.name
 | |
|       field_path = f"{current_prefix}/{field_name}"
 | |
|       field_proto = field.proto
 | |
|       field_which = field_proto.which()
 | |
| 
 | |
|       field_type = field_proto.slot.type.which() if field_which == 'slot' else field_which
 | |
|       field_types_dict[field_path] = field_type
 | |
| 
 | |
|       if field_which == 'slot':
 | |
|         slot_type = field_proto.slot.type
 | |
|         type_which = slot_type.which()
 | |
| 
 | |
|         if type_which == 'list':
 | |
|           element_type = slot_type.list.elementType.which()
 | |
|           list_path = f"{field_path}/*"
 | |
|           field_types_dict[list_path] = element_type
 | |
| 
 | |
|           if element_type == 'struct':
 | |
|             stack.append((field.schema.elementType, list_path))
 | |
| 
 | |
|         elif type_which == 'struct':
 | |
|           stack.append((field.schema, field_path))
 | |
| 
 | |
|       elif field_which == 'group':
 | |
|         stack.append((field.schema, field_path))
 | |
| 
 | |
| 
 | |
| def _convert_to_optimal_dtype(values_list, capnp_type):
 | |
|   dtype_mapping = {
 | |
|     'bool': np.bool_, 'int8': np.int8, 'int16': np.int16, 'int32': np.int32, 'int64': np.int64,
 | |
|     'uint8': np.uint8, 'uint16': np.uint16, 'uint32': np.uint32, 'uint64': np.uint64,
 | |
|     'float32': np.float32, 'float64': np.float64, 'text': object, 'data': object,
 | |
|     'enum': object, 'anyPointer': object,
 | |
|   }
 | |
| 
 | |
|   target_dtype = dtype_mapping.get(capnp_type, object)
 | |
|   return np.array(values_list, dtype=target_dtype)
 | |
| 
 | |
| 
 | |
| def _match_field_type(field_path, field_types):
 | |
|   if field_path in field_types:
 | |
|     return field_types[field_path]
 | |
| 
 | |
|   path_parts = field_path.split('/')
 | |
|   template_parts = [p if not p.isdigit() else '*' for p in path_parts]
 | |
|   template_path = '/'.join(template_parts)
 | |
|   return field_types.get(template_path)
 | |
| 
 | |
| 
 | |
| def _get_field_times_values(segment, field_name):
 | |
|   if field_name not in segment:
 | |
|     return None, None
 | |
| 
 | |
|   field_data = segment[field_name]
 | |
|   segment_times = segment['t']
 | |
| 
 | |
|   if field_data['sparse']:
 | |
|     if len(field_data['t_index']) == 0:
 | |
|       return None, None
 | |
|     return segment_times[field_data['t_index']], field_data['values']
 | |
|   else:
 | |
|     return segment_times, field_data['values']
 | |
| 
 | |
| 
 | |
| def msgs_to_time_series(msgs):
 | |
|   """Extract scalar fields and return (time_series_data, start_time, end_time)."""
 | |
|   collected_data = defaultdict(lambda: {'timestamps': [], 'columns': defaultdict(list), 'sparse_fields': set()})
 | |
|   field_types = {}
 | |
|   extracted_schemas = set()
 | |
|   min_time = max_time = None
 | |
| 
 | |
|   for msg in msgs:
 | |
|     typ = msg.which()
 | |
|     timestamp = msg.logMonoTime * 1e-9
 | |
|     if typ != 'initData':
 | |
|       if min_time is None:
 | |
|         min_time = timestamp
 | |
|       max_time = timestamp
 | |
| 
 | |
|     sub_msg = getattr(msg, typ)
 | |
|     if not hasattr(sub_msg, 'to_dict'):
 | |
|       continue
 | |
| 
 | |
|     if hasattr(sub_msg, 'schema') and typ not in extracted_schemas:
 | |
|       extract_field_types(sub_msg.schema, typ, field_types)
 | |
|       extracted_schemas.add(typ)
 | |
| 
 | |
|     try:
 | |
|       msg_dict = sub_msg.to_dict(verbose=True)
 | |
|     except Exception as e:
 | |
|       cloudlog.warning(f"Failed to convert sub_msg.to_dict() for message of type: {typ}: {e}")
 | |
|       continue
 | |
| 
 | |
|     flat_dict = flatten_dict(msg_dict)
 | |
|     flat_dict['_valid'] = msg.valid
 | |
|     field_types[f"{typ}/_valid"] = 'bool'
 | |
| 
 | |
|     type_data = collected_data[typ]
 | |
|     columns, sparse_fields = type_data['columns'], type_data['sparse_fields']
 | |
|     known_fields = set(columns.keys())
 | |
|     missing_fields = known_fields - flat_dict.keys()
 | |
| 
 | |
|     for field, value in flat_dict.items():
 | |
|       if field not in known_fields and type_data['timestamps']:
 | |
|         sparse_fields.add(field)
 | |
|       columns[field].append(value)
 | |
|       if value is None:
 | |
|         sparse_fields.add(field)
 | |
| 
 | |
|     for field in missing_fields:
 | |
|       columns[field].append(None)
 | |
|       sparse_fields.add(field)
 | |
| 
 | |
|     type_data['timestamps'].append(timestamp)
 | |
| 
 | |
|   final_result = {}
 | |
|   for typ, data in collected_data.items():
 | |
|     if not data['timestamps']:
 | |
|       continue
 | |
| 
 | |
|     typ_result = {'t': np.array(data['timestamps'], dtype=np.float64)}
 | |
|     sparse_fields = data['sparse_fields']
 | |
| 
 | |
|     for field_name, values in data['columns'].items():
 | |
|       if len(values) < len(data['timestamps']):
 | |
|         values = [None] * (len(data['timestamps']) - len(values)) + values
 | |
|         sparse_fields.add(field_name)
 | |
| 
 | |
|       capnp_type = _match_field_type(f"{typ}/{field_name}", field_types)
 | |
| 
 | |
|       if field_name in sparse_fields:  # extract non-None values and their indices
 | |
|         non_none_indices = []
 | |
|         non_none_values = []
 | |
|         for i, value in enumerate(values):
 | |
|           if value is not None:
 | |
|             non_none_indices.append(i)
 | |
|             non_none_values.append(value)
 | |
| 
 | |
|         if non_none_values: # check if indices > uint16 max, currently would require a 1000+ Hz signal since indices are within segments
 | |
|           assert max(non_none_indices) <= 65535, f"Sparse field {typ}/{field_name} has timestamp indices exceeding uint16 max. Max: {max(non_none_indices)}"
 | |
| 
 | |
|         typ_result[field_name] = {
 | |
|           'values': _convert_to_optimal_dtype(non_none_values, capnp_type),
 | |
|           'sparse': True,
 | |
|           't_index': np.array(non_none_indices, dtype=np.uint16),
 | |
|         }
 | |
|       else:  # dense representation
 | |
|         typ_result[field_name] = {'values': _convert_to_optimal_dtype(values, capnp_type), 'sparse': False}
 | |
| 
 | |
|     final_result[typ] = typ_result
 | |
| 
 | |
|   return final_result, min_time or 0.0, max_time or 0.0
 | |
| 
 | |
| 
 | |
| def _process_segment(segment_identifier: str):
 | |
|   try:
 | |
|     lr = _LogFileReader(segment_identifier, sort_by_time=True)
 | |
|     migrated_msgs = migrate_all(lr)
 | |
|     return msgs_to_time_series(migrated_msgs)
 | |
|   except Exception as e:
 | |
|     cloudlog.warning(f"Warning: Failed to process segment {segment_identifier}: {e}")
 | |
|     return {}, 0.0, 0.0
 | |
| 
 | |
| 
 | |
| class DataManager:
 | |
|   def __init__(self):
 | |
|     self._segments = []
 | |
|     self._segment_starts = []
 | |
|     self._start_time = 0.0
 | |
|     self._duration = 0.0
 | |
|     self._paths = set()
 | |
|     self._observers = []
 | |
|     self._loading = False
 | |
|     self._lock = threading.RLock()
 | |
| 
 | |
|   def load_route(self, route: str) -> None:
 | |
|     if self._loading:
 | |
|       return
 | |
|     self._reset()
 | |
|     threading.Thread(target=self._load_async, args=(route,), daemon=True).start()
 | |
| 
 | |
|   def get_timeseries(self, path: str):
 | |
|     with self._lock:
 | |
|       msg_type, field = path.split('/', 1)
 | |
|       times, values = [], []
 | |
| 
 | |
|       for segment in self._segments:
 | |
|         if msg_type in segment:
 | |
|           field_times, field_values = _get_field_times_values(segment[msg_type], field)
 | |
|           if field_times is not None:
 | |
|             times.append(field_times)
 | |
|             values.append(field_values)
 | |
| 
 | |
|       if not times:
 | |
|         return np.array([]), np.array([])
 | |
| 
 | |
|       combined_times = np.concatenate(times) - self._start_time
 | |
| 
 | |
|       if len(values) > 1:
 | |
|         first_dtype = values[0].dtype
 | |
|         if all(arr.dtype == first_dtype for arr in values):  # check if all arrays have compatible dtypes
 | |
|           combined_values = np.concatenate(values)
 | |
|         else:
 | |
|           combined_values = np.concatenate([arr.astype(object) for arr in values])
 | |
|       else:
 | |
|         combined_values = values[0] if values else np.array([])
 | |
| 
 | |
|       return combined_times, combined_values
 | |
| 
 | |
|   def get_value_at(self, path: str, time: float):
 | |
|     with self._lock:
 | |
|       MAX_LOOKBACK = 5.0  # seconds
 | |
|       absolute_time = self._start_time + time
 | |
|       message_type, field = path.split('/', 1)
 | |
|       current_index = bisect.bisect_right(self._segment_starts, absolute_time) - 1
 | |
|       for index in (current_index, current_index - 1):
 | |
|         if not 0 <= index < len(self._segments):
 | |
|           continue
 | |
|         segment = self._segments[index].get(message_type)
 | |
|         if not segment:
 | |
|           continue
 | |
|         times, values = _get_field_times_values(segment, field)
 | |
|         if times is None or len(times) == 0 or (index != current_index and absolute_time - times[-1] > MAX_LOOKBACK):
 | |
|           continue
 | |
|         position = np.searchsorted(times, absolute_time, 'right') - 1
 | |
|         if position >= 0 and absolute_time - times[position] <= MAX_LOOKBACK:
 | |
|           return values[position]
 | |
|       return None
 | |
| 
 | |
|   def get_all_paths(self):
 | |
|     with self._lock:
 | |
|       return sorted(self._paths)
 | |
| 
 | |
|   def get_duration(self):
 | |
|     with self._lock:
 | |
|       return self._duration
 | |
| 
 | |
|   def is_plottable(self, path: str):
 | |
|     _, values = self.get_timeseries(path)
 | |
|     if len(values) == 0:
 | |
|       return False
 | |
|     return np.issubdtype(values.dtype, np.number) or np.issubdtype(values.dtype, np.bool_)
 | |
| 
 | |
|   def add_observer(self, callback):
 | |
|     with self._lock:
 | |
|       self._observers.append(callback)
 | |
| 
 | |
|   def remove_observer(self, callback):
 | |
|     with self._lock:
 | |
|       if callback in self._observers:
 | |
|         self._observers.remove(callback)
 | |
| 
 | |
|   def _reset(self):
 | |
|     with self._lock:
 | |
|       self._loading = True
 | |
|       self._segments.clear()
 | |
|       self._segment_starts.clear()
 | |
|       self._paths.clear()
 | |
|       self._start_time = self._duration = 0.0
 | |
|       observers = self._observers.copy()
 | |
| 
 | |
|     for callback in observers:
 | |
|       callback({'reset': True})
 | |
| 
 | |
|   def _load_async(self, route: str):
 | |
|     try:
 | |
|       lr = LogReader(route, sort_by_time=True)
 | |
|       if not lr.logreader_identifiers:
 | |
|         cloudlog.warning(f"Warning: No log segments found for route: {route}")
 | |
|         return
 | |
| 
 | |
|       total_segments = len(lr.logreader_identifiers)
 | |
|       with self._lock:
 | |
|         observers = self._observers.copy()
 | |
|       for callback in observers:
 | |
|         callback({'metadata_loaded': True, 'total_segments': total_segments})
 | |
| 
 | |
|       num_processes = max(1, multiprocessing.cpu_count() // 2)
 | |
|       with multiprocessing.Pool(processes=num_processes) as pool, tqdm(total=len(lr.logreader_identifiers), desc="Processing Segments") as pbar:
 | |
|         for segment_result, start_time, end_time in pool.imap(_process_segment, lr.logreader_identifiers):
 | |
|           pbar.update(1)
 | |
|           if segment_result:
 | |
|             self._add_segment(segment_result, start_time, end_time)
 | |
|     except Exception:
 | |
|       cloudlog.exception(f"Error loading route {route}:")
 | |
|     finally:
 | |
|       self._finalize_loading()
 | |
| 
 | |
|   def _add_segment(self, segment_data: dict, start_time: float, end_time: float):
 | |
|     with self._lock:
 | |
|       self._segments.append(segment_data)
 | |
|       self._segment_starts.append(start_time)
 | |
| 
 | |
|       if len(self._segments) == 1:
 | |
|         self._start_time = start_time
 | |
|       self._duration = end_time - self._start_time
 | |
| 
 | |
|       for msg_type, data in segment_data.items():
 | |
|         for field_name in data.keys():
 | |
|           if field_name != 't':
 | |
|             self._paths.add(f"{msg_type}/{field_name}")
 | |
| 
 | |
|       observers = self._observers.copy()
 | |
| 
 | |
|     for callback in observers:
 | |
|       callback({'segment_added': True, 'duration': self._duration, 'segment_count': len(self._segments)})
 | |
| 
 | |
|   def _finalize_loading(self):
 | |
|     with self._lock:
 | |
|       self._loading = False
 | |
|       observers = self._observers.copy()
 | |
|       duration = self._duration
 | |
| 
 | |
|     for callback in observers:
 | |
|       callback({'loading_complete': True, 'duration': duration})
 | |
| 
 |