diff --git a/release/files_common b/release/files_common index 30c6cd9cc9..d1a0a61fa3 100644 --- a/release/files_common +++ b/release/files_common @@ -64,7 +64,6 @@ release/build_release2.sh selfdrive/version.py selfdrive/__init__.py -selfdrive/registration.py selfdrive/config.py selfdrive/crash.py selfdrive/swaglog.py @@ -77,6 +76,7 @@ selfdrive/rtshield.py selfdrive/athena/__init__.py selfdrive/athena/athenad.py selfdrive/athena/manage_athenad.py +selfdrive/athena/registration.py selfdrive/boardd/.gitignore selfdrive/boardd/SConscript diff --git a/selfdrive/registration.py b/selfdrive/athena/registration.py similarity index 94% rename from selfdrive/registration.py rename to selfdrive/athena/registration.py index ad4db74f1a..aa23b37def 100644 --- a/selfdrive/registration.py +++ b/selfdrive/athena/registration.py @@ -14,7 +14,7 @@ from selfdrive.hardware import HARDWARE, PC from selfdrive.swaglog import cloudlog -def register(show_spinner=False): +def register(show_spinner=False) -> str: params = Params() params.put("SubscriberInfo", HARDWARE.get_subscriber_info()) @@ -41,8 +41,9 @@ def register(show_spinner=False): spinner.update("registering device") # Create registration token, in the future, this key will make JWTs directly - private_key = open(PERSIST+"/comma/id_rsa").read() - public_key = open(PERSIST+"/comma/id_rsa.pub").read() + with open(PERSIST+"/comma/id_rsa.pub") as f1, open(PERSIST+"/comma/id_rsa") as f2: + public_key = f1.read() + private_key = f2.read() # Block until we get the imei imei1, imei2 = None, None diff --git a/selfdrive/athena/tests/helpers.py b/selfdrive/athena/tests/helpers.py index 5a6fba362d..831b668297 100644 --- a/selfdrive/athena/tests/helpers.py +++ b/selfdrive/athena/tests/helpers.py @@ -8,6 +8,14 @@ from multiprocessing import Process from common.timeout import Timeout + +class MockResponse: + def __init__(self, json, status_code): + self.json = json + self.text = json + self.status_code = status_code + + class EchoSocket(): def __init__(self, port): self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) diff --git a/selfdrive/athena/tests/test_registration.py b/selfdrive/athena/tests/test_registration.py new file mode 100755 index 0000000000..ffd0761cd1 --- /dev/null +++ b/selfdrive/athena/tests/test_registration.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +import json +import os +import tempfile +import unittest +from Crypto.PublicKey import RSA +from pathlib import Path +from unittest import mock + +from common.params import Params +from selfdrive.athena.registration import register +from selfdrive.athena.tests.helpers import MockResponse + + +class TestRegistration(unittest.TestCase): + + def setUp(self): + # clear params and setup key paths + self.params = Params() + self.params.clear_all() + + self.persist = tempfile.TemporaryDirectory() + os.mkdir(os.path.join(self.persist.name, "comma")) + self.priv_key = Path(os.path.join(self.persist.name, "comma/id_rsa")) + self.pub_key = Path(os.path.join(self.persist.name, "comma/id_rsa.pub")) + self.persist_patcher = mock.patch("selfdrive.athena.registration.PERSIST", self.persist.name) + self.persist_patcher.start() + + def tearDown(self): + self.persist_patcher.stop() + self.persist.cleanup() + + def _generate_keys(self): + self.pub_key.touch() + k = RSA.generate(2048) + with open(self.priv_key, "wb") as f: + f.write(k.export_key()) + with open(self.pub_key, "wb") as f: + f.write(k.publickey().export_key()) + + def test_valid_cache(self): + # if all params are written, return the cached dongle id + self.params.put("IMEI", "imei") + self.params.put("HardwareSerial", "serial") + self._generate_keys() + + with mock.patch("selfdrive.athena.registration.api_get", autospec=True) as m: + dongle = "DONGLE_ID_123" + self.params.put("DongleId", dongle) + self.assertEqual(register(), dongle) + self.assertFalse(m.called) + + def test_missing_cache(self): + # keys exist but no dongle id + self._generate_keys() + with mock.patch("selfdrive.athena.registration.api_get", autospec=True) as m: + dongle = "DONGLE_ID_123" + m.return_value = MockResponse(json.dumps({'dongle_id': dongle}), 200) + self.assertEqual(register(), dongle) + self.assertEqual(m.call_count, 1) + + # call again, shouldn't hit the API this time + self.assertEqual(register(), dongle) + self.assertEqual(m.call_count, 1) + self.assertEqual(self.params.get("DongleId", encoding='utf-8'), dongle) + + def test_unregistered_pc(self): + # no keys, no dongle id + with mock.patch("selfdrive.athena.registration.api_get", autospec=True) as m, \ + mock.patch("selfdrive.athena.registration.PC", new=True): + m.return_value = MockResponse(None, 402) + dongle = register() + self.assertGreater(len(dongle), 0) + self.assertEqual(m.call_count, 1) + self.assertEqual(self.params.get("DongleId", encoding='utf-8'), dongle) + + def test_unregistered_non_pc(self): + # no keys, no dongle id + with mock.patch("selfdrive.athena.registration.api_get", autospec=True) as m, \ + mock.patch("selfdrive.athena.registration.PC", new=False): + m.return_value = MockResponse(None, 402) + self.assertIs(register(), None) + self.assertEqual(m.call_count, 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/selfdrive/manager/manager.py b/selfdrive/manager/manager.py index 9e04913d2a..ac80187abd 100755 --- a/selfdrive/manager/manager.py +++ b/selfdrive/manager/manager.py @@ -16,7 +16,7 @@ from selfdrive.hardware import HARDWARE, PC, TICI from selfdrive.manager.helpers import unblock_stdout from selfdrive.manager.process import ensure_running from selfdrive.manager.process_config import managed_processes -from selfdrive.registration import register +from selfdrive.athena.registration import register from selfdrive.swaglog import cloudlog, add_file_handler from selfdrive.version import dirty, get_git_commit, version, origin, branch, commit, \ terms_version, training_version, \