parent
78d3dfe8ae
commit
6466776773
5 changed files with 665 additions and 0 deletions
@ -0,0 +1,322 @@ |
||||
#!/usr/bin/env python3.7 |
||||
import json |
||||
import os |
||||
import hashlib |
||||
import io |
||||
import random |
||||
import select |
||||
import socket |
||||
import time |
||||
import threading |
||||
import base64 |
||||
import requests |
||||
import queue |
||||
from collections import namedtuple |
||||
from functools import partial |
||||
from jsonrpc import JSONRPCResponseManager, dispatcher |
||||
from websocket import create_connection, WebSocketTimeoutException, ABNF |
||||
from selfdrive.loggerd.config import ROOT |
||||
|
||||
import cereal.messaging as messaging |
||||
from common import android |
||||
from common.api import Api |
||||
from common.params import Params |
||||
from cereal.services import service_list |
||||
from selfdrive.swaglog import cloudlog |
||||
|
||||
ATHENA_HOST = os.getenv('ATHENA_HOST', 'wss://athena.comma.ai') |
||||
HANDLER_THREADS = os.getenv('HANDLER_THREADS', 4) |
||||
LOCAL_PORT_WHITELIST = set([8022]) |
||||
|
||||
dispatcher["echo"] = lambda s: s |
||||
payload_queue = queue.Queue() |
||||
response_queue = queue.Queue() |
||||
upload_queue = queue.Queue() |
||||
cancelled_uploads = set() |
||||
UploadItem = namedtuple('UploadItem', ['path', 'url', 'headers', 'created_at', 'id']) |
||||
|
||||
def handle_long_poll(ws): |
||||
end_event = threading.Event() |
||||
|
||||
threads = [ |
||||
threading.Thread(target=ws_recv, args=(ws, end_event)), |
||||
threading.Thread(target=ws_send, args=(ws, end_event)), |
||||
threading.Thread(target=upload_handler, args=(end_event,)) |
||||
] + [ |
||||
threading.Thread(target=jsonrpc_handler, args=(end_event,)) |
||||
for x in range(HANDLER_THREADS) |
||||
] |
||||
|
||||
for thread in threads: |
||||
thread.start() |
||||
try: |
||||
while not end_event.is_set(): |
||||
time.sleep(0.1) |
||||
except (KeyboardInterrupt, SystemExit): |
||||
end_event.set() |
||||
raise |
||||
finally: |
||||
for i, thread in enumerate(threads): |
||||
thread.join() |
||||
|
||||
def jsonrpc_handler(end_event): |
||||
dispatcher["startLocalProxy"] = partial(startLocalProxy, end_event) |
||||
while not end_event.is_set(): |
||||
try: |
||||
data = payload_queue.get(timeout=1) |
||||
response = JSONRPCResponseManager.handle(data, dispatcher) |
||||
response_queue.put_nowait(response) |
||||
except queue.Empty: |
||||
pass |
||||
except Exception as e: |
||||
cloudlog.exception("athena jsonrpc handler failed") |
||||
response_queue.put_nowait(json.dumps({"error": str(e)})) |
||||
|
||||
def upload_handler(end_event): |
||||
while not end_event.is_set(): |
||||
try: |
||||
item = upload_queue.get(timeout=1) |
||||
if item.id in cancelled_uploads: |
||||
cancelled_uploads.remove(item.id) |
||||
continue |
||||
_do_upload(item) |
||||
except queue.Empty: |
||||
pass |
||||
except Exception: |
||||
cloudlog.exception("athena.upload_handler.exception") |
||||
|
||||
def _do_upload(upload_item): |
||||
with open(upload_item.path, "rb") as f: |
||||
size = os.fstat(f.fileno()).st_size |
||||
return requests.put(upload_item.url, |
||||
data=f, |
||||
headers={**upload_item.headers, 'Content-Length': str(size)}, |
||||
timeout=10) |
||||
|
||||
# security: user should be able to request any message from their car |
||||
@dispatcher.add_method |
||||
def getMessage(service=None, timeout=1000): |
||||
if service is None or service not in service_list: |
||||
raise Exception("invalid service") |
||||
|
||||
socket = messaging.sub_sock(service, timeout=timeout) |
||||
ret = messaging.recv_one(socket) |
||||
|
||||
if ret is None: |
||||
raise TimeoutError |
||||
|
||||
return ret.to_dict() |
||||
|
||||
@dispatcher.add_method |
||||
def listDataDirectory(): |
||||
files = [os.path.relpath(os.path.join(dp, f), ROOT) for dp, dn, fn in os.walk(ROOT) for f in fn] |
||||
return files |
||||
|
||||
@dispatcher.add_method |
||||
def reboot(): |
||||
thermal_sock = messaging.sub_sock("thermal", timeout=1000) |
||||
ret = messaging.recv_one(thermal_sock) |
||||
if ret is None or ret.thermal.started: |
||||
raise Exception("Reboot unavailable") |
||||
|
||||
def do_reboot(): |
||||
time.sleep(2) |
||||
android.reboot() |
||||
|
||||
threading.Thread(target=do_reboot).start() |
||||
|
||||
return {"success": 1} |
||||
|
||||
@dispatcher.add_method |
||||
def uploadFileToUrl(fn, url, headers): |
||||
if len(fn) == 0 or fn[0] == '/' or '..' in fn: |
||||
return 500 |
||||
path = os.path.join(ROOT, fn) |
||||
if not os.path.exists(path): |
||||
return 404 |
||||
|
||||
item = UploadItem(path=path, url=url, headers=headers, created_at=int(time.time()*1000), id=None) |
||||
upload_id = hashlib.sha1(str(item).encode()).hexdigest() |
||||
item = item._replace(id=upload_id) |
||||
|
||||
upload_queue.put_nowait(item) |
||||
|
||||
return {"enqueued": 1, "item": item._asdict()} |
||||
|
||||
@dispatcher.add_method |
||||
def listUploadQueue(): |
||||
return [item._asdict() for item in list(upload_queue.queue)] |
||||
|
||||
@dispatcher.add_method |
||||
def cancelUpload(upload_id): |
||||
upload_ids = set(item.id for item in list(upload_queue.queue)) |
||||
if upload_id not in upload_ids: |
||||
return 404 |
||||
|
||||
cancelled_uploads.add(upload_id) |
||||
return {"success": 1} |
||||
|
||||
def startLocalProxy(global_end_event, remote_ws_uri, local_port): |
||||
try: |
||||
if local_port not in LOCAL_PORT_WHITELIST: |
||||
raise Exception("Requested local port not whitelisted") |
||||
|
||||
params = Params() |
||||
dongle_id = params.get("DongleId").decode('utf8') |
||||
identity_token = Api(dongle_id).get_token() |
||||
ws = create_connection(remote_ws_uri, |
||||
cookie="jwt=" + identity_token, |
||||
enable_multithread=True) |
||||
|
||||
ssock, csock = socket.socketpair() |
||||
local_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
||||
local_sock.connect(('127.0.0.1', local_port)) |
||||
local_sock.setblocking(0) |
||||
|
||||
proxy_end_event = threading.Event() |
||||
threads = [ |
||||
threading.Thread(target=ws_proxy_recv, args=(ws, local_sock, ssock, proxy_end_event, global_end_event)), |
||||
threading.Thread(target=ws_proxy_send, args=(ws, local_sock, csock, proxy_end_event)) |
||||
] |
||||
for thread in threads: |
||||
thread.start() |
||||
|
||||
return {"success": 1} |
||||
except Exception as e: |
||||
cloudlog.exception("athenad.startLocalProxy.exception") |
||||
raise e |
||||
|
||||
@dispatcher.add_method |
||||
def getPublicKey(): |
||||
if not os.path.isfile('/persist/comma/id_rsa.pub'): |
||||
return None |
||||
|
||||
with open('/persist/comma/id_rsa.pub', 'r') as f: |
||||
return f.read() |
||||
|
||||
@dispatcher.add_method |
||||
def getSshAuthorizedKeys(): |
||||
return Params().get("GithubSshKeys", encoding='utf8') or '' |
||||
|
||||
@dispatcher.add_method |
||||
def getSimInfo(): |
||||
sim_state = android.getprop("gsm.sim.state").split(",") |
||||
network_type = android.getprop("gsm.network.type").split(',') |
||||
mcc_mnc = android.getprop("gsm.sim.operator.numeric") or None |
||||
|
||||
sim_id = android.parse_service_call_string(android.service_call(['iphonesubinfo', '11'])) |
||||
cell_data_state = android.parse_service_call_unpack(android.service_call(['phone', '46']), ">q") |
||||
cell_data_connected = (cell_data_state == 2) |
||||
|
||||
return { |
||||
'sim_id': sim_id, |
||||
'mcc_mnc': mcc_mnc, |
||||
'network_type': network_type, |
||||
'sim_state': sim_state, |
||||
'data_connected': cell_data_connected |
||||
} |
||||
|
||||
@dispatcher.add_method |
||||
def takeSnapshot(): |
||||
from selfdrive.camerad.snapshot.snapshot import snapshot, jpeg_write |
||||
ret = snapshot() |
||||
if ret is not None: |
||||
def b64jpeg(x): |
||||
if x is not None: |
||||
f = io.BytesIO() |
||||
jpeg_write(f, x) |
||||
return base64.b64encode(f.getvalue()).decode("utf-8") |
||||
else: |
||||
return None |
||||
return {'jpegBack': b64jpeg(ret[0]), |
||||
'jpegFront': b64jpeg(ret[1])} |
||||
else: |
||||
raise Exception("not available while camerad is started") |
||||
|
||||
def ws_proxy_recv(ws, local_sock, ssock, end_event, global_end_event): |
||||
while not (end_event.is_set() or global_end_event.is_set()): |
||||
try: |
||||
data = ws.recv() |
||||
local_sock.sendall(data) |
||||
except WebSocketTimeoutException: |
||||
pass |
||||
except Exception: |
||||
cloudlog.exception("athenad.ws_proxy_recv.exception") |
||||
break |
||||
|
||||
ssock.close() |
||||
local_sock.close() |
||||
end_event.set() |
||||
|
||||
def ws_proxy_send(ws, local_sock, signal_sock, end_event): |
||||
while not end_event.is_set(): |
||||
try: |
||||
r, _, _ = select.select((local_sock, signal_sock), (), ()) |
||||
if r: |
||||
if r[0].fileno() == signal_sock.fileno(): |
||||
# got end signal from ws_proxy_recv |
||||
end_event.set() |
||||
break |
||||
data = local_sock.recv(4096) |
||||
if not data: |
||||
# local_sock is dead |
||||
end_event.set() |
||||
break |
||||
|
||||
ws.send(data, ABNF.OPCODE_BINARY) |
||||
except Exception: |
||||
cloudlog.exception("athenad.ws_proxy_send.exception") |
||||
end_event.set() |
||||
|
||||
def ws_recv(ws, end_event): |
||||
while not end_event.is_set(): |
||||
try: |
||||
data = ws.recv() |
||||
payload_queue.put_nowait(data) |
||||
except WebSocketTimeoutException: |
||||
pass |
||||
except Exception: |
||||
cloudlog.exception("athenad.ws_recv.exception") |
||||
end_event.set() |
||||
|
||||
def ws_send(ws, end_event): |
||||
while not end_event.is_set(): |
||||
try: |
||||
response = response_queue.get(timeout=1) |
||||
ws.send(response.json) |
||||
except queue.Empty: |
||||
pass |
||||
except Exception: |
||||
cloudlog.exception("athenad.ws_send.exception") |
||||
end_event.set() |
||||
|
||||
def backoff(retries): |
||||
return random.randrange(0, min(128, int(2 ** retries))) |
||||
|
||||
def main(gctx=None): |
||||
params = Params() |
||||
dongle_id = params.get("DongleId").decode('utf-8') |
||||
ws_uri = ATHENA_HOST + "/ws/v2/" + dongle_id |
||||
|
||||
api = Api(dongle_id) |
||||
|
||||
conn_retries = 0 |
||||
while 1: |
||||
try: |
||||
ws = create_connection(ws_uri, |
||||
cookie="jwt=" + api.get_token(), |
||||
enable_multithread=True) |
||||
cloudlog.event("athenad.main.connected_ws", ws_uri=ws_uri) |
||||
ws.settimeout(1) |
||||
conn_retries = 0 |
||||
handle_long_poll(ws) |
||||
except (KeyboardInterrupt, SystemExit): |
||||
break |
||||
except Exception: |
||||
cloudlog.exception("athenad.main.exception") |
||||
conn_retries += 1 |
||||
|
||||
time.sleep(backoff(conn_retries)) |
||||
|
||||
if __name__ == "__main__": |
||||
main() |
@ -0,0 +1,36 @@ |
||||
#!/usr/bin/env python3 |
||||
|
||||
import time |
||||
from multiprocessing import Process |
||||
|
||||
import selfdrive.crash as crash |
||||
from common.params import Params |
||||
from selfdrive.launcher import launcher |
||||
from selfdrive.swaglog import cloudlog |
||||
from selfdrive.version import version, dirty |
||||
|
||||
ATHENA_MGR_PID_PARAM = "AthenadPid" |
||||
|
||||
def main(): |
||||
params = Params() |
||||
dongle_id = params.get("DongleId").decode('utf-8') |
||||
cloudlog.bind_global(dongle_id=dongle_id, version=version, dirty=dirty, is_eon=True) |
||||
crash.bind_user(id=dongle_id) |
||||
crash.bind_extra(version=version, dirty=dirty, is_eon=True) |
||||
crash.install() |
||||
|
||||
try: |
||||
while 1: |
||||
cloudlog.info("starting athena daemon") |
||||
proc = Process(name='athenad', target=launcher, args=('selfdrive.athena.athenad',)) |
||||
proc.start() |
||||
proc.join() |
||||
cloudlog.event("athenad exited", exitcode=proc.exitcode) |
||||
time.sleep(5) |
||||
except: |
||||
cloudlog.exception("manage_athenad.exception") |
||||
finally: |
||||
params.delete(ATHENA_MGR_PID_PARAM) |
||||
|
||||
if __name__ == '__main__': |
||||
main() |
@ -0,0 +1,193 @@ |
||||
#!/usr/bin/env python3 |
||||
import json |
||||
import os |
||||
import requests |
||||
import tempfile |
||||
import time |
||||
import threading |
||||
import queue |
||||
import unittest |
||||
|
||||
from multiprocessing import Process |
||||
from pathlib import Path |
||||
from unittest import mock |
||||
from websocket import ABNF |
||||
from websocket._exceptions import WebSocketConnectionClosedException |
||||
|
||||
from selfdrive.athena import athenad |
||||
from selfdrive.athena.athenad import dispatcher |
||||
from selfdrive.athena.test_helpers import MockWebsocket, MockParams, MockApi, EchoSocket, with_http_server |
||||
from cereal import messaging |
||||
|
||||
class TestAthenadMethods(unittest.TestCase): |
||||
@classmethod |
||||
def setUpClass(cls): |
||||
cls.SOCKET_PORT = 45454 |
||||
athenad.ROOT = tempfile.mkdtemp() |
||||
athenad.Params = MockParams |
||||
athenad.Api = MockApi |
||||
athenad.LOCAL_PORT_WHITELIST = set([cls.SOCKET_PORT]) |
||||
|
||||
def test_echo(self): |
||||
assert dispatcher["echo"]("bob") == "bob" |
||||
|
||||
def test_getMessage(self): |
||||
with self.assertRaises(TimeoutError) as _: |
||||
dispatcher["getMessage"]("controlsState") |
||||
|
||||
def send_thermal(): |
||||
messaging.context = messaging.Context() |
||||
pub_sock = messaging.pub_sock("thermal") |
||||
start = time.time() |
||||
|
||||
while time.time() - start < 1: |
||||
msg = messaging.new_message() |
||||
msg.init('thermal') |
||||
pub_sock.send(msg.to_bytes()) |
||||
time.sleep(0.01) |
||||
|
||||
p = Process(target=send_thermal) |
||||
p.start() |
||||
time.sleep(0.1) |
||||
try: |
||||
thermal = dispatcher["getMessage"]("thermal") |
||||
assert thermal['thermal'] |
||||
finally: |
||||
p.terminate() |
||||
|
||||
def test_listDataDirectory(self): |
||||
print(dispatcher["listDataDirectory"]()) |
||||
|
||||
@with_http_server |
||||
def test_do_upload(self, host): |
||||
fn = os.path.join(athenad.ROOT, 'qlog.bz2') |
||||
Path(fn).touch() |
||||
|
||||
try: |
||||
item = athenad.UploadItem(path=fn, url="http://localhost:1238", headers={}, created_at=int(time.time()*1000), id='') |
||||
try: |
||||
athenad._do_upload(item) |
||||
except requests.exceptions.ConnectionError: |
||||
pass |
||||
|
||||
item = athenad.UploadItem(path=fn, url=f"{host}/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='') |
||||
resp = athenad._do_upload(item) |
||||
self.assertEqual(resp.status_code, 201) |
||||
finally: |
||||
os.unlink(fn) |
||||
|
||||
@with_http_server |
||||
def test_uploadFileToUrl(self, host): |
||||
not_exists_resp = dispatcher["uploadFileToUrl"]("does_not_exist.bz2", "http://localhost:1238", {}) |
||||
self.assertEqual(not_exists_resp, 404) |
||||
|
||||
fn = os.path.join(athenad.ROOT, 'qlog.bz2') |
||||
Path(fn).touch() |
||||
|
||||
try: |
||||
resp = dispatcher["uploadFileToUrl"]("qlog.bz2", f"{host}/qlog.bz2", {}) |
||||
self.assertEqual(resp['enqueued'], 1) |
||||
self.assertDictContainsSubset({"path": fn, "url": f"{host}/qlog.bz2", "headers": {}}, resp['item']) |
||||
self.assertIsNotNone(resp['item'].get('id')) |
||||
self.assertEqual(athenad.upload_queue.qsize(), 1) |
||||
finally: |
||||
athenad.upload_queue = queue.Queue() |
||||
os.unlink(fn) |
||||
|
||||
@with_http_server |
||||
def test_upload_handler(self, host): |
||||
fn = os.path.join(athenad.ROOT, 'qlog.bz2') |
||||
Path(fn).touch() |
||||
item = athenad.UploadItem(path=fn, url=f"{host}/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='') |
||||
|
||||
end_event = threading.Event() |
||||
thread = threading.Thread(target=athenad.upload_handler, args=(end_event,)) |
||||
thread.start() |
||||
|
||||
athenad.upload_queue.put_nowait(item) |
||||
try: |
||||
now = time.time() |
||||
while time.time() - now < 5: |
||||
if athenad.upload_queue.qsize() == 0: |
||||
break |
||||
self.assertEqual(athenad.upload_queue.qsize(), 0) |
||||
finally: |
||||
end_event.set() |
||||
athenad.upload_queue = queue.Queue() |
||||
os.unlink(fn) |
||||
|
||||
def test_cancelUpload(self): |
||||
item = athenad.UploadItem(path="qlog.bz2", url="http://localhost:44444/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='id') |
||||
athenad.upload_queue.put_nowait(item) |
||||
dispatcher["cancelUpload"](item.id) |
||||
|
||||
self.assertIn(item.id, athenad.cancelled_uploads) |
||||
|
||||
end_event = threading.Event() |
||||
thread = threading.Thread(target=athenad.upload_handler, args=(end_event,)) |
||||
thread.start() |
||||
try: |
||||
now = time.time() |
||||
while time.time() - now < 5: |
||||
if athenad.upload_queue.qsize() == 0 and len(athenad.cancelled_uploads) == 0: |
||||
break |
||||
self.assertEqual(athenad.upload_queue.qsize(), 0) |
||||
self.assertEqual(len(athenad.cancelled_uploads), 0) |
||||
finally: |
||||
end_event.set() |
||||
athenad.upload_queue = queue.Queue() |
||||
|
||||
def test_listUploadQueue(self): |
||||
item = athenad.UploadItem(path="qlog.bz2", url="http://localhost:44444/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='id') |
||||
athenad.upload_queue.put_nowait(item) |
||||
|
||||
try: |
||||
items = dispatcher["listUploadQueue"]() |
||||
self.assertEqual(len(items), 1) |
||||
self.assertDictEqual(items[0], item._asdict()) |
||||
finally: |
||||
athenad.upload_queue = queue.Queue() |
||||
|
||||
@mock.patch('selfdrive.athena.athenad.create_connection') |
||||
def test_startLocalProxy(self, mock_create_connection): |
||||
end_event = threading.Event() |
||||
|
||||
ws_recv = queue.Queue() |
||||
ws_send = queue.Queue() |
||||
mock_ws = MockWebsocket(ws_recv, ws_send) |
||||
mock_create_connection.return_value = mock_ws |
||||
|
||||
echo_socket = EchoSocket(self.SOCKET_PORT) |
||||
socket_thread = threading.Thread(target=echo_socket.run) |
||||
socket_thread.start() |
||||
|
||||
athenad.startLocalProxy(end_event, 'ws://localhost:1234', self.SOCKET_PORT) |
||||
|
||||
ws_recv.put_nowait(b'ping') |
||||
try: |
||||
recv = ws_send.get(timeout=5) |
||||
assert recv == (b'ping', ABNF.OPCODE_BINARY), recv |
||||
finally: |
||||
# signal websocket close to athenad.ws_proxy_recv |
||||
ws_recv.put_nowait(WebSocketConnectionClosedException()) |
||||
socket_thread.join() |
||||
|
||||
def test_getSshAuthorizedKeys(self): |
||||
keys = dispatcher["getSshAuthorizedKeys"]() |
||||
self.assertEqual(keys, MockParams().params["GithubSshKeys"].decode('utf-8')) |
||||
|
||||
def test_jsonrpc_handler(self): |
||||
end_event = threading.Event() |
||||
thread = threading.Thread(target=athenad.jsonrpc_handler, args=(end_event,)) |
||||
thread.daemon = True |
||||
thread.start() |
||||
athenad.payload_queue.put_nowait(json.dumps({"method": "echo", "params": ["hello"], "jsonrpc": "2.0", "id": 0})) |
||||
try: |
||||
resp = athenad.response_queue.get(timeout=3) |
||||
self.assertDictEqual(resp.data, {'result': 'hello', 'id': 0, 'jsonrpc': '2.0'}) |
||||
finally: |
||||
end_event.set() |
||||
thread.join() |
||||
|
||||
if __name__ == '__main__': |
||||
unittest.main() |
@ -0,0 +1,114 @@ |
||||
import http.server |
||||
import multiprocessing |
||||
import queue |
||||
import random |
||||
import requests |
||||
import socket |
||||
import time |
||||
from functools import wraps |
||||
from multiprocessing import Process |
||||
|
||||
class EchoSocket(): |
||||
def __init__(self, port): |
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
||||
self.socket.bind(('127.0.0.1', port)) |
||||
self.socket.listen(1) |
||||
|
||||
def run(self): |
||||
conn, client_address = self.socket.accept() |
||||
conn.settimeout(5.0) |
||||
|
||||
try: |
||||
while True: |
||||
data = conn.recv(4096) |
||||
if data: |
||||
print(f'EchoSocket got {data}') |
||||
conn.sendall(data) |
||||
else: |
||||
break |
||||
finally: |
||||
conn.shutdown(0) |
||||
conn.close() |
||||
self.socket.shutdown(0) |
||||
self.socket.close() |
||||
|
||||
class MockApi(): |
||||
def __init__(self, dongle_id): |
||||
pass |
||||
|
||||
def get_token(self): |
||||
return "fake-token" |
||||
|
||||
class MockParams(): |
||||
def __init__(self): |
||||
self.params = { |
||||
"DongleId": b"0000000000000000", |
||||
"GithubSshKeys": b"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQC307aE+nuHzTAgaJhzSf5v7ZZQW9gaperjhCmyPyl4PzY7T1mDGenTlVTN7yoVFZ9UfO9oMQqo0n1OwDIiqbIFxqnhrHU0cYfj88rI85m5BEKlNu5RdaVTj1tcbaPpQc5kZEolaI1nDDjzV0lwS7jo5VYDHseiJHlik3HH1SgtdtsuamGR2T80q1SyW+5rHoMOJG73IH2553NnWuikKiuikGHUYBd00K1ilVAK2xSiMWJp55tQfZ0ecr9QjEsJ+J/efL4HqGNXhffxvypCXvbUYAFSddOwXUPo5BTKevpxMtH+2YrkpSjocWA04VnTYFiPG6U4ItKmbLOTFZtPzoez private" |
||||
} |
||||
|
||||
def get(self, k, encoding=None): |
||||
ret = self.params.get(k) |
||||
if ret is not None and encoding is not None: |
||||
ret = ret.decode(encoding) |
||||
return ret |
||||
|
||||
class MockWebsocket(): |
||||
def __init__(self, recv_queue, send_queue): |
||||
self.recv_queue = recv_queue |
||||
self.send_queue = send_queue |
||||
|
||||
def recv(self): |
||||
data = self.recv_queue.get() |
||||
if isinstance(data, Exception): |
||||
raise data |
||||
return data |
||||
|
||||
def send(self, data, opcode): |
||||
self.send_queue.put_nowait((data, opcode)) |
||||
|
||||
class HTTPRequestHandler(http.server.SimpleHTTPRequestHandler): |
||||
def do_PUT(self): |
||||
length = int(self.headers['Content-Length']) |
||||
self.rfile.read(length) |
||||
self.send_response(201, "Created") |
||||
self.end_headers() |
||||
|
||||
def http_server(port_queue, **kwargs): |
||||
while 1: |
||||
try: |
||||
port = random.randrange(40000, 50000) |
||||
port_queue.put(port) |
||||
http.server.test(**kwargs, port=port) |
||||
except OSError as e: |
||||
if e.errno == 98: |
||||
continue |
||||
|
||||
def with_http_server(func): |
||||
@wraps(func) |
||||
def inner(*args, **kwargs): |
||||
port_queue = multiprocessing.Queue() |
||||
host = '127.0.0.1' |
||||
p = Process(target=http_server, |
||||
args=(port_queue,), |
||||
kwargs={ |
||||
'HandlerClass': HTTPRequestHandler, |
||||
'bind': host}) |
||||
p.start() |
||||
now = time.time() |
||||
port = None |
||||
while 1: |
||||
if time.time() - now > 5: |
||||
raise Exception('HTTP Server did not start') |
||||
try: |
||||
port = port_queue.get(timeout=0.1) |
||||
requests.put(f'http://{host}:{port}/qlog.bz2', data='') |
||||
break |
||||
except (requests.exceptions.ConnectionError, queue.Empty): |
||||
time.sleep(0.1) |
||||
|
||||
try: |
||||
return func(*args, f'http://{host}:{port}', **kwargs) |
||||
finally: |
||||
p.terminate() |
||||
|
||||
return inner |
Loading…
Reference in new issue