Kacper Rączy 1 day ago committed by GitHub
commit 150a196727
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 7
      selfdrive/test/fuzzy_generation.py
  2. 2
      selfdrive/test/process_replay/test_fuzzy.py
  3. 48
      selfdrive/test/process_replay/test_log_compat.py

@ -1,11 +1,10 @@
import capnp import capnp
import hypothesis.strategies as st import hypothesis.strategies as st
from typing import Any from typing import Any
from types import ModuleType
from collections.abc import Callable from collections.abc import Callable
from functools import cache from functools import cache
from cereal import log
DrawType = Callable[[st.SearchStrategy], Any] DrawType = Callable[[st.SearchStrategy], Any]
@ -76,6 +75,6 @@ class FuzzyGenerator:
return data return data
@classmethod @classmethod
def get_random_event_msg(cls, draw: DrawType, events: list[str], real_floats: bool = False) -> list[dict[str, Any]]: def get_random_event_msg(cls, draw: DrawType, log_schema: ModuleType, events: list[str], real_floats: bool = False) -> list[dict[str, Any]]:
fg = cls(draw, real_floats=real_floats) fg = cls(draw, real_floats=real_floats)
return [draw(fg.generate_struct(log.Event.schema, e)) for e in sorted(events)] return [draw(fg.generate_struct(log_schema.Event.schema, e)) for e in sorted(events)]

@ -25,7 +25,7 @@ class TestFuzzProcesses:
@settings(phases=[Phase.generate, Phase.target], max_examples=MAX_EXAMPLES, deadline=1000, @settings(phases=[Phase.generate, Phase.target], max_examples=MAX_EXAMPLES, deadline=1000,
suppress_health_check=[HealthCheck.too_slow, HealthCheck.data_too_large]) suppress_health_check=[HealthCheck.too_slow, HealthCheck.data_too_large])
def test_fuzz_process(self, proc_name, cfg, data): def test_fuzz_process(self, proc_name, cfg, data):
msgs = FuzzyGenerator.get_random_event_msg(data.draw, events=cfg.pubs, real_floats=True) msgs = FuzzyGenerator.get_random_event_msg(data.draw, log_schema=log, events=cfg.pubs, real_floats=True)
lr = [log.Event.new_message(**m).as_reader() for m in msgs] lr = [log.Event.new_message(**m).as_reader() for m in msgs]
cfg.timeout = 5 cfg.timeout = 5
pr.replay_process(cfg, lr, fingerprint=TOYOTA.TOYOTA_COROLLA_TSS2, disable_progress=True) pr.replay_process(cfg, lr, fingerprint=TOYOTA.TOYOTA_COROLLA_TSS2, disable_progress=True)

@ -0,0 +1,48 @@
import os
import capnp
import pytest
import shutil
import hypothesis.strategies as st
from hypothesis import given, settings, HealthCheck
from cereal import CEREAL_PATH
from openpilot.selfdrive.test.fuzzy_generation import FuzzyGenerator
from openpilot.tools.lib.logreader import LogReader
from openpilot.common.run import run_cmd
MAX_EXAMPLES = int(os.environ.get("MAX_EXAMPLES", "10"))
TARGET_REMOTE = os.environ.get("TARGET_REMOTE", "origin")
TARGET_BRANCH = os.environ.get("TARGET_BRANCH", "master")
@pytest.fixture(scope="module")
def parent_schema_file(tmp_path_factory):
tmp_dir = tmp_path_factory.mktemp("cereal")
# FIXME this is an ugly way to do this, but for some reason capnp.load ignores the `imports``, and only looks at dir where the file is
# how it supposed to work is: capnp.load(my_custom_log_capnp, imports=[CEREAL_PATH])
shutil.copytree(CEREAL_PATH, tmp_dir, dirs_exist_ok=True)
commit = run_cmd(["git", "merge-base", f"{TARGET_REMOTE}/{TARGET_BRANCH}", "HEAD"])
log_capnp_url = f"https://raw.githubusercontent.com/commaai/openpilot/{commit}/cereal/log.capnp"
tmp_log_capnp_path = tmp_dir / f"{commit}-log.capnp"
if not tmp_log_capnp_path.exists():
run_cmd(["curl", "-o", str(tmp_log_capnp_path), log_capnp_url])
return str(tmp_log_capnp_path)
@given(st.data())
@settings(max_examples=MAX_EXAMPLES, suppress_health_check=[HealthCheck.large_base_example])
def test_log_backwards_compatibility(parent_schema_file, data):
# capnp global parser needs to be cleaned up to avoid schema/struct ID conflicts
capnp_parser = capnp.SchemaParser()
old_log = capnp_parser.load(os.path.abspath(parent_schema_file))
msgs_dicts = FuzzyGenerator.get_random_event_msg(data.draw, log_schema=old_log, events=old_log.Event.schema.union_fields, real_floats=True)
msgs = [old_log.Event.new_message(**m).as_reader() for m in msgs_dicts]
dat = b"".join(msg.as_builder().to_bytes() for msg in msgs)
lr = list(LogReader.from_bytes(dat))
assert len(lr) == len(msgs)
# calling which() on a removed union type will raise an exception
assert {m.which() for m in lr} == {msg.which() for msg in msgs}
Loading…
Cancel
Save