switch cereal to pytest (#32950)

pytest
old-commit-hash: 133f25eecb
fix-exp-path
Maxime Desroches 10 months ago committed by GitHub
parent abd29fa646
commit d0e2572507
  1. 52
      cereal/messaging/tests/test_messaging.py
  2. 44
      cereal/messaging/tests/test_pub_sub_master.py
  3. 13
      cereal/messaging/tests/test_services.py
  4. 2
      pyproject.toml

@ -1,4 +1,3 @@
#!/usr/bin/env python3
import os import os
import capnp import capnp
import multiprocessing import multiprocessing
@ -6,8 +5,8 @@ import numbers
import random import random
import threading import threading
import time import time
import unittest
from parameterized import parameterized from parameterized import parameterized
import pytest
from cereal import log, car from cereal import log, car
import cereal.messaging as messaging import cereal.messaging as messaging
@ -28,12 +27,6 @@ def zmq_sleep(t=1):
if "ZMQ" in os.environ: if "ZMQ" in os.environ:
time.sleep(t) time.sleep(t)
def zmq_expected_failure(func):
if "ZMQ" in os.environ:
return unittest.expectedFailure(func)
else:
return func
# TODO: this should take any capnp struct and returrn a msg with random populated data # TODO: this should take any capnp struct and returrn a msg with random populated data
def random_carstate(): def random_carstate():
@ -58,12 +51,12 @@ def delayed_send(delay, sock, dat):
threading.Timer(delay, send_func).start() threading.Timer(delay, send_func).start()
class TestMessaging(unittest.TestCase): class TestMessaging:
def setUp(self): def setUp(self):
# TODO: ZMQ tests are too slow; all sleeps will need to be # TODO: ZMQ tests are too slow; all sleeps will need to be
# replaced with logic to block on the necessary condition # replaced with logic to block on the necessary condition
if "ZMQ" in os.environ: if "ZMQ" in os.environ:
raise unittest.SkipTest pytest.skip()
# ZMQ pub socket takes too long to die # ZMQ pub socket takes too long to die
# sleep to prevent multiple publishers error between tests # sleep to prevent multiple publishers error between tests
@ -75,9 +68,9 @@ class TestMessaging(unittest.TestCase):
msg = messaging.new_message(evt) msg = messaging.new_message(evt)
except capnp.lib.capnp.KjException: except capnp.lib.capnp.KjException:
msg = messaging.new_message(evt, random.randrange(200)) msg = messaging.new_message(evt, random.randrange(200))
self.assertLess(time.monotonic() - msg.logMonoTime, 0.1) assert (time.monotonic() - msg.logMonoTime) < 0.1
self.assertFalse(msg.valid) assert not msg.valid
self.assertEqual(evt, msg.which()) assert evt == msg.which()
@parameterized.expand(events) @parameterized.expand(events)
def test_pub_sock(self, evt): def test_pub_sock(self, evt):
@ -99,8 +92,8 @@ class TestMessaging(unittest.TestCase):
# no wait and no msgs in queue # no wait and no msgs in queue
msgs = func(sub_sock) msgs = func(sub_sock)
self.assertIsInstance(msgs, list) assert isinstance(msgs, list)
self.assertEqual(len(msgs), 0) assert len(msgs) == 0
# no wait but msgs are queued up # no wait but msgs are queued up
num_msgs = random.randrange(3, 10) num_msgs = random.randrange(3, 10)
@ -108,9 +101,9 @@ class TestMessaging(unittest.TestCase):
pub_sock.send(messaging.new_message(sock).to_bytes()) pub_sock.send(messaging.new_message(sock).to_bytes())
time.sleep(0.1) time.sleep(0.1)
msgs = func(sub_sock) msgs = func(sub_sock)
self.assertIsInstance(msgs, list) assert isinstance(msgs, list)
self.assertTrue(all(isinstance(msg, expected_type) for msg in msgs)) assert all(isinstance(msg, expected_type) for msg in msgs)
self.assertEqual(len(msgs), num_msgs) assert len(msgs) == num_msgs
def test_recv_sock(self): def test_recv_sock(self):
sock = "carState" sock = "carState"
@ -120,14 +113,14 @@ class TestMessaging(unittest.TestCase):
# no wait and no msg in queue, socket should timeout # no wait and no msg in queue, socket should timeout
recvd = messaging.recv_sock(sub_sock) recvd = messaging.recv_sock(sub_sock)
self.assertTrue(recvd is None) assert recvd is None
# no wait and one msg in queue # no wait and one msg in queue
msg = random_carstate() msg = random_carstate()
pub_sock.send(msg.to_bytes()) pub_sock.send(msg.to_bytes())
time.sleep(0.01) time.sleep(0.01)
recvd = messaging.recv_sock(sub_sock) recvd = messaging.recv_sock(sub_sock)
self.assertIsInstance(recvd, capnp._DynamicStructReader) assert isinstance(recvd, capnp._DynamicStructReader)
# https://github.com/python/mypy/issues/13038 # https://github.com/python/mypy/issues/13038
assert_carstate(msg.carState, recvd.carState) assert_carstate(msg.carState, recvd.carState)
@ -139,16 +132,16 @@ class TestMessaging(unittest.TestCase):
# no msg in queue, socket should timeout # no msg in queue, socket should timeout
recvd = messaging.recv_one(sub_sock) recvd = messaging.recv_one(sub_sock)
self.assertTrue(recvd is None) assert recvd is None
# one msg in queue # one msg in queue
msg = random_carstate() msg = random_carstate()
pub_sock.send(msg.to_bytes()) pub_sock.send(msg.to_bytes())
recvd = messaging.recv_one(sub_sock) recvd = messaging.recv_one(sub_sock)
self.assertIsInstance(recvd, capnp._DynamicStructReader) assert isinstance(recvd, capnp._DynamicStructReader)
assert_carstate(msg.carState, recvd.carState) assert_carstate(msg.carState, recvd.carState)
@zmq_expected_failure @pytest.mark.xfail(condition="ZMQ" in os.environ, reason='ZMQ detected')
def test_recv_one_or_none(self): def test_recv_one_or_none(self):
sock = "carState" sock = "carState"
pub_sock = messaging.pub_sock(sock) pub_sock = messaging.pub_sock(sock)
@ -157,13 +150,13 @@ class TestMessaging(unittest.TestCase):
# no msg in queue, socket shouldn't block # no msg in queue, socket shouldn't block
recvd = messaging.recv_one_or_none(sub_sock) recvd = messaging.recv_one_or_none(sub_sock)
self.assertTrue(recvd is None) assert recvd is None
# one msg in queue # one msg in queue
msg = random_carstate() msg = random_carstate()
pub_sock.send(msg.to_bytes()) pub_sock.send(msg.to_bytes())
recvd = messaging.recv_one_or_none(sub_sock) recvd = messaging.recv_one_or_none(sub_sock)
self.assertIsInstance(recvd, capnp._DynamicStructReader) assert isinstance(recvd, capnp._DynamicStructReader)
assert_carstate(msg.carState, recvd.carState) assert_carstate(msg.carState, recvd.carState)
def test_recv_one_retry(self): def test_recv_one_retry(self):
@ -179,7 +172,7 @@ class TestMessaging(unittest.TestCase):
p = multiprocessing.Process(target=messaging.recv_one_retry, args=(sub_sock,)) p = multiprocessing.Process(target=messaging.recv_one_retry, args=(sub_sock,))
p.start() p.start()
time.sleep(sock_timeout*15) time.sleep(sock_timeout*15)
self.assertTrue(p.is_alive()) assert p.is_alive()
p.terminate() p.terminate()
# wait 15 socket timeouts before sending # wait 15 socket timeouts before sending
@ -187,9 +180,6 @@ class TestMessaging(unittest.TestCase):
delayed_send(sock_timeout*15, pub_sock, msg.to_bytes()) delayed_send(sock_timeout*15, pub_sock, msg.to_bytes())
start_time = time.monotonic() start_time = time.monotonic()
recvd = messaging.recv_one_retry(sub_sock) recvd = messaging.recv_one_retry(sub_sock)
self.assertGreaterEqual(time.monotonic() - start_time, sock_timeout*15) assert (time.monotonic() - start_time) >= sock_timeout*15
self.assertIsInstance(recvd, capnp._DynamicStructReader) assert isinstance(recvd, capnp._DynamicStructReader)
assert_carstate(msg.carState, recvd.carState) assert_carstate(msg.carState, recvd.carState)
if __name__ == "__main__":
unittest.main()

@ -1,8 +1,6 @@
#!/usr/bin/env python3
import random import random
import time import time
from typing import Sized, cast from typing import Sized, cast
import unittest
import cereal.messaging as messaging import cereal.messaging as messaging
from cereal.messaging.tests.test_messaging import events, random_sock, random_socks, \ from cereal.messaging.tests.test_messaging import events, random_sock, random_socks, \
@ -10,9 +8,9 @@ from cereal.messaging.tests.test_messaging import events, random_sock, random_so
zmq_sleep zmq_sleep
class TestSubMaster(unittest.TestCase): class TestSubMaster:
def setUp(self): def setup_method(self):
# ZMQ pub socket takes too long to die # ZMQ pub socket takes too long to die
# sleep to prevent multiple publishers error between tests # sleep to prevent multiple publishers error between tests
zmq_sleep(3) zmq_sleep(3)
@ -21,21 +19,21 @@ class TestSubMaster(unittest.TestCase):
sm = messaging.SubMaster(events) sm = messaging.SubMaster(events)
for p in [sm.updated, sm.recv_time, sm.recv_frame, sm.alive, for p in [sm.updated, sm.recv_time, sm.recv_frame, sm.alive,
sm.sock, sm.data, sm.logMonoTime, sm.valid]: sm.sock, sm.data, sm.logMonoTime, sm.valid]:
self.assertEqual(len(cast(Sized, p)), len(events)) assert len(cast(Sized, p)) == len(events)
def test_init_state(self): def test_init_state(self):
socks = random_socks() socks = random_socks()
sm = messaging.SubMaster(socks) sm = messaging.SubMaster(socks)
self.assertEqual(sm.frame, -1) assert sm.frame == -1
self.assertFalse(any(sm.updated.values())) assert not any(sm.updated.values())
self.assertFalse(any(sm.alive.values())) assert not any(sm.alive.values())
self.assertTrue(all(t == 0. for t in sm.recv_time.values())) assert all(t == 0. for t in sm.recv_time.values())
self.assertTrue(all(f == 0 for f in sm.recv_frame.values())) assert all(f == 0 for f in sm.recv_frame.values())
self.assertTrue(all(t == 0 for t in sm.logMonoTime.values())) assert all(t == 0 for t in sm.logMonoTime.values())
for p in [sm.updated, sm.recv_time, sm.recv_frame, sm.alive, for p in [sm.updated, sm.recv_time, sm.recv_frame, sm.alive,
sm.sock, sm.data, sm.logMonoTime, sm.valid]: sm.sock, sm.data, sm.logMonoTime, sm.valid]:
self.assertEqual(len(cast(Sized, p)), len(socks)) assert len(cast(Sized, p)) == len(socks)
def test_getitem(self): def test_getitem(self):
sock = "carState" sock = "carState"
@ -59,8 +57,8 @@ class TestSubMaster(unittest.TestCase):
msg = messaging.new_message(sock) msg = messaging.new_message(sock)
pub_sock.send(msg.to_bytes()) pub_sock.send(msg.to_bytes())
sm.update(1000) sm.update(1000)
self.assertEqual(sm.frame, i) assert sm.frame == i
self.assertTrue(all(sm.updated.values())) assert all(sm.updated.values())
def test_update_timeout(self): def test_update_timeout(self):
sock = random_sock() sock = random_sock()
@ -70,9 +68,9 @@ class TestSubMaster(unittest.TestCase):
start_time = time.monotonic() start_time = time.monotonic()
sm.update(timeout) sm.update(timeout)
t = time.monotonic() - start_time t = time.monotonic() - start_time
self.assertGreaterEqual(t, timeout/1000.) assert t >= timeout/1000.
self.assertLess(t, 5) assert t < 5
self.assertFalse(any(sm.updated.values())) assert not any(sm.updated.values())
def test_avg_frequency_checks(self): def test_avg_frequency_checks(self):
for poll in (True, False): for poll in (True, False):
@ -118,12 +116,12 @@ class TestSubMaster(unittest.TestCase):
pub_sock.send(msg.to_bytes()) pub_sock.send(msg.to_bytes())
time.sleep(0.01) time.sleep(0.01)
sm.update(1000) sm.update(1000)
self.assertEqual(sm[sock].vEgo, n) assert sm[sock].vEgo == n
class TestPubMaster(unittest.TestCase): class TestPubMaster:
def setUp(self): def setup_method(self):
# ZMQ pub socket takes too long to die # ZMQ pub socket takes too long to die
# sleep to prevent multiple publishers error between tests # sleep to prevent multiple publishers error between tests
zmq_sleep(3) zmq_sleep(3)
@ -156,8 +154,4 @@ class TestPubMaster(unittest.TestCase):
if capnp: if capnp:
msg.clear_write_flag() msg.clear_write_flag()
msg = msg.to_bytes() msg = msg.to_bytes()
self.assertEqual(msg, recvd, i) assert msg == recvd, i
if __name__ == "__main__":
unittest.main()

@ -1,26 +1,21 @@
#!/usr/bin/env python3
import os import os
import tempfile import tempfile
from typing import Dict from typing import Dict
import unittest
from parameterized import parameterized from parameterized import parameterized
import cereal.services as services import cereal.services as services
from cereal.services import SERVICE_LIST from cereal.services import SERVICE_LIST
class TestServices(unittest.TestCase): class TestServices:
@parameterized.expand(SERVICE_LIST.keys()) @parameterized.expand(SERVICE_LIST.keys())
def test_services(self, s): def test_services(self, s):
service = SERVICE_LIST[s] service = SERVICE_LIST[s]
self.assertTrue(service.frequency <= 104) assert service.frequency <= 104
self.assertTrue(service.decimation != 0) assert service.decimation != 0
def test_generated_header(self): def test_generated_header(self):
with tempfile.NamedTemporaryFile(suffix=".h") as f: with tempfile.NamedTemporaryFile(suffix=".h") as f:
ret = os.system(f"python3 {services.__file__} > {f.name} && clang++ {f.name}") ret = os.system(f"python3 {services.__file__} > {f.name} && clang++ {f.name}")
self.assertEqual(ret, 0, "generated services header is not valid C") assert ret == 0, "generated services header is not valid C"
if __name__ == "__main__":
unittest.main()

@ -136,7 +136,7 @@ packages = [ "." ]
[tool.pytest.ini_options] [tool.pytest.ini_options]
minversion = "6.0" minversion = "6.0"
addopts = "--ignore=openpilot/ --ignore=cereal/ --ignore=opendbc/ --ignore=panda/ --ignore=rednose_repo/ --ignore=tinygrad_repo/ --ignore=teleoprtc_repo/ --ignore=msgq/ -Werror --strict-config --strict-markers --durations=10 -n auto --dist=loadgroup" addopts = "--ignore=openpilot/ --ignore=opendbc/ --ignore=panda/ --ignore=rednose_repo/ --ignore=tinygrad_repo/ --ignore=teleoprtc_repo/ --ignore=msgq/ -Werror --strict-config --strict-markers --durations=10 -n auto --dist=loadgroup"
cpp_files = "test_*" cpp_files = "test_*"
cpp_harness = "selfdrive/test/cpp_harness.py" cpp_harness = "selfdrive/test/cpp_harness.py"
python_files = "test_*.py" python_files = "test_*.py"

Loading…
Cancel
Save