parent
78d3dfe8ae
commit
6466776773
5 changed files with 665 additions and 0 deletions
@ -0,0 +1,322 @@ |
|||||||
|
#!/usr/bin/env python3.7 |
||||||
|
import json |
||||||
|
import os |
||||||
|
import hashlib |
||||||
|
import io |
||||||
|
import random |
||||||
|
import select |
||||||
|
import socket |
||||||
|
import time |
||||||
|
import threading |
||||||
|
import base64 |
||||||
|
import requests |
||||||
|
import queue |
||||||
|
from collections import namedtuple |
||||||
|
from functools import partial |
||||||
|
from jsonrpc import JSONRPCResponseManager, dispatcher |
||||||
|
from websocket import create_connection, WebSocketTimeoutException, ABNF |
||||||
|
from selfdrive.loggerd.config import ROOT |
||||||
|
|
||||||
|
import cereal.messaging as messaging |
||||||
|
from common import android |
||||||
|
from common.api import Api |
||||||
|
from common.params import Params |
||||||
|
from cereal.services import service_list |
||||||
|
from selfdrive.swaglog import cloudlog |
||||||
|
|
||||||
|
ATHENA_HOST = os.getenv('ATHENA_HOST', 'wss://athena.comma.ai') |
||||||
|
HANDLER_THREADS = os.getenv('HANDLER_THREADS', 4) |
||||||
|
LOCAL_PORT_WHITELIST = set([8022]) |
||||||
|
|
||||||
|
dispatcher["echo"] = lambda s: s |
||||||
|
payload_queue = queue.Queue() |
||||||
|
response_queue = queue.Queue() |
||||||
|
upload_queue = queue.Queue() |
||||||
|
cancelled_uploads = set() |
||||||
|
UploadItem = namedtuple('UploadItem', ['path', 'url', 'headers', 'created_at', 'id']) |
||||||
|
|
||||||
|
def handle_long_poll(ws): |
||||||
|
end_event = threading.Event() |
||||||
|
|
||||||
|
threads = [ |
||||||
|
threading.Thread(target=ws_recv, args=(ws, end_event)), |
||||||
|
threading.Thread(target=ws_send, args=(ws, end_event)), |
||||||
|
threading.Thread(target=upload_handler, args=(end_event,)) |
||||||
|
] + [ |
||||||
|
threading.Thread(target=jsonrpc_handler, args=(end_event,)) |
||||||
|
for x in range(HANDLER_THREADS) |
||||||
|
] |
||||||
|
|
||||||
|
for thread in threads: |
||||||
|
thread.start() |
||||||
|
try: |
||||||
|
while not end_event.is_set(): |
||||||
|
time.sleep(0.1) |
||||||
|
except (KeyboardInterrupt, SystemExit): |
||||||
|
end_event.set() |
||||||
|
raise |
||||||
|
finally: |
||||||
|
for i, thread in enumerate(threads): |
||||||
|
thread.join() |
||||||
|
|
||||||
|
def jsonrpc_handler(end_event): |
||||||
|
dispatcher["startLocalProxy"] = partial(startLocalProxy, end_event) |
||||||
|
while not end_event.is_set(): |
||||||
|
try: |
||||||
|
data = payload_queue.get(timeout=1) |
||||||
|
response = JSONRPCResponseManager.handle(data, dispatcher) |
||||||
|
response_queue.put_nowait(response) |
||||||
|
except queue.Empty: |
||||||
|
pass |
||||||
|
except Exception as e: |
||||||
|
cloudlog.exception("athena jsonrpc handler failed") |
||||||
|
response_queue.put_nowait(json.dumps({"error": str(e)})) |
||||||
|
|
||||||
|
def upload_handler(end_event): |
||||||
|
while not end_event.is_set(): |
||||||
|
try: |
||||||
|
item = upload_queue.get(timeout=1) |
||||||
|
if item.id in cancelled_uploads: |
||||||
|
cancelled_uploads.remove(item.id) |
||||||
|
continue |
||||||
|
_do_upload(item) |
||||||
|
except queue.Empty: |
||||||
|
pass |
||||||
|
except Exception: |
||||||
|
cloudlog.exception("athena.upload_handler.exception") |
||||||
|
|
||||||
|
def _do_upload(upload_item): |
||||||
|
with open(upload_item.path, "rb") as f: |
||||||
|
size = os.fstat(f.fileno()).st_size |
||||||
|
return requests.put(upload_item.url, |
||||||
|
data=f, |
||||||
|
headers={**upload_item.headers, 'Content-Length': str(size)}, |
||||||
|
timeout=10) |
||||||
|
|
||||||
|
# security: user should be able to request any message from their car |
||||||
|
@dispatcher.add_method |
||||||
|
def getMessage(service=None, timeout=1000): |
||||||
|
if service is None or service not in service_list: |
||||||
|
raise Exception("invalid service") |
||||||
|
|
||||||
|
socket = messaging.sub_sock(service, timeout=timeout) |
||||||
|
ret = messaging.recv_one(socket) |
||||||
|
|
||||||
|
if ret is None: |
||||||
|
raise TimeoutError |
||||||
|
|
||||||
|
return ret.to_dict() |
||||||
|
|
||||||
|
@dispatcher.add_method |
||||||
|
def listDataDirectory(): |
||||||
|
files = [os.path.relpath(os.path.join(dp, f), ROOT) for dp, dn, fn in os.walk(ROOT) for f in fn] |
||||||
|
return files |
||||||
|
|
||||||
|
@dispatcher.add_method |
||||||
|
def reboot(): |
||||||
|
thermal_sock = messaging.sub_sock("thermal", timeout=1000) |
||||||
|
ret = messaging.recv_one(thermal_sock) |
||||||
|
if ret is None or ret.thermal.started: |
||||||
|
raise Exception("Reboot unavailable") |
||||||
|
|
||||||
|
def do_reboot(): |
||||||
|
time.sleep(2) |
||||||
|
android.reboot() |
||||||
|
|
||||||
|
threading.Thread(target=do_reboot).start() |
||||||
|
|
||||||
|
return {"success": 1} |
||||||
|
|
||||||
|
@dispatcher.add_method |
||||||
|
def uploadFileToUrl(fn, url, headers): |
||||||
|
if len(fn) == 0 or fn[0] == '/' or '..' in fn: |
||||||
|
return 500 |
||||||
|
path = os.path.join(ROOT, fn) |
||||||
|
if not os.path.exists(path): |
||||||
|
return 404 |
||||||
|
|
||||||
|
item = UploadItem(path=path, url=url, headers=headers, created_at=int(time.time()*1000), id=None) |
||||||
|
upload_id = hashlib.sha1(str(item).encode()).hexdigest() |
||||||
|
item = item._replace(id=upload_id) |
||||||
|
|
||||||
|
upload_queue.put_nowait(item) |
||||||
|
|
||||||
|
return {"enqueued": 1, "item": item._asdict()} |
||||||
|
|
||||||
|
@dispatcher.add_method |
||||||
|
def listUploadQueue(): |
||||||
|
return [item._asdict() for item in list(upload_queue.queue)] |
||||||
|
|
||||||
|
@dispatcher.add_method |
||||||
|
def cancelUpload(upload_id): |
||||||
|
upload_ids = set(item.id for item in list(upload_queue.queue)) |
||||||
|
if upload_id not in upload_ids: |
||||||
|
return 404 |
||||||
|
|
||||||
|
cancelled_uploads.add(upload_id) |
||||||
|
return {"success": 1} |
||||||
|
|
||||||
|
def startLocalProxy(global_end_event, remote_ws_uri, local_port): |
||||||
|
try: |
||||||
|
if local_port not in LOCAL_PORT_WHITELIST: |
||||||
|
raise Exception("Requested local port not whitelisted") |
||||||
|
|
||||||
|
params = Params() |
||||||
|
dongle_id = params.get("DongleId").decode('utf8') |
||||||
|
identity_token = Api(dongle_id).get_token() |
||||||
|
ws = create_connection(remote_ws_uri, |
||||||
|
cookie="jwt=" + identity_token, |
||||||
|
enable_multithread=True) |
||||||
|
|
||||||
|
ssock, csock = socket.socketpair() |
||||||
|
local_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
||||||
|
local_sock.connect(('127.0.0.1', local_port)) |
||||||
|
local_sock.setblocking(0) |
||||||
|
|
||||||
|
proxy_end_event = threading.Event() |
||||||
|
threads = [ |
||||||
|
threading.Thread(target=ws_proxy_recv, args=(ws, local_sock, ssock, proxy_end_event, global_end_event)), |
||||||
|
threading.Thread(target=ws_proxy_send, args=(ws, local_sock, csock, proxy_end_event)) |
||||||
|
] |
||||||
|
for thread in threads: |
||||||
|
thread.start() |
||||||
|
|
||||||
|
return {"success": 1} |
||||||
|
except Exception as e: |
||||||
|
cloudlog.exception("athenad.startLocalProxy.exception") |
||||||
|
raise e |
||||||
|
|
||||||
|
@dispatcher.add_method |
||||||
|
def getPublicKey(): |
||||||
|
if not os.path.isfile('/persist/comma/id_rsa.pub'): |
||||||
|
return None |
||||||
|
|
||||||
|
with open('/persist/comma/id_rsa.pub', 'r') as f: |
||||||
|
return f.read() |
||||||
|
|
||||||
|
@dispatcher.add_method |
||||||
|
def getSshAuthorizedKeys(): |
||||||
|
return Params().get("GithubSshKeys", encoding='utf8') or '' |
||||||
|
|
||||||
|
@dispatcher.add_method |
||||||
|
def getSimInfo(): |
||||||
|
sim_state = android.getprop("gsm.sim.state").split(",") |
||||||
|
network_type = android.getprop("gsm.network.type").split(',') |
||||||
|
mcc_mnc = android.getprop("gsm.sim.operator.numeric") or None |
||||||
|
|
||||||
|
sim_id = android.parse_service_call_string(android.service_call(['iphonesubinfo', '11'])) |
||||||
|
cell_data_state = android.parse_service_call_unpack(android.service_call(['phone', '46']), ">q") |
||||||
|
cell_data_connected = (cell_data_state == 2) |
||||||
|
|
||||||
|
return { |
||||||
|
'sim_id': sim_id, |
||||||
|
'mcc_mnc': mcc_mnc, |
||||||
|
'network_type': network_type, |
||||||
|
'sim_state': sim_state, |
||||||
|
'data_connected': cell_data_connected |
||||||
|
} |
||||||
|
|
||||||
|
@dispatcher.add_method |
||||||
|
def takeSnapshot(): |
||||||
|
from selfdrive.camerad.snapshot.snapshot import snapshot, jpeg_write |
||||||
|
ret = snapshot() |
||||||
|
if ret is not None: |
||||||
|
def b64jpeg(x): |
||||||
|
if x is not None: |
||||||
|
f = io.BytesIO() |
||||||
|
jpeg_write(f, x) |
||||||
|
return base64.b64encode(f.getvalue()).decode("utf-8") |
||||||
|
else: |
||||||
|
return None |
||||||
|
return {'jpegBack': b64jpeg(ret[0]), |
||||||
|
'jpegFront': b64jpeg(ret[1])} |
||||||
|
else: |
||||||
|
raise Exception("not available while camerad is started") |
||||||
|
|
||||||
|
def ws_proxy_recv(ws, local_sock, ssock, end_event, global_end_event): |
||||||
|
while not (end_event.is_set() or global_end_event.is_set()): |
||||||
|
try: |
||||||
|
data = ws.recv() |
||||||
|
local_sock.sendall(data) |
||||||
|
except WebSocketTimeoutException: |
||||||
|
pass |
||||||
|
except Exception: |
||||||
|
cloudlog.exception("athenad.ws_proxy_recv.exception") |
||||||
|
break |
||||||
|
|
||||||
|
ssock.close() |
||||||
|
local_sock.close() |
||||||
|
end_event.set() |
||||||
|
|
||||||
|
def ws_proxy_send(ws, local_sock, signal_sock, end_event): |
||||||
|
while not end_event.is_set(): |
||||||
|
try: |
||||||
|
r, _, _ = select.select((local_sock, signal_sock), (), ()) |
||||||
|
if r: |
||||||
|
if r[0].fileno() == signal_sock.fileno(): |
||||||
|
# got end signal from ws_proxy_recv |
||||||
|
end_event.set() |
||||||
|
break |
||||||
|
data = local_sock.recv(4096) |
||||||
|
if not data: |
||||||
|
# local_sock is dead |
||||||
|
end_event.set() |
||||||
|
break |
||||||
|
|
||||||
|
ws.send(data, ABNF.OPCODE_BINARY) |
||||||
|
except Exception: |
||||||
|
cloudlog.exception("athenad.ws_proxy_send.exception") |
||||||
|
end_event.set() |
||||||
|
|
||||||
|
def ws_recv(ws, end_event): |
||||||
|
while not end_event.is_set(): |
||||||
|
try: |
||||||
|
data = ws.recv() |
||||||
|
payload_queue.put_nowait(data) |
||||||
|
except WebSocketTimeoutException: |
||||||
|
pass |
||||||
|
except Exception: |
||||||
|
cloudlog.exception("athenad.ws_recv.exception") |
||||||
|
end_event.set() |
||||||
|
|
||||||
|
def ws_send(ws, end_event): |
||||||
|
while not end_event.is_set(): |
||||||
|
try: |
||||||
|
response = response_queue.get(timeout=1) |
||||||
|
ws.send(response.json) |
||||||
|
except queue.Empty: |
||||||
|
pass |
||||||
|
except Exception: |
||||||
|
cloudlog.exception("athenad.ws_send.exception") |
||||||
|
end_event.set() |
||||||
|
|
||||||
|
def backoff(retries): |
||||||
|
return random.randrange(0, min(128, int(2 ** retries))) |
||||||
|
|
||||||
|
def main(gctx=None): |
||||||
|
params = Params() |
||||||
|
dongle_id = params.get("DongleId").decode('utf-8') |
||||||
|
ws_uri = ATHENA_HOST + "/ws/v2/" + dongle_id |
||||||
|
|
||||||
|
api = Api(dongle_id) |
||||||
|
|
||||||
|
conn_retries = 0 |
||||||
|
while 1: |
||||||
|
try: |
||||||
|
ws = create_connection(ws_uri, |
||||||
|
cookie="jwt=" + api.get_token(), |
||||||
|
enable_multithread=True) |
||||||
|
cloudlog.event("athenad.main.connected_ws", ws_uri=ws_uri) |
||||||
|
ws.settimeout(1) |
||||||
|
conn_retries = 0 |
||||||
|
handle_long_poll(ws) |
||||||
|
except (KeyboardInterrupt, SystemExit): |
||||||
|
break |
||||||
|
except Exception: |
||||||
|
cloudlog.exception("athenad.main.exception") |
||||||
|
conn_retries += 1 |
||||||
|
|
||||||
|
time.sleep(backoff(conn_retries)) |
||||||
|
|
||||||
|
if __name__ == "__main__": |
||||||
|
main() |
@ -0,0 +1,36 @@ |
|||||||
|
#!/usr/bin/env python3 |
||||||
|
|
||||||
|
import time |
||||||
|
from multiprocessing import Process |
||||||
|
|
||||||
|
import selfdrive.crash as crash |
||||||
|
from common.params import Params |
||||||
|
from selfdrive.launcher import launcher |
||||||
|
from selfdrive.swaglog import cloudlog |
||||||
|
from selfdrive.version import version, dirty |
||||||
|
|
||||||
|
ATHENA_MGR_PID_PARAM = "AthenadPid" |
||||||
|
|
||||||
|
def main(): |
||||||
|
params = Params() |
||||||
|
dongle_id = params.get("DongleId").decode('utf-8') |
||||||
|
cloudlog.bind_global(dongle_id=dongle_id, version=version, dirty=dirty, is_eon=True) |
||||||
|
crash.bind_user(id=dongle_id) |
||||||
|
crash.bind_extra(version=version, dirty=dirty, is_eon=True) |
||||||
|
crash.install() |
||||||
|
|
||||||
|
try: |
||||||
|
while 1: |
||||||
|
cloudlog.info("starting athena daemon") |
||||||
|
proc = Process(name='athenad', target=launcher, args=('selfdrive.athena.athenad',)) |
||||||
|
proc.start() |
||||||
|
proc.join() |
||||||
|
cloudlog.event("athenad exited", exitcode=proc.exitcode) |
||||||
|
time.sleep(5) |
||||||
|
except: |
||||||
|
cloudlog.exception("manage_athenad.exception") |
||||||
|
finally: |
||||||
|
params.delete(ATHENA_MGR_PID_PARAM) |
||||||
|
|
||||||
|
if __name__ == '__main__': |
||||||
|
main() |
@ -0,0 +1,193 @@ |
|||||||
|
#!/usr/bin/env python3 |
||||||
|
import json |
||||||
|
import os |
||||||
|
import requests |
||||||
|
import tempfile |
||||||
|
import time |
||||||
|
import threading |
||||||
|
import queue |
||||||
|
import unittest |
||||||
|
|
||||||
|
from multiprocessing import Process |
||||||
|
from pathlib import Path |
||||||
|
from unittest import mock |
||||||
|
from websocket import ABNF |
||||||
|
from websocket._exceptions import WebSocketConnectionClosedException |
||||||
|
|
||||||
|
from selfdrive.athena import athenad |
||||||
|
from selfdrive.athena.athenad import dispatcher |
||||||
|
from selfdrive.athena.test_helpers import MockWebsocket, MockParams, MockApi, EchoSocket, with_http_server |
||||||
|
from cereal import messaging |
||||||
|
|
||||||
|
class TestAthenadMethods(unittest.TestCase): |
||||||
|
@classmethod |
||||||
|
def setUpClass(cls): |
||||||
|
cls.SOCKET_PORT = 45454 |
||||||
|
athenad.ROOT = tempfile.mkdtemp() |
||||||
|
athenad.Params = MockParams |
||||||
|
athenad.Api = MockApi |
||||||
|
athenad.LOCAL_PORT_WHITELIST = set([cls.SOCKET_PORT]) |
||||||
|
|
||||||
|
def test_echo(self): |
||||||
|
assert dispatcher["echo"]("bob") == "bob" |
||||||
|
|
||||||
|
def test_getMessage(self): |
||||||
|
with self.assertRaises(TimeoutError) as _: |
||||||
|
dispatcher["getMessage"]("controlsState") |
||||||
|
|
||||||
|
def send_thermal(): |
||||||
|
messaging.context = messaging.Context() |
||||||
|
pub_sock = messaging.pub_sock("thermal") |
||||||
|
start = time.time() |
||||||
|
|
||||||
|
while time.time() - start < 1: |
||||||
|
msg = messaging.new_message() |
||||||
|
msg.init('thermal') |
||||||
|
pub_sock.send(msg.to_bytes()) |
||||||
|
time.sleep(0.01) |
||||||
|
|
||||||
|
p = Process(target=send_thermal) |
||||||
|
p.start() |
||||||
|
time.sleep(0.1) |
||||||
|
try: |
||||||
|
thermal = dispatcher["getMessage"]("thermal") |
||||||
|
assert thermal['thermal'] |
||||||
|
finally: |
||||||
|
p.terminate() |
||||||
|
|
||||||
|
def test_listDataDirectory(self): |
||||||
|
print(dispatcher["listDataDirectory"]()) |
||||||
|
|
||||||
|
@with_http_server |
||||||
|
def test_do_upload(self, host): |
||||||
|
fn = os.path.join(athenad.ROOT, 'qlog.bz2') |
||||||
|
Path(fn).touch() |
||||||
|
|
||||||
|
try: |
||||||
|
item = athenad.UploadItem(path=fn, url="http://localhost:1238", headers={}, created_at=int(time.time()*1000), id='') |
||||||
|
try: |
||||||
|
athenad._do_upload(item) |
||||||
|
except requests.exceptions.ConnectionError: |
||||||
|
pass |
||||||
|
|
||||||
|
item = athenad.UploadItem(path=fn, url=f"{host}/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='') |
||||||
|
resp = athenad._do_upload(item) |
||||||
|
self.assertEqual(resp.status_code, 201) |
||||||
|
finally: |
||||||
|
os.unlink(fn) |
||||||
|
|
||||||
|
@with_http_server |
||||||
|
def test_uploadFileToUrl(self, host): |
||||||
|
not_exists_resp = dispatcher["uploadFileToUrl"]("does_not_exist.bz2", "http://localhost:1238", {}) |
||||||
|
self.assertEqual(not_exists_resp, 404) |
||||||
|
|
||||||
|
fn = os.path.join(athenad.ROOT, 'qlog.bz2') |
||||||
|
Path(fn).touch() |
||||||
|
|
||||||
|
try: |
||||||
|
resp = dispatcher["uploadFileToUrl"]("qlog.bz2", f"{host}/qlog.bz2", {}) |
||||||
|
self.assertEqual(resp['enqueued'], 1) |
||||||
|
self.assertDictContainsSubset({"path": fn, "url": f"{host}/qlog.bz2", "headers": {}}, resp['item']) |
||||||
|
self.assertIsNotNone(resp['item'].get('id')) |
||||||
|
self.assertEqual(athenad.upload_queue.qsize(), 1) |
||||||
|
finally: |
||||||
|
athenad.upload_queue = queue.Queue() |
||||||
|
os.unlink(fn) |
||||||
|
|
||||||
|
@with_http_server |
||||||
|
def test_upload_handler(self, host): |
||||||
|
fn = os.path.join(athenad.ROOT, 'qlog.bz2') |
||||||
|
Path(fn).touch() |
||||||
|
item = athenad.UploadItem(path=fn, url=f"{host}/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='') |
||||||
|
|
||||||
|
end_event = threading.Event() |
||||||
|
thread = threading.Thread(target=athenad.upload_handler, args=(end_event,)) |
||||||
|
thread.start() |
||||||
|
|
||||||
|
athenad.upload_queue.put_nowait(item) |
||||||
|
try: |
||||||
|
now = time.time() |
||||||
|
while time.time() - now < 5: |
||||||
|
if athenad.upload_queue.qsize() == 0: |
||||||
|
break |
||||||
|
self.assertEqual(athenad.upload_queue.qsize(), 0) |
||||||
|
finally: |
||||||
|
end_event.set() |
||||||
|
athenad.upload_queue = queue.Queue() |
||||||
|
os.unlink(fn) |
||||||
|
|
||||||
|
def test_cancelUpload(self): |
||||||
|
item = athenad.UploadItem(path="qlog.bz2", url="http://localhost:44444/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='id') |
||||||
|
athenad.upload_queue.put_nowait(item) |
||||||
|
dispatcher["cancelUpload"](item.id) |
||||||
|
|
||||||
|
self.assertIn(item.id, athenad.cancelled_uploads) |
||||||
|
|
||||||
|
end_event = threading.Event() |
||||||
|
thread = threading.Thread(target=athenad.upload_handler, args=(end_event,)) |
||||||
|
thread.start() |
||||||
|
try: |
||||||
|
now = time.time() |
||||||
|
while time.time() - now < 5: |
||||||
|
if athenad.upload_queue.qsize() == 0 and len(athenad.cancelled_uploads) == 0: |
||||||
|
break |
||||||
|
self.assertEqual(athenad.upload_queue.qsize(), 0) |
||||||
|
self.assertEqual(len(athenad.cancelled_uploads), 0) |
||||||
|
finally: |
||||||
|
end_event.set() |
||||||
|
athenad.upload_queue = queue.Queue() |
||||||
|
|
||||||
|
def test_listUploadQueue(self): |
||||||
|
item = athenad.UploadItem(path="qlog.bz2", url="http://localhost:44444/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='id') |
||||||
|
athenad.upload_queue.put_nowait(item) |
||||||
|
|
||||||
|
try: |
||||||
|
items = dispatcher["listUploadQueue"]() |
||||||
|
self.assertEqual(len(items), 1) |
||||||
|
self.assertDictEqual(items[0], item._asdict()) |
||||||
|
finally: |
||||||
|
athenad.upload_queue = queue.Queue() |
||||||
|
|
||||||
|
@mock.patch('selfdrive.athena.athenad.create_connection') |
||||||
|
def test_startLocalProxy(self, mock_create_connection): |
||||||
|
end_event = threading.Event() |
||||||
|
|
||||||
|
ws_recv = queue.Queue() |
||||||
|
ws_send = queue.Queue() |
||||||
|
mock_ws = MockWebsocket(ws_recv, ws_send) |
||||||
|
mock_create_connection.return_value = mock_ws |
||||||
|
|
||||||
|
echo_socket = EchoSocket(self.SOCKET_PORT) |
||||||
|
socket_thread = threading.Thread(target=echo_socket.run) |
||||||
|
socket_thread.start() |
||||||
|
|
||||||
|
athenad.startLocalProxy(end_event, 'ws://localhost:1234', self.SOCKET_PORT) |
||||||
|
|
||||||
|
ws_recv.put_nowait(b'ping') |
||||||
|
try: |
||||||
|
recv = ws_send.get(timeout=5) |
||||||
|
assert recv == (b'ping', ABNF.OPCODE_BINARY), recv |
||||||
|
finally: |
||||||
|
# signal websocket close to athenad.ws_proxy_recv |
||||||
|
ws_recv.put_nowait(WebSocketConnectionClosedException()) |
||||||
|
socket_thread.join() |
||||||
|
|
||||||
|
def test_getSshAuthorizedKeys(self): |
||||||
|
keys = dispatcher["getSshAuthorizedKeys"]() |
||||||
|
self.assertEqual(keys, MockParams().params["GithubSshKeys"].decode('utf-8')) |
||||||
|
|
||||||
|
def test_jsonrpc_handler(self): |
||||||
|
end_event = threading.Event() |
||||||
|
thread = threading.Thread(target=athenad.jsonrpc_handler, args=(end_event,)) |
||||||
|
thread.daemon = True |
||||||
|
thread.start() |
||||||
|
athenad.payload_queue.put_nowait(json.dumps({"method": "echo", "params": ["hello"], "jsonrpc": "2.0", "id": 0})) |
||||||
|
try: |
||||||
|
resp = athenad.response_queue.get(timeout=3) |
||||||
|
self.assertDictEqual(resp.data, {'result': 'hello', 'id': 0, 'jsonrpc': '2.0'}) |
||||||
|
finally: |
||||||
|
end_event.set() |
||||||
|
thread.join() |
||||||
|
|
||||||
|
if __name__ == '__main__': |
||||||
|
unittest.main() |
@ -0,0 +1,114 @@ |
|||||||
|
import http.server |
||||||
|
import multiprocessing |
||||||
|
import queue |
||||||
|
import random |
||||||
|
import requests |
||||||
|
import socket |
||||||
|
import time |
||||||
|
from functools import wraps |
||||||
|
from multiprocessing import Process |
||||||
|
|
||||||
|
class EchoSocket(): |
||||||
|
def __init__(self, port): |
||||||
|
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
||||||
|
self.socket.bind(('127.0.0.1', port)) |
||||||
|
self.socket.listen(1) |
||||||
|
|
||||||
|
def run(self): |
||||||
|
conn, client_address = self.socket.accept() |
||||||
|
conn.settimeout(5.0) |
||||||
|
|
||||||
|
try: |
||||||
|
while True: |
||||||
|
data = conn.recv(4096) |
||||||
|
if data: |
||||||
|
print(f'EchoSocket got {data}') |
||||||
|
conn.sendall(data) |
||||||
|
else: |
||||||
|
break |
||||||
|
finally: |
||||||
|
conn.shutdown(0) |
||||||
|
conn.close() |
||||||
|
self.socket.shutdown(0) |
||||||
|
self.socket.close() |
||||||
|
|
||||||
|
class MockApi(): |
||||||
|
def __init__(self, dongle_id): |
||||||
|
pass |
||||||
|
|
||||||
|
def get_token(self): |
||||||
|
return "fake-token" |
||||||
|
|
||||||
|
class MockParams(): |
||||||
|
def __init__(self): |
||||||
|
self.params = { |
||||||
|
"DongleId": b"0000000000000000", |
||||||
|
"GithubSshKeys": b"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQC307aE+nuHzTAgaJhzSf5v7ZZQW9gaperjhCmyPyl4PzY7T1mDGenTlVTN7yoVFZ9UfO9oMQqo0n1OwDIiqbIFxqnhrHU0cYfj88rI85m5BEKlNu5RdaVTj1tcbaPpQc5kZEolaI1nDDjzV0lwS7jo5VYDHseiJHlik3HH1SgtdtsuamGR2T80q1SyW+5rHoMOJG73IH2553NnWuikKiuikGHUYBd00K1ilVAK2xSiMWJp55tQfZ0ecr9QjEsJ+J/efL4HqGNXhffxvypCXvbUYAFSddOwXUPo5BTKevpxMtH+2YrkpSjocWA04VnTYFiPG6U4ItKmbLOTFZtPzoez private" |
||||||
|
} |
||||||
|
|
||||||
|
def get(self, k, encoding=None): |
||||||
|
ret = self.params.get(k) |
||||||
|
if ret is not None and encoding is not None: |
||||||
|
ret = ret.decode(encoding) |
||||||
|
return ret |
||||||
|
|
||||||
|
class MockWebsocket(): |
||||||
|
def __init__(self, recv_queue, send_queue): |
||||||
|
self.recv_queue = recv_queue |
||||||
|
self.send_queue = send_queue |
||||||
|
|
||||||
|
def recv(self): |
||||||
|
data = self.recv_queue.get() |
||||||
|
if isinstance(data, Exception): |
||||||
|
raise data |
||||||
|
return data |
||||||
|
|
||||||
|
def send(self, data, opcode): |
||||||
|
self.send_queue.put_nowait((data, opcode)) |
||||||
|
|
||||||
|
class HTTPRequestHandler(http.server.SimpleHTTPRequestHandler): |
||||||
|
def do_PUT(self): |
||||||
|
length = int(self.headers['Content-Length']) |
||||||
|
self.rfile.read(length) |
||||||
|
self.send_response(201, "Created") |
||||||
|
self.end_headers() |
||||||
|
|
||||||
|
def http_server(port_queue, **kwargs): |
||||||
|
while 1: |
||||||
|
try: |
||||||
|
port = random.randrange(40000, 50000) |
||||||
|
port_queue.put(port) |
||||||
|
http.server.test(**kwargs, port=port) |
||||||
|
except OSError as e: |
||||||
|
if e.errno == 98: |
||||||
|
continue |
||||||
|
|
||||||
|
def with_http_server(func): |
||||||
|
@wraps(func) |
||||||
|
def inner(*args, **kwargs): |
||||||
|
port_queue = multiprocessing.Queue() |
||||||
|
host = '127.0.0.1' |
||||||
|
p = Process(target=http_server, |
||||||
|
args=(port_queue,), |
||||||
|
kwargs={ |
||||||
|
'HandlerClass': HTTPRequestHandler, |
||||||
|
'bind': host}) |
||||||
|
p.start() |
||||||
|
now = time.time() |
||||||
|
port = None |
||||||
|
while 1: |
||||||
|
if time.time() - now > 5: |
||||||
|
raise Exception('HTTP Server did not start') |
||||||
|
try: |
||||||
|
port = port_queue.get(timeout=0.1) |
||||||
|
requests.put(f'http://{host}:{port}/qlog.bz2', data='') |
||||||
|
break |
||||||
|
except (requests.exceptions.ConnectionError, queue.Empty): |
||||||
|
time.sleep(0.1) |
||||||
|
|
||||||
|
try: |
||||||
|
return func(*args, f'http://{host}:{port}', **kwargs) |
||||||
|
finally: |
||||||
|
p.terminate() |
||||||
|
|
||||||
|
return inner |
Loading…
Reference in new issue