From e964c5944d4342b6e08ca872eacf9d524a49b479 Mon Sep 17 00:00:00 2001 From: Shane Smiskol Date: Fri, 8 Mar 2024 02:49:24 -0800 Subject: [PATCH] LogReader: fix sort by time and union types (#31565) * fix :( * test_sort_by_time * this isn't required * not slow, and just compare sorted * messy * works * clean up * clean up * not here * clean up * clean up * clean up * makes network call --------- Co-authored-by: Justin Newberry --- tools/lib/logreader.py | 2 +- tools/lib/tests/test_logreader.py | 39 +++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/tools/lib/logreader.py b/tools/lib/logreader.py index 6247bbc9db..7a1e972e19 100755 --- a/tools/lib/logreader.py +++ b/tools/lib/logreader.py @@ -248,7 +248,7 @@ are uploaded or auto fallback to qlogs with '/a' selector at the end of the rout def _get_lr(self, i): if i not in self.__lrs: - self.__lrs[i] = _LogFileReader(self.logreader_identifiers[i]) + self.__lrs[i] = _LogFileReader(self.logreader_identifiers[i], sort_by_time=self.sort_by_time, only_union_types=self.only_union_types) return self.__lrs[i] def __iter__(self): diff --git a/tools/lib/tests/test_logreader.py b/tools/lib/tests/test_logreader.py index 2141915b87..fc72202b26 100755 --- a/tools/lib/tests/test_logreader.py +++ b/tools/lib/tests/test_logreader.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +import capnp import contextlib import io import shutil @@ -11,6 +12,7 @@ import requests from parameterized import parameterized from unittest import mock +from cereal import log as capnp_log from openpilot.tools.lib.logreader import LogIterable, LogReader, comma_api_source, parse_indirect, ReadMode, InternalUnavailableException from openpilot.tools.lib.route import SegmentRange from openpilot.tools.lib.url_file import URLFileException @@ -216,6 +218,43 @@ class TestLogReader(unittest.TestCase): log_len = len(list(lr)) self.assertEqual(qlog_len, log_len) + @pytest.mark.slow + def test_sort_by_time(self): + msgs = list(LogReader(f"{TEST_ROUTE}/0/q")) + self.assertNotEqual(msgs, sorted(msgs, key=lambda m: m.logMonoTime)) + + msgs = list(LogReader(f"{TEST_ROUTE}/0/q", sort_by_time=True)) + self.assertEqual(msgs, sorted(msgs, key=lambda m: m.logMonoTime)) + + def test_only_union_types(self): + with tempfile.NamedTemporaryFile() as qlog: + # write valid Event messages + num_msgs = 100 + with open(qlog.name, "wb") as f: + f.write(b"".join(capnp_log.Event.new_message().to_bytes() for _ in range(num_msgs))) + + msgs = list(LogReader(qlog.name)) + self.assertEqual(len(msgs), num_msgs) + [m.which() for m in msgs] + + # append non-union Event message + event_msg = capnp_log.Event.new_message() + non_union_bytes = bytearray(event_msg.to_bytes()) + non_union_bytes[event_msg.total_size.word_count * 8] = 0xff # set discriminant value out of range using Event word offset + with open(qlog.name, "ab") as f: + f.write(non_union_bytes) + + # ensure new message is added, but is not a union type + msgs = list(LogReader(qlog.name)) + self.assertEqual(len(msgs), num_msgs + 1) + with self.assertRaises(capnp.KjException): + [m.which() for m in msgs] + + # should not be added when only_union_types=True + msgs = list(LogReader(qlog.name, only_union_types=True)) + self.assertEqual(len(msgs), num_msgs) + [m.which() for m in msgs] + if __name__ == "__main__": unittest.main()