diff --git a/.github/workflows/selfdrive_tests.yaml b/.github/workflows/selfdrive_tests.yaml index 2be83b1654..85c4073072 100644 --- a/.github/workflows/selfdrive_tests.yaml +++ b/.github/workflows/selfdrive_tests.yaml @@ -147,7 +147,7 @@ jobs: - name: Run valgrind timeout-minutes: 1 run: | - ${{ env.RUN }} "python selfdrive/test/test_valgrind_replay.py" + ${{ env.RUN }} "pytest selfdrive/test/test_valgrind_replay.py" - name: Print logs if: always() run: cat selfdrive/test/valgrind_logs.txt diff --git a/common/tests/test_file_helpers.py b/common/tests/test_file_helpers.py index 1817f77cd2..a9977c2362 100644 --- a/common/tests/test_file_helpers.py +++ b/common/tests/test_file_helpers.py @@ -1,11 +1,10 @@ import os -import unittest from uuid import uuid4 from openpilot.common.file_helpers import atomic_write_in_dir -class TestFileHelpers(unittest.TestCase): +class TestFileHelpers: def run_atomic_write_func(self, atomic_write_func): path = f"/tmp/tmp{uuid4()}" with atomic_write_func(path) as f: @@ -13,12 +12,8 @@ class TestFileHelpers(unittest.TestCase): assert not os.path.exists(path) with open(path) as f: - self.assertEqual(f.read(), "test") + assert f.read() == "test" os.remove(path) def test_atomic_write_in_dir(self): self.run_atomic_write_func(atomic_write_in_dir) - - -if __name__ == "__main__": - unittest.main() diff --git a/common/tests/test_numpy_fast.py b/common/tests/test_numpy_fast.py index de7bb972e7..aa53851db0 100644 --- a/common/tests/test_numpy_fast.py +++ b/common/tests/test_numpy_fast.py @@ -1,10 +1,9 @@ import numpy as np -import unittest from openpilot.common.numpy_fast import interp -class InterpTest(unittest.TestCase): +class TestInterp: def test_correctness_controls(self): _A_CRUISE_MIN_BP = np.asarray([0., 5., 10., 20., 40.]) _A_CRUISE_MIN_V = np.asarray([-1.0, -.8, -.67, -.5, -.30]) @@ -20,7 +19,3 @@ class InterpTest(unittest.TestCase): expected = np.interp(v_ego, _A_CRUISE_MIN_BP, _A_CRUISE_MIN_V) actual = interp(v_ego, _A_CRUISE_MIN_BP, _A_CRUISE_MIN_V) np.testing.assert_equal(actual, expected) - - -if __name__ == "__main__": - unittest.main() diff --git a/common/tests/test_params.py b/common/tests/test_params.py index 490ee122be..16cbc45295 100644 --- a/common/tests/test_params.py +++ b/common/tests/test_params.py @@ -1,13 +1,13 @@ +import pytest import os import threading import time import uuid -import unittest from openpilot.common.params import Params, ParamKeyType, UnknownKeyName -class TestParams(unittest.TestCase): - def setUp(self): +class TestParams: + def setup_method(self): self.params = Params() def test_params_put_and_get(self): @@ -49,16 +49,16 @@ class TestParams(unittest.TestCase): assert self.params.get("CarParams", True) == b"test" def test_params_unknown_key_fails(self): - with self.assertRaises(UnknownKeyName): + with pytest.raises(UnknownKeyName): self.params.get("swag") - with self.assertRaises(UnknownKeyName): + with pytest.raises(UnknownKeyName): self.params.get_bool("swag") - with self.assertRaises(UnknownKeyName): + with pytest.raises(UnknownKeyName): self.params.put("swag", "abc") - with self.assertRaises(UnknownKeyName): + with pytest.raises(UnknownKeyName): self.params.put_bool("swag", True) def test_remove_not_there(self): @@ -68,19 +68,19 @@ class TestParams(unittest.TestCase): def test_get_bool(self): self.params.remove("IsMetric") - self.assertFalse(self.params.get_bool("IsMetric")) + assert not self.params.get_bool("IsMetric") self.params.put_bool("IsMetric", True) - self.assertTrue(self.params.get_bool("IsMetric")) + assert self.params.get_bool("IsMetric") self.params.put_bool("IsMetric", False) - self.assertFalse(self.params.get_bool("IsMetric")) + assert not self.params.get_bool("IsMetric") self.params.put("IsMetric", "1") - self.assertTrue(self.params.get_bool("IsMetric")) + assert self.params.get_bool("IsMetric") self.params.put("IsMetric", "0") - self.assertFalse(self.params.get_bool("IsMetric")) + assert not self.params.get_bool("IsMetric") def test_put_non_blocking_with_get_block(self): q = Params() @@ -107,7 +107,3 @@ class TestParams(unittest.TestCase): assert len(keys) > 20 assert len(keys) == len(set(keys)) assert b"CarParams" in keys - - -if __name__ == "__main__": - unittest.main() diff --git a/common/tests/test_simple_kalman.py b/common/tests/test_simple_kalman.py index f641cd19e6..f4a967e58a 100644 --- a/common/tests/test_simple_kalman.py +++ b/common/tests/test_simple_kalman.py @@ -1,10 +1,8 @@ -import unittest - from openpilot.common.simple_kalman import KF1D -class TestSimpleKalman(unittest.TestCase): - def setUp(self): +class TestSimpleKalman: + def setup_method(self): dt = 0.01 x0_0 = 0.0 x1_0 = 0.0 @@ -24,12 +22,8 @@ class TestSimpleKalman(unittest.TestCase): def test_getter_setter(self): self.kf.set_x([[1.0], [1.0]]) - self.assertEqual(self.kf.x, [[1.0], [1.0]]) + assert self.kf.x == [[1.0], [1.0]] def update_returns_state(self): x = self.kf.update(100) - self.assertEqual(x, self.kf.x) - - -if __name__ == "__main__": - unittest.main() + assert x == self.kf.x diff --git a/common/transformations/tests/test_coordinates.py b/common/transformations/tests/test_coordinates.py index 7ae79403bd..41076d9b3f 100755 --- a/common/transformations/tests/test_coordinates.py +++ b/common/transformations/tests/test_coordinates.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 import numpy as np -import unittest import openpilot.common.transformations.coordinates as coord @@ -44,7 +43,7 @@ ned_offsets_batch = np.array([[ 53.88103168, 43.83445935, -46.27488057], [ 78.56272609, 18.53100158, -43.25290759]]) -class TestNED(unittest.TestCase): +class TestNED: def test_small_distances(self): start_geodetic = np.array([33.8042184, -117.888593, 0.0]) local_coord = coord.LocalCoord.from_geodetic(start_geodetic) @@ -54,13 +53,13 @@ class TestNED(unittest.TestCase): west_geodetic = start_geodetic + [0, -0.0005, 0] west_ned = local_coord.geodetic2ned(west_geodetic) - self.assertLess(np.abs(west_ned[0]), 1e-3) - self.assertLess(west_ned[1], 0) + assert np.abs(west_ned[0]) < 1e-3 + assert west_ned[1] < 0 southwest_geodetic = start_geodetic + [-0.0005, -0.002, 0] southwest_ned = local_coord.geodetic2ned(southwest_geodetic) - self.assertLess(southwest_ned[0], 0) - self.assertLess(southwest_ned[1], 0) + assert southwest_ned[0] < 0 + assert southwest_ned[1] < 0 def test_ecef_geodetic(self): # testing single @@ -105,5 +104,3 @@ class TestNED(unittest.TestCase): np.testing.assert_allclose(converter.ned2ecef(ned_offsets_batch), ecef_positions_offset_batch, rtol=1e-9, atol=1e-7) -if __name__ == "__main__": - unittest.main() diff --git a/common/transformations/tests/test_orientation.py b/common/transformations/tests/test_orientation.py index f77827d2f9..695642774e 100755 --- a/common/transformations/tests/test_orientation.py +++ b/common/transformations/tests/test_orientation.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 import numpy as np -import unittest from openpilot.common.transformations.orientation import euler2quat, quat2euler, euler2rot, rot2euler, \ rot2quat, quat2rot, \ @@ -32,7 +31,7 @@ ned_eulers = np.array([[ 0.46806039, -0.4881889 , 1.65697808], [ 2.50450101, 0.36304151, 0.33136365]]) -class TestOrientation(unittest.TestCase): +class TestOrientation: def test_quat_euler(self): for i, eul in enumerate(eulers): np.testing.assert_allclose(quats[i], euler2quat(eul), rtol=1e-7) @@ -62,7 +61,3 @@ class TestOrientation(unittest.TestCase): np.testing.assert_allclose(ned_eulers[i], ned_euler_from_ecef(ecef_positions[i], eulers[i]), rtol=1e-7) #np.testing.assert_allclose(eulers[i], ecef_euler_from_ned(ecef_positions[i], ned_eulers[i]), rtol=1e-7) # np.testing.assert_allclose(ned_eulers, ned_euler_from_ecef(ecef_positions, eulers), rtol=1e-7) - - -if __name__ == "__main__": - unittest.main() diff --git a/opendbc b/opendbc index 91a9bb4824..e2408cb272 160000 --- a/opendbc +++ b/opendbc @@ -1 +1 @@ -Subproject commit 91a9bb4824381c39b8a15b443d5b265ec550b62b +Subproject commit e2408cb2725671ad63e827e02f37e0f1739b68c6 diff --git a/poetry.lock b/poetry.lock index 978574a1e2..f1f8d9f518 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohttp" @@ -6641,6 +6641,24 @@ pluggy = ">=1.5,<2.0" [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.23.6" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-asyncio-0.23.6.tar.gz", hash = "sha256:ffe523a89c1c222598c76856e76852b787504ddb72dd5d9b6617ffa8aa2cde5f"}, + {file = "pytest_asyncio-0.23.6-py3-none-any.whl", hash = "sha256:68516fdd1018ac57b846c9846b954f0393b26f094764a28c955eabb0536a4e8a"}, +] + +[package.dependencies] +pytest = ">=7.0.0,<9" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + [[package]] name = "pytest-cov" version = "5.0.0" @@ -6674,6 +6692,23 @@ files = [ colorama = "*" pytest = ">=7.0" +[[package]] +name = "pytest-mock" +version = "3.14.0" +description = "Thin-wrapper around the mock package for easier use with pytest" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0"}, + {file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"}, +] + +[package.dependencies] +pytest = ">=6.2.5" + +[package.extras] +dev = ["pre-commit", "pytest-asyncio", "tox"] + [[package]] name = "pytest-randomly" version = "3.15.0" @@ -7987,4 +8022,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "~3.11" -content-hash = "5f0a1b6f26faa3effeaa5393b73d9188be385a72c1d3b9befb3f03df3b38c86d" +content-hash = "9f69dc7862f33f61e94e960f0ead2cbcd306b4502163d1934381d476143344f4" diff --git a/pyproject.toml b/pyproject.toml index a25fcc1645..a4e31145de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -154,6 +154,8 @@ pytest-subtests = "*" pytest-xdist = "*" pytest-timeout = "*" pytest-randomly = "*" +pytest-asyncio = "*" +pytest-mock = "*" ruff = "*" sphinx = "*" sphinx-rtd-theme = "*" @@ -197,6 +199,7 @@ lint.flake8-implicit-str-concat.allow-multiline=false "third_party".msg = "Use openpilot.third_party" "tools".msg = "Use openpilot.tools" "pytest.main".msg = "pytest.main requires special handling that is easy to mess up!" +"unittest".msg = "Use pytest" [tool.coverage.run] concurrency = ["multiprocessing", "thread"] diff --git a/selfdrive/athena/tests/test_athenad.py b/selfdrive/athena/tests/test_athenad.py index 4850ab9a3f..bdce3dccef 100755 --- a/selfdrive/athena/tests/test_athenad.py +++ b/selfdrive/athena/tests/test_athenad.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 -from functools import partial, wraps +import pytest +from functools import wraps import json import multiprocessing import os @@ -8,12 +9,9 @@ import shutil import time import threading import queue -import unittest from dataclasses import asdict, replace from datetime import datetime, timedelta -from parameterized import parameterized -from unittest import mock from websocket import ABNF from websocket._exceptions import WebSocketConnectionClosedException @@ -24,7 +22,7 @@ from openpilot.common.timeout import Timeout from openpilot.selfdrive.athena import athenad from openpilot.selfdrive.athena.athenad import MAX_RETRY_COUNT, dispatcher from openpilot.selfdrive.athena.tests.helpers import HTTPRequestHandler, MockWebsocket, MockApi, EchoSocket -from openpilot.selfdrive.test.helpers import with_http_server +from openpilot.selfdrive.test.helpers import http_server_context from openpilot.system.hardware.hw import Paths @@ -37,10 +35,6 @@ def seed_athena_server(host, port): except requests.exceptions.ConnectionError: time.sleep(0.1) - -with_mock_athena = partial(with_http_server, handler=HTTPRequestHandler, setup=seed_athena_server) - - def with_upload_handler(func): @wraps(func) def wrapper(*args, **kwargs): @@ -54,15 +48,23 @@ def with_upload_handler(func): thread.join() return wrapper +@pytest.fixture +def mock_create_connection(mocker): + return mocker.patch('openpilot.selfdrive.athena.athenad.create_connection') + +@pytest.fixture +def host(): + with http_server_context(handler=HTTPRequestHandler, setup=seed_athena_server) as (host, port): + yield f"http://{host}:{port}" -class TestAthenadMethods(unittest.TestCase): +class TestAthenadMethods: @classmethod - def setUpClass(cls): + def setup_class(cls): cls.SOCKET_PORT = 45454 athenad.Api = MockApi athenad.LOCAL_PORT_WHITELIST = {cls.SOCKET_PORT} - def setUp(self): + def setup_method(self): self.default_params = { "DongleId": "0000000000000000", "GithubSshKeys": b"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQC307aE+nuHzTAgaJhzSf5v7ZZQW9gaperjhCmyPyl4PzY7T1mDGenTlVTN7yoVFZ9UfO9oMQqo0n1OwDIiqbIFxqnhrHU0cYfj88rI85m5BEKlNu5RdaVTj1tcbaPpQc5kZEolaI1nDDjzV0lwS7jo5VYDHseiJHlik3HH1SgtdtsuamGR2T80q1SyW+5rHoMOJG73IH2553NnWuikKiuikGHUYBd00K1ilVAK2xSiMWJp55tQfZ0ecr9QjEsJ+J/efL4HqGNXhffxvypCXvbUYAFSddOwXUPo5BTKevpxMtH+2YrkpSjocWA04VnTYFiPG6U4ItKmbLOTFZtPzoez private", # noqa: E501 @@ -109,8 +111,8 @@ class TestAthenadMethods(unittest.TestCase): def test_echo(self): assert dispatcher["echo"]("bob") == "bob" - def test_getMessage(self): - with self.assertRaises(TimeoutError) as _: + def test_get_message(self): + with pytest.raises(TimeoutError) as _: dispatcher["getMessage"]("controlsState") end_event = multiprocessing.Event() @@ -133,7 +135,7 @@ class TestAthenadMethods(unittest.TestCase): end_event.set() p.join() - def test_listDataDirectory(self): + def test_list_data_directory(self): route = '2021-03-29--13-32-47' segments = [0, 1, 2, 3, 11] @@ -143,69 +145,66 @@ class TestAthenadMethods(unittest.TestCase): self._create_file(file) resp = dispatcher["listDataDirectory"]() - self.assertTrue(resp, 'list empty!') - self.assertCountEqual(resp, files) + assert resp, 'list empty!' + assert len(resp) == len(files) resp = dispatcher["listDataDirectory"](f'{route}--123') - self.assertCountEqual(resp, []) + assert len(resp) == 0 prefix = f'{route}' - expected = filter(lambda f: f.startswith(prefix), files) + expected = list(filter(lambda f: f.startswith(prefix), files)) resp = dispatcher["listDataDirectory"](prefix) - self.assertTrue(resp, 'list empty!') - self.assertCountEqual(resp, expected) + assert resp, 'list empty!' + assert len(resp) == len(expected) prefix = f'{route}--1' - expected = filter(lambda f: f.startswith(prefix), files) + expected = list(filter(lambda f: f.startswith(prefix), files)) resp = dispatcher["listDataDirectory"](prefix) - self.assertTrue(resp, 'list empty!') - self.assertCountEqual(resp, expected) + assert resp, 'list empty!' + assert len(resp) == len(expected) prefix = f'{route}--1/' - expected = filter(lambda f: f.startswith(prefix), files) + expected = list(filter(lambda f: f.startswith(prefix), files)) resp = dispatcher["listDataDirectory"](prefix) - self.assertTrue(resp, 'list empty!') - self.assertCountEqual(resp, expected) + assert resp, 'list empty!' + assert len(resp) == len(expected) prefix = f'{route}--1/q' - expected = filter(lambda f: f.startswith(prefix), files) + expected = list(filter(lambda f: f.startswith(prefix), files)) resp = dispatcher["listDataDirectory"](prefix) - self.assertTrue(resp, 'list empty!') - self.assertCountEqual(resp, expected) + assert resp, 'list empty!' + assert len(resp) == len(expected) def test_strip_bz2_extension(self): fn = self._create_file('qlog.bz2') if fn.endswith('.bz2'): - self.assertEqual(athenad.strip_bz2_extension(fn), fn[:-4]) + assert athenad.strip_bz2_extension(fn) == fn[:-4] - @parameterized.expand([(True,), (False,)]) - @with_mock_athena - def test_do_upload(self, compress, host): + @pytest.mark.parametrize("compress", [True, False]) + def test_do_upload(self, host, compress): # random bytes to ensure rather large object post-compression fn = self._create_file('qlog', data=os.urandom(10000 * 1024)) upload_fn = fn + ('.bz2' if compress else '') item = athenad.UploadItem(path=upload_fn, url="http://localhost:1238", headers={}, created_at=int(time.time()*1000), id='') - with self.assertRaises(requests.exceptions.ConnectionError): + with pytest.raises(requests.exceptions.ConnectionError): athenad._do_upload(item) item = athenad.UploadItem(path=upload_fn, url=f"{host}/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='') resp = athenad._do_upload(item) - self.assertEqual(resp.status_code, 201) + assert resp.status_code == 201 - @with_mock_athena - def test_uploadFileToUrl(self, host): + def test_upload_file_to_url(self, host): fn = self._create_file('qlog.bz2') resp = dispatcher["uploadFileToUrl"]("qlog.bz2", f"{host}/qlog.bz2", {}) - self.assertEqual(resp['enqueued'], 1) - self.assertNotIn('failed', resp) - self.assertLessEqual({"path": fn, "url": f"{host}/qlog.bz2", "headers": {}}.items(), resp['items'][0].items()) - self.assertIsNotNone(resp['items'][0].get('id')) - self.assertEqual(athenad.upload_queue.qsize(), 1) - - @with_mock_athena - def test_uploadFileToUrl_duplicate(self, host): + assert resp['enqueued'] == 1 + assert 'failed' not in resp + assert {"path": fn, "url": f"{host}/qlog.bz2", "headers": {}}.items() <= resp['items'][0].items() + assert resp['items'][0].get('id') is not None + assert athenad.upload_queue.qsize() == 1 + + def test_upload_file_to_url_duplicate(self, host): self._create_file('qlog.bz2') url1 = f"{host}/qlog.bz2?sig=sig1" @@ -214,14 +213,12 @@ class TestAthenadMethods(unittest.TestCase): # Upload same file again, but with different signature url2 = f"{host}/qlog.bz2?sig=sig2" resp = dispatcher["uploadFileToUrl"]("qlog.bz2", url2, {}) - self.assertEqual(resp, {'enqueued': 0, 'items': []}) + assert resp == {'enqueued': 0, 'items': []} - @with_mock_athena - def test_uploadFileToUrl_does_not_exist(self, host): + def test_upload_file_to_url_does_not_exist(self, host): not_exists_resp = dispatcher["uploadFileToUrl"]("does_not_exist.bz2", "http://localhost:1238", {}) - self.assertEqual(not_exists_resp, {'enqueued': 0, 'items': [], 'failed': ['does_not_exist.bz2']}) + assert not_exists_resp == {'enqueued': 0, 'items': [], 'failed': ['does_not_exist.bz2']} - @with_mock_athena @with_upload_handler def test_upload_handler(self, host): fn = self._create_file('qlog.bz2') @@ -233,13 +230,12 @@ class TestAthenadMethods(unittest.TestCase): # TODO: verify that upload actually succeeded # TODO: also check that end_event and metered network raises AbortTransferException - self.assertEqual(athenad.upload_queue.qsize(), 0) + assert athenad.upload_queue.qsize() == 0 - @parameterized.expand([(500, True), (412, False)]) - @with_mock_athena - @mock.patch('requests.put') + @pytest.mark.parametrize("status,retry", [(500,True), (412,False)]) @with_upload_handler - def test_upload_handler_retry(self, status, retry, mock_put, host): + def test_upload_handler_retry(self, mocker, host, status, retry): + mock_put = mocker.patch('requests.put') mock_put.return_value.status_code = status fn = self._create_file('qlog.bz2') item = athenad.UploadItem(path=fn, url=f"{host}/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='', allow_cellular=True) @@ -248,10 +244,10 @@ class TestAthenadMethods(unittest.TestCase): self._wait_for_upload() time.sleep(0.1) - self.assertEqual(athenad.upload_queue.qsize(), 1 if retry else 0) + assert athenad.upload_queue.qsize() == (1 if retry else 0) if retry: - self.assertEqual(athenad.upload_queue.get().retry_count, 1) + assert athenad.upload_queue.get().retry_count == 1 @with_upload_handler def test_upload_handler_timeout(self): @@ -265,33 +261,33 @@ class TestAthenadMethods(unittest.TestCase): time.sleep(0.1) # Check that upload with retry count exceeded is not put back - self.assertEqual(athenad.upload_queue.qsize(), 0) + assert athenad.upload_queue.qsize() == 0 athenad.upload_queue.put_nowait(item) self._wait_for_upload() time.sleep(0.1) # Check that upload item was put back in the queue with incremented retry count - self.assertEqual(athenad.upload_queue.qsize(), 1) - self.assertEqual(athenad.upload_queue.get().retry_count, 1) + assert athenad.upload_queue.qsize() == 1 + assert athenad.upload_queue.get().retry_count == 1 @with_upload_handler - def test_cancelUpload(self): + def test_cancel_upload(self): item = athenad.UploadItem(path="qlog.bz2", url="http://localhost:44444/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='id', allow_cellular=True) athenad.upload_queue.put_nowait(item) dispatcher["cancelUpload"](item.id) - self.assertIn(item.id, athenad.cancelled_uploads) + assert item.id in athenad.cancelled_uploads self._wait_for_upload() time.sleep(0.1) - self.assertEqual(athenad.upload_queue.qsize(), 0) - self.assertEqual(len(athenad.cancelled_uploads), 0) + assert athenad.upload_queue.qsize() == 0 + assert len(athenad.cancelled_uploads) == 0 @with_upload_handler - def test_cancelExpiry(self): + def test_cancel_expiry(self): t_future = datetime.now() - timedelta(days=40) ts = int(t_future.strftime("%s")) * 1000 @@ -303,15 +299,14 @@ class TestAthenadMethods(unittest.TestCase): self._wait_for_upload() time.sleep(0.1) - self.assertEqual(athenad.upload_queue.qsize(), 0) + assert athenad.upload_queue.qsize() == 0 - def test_listUploadQueueEmpty(self): + def test_list_upload_queue_empty(self): items = dispatcher["listUploadQueue"]() - self.assertEqual(len(items), 0) + assert len(items) == 0 - @with_http_server @with_upload_handler - def test_listUploadQueueCurrent(self, host: str): + def test_list_upload_queue_current(self, host: str): fn = self._create_file('qlog.bz2') item = athenad.UploadItem(path=fn, url=f"{host}/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='', allow_cellular=True) @@ -319,22 +314,22 @@ class TestAthenadMethods(unittest.TestCase): self._wait_for_upload() items = dispatcher["listUploadQueue"]() - self.assertEqual(len(items), 1) - self.assertTrue(items[0]['current']) + assert len(items) == 1 + assert items[0]['current'] - def test_listUploadQueue(self): + def test_list_upload_queue(self): item = athenad.UploadItem(path="qlog.bz2", url="http://localhost:44444/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='id', allow_cellular=True) athenad.upload_queue.put_nowait(item) items = dispatcher["listUploadQueue"]() - self.assertEqual(len(items), 1) - self.assertDictEqual(items[0], asdict(item)) - self.assertFalse(items[0]['current']) + assert len(items) == 1 + assert items[0] == asdict(item) + assert not items[0]['current'] athenad.cancelled_uploads.add(item.id) items = dispatcher["listUploadQueue"]() - self.assertEqual(len(items), 0) + assert len(items) == 0 def test_upload_queue_persistence(self): item1 = athenad.UploadItem(path="_", url="_", headers={}, created_at=int(time.time()), id='id1') @@ -353,11 +348,10 @@ class TestAthenadMethods(unittest.TestCase): athenad.upload_queue.queue.clear() athenad.UploadQueueCache.initialize(athenad.upload_queue) - self.assertEqual(athenad.upload_queue.qsize(), 1) - self.assertDictEqual(asdict(athenad.upload_queue.queue[-1]), asdict(item1)) + assert athenad.upload_queue.qsize() == 1 + assert asdict(athenad.upload_queue.queue[-1]) == asdict(item1) - @mock.patch('openpilot.selfdrive.athena.athenad.create_connection') - def test_startLocalProxy(self, mock_create_connection): + def test_start_local_proxy(self, mock_create_connection): end_event = threading.Event() ws_recv = queue.Queue() @@ -380,21 +374,21 @@ class TestAthenadMethods(unittest.TestCase): ws_recv.put_nowait(WebSocketConnectionClosedException()) socket_thread.join() - def test_getSshAuthorizedKeys(self): + def test_get_ssh_authorized_keys(self): keys = dispatcher["getSshAuthorizedKeys"]() - self.assertEqual(keys, self.default_params["GithubSshKeys"].decode('utf-8')) + assert keys == self.default_params["GithubSshKeys"].decode('utf-8') - def test_getGithubUsername(self): + def test_get_github_username(self): keys = dispatcher["getGithubUsername"]() - self.assertEqual(keys, self.default_params["GithubUsername"].decode('utf-8')) + assert keys == self.default_params["GithubUsername"].decode('utf-8') - def test_getVersion(self): + def test_get_version(self): resp = dispatcher["getVersion"]() keys = ["version", "remote", "branch", "commit"] - self.assertEqual(list(resp.keys()), keys) + assert list(resp.keys()) == keys for k in keys: - self.assertIsInstance(resp[k], str, f"{k} is not a string") - self.assertTrue(len(resp[k]) > 0, f"{k} has no value") + assert isinstance(resp[k], str), f"{k} is not a string" + assert len(resp[k]) > 0, f"{k} has no value" def test_jsonrpc_handler(self): end_event = threading.Event() @@ -405,15 +399,15 @@ class TestAthenadMethods(unittest.TestCase): # with params athenad.recv_queue.put_nowait(json.dumps({"method": "echo", "params": ["hello"], "jsonrpc": "2.0", "id": 0})) resp = athenad.send_queue.get(timeout=3) - self.assertDictEqual(json.loads(resp), {'result': 'hello', 'id': 0, 'jsonrpc': '2.0'}) + assert json.loads(resp) == {'result': 'hello', 'id': 0, 'jsonrpc': '2.0'} # without params athenad.recv_queue.put_nowait(json.dumps({"method": "getNetworkType", "jsonrpc": "2.0", "id": 0})) resp = athenad.send_queue.get(timeout=3) - self.assertDictEqual(json.loads(resp), {'result': 1, 'id': 0, 'jsonrpc': '2.0'}) + assert json.loads(resp) == {'result': 1, 'id': 0, 'jsonrpc': '2.0'} # log forwarding athenad.recv_queue.put_nowait(json.dumps({'result': {'success': 1}, 'id': 0, 'jsonrpc': '2.0'})) resp = athenad.log_recv_queue.get(timeout=3) - self.assertDictEqual(json.loads(resp), {'result': {'success': 1}, 'id': 0, 'jsonrpc': '2.0'}) + assert json.loads(resp) == {'result': {'success': 1}, 'id': 0, 'jsonrpc': '2.0'} finally: end_event.set() thread.join() @@ -427,8 +421,4 @@ class TestAthenadMethods(unittest.TestCase): # ensure the list is all logs except most recent sl = athenad.get_logs_to_send_sorted() - self.assertListEqual(sl, fl[:-1]) - - -if __name__ == '__main__': - unittest.main() + assert sl == fl[:-1] diff --git a/selfdrive/athena/tests/test_athenad_ping.py b/selfdrive/athena/tests/test_athenad_ping.py index f56fcac8b5..44fa0b8481 100755 --- a/selfdrive/athena/tests/test_athenad_ping.py +++ b/selfdrive/athena/tests/test_athenad_ping.py @@ -1,10 +1,9 @@ #!/usr/bin/env python3 +import pytest import subprocess import threading import time -import unittest from typing import cast -from unittest import mock from openpilot.common.params import Params from openpilot.common.timeout import Timeout @@ -22,7 +21,7 @@ def wifi_radio(on: bool) -> None: subprocess.run(["nmcli", "radio", "wifi", "on" if on else "off"], check=True) -class TestAthenadPing(unittest.TestCase): +class TestAthenadPing: params: Params dongle_id: str @@ -39,10 +38,10 @@ class TestAthenadPing(unittest.TestCase): return self._get_ping_time() is not None @classmethod - def tearDownClass(cls) -> None: + def teardown_class(cls) -> None: wifi_radio(True) - def setUp(self) -> None: + def setup_method(self) -> None: self.params = Params() self.dongle_id = self.params.get("DongleId", encoding="utf-8") @@ -52,21 +51,23 @@ class TestAthenadPing(unittest.TestCase): self.exit_event = threading.Event() self.athenad = threading.Thread(target=athenad.main, args=(self.exit_event,)) - def tearDown(self) -> None: + def teardown_method(self) -> None: if self.athenad.is_alive(): self.exit_event.set() self.athenad.join() - @mock.patch('openpilot.selfdrive.athena.athenad.create_connection', new_callable=lambda: mock.MagicMock(wraps=athenad.create_connection)) - def assertTimeout(self, reconnect_time: float, mock_create_connection: mock.MagicMock) -> None: + def assertTimeout(self, reconnect_time: float, subtests, mocker) -> None: self.athenad.start() + mock_create_connection = mocker.patch('openpilot.selfdrive.athena.athenad.create_connection', + new_callable=lambda: mocker.MagicMock(wraps=athenad.create_connection)) + time.sleep(1) mock_create_connection.assert_called_once() mock_create_connection.reset_mock() # check normal behaviour, server pings on connection - with self.subTest("Wi-Fi: receives ping"), Timeout(70, "no ping received"): + with subtests.test("Wi-Fi: receives ping"), Timeout(70, "no ping received"): while not self._received_ping(): time.sleep(0.1) print("ping received") @@ -74,7 +75,7 @@ class TestAthenadPing(unittest.TestCase): mock_create_connection.assert_not_called() # websocket should attempt reconnect after short time - with self.subTest("LTE: attempt reconnect"): + with subtests.test("LTE: attempt reconnect"): wifi_radio(False) print("waiting for reconnect attempt") start_time = time.monotonic() @@ -86,21 +87,17 @@ class TestAthenadPing(unittest.TestCase): self._clear_ping_time() # check ping received after reconnect - with self.subTest("LTE: receives ping"), Timeout(70, "no ping received"): + with subtests.test("LTE: receives ping"), Timeout(70, "no ping received"): while not self._received_ping(): time.sleep(0.1) print("ping received") - @unittest.skipIf(not TICI, "only run on desk") - def test_offroad(self) -> None: + @pytest.mark.skipif(not TICI, reason="only run on desk") + def test_offroad(self, subtests, mocker) -> None: write_onroad_params(False, self.params) - self.assertTimeout(60 + TIMEOUT_TOLERANCE) # based using TCP keepalive settings + self.assertTimeout(60 + TIMEOUT_TOLERANCE, subtests, mocker) # based using TCP keepalive settings - @unittest.skipIf(not TICI, "only run on desk") - def test_onroad(self) -> None: + @pytest.mark.skipif(not TICI, reason="only run on desk") + def test_onroad(self, subtests, mocker) -> None: write_onroad_params(True, self.params) - self.assertTimeout(21 + TIMEOUT_TOLERANCE) - - -if __name__ == "__main__": - unittest.main() + self.assertTimeout(21 + TIMEOUT_TOLERANCE, subtests, mocker) diff --git a/selfdrive/athena/tests/test_registration.py b/selfdrive/athena/tests/test_registration.py index e7ad63a370..a808dd5668 100755 --- a/selfdrive/athena/tests/test_registration.py +++ b/selfdrive/athena/tests/test_registration.py @@ -1,9 +1,7 @@ #!/usr/bin/env python3 import json -import unittest from Crypto.PublicKey import RSA from pathlib import Path -from unittest import mock from openpilot.common.params import Params from openpilot.selfdrive.athena.registration import register, UNREGISTERED_DONGLE_ID @@ -11,9 +9,9 @@ from openpilot.selfdrive.athena.tests.helpers import MockResponse from openpilot.system.hardware.hw import Paths -class TestRegistration(unittest.TestCase): +class TestRegistration: - def setUp(self): + def setup_method(self): # clear params and setup key paths self.params = Params() self.params.clear_all() @@ -32,50 +30,46 @@ class TestRegistration(unittest.TestCase): with open(self.pub_key, "wb") as f: f.write(k.publickey().export_key()) - def test_valid_cache(self): + def test_valid_cache(self, mocker): # if all params are written, return the cached dongle id self.params.put("IMEI", "imei") self.params.put("HardwareSerial", "serial") self._generate_keys() - with mock.patch("openpilot.selfdrive.athena.registration.api_get", autospec=True) as m: - dongle = "DONGLE_ID_123" - self.params.put("DongleId", dongle) - self.assertEqual(register(), dongle) - self.assertFalse(m.called) + m = mocker.patch("openpilot.selfdrive.athena.registration.api_get", autospec=True) + dongle = "DONGLE_ID_123" + self.params.put("DongleId", dongle) + assert register() == dongle + assert not m.called - def test_no_keys(self): + def test_no_keys(self, mocker): # missing pubkey - with mock.patch("openpilot.selfdrive.athena.registration.api_get", autospec=True) as m: - dongle = register() - self.assertEqual(m.call_count, 0) - self.assertEqual(dongle, UNREGISTERED_DONGLE_ID) - self.assertEqual(self.params.get("DongleId", encoding='utf-8'), dongle) + m = mocker.patch("openpilot.selfdrive.athena.registration.api_get", autospec=True) + dongle = register() + assert m.call_count == 0 + assert dongle == UNREGISTERED_DONGLE_ID + assert self.params.get("DongleId", encoding='utf-8') == dongle - def test_missing_cache(self): + def test_missing_cache(self, mocker): # keys exist but no dongle id self._generate_keys() - with mock.patch("openpilot.selfdrive.athena.registration.api_get", autospec=True) as m: - dongle = "DONGLE_ID_123" - m.return_value = MockResponse(json.dumps({'dongle_id': dongle}), 200) - self.assertEqual(register(), dongle) - self.assertEqual(m.call_count, 1) + m = mocker.patch("openpilot.selfdrive.athena.registration.api_get", autospec=True) + dongle = "DONGLE_ID_123" + m.return_value = MockResponse(json.dumps({'dongle_id': dongle}), 200) + assert register() == dongle + assert m.call_count == 1 - # call again, shouldn't hit the API this time - self.assertEqual(register(), dongle) - self.assertEqual(m.call_count, 1) - self.assertEqual(self.params.get("DongleId", encoding='utf-8'), dongle) + # call again, shouldn't hit the API this time + assert register() == dongle + assert m.call_count == 1 + assert self.params.get("DongleId", encoding='utf-8') == dongle - def test_unregistered(self): + def test_unregistered(self, mocker): # keys exist, but unregistered self._generate_keys() - with mock.patch("openpilot.selfdrive.athena.registration.api_get", autospec=True) as m: - m.return_value = MockResponse(None, 402) - dongle = register() - self.assertEqual(m.call_count, 1) - self.assertEqual(dongle, UNREGISTERED_DONGLE_ID) - self.assertEqual(self.params.get("DongleId", encoding='utf-8'), dongle) - - -if __name__ == "__main__": - unittest.main() + m = mocker.patch("openpilot.selfdrive.athena.registration.api_get", autospec=True) + m.return_value = MockResponse(None, 402) + dongle = register() + assert m.call_count == 1 + assert dongle == UNREGISTERED_DONGLE_ID + assert self.params.get("DongleId", encoding='utf-8') == dongle diff --git a/selfdrive/boardd/tests/test_boardd_loopback.py b/selfdrive/boardd/tests/test_boardd_loopback.py index 3ab3a9c5b1..fa9eb957c2 100755 --- a/selfdrive/boardd/tests/test_boardd_loopback.py +++ b/selfdrive/boardd/tests/test_boardd_loopback.py @@ -4,7 +4,6 @@ import copy import random import time import pytest -import unittest from collections import defaultdict from pprint import pprint @@ -107,7 +106,3 @@ class TestBoarddLoopback: pprint(sm['pandaStates']) # may drop messages due to RX buffer overflow for bus in sent_loopback.keys(): assert not len(sent_loopback[bus]), f"loop {i}: bus {bus} missing {len(sent_loopback[bus])} out of {sent_total[bus]} messages" - - -if __name__ == "__main__": - unittest.main() diff --git a/selfdrive/boardd/tests/test_pandad.py b/selfdrive/boardd/tests/test_pandad.py index 3434be3fe4..65f3ad657c 100755 --- a/selfdrive/boardd/tests/test_pandad.py +++ b/selfdrive/boardd/tests/test_pandad.py @@ -2,7 +2,6 @@ import os import pytest import time -import unittest import cereal.messaging as messaging from cereal import log @@ -16,16 +15,16 @@ HERE = os.path.dirname(os.path.realpath(__file__)) @pytest.mark.tici -class TestPandad(unittest.TestCase): +class TestPandad: - def setUp(self): + def setup_method(self): # ensure panda is up if len(Panda.list()) == 0: self._run_test(60) self.spi = HARDWARE.get_device_type() != 'tici' - def tearDown(self): + def teardown_method(self): managed_processes['pandad'].stop() def _run_test(self, timeout=30) -> float: @@ -65,7 +64,7 @@ class TestPandad(unittest.TestCase): assert Panda.wait_for_panda(None, 10) if expect_mismatch: - with self.assertRaises(PandaProtocolMismatch): + with pytest.raises(PandaProtocolMismatch): Panda() else: with Panda() as p: @@ -108,9 +107,10 @@ class TestPandad(unittest.TestCase): assert 0.1 < (sum(ts)/len(ts)) < (0.5 if self.spi else 5.0) print("startup times", ts, sum(ts) / len(ts)) + def test_protocol_version_check(self): if not self.spi: - raise unittest.SkipTest("SPI test") + pytest.skip("SPI test") # flash old fw fn = os.path.join(HERE, "bootstub.panda_h7_spiv0.bin") self._flash_bootstub_and_test(fn, expect_mismatch=True) @@ -127,7 +127,3 @@ class TestPandad(unittest.TestCase): self._assert_no_panda() self._run_test(60) - - -if __name__ == "__main__": - unittest.main() diff --git a/selfdrive/car/ford/tests/test_ford.py b/selfdrive/car/ford/tests/test_ford.py index 5d7b2c3332..72dd69980a 100755 --- a/selfdrive/car/ford/tests/test_ford.py +++ b/selfdrive/car/ford/tests/test_ford.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 import random -import unittest from collections.abc import Iterable import capnp @@ -43,31 +42,31 @@ ECU_PART_NUMBER = { } -class TestFordFW(unittest.TestCase): +class TestFordFW: def test_fw_query_config(self): for (ecu, addr, subaddr) in FW_QUERY_CONFIG.extra_ecus: - self.assertIn(ecu, ECU_ADDRESSES, "Unknown ECU") - self.assertEqual(addr, ECU_ADDRESSES[ecu], "ECU address mismatch") - self.assertIsNone(subaddr, "Unexpected ECU subaddress") + assert ecu in ECU_ADDRESSES, "Unknown ECU" + assert addr == ECU_ADDRESSES[ecu], "ECU address mismatch" + assert subaddr is None, "Unexpected ECU subaddress" @parameterized.expand(FW_VERSIONS.items()) def test_fw_versions(self, car_model: str, fw_versions: dict[tuple[capnp.lib.capnp._EnumModule, int, int | None], Iterable[bytes]]): for (ecu, addr, subaddr), fws in fw_versions.items(): - self.assertIn(ecu, ECU_PART_NUMBER, "Unexpected ECU") - self.assertEqual(addr, ECU_ADDRESSES[ecu], "ECU address mismatch") - self.assertIsNone(subaddr, "Unexpected ECU subaddress") + assert ecu in ECU_PART_NUMBER, "Unexpected ECU" + assert addr == ECU_ADDRESSES[ecu], "ECU address mismatch" + assert subaddr is None, "Unexpected ECU subaddress" for fw in fws: - self.assertEqual(len(fw), 24, "Expected ECU response to be 24 bytes") + assert len(fw) == 24, "Expected ECU response to be 24 bytes" match = FW_PATTERN.match(fw) - self.assertIsNotNone(match, f"Unable to parse FW: {fw!r}") + assert match is not None, f"Unable to parse FW: {fw!r}" if match: part_number = match.group("part_number") - self.assertIn(part_number, ECU_PART_NUMBER[ecu], f"Unexpected part number for {fw!r}") + assert part_number in ECU_PART_NUMBER[ecu], f"Unexpected part number for {fw!r}" codes = get_platform_codes([fw]) - self.assertEqual(1, len(codes), f"Unable to parse FW: {fw!r}") + assert 1 == len(codes), f"Unable to parse FW: {fw!r}" @settings(max_examples=100) @given(data=st.data()) @@ -85,7 +84,7 @@ class TestFordFW(unittest.TestCase): b"PJ6T-14H102-ABJ\x00\x00\x00\x00\x00\x00\x00\x00\x00", b"LB5A-14C204-EAC\x00\x00\x00\x00\x00\x00\x00\x00\x00", ]) - self.assertEqual(results, {(b"X6A", b"J"), (b"Z6T", b"N"), (b"J6T", b"P"), (b"B5A", b"L")}) + assert results == {(b"X6A", b"J"), (b"Z6T", b"N"), (b"J6T", b"P"), (b"B5A", b"L")} def test_fuzzy_match(self): for platform, fw_by_addr in FW_VERSIONS.items(): @@ -100,7 +99,7 @@ class TestFordFW(unittest.TestCase): CP = car.CarParams.new_message(carFw=car_fw) matches = FW_QUERY_CONFIG.match_fw_to_car_fuzzy(build_fw_dict(CP.carFw), CP.carVin, FW_VERSIONS) - self.assertEqual(matches, {platform}) + assert matches == {platform} def test_match_fw_fuzzy(self): offline_fw = { @@ -132,18 +131,14 @@ class TestFordFW(unittest.TestCase): (0x706, None): {b"LB5T-14F397-XX\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, } candidates = FW_QUERY_CONFIG.match_fw_to_car_fuzzy(live_fw, '', {expected_fingerprint: offline_fw}) - self.assertEqual(candidates, {expected_fingerprint}) + assert candidates == {expected_fingerprint} # model year hint in between the range should match live_fw[(0x706, None)] = {b"MB5T-14F397-XX\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"} candidates = FW_QUERY_CONFIG.match_fw_to_car_fuzzy(live_fw, '', {expected_fingerprint: offline_fw,}) - self.assertEqual(candidates, {expected_fingerprint}) + assert candidates == {expected_fingerprint} # unseen model year hint should not match live_fw[(0x760, None)] = {b"M1MC-2D053-XX\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"} candidates = FW_QUERY_CONFIG.match_fw_to_car_fuzzy(live_fw, '', {expected_fingerprint: offline_fw}) - self.assertEqual(len(candidates), 0, "Should not match new model year hint") - - -if __name__ == "__main__": - unittest.main() + assert len(candidates) == 0, "Should not match new model year hint" diff --git a/selfdrive/car/gm/tests/test_gm.py b/selfdrive/car/gm/tests/test_gm.py index 01ec8533b8..389d0636c9 100755 --- a/selfdrive/car/gm/tests/test_gm.py +++ b/selfdrive/car/gm/tests/test_gm.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 from parameterized import parameterized -import unittest from openpilot.selfdrive.car.gm.fingerprints import FINGERPRINTS from openpilot.selfdrive.car.gm.values import CAMERA_ACC_CAR, GM_RX_OFFSET @@ -8,19 +7,15 @@ from openpilot.selfdrive.car.gm.values import CAMERA_ACC_CAR, GM_RX_OFFSET CAMERA_DIAGNOSTIC_ADDRESS = 0x24b -class TestGMFingerprint(unittest.TestCase): +class TestGMFingerprint: @parameterized.expand(FINGERPRINTS.items()) def test_can_fingerprints(self, car_model, fingerprints): - self.assertGreater(len(fingerprints), 0) + assert len(fingerprints) > 0 - self.assertTrue(all(len(finger) for finger in fingerprints)) + assert all(len(finger) for finger in fingerprints) # The camera can sometimes be communicating on startup if car_model in CAMERA_ACC_CAR: for finger in fingerprints: for required_addr in (CAMERA_DIAGNOSTIC_ADDRESS, CAMERA_DIAGNOSTIC_ADDRESS + GM_RX_OFFSET): - self.assertEqual(finger.get(required_addr), 8, required_addr) - - -if __name__ == "__main__": - unittest.main() + assert finger.get(required_addr) == 8, required_addr diff --git a/selfdrive/car/honda/tests/test_honda.py b/selfdrive/car/honda/tests/test_honda.py index 60d91b84a8..54d177d2ed 100755 --- a/selfdrive/car/honda/tests/test_honda.py +++ b/selfdrive/car/honda/tests/test_honda.py @@ -1,20 +1,15 @@ #!/usr/bin/env python3 import re -import unittest from openpilot.selfdrive.car.honda.fingerprints import FW_VERSIONS HONDA_FW_VERSION_RE = br"[A-Z0-9]{5}-[A-Z0-9]{3}(-|,)[A-Z0-9]{4}(\x00){2}$" -class TestHondaFingerprint(unittest.TestCase): +class TestHondaFingerprint: def test_fw_version_format(self): # Asserts all FW versions follow an expected format for fw_by_ecu in FW_VERSIONS.values(): for fws in fw_by_ecu.values(): for fw in fws: - self.assertTrue(re.match(HONDA_FW_VERSION_RE, fw) is not None, fw) - - -if __name__ == "__main__": - unittest.main() + assert re.match(HONDA_FW_VERSION_RE, fw) is not None, fw diff --git a/selfdrive/car/hyundai/tests/test_hyundai.py b/selfdrive/car/hyundai/tests/test_hyundai.py index 0753b372e1..db2110b0de 100755 --- a/selfdrive/car/hyundai/tests/test_hyundai.py +++ b/selfdrive/car/hyundai/tests/test_hyundai.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 from hypothesis import settings, given, strategies as st -import unittest + +import pytest from cereal import car from openpilot.selfdrive.car.fw_versions import build_fw_dict @@ -39,19 +40,19 @@ NO_DATES_PLATFORMS = { CANFD_EXPECTED_ECUS = {Ecu.fwdCamera, Ecu.fwdRadar} -class TestHyundaiFingerprint(unittest.TestCase): +class TestHyundaiFingerprint: def test_can_features(self): # Test no EV/HEV in any gear lists (should all use ELECT_GEAR) - self.assertEqual(set.union(*CAN_GEARS.values()) & (HYBRID_CAR | EV_CAR), set()) + assert set.union(*CAN_GEARS.values()) & (HYBRID_CAR | EV_CAR) == set() # Test CAN FD car not in CAN feature lists can_specific_feature_list = set.union(*CAN_GEARS.values(), *CHECKSUM.values(), LEGACY_SAFETY_MODE_CAR, UNSUPPORTED_LONGITUDINAL_CAR, CAMERA_SCC_CAR) for car_model in CANFD_CAR: - self.assertNotIn(car_model, can_specific_feature_list, "CAN FD car unexpectedly found in a CAN feature list") + assert car_model not in can_specific_feature_list, "CAN FD car unexpectedly found in a CAN feature list" def test_hybrid_ev_sets(self): - self.assertEqual(HYBRID_CAR & EV_CAR, set(), "Shared cars between hybrid and EV") - self.assertEqual(CANFD_CAR & HYBRID_CAR, set(), "Hard coding CAN FD cars as hybrid is no longer supported") + assert HYBRID_CAR & EV_CAR == set(), "Shared cars between hybrid and EV" + assert CANFD_CAR & HYBRID_CAR == set(), "Hard coding CAN FD cars as hybrid is no longer supported" def test_canfd_ecu_whitelist(self): # Asserts only expected Ecus can exist in database for CAN-FD cars @@ -59,34 +60,34 @@ class TestHyundaiFingerprint(unittest.TestCase): ecus = {fw[0] for fw in FW_VERSIONS[car_model].keys()} ecus_not_in_whitelist = ecus - CANFD_EXPECTED_ECUS ecu_strings = ", ".join([f"Ecu.{ECU_NAME[ecu]}" for ecu in ecus_not_in_whitelist]) - self.assertEqual(len(ecus_not_in_whitelist), 0, - f"{car_model}: Car model has unexpected ECUs: {ecu_strings}") + assert len(ecus_not_in_whitelist) == 0, \ + f"{car_model}: Car model has unexpected ECUs: {ecu_strings}" - def test_blacklisted_parts(self): + def test_blacklisted_parts(self, subtests): # Asserts no ECUs known to be shared across platforms exist in the database. # Tucson having Santa Cruz camera and EPS for example for car_model, ecus in FW_VERSIONS.items(): - with self.subTest(car_model=car_model.value): + with subtests.test(car_model=car_model.value): if car_model == CAR.HYUNDAI_SANTA_CRUZ_1ST_GEN: - raise unittest.SkipTest("Skip checking Santa Cruz for its parts") + pytest.skip("Skip checking Santa Cruz for its parts") for code, _ in get_platform_codes(ecus[(Ecu.fwdCamera, 0x7c4, None)]): if b"-" not in code: continue part = code.split(b"-")[1] - self.assertFalse(part.startswith(b'CW'), "Car has bad part number") + assert not part.startswith(b'CW'), "Car has bad part number" - def test_correct_ecu_response_database(self): + def test_correct_ecu_response_database(self, subtests): """ Assert standard responses for certain ECUs, since they can respond to multiple queries with different data """ expected_fw_prefix = HYUNDAI_VERSION_REQUEST_LONG[1:] for car_model, ecus in FW_VERSIONS.items(): - with self.subTest(car_model=car_model.value): + with subtests.test(car_model=car_model.value): for ecu, fws in ecus.items(): - self.assertTrue(all(fw.startswith(expected_fw_prefix) for fw in fws), - f"FW from unexpected request in database: {(ecu, fws)}") + assert all(fw.startswith(expected_fw_prefix) for fw in fws), \ + f"FW from unexpected request in database: {(ecu, fws)}" @settings(max_examples=100) @given(data=st.data()) @@ -96,10 +97,10 @@ class TestHyundaiFingerprint(unittest.TestCase): fws = data.draw(fw_strategy) get_platform_codes(fws) - def test_expected_platform_codes(self): + def test_expected_platform_codes(self, subtests): # Ensures we don't accidentally add multiple platform codes for a car unless it is intentional for car_model, ecus in FW_VERSIONS.items(): - with self.subTest(car_model=car_model.value): + with subtests.test(car_model=car_model.value): for ecu, fws in ecus.items(): if ecu[0] not in PLATFORM_CODE_ECUS: continue @@ -107,37 +108,37 @@ class TestHyundaiFingerprint(unittest.TestCase): # Third and fourth character are usually EV/hybrid identifiers codes = {code.split(b"-")[0][:2] for code, _ in get_platform_codes(fws)} if car_model == CAR.HYUNDAI_PALISADE: - self.assertEqual(codes, {b"LX", b"ON"}, f"Car has unexpected platform codes: {car_model} {codes}") + assert codes == {b"LX", b"ON"}, f"Car has unexpected platform codes: {car_model} {codes}" elif car_model == CAR.HYUNDAI_KONA_EV and ecu[0] == Ecu.fwdCamera: - self.assertEqual(codes, {b"OE", b"OS"}, f"Car has unexpected platform codes: {car_model} {codes}") + assert codes == {b"OE", b"OS"}, f"Car has unexpected platform codes: {car_model} {codes}" else: - self.assertEqual(len(codes), 1, f"Car has multiple platform codes: {car_model} {codes}") + assert len(codes) == 1, f"Car has multiple platform codes: {car_model} {codes}" # Tests for platform codes, part numbers, and FW dates which Hyundai will use to fuzzy # fingerprint in the absence of full FW matches: - def test_platform_code_ecus_available(self): + def test_platform_code_ecus_available(self, subtests): # TODO: add queries for these non-CAN FD cars to get EPS no_eps_platforms = CANFD_CAR | {CAR.KIA_SORENTO, CAR.KIA_OPTIMA_G4, CAR.KIA_OPTIMA_G4_FL, CAR.KIA_OPTIMA_H, CAR.KIA_OPTIMA_H_G4_FL, CAR.HYUNDAI_SONATA_LF, CAR.HYUNDAI_TUCSON, CAR.GENESIS_G90, CAR.GENESIS_G80, CAR.HYUNDAI_ELANTRA} # Asserts ECU keys essential for fuzzy fingerprinting are available on all platforms for car_model, ecus in FW_VERSIONS.items(): - with self.subTest(car_model=car_model.value): + with subtests.test(car_model=car_model.value): for platform_code_ecu in PLATFORM_CODE_ECUS: if platform_code_ecu in (Ecu.fwdRadar, Ecu.eps) and car_model == CAR.HYUNDAI_GENESIS: continue if platform_code_ecu == Ecu.eps and car_model in no_eps_platforms: continue - self.assertIn(platform_code_ecu, [e[0] for e in ecus]) + assert platform_code_ecu in [e[0] for e in ecus] - def test_fw_format(self): + def test_fw_format(self, subtests): # Asserts: # - every supported ECU FW version returns one platform code # - every supported ECU FW version has a part number # - expected parsing of ECU FW dates for car_model, ecus in FW_VERSIONS.items(): - with self.subTest(car_model=car_model.value): + with subtests.test(car_model=car_model.value): for ecu, fws in ecus.items(): if ecu[0] not in PLATFORM_CODE_ECUS: continue @@ -145,40 +146,40 @@ class TestHyundaiFingerprint(unittest.TestCase): codes = set() for fw in fws: result = get_platform_codes([fw]) - self.assertEqual(1, len(result), f"Unable to parse FW: {fw}") + assert 1 == len(result), f"Unable to parse FW: {fw}" codes |= result if ecu[0] not in DATE_FW_ECUS or car_model in NO_DATES_PLATFORMS: - self.assertTrue(all(date is None for _, date in codes)) + assert all(date is None for _, date in codes) else: - self.assertTrue(all(date is not None for _, date in codes)) + assert all(date is not None for _, date in codes) if car_model == CAR.HYUNDAI_GENESIS: - raise unittest.SkipTest("No part numbers for car model") + pytest.skip("No part numbers for car model") # Hyundai places the ECU part number in their FW versions, assert all parsable # Some examples of valid formats: b"56310-L0010", b"56310L0010", b"56310/M6300" - self.assertTrue(all(b"-" in code for code, _ in codes), - f"FW does not have part number: {fw}") + assert all(b"-" in code for code, _ in codes), \ + f"FW does not have part number: {fw}" def test_platform_codes_spot_check(self): # Asserts basic platform code parsing behavior for a few cases results = get_platform_codes([b"\xf1\x00DH LKAS 1.1 -150210"]) - self.assertEqual(results, {(b"DH", b"150210")}) + assert results == {(b"DH", b"150210")} # Some cameras and all radars do not have dates results = get_platform_codes([b"\xf1\x00AEhe SCC H-CUP 1.01 1.01 96400-G2000 "]) - self.assertEqual(results, {(b"AEhe-G2000", None)}) + assert results == {(b"AEhe-G2000", None)} results = get_platform_codes([b"\xf1\x00CV1_ RDR ----- 1.00 1.01 99110-CV000 "]) - self.assertEqual(results, {(b"CV1-CV000", None)}) + assert results == {(b"CV1-CV000", None)} results = get_platform_codes([ b"\xf1\x00DH LKAS 1.1 -150210", b"\xf1\x00AEhe SCC H-CUP 1.01 1.01 96400-G2000 ", b"\xf1\x00CV1_ RDR ----- 1.00 1.01 99110-CV000 ", ]) - self.assertEqual(results, {(b"DH", b"150210"), (b"AEhe-G2000", None), (b"CV1-CV000", None)}) + assert results == {(b"DH", b"150210"), (b"AEhe-G2000", None), (b"CV1-CV000", None)} results = get_platform_codes([ b"\xf1\x00LX2 MFC AT USA LHD 1.00 1.07 99211-S8100 220222", @@ -186,8 +187,8 @@ class TestHyundaiFingerprint(unittest.TestCase): b"\xf1\x00ON MFC AT USA LHD 1.00 1.01 99211-S9100 190405", b"\xf1\x00ON MFC AT USA LHD 1.00 1.03 99211-S9100 190720", ]) - self.assertEqual(results, {(b"LX2-S8100", b"220222"), (b"LX2-S8100", b"211103"), - (b"ON-S9100", b"190405"), (b"ON-S9100", b"190720")}) + assert results == {(b"LX2-S8100", b"220222"), (b"LX2-S8100", b"211103"), + (b"ON-S9100", b"190405"), (b"ON-S9100", b"190720")} def test_fuzzy_excluded_platforms(self): # Asserts a list of platforms that will not fuzzy fingerprint with platform codes due to them being shared. @@ -211,12 +212,8 @@ class TestHyundaiFingerprint(unittest.TestCase): CP = car.CarParams.new_message(carFw=car_fw) matches = FW_QUERY_CONFIG.match_fw_to_car_fuzzy(build_fw_dict(CP.carFw), CP.carVin, FW_VERSIONS) if len(matches) == 1: - self.assertEqual(list(matches)[0], platform) + assert list(matches)[0] == platform else: platforms_with_shared_codes.add(platform) - self.assertEqual(platforms_with_shared_codes, excluded_platforms) - - -if __name__ == "__main__": - unittest.main() + assert platforms_with_shared_codes == excluded_platforms diff --git a/selfdrive/car/subaru/tests/test_subaru.py b/selfdrive/car/subaru/tests/test_subaru.py index c8cdf66065..33040442b6 100644 --- a/selfdrive/car/subaru/tests/test_subaru.py +++ b/selfdrive/car/subaru/tests/test_subaru.py @@ -1,5 +1,4 @@ from cereal import car -import unittest from openpilot.selfdrive.car.subaru.fingerprints import FW_VERSIONS Ecu = car.CarParams.Ecu @@ -7,14 +6,11 @@ Ecu = car.CarParams.Ecu ECU_NAME = {v: k for k, v in Ecu.schema.enumerants.items()} -class TestSubaruFingerprint(unittest.TestCase): +class TestSubaruFingerprint: def test_fw_version_format(self): for platform, fws_per_ecu in FW_VERSIONS.items(): for (ecu, _, _), fws in fws_per_ecu.items(): fw_size = len(fws[0]) for fw in fws: - self.assertEqual(len(fw), fw_size, f"{platform} {ecu}: {len(fw)} {fw_size}") + assert len(fw) == fw_size, f"{platform} {ecu}: {len(fw)} {fw_size}" - -if __name__ == "__main__": - unittest.main() diff --git a/selfdrive/car/tests/test_can_fingerprint.py b/selfdrive/car/tests/test_can_fingerprint.py index 8df7007339..bb585d567f 100755 --- a/selfdrive/car/tests/test_can_fingerprint.py +++ b/selfdrive/car/tests/test_can_fingerprint.py @@ -1,13 +1,12 @@ #!/usr/bin/env python3 from parameterized import parameterized -import unittest from cereal import log, messaging from openpilot.selfdrive.car.car_helpers import FRAME_FINGERPRINT, can_fingerprint from openpilot.selfdrive.car.fingerprints import _FINGERPRINTS as FINGERPRINTS -class TestCanFingerprint(unittest.TestCase): +class TestCanFingerprint: @parameterized.expand(list(FINGERPRINTS.items())) def test_can_fingerprint(self, car_model, fingerprints): """Tests online fingerprinting function on offline fingerprints""" @@ -21,12 +20,12 @@ class TestCanFingerprint(unittest.TestCase): empty_can = messaging.new_message('can', 0) car_fingerprint, finger = can_fingerprint(lambda: next(fingerprint_iter, empty_can)) # noqa: B023 - self.assertEqual(car_fingerprint, car_model) - self.assertEqual(finger[0], fingerprint) - self.assertEqual(finger[1], fingerprint) - self.assertEqual(finger[2], {}) + assert car_fingerprint == car_model + assert finger[0] == fingerprint + assert finger[1] == fingerprint + assert finger[2] == {} - def test_timing(self): + def test_timing(self, subtests): # just pick any CAN fingerprinting car car_model = "CHEVROLET_BOLT_EUV" fingerprint = FINGERPRINTS[car_model][0] @@ -50,7 +49,7 @@ class TestCanFingerprint(unittest.TestCase): cases.append((FRAME_FINGERPRINT * 2, None, can)) for expected_frames, car_model, can in cases: - with self.subTest(expected_frames=expected_frames, car_model=car_model): + with subtests.test(expected_frames=expected_frames, car_model=car_model): frames = 0 def test(): @@ -59,9 +58,5 @@ class TestCanFingerprint(unittest.TestCase): return can # noqa: B023 car_fingerprint, _ = can_fingerprint(test) - self.assertEqual(car_fingerprint, car_model) - self.assertEqual(frames, expected_frames + 2) # TODO: fix extra frames - - -if __name__ == "__main__": - unittest.main() + assert car_fingerprint == car_model + assert frames == expected_frames + 2# TODO: fix extra frames diff --git a/selfdrive/car/tests/test_car_interfaces.py b/selfdrive/car/tests/test_car_interfaces.py index dfcd9b0527..4bbecd99fe 100755 --- a/selfdrive/car/tests/test_car_interfaces.py +++ b/selfdrive/car/tests/test_car_interfaces.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 import os import math -import unittest import hypothesis.strategies as st from hypothesis import Phase, given, settings import importlib @@ -45,7 +44,7 @@ def get_fuzzy_car_interface_args(draw: DrawType) -> dict: return params -class TestCarInterfaces(unittest.TestCase): +class TestCarInterfaces: # FIXME: Due to the lists used in carParams, Phase.target is very slow and will cause # many generated examples to overrun when max_examples > ~20, don't use it @parameterized.expand([(car,) for car in sorted(all_known_cars())]) @@ -63,28 +62,28 @@ class TestCarInterfaces(unittest.TestCase): assert car_params assert car_interface - self.assertGreater(car_params.mass, 1) - self.assertGreater(car_params.wheelbase, 0) + assert car_params.mass > 1 + assert car_params.wheelbase > 0 # centerToFront is center of gravity to front wheels, assert a reasonable range - self.assertTrue(car_params.wheelbase * 0.3 < car_params.centerToFront < car_params.wheelbase * 0.7) - self.assertGreater(car_params.maxLateralAccel, 0) + assert car_params.wheelbase * 0.3 < car_params.centerToFront < car_params.wheelbase * 0.7 + assert car_params.maxLateralAccel > 0 # Longitudinal sanity checks - self.assertEqual(len(car_params.longitudinalTuning.kpV), len(car_params.longitudinalTuning.kpBP)) - self.assertEqual(len(car_params.longitudinalTuning.kiV), len(car_params.longitudinalTuning.kiBP)) - self.assertEqual(len(car_params.longitudinalTuning.deadzoneV), len(car_params.longitudinalTuning.deadzoneBP)) + assert len(car_params.longitudinalTuning.kpV) == len(car_params.longitudinalTuning.kpBP) + assert len(car_params.longitudinalTuning.kiV) == len(car_params.longitudinalTuning.kiBP) + assert len(car_params.longitudinalTuning.deadzoneV) == len(car_params.longitudinalTuning.deadzoneBP) # Lateral sanity checks if car_params.steerControlType != car.CarParams.SteerControlType.angle: tune = car_params.lateralTuning if tune.which() == 'pid': - self.assertTrue(not math.isnan(tune.pid.kf) and tune.pid.kf > 0) - self.assertTrue(len(tune.pid.kpV) > 0 and len(tune.pid.kpV) == len(tune.pid.kpBP)) - self.assertTrue(len(tune.pid.kiV) > 0 and len(tune.pid.kiV) == len(tune.pid.kiBP)) + assert not math.isnan(tune.pid.kf) and tune.pid.kf > 0 + assert len(tune.pid.kpV) > 0 and len(tune.pid.kpV) == len(tune.pid.kpBP) + assert len(tune.pid.kiV) > 0 and len(tune.pid.kiV) == len(tune.pid.kiBP) elif tune.which() == 'torque': - self.assertTrue(not math.isnan(tune.torque.kf) and tune.torque.kf > 0) - self.assertTrue(not math.isnan(tune.torque.friction) and tune.torque.friction > 0) + assert not math.isnan(tune.torque.kf) and tune.torque.kf > 0 + assert not math.isnan(tune.torque.friction) and tune.torque.friction > 0 cc_msg = FuzzyGenerator.get_random_msg(data.draw, car.CarControl, real_floats=True) # Run car interface @@ -128,33 +127,29 @@ class TestCarInterfaces(unittest.TestCase): if not car_params.radarUnavailable and radar_interface.rcp is not None: cans = [messaging.new_message('can', 1).to_bytes() for _ in range(5)] rr = radar_interface.update(cans) - self.assertTrue(rr is None or len(rr.errors) > 0) + assert rr is None or len(rr.errors) > 0 def test_interface_attrs(self): """Asserts basic behavior of interface attribute getter""" num_brands = len(get_interface_attr('CAR')) - self.assertGreaterEqual(num_brands, 13) + assert num_brands >= 13 # Should return value for all brands when not combining, even if attribute doesn't exist ret = get_interface_attr('FAKE_ATTR') - self.assertEqual(len(ret), num_brands) + assert len(ret) == num_brands # Make sure we can combine dicts ret = get_interface_attr('DBC', combine_brands=True) - self.assertGreaterEqual(len(ret), 160) + assert len(ret) >= 160 # We don't support combining non-dicts ret = get_interface_attr('CAR', combine_brands=True) - self.assertEqual(len(ret), 0) + assert len(ret) == 0 # If brand has None value, it shouldn't return when ignore_none=True is specified none_brands = {b for b, v in get_interface_attr('FINGERPRINTS').items() if v is None} - self.assertGreaterEqual(len(none_brands), 1) + assert len(none_brands) >= 1 ret = get_interface_attr('FINGERPRINTS', ignore_none=True) none_brands_in_ret = none_brands.intersection(ret) - self.assertEqual(len(none_brands_in_ret), 0, f'Brands with None values in ignore_none=True result: {none_brands_in_ret}') - - -if __name__ == "__main__": - unittest.main() + assert len(none_brands_in_ret) == 0, f'Brands with None values in ignore_none=True result: {none_brands_in_ret}' diff --git a/selfdrive/car/tests/test_docs.py b/selfdrive/car/tests/test_docs.py index 143b402d5f..0ed95e18f2 100755 --- a/selfdrive/car/tests/test_docs.py +++ b/selfdrive/car/tests/test_docs.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 from collections import defaultdict import os +import pytest import re -import unittest from openpilot.common.basedir import BASEDIR from openpilot.selfdrive.car.car_helpers import interfaces @@ -14,9 +14,9 @@ from openpilot.selfdrive.debug.dump_car_docs import dump_car_docs from openpilot.selfdrive.debug.print_docs_diff import print_car_docs_diff -class TestCarDocs(unittest.TestCase): +class TestCarDocs: @classmethod - def setUpClass(cls): + def setup_class(cls): cls.all_cars = get_all_car_docs() def test_generator(self): @@ -24,8 +24,7 @@ class TestCarDocs(unittest.TestCase): with open(CARS_MD_OUT) as f: current_cars_md = f.read() - self.assertEqual(generated_cars_md, current_cars_md, - "Run selfdrive/car/docs.py to update the compatibility documentation") + assert generated_cars_md == current_cars_md, "Run selfdrive/car/docs.py to update the compatibility documentation" def test_docs_diff(self): dump_path = os.path.join(BASEDIR, "selfdrive", "car", "tests", "cars_dump") @@ -33,65 +32,61 @@ class TestCarDocs(unittest.TestCase): print_car_docs_diff(dump_path) os.remove(dump_path) - def test_duplicate_years(self): + def test_duplicate_years(self, subtests): make_model_years = defaultdict(list) for car in self.all_cars: - with self.subTest(car_docs_name=car.name): + with subtests.test(car_docs_name=car.name): make_model = (car.make, car.model) for year in car.year_list: - self.assertNotIn(year, make_model_years[make_model], f"{car.name}: Duplicate model year") + assert year not in make_model_years[make_model], f"{car.name}: Duplicate model year" make_model_years[make_model].append(year) - def test_missing_car_docs(self): + def test_missing_car_docs(self, subtests): all_car_docs_platforms = [name for name, config in PLATFORMS.items()] for platform in sorted(interfaces.keys()): - with self.subTest(platform=platform): - self.assertTrue(platform in all_car_docs_platforms, f"Platform: {platform} doesn't have a CarDocs entry") + with subtests.test(platform=platform): + assert platform in all_car_docs_platforms, f"Platform: {platform} doesn't have a CarDocs entry" - def test_naming_conventions(self): + def test_naming_conventions(self, subtests): # Asserts market-standard car naming conventions by brand for car in self.all_cars: - with self.subTest(car=car): + with subtests.test(car=car.name): tokens = car.model.lower().split(" ") if car.car_name == "hyundai": - self.assertNotIn("phev", tokens, "Use `Plug-in Hybrid`") - self.assertNotIn("hev", tokens, "Use `Hybrid`") + assert "phev" not in tokens, "Use `Plug-in Hybrid`" + assert "hev" not in tokens, "Use `Hybrid`" if "plug-in hybrid" in car.model.lower(): - self.assertIn("Plug-in Hybrid", car.model, "Use correct capitalization") + assert "Plug-in Hybrid" in car.model, "Use correct capitalization" if car.make != "Kia": - self.assertNotIn("ev", tokens, "Use `Electric`") + assert "ev" not in tokens, "Use `Electric`" elif car.car_name == "toyota": if "rav4" in tokens: - self.assertIn("RAV4", car.model, "Use correct capitalization") + assert "RAV4" in car.model, "Use correct capitalization" - def test_torque_star(self): + def test_torque_star(self, subtests): # Asserts brand-specific assumptions around steering torque star for car in self.all_cars: - with self.subTest(car=car): + with subtests.test(car=car.name): # honda sanity check, it's the definition of a no torque star if car.car_fingerprint in (HONDA.HONDA_ACCORD, HONDA.HONDA_CIVIC, HONDA.HONDA_CRV, HONDA.HONDA_ODYSSEY, HONDA.HONDA_PILOT): - self.assertEqual(car.row[Column.STEERING_TORQUE], Star.EMPTY, f"{car.name} has full torque star") + assert car.row[Column.STEERING_TORQUE] == Star.EMPTY, f"{car.name} has full torque star" elif car.car_name in ("toyota", "hyundai"): - self.assertNotEqual(car.row[Column.STEERING_TORQUE], Star.EMPTY, f"{car.name} has no torque star") + assert car.row[Column.STEERING_TORQUE] != Star.EMPTY, f"{car.name} has no torque star" - def test_year_format(self): + def test_year_format(self, subtests): for car in self.all_cars: - with self.subTest(car=car): - self.assertIsNone(re.search(r"\d{4}-\d{4}", car.name), f"Format years correctly: {car.name}") + with subtests.test(car=car.name): + assert re.search(r"\d{4}-\d{4}", car.name) is None, f"Format years correctly: {car.name}" - def test_harnesses(self): + def test_harnesses(self, subtests): for car in self.all_cars: - with self.subTest(car=car): + with subtests.test(car=car.name): if car.name == "comma body": - raise unittest.SkipTest + pytest.skip() car_part_type = [p.part_type for p in car.car_parts.all_parts()] car_parts = list(car.car_parts.all_parts()) - self.assertTrue(len(car_parts) > 0, f"Need to specify car parts: {car.name}") - self.assertTrue(car_part_type.count(PartType.connector) == 1, f"Need to specify one harness connector: {car.name}") - self.assertTrue(car_part_type.count(PartType.mount) == 1, f"Need to specify one mount: {car.name}") - self.assertTrue(Cable.right_angle_obd_c_cable_1_5ft in car_parts, f"Need to specify a right angle OBD-C cable (1.5ft): {car.name}") - - -if __name__ == "__main__": - unittest.main() + assert len(car_parts) > 0, f"Need to specify car parts: {car.name}" + assert car_part_type.count(PartType.connector) == 1, f"Need to specify one harness connector: {car.name}" + assert car_part_type.count(PartType.mount) == 1, f"Need to specify one mount: {car.name}" + assert Cable.right_angle_obd_c_cable_1_5ft in car_parts, f"Need to specify a right angle OBD-C cable (1.5ft): {car.name}" diff --git a/selfdrive/car/tests/test_fw_fingerprint.py b/selfdrive/car/tests/test_fw_fingerprint.py index ed5edbef31..230e6f10e1 100755 --- a/selfdrive/car/tests/test_fw_fingerprint.py +++ b/selfdrive/car/tests/test_fw_fingerprint.py @@ -1,10 +1,9 @@ #!/usr/bin/env python3 +import pytest import random import time -import unittest from collections import defaultdict from parameterized import parameterized -from unittest import mock from cereal import car from openpilot.selfdrive.car.car_helpers import interfaces @@ -27,11 +26,11 @@ class FakeSocket: pass -class TestFwFingerprint(unittest.TestCase): +class TestFwFingerprint: def assertFingerprints(self, candidates, expected): candidates = list(candidates) - self.assertEqual(len(candidates), 1, f"got more than one candidate: {candidates}") - self.assertEqual(candidates[0], expected) + assert len(candidates) == 1, f"got more than one candidate: {candidates}" + assert candidates[0] == expected @parameterized.expand([(b, c, e[c], n) for b, e in VERSIONS.items() for c in e for n in (True, False)]) def test_exact_match(self, brand, car_model, ecus, test_non_essential): @@ -62,7 +61,7 @@ class TestFwFingerprint(unittest.TestCase): # Assert brand-specific fuzzy fingerprinting function doesn't disagree with standard fuzzy function config = FW_QUERY_CONFIGS[brand] if config.match_fw_to_car_fuzzy is None: - raise unittest.SkipTest("Brand does not implement custom fuzzy fingerprinting function") + pytest.skip("Brand does not implement custom fuzzy fingerprinting function") CP = car.CarParams.new_message() for _ in range(5): @@ -77,14 +76,14 @@ class TestFwFingerprint(unittest.TestCase): # If both have matches, they must agree if len(matches) == 1 and len(brand_matches) == 1: - self.assertEqual(matches, brand_matches) + assert matches == brand_matches @parameterized.expand([(b, c, e[c]) for b, e in VERSIONS.items() for c in e]) def test_fuzzy_match_ecu_count(self, brand, car_model, ecus): # Asserts that fuzzy matching does not count matching FW, but ECU address keys valid_ecus = [e for e in ecus if e[0] not in FUZZY_EXCLUDE_ECUS] if not len(valid_ecus): - raise unittest.SkipTest("Car model has no compatible ECUs for fuzzy matching") + pytest.skip("Car model has no compatible ECUs for fuzzy matching") fw = [] for ecu in valid_ecus: @@ -99,19 +98,19 @@ class TestFwFingerprint(unittest.TestCase): # Assert no match if there are not enough unique ECUs unique_ecus = {(f['address'], f['subAddress']) for f in fw} if len(unique_ecus) < 2: - self.assertEqual(len(matches), 0, car_model) + assert len(matches) == 0, car_model # There won't always be a match due to shared FW, but if there is it should be correct elif len(matches): self.assertFingerprints(matches, car_model) - def test_fw_version_lists(self): + def test_fw_version_lists(self, subtests): for car_model, ecus in FW_VERSIONS.items(): - with self.subTest(car_model=car_model.value): + with subtests.test(car_model=car_model.value): for ecu, ecu_fw in ecus.items(): - with self.subTest(ecu): + with subtests.test(ecu): duplicates = {fw for fw in ecu_fw if ecu_fw.count(fw) > 1} - self.assertFalse(len(duplicates), f'{car_model}: Duplicate FW versions: Ecu.{ECU_NAME[ecu[0]]}, {duplicates}') - self.assertGreater(len(ecu_fw), 0, f'{car_model}: No FW versions: Ecu.{ECU_NAME[ecu[0]]}') + assert not len(duplicates), f'{car_model}: Duplicate FW versions: Ecu.{ECU_NAME[ecu[0]]}, {duplicates}' + assert len(ecu_fw) > 0, f'{car_model}: No FW versions: Ecu.{ECU_NAME[ecu[0]]}' def test_all_addrs_map_to_one_ecu(self): for brand, cars in VERSIONS.items(): @@ -121,59 +120,59 @@ class TestFwFingerprint(unittest.TestCase): addr_to_ecu[(addr, sub_addr)].add(ecu_type) ecus_for_addr = addr_to_ecu[(addr, sub_addr)] ecu_strings = ", ".join([f'Ecu.{ECU_NAME[ecu]}' for ecu in ecus_for_addr]) - self.assertLessEqual(len(ecus_for_addr), 1, f"{brand} has multiple ECUs that map to one address: {ecu_strings} -> ({hex(addr)}, {sub_addr})") + assert len(ecus_for_addr) <= 1, f"{brand} has multiple ECUs that map to one address: {ecu_strings} -> ({hex(addr)}, {sub_addr})" - def test_data_collection_ecus(self): + def test_data_collection_ecus(self, subtests): # Asserts no extra ECUs are in the fingerprinting database for brand, config in FW_QUERY_CONFIGS.items(): for car_model, ecus in VERSIONS[brand].items(): bad_ecus = set(ecus).intersection(config.extra_ecus) - with self.subTest(car_model=car_model.value): - self.assertFalse(len(bad_ecus), f'{car_model}: Fingerprints contain ECUs added for data collection: {bad_ecus}') + with subtests.test(car_model=car_model.value): + assert not len(bad_ecus), f'{car_model}: Fingerprints contain ECUs added for data collection: {bad_ecus}' - def test_blacklisted_ecus(self): + def test_blacklisted_ecus(self, subtests): blacklisted_addrs = (0x7c4, 0x7d0) # includes A/C ecu and an unknown ecu for car_model, ecus in FW_VERSIONS.items(): - with self.subTest(car_model=car_model.value): + with subtests.test(car_model=car_model.value): CP = interfaces[car_model][0].get_non_essential_params(car_model) if CP.carName == 'subaru': for ecu in ecus.keys(): - self.assertNotIn(ecu[1], blacklisted_addrs, f'{car_model}: Blacklisted ecu: (Ecu.{ECU_NAME[ecu[0]]}, {hex(ecu[1])})') + assert ecu[1] not in blacklisted_addrs, f'{car_model}: Blacklisted ecu: (Ecu.{ECU_NAME[ecu[0]]}, {hex(ecu[1])})' elif CP.carName == "chrysler": # Some HD trucks have a combined TCM and ECM if CP.carFingerprint.startswith("RAM HD"): for ecu in ecus.keys(): - self.assertNotEqual(ecu[0], Ecu.transmission, f"{car_model}: Blacklisted ecu: (Ecu.{ECU_NAME[ecu[0]]}, {hex(ecu[1])})") + assert ecu[0] != Ecu.transmission, f"{car_model}: Blacklisted ecu: (Ecu.{ECU_NAME[ecu[0]]}, {hex(ecu[1])})" - def test_non_essential_ecus(self): + def test_non_essential_ecus(self, subtests): for brand, config in FW_QUERY_CONFIGS.items(): - with self.subTest(brand): + with subtests.test(brand): # These ECUs are already not in ESSENTIAL_ECUS which the fingerprint functions give a pass if missing unnecessary_non_essential_ecus = set(config.non_essential_ecus) - set(ESSENTIAL_ECUS) - self.assertEqual(unnecessary_non_essential_ecus, set(), "Declaring non-essential ECUs non-essential is not required: " + - f"{', '.join([f'Ecu.{ECU_NAME[ecu]}' for ecu in unnecessary_non_essential_ecus])}") + assert unnecessary_non_essential_ecus == set(), "Declaring non-essential ECUs non-essential is not required: " + \ + f"{', '.join([f'Ecu.{ECU_NAME[ecu]}' for ecu in unnecessary_non_essential_ecus])}" - def test_missing_versions_and_configs(self): + def test_missing_versions_and_configs(self, subtests): brand_versions = set(VERSIONS.keys()) brand_configs = set(FW_QUERY_CONFIGS.keys()) if len(brand_configs - brand_versions): - with self.subTest(): - self.fail(f"Brands do not implement FW_VERSIONS: {brand_configs - brand_versions}") + with subtests.test(): + pytest.fail(f"Brands do not implement FW_VERSIONS: {brand_configs - brand_versions}") if len(brand_versions - brand_configs): - with self.subTest(): - self.fail(f"Brands do not implement FW_QUERY_CONFIG: {brand_versions - brand_configs}") + with subtests.test(): + pytest.fail(f"Brands do not implement FW_QUERY_CONFIG: {brand_versions - brand_configs}") # Ensure each brand has at least 1 ECU to query, and extra ECU retrieval for brand, config in FW_QUERY_CONFIGS.items(): - self.assertEqual(len(config.get_all_ecus({}, include_extra_ecus=False)), 0) - self.assertEqual(config.get_all_ecus({}), set(config.extra_ecus)) - self.assertGreater(len(config.get_all_ecus(VERSIONS[brand])), 0) + assert len(config.get_all_ecus({}, include_extra_ecus=False)) == 0 + assert config.get_all_ecus({}) == set(config.extra_ecus) + assert len(config.get_all_ecus(VERSIONS[brand])) > 0 - def test_fw_request_ecu_whitelist(self): + def test_fw_request_ecu_whitelist(self, subtests): for brand, config in FW_QUERY_CONFIGS.items(): - with self.subTest(brand=brand): + with subtests.test(brand=brand): whitelisted_ecus = {ecu for r in config.requests for ecu in r.whitelist_ecus} brand_ecus = {fw[0] for car_fw in VERSIONS[brand].values() for fw in car_fw} brand_ecus |= {ecu[0] for ecu in config.extra_ecus} @@ -182,30 +181,30 @@ class TestFwFingerprint(unittest.TestCase): ecus_not_whitelisted = brand_ecus - whitelisted_ecus ecu_strings = ", ".join([f'Ecu.{ECU_NAME[ecu]}' for ecu in ecus_not_whitelisted]) - self.assertFalse(len(whitelisted_ecus) and len(ecus_not_whitelisted), - f'{brand.title()}: ECUs not in any FW query whitelists: {ecu_strings}') + assert not (len(whitelisted_ecus) and len(ecus_not_whitelisted)), \ + f'{brand.title()}: ECUs not in any FW query whitelists: {ecu_strings}' - def test_fw_requests(self): + def test_fw_requests(self, subtests): # Asserts equal length request and response lists for brand, config in FW_QUERY_CONFIGS.items(): - with self.subTest(brand=brand): + with subtests.test(brand=brand): for request_obj in config.requests: - self.assertEqual(len(request_obj.request), len(request_obj.response)) + assert len(request_obj.request) == len(request_obj.response) # No request on the OBD port (bus 1, multiplexed) should be run on an aux panda - self.assertFalse(request_obj.auxiliary and request_obj.bus == 1 and request_obj.obd_multiplexing, - f"{brand.title()}: OBD multiplexed request is marked auxiliary: {request_obj}") + assert not (request_obj.auxiliary and request_obj.bus == 1 and request_obj.obd_multiplexing), \ + f"{brand.title()}: OBD multiplexed request is marked auxiliary: {request_obj}" def test_brand_ecu_matches(self): empty_response = {brand: set() for brand in FW_QUERY_CONFIGS} - self.assertEqual(get_brand_ecu_matches(set()), empty_response) + assert get_brand_ecu_matches(set()) == empty_response # we ignore bus expected_response = empty_response | {'toyota': {(0x750, 0xf)}} - self.assertEqual(get_brand_ecu_matches({(0x758, 0xf, 99)}), expected_response) + assert get_brand_ecu_matches({(0x758, 0xf, 99)}) == expected_response -class TestFwFingerprintTiming(unittest.TestCase): +class TestFwFingerprintTiming: N: int = 5 TOL: float = 0.05 @@ -223,26 +222,26 @@ class TestFwFingerprintTiming(unittest.TestCase): self.total_time += timeout return {} - def _benchmark_brand(self, brand, num_pandas): + def _benchmark_brand(self, brand, num_pandas, mocker): fake_socket = FakeSocket() self.total_time = 0 - with (mock.patch("openpilot.selfdrive.car.fw_versions.set_obd_multiplexing", self.fake_set_obd_multiplexing), - mock.patch("openpilot.selfdrive.car.isotp_parallel_query.IsoTpParallelQuery.get_data", self.fake_get_data)): - for _ in range(self.N): - # Treat each brand as the most likely (aka, the first) brand with OBD multiplexing initially on - self.current_obd_multiplexing = True + mocker.patch("openpilot.selfdrive.car.fw_versions.set_obd_multiplexing", self.fake_set_obd_multiplexing) + mocker.patch("openpilot.selfdrive.car.isotp_parallel_query.IsoTpParallelQuery.get_data", self.fake_get_data) + for _ in range(self.N): + # Treat each brand as the most likely (aka, the first) brand with OBD multiplexing initially on + self.current_obd_multiplexing = True - t = time.perf_counter() - get_fw_versions(fake_socket, fake_socket, brand, num_pandas=num_pandas) - self.total_time += time.perf_counter() - t + t = time.perf_counter() + get_fw_versions(fake_socket, fake_socket, brand, num_pandas=num_pandas) + self.total_time += time.perf_counter() - t return self.total_time / self.N def _assert_timing(self, avg_time, ref_time): - self.assertLess(avg_time, ref_time + self.TOL) - self.assertGreater(avg_time, ref_time - self.TOL, "Performance seems to have improved, update test refs.") + assert avg_time < ref_time + self.TOL + assert avg_time > ref_time - self.TOL, "Performance seems to have improved, update test refs." - def test_startup_timing(self): + def test_startup_timing(self, subtests, mocker): # Tests worse-case VIN query time and typical present ECU query time vin_ref_times = {'worst': 1.4, 'best': 0.7} # best assumes we go through all queries to get a match present_ecu_ref_time = 0.45 @@ -253,24 +252,24 @@ class TestFwFingerprintTiming(unittest.TestCase): fake_socket = FakeSocket() self.total_time = 0.0 - with (mock.patch("openpilot.selfdrive.car.fw_versions.set_obd_multiplexing", self.fake_set_obd_multiplexing), - mock.patch("openpilot.selfdrive.car.fw_versions.get_ecu_addrs", fake_get_ecu_addrs)): - for _ in range(self.N): - self.current_obd_multiplexing = True - get_present_ecus(fake_socket, fake_socket, num_pandas=2) + mocker.patch("openpilot.selfdrive.car.fw_versions.set_obd_multiplexing", self.fake_set_obd_multiplexing) + mocker.patch("openpilot.selfdrive.car.fw_versions.get_ecu_addrs", fake_get_ecu_addrs) + for _ in range(self.N): + self.current_obd_multiplexing = True + get_present_ecus(fake_socket, fake_socket, num_pandas=2) self._assert_timing(self.total_time / self.N, present_ecu_ref_time) print(f'get_present_ecus, query time={self.total_time / self.N} seconds') for name, args in (('worst', {}), ('best', {'retry': 1})): - with self.subTest(name=name): + with subtests.test(name=name): self.total_time = 0.0 - with (mock.patch("openpilot.selfdrive.car.isotp_parallel_query.IsoTpParallelQuery.get_data", self.fake_get_data)): - for _ in range(self.N): - get_vin(fake_socket, fake_socket, (0, 1), **args) + mocker.patch("openpilot.selfdrive.car.isotp_parallel_query.IsoTpParallelQuery.get_data", self.fake_get_data) + for _ in range(self.N): + get_vin(fake_socket, fake_socket, (0, 1), **args) self._assert_timing(self.total_time / self.N, vin_ref_times[name]) print(f'get_vin {name} case, query time={self.total_time / self.N} seconds') - def test_fw_query_timing(self): + def test_fw_query_timing(self, subtests, mocker): total_ref_time = {1: 7.2, 2: 7.8} brand_ref_times = { 1: { @@ -297,8 +296,8 @@ class TestFwFingerprintTiming(unittest.TestCase): total_times = {1: 0.0, 2: 0.0} for num_pandas in (1, 2): for brand, config in FW_QUERY_CONFIGS.items(): - with self.subTest(brand=brand, num_pandas=num_pandas): - avg_time = self._benchmark_brand(brand, num_pandas) + with subtests.test(brand=brand, num_pandas=num_pandas): + avg_time = self._benchmark_brand(brand, num_pandas, mocker) total_times[num_pandas] += avg_time avg_time = round(avg_time, 2) @@ -311,11 +310,7 @@ class TestFwFingerprintTiming(unittest.TestCase): print(f'{brand=}, {num_pandas=}, {len(config.requests)=}, avg FW query time={avg_time} seconds') for num_pandas in (1, 2): - with self.subTest(brand='all_brands', num_pandas=num_pandas): + with subtests.test(brand='all_brands', num_pandas=num_pandas): total_time = round(total_times[num_pandas], 2) self._assert_timing(total_time, total_ref_time[num_pandas]) print(f'all brands, total FW query time={total_time} seconds') - - -if __name__ == "__main__": - unittest.main() diff --git a/selfdrive/car/tests/test_lateral_limits.py b/selfdrive/car/tests/test_lateral_limits.py index e5cfd972bd..a478bc601a 100755 --- a/selfdrive/car/tests/test_lateral_limits.py +++ b/selfdrive/car/tests/test_lateral_limits.py @@ -2,8 +2,7 @@ from collections import defaultdict import importlib from parameterized import parameterized_class -import sys -import unittest +import pytest from openpilot.common.realtime import DT_CTRL from openpilot.selfdrive.car.car_helpers import interfaces @@ -25,23 +24,23 @@ car_model_jerks: defaultdict[str, dict[str, float]] = defaultdict(dict) @parameterized_class('car_model', [(c,) for c in sorted(CAR_MODELS)]) -class TestLateralLimits(unittest.TestCase): +class TestLateralLimits: car_model: str @classmethod - def setUpClass(cls): + def setup_class(cls): CarInterface, _, _ = interfaces[cls.car_model] CP = CarInterface.get_non_essential_params(cls.car_model) if CP.dashcamOnly: - raise unittest.SkipTest("Platform is behind dashcamOnly") + pytest.skip("Platform is behind dashcamOnly") # TODO: test all platforms if CP.lateralTuning.which() != 'torque': - raise unittest.SkipTest + pytest.skip() if CP.notCar: - raise unittest.SkipTest + pytest.skip() CarControllerParams = importlib.import_module(f'selfdrive.car.{CP.carName}.values').CarControllerParams cls.control_params = CarControllerParams(CP) @@ -66,26 +65,8 @@ class TestLateralLimits(unittest.TestCase): def test_jerk_limits(self): up_jerk, down_jerk = self.calculate_0_5s_jerk(self.control_params, self.torque_params) car_model_jerks[self.car_model] = {"up_jerk": up_jerk, "down_jerk": down_jerk} - self.assertLessEqual(up_jerk, MAX_LAT_JERK_UP + MAX_LAT_JERK_UP_TOLERANCE) - self.assertLessEqual(down_jerk, MAX_LAT_JERK_DOWN) + assert up_jerk <= MAX_LAT_JERK_UP + MAX_LAT_JERK_UP_TOLERANCE + assert down_jerk <= MAX_LAT_JERK_DOWN def test_max_lateral_accel(self): - self.assertLessEqual(self.torque_params["MAX_LAT_ACCEL_MEASURED"], MAX_LAT_ACCEL) - - -if __name__ == "__main__": - result = unittest.main(exit=False) - - print(f"\n\n---- Lateral limit report ({len(CAR_MODELS)} cars) ----\n") - - max_car_model_len = max([len(car_model) for car_model in car_model_jerks]) - for car_model, _jerks in sorted(car_model_jerks.items(), key=lambda i: i[1]['up_jerk'], reverse=True): - violation = _jerks["up_jerk"] > MAX_LAT_JERK_UP + MAX_LAT_JERK_UP_TOLERANCE or \ - _jerks["down_jerk"] > MAX_LAT_JERK_DOWN - violation_str = " - VIOLATION" if violation else "" - - print(f"{car_model:{max_car_model_len}} - up jerk: {round(_jerks['up_jerk'], 2):5} " + - f"m/s^3, down jerk: {round(_jerks['down_jerk'], 2):5} m/s^3{violation_str}") - - # exit with test result - sys.exit(not result.result.wasSuccessful()) + assert self.torque_params["MAX_LAT_ACCEL_MEASURED"] <= MAX_LAT_ACCEL diff --git a/selfdrive/car/tests/test_models.py b/selfdrive/car/tests/test_models.py index dc3d4256a2..026693bdce 100755 --- a/selfdrive/car/tests/test_models.py +++ b/selfdrive/car/tests/test_models.py @@ -4,7 +4,7 @@ import os import importlib import pytest import random -import unittest +import unittest # noqa: TID251 from collections import defaultdict, Counter import hypothesis.strategies as st from hypothesis import Phase, given, settings diff --git a/selfdrive/car/tests/test_platform_configs.py b/selfdrive/car/tests/test_platform_configs.py index 523c331b9e..217189255e 100755 --- a/selfdrive/car/tests/test_platform_configs.py +++ b/selfdrive/car/tests/test_platform_configs.py @@ -1,25 +1,19 @@ #!/usr/bin/env python3 -import unittest - from openpilot.selfdrive.car.values import PLATFORMS -class TestPlatformConfigs(unittest.TestCase): - def test_configs(self): +class TestPlatformConfigs: + def test_configs(self, subtests): for name, platform in PLATFORMS.items(): - with self.subTest(platform=str(platform)): - self.assertTrue(platform.config._frozen) + with subtests.test(platform=str(platform)): + assert platform.config._frozen if platform != "MOCK": - self.assertIn("pt", platform.config.dbc_dict) - self.assertTrue(len(platform.config.platform_str) > 0) - - self.assertEqual(name, platform.config.platform_str) - - self.assertIsNotNone(platform.config.specs) + assert "pt" in platform.config.dbc_dict + assert len(platform.config.platform_str) > 0 + assert name == platform.config.platform_str -if __name__ == "__main__": - unittest.main() + assert platform.config.specs is not None diff --git a/selfdrive/car/toyota/tests/test_toyota.py b/selfdrive/car/toyota/tests/test_toyota.py index e2a9b46eb4..ef49e00551 100755 --- a/selfdrive/car/toyota/tests/test_toyota.py +++ b/selfdrive/car/toyota/tests/test_toyota.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 from hypothesis import given, settings, strategies as st -import unittest from cereal import car from openpilot.selfdrive.car.fw_versions import build_fw_dict @@ -17,59 +16,58 @@ def check_fw_version(fw_version: bytes) -> bool: return b'?' not in fw_version -class TestToyotaInterfaces(unittest.TestCase): +class TestToyotaInterfaces: def test_car_sets(self): - self.assertTrue(len(ANGLE_CONTROL_CAR - TSS2_CAR) == 0) - self.assertTrue(len(RADAR_ACC_CAR - TSS2_CAR) == 0) + assert len(ANGLE_CONTROL_CAR - TSS2_CAR) == 0 + assert len(RADAR_ACC_CAR - TSS2_CAR) == 0 def test_lta_platforms(self): # At this time, only RAV4 2023 is expected to use LTA/angle control - self.assertEqual(ANGLE_CONTROL_CAR, {CAR.TOYOTA_RAV4_TSS2_2023}) + assert ANGLE_CONTROL_CAR == {CAR.TOYOTA_RAV4_TSS2_2023} def test_tss2_dbc(self): # We make some assumptions about TSS2 platforms, # like looking up certain signals only in this DBC for car_model, dbc in DBC.items(): if car_model in TSS2_CAR: - self.assertEqual(dbc["pt"], "toyota_nodsu_pt_generated") + assert dbc["pt"] == "toyota_nodsu_pt_generated" - def test_essential_ecus(self): + def test_essential_ecus(self, subtests): # Asserts standard ECUs exist for each platform common_ecus = {Ecu.fwdRadar, Ecu.fwdCamera} for car_model, ecus in FW_VERSIONS.items(): - with self.subTest(car_model=car_model.value): + with subtests.test(car_model=car_model.value): present_ecus = {ecu[0] for ecu in ecus} missing_ecus = common_ecus - present_ecus - self.assertEqual(len(missing_ecus), 0) + assert len(missing_ecus) == 0 # Some exceptions for other common ECUs if car_model not in (CAR.TOYOTA_ALPHARD_TSS2,): - self.assertIn(Ecu.abs, present_ecus) + assert Ecu.abs in present_ecus if car_model not in (CAR.TOYOTA_MIRAI,): - self.assertIn(Ecu.engine, present_ecus) + assert Ecu.engine in present_ecus if car_model not in (CAR.TOYOTA_PRIUS_V, CAR.LEXUS_CTH): - self.assertIn(Ecu.eps, present_ecus) + assert Ecu.eps in present_ecus -class TestToyotaFingerprint(unittest.TestCase): - def test_non_essential_ecus(self): +class TestToyotaFingerprint: + def test_non_essential_ecus(self, subtests): # Ensures only the cars that have multiple engine ECUs are in the engine non-essential ECU list for car_model, ecus in FW_VERSIONS.items(): - with self.subTest(car_model=car_model.value): + with subtests.test(car_model=car_model.value): engine_ecus = {ecu for ecu in ecus if ecu[0] == Ecu.engine} - self.assertEqual(len(engine_ecus) > 1, - car_model in FW_QUERY_CONFIG.non_essential_ecus[Ecu.engine], - f"Car model unexpectedly {'not ' if len(engine_ecus) > 1 else ''}in non-essential list") + assert (len(engine_ecus) > 1) == (car_model in FW_QUERY_CONFIG.non_essential_ecus[Ecu.engine]), \ + f"Car model unexpectedly {'not ' if len(engine_ecus) > 1 else ''}in non-essential list" - def test_valid_fw_versions(self): + def test_valid_fw_versions(self, subtests): # Asserts all FW versions are valid for car_model, ecus in FW_VERSIONS.items(): - with self.subTest(car_model=car_model.value): + with subtests.test(car_model=car_model.value): for fws in ecus.values(): for fw in fws: - self.assertTrue(check_fw_version(fw), fw) + assert check_fw_version(fw), fw # Tests for part numbers, platform codes, and sub-versions which Toyota will use to fuzzy # fingerprint in the absence of full FW matches: @@ -80,25 +78,25 @@ class TestToyotaFingerprint(unittest.TestCase): fws = data.draw(fw_strategy) get_platform_codes(fws) - def test_platform_code_ecus_available(self): + def test_platform_code_ecus_available(self, subtests): # Asserts ECU keys essential for fuzzy fingerprinting are available on all platforms for car_model, ecus in FW_VERSIONS.items(): - with self.subTest(car_model=car_model.value): + with subtests.test(car_model=car_model.value): for platform_code_ecu in PLATFORM_CODE_ECUS: if platform_code_ecu == Ecu.eps and car_model in (CAR.TOYOTA_PRIUS_V, CAR.LEXUS_CTH,): continue if platform_code_ecu == Ecu.abs and car_model in (CAR.TOYOTA_ALPHARD_TSS2,): continue - self.assertIn(platform_code_ecu, [e[0] for e in ecus]) + assert platform_code_ecu in [e[0] for e in ecus] - def test_fw_format(self): + def test_fw_format(self, subtests): # Asserts: # - every supported ECU FW version returns one platform code # - every supported ECU FW version has a part number # - expected parsing of ECU sub-versions for car_model, ecus in FW_VERSIONS.items(): - with self.subTest(car_model=car_model.value): + with subtests.test(car_model=car_model.value): for ecu, fws in ecus.items(): if ecu[0] not in PLATFORM_CODE_ECUS: continue @@ -107,15 +105,14 @@ class TestToyotaFingerprint(unittest.TestCase): for fw in fws: result = get_platform_codes([fw]) # Check only one platform code and sub-version - self.assertEqual(1, len(result), f"Unable to parse FW: {fw}") - self.assertEqual(1, len(list(result.values())[0]), f"Unable to parse FW: {fw}") + assert 1 == len(result), f"Unable to parse FW: {fw}" + assert 1 == len(list(result.values())[0]), f"Unable to parse FW: {fw}" codes |= result # Toyota places the ECU part number in their FW versions, assert all parsable # Note that there is only one unique part number per ECU across the fleet, so this # is not important for identification, just a sanity check. - self.assertTrue(all(code.count(b"-") > 1 for code in codes), - f"FW does not have part number: {fw} {codes}") + assert all(code.count(b"-") > 1 for code in codes), f"FW does not have part number: {fw} {codes}" def test_platform_codes_spot_check(self): # Asserts basic platform code parsing behavior for a few cases @@ -125,20 +122,20 @@ class TestToyotaFingerprint(unittest.TestCase): b"F152607110\x00\x00\x00\x00\x00\x00", b"F152607180\x00\x00\x00\x00\x00\x00", ]) - self.assertEqual(results, {b"F1526-07-1": {b"10", b"40", b"71", b"80"}}) + assert results == {b"F1526-07-1": {b"10", b"40", b"71", b"80"}} results = get_platform_codes([ b"\x028646F4104100\x00\x00\x00\x008646G5301200\x00\x00\x00\x00", b"\x028646F4104100\x00\x00\x00\x008646G3304000\x00\x00\x00\x00", ]) - self.assertEqual(results, {b"8646F-41-04": {b"100"}}) + assert results == {b"8646F-41-04": {b"100"}} # Short version has no part number results = get_platform_codes([ b"\x0235870000\x00\x00\x00\x00\x00\x00\x00\x00A0202000\x00\x00\x00\x00\x00\x00\x00\x00", b"\x0235883000\x00\x00\x00\x00\x00\x00\x00\x00A0202000\x00\x00\x00\x00\x00\x00\x00\x00", ]) - self.assertEqual(results, {b"58-70": {b"000"}, b"58-83": {b"000"}}) + assert results == {b"58-70": {b"000"}, b"58-83": {b"000"}} results = get_platform_codes([ b"F152607110\x00\x00\x00\x00\x00\x00", @@ -146,7 +143,7 @@ class TestToyotaFingerprint(unittest.TestCase): b"\x028646F4104100\x00\x00\x00\x008646G5301200\x00\x00\x00\x00", b"\x0235879000\x00\x00\x00\x00\x00\x00\x00\x00A4701000\x00\x00\x00\x00\x00\x00\x00\x00", ]) - self.assertEqual(results, {b"F1526-07-1": {b"10", b"40"}, b"8646F-41-04": {b"100"}, b"58-79": {b"000"}}) + assert results == {b"F1526-07-1": {b"10", b"40"}, b"8646F-41-04": {b"100"}, b"58-79": {b"000"}} def test_fuzzy_excluded_platforms(self): # Asserts a list of platforms that will not fuzzy fingerprint with platform codes due to them being shared. @@ -162,13 +159,9 @@ class TestToyotaFingerprint(unittest.TestCase): CP = car.CarParams.new_message(carFw=car_fw) matches = FW_QUERY_CONFIG.match_fw_to_car_fuzzy(build_fw_dict(CP.carFw), CP.carVin, FW_VERSIONS) if len(matches) == 1: - self.assertEqual(list(matches)[0], platform) + assert list(matches)[0] == platform else: # If a platform has multiple matches, add it and its matches platforms_with_shared_codes |= {str(platform), *matches} - self.assertEqual(platforms_with_shared_codes, FUZZY_EXCLUDED_PLATFORMS, (len(platforms_with_shared_codes), len(FW_VERSIONS))) - - -if __name__ == "__main__": - unittest.main() + assert platforms_with_shared_codes == FUZZY_EXCLUDED_PLATFORMS, (len(platforms_with_shared_codes), len(FW_VERSIONS)) diff --git a/selfdrive/car/volkswagen/tests/test_volkswagen.py b/selfdrive/car/volkswagen/tests/test_volkswagen.py index 17331203bb..561d28b9fb 100755 --- a/selfdrive/car/volkswagen/tests/test_volkswagen.py +++ b/selfdrive/car/volkswagen/tests/test_volkswagen.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 import random import re -import unittest from cereal import car from openpilot.selfdrive.car.volkswagen.values import CAR, FW_QUERY_CONFIG, WMI @@ -14,35 +13,35 @@ CHASSIS_CODE_PATTERN = re.compile('[A-Z0-9]{2}') SPARE_PART_FW_PATTERN = re.compile(b'\xf1\x87(?P[0-9][0-9A-Z]{2})(?P[0-9][0-9A-Z][0-9])(?P[0-9A-Z]{2}[0-9])([A-Z0-9]| )') -class TestVolkswagenPlatformConfigs(unittest.TestCase): - def test_spare_part_fw_pattern(self): +class TestVolkswagenPlatformConfigs: + def test_spare_part_fw_pattern(self, subtests): # Relied on for determining if a FW is likely VW for platform, ecus in FW_VERSIONS.items(): - with self.subTest(platform=platform): + with subtests.test(platform=platform.value): for fws in ecus.values(): for fw in fws: - self.assertNotEqual(SPARE_PART_FW_PATTERN.match(fw), None, f"Bad FW: {fw}") + assert SPARE_PART_FW_PATTERN.match(fw) is not None, f"Bad FW: {fw}" - def test_chassis_codes(self): + def test_chassis_codes(self, subtests): for platform in CAR: - with self.subTest(platform=platform): - self.assertTrue(len(platform.config.wmis) > 0, "WMIs not set") - self.assertTrue(len(platform.config.chassis_codes) > 0, "Chassis codes not set") - self.assertTrue(all(CHASSIS_CODE_PATTERN.match(cc) for cc in - platform.config.chassis_codes), "Bad chassis codes") + with subtests.test(platform=platform.value): + assert len(platform.config.wmis) > 0, "WMIs not set" + assert len(platform.config.chassis_codes) > 0, "Chassis codes not set" + assert all(CHASSIS_CODE_PATTERN.match(cc) for cc in \ + platform.config.chassis_codes), "Bad chassis codes" # No two platforms should share chassis codes for comp in CAR: if platform == comp: continue - self.assertEqual(set(), platform.config.chassis_codes & comp.config.chassis_codes, - f"Shared chassis codes: {comp}") + assert set() == platform.config.chassis_codes & comp.config.chassis_codes, \ + f"Shared chassis codes: {comp}" - def test_custom_fuzzy_fingerprinting(self): + def test_custom_fuzzy_fingerprinting(self, subtests): all_radar_fw = list({fw for ecus in FW_VERSIONS.values() for fw in ecus[Ecu.fwdRadar, 0x757, None]}) for platform in CAR: - with self.subTest(platform=platform): + with subtests.test(platform=platform.name): for wmi in WMI: for chassis_code in platform.config.chassis_codes | {"00"}: vin = ["0"] * 17 @@ -59,8 +58,4 @@ class TestVolkswagenPlatformConfigs(unittest.TestCase): matches = FW_QUERY_CONFIG.match_fw_to_car_fuzzy(live_fws, vin, FW_VERSIONS) expected_matches = {platform} if should_match else set() - self.assertEqual(expected_matches, matches, "Bad match") - - -if __name__ == "__main__": - unittest.main() + assert expected_matches == matches, "Bad match" diff --git a/selfdrive/controls/lib/tests/test_alertmanager.py b/selfdrive/controls/lib/tests/test_alertmanager.py index dbd42858a0..c234cc49d6 100755 --- a/selfdrive/controls/lib/tests/test_alertmanager.py +++ b/selfdrive/controls/lib/tests/test_alertmanager.py @@ -1,12 +1,11 @@ #!/usr/bin/env python3 import random -import unittest from openpilot.selfdrive.controls.lib.events import Alert, EVENTS from openpilot.selfdrive.controls.lib.alertmanager import AlertManager -class TestAlertManager(unittest.TestCase): +class TestAlertManager: def test_duration(self): """ @@ -38,8 +37,4 @@ class TestAlertManager(unittest.TestCase): shown = current_alert is not None should_show = frame <= show_duration - self.assertEqual(shown, should_show, msg=f"{frame=} {add_duration=} {duration=}") - - -if __name__ == "__main__": - unittest.main() + assert shown == should_show, f"{frame=} {add_duration=} {duration=}" diff --git a/selfdrive/controls/lib/tests/test_latcontrol.py b/selfdrive/controls/lib/tests/test_latcontrol.py index 838023af72..b731bbd950 100755 --- a/selfdrive/controls/lib/tests/test_latcontrol.py +++ b/selfdrive/controls/lib/tests/test_latcontrol.py @@ -1,6 +1,4 @@ #!/usr/bin/env python3 -import unittest - from parameterized import parameterized from cereal import car, log @@ -15,7 +13,7 @@ from openpilot.selfdrive.controls.lib.vehicle_model import VehicleModel from openpilot.common.mock.generators import generate_liveLocationKalman -class TestLatControl(unittest.TestCase): +class TestLatControl: @parameterized.expand([(HONDA.HONDA_CIVIC, LatControlPID), (TOYOTA.TOYOTA_RAV4, LatControlTorque), (NISSAN.NISSAN_LEAF, LatControlAngle)]) def test_saturation(self, car_name, controller): @@ -36,8 +34,4 @@ class TestLatControl(unittest.TestCase): for _ in range(1000): _, _, lac_log = controller.update(True, CS, VM, params, False, 1, llk) - self.assertTrue(lac_log.saturated) - - -if __name__ == "__main__": - unittest.main() + assert lac_log.saturated diff --git a/selfdrive/controls/lib/tests/test_vehicle_model.py b/selfdrive/controls/lib/tests/test_vehicle_model.py index c3997afdf3..2efcf2fbbd 100755 --- a/selfdrive/controls/lib/tests/test_vehicle_model.py +++ b/selfdrive/controls/lib/tests/test_vehicle_model.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 +import pytest import math -import unittest import numpy as np from control import StateSpace @@ -10,8 +10,8 @@ from openpilot.selfdrive.car.honda.values import CAR from openpilot.selfdrive.controls.lib.vehicle_model import VehicleModel, dyn_ss_sol, create_dyn_state_matrices -class TestVehicleModel(unittest.TestCase): - def setUp(self): +class TestVehicleModel: + def setup_method(self): CP = CarInterface.get_non_essential_params(CAR.HONDA_CIVIC) self.VM = VehicleModel(CP) @@ -23,7 +23,7 @@ class TestVehicleModel(unittest.TestCase): yr = self.VM.yaw_rate(sa, u, roll) new_sa = self.VM.get_steer_from_yaw_rate(yr, u, roll) - self.assertAlmostEqual(sa, new_sa) + assert sa == pytest.approx(new_sa) def test_dyn_ss_sol_against_yaw_rate(self): """Verify that the yaw_rate helper function matches the results @@ -38,7 +38,7 @@ class TestVehicleModel(unittest.TestCase): # Compute yaw rate using direct computations yr2 = self.VM.yaw_rate(sa, u, roll) - self.assertAlmostEqual(float(yr1[0]), yr2) + assert float(yr1[0]) == pytest.approx(yr2) def test_syn_ss_sol_simulate(self): """Verifies that dyn_ss_sol matches a simulation""" @@ -63,8 +63,3 @@ class TestVehicleModel(unittest.TestCase): x2 = dyn_ss_sol(sa, u, roll, self.VM) np.testing.assert_almost_equal(x1, x2, decimal=3) - - - -if __name__ == "__main__": - unittest.main() diff --git a/selfdrive/controls/tests/test_alerts.py b/selfdrive/controls/tests/test_alerts.py index 7b4fba0dce..e29a6322ab 100755 --- a/selfdrive/controls/tests/test_alerts.py +++ b/selfdrive/controls/tests/test_alerts.py @@ -2,7 +2,6 @@ import copy import json import os -import unittest import random from PIL import Image, ImageDraw, ImageFont @@ -25,10 +24,10 @@ for event_types in EVENTS.values(): ALERTS.append(alert) -class TestAlerts(unittest.TestCase): +class TestAlerts: @classmethod - def setUpClass(cls): + def setup_class(cls): with open(OFFROAD_ALERTS_PATH) as f: cls.offroad_alerts = json.loads(f.read()) @@ -45,7 +44,7 @@ class TestAlerts(unittest.TestCase): for name, e in events.items(): if not name.endswith("DEPRECATED"): fail_msg = "%s @%d not in EVENTS" % (name, e) - self.assertTrue(e in EVENTS.keys(), msg=fail_msg) + assert e in EVENTS.keys(), fail_msg # ensure alert text doesn't exceed allowed width def test_alert_text_length(self): @@ -80,7 +79,7 @@ class TestAlerts(unittest.TestCase): left, _, right, _ = draw.textbbox((0, 0), txt, font) width = right - left msg = f"type: {alert.alert_type} msg: {txt}" - self.assertLessEqual(width, max_text_width, msg=msg) + assert width <= max_text_width, msg def test_alert_sanity_check(self): for event_types in EVENTS.values(): @@ -90,21 +89,21 @@ class TestAlerts(unittest.TestCase): continue if a.alert_size == AlertSize.none: - self.assertEqual(len(a.alert_text_1), 0) - self.assertEqual(len(a.alert_text_2), 0) + assert len(a.alert_text_1) == 0 + assert len(a.alert_text_2) == 0 elif a.alert_size == AlertSize.small: - self.assertGreater(len(a.alert_text_1), 0) - self.assertEqual(len(a.alert_text_2), 0) + assert len(a.alert_text_1) > 0 + assert len(a.alert_text_2) == 0 elif a.alert_size == AlertSize.mid: - self.assertGreater(len(a.alert_text_1), 0) - self.assertGreater(len(a.alert_text_2), 0) + assert len(a.alert_text_1) > 0 + assert len(a.alert_text_2) > 0 else: - self.assertGreater(len(a.alert_text_1), 0) + assert len(a.alert_text_1) > 0 - self.assertGreaterEqual(a.duration, 0.) + assert a.duration >= 0. if event_type not in (ET.WARNING, ET.PERMANENT, ET.PRE_ENABLE): - self.assertEqual(a.creation_delay, 0.) + assert a.creation_delay == 0. def test_offroad_alerts(self): params = Params() @@ -113,11 +112,11 @@ class TestAlerts(unittest.TestCase): alert = copy.copy(self.offroad_alerts[a]) set_offroad_alert(a, True) alert['extra'] = '' - self.assertTrue(json.dumps(alert) == params.get(a, encoding='utf8')) + assert json.dumps(alert) == params.get(a, encoding='utf8') # then delete it set_offroad_alert(a, False) - self.assertTrue(params.get(a) is None) + assert params.get(a) is None def test_offroad_alerts_extra_text(self): params = Params() @@ -128,8 +127,5 @@ class TestAlerts(unittest.TestCase): set_offroad_alert(a, True, extra_text="a"*i) written_alert = json.loads(params.get(a, encoding='utf8')) - self.assertTrue("a"*i == written_alert['extra']) - self.assertTrue(alert["text"] == written_alert['text']) - -if __name__ == "__main__": - unittest.main() + assert "a"*i == written_alert['extra'] + assert alert["text"] == written_alert['text'] diff --git a/selfdrive/controls/tests/test_cruise_speed.py b/selfdrive/controls/tests/test_cruise_speed.py index c46d03ad1e..6c46285e81 100755 --- a/selfdrive/controls/tests/test_cruise_speed.py +++ b/selfdrive/controls/tests/test_cruise_speed.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 +import pytest import itertools import numpy as np -import unittest from parameterized import parameterized_class from cereal import log @@ -36,19 +36,19 @@ def run_cruise_simulation(cruise, e2e, personality, t_end=20.): [True, False], # e2e log.LongitudinalPersonality.schema.enumerants, # personality [5,35])) # speed -class TestCruiseSpeed(unittest.TestCase): +class TestCruiseSpeed: def test_cruise_speed(self): print(f'Testing {self.speed} m/s') cruise_speed = float(self.speed) simulation_steady_state = run_cruise_simulation(cruise_speed, self.e2e, self.personality) - self.assertAlmostEqual(simulation_steady_state, cruise_speed, delta=.01, msg=f'Did not reach {self.speed} m/s') + assert simulation_steady_state == pytest.approx(cruise_speed, abs=.01), f'Did not reach {self.speed} m/s' # TODO: test pcmCruise @parameterized_class(('pcm_cruise',), [(False,)]) -class TestVCruiseHelper(unittest.TestCase): - def setUp(self): +class TestVCruiseHelper: + def setup_method(self): self.CP = car.CarParams(pcmCruise=self.pcm_cruise) self.v_cruise_helper = VCruiseHelper(self.CP) self.reset_cruise_speed_state() @@ -75,7 +75,7 @@ class TestVCruiseHelper(unittest.TestCase): CS.buttonEvents = [ButtonEvent(type=btn, pressed=pressed)] self.v_cruise_helper.update_v_cruise(CS, enabled=True, is_metric=False) - self.assertEqual(pressed, self.v_cruise_helper.v_cruise_kph == self.v_cruise_helper.v_cruise_kph_last) + assert pressed == (self.v_cruise_helper.v_cruise_kph == self.v_cruise_helper.v_cruise_kph_last) def test_rising_edge_enable(self): """ @@ -94,7 +94,7 @@ class TestVCruiseHelper(unittest.TestCase): self.enable(V_CRUISE_INITIAL * CV.KPH_TO_MS, False) # Expected diff on enabling. Speed should not change on falling edge of pressed - self.assertEqual(not pressed, self.v_cruise_helper.v_cruise_kph == self.v_cruise_helper.v_cruise_kph_last) + assert not pressed == self.v_cruise_helper.v_cruise_kph == self.v_cruise_helper.v_cruise_kph_last def test_resume_in_standstill(self): """ @@ -111,7 +111,7 @@ class TestVCruiseHelper(unittest.TestCase): # speed should only update if not at standstill and button falling edge should_equal = standstill or pressed - self.assertEqual(should_equal, self.v_cruise_helper.v_cruise_kph == self.v_cruise_helper.v_cruise_kph_last) + assert should_equal == (self.v_cruise_helper.v_cruise_kph == self.v_cruise_helper.v_cruise_kph_last) def test_set_gas_pressed(self): """ @@ -135,7 +135,7 @@ class TestVCruiseHelper(unittest.TestCase): # TODO: fix skipping first run due to enabled on rising edge exception if v_ego == 0.0: continue - self.assertEqual(expected_v_cruise_kph, self.v_cruise_helper.v_cruise_kph) + assert expected_v_cruise_kph == self.v_cruise_helper.v_cruise_kph def test_initialize_v_cruise(self): """ @@ -145,12 +145,8 @@ class TestVCruiseHelper(unittest.TestCase): for experimental_mode in (True, False): for v_ego in np.linspace(0, 100, 101): self.reset_cruise_speed_state() - self.assertFalse(self.v_cruise_helper.v_cruise_initialized) + assert not self.v_cruise_helper.v_cruise_initialized self.enable(float(v_ego), experimental_mode) - self.assertTrue(V_CRUISE_INITIAL <= self.v_cruise_helper.v_cruise_kph <= V_CRUISE_MAX) - self.assertTrue(self.v_cruise_helper.v_cruise_initialized) - - -if __name__ == "__main__": - unittest.main() + assert V_CRUISE_INITIAL <= self.v_cruise_helper.v_cruise_kph <= V_CRUISE_MAX + assert self.v_cruise_helper.v_cruise_initialized diff --git a/selfdrive/controls/tests/test_following_distance.py b/selfdrive/controls/tests/test_following_distance.py index f58e6383c4..5d60911805 100755 --- a/selfdrive/controls/tests/test_following_distance.py +++ b/selfdrive/controls/tests/test_following_distance.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -import unittest +import pytest import itertools from parameterized import parameterized_class @@ -32,14 +32,10 @@ def run_following_distance_simulation(v_lead, t_end=100.0, e2e=False, personalit log.LongitudinalPersonality.standard, log.LongitudinalPersonality.aggressive], [0,10,35])) # speed -class TestFollowingDistance(unittest.TestCase): +class TestFollowingDistance(): def test_following_distance(self): v_lead = float(self.speed) simulation_steady_state = run_following_distance_simulation(v_lead, e2e=self.e2e, personality=self.personality) correct_steady_state = desired_follow_distance(v_lead, v_lead, get_T_FOLLOW(self.personality)) err_ratio = 0.2 if self.e2e else 0.1 - self.assertAlmostEqual(simulation_steady_state, correct_steady_state, delta=(err_ratio * correct_steady_state + .5)) - - -if __name__ == "__main__": - unittest.main() + assert simulation_steady_state == pytest.approx(correct_steady_state, abs=err_ratio * correct_steady_state + .5) diff --git a/selfdrive/controls/tests/test_lateral_mpc.py b/selfdrive/controls/tests/test_lateral_mpc.py index 8c09f46b60..3aa0fd1bce 100644 --- a/selfdrive/controls/tests/test_lateral_mpc.py +++ b/selfdrive/controls/tests/test_lateral_mpc.py @@ -1,4 +1,4 @@ -import unittest +import pytest import numpy as np from openpilot.selfdrive.controls.lib.lateral_mpc_lib.lat_mpc import LateralMpc from openpilot.selfdrive.controls.lib.drive_helpers import CAR_ROTATION_RADIUS @@ -27,20 +27,20 @@ def run_mpc(lat_mpc=None, v_ref=30., x_init=0., y_init=0., psi_init=0., curvatur return lat_mpc.x_sol -class TestLateralMpc(unittest.TestCase): +class TestLateralMpc: def _assert_null(self, sol, curvature=1e-6): for i in range(len(sol)): - self.assertAlmostEqual(sol[0,i,1], 0., delta=curvature) - self.assertAlmostEqual(sol[0,i,2], 0., delta=curvature) - self.assertAlmostEqual(sol[0,i,3], 0., delta=curvature) + assert sol[0,i,1] == pytest.approx(0, abs=curvature) + assert sol[0,i,2] == pytest.approx(0, abs=curvature) + assert sol[0,i,3] == pytest.approx(0, abs=curvature) def _assert_simmetry(self, sol, curvature=1e-6): for i in range(len(sol)): - self.assertAlmostEqual(sol[0,i,1], -sol[1,i,1], delta=curvature) - self.assertAlmostEqual(sol[0,i,2], -sol[1,i,2], delta=curvature) - self.assertAlmostEqual(sol[0,i,3], -sol[1,i,3], delta=curvature) - self.assertAlmostEqual(sol[0,i,0], sol[1,i,0], delta=curvature) + assert sol[0,i,1] == pytest.approx(-sol[1,i,1], abs=curvature) + assert sol[0,i,2] == pytest.approx(-sol[1,i,2], abs=curvature) + assert sol[0,i,3] == pytest.approx(-sol[1,i,3], abs=curvature) + assert sol[0,i,0] == pytest.approx(sol[1,i,0], abs=curvature) def test_straight(self): sol = run_mpc() @@ -74,7 +74,7 @@ class TestLateralMpc(unittest.TestCase): y_init = 1. sol = run_mpc(y_init=y_init) for y in list(sol[:,1]): - self.assertGreaterEqual(y_init, abs(y)) + assert y_init >= abs(y) def test_switch_convergence(self): lat_mpc = LateralMpc() @@ -83,7 +83,3 @@ class TestLateralMpc(unittest.TestCase): sol = run_mpc(lat_mpc=lat_mpc, poly_shift=-3.0, v_ref=7.0) left_psi_deg = np.degrees(sol[:,2]) np.testing.assert_almost_equal(right_psi_deg, -left_psi_deg, decimal=3) - - -if __name__ == "__main__": - unittest.main() diff --git a/selfdrive/controls/tests/test_leads.py b/selfdrive/controls/tests/test_leads.py index a06387a087..f4e97725ff 100755 --- a/selfdrive/controls/tests/test_leads.py +++ b/selfdrive/controls/tests/test_leads.py @@ -1,13 +1,11 @@ #!/usr/bin/env python3 -import unittest - import cereal.messaging as messaging from openpilot.selfdrive.test.process_replay import replay_process_with_name from openpilot.selfdrive.car.toyota.values import CAR as TOYOTA -class TestLeads(unittest.TestCase): +class TestLeads: def test_radar_fault(self): # if there's no radar-related can traffic, radard should either not respond or respond with an error # this is tightly coupled with underlying car radar_interface implementation, but it's a good sanity check @@ -29,8 +27,4 @@ class TestLeads(unittest.TestCase): states = [m for m in out if m.which() == "radarState"] failures = [not state.valid and len(state.radarState.radarErrors) for state in states] - self.assertTrue(len(states) == 0 or all(failures)) - - -if __name__ == "__main__": - unittest.main() + assert len(states) == 0 or all(failures) diff --git a/selfdrive/controls/tests/test_state_machine.py b/selfdrive/controls/tests/test_state_machine.py index d49111752d..b92724ce43 100755 --- a/selfdrive/controls/tests/test_state_machine.py +++ b/selfdrive/controls/tests/test_state_machine.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -import unittest from cereal import car, log from openpilot.common.realtime import DT_CTRL @@ -28,9 +27,9 @@ def make_event(event_types): return 0 -class TestStateMachine(unittest.TestCase): +class TestStateMachine: - def setUp(self): + def setup_method(self): CarInterface, CarController, CarState = interfaces[MOCK.MOCK] CP = CarInterface.get_non_essential_params(MOCK.MOCK) CI = CarInterface(CP, CarController, CarState) @@ -46,7 +45,7 @@ class TestStateMachine(unittest.TestCase): self.controlsd.events.add(make_event([et, ET.IMMEDIATE_DISABLE])) self.controlsd.state = state self.controlsd.state_transition(self.CS) - self.assertEqual(State.disabled, self.controlsd.state) + assert State.disabled == self.controlsd.state self.controlsd.events.clear() def test_user_disable(self): @@ -55,7 +54,7 @@ class TestStateMachine(unittest.TestCase): self.controlsd.events.add(make_event([et, ET.USER_DISABLE])) self.controlsd.state = state self.controlsd.state_transition(self.CS) - self.assertEqual(State.disabled, self.controlsd.state) + assert State.disabled == self.controlsd.state self.controlsd.events.clear() def test_soft_disable(self): @@ -66,7 +65,7 @@ class TestStateMachine(unittest.TestCase): self.controlsd.events.add(make_event([et, ET.SOFT_DISABLE])) self.controlsd.state = state self.controlsd.state_transition(self.CS) - self.assertEqual(self.controlsd.state, State.disabled if state == State.disabled else State.softDisabling) + assert self.controlsd.state == State.disabled if state == State.disabled else State.softDisabling self.controlsd.events.clear() def test_soft_disable_timer(self): @@ -74,17 +73,17 @@ class TestStateMachine(unittest.TestCase): self.controlsd.events.add(make_event([ET.SOFT_DISABLE])) self.controlsd.state_transition(self.CS) for _ in range(int(SOFT_DISABLE_TIME / DT_CTRL)): - self.assertEqual(self.controlsd.state, State.softDisabling) + assert self.controlsd.state == State.softDisabling self.controlsd.state_transition(self.CS) - self.assertEqual(self.controlsd.state, State.disabled) + assert self.controlsd.state == State.disabled def test_no_entry(self): # Make sure noEntry keeps us disabled for et in ENABLE_EVENT_TYPES: self.controlsd.events.add(make_event([ET.NO_ENTRY, et])) self.controlsd.state_transition(self.CS) - self.assertEqual(self.controlsd.state, State.disabled) + assert self.controlsd.state == State.disabled self.controlsd.events.clear() def test_no_entry_pre_enable(self): @@ -92,7 +91,7 @@ class TestStateMachine(unittest.TestCase): self.controlsd.state = State.preEnabled self.controlsd.events.add(make_event([ET.NO_ENTRY, ET.PRE_ENABLE])) self.controlsd.state_transition(self.CS) - self.assertEqual(self.controlsd.state, State.preEnabled) + assert self.controlsd.state == State.preEnabled def test_maintain_states(self): # Given current state's event type, we should maintain state @@ -101,9 +100,5 @@ class TestStateMachine(unittest.TestCase): self.controlsd.state = state self.controlsd.events.add(make_event([et])) self.controlsd.state_transition(self.CS) - self.assertEqual(self.controlsd.state, state) + assert self.controlsd.state == state self.controlsd.events.clear() - - -if __name__ == "__main__": - unittest.main() diff --git a/selfdrive/locationd/test/test_calibrationd.py b/selfdrive/locationd/test/test_calibrationd.py index e2db094397..598d5d2d5f 100755 --- a/selfdrive/locationd/test/test_calibrationd.py +++ b/selfdrive/locationd/test/test_calibrationd.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 import random -import unittest import numpy as np @@ -31,7 +30,7 @@ def process_messages(c, cam_odo_calib, cycles, [0.0, 0.0, HEIGHT_INIT.item()], [cam_odo_height_std, cam_odo_height_std, cam_odo_height_std]) -class TestCalibrationd(unittest.TestCase): +class TestCalibrationd: def test_read_saved_params(self): msg = messaging.new_message('liveCalibration') @@ -43,13 +42,13 @@ class TestCalibrationd(unittest.TestCase): np.testing.assert_allclose(msg.liveCalibration.rpyCalib, c.rpy) np.testing.assert_allclose(msg.liveCalibration.height, c.height) - self.assertEqual(msg.liveCalibration.validBlocks, c.valid_blocks) + assert msg.liveCalibration.validBlocks == c.valid_blocks def test_calibration_basics(self): c = Calibrator(param_put=False) process_messages(c, [0.0, 0.0, 0.0], BLOCK_SIZE * INPUTS_WANTED) - self.assertEqual(c.valid_blocks, INPUTS_WANTED) + assert c.valid_blocks == INPUTS_WANTED np.testing.assert_allclose(c.rpy, np.zeros(3)) np.testing.assert_allclose(c.height, HEIGHT_INIT) c.reset() @@ -59,7 +58,7 @@ class TestCalibrationd(unittest.TestCase): c = Calibrator(param_put=False) process_messages(c, [0.0, 0.0, 0.0], BLOCK_SIZE * INPUTS_WANTED, cam_odo_speed=MIN_SPEED_FILTER - 1) process_messages(c, [0.0, 0.0, 0.0], BLOCK_SIZE * INPUTS_WANTED, carstate_speed=MIN_SPEED_FILTER - 1) - self.assertEqual(c.valid_blocks, 0) + assert c.valid_blocks == 0 np.testing.assert_allclose(c.rpy, np.zeros(3)) np.testing.assert_allclose(c.height, HEIGHT_INIT) @@ -67,7 +66,7 @@ class TestCalibrationd(unittest.TestCase): def test_calibration_yaw_rate_reject(self): c = Calibrator(param_put=False) process_messages(c, [0.0, 0.0, 0.0], BLOCK_SIZE * INPUTS_WANTED, cam_odo_yr=MAX_YAW_RATE_FILTER) - self.assertEqual(c.valid_blocks, 0) + assert c.valid_blocks == 0 np.testing.assert_allclose(c.rpy, np.zeros(3)) np.testing.assert_allclose(c.height, HEIGHT_INIT) @@ -75,43 +74,40 @@ class TestCalibrationd(unittest.TestCase): def test_calibration_speed_std_reject(self): c = Calibrator(param_put=False) process_messages(c, [0.0, 0.0, 0.0], BLOCK_SIZE * INPUTS_WANTED, cam_odo_speed_std=1e3) - self.assertEqual(c.valid_blocks, INPUTS_NEEDED) + assert c.valid_blocks == INPUTS_NEEDED np.testing.assert_allclose(c.rpy, np.zeros(3)) def test_calibration_speed_std_height_reject(self): c = Calibrator(param_put=False) process_messages(c, [0.0, 0.0, 0.0], BLOCK_SIZE * INPUTS_WANTED, cam_odo_height_std=1e3) - self.assertEqual(c.valid_blocks, INPUTS_NEEDED) + assert c.valid_blocks == INPUTS_NEEDED np.testing.assert_allclose(c.rpy, np.zeros(3)) def test_calibration_auto_reset(self): c = Calibrator(param_put=False) process_messages(c, [0.0, 0.0, 0.0], BLOCK_SIZE * INPUTS_NEEDED) - self.assertEqual(c.valid_blocks, INPUTS_NEEDED) + assert c.valid_blocks == INPUTS_NEEDED np.testing.assert_allclose(c.rpy, [0.0, 0.0, 0.0], atol=1e-3) process_messages(c, [0.0, MAX_ALLOWED_PITCH_SPREAD*0.9, MAX_ALLOWED_YAW_SPREAD*0.9], BLOCK_SIZE + 10) - self.assertEqual(c.valid_blocks, INPUTS_NEEDED + 1) - self.assertEqual(c.cal_status, log.LiveCalibrationData.Status.calibrated) + assert c.valid_blocks == INPUTS_NEEDED + 1 + assert c.cal_status == log.LiveCalibrationData.Status.calibrated c = Calibrator(param_put=False) process_messages(c, [0.0, 0.0, 0.0], BLOCK_SIZE * INPUTS_NEEDED) - self.assertEqual(c.valid_blocks, INPUTS_NEEDED) + assert c.valid_blocks == INPUTS_NEEDED np.testing.assert_allclose(c.rpy, [0.0, 0.0, 0.0]) process_messages(c, [0.0, MAX_ALLOWED_PITCH_SPREAD*1.1, 0.0], BLOCK_SIZE + 10) - self.assertEqual(c.valid_blocks, 1) - self.assertEqual(c.cal_status, log.LiveCalibrationData.Status.recalibrating) + assert c.valid_blocks == 1 + assert c.cal_status == log.LiveCalibrationData.Status.recalibrating np.testing.assert_allclose(c.rpy, [0.0, MAX_ALLOWED_PITCH_SPREAD*1.1, 0.0], atol=1e-2) c = Calibrator(param_put=False) process_messages(c, [0.0, 0.0, 0.0], BLOCK_SIZE * INPUTS_NEEDED) - self.assertEqual(c.valid_blocks, INPUTS_NEEDED) + assert c.valid_blocks == INPUTS_NEEDED np.testing.assert_allclose(c.rpy, [0.0, 0.0, 0.0]) process_messages(c, [0.0, 0.0, MAX_ALLOWED_YAW_SPREAD*1.1], BLOCK_SIZE + 10) - self.assertEqual(c.valid_blocks, 1) - self.assertEqual(c.cal_status, log.LiveCalibrationData.Status.recalibrating) + assert c.valid_blocks == 1 + assert c.cal_status == log.LiveCalibrationData.Status.recalibrating np.testing.assert_allclose(c.rpy, [0.0, 0.0, MAX_ALLOWED_YAW_SPREAD*1.1], atol=1e-2) - -if __name__ == "__main__": - unittest.main() diff --git a/selfdrive/locationd/test/test_locationd.py b/selfdrive/locationd/test/test_locationd.py index cd032dbaf0..bac824bada 100755 --- a/selfdrive/locationd/test/test_locationd.py +++ b/selfdrive/locationd/test/test_locationd.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 +import pytest import json import random -import unittest import time import capnp @@ -13,13 +13,11 @@ from openpilot.common.transformations.coordinates import ecef2geodetic from openpilot.selfdrive.manager.process_config import managed_processes -class TestLocationdProc(unittest.TestCase): +class TestLocationdProc: LLD_MSGS = ['gpsLocationExternal', 'cameraOdometry', 'carState', 'liveCalibration', 'accelerometer', 'gyroscope', 'magnetometer'] - def setUp(self): - random.seed(123489234) - + def setup_method(self): self.pm = messaging.PubMaster(self.LLD_MSGS) self.params = Params() @@ -27,7 +25,7 @@ class TestLocationdProc(unittest.TestCase): managed_processes['locationd'].prepare() managed_processes['locationd'].start() - def tearDown(self): + def teardown_method(self): managed_processes['locationd'].stop() def get_msg(self, name, t): @@ -65,6 +63,7 @@ class TestLocationdProc(unittest.TestCase): return msg def test_params_gps(self): + random.seed(123489234) self.params.remove('LastGPSPosition') self.x = -2710700 + (random.random() * 1e5) @@ -86,10 +85,6 @@ class TestLocationdProc(unittest.TestCase): time.sleep(1) # wait for async params write lastGPS = json.loads(self.params.get('LastGPSPosition')) - self.assertAlmostEqual(lastGPS['latitude'], self.lat, places=3) - self.assertAlmostEqual(lastGPS['longitude'], self.lon, places=3) - self.assertAlmostEqual(lastGPS['altitude'], self.alt, places=3) - - -if __name__ == "__main__": - unittest.main() + assert lastGPS['latitude'] == pytest.approx(self.lat, abs=0.001) + assert lastGPS['longitude'] == pytest.approx(self.lon, abs=0.001) + assert lastGPS['altitude'] == pytest.approx(self.alt, abs=0.001) diff --git a/selfdrive/locationd/test/test_locationd_scenarios.py b/selfdrive/locationd/test/test_locationd_scenarios.py index 3fdd47275f..be95c6fffb 100755 --- a/selfdrive/locationd/test/test_locationd_scenarios.py +++ b/selfdrive/locationd/test/test_locationd_scenarios.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 import pytest -import unittest import numpy as np from collections import defaultdict from enum import Enum @@ -99,7 +98,7 @@ def run_scenarios(scenario, logs): @pytest.mark.xdist_group("test_locationd_scenarios") @pytest.mark.shared_download_cache -class TestLocationdScenarios(unittest.TestCase): +class TestLocationdScenarios: """ Test locationd with different scenarios. In all these scenarios, we expect the following: - locationd kalman filter should never go unstable (we care mostly about yaw_rate, roll, gpsOK, inputsOK, sensorsOK) @@ -107,7 +106,7 @@ class TestLocationdScenarios(unittest.TestCase): """ @classmethod - def setUpClass(cls): + def setup_class(cls): cls.logs = migrate_all(LogReader(TEST_ROUTE)) def test_base(self): @@ -118,8 +117,8 @@ class TestLocationdScenarios(unittest.TestCase): - roll: unchanged """ orig_data, replayed_data = run_scenarios(Scenario.BASE, self.logs) - self.assertTrue(np.allclose(orig_data['yaw_rate'], replayed_data['yaw_rate'], atol=np.radians(0.2))) - self.assertTrue(np.allclose(orig_data['roll'], replayed_data['roll'], atol=np.radians(0.5))) + assert np.allclose(orig_data['yaw_rate'], replayed_data['yaw_rate'], atol=np.radians(0.2)) + assert np.allclose(orig_data['roll'], replayed_data['roll'], atol=np.radians(0.5)) def test_gps_off(self): """ @@ -130,9 +129,9 @@ class TestLocationdScenarios(unittest.TestCase): - gpsOK: False """ orig_data, replayed_data = run_scenarios(Scenario.GPS_OFF, self.logs) - self.assertTrue(np.allclose(orig_data['yaw_rate'], replayed_data['yaw_rate'], atol=np.radians(0.2))) - self.assertTrue(np.allclose(orig_data['roll'], replayed_data['roll'], atol=np.radians(0.5))) - self.assertTrue(np.all(replayed_data['gps_flag'] == 0.0)) + assert np.allclose(orig_data['yaw_rate'], replayed_data['yaw_rate'], atol=np.radians(0.2)) + assert np.allclose(orig_data['roll'], replayed_data['roll'], atol=np.radians(0.5)) + assert np.all(replayed_data['gps_flag'] == 0.0) def test_gps_off_midway(self): """ @@ -143,9 +142,9 @@ class TestLocationdScenarios(unittest.TestCase): - gpsOK: True for the first half, False for the second half """ orig_data, replayed_data = run_scenarios(Scenario.GPS_OFF_MIDWAY, self.logs) - self.assertTrue(np.allclose(orig_data['yaw_rate'], replayed_data['yaw_rate'], atol=np.radians(0.2))) - self.assertTrue(np.allclose(orig_data['roll'], replayed_data['roll'], atol=np.radians(0.5))) - self.assertTrue(np.diff(replayed_data['gps_flag'])[512] == -1.0) + assert np.allclose(orig_data['yaw_rate'], replayed_data['yaw_rate'], atol=np.radians(0.2)) + assert np.allclose(orig_data['roll'], replayed_data['roll'], atol=np.radians(0.5)) + assert np.diff(replayed_data['gps_flag'])[512] == -1.0 def test_gps_on_midway(self): """ @@ -156,9 +155,9 @@ class TestLocationdScenarios(unittest.TestCase): - gpsOK: False for the first half, True for the second half """ orig_data, replayed_data = run_scenarios(Scenario.GPS_ON_MIDWAY, self.logs) - self.assertTrue(np.allclose(orig_data['yaw_rate'], replayed_data['yaw_rate'], atol=np.radians(0.2))) - self.assertTrue(np.allclose(orig_data['roll'], replayed_data['roll'], atol=np.radians(1.5))) - self.assertTrue(np.diff(replayed_data['gps_flag'])[505] == 1.0) + assert np.allclose(orig_data['yaw_rate'], replayed_data['yaw_rate'], atol=np.radians(0.2)) + assert np.allclose(orig_data['roll'], replayed_data['roll'], atol=np.radians(1.5)) + assert np.diff(replayed_data['gps_flag'])[505] == 1.0 def test_gps_tunnel(self): """ @@ -169,10 +168,10 @@ class TestLocationdScenarios(unittest.TestCase): - gpsOK: False for the middle section, True for the rest """ orig_data, replayed_data = run_scenarios(Scenario.GPS_TUNNEL, self.logs) - self.assertTrue(np.allclose(orig_data['yaw_rate'], replayed_data['yaw_rate'], atol=np.radians(0.2))) - self.assertTrue(np.allclose(orig_data['roll'], replayed_data['roll'], atol=np.radians(0.5))) - self.assertTrue(np.diff(replayed_data['gps_flag'])[213] == -1.0) - self.assertTrue(np.diff(replayed_data['gps_flag'])[805] == 1.0) + assert np.allclose(orig_data['yaw_rate'], replayed_data['yaw_rate'], atol=np.radians(0.2)) + assert np.allclose(orig_data['roll'], replayed_data['roll'], atol=np.radians(0.5)) + assert np.diff(replayed_data['gps_flag'])[213] == -1.0 + assert np.diff(replayed_data['gps_flag'])[805] == 1.0 def test_gyro_off(self): """ @@ -183,9 +182,9 @@ class TestLocationdScenarios(unittest.TestCase): - sensorsOK: False """ _, replayed_data = run_scenarios(Scenario.GYRO_OFF, self.logs) - self.assertTrue(np.allclose(replayed_data['yaw_rate'], 0.0)) - self.assertTrue(np.allclose(replayed_data['roll'], 0.0)) - self.assertTrue(np.all(replayed_data['sensors_flag'] == 0.0)) + assert np.allclose(replayed_data['yaw_rate'], 0.0) + assert np.allclose(replayed_data['roll'], 0.0) + assert np.all(replayed_data['sensors_flag'] == 0.0) def test_gyro_spikes(self): """ @@ -196,10 +195,10 @@ class TestLocationdScenarios(unittest.TestCase): - inputsOK: False for some time after the spike, True for the rest """ orig_data, replayed_data = run_scenarios(Scenario.GYRO_SPIKE_MIDWAY, self.logs) - self.assertTrue(np.allclose(orig_data['yaw_rate'], replayed_data['yaw_rate'], atol=np.radians(0.2))) - self.assertTrue(np.allclose(orig_data['roll'], replayed_data['roll'], atol=np.radians(0.5))) - self.assertTrue(np.diff(replayed_data['inputs_flag'])[500] == -1.0) - self.assertTrue(np.diff(replayed_data['inputs_flag'])[694] == 1.0) + assert np.allclose(orig_data['yaw_rate'], replayed_data['yaw_rate'], atol=np.radians(0.2)) + assert np.allclose(orig_data['roll'], replayed_data['roll'], atol=np.radians(0.5)) + assert np.diff(replayed_data['inputs_flag'])[500] == -1.0 + assert np.diff(replayed_data['inputs_flag'])[694] == 1.0 def test_accel_off(self): """ @@ -210,9 +209,9 @@ class TestLocationdScenarios(unittest.TestCase): - sensorsOK: False """ _, replayed_data = run_scenarios(Scenario.ACCEL_OFF, self.logs) - self.assertTrue(np.allclose(replayed_data['yaw_rate'], 0.0)) - self.assertTrue(np.allclose(replayed_data['roll'], 0.0)) - self.assertTrue(np.all(replayed_data['sensors_flag'] == 0.0)) + assert np.allclose(replayed_data['yaw_rate'], 0.0) + assert np.allclose(replayed_data['roll'], 0.0) + assert np.all(replayed_data['sensors_flag'] == 0.0) def test_accel_spikes(self): """ @@ -221,9 +220,5 @@ class TestLocationdScenarios(unittest.TestCase): Expected Result: Right now, the kalman filter is not robust to small spikes like it is to gyroscope spikes. """ orig_data, replayed_data = run_scenarios(Scenario.ACCEL_SPIKE_MIDWAY, self.logs) - self.assertTrue(np.allclose(orig_data['yaw_rate'], replayed_data['yaw_rate'], atol=np.radians(0.2))) - self.assertTrue(np.allclose(orig_data['roll'], replayed_data['roll'], atol=np.radians(0.5))) - - -if __name__ == "__main__": - unittest.main() + assert np.allclose(orig_data['yaw_rate'], replayed_data['yaw_rate'], atol=np.radians(0.2)) + assert np.allclose(orig_data['roll'], replayed_data['roll'], atol=np.radians(0.5)) diff --git a/selfdrive/manager/test/test_manager.py b/selfdrive/manager/test/test_manager.py index 1ae94b26a1..4cdc99c240 100755 --- a/selfdrive/manager/test/test_manager.py +++ b/selfdrive/manager/test/test_manager.py @@ -3,7 +3,6 @@ import os import pytest import signal import time -import unittest from parameterized import parameterized @@ -21,15 +20,15 @@ BLACKLIST_PROCS = ['manage_athenad', 'pandad', 'pigeond'] @pytest.mark.tici -class TestManager(unittest.TestCase): - def setUp(self): +class TestManager: + def setup_method(self): HARDWARE.set_power_save(False) # ensure clean CarParams params = Params() params.clear_all() - def tearDown(self): + def teardown_method(self): manager.manager_cleanup() def test_manager_prepare(self): @@ -38,7 +37,7 @@ class TestManager(unittest.TestCase): def test_blacklisted_procs(self): # TODO: ensure there are blacklisted procs until we have a dedicated test - self.assertTrue(len(BLACKLIST_PROCS), "No blacklisted procs to test not_run") + assert len(BLACKLIST_PROCS), "No blacklisted procs to test not_run" @parameterized.expand([(i,) for i in range(10)]) def test_startup_time(self, index): @@ -48,8 +47,8 @@ class TestManager(unittest.TestCase): t = time.monotonic() - start assert t < MAX_STARTUP_TIME, f"startup took {t}s, expected <{MAX_STARTUP_TIME}s" - @unittest.skip("this test is flaky the way it's currently written, should be moved to test_onroad") - def test_clean_exit(self): + @pytest.mark.skip("this test is flaky the way it's currently written, should be moved to test_onroad") + def test_clean_exit(self, subtests): """ Ensure all processes exit cleanly when stopped. """ @@ -62,21 +61,17 @@ class TestManager(unittest.TestCase): time.sleep(10) for p in procs: - with self.subTest(proc=p.name): + with subtests.test(proc=p.name): state = p.get_process_state_msg() - self.assertTrue(state.running, f"{p.name} not running") + assert state.running, f"{p.name} not running" exit_code = p.stop(retry=False) - self.assertNotIn(p.name, BLACKLIST_PROCS, f"{p.name} was started") + assert p.name not in BLACKLIST_PROCS, f"{p.name} was started" - self.assertTrue(exit_code is not None, f"{p.name} failed to exit") + assert exit_code is not None, f"{p.name} failed to exit" # TODO: interrupted blocking read exits with 1 in cereal. use a more unique return code exit_codes = [0, 1] if p.sigkill: exit_codes = [-signal.SIGKILL] - self.assertIn(exit_code, exit_codes, f"{p.name} died with {exit_code}") - - -if __name__ == "__main__": - unittest.main() + assert exit_code in exit_codes, f"{p.name} died with {exit_code}" diff --git a/selfdrive/modeld/tests/test_modeld.py b/selfdrive/modeld/tests/test_modeld.py index 67c6f71038..a18ce8fa42 100755 --- a/selfdrive/modeld/tests/test_modeld.py +++ b/selfdrive/modeld/tests/test_modeld.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -import unittest import numpy as np import random @@ -16,9 +15,9 @@ IMG = np.zeros(int(CAM.width*CAM.height*(3/2)), dtype=np.uint8) IMG_BYTES = IMG.flatten().tobytes() -class TestModeld(unittest.TestCase): +class TestModeld: - def setUp(self): + def setup_method(self): self.vipc_server = VisionIpcServer("camerad") self.vipc_server.create_buffers(VisionStreamType.VISION_STREAM_ROAD, 40, False, CAM.width, CAM.height) self.vipc_server.create_buffers(VisionStreamType.VISION_STREAM_DRIVER, 40, False, CAM.width, CAM.height) @@ -32,7 +31,7 @@ class TestModeld(unittest.TestCase): managed_processes['modeld'].start() self.pm.wait_for_readers_to_update("roadCameraState", 10) - def tearDown(self): + def teardown_method(self): managed_processes['modeld'].stop() del self.vipc_server @@ -65,15 +64,15 @@ class TestModeld(unittest.TestCase): self._wait() mdl = self.sm['modelV2'] - self.assertEqual(mdl.frameId, n) - self.assertEqual(mdl.frameIdExtra, n) - self.assertEqual(mdl.timestampEof, cs.timestampEof) - self.assertEqual(mdl.frameAge, 0) - self.assertEqual(mdl.frameDropPerc, 0) + assert mdl.frameId == n + assert mdl.frameIdExtra == n + assert mdl.timestampEof == cs.timestampEof + assert mdl.frameAge == 0 + assert mdl.frameDropPerc == 0 odo = self.sm['cameraOdometry'] - self.assertEqual(odo.frameId, n) - self.assertEqual(odo.timestampEof, cs.timestampEof) + assert odo.frameId == n + assert odo.timestampEof == cs.timestampEof def test_dropped_frames(self): """ @@ -95,13 +94,9 @@ class TestModeld(unittest.TestCase): mdl = self.sm['modelV2'] odo = self.sm['cameraOdometry'] - self.assertEqual(mdl.frameId, frame_id) - self.assertEqual(mdl.frameIdExtra, frame_id) - self.assertEqual(odo.frameId, frame_id) + assert mdl.frameId == frame_id + assert mdl.frameIdExtra == frame_id + assert odo.frameId == frame_id if n != frame_id: - self.assertFalse(self.sm.updated['modelV2']) - self.assertFalse(self.sm.updated['cameraOdometry']) - - -if __name__ == "__main__": - unittest.main() + assert not self.sm.updated['modelV2'] + assert not self.sm.updated['cameraOdometry'] diff --git a/selfdrive/monitoring/test_monitoring.py b/selfdrive/monitoring/test_monitoring.py index 50b2746e2d..9395960b65 100755 --- a/selfdrive/monitoring/test_monitoring.py +++ b/selfdrive/monitoring/test_monitoring.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -import unittest import numpy as np from cereal import car, log @@ -53,7 +52,7 @@ always_true = [True] * int(TEST_TIMESPAN / DT_DMON) always_false = [False] * int(TEST_TIMESPAN / DT_DMON) # TODO: this only tests DriverStatus -class TestMonitoring(unittest.TestCase): +class TestMonitoring: def _run_seq(self, msgs, interaction, engaged, standstill): DS = DriverStatus() events = [] @@ -69,7 +68,7 @@ class TestMonitoring(unittest.TestCase): return events, DS def _assert_no_events(self, events): - self.assertTrue(all(not len(e) for e in events)) + assert all(not len(e) for e in events) # engaged, driver is attentive all the time def test_fully_aware_driver(self): @@ -79,27 +78,27 @@ class TestMonitoring(unittest.TestCase): # engaged, driver is distracted and does nothing def test_fully_distracted_driver(self): events, d_status = self._run_seq(always_distracted, always_false, always_true, always_false) - self.assertEqual(len(events[int((d_status.settings._DISTRACTED_TIME-d_status.settings._DISTRACTED_PRE_TIME_TILL_TERMINAL)/2/DT_DMON)]), 0) - self.assertEqual(events[int((d_status.settings._DISTRACTED_TIME-d_status.settings._DISTRACTED_PRE_TIME_TILL_TERMINAL + - ((d_status.settings._DISTRACTED_PRE_TIME_TILL_TERMINAL-d_status.settings._DISTRACTED_PROMPT_TIME_TILL_TERMINAL)/2))/DT_DMON)].names[0], - EventName.preDriverDistracted) - self.assertEqual(events[int((d_status.settings._DISTRACTED_TIME-d_status.settings._DISTRACTED_PROMPT_TIME_TILL_TERMINAL + - ((d_status.settings._DISTRACTED_PROMPT_TIME_TILL_TERMINAL)/2))/DT_DMON)].names[0], EventName.promptDriverDistracted) - self.assertEqual(events[int((d_status.settings._DISTRACTED_TIME + - ((TEST_TIMESPAN-10-d_status.settings._DISTRACTED_TIME)/2))/DT_DMON)].names[0], EventName.driverDistracted) - self.assertIs(type(d_status.awareness), float) + assert len(events[int((d_status.settings._DISTRACTED_TIME-d_status.settings._DISTRACTED_PRE_TIME_TILL_TERMINAL)/2/DT_DMON)]) == 0 + assert events[int((d_status.settings._DISTRACTED_TIME-d_status.settings._DISTRACTED_PRE_TIME_TILL_TERMINAL + \ + ((d_status.settings._DISTRACTED_PRE_TIME_TILL_TERMINAL-d_status.settings._DISTRACTED_PROMPT_TIME_TILL_TERMINAL)/2))/DT_DMON)].names[0] == \ + EventName.preDriverDistracted + assert events[int((d_status.settings._DISTRACTED_TIME-d_status.settings._DISTRACTED_PROMPT_TIME_TILL_TERMINAL + \ + ((d_status.settings._DISTRACTED_PROMPT_TIME_TILL_TERMINAL)/2))/DT_DMON)].names[0] == EventName.promptDriverDistracted + assert events[int((d_status.settings._DISTRACTED_TIME + \ + ((TEST_TIMESPAN-10-d_status.settings._DISTRACTED_TIME)/2))/DT_DMON)].names[0] == EventName.driverDistracted + assert isinstance(d_status.awareness, float) # engaged, no face detected the whole time, no action def test_fully_invisible_driver(self): events, d_status = self._run_seq(always_no_face, always_false, always_true, always_false) - self.assertTrue(len(events[int((d_status.settings._AWARENESS_TIME-d_status.settings._AWARENESS_PRE_TIME_TILL_TERMINAL)/2/DT_DMON)]) == 0) - self.assertEqual(events[int((d_status.settings._AWARENESS_TIME-d_status.settings._AWARENESS_PRE_TIME_TILL_TERMINAL + - ((d_status.settings._AWARENESS_PRE_TIME_TILL_TERMINAL-d_status.settings._AWARENESS_PROMPT_TIME_TILL_TERMINAL)/2))/DT_DMON)].names[0], - EventName.preDriverUnresponsive) - self.assertEqual(events[int((d_status.settings._AWARENESS_TIME-d_status.settings._AWARENESS_PROMPT_TIME_TILL_TERMINAL + - ((d_status.settings._AWARENESS_PROMPT_TIME_TILL_TERMINAL)/2))/DT_DMON)].names[0], EventName.promptDriverUnresponsive) - self.assertEqual(events[int((d_status.settings._AWARENESS_TIME + - ((TEST_TIMESPAN-10-d_status.settings._AWARENESS_TIME)/2))/DT_DMON)].names[0], EventName.driverUnresponsive) + assert len(events[int((d_status.settings._AWARENESS_TIME-d_status.settings._AWARENESS_PRE_TIME_TILL_TERMINAL)/2/DT_DMON)]) == 0 + assert events[int((d_status.settings._AWARENESS_TIME-d_status.settings._AWARENESS_PRE_TIME_TILL_TERMINAL + \ + ((d_status.settings._AWARENESS_PRE_TIME_TILL_TERMINAL-d_status.settings._AWARENESS_PROMPT_TIME_TILL_TERMINAL)/2))/DT_DMON)].names[0] == \ + EventName.preDriverUnresponsive + assert events[int((d_status.settings._AWARENESS_TIME-d_status.settings._AWARENESS_PROMPT_TIME_TILL_TERMINAL + \ + ((d_status.settings._AWARENESS_PROMPT_TIME_TILL_TERMINAL)/2))/DT_DMON)].names[0] == EventName.promptDriverUnresponsive + assert events[int((d_status.settings._AWARENESS_TIME + \ + ((TEST_TIMESPAN-10-d_status.settings._AWARENESS_TIME)/2))/DT_DMON)].names[0] == EventName.driverUnresponsive # engaged, down to orange, driver pays attention, back to normal; then down to orange, driver touches wheel # - should have short orange recovery time and no green afterwards; wheel touch only recovers when paying attention @@ -111,12 +110,12 @@ class TestMonitoring(unittest.TestCase): interaction_vector = [car_interaction_NOT_DETECTED] * int(DISTRACTED_SECONDS_TO_ORANGE*3/DT_DMON) + \ [car_interaction_DETECTED] * (int(TEST_TIMESPAN/DT_DMON)-int(DISTRACTED_SECONDS_TO_ORANGE*3/DT_DMON)) events, _ = self._run_seq(ds_vector, interaction_vector, always_true, always_false) - self.assertEqual(len(events[int(DISTRACTED_SECONDS_TO_ORANGE*0.5/DT_DMON)]), 0) - self.assertEqual(events[int((DISTRACTED_SECONDS_TO_ORANGE-0.1)/DT_DMON)].names[0], EventName.promptDriverDistracted) - self.assertEqual(len(events[int(DISTRACTED_SECONDS_TO_ORANGE*1.5/DT_DMON)]), 0) - self.assertEqual(events[int((DISTRACTED_SECONDS_TO_ORANGE*3-0.1)/DT_DMON)].names[0], EventName.promptDriverDistracted) - self.assertEqual(events[int((DISTRACTED_SECONDS_TO_ORANGE*3+0.1)/DT_DMON)].names[0], EventName.promptDriverDistracted) - self.assertEqual(len(events[int((DISTRACTED_SECONDS_TO_ORANGE*3+2.5)/DT_DMON)]), 0) + assert len(events[int(DISTRACTED_SECONDS_TO_ORANGE*0.5/DT_DMON)]) == 0 + assert events[int((DISTRACTED_SECONDS_TO_ORANGE-0.1)/DT_DMON)].names[0] == EventName.promptDriverDistracted + assert len(events[int(DISTRACTED_SECONDS_TO_ORANGE*1.5/DT_DMON)]) == 0 + assert events[int((DISTRACTED_SECONDS_TO_ORANGE*3-0.1)/DT_DMON)].names[0] == EventName.promptDriverDistracted + assert events[int((DISTRACTED_SECONDS_TO_ORANGE*3+0.1)/DT_DMON)].names[0] == EventName.promptDriverDistracted + assert len(events[int((DISTRACTED_SECONDS_TO_ORANGE*3+2.5)/DT_DMON)]) == 0 # engaged, down to orange, driver dodges camera, then comes back still distracted, down to red, \ # driver dodges, and then touches wheel to no avail, disengages and reengages @@ -135,10 +134,10 @@ class TestMonitoring(unittest.TestCase): op_vector[int((DISTRACTED_SECONDS_TO_RED+2*_invisible_time+2.5)/DT_DMON):int((DISTRACTED_SECONDS_TO_RED+2*_invisible_time+3)/DT_DMON)] \ = [False] * int(0.5/DT_DMON) events, _ = self._run_seq(ds_vector, interaction_vector, op_vector, always_false) - self.assertEqual(events[int((DISTRACTED_SECONDS_TO_ORANGE+0.5*_invisible_time)/DT_DMON)].names[0], EventName.promptDriverDistracted) - self.assertEqual(events[int((DISTRACTED_SECONDS_TO_RED+1.5*_invisible_time)/DT_DMON)].names[0], EventName.driverDistracted) - self.assertEqual(events[int((DISTRACTED_SECONDS_TO_RED+2*_invisible_time+1.5)/DT_DMON)].names[0], EventName.driverDistracted) - self.assertTrue(len(events[int((DISTRACTED_SECONDS_TO_RED+2*_invisible_time+3.5)/DT_DMON)]) == 0) + assert events[int((DISTRACTED_SECONDS_TO_ORANGE+0.5*_invisible_time)/DT_DMON)].names[0] == EventName.promptDriverDistracted + assert events[int((DISTRACTED_SECONDS_TO_RED+1.5*_invisible_time)/DT_DMON)].names[0] == EventName.driverDistracted + assert events[int((DISTRACTED_SECONDS_TO_RED+2*_invisible_time+1.5)/DT_DMON)].names[0] == EventName.driverDistracted + assert len(events[int((DISTRACTED_SECONDS_TO_RED+2*_invisible_time+3.5)/DT_DMON)]) == 0 # engaged, invisible driver, down to orange, driver touches wheel; then down to orange again, driver appears # - both actions should clear the alert, but momentary appearance should not @@ -150,15 +149,15 @@ class TestMonitoring(unittest.TestCase): [msg_ATTENTIVE] * int(_visible_time/DT_DMON) interaction_vector[int((INVISIBLE_SECONDS_TO_ORANGE)/DT_DMON):int((INVISIBLE_SECONDS_TO_ORANGE+1)/DT_DMON)] = [True] * int(1/DT_DMON) events, _ = self._run_seq(ds_vector, interaction_vector, 2*always_true, 2*always_false) - self.assertTrue(len(events[int(INVISIBLE_SECONDS_TO_ORANGE*0.5/DT_DMON)]) == 0) - self.assertEqual(events[int((INVISIBLE_SECONDS_TO_ORANGE-0.1)/DT_DMON)].names[0], EventName.promptDriverUnresponsive) - self.assertTrue(len(events[int((INVISIBLE_SECONDS_TO_ORANGE+0.1)/DT_DMON)]) == 0) + assert len(events[int(INVISIBLE_SECONDS_TO_ORANGE*0.5/DT_DMON)]) == 0 + assert events[int((INVISIBLE_SECONDS_TO_ORANGE-0.1)/DT_DMON)].names[0] == EventName.promptDriverUnresponsive + assert len(events[int((INVISIBLE_SECONDS_TO_ORANGE+0.1)/DT_DMON)]) == 0 if _visible_time == 0.5: - self.assertEqual(events[int((INVISIBLE_SECONDS_TO_ORANGE*2+1-0.1)/DT_DMON)].names[0], EventName.promptDriverUnresponsive) - self.assertEqual(events[int((INVISIBLE_SECONDS_TO_ORANGE*2+1+0.1+_visible_time)/DT_DMON)].names[0], EventName.preDriverUnresponsive) + assert events[int((INVISIBLE_SECONDS_TO_ORANGE*2+1-0.1)/DT_DMON)].names[0] == EventName.promptDriverUnresponsive + assert events[int((INVISIBLE_SECONDS_TO_ORANGE*2+1+0.1+_visible_time)/DT_DMON)].names[0] == EventName.preDriverUnresponsive elif _visible_time == 10: - self.assertEqual(events[int((INVISIBLE_SECONDS_TO_ORANGE*2+1-0.1)/DT_DMON)].names[0], EventName.promptDriverUnresponsive) - self.assertTrue(len(events[int((INVISIBLE_SECONDS_TO_ORANGE*2+1+0.1+_visible_time)/DT_DMON)]) == 0) + assert events[int((INVISIBLE_SECONDS_TO_ORANGE*2+1-0.1)/DT_DMON)].names[0] == EventName.promptDriverUnresponsive + assert len(events[int((INVISIBLE_SECONDS_TO_ORANGE*2+1+0.1+_visible_time)/DT_DMON)]) == 0 # engaged, invisible driver, down to red, driver appears and then touches wheel, then disengages/reengages # - only disengage will clear the alert @@ -171,18 +170,18 @@ class TestMonitoring(unittest.TestCase): interaction_vector[int((INVISIBLE_SECONDS_TO_RED+_visible_time)/DT_DMON):int((INVISIBLE_SECONDS_TO_RED+_visible_time+1)/DT_DMON)] = [True] * int(1/DT_DMON) op_vector[int((INVISIBLE_SECONDS_TO_RED+_visible_time+1)/DT_DMON):int((INVISIBLE_SECONDS_TO_RED+_visible_time+0.5)/DT_DMON)] = [False] * int(0.5/DT_DMON) events, _ = self._run_seq(ds_vector, interaction_vector, op_vector, always_false) - self.assertTrue(len(events[int(INVISIBLE_SECONDS_TO_ORANGE*0.5/DT_DMON)]) == 0) - self.assertEqual(events[int((INVISIBLE_SECONDS_TO_ORANGE-0.1)/DT_DMON)].names[0], EventName.promptDriverUnresponsive) - self.assertEqual(events[int((INVISIBLE_SECONDS_TO_RED-0.1)/DT_DMON)].names[0], EventName.driverUnresponsive) - self.assertEqual(events[int((INVISIBLE_SECONDS_TO_RED+0.5*_visible_time)/DT_DMON)].names[0], EventName.driverUnresponsive) - self.assertEqual(events[int((INVISIBLE_SECONDS_TO_RED+_visible_time+0.5)/DT_DMON)].names[0], EventName.driverUnresponsive) - self.assertTrue(len(events[int((INVISIBLE_SECONDS_TO_RED+_visible_time+1+0.1)/DT_DMON)]) == 0) + assert len(events[int(INVISIBLE_SECONDS_TO_ORANGE*0.5/DT_DMON)]) == 0 + assert events[int((INVISIBLE_SECONDS_TO_ORANGE-0.1)/DT_DMON)].names[0] == EventName.promptDriverUnresponsive + assert events[int((INVISIBLE_SECONDS_TO_RED-0.1)/DT_DMON)].names[0] == EventName.driverUnresponsive + assert events[int((INVISIBLE_SECONDS_TO_RED+0.5*_visible_time)/DT_DMON)].names[0] == EventName.driverUnresponsive + assert events[int((INVISIBLE_SECONDS_TO_RED+_visible_time+0.5)/DT_DMON)].names[0] == EventName.driverUnresponsive + assert len(events[int((INVISIBLE_SECONDS_TO_RED+_visible_time+1+0.1)/DT_DMON)]) == 0 # disengaged, always distracted driver # - dm should stay quiet when not engaged def test_pure_dashcam_user(self): events, _ = self._run_seq(always_distracted, always_false, always_false, always_false) - self.assertTrue(sum(len(event) for event in events) == 0) + assert sum(len(event) for event in events) == 0 # engaged, car stops at traffic light, down to orange, no action, then car starts moving # - should only reach green when stopped, but continues counting down on launch @@ -191,10 +190,10 @@ class TestMonitoring(unittest.TestCase): standstill_vector = always_true[:] standstill_vector[int(_redlight_time/DT_DMON):] = [False] * int((TEST_TIMESPAN-_redlight_time)/DT_DMON) events, d_status = self._run_seq(always_distracted, always_false, always_true, standstill_vector) - self.assertEqual(events[int((d_status.settings._DISTRACTED_TIME-d_status.settings._DISTRACTED_PRE_TIME_TILL_TERMINAL+1)/DT_DMON)].names[0], - EventName.preDriverDistracted) - self.assertEqual(events[int((_redlight_time-0.1)/DT_DMON)].names[0], EventName.preDriverDistracted) - self.assertEqual(events[int((_redlight_time+0.5)/DT_DMON)].names[0], EventName.promptDriverDistracted) + assert events[int((d_status.settings._DISTRACTED_TIME-d_status.settings._DISTRACTED_PRE_TIME_TILL_TERMINAL+1)/DT_DMON)].names[0] == \ + EventName.preDriverDistracted + assert events[int((_redlight_time-0.1)/DT_DMON)].names[0] == EventName.preDriverDistracted + assert events[int((_redlight_time+0.5)/DT_DMON)].names[0] == EventName.promptDriverDistracted # engaged, model is somehow uncertain and driver is distracted # - should fall back to wheel touch after uncertain alert @@ -202,13 +201,10 @@ class TestMonitoring(unittest.TestCase): ds_vector = [msg_DISTRACTED_BUT_SOMEHOW_UNCERTAIN] * int(TEST_TIMESPAN/DT_DMON) interaction_vector = always_false[:] events, d_status = self._run_seq(ds_vector, interaction_vector, always_true, always_false) - self.assertTrue(EventName.preDriverUnresponsive in - events[int((INVISIBLE_SECONDS_TO_ORANGE-1+DT_DMON*d_status.settings._HI_STD_FALLBACK_TIME-0.1)/DT_DMON)].names) - self.assertTrue(EventName.promptDriverUnresponsive in - events[int((INVISIBLE_SECONDS_TO_ORANGE-1+DT_DMON*d_status.settings._HI_STD_FALLBACK_TIME+0.1)/DT_DMON)].names) - self.assertTrue(EventName.driverUnresponsive in - events[int((INVISIBLE_SECONDS_TO_RED-1+DT_DMON*d_status.settings._HI_STD_FALLBACK_TIME+0.1)/DT_DMON)].names) + assert EventName.preDriverUnresponsive in \ + events[int((INVISIBLE_SECONDS_TO_ORANGE-1+DT_DMON*d_status.settings._HI_STD_FALLBACK_TIME-0.1)/DT_DMON)].names + assert EventName.promptDriverUnresponsive in \ + events[int((INVISIBLE_SECONDS_TO_ORANGE-1+DT_DMON*d_status.settings._HI_STD_FALLBACK_TIME+0.1)/DT_DMON)].names + assert EventName.driverUnresponsive in \ + events[int((INVISIBLE_SECONDS_TO_RED-1+DT_DMON*d_status.settings._HI_STD_FALLBACK_TIME+0.1)/DT_DMON)].names - -if __name__ == "__main__": - unittest.main() diff --git a/selfdrive/navd/tests/test_map_renderer.py b/selfdrive/navd/tests/test_map_renderer.py index 832e0d1eab..52b594a57e 100755 --- a/selfdrive/navd/tests/test_map_renderer.py +++ b/selfdrive/navd/tests/test_map_renderer.py @@ -3,7 +3,6 @@ import time import numpy as np import os import pytest -import unittest import requests import threading import http.server @@ -66,11 +65,11 @@ class MapBoxInternetDisabledServer(threading.Thread): @pytest.mark.skip(reason="not used") -class TestMapRenderer(unittest.TestCase): +class TestMapRenderer: server: MapBoxInternetDisabledServer @classmethod - def setUpClass(cls): + def setup_class(cls): assert "MAPBOX_TOKEN" in os.environ cls.original_token = os.environ["MAPBOX_TOKEN"] cls.server = MapBoxInternetDisabledServer() @@ -78,10 +77,10 @@ class TestMapRenderer(unittest.TestCase): time.sleep(0.5) # wait for server to startup @classmethod - def tearDownClass(cls) -> None: + def teardown_class(cls) -> None: cls.server.stop() - def setUp(self): + def setup_method(self): self.server.enable_internet() os.environ['MAPS_HOST'] = f'http://localhost:{self.server.port}' @@ -203,15 +202,12 @@ class TestMapRenderer(unittest.TestCase): def assert_stat(stat, nominal, tol=0.3): tol = (nominal / (1+tol)), (nominal * (1+tol)) - self.assertTrue(tol[0] < stat < tol[1], f"{stat} not in tolerance {tol}") + assert tol[0] < stat < tol[1], f"{stat} not in tolerance {tol}" assert_stat(_mean, 0.030) assert_stat(_median, 0.027) assert_stat(_stddev, 0.0078) - self.assertLess(_max, 0.065) - self.assertGreater(_min, 0.015) + assert _max < 0.065 + assert _min > 0.015 - -if __name__ == "__main__": - unittest.main() diff --git a/selfdrive/navd/tests/test_navd.py b/selfdrive/navd/tests/test_navd.py index 61be6cc387..07f9303653 100755 --- a/selfdrive/navd/tests/test_navd.py +++ b/selfdrive/navd/tests/test_navd.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 import json import random -import unittest import numpy as np from parameterized import parameterized @@ -11,12 +10,12 @@ from openpilot.common.params import Params from openpilot.selfdrive.manager.process_config import managed_processes -class TestNavd(unittest.TestCase): - def setUp(self): +class TestNavd: + def setup_method(self): self.params = Params() self.sm = messaging.SubMaster(['navRoute', 'navInstruction']) - def tearDown(self): + def teardown_method(self): managed_processes['navd'].stop() def _check_route(self, start, end, check_coords=True): @@ -57,7 +56,3 @@ class TestNavd(unittest.TestCase): start = {"latitude": random.uniform(-90, 90), "longitude": random.uniform(-180, 180)} end = {"latitude": random.uniform(-90, 90), "longitude": random.uniform(-180, 180)} self._check_route(start, end, check_coords=False) - - -if __name__ == "__main__": - unittest.main() diff --git a/selfdrive/test/helpers.py b/selfdrive/test/helpers.py index ce8918aec3..148d142bb6 100644 --- a/selfdrive/test/helpers.py +++ b/selfdrive/test/helpers.py @@ -3,6 +3,7 @@ import http.server import os import threading import time +import pytest from functools import wraps @@ -32,15 +33,15 @@ def phone_only(f): @wraps(f) def wrap(self, *args, **kwargs): if PC: - self.skipTest("This test is not meant to run on PC") - f(self, *args, **kwargs) + pytest.skip("This test is not meant to run on PC") + return f(self, *args, **kwargs) return wrap def release_only(f): @wraps(f) def wrap(self, *args, **kwargs): if "RELEASE" not in os.environ: - self.skipTest("This test is only for release branches") + pytest.skip("This test is only for release branches") f(self, *args, **kwargs) return wrap diff --git a/selfdrive/test/longitudinal_maneuvers/test_longitudinal.py b/selfdrive/test/longitudinal_maneuvers/test_longitudinal.py index 713b7801f8..0ad6d6d4fd 100755 --- a/selfdrive/test/longitudinal_maneuvers/test_longitudinal.py +++ b/selfdrive/test/longitudinal_maneuvers/test_longitudinal.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 import itertools -import unittest from parameterized import parameterized_class from openpilot.selfdrive.controls.lib.longitudinal_mpc_lib.long_mpc import STOP_DISTANCE @@ -144,17 +143,13 @@ def create_maneuvers(kwargs): @parameterized_class(("e2e", "force_decel"), itertools.product([True, False], repeat=2)) -class LongitudinalControl(unittest.TestCase): +class TestLongitudinalControl: e2e: bool force_decel: bool - def test_maneuver(self): + def test_maneuver(self, subtests): for maneuver in create_maneuvers({"e2e": self.e2e, "force_decel": self.force_decel}): - with self.subTest(title=maneuver.title, e2e=maneuver.e2e, force_decel=maneuver.force_decel): + with subtests.test(title=maneuver.title, e2e=maneuver.e2e, force_decel=maneuver.force_decel): print(maneuver.title, f'in {"e2e" if maneuver.e2e else "acc"} mode') valid, _ = maneuver.evaluate() - self.assertTrue(valid) - - -if __name__ == "__main__": - unittest.main(failfast=True) + assert valid diff --git a/selfdrive/test/process_replay/test_fuzzy.py b/selfdrive/test/process_replay/test_fuzzy.py index 6c81119fbf..d295092b20 100755 --- a/selfdrive/test/process_replay/test_fuzzy.py +++ b/selfdrive/test/process_replay/test_fuzzy.py @@ -3,7 +3,6 @@ import copy from hypothesis import given, HealthCheck, Phase, settings import hypothesis.strategies as st from parameterized import parameterized -import unittest from cereal import log from openpilot.selfdrive.car.toyota.values import CAR as TOYOTA @@ -17,7 +16,7 @@ NOT_TESTED = ['controlsd', 'plannerd', 'calibrationd', 'dmonitoringd', 'paramsd' TEST_CASES = [(cfg.proc_name, copy.deepcopy(cfg)) for cfg in pr.CONFIGS if cfg.proc_name not in NOT_TESTED] -class TestFuzzProcesses(unittest.TestCase): +class TestFuzzProcesses: # TODO: make this faster and increase examples @parameterized.expand(TEST_CASES) @@ -28,6 +27,3 @@ class TestFuzzProcesses(unittest.TestCase): lr = [log.Event.new_message(**m).as_reader() for m in msgs] cfg.timeout = 5 pr.replay_process(cfg, lr, fingerprint=TOYOTA.TOYOTA_COROLLA_TSS2, disable_progress=True) - -if __name__ == "__main__": - unittest.main() diff --git a/selfdrive/test/process_replay/test_regen.py b/selfdrive/test/process_replay/test_regen.py index d989635497..c27d9e8f7b 100755 --- a/selfdrive/test/process_replay/test_regen.py +++ b/selfdrive/test/process_replay/test_regen.py @@ -1,7 +1,5 @@ #!/usr/bin/env python3 -import unittest - from parameterized import parameterized from openpilot.selfdrive.test.process_replay.regen import regen_segment, DummyFrameReader @@ -30,7 +28,7 @@ def ci_setup_data_readers(route, sidx): return lr, frs -class TestRegen(unittest.TestCase): +class TestRegen: @parameterized.expand(TESTED_SEGMENTS) def test_engaged(self, case_name, segment): route, sidx = segment.rsplit("--", 1) @@ -38,8 +36,4 @@ class TestRegen(unittest.TestCase): output_logs = regen_segment(lr, frs, disable_tqdm=True) engaged = check_openpilot_enabled(output_logs) - self.assertTrue(engaged, f"openpilot not engaged in {case_name}") - - -if __name__=='__main__': - unittest.main() + assert engaged, f"openpilot not engaged in {case_name}" diff --git a/selfdrive/test/test_onroad.py b/selfdrive/test/test_onroad.py index cd0846c894..9aca9afcd9 100755 --- a/selfdrive/test/test_onroad.py +++ b/selfdrive/test/test_onroad.py @@ -10,7 +10,6 @@ import shutil import subprocess import time import numpy as np -import unittest from collections import Counter, defaultdict from functools import cached_property from pathlib import Path @@ -102,10 +101,10 @@ def cputime_total(ct): @pytest.mark.tici -class TestOnroad(unittest.TestCase): +class TestOnroad: @classmethod - def setUpClass(cls): + def setup_class(cls): if "DEBUG" in os.environ: segs = filter(lambda x: os.path.exists(os.path.join(x, "rlog")), Path(Paths.log_root()).iterdir()) segs = sorted(segs, key=lambda x: x.stat().st_mtime) @@ -181,7 +180,7 @@ class TestOnroad(unittest.TestCase): msgs[m.which()].append(m) return msgs - def test_service_frequencies(self): + def test_service_frequencies(self, subtests): for s, msgs in self.service_msgs.items(): if s in ('initData', 'sentinel'): continue @@ -190,18 +189,18 @@ class TestOnroad(unittest.TestCase): if s in ('ubloxGnss', 'ubloxRaw', 'gnssMeasurements', 'gpsLocation', 'gpsLocationExternal', 'qcomGnss'): continue - with self.subTest(service=s): + with subtests.test(service=s): assert len(msgs) >= math.floor(SERVICE_LIST[s].frequency*55) def test_cloudlog_size(self): msgs = [m for m in self.lr if m.which() == 'logMessage'] total_size = sum(len(m.as_builder().to_bytes()) for m in msgs) - self.assertLess(total_size, 3.5e5) + assert total_size < 3.5e5 cnt = Counter(json.loads(m.logMessage)['filename'] for m in msgs) big_logs = [f for f, n in cnt.most_common(3) if n / sum(cnt.values()) > 30.] - self.assertEqual(len(big_logs), 0, f"Log spam: {big_logs}") + assert len(big_logs) == 0, f"Log spam: {big_logs}" def test_log_sizes(self): for f, sz in self.log_sizes.items(): @@ -230,15 +229,15 @@ class TestOnroad(unittest.TestCase): result += "------------------------------------------------\n" print(result) - self.assertLess(max(ts), 250.) - self.assertLess(np.mean(ts), 10.) + assert max(ts) < 250. + assert np.mean(ts) < 10. #self.assertLess(np.std(ts), 5.) # some slow frames are expected since camerad/modeld can preempt ui veryslow = [x for x in ts if x > 40.] assert len(veryslow) < 5, f"Too many slow frame draw times: {veryslow}" - def test_cpu_usage(self): + def test_cpu_usage(self, subtests): result = "\n" result += "------------------------------------------------\n" result += "------------------ CPU Usage -------------------\n" @@ -286,12 +285,12 @@ class TestOnroad(unittest.TestCase): # Ensure there's no missing procs all_procs = {p.name for p in self.service_msgs['managerState'][0].managerState.processes if p.shouldBeRunning} for p in all_procs: - with self.subTest(proc=p): + with subtests.test(proc=p): assert any(p in pp for pp in PROCS.keys()), f"Expected CPU usage missing for {p}" # total CPU check procs_tot = sum([(max(x) if isinstance(x, tuple) else x) for x in PROCS.values()]) - with self.subTest(name="total CPU"): + with subtests.test(name="total CPU"): assert procs_tot < MAX_TOTAL_CPU, "Total CPU budget exceeded" result += "------------------------------------------------\n" result += f"Total allocated CPU usage is {procs_tot}%, budget is {MAX_TOTAL_CPU}%, {MAX_TOTAL_CPU-procs_tot:.1f}% left\n" @@ -299,7 +298,7 @@ class TestOnroad(unittest.TestCase): print(result) - self.assertTrue(cpu_ok) + assert cpu_ok def test_memory_usage(self): mems = [m.deviceState.memoryUsagePercent for m in self.service_msgs['deviceState']] @@ -307,10 +306,10 @@ class TestOnroad(unittest.TestCase): # check for big leaks. note that memory usage is # expected to go up while the MSGQ buffers fill up - self.assertLessEqual(max(mems) - min(mems), 3.0) + assert max(mems) - min(mems) <= 3.0 def test_gpu_usage(self): - self.assertEqual(self.gpu_procs, {"weston", "ui", "camerad", "selfdrive.modeld.modeld"}) + assert self.gpu_procs == {"weston", "ui", "camerad", "selfdrive.modeld.modeld"} def test_camera_processing_time(self): result = "\n" @@ -319,14 +318,14 @@ class TestOnroad(unittest.TestCase): result += "------------------------------------------------\n" ts = [getattr(m, m.which()).processingTime for m in self.lr if 'CameraState' in m.which()] - self.assertLess(min(ts), 0.025, f"high execution time: {min(ts)}") + assert min(ts) < 0.025, f"high execution time: {min(ts)}" result += f"execution time: min {min(ts):.5f}s\n" result += f"execution time: max {max(ts):.5f}s\n" result += f"execution time: mean {np.mean(ts):.5f}s\n" result += "------------------------------------------------\n" print(result) - @unittest.skip("TODO: enable once timings are fixed") + @pytest.mark.skip("TODO: enable once timings are fixed") def test_camera_frame_timings(self): result = "\n" result += "------------------------------------------------\n" @@ -336,7 +335,7 @@ class TestOnroad(unittest.TestCase): ts = [getattr(m, m.which()).timestampSof for m in self.lr if name in m.which()] d_ms = np.diff(ts) / 1e6 d50 = np.abs(d_ms-50) - self.assertLess(max(d50), 1.0, f"high sof delta vs 50ms: {max(d50)}") + assert max(d50) < 1.0, f"high sof delta vs 50ms: {max(d50)}" result += f"{name} sof delta vs 50ms: min {min(d50):.5f}s\n" result += f"{name} sof delta vs 50ms: max {max(d50):.5f}s\n" result += f"{name} sof delta vs 50ms: mean {d50.mean():.5f}s\n" @@ -352,8 +351,8 @@ class TestOnroad(unittest.TestCase): cfgs = [("longitudinalPlan", 0.05, 0.05),] for (s, instant_max, avg_max) in cfgs: ts = [getattr(m, s).solverExecutionTime for m in self.service_msgs[s]] - self.assertLess(max(ts), instant_max, f"high '{s}' execution time: {max(ts)}") - self.assertLess(np.mean(ts), avg_max, f"high avg '{s}' execution time: {np.mean(ts)}") + assert max(ts) < instant_max, f"high '{s}' execution time: {max(ts)}" + assert np.mean(ts) < avg_max, f"high avg '{s}' execution time: {np.mean(ts)}" result += f"'{s}' execution time: min {min(ts):.5f}s\n" result += f"'{s}' execution time: max {max(ts):.5f}s\n" result += f"'{s}' execution time: mean {np.mean(ts):.5f}s\n" @@ -372,8 +371,8 @@ class TestOnroad(unittest.TestCase): ] for (s, instant_max, avg_max) in cfgs: ts = [getattr(m, s).modelExecutionTime for m in self.service_msgs[s]] - self.assertLess(max(ts), instant_max, f"high '{s}' execution time: {max(ts)}") - self.assertLess(np.mean(ts), avg_max, f"high avg '{s}' execution time: {np.mean(ts)}") + assert max(ts) < instant_max, f"high '{s}' execution time: {max(ts)}" + assert np.mean(ts) < avg_max, f"high avg '{s}' execution time: {np.mean(ts)}" result += f"'{s}' execution time: min {min(ts):.5f}s\n" result += f"'{s}' execution time: max {max(ts):.5f}s\n" result += f"'{s}' execution time: mean {np.mean(ts):.5f}s\n" @@ -409,7 +408,7 @@ class TestOnroad(unittest.TestCase): result += f"{''.ljust(40)} {np.max(np.absolute([np.max(ts)/dt, np.min(ts)/dt]))} {np.std(ts)/dt}\n" result += "="*67 print(result) - self.assertTrue(passed) + assert passed @release_only def test_startup(self): @@ -420,7 +419,7 @@ class TestOnroad(unittest.TestCase): startup_alert = msg.controlsState.alertText1 break expected = EVENTS[car.CarEvent.EventName.startup][ET.PERMANENT].alert_text_1 - self.assertEqual(startup_alert, expected, "wrong startup alert") + assert startup_alert == expected, "wrong startup alert" def test_engagable(self): no_entries = Counter() @@ -432,7 +431,3 @@ class TestOnroad(unittest.TestCase): eng = [m.controlsState.engageable for m in self.service_msgs['controlsState']] assert all(eng), \ f"Not engageable for whole segment:\n- controlsState.engageable: {Counter(eng)}\n- No entry events: {no_entries}" - - -if __name__ == "__main__": - unittest.main() diff --git a/selfdrive/test/test_updated.py b/selfdrive/test/test_updated.py index dd79e03de4..5220cfb288 100755 --- a/selfdrive/test/test_updated.py +++ b/selfdrive/test/test_updated.py @@ -4,7 +4,6 @@ import os import pytest import time import tempfile -import unittest import shutil import signal import subprocess @@ -15,9 +14,9 @@ from openpilot.common.params import Params @pytest.mark.tici -class TestUpdated(unittest.TestCase): +class TestUpdated: - def setUp(self): + def setup_method(self): self.updated_proc = None self.tmp_dir = tempfile.TemporaryDirectory() @@ -59,7 +58,7 @@ class TestUpdated(unittest.TestCase): self.params.clear_all() os.sync() - def tearDown(self): + def teardown_method(self): try: if self.updated_proc is not None: self.updated_proc.terminate() @@ -167,7 +166,7 @@ class TestUpdated(unittest.TestCase): t = self._read_param("LastUpdateTime") last_update_time = datetime.datetime.fromisoformat(t) td = datetime.datetime.utcnow() - last_update_time - self.assertLess(td.total_seconds(), 10) + assert td.total_seconds() < 10 self.params.remove("LastUpdateTime") # wait a bit for the rest of the params to be written @@ -175,13 +174,13 @@ class TestUpdated(unittest.TestCase): # check params update = self._read_param("UpdateAvailable") - self.assertEqual(update == "1", update_available, f"UpdateAvailable: {repr(update)}") - self.assertEqual(self._read_param("UpdateFailedCount"), "0") + assert update == "1" == update_available, f"UpdateAvailable: {repr(update)}" + assert self._read_param("UpdateFailedCount") == "0" # TODO: check that the finalized update actually matches remote # check the .overlay_init and .overlay_consistent flags - self.assertTrue(os.path.isfile(os.path.join(self.basedir, ".overlay_init"))) - self.assertEqual(os.path.isfile(os.path.join(self.finalized_dir, ".overlay_consistent")), update_available) + assert os.path.isfile(os.path.join(self.basedir, ".overlay_init")) + assert os.path.isfile(os.path.join(self.finalized_dir, ".overlay_consistent")) == update_available # *** test cases *** @@ -214,7 +213,7 @@ class TestUpdated(unittest.TestCase): self._check_update_state(True) # Let the updater run for 10 cycles, and write an update every cycle - @unittest.skip("need to make this faster") + @pytest.mark.skip("need to make this faster") def test_update_loop(self): self._start_updater() @@ -243,12 +242,12 @@ class TestUpdated(unittest.TestCase): # run another cycle, should have a new mtime self._wait_for_update(clear_param=True) second_mtime = os.path.getmtime(overlay_init_fn) - self.assertTrue(first_mtime != second_mtime) + assert first_mtime != second_mtime # run another cycle, mtime should be same as last cycle self._wait_for_update(clear_param=True) new_mtime = os.path.getmtime(overlay_init_fn) - self.assertTrue(second_mtime == new_mtime) + assert second_mtime == new_mtime # Make sure updated exits if another instance is running def test_multiple_instances(self): @@ -260,7 +259,7 @@ class TestUpdated(unittest.TestCase): # start another instance second_updated = self._get_updated_proc() ret_code = second_updated.wait(timeout=5) - self.assertTrue(ret_code is not None) + assert ret_code is not None # *** test cases with NEOS updates *** @@ -277,10 +276,10 @@ class TestUpdated(unittest.TestCase): self._start_updater() self._wait_for_update(clear_param=True) self._check_update_state(False) - self.assertFalse(os.path.isdir(self.neosupdate_dir)) + assert not os.path.isdir(self.neosupdate_dir) # Let the updater run with no update for a cycle, then write an update - @unittest.skip("TODO: only runs on device") + @pytest.mark.skip("TODO: only runs on device") def test_update_with_neos_update(self): # bump the NEOS version and commit it self._run([ @@ -295,8 +294,4 @@ class TestUpdated(unittest.TestCase): self._check_update_state(True) # TODO: more comprehensive check - self.assertTrue(os.path.isdir(self.neosupdate_dir)) - - -if __name__ == "__main__": - unittest.main() + assert os.path.isdir(self.neosupdate_dir) diff --git a/selfdrive/test/test_valgrind_replay.py b/selfdrive/test/test_valgrind_replay.py index 75520df91b..dffd60df97 100755 --- a/selfdrive/test/test_valgrind_replay.py +++ b/selfdrive/test/test_valgrind_replay.py @@ -2,7 +2,6 @@ import os import threading import time -import unittest import subprocess import signal @@ -35,7 +34,7 @@ CONFIGS = [ ] -class TestValgrind(unittest.TestCase): +class TestValgrind: def extract_leak_sizes(self, log): if "All heap blocks were freed -- no leaks are possible" in log: return (0,0,0) @@ -110,8 +109,4 @@ class TestValgrind(unittest.TestCase): while self.leak is None: time.sleep(0.1) # Wait for the valgrind to finish - self.assertFalse(self.leak) - - -if __name__ == "__main__": - unittest.main() + assert not self.leak diff --git a/selfdrive/thermald/tests/test_fan_controller.py b/selfdrive/thermald/tests/test_fan_controller.py index 7081e1353e..c2b1a64509 100755 --- a/selfdrive/thermald/tests/test_fan_controller.py +++ b/selfdrive/thermald/tests/test_fan_controller.py @@ -1,17 +1,15 @@ #!/usr/bin/env python3 -import unittest -from unittest.mock import Mock, patch -from parameterized import parameterized +import pytest from openpilot.selfdrive.thermald.fan_controller import TiciFanController -ALL_CONTROLLERS = [(TiciFanController,)] +ALL_CONTROLLERS = [TiciFanController] -def patched_controller(controller_class): - with patch("os.system", new=Mock()): - return controller_class() +def patched_controller(mocker, controller_class): + mocker.patch("os.system", new=mocker.Mock()) + return controller_class() -class TestFanController(unittest.TestCase): +class TestFanController: def wind_up(self, controller, ignition=True): for _ in range(1000): controller.update(100, ignition) @@ -20,37 +18,34 @@ class TestFanController(unittest.TestCase): for _ in range(1000): controller.update(10, ignition) - @parameterized.expand(ALL_CONTROLLERS) - def test_hot_onroad(self, controller_class): - controller = patched_controller(controller_class) + @pytest.mark.parametrize("controller_class", ALL_CONTROLLERS) + def test_hot_onroad(self, mocker, controller_class): + controller = patched_controller(mocker, controller_class) self.wind_up(controller) - self.assertGreaterEqual(controller.update(100, True), 70) + assert controller.update(100, True) >= 70 - @parameterized.expand(ALL_CONTROLLERS) - def test_offroad_limits(self, controller_class): - controller = patched_controller(controller_class) + @pytest.mark.parametrize("controller_class", ALL_CONTROLLERS) + def test_offroad_limits(self, mocker, controller_class): + controller = patched_controller(mocker, controller_class) self.wind_up(controller) - self.assertLessEqual(controller.update(100, False), 30) + assert controller.update(100, False) <= 30 - @parameterized.expand(ALL_CONTROLLERS) - def test_no_fan_wear(self, controller_class): - controller = patched_controller(controller_class) + @pytest.mark.parametrize("controller_class", ALL_CONTROLLERS) + def test_no_fan_wear(self, mocker, controller_class): + controller = patched_controller(mocker, controller_class) self.wind_down(controller) - self.assertEqual(controller.update(10, False), 0) + assert controller.update(10, False) == 0 - @parameterized.expand(ALL_CONTROLLERS) - def test_limited(self, controller_class): - controller = patched_controller(controller_class) + @pytest.mark.parametrize("controller_class", ALL_CONTROLLERS) + def test_limited(self, mocker, controller_class): + controller = patched_controller(mocker, controller_class) self.wind_up(controller, True) - self.assertEqual(controller.update(100, True), 100) + assert controller.update(100, True) == 100 - @parameterized.expand(ALL_CONTROLLERS) - def test_windup_speed(self, controller_class): - controller = patched_controller(controller_class) + @pytest.mark.parametrize("controller_class", ALL_CONTROLLERS) + def test_windup_speed(self, mocker, controller_class): + controller = patched_controller(mocker, controller_class) self.wind_down(controller, True) for _ in range(10): controller.update(90, True) - self.assertGreaterEqual(controller.update(90, True), 60) - -if __name__ == "__main__": - unittest.main() + assert controller.update(90, True) >= 60 diff --git a/selfdrive/thermald/tests/test_power_monitoring.py b/selfdrive/thermald/tests/test_power_monitoring.py index c3a890f068..f68191475b 100755 --- a/selfdrive/thermald/tests/test_power_monitoring.py +++ b/selfdrive/thermald/tests/test_power_monitoring.py @@ -1,12 +1,10 @@ #!/usr/bin/env python3 -import unittest -from unittest.mock import patch +import pytest from openpilot.common.params import Params from openpilot.selfdrive.thermald.power_monitoring import PowerMonitoring, CAR_BATTERY_CAPACITY_uWh, \ CAR_CHARGING_RATE_W, VBATT_PAUSE_CHARGING, DELAY_SHUTDOWN_TIME_S - # Create fake time ssb = 0. def mock_time_monotonic(): @@ -18,163 +16,169 @@ TEST_DURATION_S = 50 GOOD_VOLTAGE = 12 * 1e3 VOLTAGE_BELOW_PAUSE_CHARGING = (VBATT_PAUSE_CHARGING - 1) * 1e3 -def pm_patch(name, value, constant=False): +def pm_patch(mocker, name, value, constant=False): if constant: - return patch(f"openpilot.selfdrive.thermald.power_monitoring.{name}", value) - return patch(f"openpilot.selfdrive.thermald.power_monitoring.{name}", return_value=value) + mocker.patch(f"openpilot.selfdrive.thermald.power_monitoring.{name}", value) + else: + mocker.patch(f"openpilot.selfdrive.thermald.power_monitoring.{name}", return_value=value) + + +@pytest.fixture(autouse=True) +def mock_time(mocker): + mocker.patch("time.monotonic", mock_time_monotonic) -@patch("time.monotonic", new=mock_time_monotonic) -class TestPowerMonitoring(unittest.TestCase): - def setUp(self): +class TestPowerMonitoring: + def setup_method(self): self.params = Params() # Test to see that it doesn't do anything when pandaState is None - def test_pandaState_present(self): + def test_panda_state_present(self): pm = PowerMonitoring() for _ in range(10): pm.calculate(None, None) - self.assertEqual(pm.get_power_used(), 0) - self.assertEqual(pm.get_car_battery_capacity(), (CAR_BATTERY_CAPACITY_uWh / 10)) + assert pm.get_power_used() == 0 + assert pm.get_car_battery_capacity() == (CAR_BATTERY_CAPACITY_uWh / 10) # Test to see that it doesn't integrate offroad when ignition is True def test_offroad_ignition(self): pm = PowerMonitoring() for _ in range(10): pm.calculate(GOOD_VOLTAGE, True) - self.assertEqual(pm.get_power_used(), 0) + assert pm.get_power_used() == 0 # Test to see that it integrates with discharging battery - def test_offroad_integration_discharging(self): + def test_offroad_integration_discharging(self, mocker): POWER_DRAW = 4 - with pm_patch("HARDWARE.get_current_power_draw", POWER_DRAW): - pm = PowerMonitoring() - for _ in range(TEST_DURATION_S + 1): - pm.calculate(GOOD_VOLTAGE, False) - expected_power_usage = ((TEST_DURATION_S/3600) * POWER_DRAW * 1e6) - self.assertLess(abs(pm.get_power_used() - expected_power_usage), 10) + pm_patch(mocker, "HARDWARE.get_current_power_draw", POWER_DRAW) + pm = PowerMonitoring() + for _ in range(TEST_DURATION_S + 1): + pm.calculate(GOOD_VOLTAGE, False) + expected_power_usage = ((TEST_DURATION_S/3600) * POWER_DRAW * 1e6) + assert abs(pm.get_power_used() - expected_power_usage) < 10 # Test to check positive integration of car_battery_capacity - def test_car_battery_integration_onroad(self): + def test_car_battery_integration_onroad(self, mocker): POWER_DRAW = 4 - with pm_patch("HARDWARE.get_current_power_draw", POWER_DRAW): - pm = PowerMonitoring() - pm.car_battery_capacity_uWh = 0 - for _ in range(TEST_DURATION_S + 1): - pm.calculate(GOOD_VOLTAGE, True) - expected_capacity = ((TEST_DURATION_S/3600) * CAR_CHARGING_RATE_W * 1e6) - self.assertLess(abs(pm.get_car_battery_capacity() - expected_capacity), 10) + pm_patch(mocker, "HARDWARE.get_current_power_draw", POWER_DRAW) + pm = PowerMonitoring() + pm.car_battery_capacity_uWh = 0 + for _ in range(TEST_DURATION_S + 1): + pm.calculate(GOOD_VOLTAGE, True) + expected_capacity = ((TEST_DURATION_S/3600) * CAR_CHARGING_RATE_W * 1e6) + assert abs(pm.get_car_battery_capacity() - expected_capacity) < 10 # Test to check positive integration upper limit - def test_car_battery_integration_upper_limit(self): + def test_car_battery_integration_upper_limit(self, mocker): POWER_DRAW = 4 - with pm_patch("HARDWARE.get_current_power_draw", POWER_DRAW): - pm = PowerMonitoring() - pm.car_battery_capacity_uWh = CAR_BATTERY_CAPACITY_uWh - 1000 - for _ in range(TEST_DURATION_S + 1): - pm.calculate(GOOD_VOLTAGE, True) - estimated_capacity = CAR_BATTERY_CAPACITY_uWh + (CAR_CHARGING_RATE_W / 3600 * 1e6) - self.assertLess(abs(pm.get_car_battery_capacity() - estimated_capacity), 10) + pm_patch(mocker, "HARDWARE.get_current_power_draw", POWER_DRAW) + pm = PowerMonitoring() + pm.car_battery_capacity_uWh = CAR_BATTERY_CAPACITY_uWh - 1000 + for _ in range(TEST_DURATION_S + 1): + pm.calculate(GOOD_VOLTAGE, True) + estimated_capacity = CAR_BATTERY_CAPACITY_uWh + (CAR_CHARGING_RATE_W / 3600 * 1e6) + assert abs(pm.get_car_battery_capacity() - estimated_capacity) < 10 # Test to check negative integration of car_battery_capacity - def test_car_battery_integration_offroad(self): + def test_car_battery_integration_offroad(self, mocker): POWER_DRAW = 4 - with pm_patch("HARDWARE.get_current_power_draw", POWER_DRAW): - pm = PowerMonitoring() - pm.car_battery_capacity_uWh = CAR_BATTERY_CAPACITY_uWh - for _ in range(TEST_DURATION_S + 1): - pm.calculate(GOOD_VOLTAGE, False) - expected_capacity = CAR_BATTERY_CAPACITY_uWh - ((TEST_DURATION_S/3600) * POWER_DRAW * 1e6) - self.assertLess(abs(pm.get_car_battery_capacity() - expected_capacity), 10) + pm_patch(mocker, "HARDWARE.get_current_power_draw", POWER_DRAW) + pm = PowerMonitoring() + pm.car_battery_capacity_uWh = CAR_BATTERY_CAPACITY_uWh + for _ in range(TEST_DURATION_S + 1): + pm.calculate(GOOD_VOLTAGE, False) + expected_capacity = CAR_BATTERY_CAPACITY_uWh - ((TEST_DURATION_S/3600) * POWER_DRAW * 1e6) + assert abs(pm.get_car_battery_capacity() - expected_capacity) < 10 # Test to check negative integration lower limit - def test_car_battery_integration_lower_limit(self): + def test_car_battery_integration_lower_limit(self, mocker): POWER_DRAW = 4 - with pm_patch("HARDWARE.get_current_power_draw", POWER_DRAW): - pm = PowerMonitoring() - pm.car_battery_capacity_uWh = 1000 - for _ in range(TEST_DURATION_S + 1): - pm.calculate(GOOD_VOLTAGE, False) - estimated_capacity = 0 - ((1/3600) * POWER_DRAW * 1e6) - self.assertLess(abs(pm.get_car_battery_capacity() - estimated_capacity), 10) + pm_patch(mocker, "HARDWARE.get_current_power_draw", POWER_DRAW) + pm = PowerMonitoring() + pm.car_battery_capacity_uWh = 1000 + for _ in range(TEST_DURATION_S + 1): + pm.calculate(GOOD_VOLTAGE, False) + estimated_capacity = 0 - ((1/3600) * POWER_DRAW * 1e6) + assert abs(pm.get_car_battery_capacity() - estimated_capacity) < 10 # Test to check policy of stopping charging after MAX_TIME_OFFROAD_S - def test_max_time_offroad(self): + def test_max_time_offroad(self, mocker): MOCKED_MAX_OFFROAD_TIME = 3600 POWER_DRAW = 0 # To stop shutting down for other reasons - with pm_patch("MAX_TIME_OFFROAD_S", MOCKED_MAX_OFFROAD_TIME, constant=True), pm_patch("HARDWARE.get_current_power_draw", POWER_DRAW): - pm = PowerMonitoring() - pm.car_battery_capacity_uWh = CAR_BATTERY_CAPACITY_uWh - start_time = ssb - ignition = False - while ssb <= start_time + MOCKED_MAX_OFFROAD_TIME: - pm.calculate(GOOD_VOLTAGE, ignition) - if (ssb - start_time) % 1000 == 0 and ssb < start_time + MOCKED_MAX_OFFROAD_TIME: - self.assertFalse(pm.should_shutdown(ignition, True, start_time, False)) - self.assertTrue(pm.should_shutdown(ignition, True, start_time, False)) - - def test_car_voltage(self): + pm_patch(mocker, "MAX_TIME_OFFROAD_S", MOCKED_MAX_OFFROAD_TIME, constant=True) + pm_patch(mocker, "HARDWARE.get_current_power_draw", POWER_DRAW) + pm = PowerMonitoring() + pm.car_battery_capacity_uWh = CAR_BATTERY_CAPACITY_uWh + start_time = ssb + ignition = False + while ssb <= start_time + MOCKED_MAX_OFFROAD_TIME: + pm.calculate(GOOD_VOLTAGE, ignition) + if (ssb - start_time) % 1000 == 0 and ssb < start_time + MOCKED_MAX_OFFROAD_TIME: + assert not pm.should_shutdown(ignition, True, start_time, False) + assert pm.should_shutdown(ignition, True, start_time, False) + + def test_car_voltage(self, mocker): POWER_DRAW = 0 # To stop shutting down for other reasons TEST_TIME = 350 VOLTAGE_SHUTDOWN_MIN_OFFROAD_TIME_S = 50 - with pm_patch("VOLTAGE_SHUTDOWN_MIN_OFFROAD_TIME_S", VOLTAGE_SHUTDOWN_MIN_OFFROAD_TIME_S, constant=True), \ - pm_patch("HARDWARE.get_current_power_draw", POWER_DRAW): - pm = PowerMonitoring() - pm.car_battery_capacity_uWh = CAR_BATTERY_CAPACITY_uWh - ignition = False - start_time = ssb - for i in range(TEST_TIME): - pm.calculate(VOLTAGE_BELOW_PAUSE_CHARGING, ignition) - if i % 10 == 0: - self.assertEqual(pm.should_shutdown(ignition, True, start_time, True), - (pm.car_voltage_mV < VBATT_PAUSE_CHARGING * 1e3 and - (ssb - start_time) > VOLTAGE_SHUTDOWN_MIN_OFFROAD_TIME_S and - (ssb - start_time) > DELAY_SHUTDOWN_TIME_S)) - self.assertTrue(pm.should_shutdown(ignition, True, start_time, True)) + pm_patch(mocker, "VOLTAGE_SHUTDOWN_MIN_OFFROAD_TIME_S", VOLTAGE_SHUTDOWN_MIN_OFFROAD_TIME_S, constant=True) + pm_patch(mocker, "HARDWARE.get_current_power_draw", POWER_DRAW) + pm = PowerMonitoring() + pm.car_battery_capacity_uWh = CAR_BATTERY_CAPACITY_uWh + ignition = False + start_time = ssb + for i in range(TEST_TIME): + pm.calculate(VOLTAGE_BELOW_PAUSE_CHARGING, ignition) + if i % 10 == 0: + assert pm.should_shutdown(ignition, True, start_time, True) == \ + (pm.car_voltage_mV < VBATT_PAUSE_CHARGING * 1e3 and \ + (ssb - start_time) > VOLTAGE_SHUTDOWN_MIN_OFFROAD_TIME_S and \ + (ssb - start_time) > DELAY_SHUTDOWN_TIME_S) + assert pm.should_shutdown(ignition, True, start_time, True) # Test to check policy of not stopping charging when DisablePowerDown is set - def test_disable_power_down(self): + def test_disable_power_down(self, mocker): POWER_DRAW = 0 # To stop shutting down for other reasons TEST_TIME = 100 self.params.put_bool("DisablePowerDown", True) - with pm_patch("HARDWARE.get_current_power_draw", POWER_DRAW): - pm = PowerMonitoring() - pm.car_battery_capacity_uWh = CAR_BATTERY_CAPACITY_uWh - ignition = False - for i in range(TEST_TIME): - pm.calculate(VOLTAGE_BELOW_PAUSE_CHARGING, ignition) - if i % 10 == 0: - self.assertFalse(pm.should_shutdown(ignition, True, ssb, False)) - self.assertFalse(pm.should_shutdown(ignition, True, ssb, False)) + pm_patch(mocker, "HARDWARE.get_current_power_draw", POWER_DRAW) + pm = PowerMonitoring() + pm.car_battery_capacity_uWh = CAR_BATTERY_CAPACITY_uWh + ignition = False + for i in range(TEST_TIME): + pm.calculate(VOLTAGE_BELOW_PAUSE_CHARGING, ignition) + if i % 10 == 0: + assert not pm.should_shutdown(ignition, True, ssb, False) + assert not pm.should_shutdown(ignition, True, ssb, False) # Test to check policy of not stopping charging when ignition - def test_ignition(self): + def test_ignition(self, mocker): POWER_DRAW = 0 # To stop shutting down for other reasons TEST_TIME = 100 - with pm_patch("HARDWARE.get_current_power_draw", POWER_DRAW): - pm = PowerMonitoring() - pm.car_battery_capacity_uWh = CAR_BATTERY_CAPACITY_uWh - ignition = True - for i in range(TEST_TIME): - pm.calculate(VOLTAGE_BELOW_PAUSE_CHARGING, ignition) - if i % 10 == 0: - self.assertFalse(pm.should_shutdown(ignition, True, ssb, False)) - self.assertFalse(pm.should_shutdown(ignition, True, ssb, False)) + pm_patch(mocker, "HARDWARE.get_current_power_draw", POWER_DRAW) + pm = PowerMonitoring() + pm.car_battery_capacity_uWh = CAR_BATTERY_CAPACITY_uWh + ignition = True + for i in range(TEST_TIME): + pm.calculate(VOLTAGE_BELOW_PAUSE_CHARGING, ignition) + if i % 10 == 0: + assert not pm.should_shutdown(ignition, True, ssb, False) + assert not pm.should_shutdown(ignition, True, ssb, False) # Test to check policy of not stopping charging when harness is not connected - def test_harness_connection(self): + def test_harness_connection(self, mocker): POWER_DRAW = 0 # To stop shutting down for other reasons TEST_TIME = 100 - with pm_patch("HARDWARE.get_current_power_draw", POWER_DRAW): - pm = PowerMonitoring() - pm.car_battery_capacity_uWh = CAR_BATTERY_CAPACITY_uWh + pm_patch(mocker, "HARDWARE.get_current_power_draw", POWER_DRAW) + pm = PowerMonitoring() + pm.car_battery_capacity_uWh = CAR_BATTERY_CAPACITY_uWh - ignition = False - for i in range(TEST_TIME): - pm.calculate(VOLTAGE_BELOW_PAUSE_CHARGING, ignition) - if i % 10 == 0: - self.assertFalse(pm.should_shutdown(ignition, False, ssb, False)) - self.assertFalse(pm.should_shutdown(ignition, False, ssb, False)) + ignition = False + for i in range(TEST_TIME): + pm.calculate(VOLTAGE_BELOW_PAUSE_CHARGING, ignition) + if i % 10 == 0: + assert not pm.should_shutdown(ignition, False, ssb, False) + assert not pm.should_shutdown(ignition, False, ssb, False) def test_delay_shutdown_time(self): pm = PowerMonitoring() @@ -186,15 +190,11 @@ class TestPowerMonitoring(unittest.TestCase): pm.calculate(VOLTAGE_BELOW_PAUSE_CHARGING, ignition) while ssb < offroad_timestamp + DELAY_SHUTDOWN_TIME_S: - self.assertFalse(pm.should_shutdown(ignition, in_car, + assert not pm.should_shutdown(ignition, in_car, offroad_timestamp, - started_seen), - f"Should not shutdown before {DELAY_SHUTDOWN_TIME_S} seconds offroad time") - self.assertTrue(pm.should_shutdown(ignition, in_car, + started_seen), \ + f"Should not shutdown before {DELAY_SHUTDOWN_TIME_S} seconds offroad time" + assert pm.should_shutdown(ignition, in_car, offroad_timestamp, - started_seen), - f"Should shutdown after {DELAY_SHUTDOWN_TIME_S} seconds offroad time") - - -if __name__ == "__main__": - unittest.main() + started_seen), \ + f"Should shutdown after {DELAY_SHUTDOWN_TIME_S} seconds offroad time" diff --git a/selfdrive/ui/tests/test_soundd.py b/selfdrive/ui/tests/test_soundd.py index 94ce26eb47..d15a6c1831 100755 --- a/selfdrive/ui/tests/test_soundd.py +++ b/selfdrive/ui/tests/test_soundd.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -import unittest from cereal import car from cereal import messaging @@ -11,7 +10,7 @@ import time AudibleAlert = car.CarControl.HUDControl.AudibleAlert -class TestSoundd(unittest.TestCase): +class TestSoundd: def test_check_controls_timeout_alert(self): sm = SubMaster(['controlsState']) pm = PubMaster(['controlsState']) @@ -26,16 +25,13 @@ class TestSoundd(unittest.TestCase): sm.update(0) - self.assertFalse(check_controls_timeout_alert(sm)) + assert not check_controls_timeout_alert(sm) for _ in range(CONTROLS_TIMEOUT * 110): sm.update(0) time.sleep(0.01) - self.assertTrue(check_controls_timeout_alert(sm)) + assert check_controls_timeout_alert(sm) # TODO: add test with micd for checking that soundd actually outputs sounds - -if __name__ == "__main__": - unittest.main() diff --git a/selfdrive/ui/tests/test_translations.py b/selfdrive/ui/tests/test_translations.py index 8e50695e70..57de069d0b 100755 --- a/selfdrive/ui/tests/test_translations.py +++ b/selfdrive/ui/tests/test_translations.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 +import pytest import json import os import re -import unittest import shutil import tempfile import xml.etree.ElementTree as ET @@ -21,7 +21,7 @@ FORMAT_ARG = re.compile("%[0-9]+") @parameterized_class(("name", "file"), translation_files.items()) -class TestTranslations(unittest.TestCase): +class TestTranslations: name: str file: str @@ -32,8 +32,8 @@ class TestTranslations(unittest.TestCase): return f.read() def test_missing_translation_files(self): - self.assertTrue(os.path.exists(os.path.join(TRANSLATIONS_DIR, f"{self.file}.ts")), - f"{self.name} has no XML translation file, run selfdrive/ui/update_translations.py") + assert os.path.exists(os.path.join(TRANSLATIONS_DIR, f"{self.file}.ts")), \ + f"{self.name} has no XML translation file, run selfdrive/ui/update_translations.py" def test_translations_updated(self): with tempfile.TemporaryDirectory() as tmpdir: @@ -42,19 +42,19 @@ class TestTranslations(unittest.TestCase): cur_translations = self._read_translation_file(TRANSLATIONS_DIR, self.file) new_translations = self._read_translation_file(tmpdir, self.file) - self.assertEqual(cur_translations, new_translations, - f"{self.file} ({self.name}) XML translation file out of date. Run selfdrive/ui/update_translations.py to update the translation files") + assert cur_translations == new_translations, \ + f"{self.file} ({self.name}) XML translation file out of date. Run selfdrive/ui/update_translations.py to update the translation files" - @unittest.skip("Only test unfinished translations before going to release") + @pytest.mark.skip("Only test unfinished translations before going to release") def test_unfinished_translations(self): cur_translations = self._read_translation_file(TRANSLATIONS_DIR, self.file) - self.assertTrue(UNFINISHED_TRANSLATION_TAG not in cur_translations, - f"{self.file} ({self.name}) translation file has unfinished translations. Finish translations or mark them as completed in Qt Linguist") + assert UNFINISHED_TRANSLATION_TAG not in cur_translations, \ + f"{self.file} ({self.name}) translation file has unfinished translations. Finish translations or mark them as completed in Qt Linguist" def test_vanished_translations(self): cur_translations = self._read_translation_file(TRANSLATIONS_DIR, self.file) - self.assertTrue("" not in cur_translations, - f"{self.file} ({self.name}) translation file has obsolete translations. Run selfdrive/ui/update_translations.py --vanish to remove them") + assert "" not in cur_translations, \ + f"{self.file} ({self.name}) translation file has obsolete translations. Run selfdrive/ui/update_translations.py --vanish to remove them" def test_finished_translations(self): """ @@ -81,27 +81,27 @@ class TestTranslations(unittest.TestCase): numerusform = [t.text for t in translation.findall("numerusform")] for nf in numerusform: - self.assertIsNotNone(nf, f"Ensure all plural translation forms are completed: {source_text}") - self.assertIn("%n", nf, "Ensure numerus argument (%n) exists in translation.") - self.assertIsNone(FORMAT_ARG.search(nf), f"Plural translations must use %n, not %1, %2, etc.: {numerusform}") + assert nf is not None, f"Ensure all plural translation forms are completed: {source_text}" + assert "%n" in nf, "Ensure numerus argument (%n) exists in translation." + assert FORMAT_ARG.search(nf) is None, f"Plural translations must use %n, not %1, %2, etc.: {numerusform}" else: - self.assertIsNotNone(translation.text, f"Ensure translation is completed: {source_text}") + assert translation.text is not None, f"Ensure translation is completed: {source_text}" source_args = FORMAT_ARG.findall(source_text) translation_args = FORMAT_ARG.findall(translation.text) - self.assertEqual(sorted(source_args), sorted(translation_args), - f"Ensure format arguments are consistent: `{source_text}` vs. `{translation.text}`") + assert sorted(source_args) == sorted(translation_args), \ + f"Ensure format arguments are consistent: `{source_text}` vs. `{translation.text}`" def test_no_locations(self): for line in self._read_translation_file(TRANSLATIONS_DIR, self.file).splitlines(): - self.assertFalse(line.strip().startswith(LOCATION_TAG), - f"Line contains location tag: {line.strip()}, remove all line numbers.") + assert not line.strip().startswith(LOCATION_TAG), \ + f"Line contains location tag: {line.strip()}, remove all line numbers." def test_entities_error(self): cur_translations = self._read_translation_file(TRANSLATIONS_DIR, self.file) matches = re.findall(r'@(\w+);', cur_translations) - self.assertEqual(len(matches), 0, f"The string(s) {matches} were found with '@' instead of '&'") + assert len(matches) == 0, f"The string(s) {matches} were found with '@' instead of '&'" def test_bad_language(self): IGNORED_WORDS = {'pédale'} @@ -128,7 +128,3 @@ class TestTranslations(unittest.TestCase): words = set(translation_text.translate(str.maketrans('', '', string.punctuation + '%n')).lower().split()) bad_words_found = words & (banned_words - IGNORED_WORDS) assert not bad_words_found, f"Bad language found in {self.name}: '{translation_text}'. Bad word(s): {', '.join(bad_words_found)}" - - -if __name__ == "__main__": - unittest.main() diff --git a/selfdrive/ui/tests/test_ui/run.py b/selfdrive/ui/tests/test_ui/run.py index c834107780..79f30e79bf 100644 --- a/selfdrive/ui/tests/test_ui/run.py +++ b/selfdrive/ui/tests/test_ui/run.py @@ -8,7 +8,7 @@ import numpy as np import os import pywinctl import time -import unittest +import unittest # noqa: TID251 from parameterized import parameterized from cereal import messaging, car, log diff --git a/selfdrive/updated/tests/test_base.py b/selfdrive/updated/tests/test_base.py index b59f03fe77..615d0de99c 100644 --- a/selfdrive/updated/tests/test_base.py +++ b/selfdrive/updated/tests/test_base.py @@ -6,9 +6,6 @@ import stat import subprocess import tempfile import time -import unittest -from unittest import mock - import pytest from openpilot.common.params import Params @@ -52,13 +49,13 @@ def get_version(path: str) -> str: @pytest.mark.slow # TODO: can we test overlayfs in GHA? -class BaseUpdateTest(unittest.TestCase): +class TestBaseUpdate: @classmethod - def setUpClass(cls): + def setup_class(cls): if "Base" in cls.__name__: - raise unittest.SkipTest + pytest.skip() - def setUp(self): + def setup_method(self): self.tmpdir = tempfile.mkdtemp() run(["sudo", "mount", "-t", "tmpfs", "tmpfs", self.tmpdir]) # overlayfs doesn't work inside of docker unless this is a tmpfs @@ -76,8 +73,6 @@ class BaseUpdateTest(unittest.TestCase): self.remote_dir = self.mock_update_path / "remote" self.remote_dir.mkdir() - mock.patch("openpilot.common.basedir.BASEDIR", self.basedir).start() - os.environ["UPDATER_STAGING_ROOT"] = str(self.staging_root) os.environ["UPDATER_LOCK_FILE"] = str(self.mock_update_path / "safe_staging_overlay.lock") @@ -86,6 +81,10 @@ class BaseUpdateTest(unittest.TestCase): "master": ("0.1.3", "1.2", "0.1.3 release notes"), } + @pytest.fixture(autouse=True) + def mock_basedir(self, mocker): + mocker.patch("openpilot.common.basedir.BASEDIR", self.basedir) + def set_target_branch(self, branch): self.params.put("UpdaterTargetBranch", branch) @@ -102,8 +101,7 @@ class BaseUpdateTest(unittest.TestCase): def additional_context(self): raise NotImplementedError("") - def tearDown(self): - mock.patch.stopall() + def teardown_method(self): try: run(["sudo", "umount", "-l", str(self.staging_root / "merged")]) run(["sudo", "umount", "-l", self.tmpdir]) @@ -125,17 +123,17 @@ class BaseUpdateTest(unittest.TestCase): time.sleep(1) def _test_finalized_update(self, branch, version, agnos_version, release_notes): - self.assertEqual(get_version(str(self.staging_root / "finalized")), version) - self.assertEqual(get_consistent_flag(str(self.staging_root / "finalized")), True) - self.assertTrue(os.access(str(self.staging_root / "finalized" / "launch_env.sh"), os.X_OK)) + assert get_version(str(self.staging_root / "finalized")) == version + assert get_consistent_flag(str(self.staging_root / "finalized")) + assert os.access(str(self.staging_root / "finalized" / "launch_env.sh"), os.X_OK) with open(self.staging_root / "finalized" / "test_symlink") as f: - self.assertIn(version, f.read()) + assert version in f.read() -class ParamsBaseUpdateTest(BaseUpdateTest): +class ParamsBaseUpdateTest(TestBaseUpdate): def _test_finalized_update(self, branch, version, agnos_version, release_notes): - self.assertTrue(self.params.get("UpdaterNewDescription", encoding="utf-8").startswith(f"{version} / {branch}")) - self.assertEqual(self.params.get("UpdaterNewReleaseNotes", encoding="utf-8"), f"

{release_notes}

\n") + assert self.params.get("UpdaterNewDescription", encoding="utf-8").startswith(f"{version} / {branch}") + assert self.params.get("UpdaterNewReleaseNotes", encoding="utf-8") == f"

{release_notes}

\n" super()._test_finalized_update(branch, version, agnos_version, release_notes) def send_check_for_updates_signal(self, updated: ManagerProcess): @@ -145,9 +143,9 @@ class ParamsBaseUpdateTest(BaseUpdateTest): updated.signal(signal.SIGHUP.value) def _test_params(self, branch, fetch_available, update_available): - self.assertEqual(self.params.get("UpdaterTargetBranch", encoding="utf-8"), branch) - self.assertEqual(self.params.get_bool("UpdaterFetchAvailable"), fetch_available) - self.assertEqual(self.params.get_bool("UpdateAvailable"), update_available) + assert self.params.get("UpdaterTargetBranch", encoding="utf-8") == branch + assert self.params.get_bool("UpdaterFetchAvailable") == fetch_available + assert self.params.get_bool("UpdateAvailable") == update_available def wait_for_idle(self): self.wait_for_condition(lambda: self.params.get("UpdaterState", encoding="utf-8") == "idle") @@ -229,17 +227,16 @@ class ParamsBaseUpdateTest(BaseUpdateTest): self._test_params("master", False, True) self._test_finalized_update("master", *self.MOCK_RELEASES["master"]) - def test_agnos_update(self): + def test_agnos_update(self, mocker): # Start on release3, push an update with an agnos change self.setup_remote_release("release3") self.setup_basedir_release("release3") - with self.additional_context(), \ - mock.patch("openpilot.system.hardware.AGNOS", "True"), \ - mock.patch("openpilot.system.hardware.tici.hardware.Tici.get_os_version", "1.2"), \ - mock.patch("openpilot.system.hardware.tici.agnos.get_target_slot_number"), \ - mock.patch("openpilot.system.hardware.tici.agnos.flash_agnos_update"), \ - processes_context(["updated"]) as [updated]: + with self.additional_context(), processes_context(["updated"]) as [updated]: + mocker.patch("openpilot.system.hardware.AGNOS", "True") + mocker.patch("openpilot.system.hardware.tici.hardware.Tici.get_os_version", "1.2") + mocker.patch("openpilot.system.hardware.tici.agnos.get_target_slot_number") + mocker.patch("openpilot.system.hardware.tici.agnos.flash_agnos_update") self._test_params("release3", False, False) self.wait_for_idle() diff --git a/system/camerad/test/test_exposure.py b/system/camerad/test/test_exposure.py index 50467f9db4..36e8522b1d 100755 --- a/system/camerad/test/test_exposure.py +++ b/system/camerad/test/test_exposure.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 import time -import unittest import numpy as np from openpilot.selfdrive.test.helpers import with_processes, phone_only @@ -9,9 +8,9 @@ from openpilot.system.camerad.snapshot.snapshot import get_snapshots TEST_TIME = 45 REPEAT = 5 -class TestCamerad(unittest.TestCase): +class TestCamerad: @classmethod - def setUpClass(cls): + def setup_class(cls): pass def _numpy_rgb2gray(self, im): @@ -49,7 +48,4 @@ class TestCamerad(unittest.TestCase): passed += int(res) time.sleep(2) - self.assertGreaterEqual(passed, REPEAT) - -if __name__ == "__main__": - unittest.main() + assert passed >= REPEAT diff --git a/system/hardware/tici/tests/test_agnos_updater.py b/system/hardware/tici/tests/test_agnos_updater.py index 86bc78881e..462cf6cb5c 100755 --- a/system/hardware/tici/tests/test_agnos_updater.py +++ b/system/hardware/tici/tests/test_agnos_updater.py @@ -1,14 +1,13 @@ #!/usr/bin/env python3 import json import os -import unittest import requests TEST_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__))) MANIFEST = os.path.join(TEST_DIR, "../agnos.json") -class TestAgnosUpdater(unittest.TestCase): +class TestAgnosUpdater: def test_manifest(self): with open(MANIFEST) as f: @@ -17,10 +16,6 @@ class TestAgnosUpdater(unittest.TestCase): for img in m: r = requests.head(img['url'], timeout=10) r.raise_for_status() - self.assertEqual(r.headers['Content-Type'], "application/x-xz") + assert r.headers['Content-Type'] == "application/x-xz" if not img['sparse']: assert img['hash'] == img['hash_raw'] - - -if __name__ == "__main__": - unittest.main() diff --git a/system/hardware/tici/tests/test_amplifier.py b/system/hardware/tici/tests/test_amplifier.py index cd3b0f90fe..dfba84b942 100755 --- a/system/hardware/tici/tests/test_amplifier.py +++ b/system/hardware/tici/tests/test_amplifier.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 +import pytest import time import random -import unittest import subprocess from panda import Panda @@ -10,14 +10,14 @@ from openpilot.system.hardware.tici.hardware import Tici from openpilot.system.hardware.tici.amplifier import Amplifier -class TestAmplifier(unittest.TestCase): +class TestAmplifier: @classmethod - def setUpClass(cls): + def setup_class(cls): if not TICI: - raise unittest.SkipTest + pytest.skip() - def setUp(self): + def setup_method(self): # clear dmesg subprocess.check_call("sudo dmesg -C", shell=True) @@ -25,7 +25,7 @@ class TestAmplifier(unittest.TestCase): Panda.wait_for_panda(None, 30) self.panda = Panda() - def tearDown(self): + def teardown_method(self): HARDWARE.reset_internal_panda() def _check_for_i2c_errors(self, expected): @@ -68,8 +68,4 @@ class TestAmplifier(unittest.TestCase): if self._check_for_i2c_errors(True): break else: - self.fail("didn't hit any i2c errors") - - -if __name__ == "__main__": - unittest.main() + pytest.fail("didn't hit any i2c errors") diff --git a/system/hardware/tici/tests/test_hardware.py b/system/hardware/tici/tests/test_hardware.py index 6c41c383a0..49d4ac7699 100755 --- a/system/hardware/tici/tests/test_hardware.py +++ b/system/hardware/tici/tests/test_hardware.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 import pytest import time -import unittest import numpy as np from openpilot.system.hardware.tici.hardware import Tici @@ -10,7 +9,7 @@ HARDWARE = Tici() @pytest.mark.tici -class TestHardware(unittest.TestCase): +class TestHardware: def test_power_save_time(self): ts = [] @@ -22,7 +21,3 @@ class TestHardware(unittest.TestCase): assert 0.1 < np.mean(ts) < 0.25 assert max(ts) < 0.3 - - -if __name__ == "__main__": - unittest.main() diff --git a/system/hardware/tici/tests/test_power_draw.py b/system/hardware/tici/tests/test_power_draw.py index ba7e0a6d9d..104329da42 100755 --- a/system/hardware/tici/tests/test_power_draw.py +++ b/system/hardware/tici/tests/test_power_draw.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 from collections import defaultdict, deque import pytest -import unittest import time import numpy as np from dataclasses import dataclass @@ -40,15 +39,15 @@ PROCS = [ @pytest.mark.tici -class TestPowerDraw(unittest.TestCase): +class TestPowerDraw: - def setUp(self): + def setup_method(self): write_car_param() # wait a bit for power save to disable time.sleep(5) - def tearDown(self): + def teardown_method(self): manager_cleanup() def get_expected_messages(self, proc): @@ -97,7 +96,7 @@ class TestPowerDraw(unittest.TestCase): return now, msg_counts, time.monotonic() - start_time - SAMPLE_TIME @mock_messages(['liveLocationKalman']) - def test_camera_procs(self): + def test_camera_procs(self, subtests): baseline = get_power() prev = baseline @@ -122,12 +121,8 @@ class TestPowerDraw(unittest.TestCase): expected = proc.power msgs_received = sum(msg_counts[msg] for msg in proc.msgs) tab.append([proc.name, round(expected, 2), round(cur, 2), self.get_expected_messages(proc), msgs_received, round(warmup_time[proc.name], 2)]) - with self.subTest(proc=proc.name): - self.assertTrue(self.valid_msg_count(proc, msg_counts), f"expected {self.get_expected_messages(proc)} msgs, got {msgs_received} msgs") - self.assertTrue(self.valid_power_draw(proc, cur), f"expected {expected:.2f}W, got {cur:.2f}W") + with subtests.test(proc=proc.name): + assert self.valid_msg_count(proc, msg_counts), f"expected {self.get_expected_messages(proc)} msgs, got {msgs_received} msgs" + assert self.valid_power_draw(proc, cur), f"expected {expected:.2f}W, got {cur:.2f}W" print(tabulate(tab)) print(f"Baseline {baseline:.2f}W\n") - - -if __name__ == "__main__": - unittest.main() diff --git a/system/loggerd/tests/loggerd_tests_common.py b/system/loggerd/tests/loggerd_tests_common.py index 877c872b6b..e8a6d031c4 100644 --- a/system/loggerd/tests/loggerd_tests_common.py +++ b/system/loggerd/tests/loggerd_tests_common.py @@ -1,6 +1,5 @@ import os import random -import unittest from pathlib import Path @@ -54,7 +53,7 @@ class MockApiIgnore(): def get_token(self): return "fake-token" -class UploaderTestCase(unittest.TestCase): +class UploaderTestCase: f_type = "UNKNOWN" root: Path @@ -66,7 +65,7 @@ class UploaderTestCase(unittest.TestCase): def set_ignore(self): uploader.Api = MockApiIgnore - def setUp(self): + def setup_method(self): uploader.Api = MockApi uploader.fake_upload = True uploader.force_wifi = True diff --git a/system/loggerd/tests/test_deleter.py b/system/loggerd/tests/test_deleter.py index 37d25507e0..3ba6ad4031 100755 --- a/system/loggerd/tests/test_deleter.py +++ b/system/loggerd/tests/test_deleter.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 import time import threading -import unittest from collections import namedtuple from pathlib import Path from collections.abc import Sequence @@ -17,9 +16,9 @@ class TestDeleter(UploaderTestCase): def fake_statvfs(self, d): return self.fake_stats - def setUp(self): + def setup_method(self): self.f_type = "fcamera.hevc" - super().setUp() + super().setup_method() self.fake_stats = Stats(f_bavail=0, f_blocks=10, f_frsize=4096) deleter.os.statvfs = self.fake_statvfs @@ -64,7 +63,7 @@ class TestDeleter(UploaderTestCase): finally: self.join_thread() - self.assertEqual(deleted_order, f_paths, "Files not deleted in expected order") + assert deleted_order == f_paths, "Files not deleted in expected order" def test_delete_order(self): self.assertDeleteOrder([ @@ -105,7 +104,7 @@ class TestDeleter(UploaderTestCase): time.sleep(0.01) self.join_thread() - self.assertTrue(f_path.exists(), "File deleted with available space") + assert f_path.exists(), "File deleted with available space" def test_no_delete_with_lock_file(self): f_path = self.make_file_with_data(self.seg_dir, self.f_type, lock=True) @@ -116,8 +115,4 @@ class TestDeleter(UploaderTestCase): time.sleep(0.01) self.join_thread() - self.assertTrue(f_path.exists(), "File deleted when locked") - - -if __name__ == "__main__": - unittest.main() + assert f_path.exists(), "File deleted when locked" diff --git a/system/loggerd/tests/test_encoder.py b/system/loggerd/tests/test_encoder.py index bd076dc5f3..f6eb9b9011 100755 --- a/system/loggerd/tests/test_encoder.py +++ b/system/loggerd/tests/test_encoder.py @@ -6,7 +6,6 @@ import random import shutil import subprocess import time -import unittest from pathlib import Path from parameterized import parameterized @@ -33,14 +32,14 @@ FILE_SIZE_TOLERANCE = 0.5 @pytest.mark.tici # TODO: all of loggerd should work on PC -class TestEncoder(unittest.TestCase): +class TestEncoder: - def setUp(self): + def setup_method(self): self._clear_logs() os.environ["LOGGERD_TEST"] = "1" os.environ["LOGGERD_SEGMENT_LENGTH"] = str(SEGMENT_LENGTH) - def tearDown(self): + def teardown_method(self): self._clear_logs() def _clear_logs(self): @@ -85,7 +84,7 @@ class TestEncoder(unittest.TestCase): file_path = f"{route_prefix_path}--{i}/{camera}" # check file exists - self.assertTrue(os.path.exists(file_path), f"segment #{i}: '{file_path}' missing") + assert os.path.exists(file_path), f"segment #{i}: '{file_path}' missing" # TODO: this ffprobe call is really slow # check frame count @@ -98,13 +97,13 @@ class TestEncoder(unittest.TestCase): frame_count = int(probe.split('\n')[0].strip()) counts.append(frame_count) - self.assertEqual(frame_count, expected_frames, - f"segment #{i}: {camera} failed frame count check: expected {expected_frames}, got {frame_count}") + assert frame_count == expected_frames, \ + f"segment #{i}: {camera} failed frame count check: expected {expected_frames}, got {frame_count}" # sanity check file size file_size = os.path.getsize(file_path) - self.assertTrue(math.isclose(file_size, size, rel_tol=FILE_SIZE_TOLERANCE), - f"{file_path} size {file_size} isn't close to target size {size}") + assert math.isclose(file_size, size, rel_tol=FILE_SIZE_TOLERANCE), \ + f"{file_path} size {file_size} isn't close to target size {size}" # Check encodeIdx if encode_idx_name is not None: @@ -118,24 +117,24 @@ class TestEncoder(unittest.TestCase): frame_idxs = [m.frameId for m in encode_msgs] # Check frame count - self.assertEqual(frame_count, len(segment_idxs)) - self.assertEqual(frame_count, len(encode_idxs)) + assert frame_count == len(segment_idxs) + assert frame_count == len(encode_idxs) # Check for duplicates or skips - self.assertEqual(0, segment_idxs[0]) - self.assertEqual(len(set(segment_idxs)), len(segment_idxs)) + assert 0 == segment_idxs[0] + assert len(set(segment_idxs)) == len(segment_idxs) - self.assertTrue(all(valid)) + assert all(valid) - self.assertEqual(expected_frames * i, encode_idxs[0]) + assert expected_frames * i == encode_idxs[0] first_frames.append(frame_idxs[0]) - self.assertEqual(len(set(encode_idxs)), len(encode_idxs)) + assert len(set(encode_idxs)) == len(encode_idxs) - self.assertEqual(1, len(set(first_frames))) + assert 1 == len(set(first_frames)) if TICI: expected_frames = fps * SEGMENT_LENGTH - self.assertEqual(min(counts), expected_frames) + assert min(counts) == expected_frames shutil.rmtree(f"{route_prefix_path}--{i}") try: @@ -150,7 +149,3 @@ class TestEncoder(unittest.TestCase): managed_processes['encoderd'].stop() managed_processes['camerad'].stop() managed_processes['sensord'].stop() - - -if __name__ == "__main__": - unittest.main() diff --git a/system/loggerd/tests/test_uploader.py b/system/loggerd/tests/test_uploader.py index 73917a30cf..c0a9770e53 100755 --- a/system/loggerd/tests/test_uploader.py +++ b/system/loggerd/tests/test_uploader.py @@ -2,7 +2,6 @@ import os import time import threading -import unittest import logging import json from pathlib import Path @@ -38,8 +37,8 @@ cloudlog.addHandler(log_handler) class TestUploader(UploaderTestCase): - def setUp(self): - super().setUp() + def setup_method(self): + super().setup_method() log_handler.reset() def start_thread(self): @@ -80,13 +79,13 @@ class TestUploader(UploaderTestCase): exp_order = self.gen_order([self.seg_num], []) - self.assertTrue(len(log_handler.upload_ignored) == 0, "Some files were ignored") - self.assertFalse(len(log_handler.upload_order) < len(exp_order), "Some files failed to upload") - self.assertFalse(len(log_handler.upload_order) > len(exp_order), "Some files were uploaded twice") + assert len(log_handler.upload_ignored) == 0, "Some files were ignored" + assert not len(log_handler.upload_order) < len(exp_order), "Some files failed to upload" + assert not len(log_handler.upload_order) > len(exp_order), "Some files were uploaded twice" for f_path in exp_order: - self.assertEqual(os.getxattr((Path(Paths.log_root()) / f_path).with_suffix(""), UPLOAD_ATTR_NAME), UPLOAD_ATTR_VALUE, "All files not uploaded") + assert os.getxattr((Path(Paths.log_root()) / f_path).with_suffix(""), UPLOAD_ATTR_NAME) == UPLOAD_ATTR_VALUE, "All files not uploaded" - self.assertTrue(log_handler.upload_order == exp_order, "Files uploaded in wrong order") + assert log_handler.upload_order == exp_order, "Files uploaded in wrong order" def test_upload_with_wrong_xattr(self): self.gen_files(lock=False, xattr=b'0') @@ -98,13 +97,13 @@ class TestUploader(UploaderTestCase): exp_order = self.gen_order([self.seg_num], []) - self.assertTrue(len(log_handler.upload_ignored) == 0, "Some files were ignored") - self.assertFalse(len(log_handler.upload_order) < len(exp_order), "Some files failed to upload") - self.assertFalse(len(log_handler.upload_order) > len(exp_order), "Some files were uploaded twice") + assert len(log_handler.upload_ignored) == 0, "Some files were ignored" + assert not len(log_handler.upload_order) < len(exp_order), "Some files failed to upload" + assert not len(log_handler.upload_order) > len(exp_order), "Some files were uploaded twice" for f_path in exp_order: - self.assertEqual(os.getxattr((Path(Paths.log_root()) / f_path).with_suffix(""), UPLOAD_ATTR_NAME), UPLOAD_ATTR_VALUE, "All files not uploaded") + assert os.getxattr((Path(Paths.log_root()) / f_path).with_suffix(""), UPLOAD_ATTR_NAME) == UPLOAD_ATTR_VALUE, "All files not uploaded" - self.assertTrue(log_handler.upload_order == exp_order, "Files uploaded in wrong order") + assert log_handler.upload_order == exp_order, "Files uploaded in wrong order" def test_upload_ignored(self): self.set_ignore() @@ -117,13 +116,13 @@ class TestUploader(UploaderTestCase): exp_order = self.gen_order([self.seg_num], []) - self.assertTrue(len(log_handler.upload_order) == 0, "Some files were not ignored") - self.assertFalse(len(log_handler.upload_ignored) < len(exp_order), "Some files failed to ignore") - self.assertFalse(len(log_handler.upload_ignored) > len(exp_order), "Some files were ignored twice") + assert len(log_handler.upload_order) == 0, "Some files were not ignored" + assert not len(log_handler.upload_ignored) < len(exp_order), "Some files failed to ignore" + assert not len(log_handler.upload_ignored) > len(exp_order), "Some files were ignored twice" for f_path in exp_order: - self.assertEqual(os.getxattr((Path(Paths.log_root()) / f_path).with_suffix(""), UPLOAD_ATTR_NAME), UPLOAD_ATTR_VALUE, "All files not ignored") + assert os.getxattr((Path(Paths.log_root()) / f_path).with_suffix(""), UPLOAD_ATTR_NAME) == UPLOAD_ATTR_VALUE, "All files not ignored" - self.assertTrue(log_handler.upload_ignored == exp_order, "Files ignored in wrong order") + assert log_handler.upload_ignored == exp_order, "Files ignored in wrong order" def test_upload_files_in_create_order(self): seg1_nums = [0, 1, 2, 10, 20] @@ -142,13 +141,13 @@ class TestUploader(UploaderTestCase): time.sleep(5) self.join_thread() - self.assertTrue(len(log_handler.upload_ignored) == 0, "Some files were ignored") - self.assertFalse(len(log_handler.upload_order) < len(exp_order), "Some files failed to upload") - self.assertFalse(len(log_handler.upload_order) > len(exp_order), "Some files were uploaded twice") + assert len(log_handler.upload_ignored) == 0, "Some files were ignored" + assert not len(log_handler.upload_order) < len(exp_order), "Some files failed to upload" + assert not len(log_handler.upload_order) > len(exp_order), "Some files were uploaded twice" for f_path in exp_order: - self.assertEqual(os.getxattr((Path(Paths.log_root()) / f_path).with_suffix(""), UPLOAD_ATTR_NAME), UPLOAD_ATTR_VALUE, "All files not uploaded") + assert os.getxattr((Path(Paths.log_root()) / f_path).with_suffix(""), UPLOAD_ATTR_NAME) == UPLOAD_ATTR_VALUE, "All files not uploaded" - self.assertTrue(log_handler.upload_order == exp_order, "Files uploaded in wrong order") + assert log_handler.upload_order == exp_order, "Files uploaded in wrong order" def test_no_upload_with_lock_file(self): self.start_thread() @@ -163,7 +162,7 @@ class TestUploader(UploaderTestCase): for f_path in f_paths: fn = f_path.with_suffix(f_path.suffix.replace(".bz2", "")) uploaded = UPLOAD_ATTR_NAME in os.listxattr(fn) and os.getxattr(fn, UPLOAD_ATTR_NAME) == UPLOAD_ATTR_VALUE - self.assertFalse(uploaded, "File upload when locked") + assert not uploaded, "File upload when locked" def test_no_upload_with_xattr(self): self.gen_files(lock=False, xattr=UPLOAD_ATTR_VALUE) @@ -173,7 +172,7 @@ class TestUploader(UploaderTestCase): time.sleep(5) self.join_thread() - self.assertEqual(len(log_handler.upload_order), 0, "File uploaded again") + assert len(log_handler.upload_order) == 0, "File uploaded again" def test_clear_locks_on_startup(self): f_paths = self.gen_files(lock=True, boot=False) @@ -183,8 +182,4 @@ class TestUploader(UploaderTestCase): for f_path in f_paths: lock_path = f_path.with_suffix(f_path.suffix + ".lock") - self.assertFalse(lock_path.is_file(), "File lock not cleared on startup") - - -if __name__ == "__main__": - unittest.main() + assert not lock_path.is_file(), "File lock not cleared on startup" diff --git a/system/qcomgpsd/tests/test_qcomgpsd.py b/system/qcomgpsd/tests/test_qcomgpsd.py index 6c93f7dd93..d47ea5d634 100755 --- a/system/qcomgpsd/tests/test_qcomgpsd.py +++ b/system/qcomgpsd/tests/test_qcomgpsd.py @@ -4,7 +4,6 @@ import pytest import json import time import datetime -import unittest import subprocess import cereal.messaging as messaging @@ -15,24 +14,24 @@ GOOD_SIGNAL = bool(int(os.getenv("GOOD_SIGNAL", '0'))) @pytest.mark.tici -class TestRawgpsd(unittest.TestCase): +class TestRawgpsd: @classmethod - def setUpClass(cls): + def setup_class(cls): os.system("sudo systemctl start systemd-resolved") os.system("sudo systemctl restart ModemManager lte") wait_for_modem() @classmethod - def tearDownClass(cls): + def teardown_class(cls): managed_processes['qcomgpsd'].stop() os.system("sudo systemctl restart systemd-resolved") os.system("sudo systemctl restart ModemManager lte") - def setUp(self): + def setup_method(self): at_cmd("AT+QGPSDEL=0") self.sm = messaging.SubMaster(['qcomGnss', 'gpsLocation', 'gnssMeasurements']) - def tearDown(self): + def teardown_method(self): managed_processes['qcomgpsd'].stop() os.system("sudo systemctl restart systemd-resolved") @@ -57,18 +56,18 @@ class TestRawgpsd(unittest.TestCase): os.system("sudo systemctl restart ModemManager") assert self._wait_for_output(30) - def test_startup_time(self): + def test_startup_time(self, subtests): for internet in (True, False): if not internet: os.system("sudo systemctl stop systemd-resolved") - with self.subTest(internet=internet): + with subtests.test(internet=internet): managed_processes['qcomgpsd'].start() assert self._wait_for_output(7) managed_processes['qcomgpsd'].stop() - def test_turns_off_gnss(self): + def test_turns_off_gnss(self, subtests): for s in (0.1, 1, 5): - with self.subTest(runtime=s): + with subtests.test(runtime=s): managed_processes['qcomgpsd'].start() time.sleep(s) managed_processes['qcomgpsd'].stop() @@ -87,7 +86,7 @@ class TestRawgpsd(unittest.TestCase): if should_be_loaded: assert valid_duration == "10080" # should be max time injected_time = datetime.datetime.strptime(injected_time_str.replace("\"", ""), "%Y/%m/%d,%H:%M:%S") - self.assertLess(abs((datetime.datetime.utcnow() - injected_time).total_seconds()), 60*60*12) + assert abs((datetime.datetime.utcnow() - injected_time).total_seconds()) < 60*60*12 else: valid_duration, injected_time_str = out.split(",", 1) injected_time_str = injected_time_str.replace('\"', '').replace('\'', '') @@ -119,6 +118,3 @@ class TestRawgpsd(unittest.TestCase): time.sleep(15) managed_processes['qcomgpsd'].stop() self.check_assistance(True) - -if __name__ == "__main__": - unittest.main(failfast=True) diff --git a/system/sensord/tests/test_sensord.py b/system/sensord/tests/test_sensord.py index 3075c8a343..1b3b78da88 100755 --- a/system/sensord/tests/test_sensord.py +++ b/system/sensord/tests/test_sensord.py @@ -2,7 +2,6 @@ import os import pytest import time -import unittest import numpy as np from collections import namedtuple, defaultdict @@ -99,9 +98,9 @@ def read_sensor_events(duration_sec): return {k: v for k, v in events.items() if len(v) > 0} @pytest.mark.tici -class TestSensord(unittest.TestCase): +class TestSensord: @classmethod - def setUpClass(cls): + def setup_class(cls): # enable LSM self test os.environ["LSM_SELF_TEST"] = "1" @@ -119,10 +118,10 @@ class TestSensord(unittest.TestCase): managed_processes["sensord"].stop() @classmethod - def tearDownClass(cls): + def teardown_class(cls): managed_processes["sensord"].stop() - def tearDown(self): + def teardown_method(self): managed_processes["sensord"].stop() def test_sensors_present(self): @@ -133,9 +132,9 @@ class TestSensord(unittest.TestCase): m = getattr(measurement, measurement.which()) seen.add((str(m.source), m.which())) - self.assertIn(seen, SENSOR_CONFIGURATIONS) + assert seen in SENSOR_CONFIGURATIONS - def test_lsm6ds3_timing(self): + def test_lsm6ds3_timing(self, subtests): # verify measurements are sampled and published at 104Hz sensor_t = { @@ -152,7 +151,7 @@ class TestSensord(unittest.TestCase): sensor_t[m.sensor].append(m.timestamp) for s, vals in sensor_t.items(): - with self.subTest(sensor=s): + with subtests.test(sensor=s): assert len(vals) > 0 tdiffs = np.diff(vals) / 1e6 # millis @@ -166,9 +165,9 @@ class TestSensord(unittest.TestCase): stddev = np.std(tdiffs) assert stddev < 2.0, f"Standard-dev to big {stddev}" - def test_sensor_frequency(self): + def test_sensor_frequency(self, subtests): for s, msgs in self.events.items(): - with self.subTest(sensor=s): + with subtests.test(sensor=s): freq = len(msgs) / self.sample_secs ef = SERVICE_LIST[s].frequency assert ef*0.85 <= freq <= ef*1.15 @@ -246,6 +245,3 @@ class TestSensord(unittest.TestCase): state_two = get_irq_count(self.sensord_irq) assert state_one == state_two, "Interrupts received after sensord stop!" - -if __name__ == "__main__": - unittest.main() diff --git a/system/tests/test_logmessaged.py b/system/tests/test_logmessaged.py index d27dae46ad..6d59bdcb08 100755 --- a/system/tests/test_logmessaged.py +++ b/system/tests/test_logmessaged.py @@ -2,7 +2,6 @@ import glob import os import time -import unittest import cereal.messaging as messaging from openpilot.selfdrive.manager.process_config import managed_processes @@ -10,8 +9,8 @@ from openpilot.system.hardware.hw import Paths from openpilot.common.swaglog import cloudlog, ipchandler -class TestLogmessaged(unittest.TestCase): - def setUp(self): +class TestLogmessaged: + def setup_method(self): # clear the IPC buffer in case some other tests used cloudlog and filled it ipchandler.close() ipchandler.connect() @@ -25,7 +24,7 @@ class TestLogmessaged(unittest.TestCase): messaging.drain_sock(self.sock) messaging.drain_sock(self.error_sock) - def tearDown(self): + def teardown_method(self): del self.sock del self.error_sock managed_processes['logmessaged'].stop(block=True) @@ -55,6 +54,3 @@ class TestLogmessaged(unittest.TestCase): logsize = sum([os.path.getsize(f) for f in self._get_log_files()]) assert (n*len(msg)) < logsize < (n*(len(msg)+1024)) - -if __name__ == "__main__": - unittest.main() diff --git a/system/ubloxd/tests/test_pigeond.py b/system/ubloxd/tests/test_pigeond.py index 742e20bb90..a24414466a 100755 --- a/system/ubloxd/tests/test_pigeond.py +++ b/system/ubloxd/tests/test_pigeond.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 import pytest import time -import unittest import cereal.messaging as messaging from cereal.services import SERVICE_LIST @@ -13,9 +12,9 @@ from openpilot.system.hardware.tici.pins import GPIO # TODO: test TTFF when we have good A-GNSS @pytest.mark.tici -class TestPigeond(unittest.TestCase): +class TestPigeond: - def tearDown(self): + def teardown_method(self): managed_processes['pigeond'].stop() @with_processes(['pigeond']) @@ -54,7 +53,3 @@ class TestPigeond(unittest.TestCase): assert gpio_read(GPIO.UBLOX_RST_N) == 0 assert gpio_read(GPIO.GNSS_PWR_EN) == 0 - - -if __name__ == "__main__": - unittest.main() diff --git a/system/updated/casync/tests/test_casync.py b/system/updated/casync/tests/test_casync.py index 34427d5625..80c5d2705c 100755 --- a/system/updated/casync/tests/test_casync.py +++ b/system/updated/casync/tests/test_casync.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 +import pytest import os import pathlib -import unittest import tempfile import subprocess @@ -14,9 +14,9 @@ from openpilot.system.updated.casync import tar LOOPBACK = os.environ.get('LOOPBACK', None) -class TestCasync(unittest.TestCase): +class TestCasync: @classmethod - def setUpClass(cls): + def setup_class(cls): cls.tmpdir = tempfile.TemporaryDirectory() # Build example contents @@ -43,7 +43,7 @@ class TestCasync(unittest.TestCase): # Ensure we have chunk reuse assert len(hashes) > len(set(hashes)) - def setUp(self): + def setup_method(self): # Clear target_lo if LOOPBACK is not None: self.target_lo = LOOPBACK @@ -53,7 +53,7 @@ class TestCasync(unittest.TestCase): self.target_fn = os.path.join(self.tmpdir.name, next(tempfile._get_candidate_names())) self.seed_fn = os.path.join(self.tmpdir.name, next(tempfile._get_candidate_names())) - def tearDown(self): + def teardown_method(self): for fn in [self.target_fn, self.seed_fn]: try: os.unlink(fn) @@ -67,9 +67,9 @@ class TestCasync(unittest.TestCase): stats = casync.extract(target, sources, self.target_fn) with open(self.target_fn, 'rb') as target_f: - self.assertEqual(target_f.read(), self.contents) + assert target_f.read() == self.contents - self.assertEqual(stats['remote'], len(self.contents)) + assert stats['remote'] == len(self.contents) def test_seed(self): target = casync.parse_caibx(self.manifest_fn) @@ -83,10 +83,10 @@ class TestCasync(unittest.TestCase): stats = casync.extract(target, sources, self.target_fn) with open(self.target_fn, 'rb') as target_f: - self.assertEqual(target_f.read(), self.contents) + assert target_f.read() == self.contents - self.assertGreater(stats['seed'], 0) - self.assertLess(stats['remote'], len(self.contents)) + assert stats['seed'] > 0 + assert stats['remote'] < len(self.contents) def test_already_done(self): """Test that an already flashed target doesn't download any chunks""" @@ -101,9 +101,9 @@ class TestCasync(unittest.TestCase): stats = casync.extract(target, sources, self.target_fn) with open(self.target_fn, 'rb') as f: - self.assertEqual(f.read(), self.contents) + assert f.read() == self.contents - self.assertEqual(stats['target'], len(self.contents)) + assert stats['target'] == len(self.contents) def test_chunk_reuse(self): """Test that chunks that are reused are only downloaded once""" @@ -119,11 +119,11 @@ class TestCasync(unittest.TestCase): stats = casync.extract(target, sources, self.target_fn) with open(self.target_fn, 'rb') as f: - self.assertEqual(f.read(), self.contents) + assert f.read() == self.contents - self.assertLess(stats['remote'], len(self.contents)) + assert stats['remote'] < len(self.contents) - @unittest.skipUnless(LOOPBACK, "requires loopback device") + @pytest.mark.skipif(not LOOPBACK, reason="requires loopback device") def test_lo_simple_extract(self): target = casync.parse_caibx(self.manifest_fn) sources = [('remote', casync.RemoteChunkReader(self.store_fn), casync.build_chunk_dict(target))] @@ -131,11 +131,11 @@ class TestCasync(unittest.TestCase): stats = casync.extract(target, sources, self.target_lo) with open(self.target_lo, 'rb') as target_f: - self.assertEqual(target_f.read(len(self.contents)), self.contents) + assert target_f.read(len(self.contents)) == self.contents - self.assertEqual(stats['remote'], len(self.contents)) + assert stats['remote'] == len(self.contents) - @unittest.skipUnless(LOOPBACK, "requires loopback device") + @pytest.mark.skipif(not LOOPBACK, reason="requires loopback device") def test_lo_chunk_reuse(self): """Test that chunks that are reused are only downloaded once""" target = casync.parse_caibx(self.manifest_fn) @@ -146,12 +146,12 @@ class TestCasync(unittest.TestCase): stats = casync.extract(target, sources, self.target_lo) with open(self.target_lo, 'rb') as f: - self.assertEqual(f.read(len(self.contents)), self.contents) + assert f.read(len(self.contents)) == self.contents - self.assertLess(stats['remote'], len(self.contents)) + assert stats['remote'] < len(self.contents) -class TestCasyncDirectory(unittest.TestCase): +class TestCasyncDirectory: """Tests extracting a directory stored as a casync tar archive""" NUM_FILES = 16 @@ -174,7 +174,7 @@ class TestCasyncDirectory(unittest.TestCase): os.symlink(f"file_{i}.txt", os.path.join(directory, f"link_{i}.txt")) @classmethod - def setUpClass(cls): + def setup_class(cls): cls.tmpdir = tempfile.TemporaryDirectory() # Create casync files @@ -190,16 +190,16 @@ class TestCasyncDirectory(unittest.TestCase): subprocess.check_output(["casync", "make", "--compression=xz", "--store", cls.store_fn, cls.manifest_fn, cls.orig_fn]) @classmethod - def tearDownClass(cls): + def teardown_class(cls): cls.tmpdir.cleanup() cls.directory_to_extract.cleanup() - def setUp(self): + def setup_method(self): self.cache_dir = tempfile.TemporaryDirectory() self.working_dir = tempfile.TemporaryDirectory() self.out_dir = tempfile.TemporaryDirectory() - def tearDown(self): + def teardown_method(self): self.cache_dir.cleanup() self.working_dir.cleanup() self.out_dir.cleanup() @@ -216,32 +216,32 @@ class TestCasyncDirectory(unittest.TestCase): stats = casync.extract_directory(target, sources, pathlib.Path(self.out_dir.name), tmp_filename) with open(os.path.join(self.out_dir.name, "file_0.txt"), "rb") as f: - self.assertEqual(f.read(), self.contents) + assert f.read() == self.contents with open(os.path.join(self.out_dir.name, "link_0.txt"), "rb") as f: - self.assertEqual(f.read(), self.contents) - self.assertEqual(os.readlink(os.path.join(self.out_dir.name, "link_0.txt")), "file_0.txt") + assert f.read() == self.contents + assert os.readlink(os.path.join(self.out_dir.name, "link_0.txt")) == "file_0.txt" return stats def test_no_cache(self): self.setup_cache(self.cache_dir.name, []) stats = self.run_test() - self.assertGreater(stats['remote'], 0) - self.assertEqual(stats['cache'], 0) + assert stats['remote'] > 0 + assert stats['cache'] == 0 def test_full_cache(self): self.setup_cache(self.cache_dir.name, range(self.NUM_FILES)) stats = self.run_test() - self.assertEqual(stats['remote'], 0) - self.assertGreater(stats['cache'], 0) + assert stats['remote'] == 0 + assert stats['cache'] > 0 def test_one_file_cache(self): self.setup_cache(self.cache_dir.name, range(1)) stats = self.run_test() - self.assertGreater(stats['remote'], 0) - self.assertGreater(stats['cache'], 0) - self.assertLess(stats['cache'], stats['remote']) + assert stats['remote'] > 0 + assert stats['cache'] > 0 + assert stats['cache'] < stats['remote'] def test_one_file_incorrect_cache(self): self.setup_cache(self.cache_dir.name, range(self.NUM_FILES)) @@ -249,19 +249,15 @@ class TestCasyncDirectory(unittest.TestCase): f.write(b"1234") stats = self.run_test() - self.assertGreater(stats['remote'], 0) - self.assertGreater(stats['cache'], 0) - self.assertGreater(stats['cache'], stats['remote']) + assert stats['remote'] > 0 + assert stats['cache'] > 0 + assert stats['cache'] > stats['remote'] def test_one_file_missing_cache(self): self.setup_cache(self.cache_dir.name, range(self.NUM_FILES)) os.unlink(os.path.join(self.cache_dir.name, "file_12.txt")) stats = self.run_test() - self.assertGreater(stats['remote'], 0) - self.assertGreater(stats['cache'], 0) - self.assertGreater(stats['cache'], stats['remote']) - - -if __name__ == "__main__": - unittest.main() + assert stats['remote'] > 0 + assert stats['cache'] > 0 + assert stats['cache'] > stats['remote'] diff --git a/system/webrtc/tests/test_stream_session.py b/system/webrtc/tests/test_stream_session.py index 2173c3806b..d8defab13f 100755 --- a/system/webrtc/tests/test_stream_session.py +++ b/system/webrtc/tests/test_stream_session.py @@ -1,7 +1,5 @@ #!/usr/bin/env python3 import asyncio -import unittest -from unittest.mock import Mock, MagicMock, patch import json # for aiortc and its dependencies import warnings @@ -20,15 +18,15 @@ from openpilot.system.webrtc.device.audio import AudioInputStreamTrack from openpilot.common.realtime import DT_DMON -class TestStreamSession(unittest.TestCase): - def setUp(self): +class TestStreamSession: + def setup_method(self): self.loop = asyncio.new_event_loop() - def tearDown(self): + def teardown_method(self): self.loop.stop() self.loop.close() - def test_outgoing_proxy(self): + def test_outgoing_proxy(self, mocker): test_msg = log.Event.new_message() test_msg.logMonoTime = 123 test_msg.valid = True @@ -36,27 +34,27 @@ class TestStreamSession(unittest.TestCase): expected_dict = {"type": "customReservedRawData0", "logMonoTime": 123, "valid": True, "data": "test"} expected_json = json.dumps(expected_dict).encode() - channel = Mock(spec=RTCDataChannel) + channel = mocker.Mock(spec=RTCDataChannel) mocked_submaster = messaging.SubMaster(["customReservedRawData0"]) def mocked_update(t): mocked_submaster.update_msgs(0, [test_msg]) - with patch.object(messaging.SubMaster, "update", side_effect=mocked_update): - proxy = CerealOutgoingMessageProxy(mocked_submaster) - proxy.add_channel(channel) + mocker.patch.object(messaging.SubMaster, "update", side_effect=mocked_update) + proxy = CerealOutgoingMessageProxy(mocked_submaster) + proxy.add_channel(channel) - proxy.update() + proxy.update() - channel.send.assert_called_once_with(expected_json) + channel.send.assert_called_once_with(expected_json) - def test_incoming_proxy(self): + def test_incoming_proxy(self, mocker): tested_msgs = [ {"type": "customReservedRawData0", "data": "test"}, # primitive {"type": "can", "data": [{"address": 0, "busTime": 0, "dat": "", "src": 0}]}, # list {"type": "testJoystick", "data": {"axes": [0, 0], "buttons": [False]}}, # dict ] - mocked_pubmaster = MagicMock(spec=messaging.PubMaster) + mocked_pubmaster = mocker.MagicMock(spec=messaging.PubMaster) proxy = CerealIncomingMessageProxy(mocked_pubmaster) @@ -65,44 +63,40 @@ class TestStreamSession(unittest.TestCase): mocked_pubmaster.send.assert_called_once() mt, md = mocked_pubmaster.send.call_args.args - self.assertEqual(mt, msg["type"]) - self.assertIsInstance(md, capnp._DynamicStructBuilder) - self.assertTrue(hasattr(md, msg["type"])) + assert mt == msg["type"] + assert isinstance(md, capnp._DynamicStructBuilder) + assert hasattr(md, msg["type"]) mocked_pubmaster.reset_mock() - def test_livestream_track(self): + def test_livestream_track(self, mocker): fake_msg = messaging.new_message("livestreamDriverEncodeData") config = {"receive.return_value": fake_msg.to_bytes()} - with patch("cereal.messaging.SubSocket", spec=True, **config): - track = LiveStreamVideoStreamTrack("driver") + mocker.patch("cereal.messaging.SubSocket", spec=True, **config) + track = LiveStreamVideoStreamTrack("driver") - self.assertTrue(track.id.startswith("driver")) - self.assertEqual(track.codec_preference(), "H264") + assert track.id.startswith("driver") + assert track.codec_preference() == "H264" - for i in range(5): - packet = self.loop.run_until_complete(track.recv()) - self.assertEqual(packet.time_base, VIDEO_TIME_BASE) - self.assertEqual(packet.pts, int(i * DT_DMON * VIDEO_CLOCK_RATE)) - self.assertEqual(packet.size, 0) + for i in range(5): + packet = self.loop.run_until_complete(track.recv()) + assert packet.time_base == VIDEO_TIME_BASE + assert packet.pts == int(i * DT_DMON * VIDEO_CLOCK_RATE) + assert packet.size == 0 - def test_input_audio_track(self): + def test_input_audio_track(self, mocker): packet_time, rate = 0.02, 16000 sample_count = int(packet_time * rate) - mocked_stream = MagicMock(spec=pyaudio.Stream) + mocked_stream = mocker.MagicMock(spec=pyaudio.Stream) mocked_stream.read.return_value = b"\x00" * 2 * sample_count config = {"open.side_effect": lambda *args, **kwargs: mocked_stream} - with patch("pyaudio.PyAudio", spec=True, **config): - track = AudioInputStreamTrack(audio_format=pyaudio.paInt16, packet_time=packet_time, rate=rate) - - for i in range(5): - frame = self.loop.run_until_complete(track.recv()) - self.assertEqual(frame.rate, rate) - self.assertEqual(frame.samples, sample_count) - self.assertEqual(frame.pts, i * sample_count) - - -if __name__ == "__main__": - unittest.main() + mocker.patch("pyaudio.PyAudio", spec=True, **config) + track = AudioInputStreamTrack(audio_format=pyaudio.paInt16, packet_time=packet_time, rate=rate) + + for i in range(5): + frame = self.loop.run_until_complete(track.recv()) + assert frame.rate == rate + assert frame.samples == sample_count + assert frame.pts == i * sample_count diff --git a/system/webrtc/tests/test_webrtcd.py b/system/webrtc/tests/test_webrtcd.py index e5742dba07..684c7cf359 100755 --- a/system/webrtc/tests/test_webrtcd.py +++ b/system/webrtc/tests/test_webrtcd.py @@ -1,8 +1,7 @@ #!/usr/bin/env python +import pytest import asyncio import json -import unittest -from unittest.mock import MagicMock, AsyncMock # for aiortc and its dependencies import warnings warnings.filterwarnings("ignore", category=DeprecationWarning) @@ -20,19 +19,20 @@ from parameterized import parameterized_class (["testJoystick"], []), ([], []), ]) -class TestWebrtcdProc(unittest.IsolatedAsyncioTestCase): +@pytest.mark.asyncio +class TestWebrtcdProc(): async def assertCompletesWithTimeout(self, awaitable, timeout=1): try: async with asyncio.timeout(timeout): await awaitable except TimeoutError: - self.fail("Timeout while waiting for awaitable to complete") + pytest.fail("Timeout while waiting for awaitable to complete") - async def test_webrtcd(self): - mock_request = MagicMock() + async def test_webrtcd(self, mocker): + mock_request = mocker.MagicMock() async def connect(offer): body = {'sdp': offer.sdp, 'cameras': offer.video, 'bridge_services_in': self.in_services, 'bridge_services_out': self.out_services} - mock_request.json.side_effect = AsyncMock(return_value=body) + mock_request.json.side_effect = mocker.AsyncMock(return_value=body) response = await get_stream(mock_request) response_json = json.loads(response.text) return aiortc.RTCSessionDescription(**response_json) @@ -48,9 +48,9 @@ class TestWebrtcdProc(unittest.IsolatedAsyncioTestCase): await self.assertCompletesWithTimeout(stream.start()) await self.assertCompletesWithTimeout(stream.wait_for_connection()) - self.assertTrue(stream.has_incoming_video_track("road")) - self.assertTrue(stream.has_incoming_audio_track()) - self.assertEqual(stream.has_messaging_channel(), len(self.in_services) > 0 or len(self.out_services) > 0) + assert stream.has_incoming_video_track("road") + assert stream.has_incoming_audio_track() + assert stream.has_messaging_channel() == (len(self.in_services) > 0 or len(self.out_services) > 0) video_track, audio_track = stream.get_incoming_video_track("road"), stream.get_incoming_audio_track() await self.assertCompletesWithTimeout(video_track.recv()) @@ -59,10 +59,6 @@ class TestWebrtcdProc(unittest.IsolatedAsyncioTestCase): await self.assertCompletesWithTimeout(stream.stop()) # cleanup, very implementation specific, test may break if it changes - self.assertTrue(mock_request.app["streams"].__setitem__.called, "Implementation changed, please update this test") + assert mock_request.app["streams"].__setitem__.called, "Implementation changed, please update this test" _, session = mock_request.app["streams"].__setitem__.call_args.args await self.assertCompletesWithTimeout(session.post_run_cleanup()) - - -if __name__ == "__main__": - unittest.main() diff --git a/tools/car_porting/test_car_model.py b/tools/car_porting/test_car_model.py index 5f8294fd3c..cf0be1a80a 100755 --- a/tools/car_porting/test_car_model.py +++ b/tools/car_porting/test_car_model.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 import argparse import sys -import unittest +import unittest # noqa: TID251 from openpilot.selfdrive.car.tests.routes import CarTestRoute from openpilot.selfdrive.car.tests.test_models import TestCarModel diff --git a/tools/lib/tests/test_caching.py b/tools/lib/tests/test_caching.py index 5d3dfeba42..32f55df5d1 100755 --- a/tools/lib/tests/test_caching.py +++ b/tools/lib/tests/test_caching.py @@ -1,13 +1,11 @@ #!/usr/bin/env python3 -from functools import partial import http.server import os import shutil import socket -import unittest +import pytest -from parameterized import parameterized -from openpilot.selfdrive.test.helpers import with_http_server +from openpilot.selfdrive.test.helpers import http_server_context from openpilot.system.hardware.hw import Paths from openpilot.tools.lib.url_file import URLFile @@ -31,22 +29,23 @@ class CachingTestRequestHandler(http.server.BaseHTTPRequestHandler): self.end_headers() -with_caching_server = partial(with_http_server, handler=CachingTestRequestHandler) +@pytest.fixture +def host(): + with http_server_context(handler=CachingTestRequestHandler) as (host, port): + yield f"http://{host}:{port}" +class TestFileDownload: -class TestFileDownload(unittest.TestCase): - - @with_caching_server def test_pipeline_defaults(self, host): # TODO: parameterize the defaults so we don't rely on hard-coded values in xx - self.assertEqual(URLFile.pool_manager().pools._maxsize, 10) # PoolManager num_pools param + assert URLFile.pool_manager().pools._maxsize == 10# PoolManager num_pools param pool_manager_defaults = { "maxsize": 100, "socket_options": [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),], } for k, v in pool_manager_defaults.items(): - self.assertEqual(URLFile.pool_manager().connection_pool_kw.get(k), v) + assert URLFile.pool_manager().connection_pool_kw.get(k) == v retry_defaults = { "total": 5, @@ -54,7 +53,7 @@ class TestFileDownload(unittest.TestCase): "status_forcelist": [409, 429, 503, 504], } for k, v in retry_defaults.items(): - self.assertEqual(getattr(URLFile.pool_manager().connection_pool_kw["retries"], k), v) + assert getattr(URLFile.pool_manager().connection_pool_kw["retries"], k) == v # ensure caching off by default and cache dir doesn't get created os.environ.pop("FILEREADER_CACHE", None) @@ -62,7 +61,7 @@ class TestFileDownload(unittest.TestCase): shutil.rmtree(Paths.download_cache_root()) URLFile(f"{host}/test.txt").get_length() URLFile(f"{host}/test.txt").read() - self.assertEqual(os.path.exists(Paths.download_cache_root()), False) + assert not os.path.exists(Paths.download_cache_root()) def compare_loads(self, url, start=0, length=None): """Compares range between cached and non cached version""" @@ -72,21 +71,21 @@ class TestFileDownload(unittest.TestCase): file_cached.seek(start) file_downloaded.seek(start) - self.assertEqual(file_cached.get_length(), file_downloaded.get_length()) - self.assertLessEqual(length + start if length is not None else 0, file_downloaded.get_length()) + assert file_cached.get_length() == file_downloaded.get_length() + assert length + start if length is not None else 0 <= file_downloaded.get_length() response_cached = file_cached.read(ll=length) response_downloaded = file_downloaded.read(ll=length) - self.assertEqual(response_cached, response_downloaded) + assert response_cached == response_downloaded # Now test with cache in place file_cached = URLFile(url, cache=True) file_cached.seek(start) response_cached = file_cached.read(ll=length) - self.assertEqual(file_cached.get_length(), file_downloaded.get_length()) - self.assertEqual(response_cached, response_downloaded) + assert file_cached.get_length() == file_downloaded.get_length() + assert response_cached == response_downloaded def test_small_file(self): # Make sure we don't force cache @@ -117,22 +116,16 @@ class TestFileDownload(unittest.TestCase): self.compare_loads(large_file_url, length - 100, 100) self.compare_loads(large_file_url) - @parameterized.expand([(True, ), (False, )]) - @with_caching_server - def test_recover_from_missing_file(self, cache_enabled, host): + @pytest.mark.parametrize("cache_enabled", [True, False]) + def test_recover_from_missing_file(self, host, cache_enabled): os.environ["FILEREADER_CACHE"] = "1" if cache_enabled else "0" file_url = f"{host}/test.png" CachingTestRequestHandler.FILE_EXISTS = False length = URLFile(file_url).get_length() - self.assertEqual(length, -1) + assert length == -1 CachingTestRequestHandler.FILE_EXISTS = True length = URLFile(file_url).get_length() - self.assertEqual(length, 4) - - - -if __name__ == "__main__": - unittest.main() + assert length == 4 diff --git a/tools/lib/tests/test_comma_car_segments.py b/tools/lib/tests/test_comma_car_segments.py index 91bab94343..b9a4def75f 100644 --- a/tools/lib/tests/test_comma_car_segments.py +++ b/tools/lib/tests/test_comma_car_segments.py @@ -1,5 +1,4 @@ import pytest -import unittest import requests from openpilot.selfdrive.car.fingerprints import MIGRATION from openpilot.tools.lib.comma_car_segments import get_comma_car_segments_database, get_url @@ -8,7 +7,7 @@ from openpilot.tools.lib.route import SegmentRange @pytest.mark.skip(reason="huggingface is flaky, run this test manually to check for issues") -class TestCommaCarSegments(unittest.TestCase): +class TestCommaCarSegments: def test_database(self): database = get_comma_car_segments_database() @@ -28,12 +27,8 @@ class TestCommaCarSegments(unittest.TestCase): url = get_url(sr.route_name, sr.slice) resp = requests.get(url) - self.assertEqual(resp.status_code, 200) + assert resp.status_code == 200 lr = LogReader(url) CP = lr.first("carParams") - self.assertEqual(MIGRATION.get(CP.carFingerprint, CP.carFingerprint), fp) - - -if __name__ == "__main__": - unittest.main() + assert MIGRATION.get(CP.carFingerprint, CP.carFingerprint) == fp diff --git a/tools/lib/tests/test_logreader.py b/tools/lib/tests/test_logreader.py index fc72202b26..58d22a07ef 100755 --- a/tools/lib/tests/test_logreader.py +++ b/tools/lib/tests/test_logreader.py @@ -5,12 +5,10 @@ import io import shutil import tempfile import os -import unittest import pytest import requests from parameterized import parameterized -from unittest import mock from cereal import log as capnp_log from openpilot.tools.lib.logreader import LogIterable, LogReader, comma_api_source, parse_indirect, ReadMode, InternalUnavailableException @@ -28,24 +26,22 @@ def noop(segment: LogIterable): @contextlib.contextmanager -def setup_source_scenario(is_internal=False): - with ( - mock.patch("openpilot.tools.lib.logreader.internal_source") as internal_source_mock, - mock.patch("openpilot.tools.lib.logreader.openpilotci_source") as openpilotci_source_mock, - mock.patch("openpilot.tools.lib.logreader.comma_api_source") as comma_api_source_mock, - ): - if is_internal: - internal_source_mock.return_value = [QLOG_FILE] - else: - internal_source_mock.side_effect = InternalUnavailableException +def setup_source_scenario(mocker, is_internal=False): + internal_source_mock = mocker.patch("openpilot.tools.lib.logreader.internal_source") + openpilotci_source_mock = mocker.patch("openpilot.tools.lib.logreader.openpilotci_source") + comma_api_source_mock = mocker.patch("openpilot.tools.lib.logreader.comma_api_source") + if is_internal: + internal_source_mock.return_value = [QLOG_FILE] + else: + internal_source_mock.side_effect = InternalUnavailableException - openpilotci_source_mock.return_value = [None] - comma_api_source_mock.return_value = [QLOG_FILE] + openpilotci_source_mock.return_value = [None] + comma_api_source_mock.return_value = [QLOG_FILE] - yield + yield -class TestLogReader(unittest.TestCase): +class TestLogReader: @parameterized.expand([ (f"{TEST_ROUTE}", ALL_SEGS), (f"{TEST_ROUTE.replace('/', '|')}", ALL_SEGS), @@ -74,7 +70,7 @@ class TestLogReader(unittest.TestCase): def test_indirect_parsing(self, identifier, expected): parsed, _, _ = parse_indirect(identifier) sr = SegmentRange(parsed) - self.assertListEqual(list(sr.seg_idxs), expected, identifier) + assert list(sr.seg_idxs) == expected, identifier @parameterized.expand([ (f"{TEST_ROUTE}", f"{TEST_ROUTE}"), @@ -86,11 +82,11 @@ class TestLogReader(unittest.TestCase): ]) def test_canonical_name(self, identifier, expected): sr = SegmentRange(identifier) - self.assertEqual(str(sr), expected) + assert str(sr) == expected - @parameterized.expand([(True,), (False,)]) - @mock.patch("openpilot.tools.lib.logreader.file_exists") - def test_direct_parsing(self, cache_enabled, file_exists_mock): + @pytest.mark.parametrize("cache_enabled", [True, False]) + def test_direct_parsing(self, mocker, cache_enabled): + file_exists_mock = mocker.patch("openpilot.tools.lib.logreader.file_exists") os.environ["FILEREADER_CACHE"] = "1" if cache_enabled else "0" qlog = tempfile.NamedTemporaryFile(mode='wb', delete=False) @@ -100,13 +96,13 @@ class TestLogReader(unittest.TestCase): for f in [QLOG_FILE, qlog.name]: l = len(list(LogReader(f))) - self.assertGreater(l, 100) + assert l > 100 - with self.assertRaises(URLFileException) if not cache_enabled else self.assertRaises(AssertionError): + with pytest.raises(URLFileException) if not cache_enabled else pytest.raises(AssertionError): l = len(list(LogReader(QLOG_FILE.replace("/3/", "/200/")))) # file_exists should not be called for direct files - self.assertEqual(file_exists_mock.call_count, 0) + assert file_exists_mock.call_count == 0 @parameterized.expand([ (f"{TEST_ROUTE}///",), @@ -121,110 +117,110 @@ class TestLogReader(unittest.TestCase): (f"{TEST_ROUTE}--3a",), ]) def test_bad_ranges(self, segment_range): - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): _ = SegmentRange(segment_range).seg_idxs - @parameterized.expand([ + @pytest.mark.parametrize("segment_range, api_call", [ (f"{TEST_ROUTE}/0", False), (f"{TEST_ROUTE}/:2", False), (f"{TEST_ROUTE}/0:", True), (f"{TEST_ROUTE}/-1", True), (f"{TEST_ROUTE}", True), ]) - def test_slicing_api_call(self, segment_range, api_call): - with mock.patch("openpilot.tools.lib.route.get_max_seg_number_cached") as max_seg_mock: - max_seg_mock.return_value = NUM_SEGS - _ = SegmentRange(segment_range).seg_idxs - self.assertEqual(api_call, max_seg_mock.called) + def test_slicing_api_call(self, mocker, segment_range, api_call): + max_seg_mock = mocker.patch("openpilot.tools.lib.route.get_max_seg_number_cached") + max_seg_mock.return_value = NUM_SEGS + _ = SegmentRange(segment_range).seg_idxs + assert api_call == max_seg_mock.called @pytest.mark.slow def test_modes(self): qlog_len = len(list(LogReader(f"{TEST_ROUTE}/0", ReadMode.QLOG))) rlog_len = len(list(LogReader(f"{TEST_ROUTE}/0", ReadMode.RLOG))) - self.assertLess(qlog_len * 6, rlog_len) + assert qlog_len * 6 < rlog_len @pytest.mark.slow def test_modes_from_name(self): qlog_len = len(list(LogReader(f"{TEST_ROUTE}/0/q"))) rlog_len = len(list(LogReader(f"{TEST_ROUTE}/0/r"))) - self.assertLess(qlog_len * 6, rlog_len) + assert qlog_len * 6 < rlog_len @pytest.mark.slow def test_list(self): qlog_len = len(list(LogReader(f"{TEST_ROUTE}/0/q"))) qlog_len_2 = len(list(LogReader([f"{TEST_ROUTE}/0/q", f"{TEST_ROUTE}/0/q"]))) - self.assertEqual(qlog_len * 2, qlog_len_2) + assert qlog_len * 2 == qlog_len_2 @pytest.mark.slow - @mock.patch("openpilot.tools.lib.logreader._LogFileReader") - def test_multiple_iterations(self, init_mock): + def test_multiple_iterations(self, mocker): + init_mock = mocker.patch("openpilot.tools.lib.logreader._LogFileReader") lr = LogReader(f"{TEST_ROUTE}/0/q") qlog_len1 = len(list(lr)) qlog_len2 = len(list(lr)) # ensure we don't create multiple instances of _LogFileReader, which means downloading the files twice - self.assertEqual(init_mock.call_count, 1) + assert init_mock.call_count == 1 - self.assertEqual(qlog_len1, qlog_len2) + assert qlog_len1 == qlog_len2 @pytest.mark.slow def test_helpers(self): lr = LogReader(f"{TEST_ROUTE}/0/q") - self.assertEqual(lr.first("carParams").carFingerprint, "SUBARU OUTBACK 6TH GEN") - self.assertTrue(0 < len(list(lr.filter("carParams"))) < len(list(lr))) + assert lr.first("carParams").carFingerprint == "SUBARU OUTBACK 6TH GEN" + assert 0 < len(list(lr.filter("carParams"))) < len(list(lr)) @parameterized.expand([(True,), (False,)]) @pytest.mark.slow def test_run_across_segments(self, cache_enabled): os.environ["FILEREADER_CACHE"] = "1" if cache_enabled else "0" lr = LogReader(f"{TEST_ROUTE}/0:4") - self.assertEqual(len(lr.run_across_segments(4, noop)), len(list(lr))) + assert len(lr.run_across_segments(4, noop)) == len(list(lr)) @pytest.mark.slow - def test_auto_mode(self): + def test_auto_mode(self, subtests, mocker): lr = LogReader(f"{TEST_ROUTE}/0/q") qlog_len = len(list(lr)) - with mock.patch("openpilot.tools.lib.route.Route.log_paths") as log_paths_mock: - log_paths_mock.return_value = [None] * NUM_SEGS - # Should fall back to qlogs since rlogs are not available - - with self.subTest("interactive_yes"): - with mock.patch("sys.stdin", new=io.StringIO("y\n")): - lr = LogReader(f"{TEST_ROUTE}/0", default_mode=ReadMode.AUTO_INTERACTIVE, default_source=comma_api_source) - log_len = len(list(lr)) - self.assertEqual(qlog_len, log_len) - - with self.subTest("interactive_no"): - with mock.patch("sys.stdin", new=io.StringIO("n\n")): - with self.assertRaises(AssertionError): - lr = LogReader(f"{TEST_ROUTE}/0", default_mode=ReadMode.AUTO_INTERACTIVE, default_source=comma_api_source) - - with self.subTest("non_interactive"): - lr = LogReader(f"{TEST_ROUTE}/0", default_mode=ReadMode.AUTO, default_source=comma_api_source) - log_len = len(list(lr)) - self.assertEqual(qlog_len, log_len) + log_paths_mock = mocker.patch("openpilot.tools.lib.route.Route.log_paths") + log_paths_mock.return_value = [None] * NUM_SEGS + # Should fall back to qlogs since rlogs are not available - @parameterized.expand([(True,), (False,)]) + with subtests.test("interactive_yes"): + mocker.patch("sys.stdin", new=io.StringIO("y\n")) + lr = LogReader(f"{TEST_ROUTE}/0", default_mode=ReadMode.AUTO_INTERACTIVE, default_source=comma_api_source) + log_len = len(list(lr)) + assert qlog_len == log_len + + with subtests.test("interactive_no"): + mocker.patch("sys.stdin", new=io.StringIO("n\n")) + with pytest.raises(AssertionError): + lr = LogReader(f"{TEST_ROUTE}/0", default_mode=ReadMode.AUTO_INTERACTIVE, default_source=comma_api_source) + + with subtests.test("non_interactive"): + lr = LogReader(f"{TEST_ROUTE}/0", default_mode=ReadMode.AUTO, default_source=comma_api_source) + log_len = len(list(lr)) + assert qlog_len == log_len + + @pytest.mark.parametrize("is_internal", [True, False]) @pytest.mark.slow - def test_auto_source_scenarios(self, is_internal): + def test_auto_source_scenarios(self, mocker, is_internal): lr = LogReader(QLOG_FILE) qlog_len = len(list(lr)) - with setup_source_scenario(is_internal=is_internal): + with setup_source_scenario(mocker, is_internal=is_internal): lr = LogReader(f"{TEST_ROUTE}/0/q") log_len = len(list(lr)) - self.assertEqual(qlog_len, log_len) + assert qlog_len == log_len @pytest.mark.slow def test_sort_by_time(self): msgs = list(LogReader(f"{TEST_ROUTE}/0/q")) - self.assertNotEqual(msgs, sorted(msgs, key=lambda m: m.logMonoTime)) + assert msgs != sorted(msgs, key=lambda m: m.logMonoTime) msgs = list(LogReader(f"{TEST_ROUTE}/0/q", sort_by_time=True)) - self.assertEqual(msgs, sorted(msgs, key=lambda m: m.logMonoTime)) + assert msgs == sorted(msgs, key=lambda m: m.logMonoTime) def test_only_union_types(self): with tempfile.NamedTemporaryFile() as qlog: @@ -234,7 +230,7 @@ class TestLogReader(unittest.TestCase): f.write(b"".join(capnp_log.Event.new_message().to_bytes() for _ in range(num_msgs))) msgs = list(LogReader(qlog.name)) - self.assertEqual(len(msgs), num_msgs) + assert len(msgs) == num_msgs [m.which() for m in msgs] # append non-union Event message @@ -246,15 +242,11 @@ class TestLogReader(unittest.TestCase): # ensure new message is added, but is not a union type msgs = list(LogReader(qlog.name)) - self.assertEqual(len(msgs), num_msgs + 1) - with self.assertRaises(capnp.KjException): + assert len(msgs) == num_msgs + 1 + with pytest.raises(capnp.KjException): [m.which() for m in msgs] # should not be added when only_union_types=True msgs = list(LogReader(qlog.name, only_union_types=True)) - self.assertEqual(len(msgs), num_msgs) + assert len(msgs) == num_msgs [m.which() for m in msgs] - - -if __name__ == "__main__": - unittest.main() diff --git a/tools/lib/tests/test_readers.py b/tools/lib/tests/test_readers.py index 1f24ae5c8e..f92554872f 100755 --- a/tools/lib/tests/test_readers.py +++ b/tools/lib/tests/test_readers.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -import unittest +import pytest import requests import tempfile @@ -9,16 +9,16 @@ from openpilot.tools.lib.framereader import FrameReader from openpilot.tools.lib.logreader import LogReader -class TestReaders(unittest.TestCase): - @unittest.skip("skip for bandwidth reasons") +class TestReaders: + @pytest.mark.skip("skip for bandwidth reasons") def test_logreader(self): def _check_data(lr): hist = defaultdict(int) for l in lr: hist[l.which()] += 1 - self.assertEqual(hist['carControl'], 6000) - self.assertEqual(hist['logMessage'], 6857) + assert hist['carControl'] == 6000 + assert hist['logMessage'] == 6857 with tempfile.NamedTemporaryFile(suffix=".bz2") as fp: r = requests.get("https://github.com/commaai/comma2k19/blob/master/Example_1/b0c9d2329ad1606b%7C2018-08-02--08-34-47/40/raw_log.bz2?raw=true", timeout=10) @@ -31,15 +31,15 @@ class TestReaders(unittest.TestCase): lr_url = LogReader("https://github.com/commaai/comma2k19/blob/master/Example_1/b0c9d2329ad1606b%7C2018-08-02--08-34-47/40/raw_log.bz2?raw=true") _check_data(lr_url) - @unittest.skip("skip for bandwidth reasons") + @pytest.mark.skip("skip for bandwidth reasons") def test_framereader(self): def _check_data(f): - self.assertEqual(f.frame_count, 1200) - self.assertEqual(f.w, 1164) - self.assertEqual(f.h, 874) + assert f.frame_count == 1200 + assert f.w == 1164 + assert f.h == 874 frame_first_30 = f.get(0, 30) - self.assertEqual(len(frame_first_30), 30) + assert len(frame_first_30) == 30 print(frame_first_30[15]) @@ -62,6 +62,3 @@ class TestReaders(unittest.TestCase): fr_url = FrameReader("https://github.com/commaai/comma2k19/blob/master/Example_1/b0c9d2329ad1606b%7C2018-08-02--08-34-47/40/video.hevc?raw=true") _check_data(fr_url) - -if __name__ == "__main__": - unittest.main() diff --git a/tools/lib/tests/test_route_library.py b/tools/lib/tests/test_route_library.py index 7977f17be2..8f75fa19c0 100755 --- a/tools/lib/tests/test_route_library.py +++ b/tools/lib/tests/test_route_library.py @@ -1,10 +1,9 @@ #!/usr/bin/env python -import unittest from collections import namedtuple from openpilot.tools.lib.route import SegmentName -class TestRouteLibrary(unittest.TestCase): +class TestRouteLibrary: def test_segment_name_formats(self): Case = namedtuple('Case', ['input', 'expected_route', 'expected_segment_num', 'expected_data_dir']) @@ -21,12 +20,9 @@ class TestRouteLibrary(unittest.TestCase): s = SegmentName(route_or_segment_name, allow_route_name=True) - self.assertEqual(str(s.route_name), case.expected_route) - self.assertEqual(s.segment_num, case.expected_segment_num) - self.assertEqual(s.data_dir, case.expected_data_dir) + assert str(s.route_name) == case.expected_route + assert s.segment_num == case.expected_segment_num + assert s.data_dir == case.expected_data_dir for case in cases: _validate(case) - -if __name__ == "__main__": - unittest.main() diff --git a/tools/plotjuggler/test_plotjuggler.py b/tools/plotjuggler/test_plotjuggler.py index 17287fb803..8b811f4847 100755 --- a/tools/plotjuggler/test_plotjuggler.py +++ b/tools/plotjuggler/test_plotjuggler.py @@ -4,7 +4,6 @@ import glob import signal import subprocess import time -import unittest from openpilot.common.basedir import BASEDIR from openpilot.common.timeout import Timeout @@ -12,7 +11,7 @@ from openpilot.tools.plotjuggler.juggle import DEMO_ROUTE, install PJ_DIR = os.path.join(BASEDIR, "tools/plotjuggler") -class TestPlotJuggler(unittest.TestCase): +class TestPlotJuggler: def test_demo(self): install() @@ -28,13 +27,13 @@ class TestPlotJuggler(unittest.TestCase): # ensure plotjuggler didn't crash after exiting the plugin time.sleep(15) - self.assertEqual(p.poll(), None) + assert p.poll() is None os.killpg(os.getpgid(p.pid), signal.SIGTERM) - self.assertNotIn("Raw file read failed", output) + assert "Raw file read failed" not in output # TODO: also test that layouts successfully load - def test_layouts(self): + def test_layouts(self, subtests): bad_strings = ( # if a previously loaded file is defined, # PJ will throw a warning when loading the layout @@ -43,12 +42,8 @@ class TestPlotJuggler(unittest.TestCase): ) for fn in glob.glob(os.path.join(PJ_DIR, "layouts/*")): name = os.path.basename(fn) - with self.subTest(layout=name): + with subtests.test(layout=name): with open(fn) as f: layout = f.read() violations = [s for s in bad_strings if s in layout] assert len(violations) == 0, f"These should be stripped out of the layout: {str(violations)}" - - -if __name__ == "__main__": - unittest.main() diff --git a/tools/sim/tests/test_metadrive_bridge.py b/tools/sim/tests/test_metadrive_bridge.py index 6cb8e1465e..28fa462a1c 100755 --- a/tools/sim/tests/test_metadrive_bridge.py +++ b/tools/sim/tests/test_metadrive_bridge.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 import pytest -import unittest from openpilot.tools.sim.bridge.metadrive.metadrive_bridge import MetaDriveBridge from openpilot.tools.sim.tests.test_sim_bridge import TestSimBridgeBase @@ -9,7 +8,3 @@ from openpilot.tools.sim.tests.test_sim_bridge import TestSimBridgeBase class TestMetaDriveBridge(TestSimBridgeBase): def create_bridge(self): return MetaDriveBridge(False, False) - - -if __name__ == "__main__": - unittest.main() diff --git a/tools/sim/tests/test_sim_bridge.py b/tools/sim/tests/test_sim_bridge.py index d9653d5cfd..6b8b811fbb 100644 --- a/tools/sim/tests/test_sim_bridge.py +++ b/tools/sim/tests/test_sim_bridge.py @@ -1,7 +1,7 @@ import os import subprocess import time -import unittest +import pytest from multiprocessing import Queue @@ -10,13 +10,13 @@ from openpilot.common.basedir import BASEDIR SIM_DIR = os.path.join(BASEDIR, "tools/sim") -class TestSimBridgeBase(unittest.TestCase): +class TestSimBridgeBase: @classmethod - def setUpClass(cls): + def setup_class(cls): if cls is TestSimBridgeBase: - raise unittest.SkipTest("Don't run this base class, run test_metadrive_bridge.py instead") + raise pytest.skip("Don't run this base class, run test_metadrive_bridge.py instead") - def setUp(self): + def setup_method(self): self.processes = [] def test_engage(self): @@ -36,7 +36,7 @@ class TestSimBridgeBase(unittest.TestCase): start_waiting = time.monotonic() while not bridge.started.value and time.monotonic() < start_waiting + max_time_per_step: time.sleep(0.1) - self.assertEqual(p_bridge.exitcode, None, f"Bridge process should be running, but exited with code {p_bridge.exitcode}") + assert p_bridge.exitcode is None, f"Bridge process should be running, but exited with code {p_bridge.exitcode}" start_time = time.monotonic() no_car_events_issues_once = False @@ -52,8 +52,8 @@ class TestSimBridgeBase(unittest.TestCase): no_car_events_issues_once = True break - self.assertTrue(no_car_events_issues_once, - f"Failed because no messages received, or CarEvents '{car_event_issues}' or processes not running '{not_running}'") + assert no_car_events_issues_once, \ + f"Failed because no messages received, or CarEvents '{car_event_issues}' or processes not running '{not_running}'" start_time = time.monotonic() min_counts_control_active = 100 @@ -68,9 +68,9 @@ class TestSimBridgeBase(unittest.TestCase): if control_active == min_counts_control_active: break - self.assertEqual(min_counts_control_active, control_active, f"Simulator did not engage a minimal of {min_counts_control_active} steps was {control_active}") + assert min_counts_control_active == control_active, f"Simulator did not engage a minimal of {min_counts_control_active} steps was {control_active}" - def tearDown(self): + def teardown_method(self): print("Test shutting down. CommIssues are acceptable") for p in reversed(self.processes): p.terminate() @@ -80,7 +80,3 @@ class TestSimBridgeBase(unittest.TestCase): p.wait(15) else: p.join(15) - - -if __name__ == "__main__": - unittest.main()