diff --git a/selfdrive/test/process_replay/model_replay.py b/selfdrive/test/process_replay/model_replay.py index 0578a61588..201baf6b0e 100755 --- a/selfdrive/test/process_replay/model_replay.py +++ b/selfdrive/test/process_replay/model_replay.py @@ -15,15 +15,15 @@ from openpilot.system.hardware import PC from openpilot.tools.lib.openpilotci import get_url from openpilot.selfdrive.test.process_replay.compare_logs import compare_logs, format_diff from openpilot.selfdrive.test.process_replay.process_replay import get_process_config, replay_process -from openpilot.tools.lib.framereader import FrameReader, NumpyFrameReader +from openpilot.tools.lib.framereader import FrameReader from openpilot.tools.lib.logreader import LogReader, save_log from openpilot.tools.lib.github_utils import GithubUtils TEST_ROUTE = "8494c69d3c710e81|000001d4--2648a9a404" SEGMENT = 4 -MAX_FRAMES = 100 if PC else 400 +START_FRAME = 0 +END_FRAME = 60 -NO_MODEL = "NO_MODEL" in os.environ SEND_EXTRA_INPUTS = bool(int(os.getenv("SEND_EXTRA_INPUTS", "0"))) DATA_TOKEN = os.getenv("CI_ARTIFACTS_TOKEN","") @@ -125,16 +125,15 @@ def comment_replay_report(proposed, master, full_logs): comment = f"ref for commit {commit}: {link}/{log_name}" + diff_plots + all_plots GITHUB.comment_on_pr(comment, PR_BRANCH, "commaci-public", True) -def trim_logs_to_max_frames(logs, max_frames, frs_types, include_all_types): +def trim_logs(logs, start_frame, end_frame, frs_types, include_all_types): all_msgs = [] cam_state_counts = defaultdict(int) - # keep adding messages until cam states are equal to MAX_FRAMES for msg in sorted(logs, key=lambda m: m.logMonoTime): - all_msgs.append(msg) if msg.which() in frs_types: cam_state_counts[msg.which()] += 1 - - if all(cam_state_counts[state] == max_frames for state in frs_types): + if any(cam_state_counts[state] >= start_frame for state in frs_types): + all_msgs.append(msg) + if all(cam_state_counts[state] == end_frame for state in frs_types): break if len(include_all_types) != 0: @@ -146,9 +145,9 @@ def trim_logs_to_max_frames(logs, max_frames, frs_types, include_all_types): def model_replay(lr, frs): # modeld is using frame pairs - modeld_logs = trim_logs_to_max_frames(lr, MAX_FRAMES, {"roadCameraState", "wideRoadCameraState"}, - {"roadEncodeIdx", "wideRoadEncodeIdx", "carParams", "carState", "carControl"}) - dmodeld_logs = trim_logs_to_max_frames(lr, MAX_FRAMES, {"driverCameraState"}, {"driverEncodeIdx", "carParams"}) + modeld_logs = trim_logs(lr, START_FRAME, END_FRAME, {"roadCameraState", "wideRoadCameraState"}, + {"roadEncodeIdx", "wideRoadEncodeIdx", "carParams", "carState", "carControl", "can"}) + dmodeld_logs = trim_logs(lr, START_FRAME, END_FRAME, {"driverCameraState"}, {"driverEncodeIdx", "carParams", "can"}) if not SEND_EXTRA_INPUTS: modeld_logs = [msg for msg in modeld_logs if msg.which() != 'liveCalibration'] @@ -165,9 +164,6 @@ def model_replay(lr, frs): dmonitoringmodeld = get_process_config("dmonitoringmodeld") modeld_msgs = replay_process(modeld, modeld_logs, frs) - if isinstance(frs['roadCameraState'], NumpyFrameReader): - del frs['roadCameraState'].frames - del frs['wideRoadCameraState'].frames dmonitoringmodeld_msgs = replay_process(dmonitoringmodeld, dmodeld_logs, frs) msgs = modeld_msgs + dmonitoringmodeld_msgs @@ -198,42 +194,21 @@ def model_replay(lr, frs): return msgs -def get_frames(): - regen_cache = "--regen-cache" in sys.argv - cache = "--cache" in sys.argv or not PC or regen_cache - videos = ('fcamera.hevc', 'dcamera.hevc', 'ecamera.hevc') - cams = ('roadCameraState', 'driverCameraState', 'wideRoadCameraState') - - if cache: - frames_cache = '/tmp/model_replay_cache' if PC else '/data/model_replay_cache' - os.makedirs(frames_cache, exist_ok=True) - - cache_size = 200 - for v in videos: - if not all(os.path.isfile(f'{frames_cache}/{TEST_ROUTE}_{v}_{i}.npy') for i in range(MAX_FRAMES//cache_size)) or regen_cache: - f = FrameReader(get_url(TEST_ROUTE, SEGMENT, v)).get(0, MAX_FRAMES + 1, pix_fmt="nv12") - print(f'Caching {v}...') - for i in range(MAX_FRAMES//cache_size): - np.save(f'{frames_cache}/{TEST_ROUTE}_{v}_{i}', f[(i * cache_size) + 1:((i + 1) * cache_size) + 1]) - del f - - return {c : NumpyFrameReader(f"{frames_cache}/{TEST_ROUTE}_{v}", 1928, 1208, cache_size) for c,v in zip(cams, videos, strict=True)} - else: - return {c : FrameReader(get_url(TEST_ROUTE, SEGMENT, v), readahead=True) for c,v in zip(cams, videos, strict=True)} - - if __name__ == "__main__": update = "--update" in sys.argv or (os.getenv("GIT_BRANCH", "") == 'master') replay_dir = os.path.dirname(os.path.abspath(__file__)) # load logs lr = list(LogReader(get_url(TEST_ROUTE, SEGMENT, "rlog.zst"))) - frs = get_frames() + frs = { + 'roadCameraState': FrameReader(get_url(TEST_ROUTE, SEGMENT, "fcamera.hevc"), readahead=True), + 'driverCameraState': FrameReader(get_url(TEST_ROUTE, SEGMENT, "dcamera.hevc"), readahead=True), + 'wideRoadCameraState': FrameReader(get_url(TEST_ROUTE, SEGMENT, "ecamera.hevc"), readahead=True) + } log_msgs = [] # run replays - if not NO_MODEL: - log_msgs += model_replay(lr, frs) + log_msgs += model_replay(lr, frs) # get diff failed = False @@ -242,13 +217,10 @@ if __name__ == "__main__": try: all_logs = list(LogReader(GITHUB.get_file_url(MODEL_REPLAY_BUCKET, log_fn))) cmp_log = [] - - # logs are ordered based on type: modelV2, drivingModelData, driverStateV2 - if not NO_MODEL: - model_start_index = next(i for i, m in enumerate(all_logs) if m.which() in ("modelV2", "drivingModelData", "cameraOdometry")) - cmp_log += all_logs[model_start_index:model_start_index + MAX_FRAMES*3] - dmon_start_index = next(i for i, m in enumerate(all_logs) if m.which() == "driverStateV2") - cmp_log += all_logs[dmon_start_index:dmon_start_index + MAX_FRAMES] + model_start_index = next(i for i, m in enumerate(all_logs) if m.which() in ("modelV2", "drivingModelData", "cameraOdometry")) + cmp_log += all_logs[model_start_index+START_FRAME*3:model_start_index + END_FRAME*3] + dmon_start_index = next(i for i, m in enumerate(all_logs) if m.which() == "driverStateV2") + cmp_log += all_logs[dmon_start_index+START_FRAME:dmon_start_index + END_FRAME] ignore = [ 'logMonoTime', diff --git a/tools/lib/framereader.py b/tools/lib/framereader.py index 7c01992a28..275b9b65b8 100644 --- a/tools/lib/framereader.py +++ b/tools/lib/framereader.py @@ -535,25 +535,3 @@ def FrameIterator(fn, pix_fmt, **kwargs): else: for i in range(fr.frame_count): yield fr.get(i, pix_fmt=pix_fmt)[0] - - -class NumpyFrameReader: - def __init__(self, name, w, h, cache_size): - self.name = name - self.pos = -1 - self.frames = None - self.w = w - self.h = h - self.cache_size = cache_size - - def close(self): - pass - - def get(self, num, count=1, pix_fmt="nv12"): - num -= 1 - q = num // self.cache_size - if q != self.pos: - del self.frames - self.pos = q - self.frames = np.load(f'{self.name}_{self.pos}.npy') - return [self.frames[num % self.cache_size]]