diff --git a/selfdrive/test/process_replay/process_replay.py b/selfdrive/test/process_replay/process_replay.py index 2b9096b423..43b329e916 100755 --- a/selfdrive/test/process_replay/process_replay.py +++ b/selfdrive/test/process_replay/process_replay.py @@ -5,7 +5,8 @@ import sys import threading import time import signal -from collections import namedtuple +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Callable import capnp @@ -28,7 +29,20 @@ TIMEOUT = 15 PROC_REPLAY_DIR = os.path.dirname(os.path.abspath(__file__)) FAKEDATA = os.path.join(PROC_REPLAY_DIR, "fakedata/") -ProcessConfig = namedtuple('ProcessConfig', ['proc_name', 'pub_sub', 'ignore', 'init_callback', 'should_recv_callback', 'tolerance', 'fake_pubsubmaster', 'submaster_config', 'environ', 'subtest_name', "field_tolerances"], defaults=({}, {}, "", {})) +@dataclass +class ProcessConfig: + proc_name: str + pub_sub: Dict[str, List[str]] + ignore: List[str] + init_callback: Optional[Callable] + should_recv_callback: Optional[Callable] + tolerance: Optional[float] + fake_pubsubmaster: bool + submaster_config: Dict[str, List[str]] = field(default_factory=dict) + environ: Dict[str, str] = field(default_factory=dict) + subtest_name: str = "" + field_tolerances: Dict[str, float] = field(default_factory=dict) + timeout: int = 30 def wait_for_event(evt): @@ -365,6 +379,7 @@ CONFIGS = [ should_recv_callback=None, tolerance=NUMPY_TOLERANCE, fake_pubsubmaster=False, + timeout=60*10, # first messages are blocked on internet assistance ), ProcessConfig( proc_name="torqued", @@ -532,8 +547,9 @@ def python_replay_process(cfg, lr, fingerprint=None): def replay_process_with_sockets(cfg, lr, fingerprint=None): - sub_sockets = [s for _, sub in cfg.pub_sub.items() for s in sub] pm = messaging.PubMaster(cfg.pub_sub.keys()) + sub_sockets = [s for _, sub in cfg.pub_sub.items() for s in sub] + sockets = {s: messaging.sub_sock(s, timeout=100) for s in sub_sockets} all_msgs = sorted(lr, key=lambda msg: msg.logMonoTime) pub_msgs = [msg for msg in all_msgs if msg.which() in list(cfg.pub_sub.keys())] @@ -556,22 +572,26 @@ def replay_process_with_sockets(cfg, lr, fingerprint=None): while not any(pm.all_readers_updated(s) for s in cfg.pub_sub.keys()): time.sleep(0) - # Make sure all subscribers are connected - sockets = {s: messaging.sub_sock(s, timeout=2000) for s in sub_sockets} - for s in sub_sockets: - messaging.recv_one_or_none(sockets[s]) + for s in sockets.values(): + messaging.recv_one_or_none(s) # Do the replay cnt = 0 for msg in pub_msgs: - with Timeout(TIMEOUT, error_msg=f"timed out testing process {repr(cfg.proc_name)}, {cnt}/{len(pub_msgs)} msgs done"): + with Timeout(cfg.timeout, error_msg=f"timed out testing process {repr(cfg.proc_name)}, {cnt}/{len(pub_msgs)} msgs done"): + resp_sockets = cfg.pub_sub[msg.which()] + if cfg.should_recv_callback is not None: + resp_sockets, _ = cfg.should_recv_callback(msg, None, None, None) + + # Make sure all subscribers are connected + if len(log_msgs) == 0 and len(resp_sockets) > 0: + for s in sockets.values(): + messaging.recv_one_or_none(s) + pm.send(msg.which(), msg.as_builder()) while not pm.all_readers_updated(msg.which()): time.sleep(0) - resp_sockets = cfg.pub_sub[msg.which()] - if cfg.should_recv_callback is not None: - resp_sockets, _ = cfg.should_recv_callback(msg, None, None, None) for s in resp_sockets: m = messaging.recv_one_retry(sockets[s]) m = m.as_builder()