parent
							
								
									84560ccd55
								
							
						
					
					
						commit
						341c0da987
					
				
				 5 changed files with 665 additions and 0 deletions
			
			
		@ -0,0 +1,322 @@ | 
				
			|||||||
 | 
					#!/usr/bin/env python3.7 | 
				
			||||||
 | 
					import json | 
				
			||||||
 | 
					import os | 
				
			||||||
 | 
					import hashlib | 
				
			||||||
 | 
					import io | 
				
			||||||
 | 
					import random | 
				
			||||||
 | 
					import select | 
				
			||||||
 | 
					import socket | 
				
			||||||
 | 
					import time | 
				
			||||||
 | 
					import threading | 
				
			||||||
 | 
					import base64 | 
				
			||||||
 | 
					import requests | 
				
			||||||
 | 
					import queue | 
				
			||||||
 | 
					from collections import namedtuple | 
				
			||||||
 | 
					from functools import partial | 
				
			||||||
 | 
					from jsonrpc import JSONRPCResponseManager, dispatcher | 
				
			||||||
 | 
					from websocket import create_connection, WebSocketTimeoutException, ABNF | 
				
			||||||
 | 
					from selfdrive.loggerd.config import ROOT | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import cereal.messaging as messaging | 
				
			||||||
 | 
					from common import android | 
				
			||||||
 | 
					from common.api import Api | 
				
			||||||
 | 
					from common.params import Params | 
				
			||||||
 | 
					from cereal.services import service_list | 
				
			||||||
 | 
					from selfdrive.swaglog import cloudlog | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					ATHENA_HOST = os.getenv('ATHENA_HOST', 'wss://athena.comma.ai') | 
				
			||||||
 | 
					HANDLER_THREADS = os.getenv('HANDLER_THREADS', 4) | 
				
			||||||
 | 
					LOCAL_PORT_WHITELIST = set([8022]) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					dispatcher["echo"] = lambda s: s | 
				
			||||||
 | 
					payload_queue = queue.Queue() | 
				
			||||||
 | 
					response_queue = queue.Queue() | 
				
			||||||
 | 
					upload_queue = queue.Queue() | 
				
			||||||
 | 
					cancelled_uploads = set() | 
				
			||||||
 | 
					UploadItem = namedtuple('UploadItem', ['path', 'url', 'headers', 'created_at', 'id']) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def handle_long_poll(ws): | 
				
			||||||
 | 
					  end_event = threading.Event() | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  threads = [ | 
				
			||||||
 | 
					    threading.Thread(target=ws_recv, args=(ws, end_event)), | 
				
			||||||
 | 
					    threading.Thread(target=ws_send, args=(ws, end_event)), | 
				
			||||||
 | 
					    threading.Thread(target=upload_handler, args=(end_event,)) | 
				
			||||||
 | 
					  ] + [ | 
				
			||||||
 | 
					    threading.Thread(target=jsonrpc_handler, args=(end_event,)) | 
				
			||||||
 | 
					    for x in range(HANDLER_THREADS) | 
				
			||||||
 | 
					  ] | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for thread in threads: | 
				
			||||||
 | 
					    thread.start() | 
				
			||||||
 | 
					  try: | 
				
			||||||
 | 
					    while not end_event.is_set(): | 
				
			||||||
 | 
					      time.sleep(0.1) | 
				
			||||||
 | 
					  except (KeyboardInterrupt, SystemExit): | 
				
			||||||
 | 
					    end_event.set() | 
				
			||||||
 | 
					    raise | 
				
			||||||
 | 
					  finally: | 
				
			||||||
 | 
					    for i, thread in enumerate(threads): | 
				
			||||||
 | 
					      thread.join() | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def jsonrpc_handler(end_event): | 
				
			||||||
 | 
					  dispatcher["startLocalProxy"] = partial(startLocalProxy, end_event) | 
				
			||||||
 | 
					  while not end_event.is_set(): | 
				
			||||||
 | 
					    try: | 
				
			||||||
 | 
					      data = payload_queue.get(timeout=1) | 
				
			||||||
 | 
					      response = JSONRPCResponseManager.handle(data, dispatcher) | 
				
			||||||
 | 
					      response_queue.put_nowait(response) | 
				
			||||||
 | 
					    except queue.Empty: | 
				
			||||||
 | 
					      pass | 
				
			||||||
 | 
					    except Exception as e: | 
				
			||||||
 | 
					      cloudlog.exception("athena jsonrpc handler failed") | 
				
			||||||
 | 
					      response_queue.put_nowait(json.dumps({"error": str(e)})) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def upload_handler(end_event): | 
				
			||||||
 | 
					  while not end_event.is_set(): | 
				
			||||||
 | 
					    try: | 
				
			||||||
 | 
					      item = upload_queue.get(timeout=1) | 
				
			||||||
 | 
					      if item.id in cancelled_uploads: | 
				
			||||||
 | 
					        cancelled_uploads.remove(item.id) | 
				
			||||||
 | 
					        continue | 
				
			||||||
 | 
					      _do_upload(item) | 
				
			||||||
 | 
					    except queue.Empty: | 
				
			||||||
 | 
					      pass | 
				
			||||||
 | 
					    except Exception: | 
				
			||||||
 | 
					      cloudlog.exception("athena.upload_handler.exception") | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _do_upload(upload_item): | 
				
			||||||
 | 
					  with open(upload_item.path, "rb") as f: | 
				
			||||||
 | 
					    size = os.fstat(f.fileno()).st_size | 
				
			||||||
 | 
					    return requests.put(upload_item.url, | 
				
			||||||
 | 
					                        data=f, | 
				
			||||||
 | 
					                        headers={**upload_item.headers, 'Content-Length': str(size)}, | 
				
			||||||
 | 
					                        timeout=10) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# security: user should be able to request any message from their car | 
				
			||||||
 | 
					@dispatcher.add_method | 
				
			||||||
 | 
					def getMessage(service=None, timeout=1000): | 
				
			||||||
 | 
					  if service is None or service not in service_list: | 
				
			||||||
 | 
					    raise Exception("invalid service") | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  socket = messaging.sub_sock(service, timeout=timeout) | 
				
			||||||
 | 
					  ret = messaging.recv_one(socket) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if ret is None: | 
				
			||||||
 | 
					    raise TimeoutError | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return ret.to_dict() | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dispatcher.add_method | 
				
			||||||
 | 
					def listDataDirectory(): | 
				
			||||||
 | 
					  files = [os.path.relpath(os.path.join(dp, f), ROOT) for dp, dn, fn in os.walk(ROOT) for f in fn] | 
				
			||||||
 | 
					  return files | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dispatcher.add_method | 
				
			||||||
 | 
					def reboot(): | 
				
			||||||
 | 
					  thermal_sock = messaging.sub_sock("thermal", timeout=1000) | 
				
			||||||
 | 
					  ret = messaging.recv_one(thermal_sock) | 
				
			||||||
 | 
					  if ret is None or ret.thermal.started: | 
				
			||||||
 | 
					    raise Exception("Reboot unavailable") | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def do_reboot(): | 
				
			||||||
 | 
					    time.sleep(2) | 
				
			||||||
 | 
					    android.reboot() | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  threading.Thread(target=do_reboot).start() | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return {"success": 1} | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dispatcher.add_method | 
				
			||||||
 | 
					def uploadFileToUrl(fn, url, headers): | 
				
			||||||
 | 
					  if len(fn) == 0 or fn[0] == '/' or '..' in fn: | 
				
			||||||
 | 
					    return 500 | 
				
			||||||
 | 
					  path = os.path.join(ROOT, fn) | 
				
			||||||
 | 
					  if not os.path.exists(path): | 
				
			||||||
 | 
					    return 404 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  item = UploadItem(path=path, url=url, headers=headers, created_at=int(time.time()*1000), id=None) | 
				
			||||||
 | 
					  upload_id = hashlib.sha1(str(item).encode()).hexdigest() | 
				
			||||||
 | 
					  item = item._replace(id=upload_id) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  upload_queue.put_nowait(item) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return {"enqueued": 1, "item": item._asdict()} | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dispatcher.add_method | 
				
			||||||
 | 
					def listUploadQueue(): | 
				
			||||||
 | 
					  return [item._asdict() for item in list(upload_queue.queue)] | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dispatcher.add_method | 
				
			||||||
 | 
					def cancelUpload(upload_id): | 
				
			||||||
 | 
					  upload_ids = set(item.id for item in list(upload_queue.queue)) | 
				
			||||||
 | 
					  if upload_id not in upload_ids: | 
				
			||||||
 | 
					    return 404 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  cancelled_uploads.add(upload_id) | 
				
			||||||
 | 
					  return {"success": 1} | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def startLocalProxy(global_end_event, remote_ws_uri, local_port): | 
				
			||||||
 | 
					  try: | 
				
			||||||
 | 
					    if local_port not in LOCAL_PORT_WHITELIST: | 
				
			||||||
 | 
					      raise Exception("Requested local port not whitelisted") | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    params = Params() | 
				
			||||||
 | 
					    dongle_id = params.get("DongleId").decode('utf8') | 
				
			||||||
 | 
					    identity_token = Api(dongle_id).get_token() | 
				
			||||||
 | 
					    ws = create_connection(remote_ws_uri, | 
				
			||||||
 | 
					                           cookie="jwt=" + identity_token, | 
				
			||||||
 | 
					                           enable_multithread=True) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ssock, csock = socket.socketpair() | 
				
			||||||
 | 
					    local_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | 
				
			||||||
 | 
					    local_sock.connect(('127.0.0.1', local_port)) | 
				
			||||||
 | 
					    local_sock.setblocking(0) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    proxy_end_event = threading.Event() | 
				
			||||||
 | 
					    threads = [ | 
				
			||||||
 | 
					      threading.Thread(target=ws_proxy_recv, args=(ws, local_sock, ssock, proxy_end_event, global_end_event)), | 
				
			||||||
 | 
					      threading.Thread(target=ws_proxy_send, args=(ws, local_sock, csock, proxy_end_event)) | 
				
			||||||
 | 
					    ] | 
				
			||||||
 | 
					    for thread in threads: | 
				
			||||||
 | 
					      thread.start() | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return {"success": 1} | 
				
			||||||
 | 
					  except Exception as e: | 
				
			||||||
 | 
					    cloudlog.exception("athenad.startLocalProxy.exception") | 
				
			||||||
 | 
					    raise e | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dispatcher.add_method | 
				
			||||||
 | 
					def getPublicKey(): | 
				
			||||||
 | 
					  if not os.path.isfile('/persist/comma/id_rsa.pub'): | 
				
			||||||
 | 
					    return None | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  with open('/persist/comma/id_rsa.pub', 'r') as f: | 
				
			||||||
 | 
					    return f.read() | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dispatcher.add_method | 
				
			||||||
 | 
					def getSshAuthorizedKeys(): | 
				
			||||||
 | 
					  return Params().get("GithubSshKeys", encoding='utf8') or '' | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dispatcher.add_method | 
				
			||||||
 | 
					def getSimInfo(): | 
				
			||||||
 | 
					  sim_state = android.getprop("gsm.sim.state").split(",") | 
				
			||||||
 | 
					  network_type = android.getprop("gsm.network.type").split(',') | 
				
			||||||
 | 
					  mcc_mnc = android.getprop("gsm.sim.operator.numeric") or None | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  sim_id = android.parse_service_call_string(android.service_call(['iphonesubinfo', '11'])) | 
				
			||||||
 | 
					  cell_data_state = android.parse_service_call_unpack(android.service_call(['phone', '46']), ">q") | 
				
			||||||
 | 
					  cell_data_connected = (cell_data_state == 2) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return { | 
				
			||||||
 | 
					    'sim_id': sim_id, | 
				
			||||||
 | 
					    'mcc_mnc': mcc_mnc, | 
				
			||||||
 | 
					    'network_type': network_type, | 
				
			||||||
 | 
					    'sim_state': sim_state, | 
				
			||||||
 | 
					    'data_connected': cell_data_connected | 
				
			||||||
 | 
					  } | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dispatcher.add_method | 
				
			||||||
 | 
					def takeSnapshot(): | 
				
			||||||
 | 
					  from selfdrive.camerad.snapshot.snapshot import snapshot, jpeg_write | 
				
			||||||
 | 
					  ret = snapshot() | 
				
			||||||
 | 
					  if ret is not None: | 
				
			||||||
 | 
					    def b64jpeg(x): | 
				
			||||||
 | 
					      if x is not None: | 
				
			||||||
 | 
					        f = io.BytesIO() | 
				
			||||||
 | 
					        jpeg_write(f, x) | 
				
			||||||
 | 
					        return base64.b64encode(f.getvalue()).decode("utf-8") | 
				
			||||||
 | 
					      else: | 
				
			||||||
 | 
					        return None | 
				
			||||||
 | 
					    return {'jpegBack': b64jpeg(ret[0]), | 
				
			||||||
 | 
					            'jpegFront': b64jpeg(ret[1])} | 
				
			||||||
 | 
					  else: | 
				
			||||||
 | 
					    raise Exception("not available while camerad is started") | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def ws_proxy_recv(ws, local_sock, ssock, end_event, global_end_event): | 
				
			||||||
 | 
					  while not (end_event.is_set() or global_end_event.is_set()): | 
				
			||||||
 | 
					    try: | 
				
			||||||
 | 
					      data = ws.recv() | 
				
			||||||
 | 
					      local_sock.sendall(data) | 
				
			||||||
 | 
					    except WebSocketTimeoutException: | 
				
			||||||
 | 
					      pass | 
				
			||||||
 | 
					    except Exception: | 
				
			||||||
 | 
					      cloudlog.exception("athenad.ws_proxy_recv.exception") | 
				
			||||||
 | 
					      break | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  ssock.close() | 
				
			||||||
 | 
					  local_sock.close() | 
				
			||||||
 | 
					  end_event.set() | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def ws_proxy_send(ws, local_sock, signal_sock, end_event): | 
				
			||||||
 | 
					  while not end_event.is_set(): | 
				
			||||||
 | 
					    try: | 
				
			||||||
 | 
					      r, _, _ = select.select((local_sock, signal_sock), (), ()) | 
				
			||||||
 | 
					      if r: | 
				
			||||||
 | 
					        if r[0].fileno() == signal_sock.fileno(): | 
				
			||||||
 | 
					          # got end signal from ws_proxy_recv | 
				
			||||||
 | 
					          end_event.set() | 
				
			||||||
 | 
					          break | 
				
			||||||
 | 
					        data = local_sock.recv(4096) | 
				
			||||||
 | 
					        if not data: | 
				
			||||||
 | 
					          # local_sock is dead | 
				
			||||||
 | 
					          end_event.set() | 
				
			||||||
 | 
					          break | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ws.send(data, ABNF.OPCODE_BINARY) | 
				
			||||||
 | 
					    except Exception: | 
				
			||||||
 | 
					      cloudlog.exception("athenad.ws_proxy_send.exception") | 
				
			||||||
 | 
					      end_event.set() | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def ws_recv(ws, end_event): | 
				
			||||||
 | 
					  while not end_event.is_set(): | 
				
			||||||
 | 
					    try: | 
				
			||||||
 | 
					      data = ws.recv() | 
				
			||||||
 | 
					      payload_queue.put_nowait(data) | 
				
			||||||
 | 
					    except WebSocketTimeoutException: | 
				
			||||||
 | 
					      pass | 
				
			||||||
 | 
					    except Exception: | 
				
			||||||
 | 
					      cloudlog.exception("athenad.ws_recv.exception") | 
				
			||||||
 | 
					      end_event.set() | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def ws_send(ws, end_event): | 
				
			||||||
 | 
					  while not end_event.is_set(): | 
				
			||||||
 | 
					    try: | 
				
			||||||
 | 
					      response = response_queue.get(timeout=1) | 
				
			||||||
 | 
					      ws.send(response.json) | 
				
			||||||
 | 
					    except queue.Empty: | 
				
			||||||
 | 
					      pass | 
				
			||||||
 | 
					    except Exception: | 
				
			||||||
 | 
					      cloudlog.exception("athenad.ws_send.exception") | 
				
			||||||
 | 
					      end_event.set() | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def backoff(retries): | 
				
			||||||
 | 
					  return random.randrange(0, min(128, int(2 ** retries))) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def main(gctx=None): | 
				
			||||||
 | 
					  params = Params() | 
				
			||||||
 | 
					  dongle_id = params.get("DongleId").decode('utf-8') | 
				
			||||||
 | 
					  ws_uri = ATHENA_HOST + "/ws/v2/" + dongle_id | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  api = Api(dongle_id) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  conn_retries = 0 | 
				
			||||||
 | 
					  while 1: | 
				
			||||||
 | 
					    try: | 
				
			||||||
 | 
					      ws = create_connection(ws_uri, | 
				
			||||||
 | 
					                             cookie="jwt=" + api.get_token(), | 
				
			||||||
 | 
					                             enable_multithread=True) | 
				
			||||||
 | 
					      cloudlog.event("athenad.main.connected_ws", ws_uri=ws_uri) | 
				
			||||||
 | 
					      ws.settimeout(1) | 
				
			||||||
 | 
					      conn_retries = 0 | 
				
			||||||
 | 
					      handle_long_poll(ws) | 
				
			||||||
 | 
					    except (KeyboardInterrupt, SystemExit): | 
				
			||||||
 | 
					      break | 
				
			||||||
 | 
					    except Exception: | 
				
			||||||
 | 
					      cloudlog.exception("athenad.main.exception") | 
				
			||||||
 | 
					      conn_retries += 1 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    time.sleep(backoff(conn_retries)) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__": | 
				
			||||||
 | 
					  main() | 
				
			||||||
@ -0,0 +1,36 @@ | 
				
			|||||||
 | 
					#!/usr/bin/env python3 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import time | 
				
			||||||
 | 
					from multiprocessing import Process | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import selfdrive.crash as crash | 
				
			||||||
 | 
					from common.params import Params | 
				
			||||||
 | 
					from selfdrive.launcher import launcher | 
				
			||||||
 | 
					from selfdrive.swaglog import cloudlog | 
				
			||||||
 | 
					from selfdrive.version import version, dirty | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					ATHENA_MGR_PID_PARAM = "AthenadPid" | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def main(): | 
				
			||||||
 | 
					  params = Params() | 
				
			||||||
 | 
					  dongle_id = params.get("DongleId").decode('utf-8') | 
				
			||||||
 | 
					  cloudlog.bind_global(dongle_id=dongle_id, version=version, dirty=dirty, is_eon=True) | 
				
			||||||
 | 
					  crash.bind_user(id=dongle_id) | 
				
			||||||
 | 
					  crash.bind_extra(version=version, dirty=dirty, is_eon=True) | 
				
			||||||
 | 
					  crash.install() | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  try: | 
				
			||||||
 | 
					    while 1: | 
				
			||||||
 | 
					      cloudlog.info("starting athena daemon") | 
				
			||||||
 | 
					      proc = Process(name='athenad', target=launcher, args=('selfdrive.athena.athenad',)) | 
				
			||||||
 | 
					      proc.start() | 
				
			||||||
 | 
					      proc.join() | 
				
			||||||
 | 
					      cloudlog.event("athenad exited", exitcode=proc.exitcode) | 
				
			||||||
 | 
					      time.sleep(5) | 
				
			||||||
 | 
					  except: | 
				
			||||||
 | 
					    cloudlog.exception("manage_athenad.exception") | 
				
			||||||
 | 
					  finally: | 
				
			||||||
 | 
					    params.delete(ATHENA_MGR_PID_PARAM) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__': | 
				
			||||||
 | 
					  main() | 
				
			||||||
@ -0,0 +1,193 @@ | 
				
			|||||||
 | 
					#!/usr/bin/env python3 | 
				
			||||||
 | 
					import json | 
				
			||||||
 | 
					import os | 
				
			||||||
 | 
					import requests | 
				
			||||||
 | 
					import tempfile | 
				
			||||||
 | 
					import time | 
				
			||||||
 | 
					import threading | 
				
			||||||
 | 
					import queue | 
				
			||||||
 | 
					import unittest | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from multiprocessing import Process | 
				
			||||||
 | 
					from pathlib import Path | 
				
			||||||
 | 
					from unittest import mock | 
				
			||||||
 | 
					from websocket import ABNF | 
				
			||||||
 | 
					from websocket._exceptions import WebSocketConnectionClosedException | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from selfdrive.athena import athenad | 
				
			||||||
 | 
					from selfdrive.athena.athenad import dispatcher | 
				
			||||||
 | 
					from selfdrive.athena.test_helpers import MockWebsocket, MockParams, MockApi, EchoSocket, with_http_server | 
				
			||||||
 | 
					from cereal import messaging | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestAthenadMethods(unittest.TestCase): | 
				
			||||||
 | 
					  @classmethod | 
				
			||||||
 | 
					  def setUpClass(cls): | 
				
			||||||
 | 
					    cls.SOCKET_PORT = 45454 | 
				
			||||||
 | 
					    athenad.ROOT = tempfile.mkdtemp() | 
				
			||||||
 | 
					    athenad.Params = MockParams | 
				
			||||||
 | 
					    athenad.Api = MockApi | 
				
			||||||
 | 
					    athenad.LOCAL_PORT_WHITELIST = set([cls.SOCKET_PORT]) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_echo(self): | 
				
			||||||
 | 
					    assert dispatcher["echo"]("bob") == "bob" | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_getMessage(self): | 
				
			||||||
 | 
					    with self.assertRaises(TimeoutError) as _: | 
				
			||||||
 | 
					      dispatcher["getMessage"]("controlsState") | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def send_thermal(): | 
				
			||||||
 | 
					      messaging.context = messaging.Context() | 
				
			||||||
 | 
					      pub_sock = messaging.pub_sock("thermal") | 
				
			||||||
 | 
					      start = time.time() | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      while time.time() - start < 1: | 
				
			||||||
 | 
					        msg = messaging.new_message() | 
				
			||||||
 | 
					        msg.init('thermal') | 
				
			||||||
 | 
					        pub_sock.send(msg.to_bytes()) | 
				
			||||||
 | 
					        time.sleep(0.01) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    p = Process(target=send_thermal) | 
				
			||||||
 | 
					    p.start() | 
				
			||||||
 | 
					    time.sleep(0.1) | 
				
			||||||
 | 
					    try: | 
				
			||||||
 | 
					      thermal = dispatcher["getMessage"]("thermal") | 
				
			||||||
 | 
					      assert thermal['thermal'] | 
				
			||||||
 | 
					    finally: | 
				
			||||||
 | 
					      p.terminate() | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_listDataDirectory(self): | 
				
			||||||
 | 
					    print(dispatcher["listDataDirectory"]()) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @with_http_server | 
				
			||||||
 | 
					  def test_do_upload(self, host): | 
				
			||||||
 | 
					    fn = os.path.join(athenad.ROOT, 'qlog.bz2') | 
				
			||||||
 | 
					    Path(fn).touch() | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    try: | 
				
			||||||
 | 
					      item = athenad.UploadItem(path=fn, url="http://localhost:1238", headers={}, created_at=int(time.time()*1000), id='') | 
				
			||||||
 | 
					      try: | 
				
			||||||
 | 
					        athenad._do_upload(item) | 
				
			||||||
 | 
					      except requests.exceptions.ConnectionError: | 
				
			||||||
 | 
					        pass | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      item = athenad.UploadItem(path=fn, url=f"{host}/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='') | 
				
			||||||
 | 
					      resp = athenad._do_upload(item) | 
				
			||||||
 | 
					      self.assertEqual(resp.status_code, 201) | 
				
			||||||
 | 
					    finally: | 
				
			||||||
 | 
					      os.unlink(fn) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @with_http_server | 
				
			||||||
 | 
					  def test_uploadFileToUrl(self, host): | 
				
			||||||
 | 
					    not_exists_resp = dispatcher["uploadFileToUrl"]("does_not_exist.bz2", "http://localhost:1238", {}) | 
				
			||||||
 | 
					    self.assertEqual(not_exists_resp, 404) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    fn = os.path.join(athenad.ROOT, 'qlog.bz2') | 
				
			||||||
 | 
					    Path(fn).touch() | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    try: | 
				
			||||||
 | 
					      resp = dispatcher["uploadFileToUrl"]("qlog.bz2", f"{host}/qlog.bz2", {}) | 
				
			||||||
 | 
					      self.assertEqual(resp['enqueued'], 1) | 
				
			||||||
 | 
					      self.assertDictContainsSubset({"path": fn, "url": f"{host}/qlog.bz2", "headers": {}}, resp['item']) | 
				
			||||||
 | 
					      self.assertIsNotNone(resp['item'].get('id')) | 
				
			||||||
 | 
					      self.assertEqual(athenad.upload_queue.qsize(), 1) | 
				
			||||||
 | 
					    finally: | 
				
			||||||
 | 
					      athenad.upload_queue = queue.Queue() | 
				
			||||||
 | 
					      os.unlink(fn) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @with_http_server | 
				
			||||||
 | 
					  def test_upload_handler(self, host): | 
				
			||||||
 | 
					    fn = os.path.join(athenad.ROOT, 'qlog.bz2') | 
				
			||||||
 | 
					    Path(fn).touch() | 
				
			||||||
 | 
					    item = athenad.UploadItem(path=fn, url=f"{host}/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='') | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    end_event = threading.Event() | 
				
			||||||
 | 
					    thread = threading.Thread(target=athenad.upload_handler, args=(end_event,)) | 
				
			||||||
 | 
					    thread.start() | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    athenad.upload_queue.put_nowait(item) | 
				
			||||||
 | 
					    try: | 
				
			||||||
 | 
					      now = time.time() | 
				
			||||||
 | 
					      while time.time() - now < 5: | 
				
			||||||
 | 
					        if athenad.upload_queue.qsize() == 0: | 
				
			||||||
 | 
					          break | 
				
			||||||
 | 
					      self.assertEqual(athenad.upload_queue.qsize(), 0) | 
				
			||||||
 | 
					    finally: | 
				
			||||||
 | 
					      end_event.set() | 
				
			||||||
 | 
					      athenad.upload_queue = queue.Queue() | 
				
			||||||
 | 
					      os.unlink(fn) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_cancelUpload(self): | 
				
			||||||
 | 
					    item = athenad.UploadItem(path="qlog.bz2", url="http://localhost:44444/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='id') | 
				
			||||||
 | 
					    athenad.upload_queue.put_nowait(item) | 
				
			||||||
 | 
					    dispatcher["cancelUpload"](item.id) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    self.assertIn(item.id, athenad.cancelled_uploads) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    end_event = threading.Event() | 
				
			||||||
 | 
					    thread = threading.Thread(target=athenad.upload_handler, args=(end_event,)) | 
				
			||||||
 | 
					    thread.start() | 
				
			||||||
 | 
					    try: | 
				
			||||||
 | 
					      now = time.time() | 
				
			||||||
 | 
					      while time.time() - now < 5: | 
				
			||||||
 | 
					        if athenad.upload_queue.qsize() == 0 and len(athenad.cancelled_uploads) == 0: | 
				
			||||||
 | 
					          break | 
				
			||||||
 | 
					      self.assertEqual(athenad.upload_queue.qsize(), 0) | 
				
			||||||
 | 
					      self.assertEqual(len(athenad.cancelled_uploads), 0) | 
				
			||||||
 | 
					    finally: | 
				
			||||||
 | 
					      end_event.set() | 
				
			||||||
 | 
					      athenad.upload_queue = queue.Queue() | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_listUploadQueue(self): | 
				
			||||||
 | 
					    item = athenad.UploadItem(path="qlog.bz2", url="http://localhost:44444/qlog.bz2", headers={}, created_at=int(time.time()*1000), id='id') | 
				
			||||||
 | 
					    athenad.upload_queue.put_nowait(item) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    try: | 
				
			||||||
 | 
					      items = dispatcher["listUploadQueue"]() | 
				
			||||||
 | 
					      self.assertEqual(len(items), 1) | 
				
			||||||
 | 
					      self.assertDictEqual(items[0], item._asdict()) | 
				
			||||||
 | 
					    finally: | 
				
			||||||
 | 
					      athenad.upload_queue = queue.Queue() | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @mock.patch('selfdrive.athena.athenad.create_connection') | 
				
			||||||
 | 
					  def test_startLocalProxy(self, mock_create_connection): | 
				
			||||||
 | 
					    end_event = threading.Event() | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ws_recv = queue.Queue() | 
				
			||||||
 | 
					    ws_send = queue.Queue() | 
				
			||||||
 | 
					    mock_ws = MockWebsocket(ws_recv, ws_send) | 
				
			||||||
 | 
					    mock_create_connection.return_value = mock_ws | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    echo_socket = EchoSocket(self.SOCKET_PORT) | 
				
			||||||
 | 
					    socket_thread = threading.Thread(target=echo_socket.run) | 
				
			||||||
 | 
					    socket_thread.start() | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    athenad.startLocalProxy(end_event, 'ws://localhost:1234', self.SOCKET_PORT) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ws_recv.put_nowait(b'ping') | 
				
			||||||
 | 
					    try: | 
				
			||||||
 | 
					      recv = ws_send.get(timeout=5) | 
				
			||||||
 | 
					      assert recv == (b'ping', ABNF.OPCODE_BINARY), recv | 
				
			||||||
 | 
					    finally: | 
				
			||||||
 | 
					      # signal websocket close to athenad.ws_proxy_recv | 
				
			||||||
 | 
					      ws_recv.put_nowait(WebSocketConnectionClosedException()) | 
				
			||||||
 | 
					      socket_thread.join() | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_getSshAuthorizedKeys(self): | 
				
			||||||
 | 
					    keys = dispatcher["getSshAuthorizedKeys"]() | 
				
			||||||
 | 
					    self.assertEqual(keys, MockParams().params["GithubSshKeys"].decode('utf-8')) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_jsonrpc_handler(self): | 
				
			||||||
 | 
					    end_event = threading.Event() | 
				
			||||||
 | 
					    thread = threading.Thread(target=athenad.jsonrpc_handler, args=(end_event,)) | 
				
			||||||
 | 
					    thread.daemon = True | 
				
			||||||
 | 
					    thread.start() | 
				
			||||||
 | 
					    athenad.payload_queue.put_nowait(json.dumps({"method": "echo", "params": ["hello"], "jsonrpc": "2.0", "id": 0})) | 
				
			||||||
 | 
					    try: | 
				
			||||||
 | 
					      resp = athenad.response_queue.get(timeout=3) | 
				
			||||||
 | 
					      self.assertDictEqual(resp.data, {'result': 'hello', 'id': 0, 'jsonrpc': '2.0'}) | 
				
			||||||
 | 
					    finally: | 
				
			||||||
 | 
					      end_event.set() | 
				
			||||||
 | 
					      thread.join() | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__': | 
				
			||||||
 | 
					  unittest.main() | 
				
			||||||
@ -0,0 +1,114 @@ | 
				
			|||||||
 | 
					import http.server | 
				
			||||||
 | 
					import multiprocessing | 
				
			||||||
 | 
					import queue | 
				
			||||||
 | 
					import random | 
				
			||||||
 | 
					import requests | 
				
			||||||
 | 
					import socket | 
				
			||||||
 | 
					import time | 
				
			||||||
 | 
					from functools import wraps | 
				
			||||||
 | 
					from multiprocessing import Process | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class EchoSocket(): | 
				
			||||||
 | 
					  def __init__(self, port): | 
				
			||||||
 | 
					    self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | 
				
			||||||
 | 
					    self.socket.bind(('127.0.0.1', port)) | 
				
			||||||
 | 
					    self.socket.listen(1) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def run(self): | 
				
			||||||
 | 
					    conn, client_address = self.socket.accept() | 
				
			||||||
 | 
					    conn.settimeout(5.0) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    try: | 
				
			||||||
 | 
					      while True: | 
				
			||||||
 | 
					        data = conn.recv(4096) | 
				
			||||||
 | 
					        if data: | 
				
			||||||
 | 
					          print(f'EchoSocket got {data}') | 
				
			||||||
 | 
					          conn.sendall(data) | 
				
			||||||
 | 
					        else: | 
				
			||||||
 | 
					          break | 
				
			||||||
 | 
					    finally: | 
				
			||||||
 | 
					      conn.shutdown(0) | 
				
			||||||
 | 
					      conn.close() | 
				
			||||||
 | 
					      self.socket.shutdown(0) | 
				
			||||||
 | 
					      self.socket.close() | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MockApi(): | 
				
			||||||
 | 
					  def __init__(self, dongle_id): | 
				
			||||||
 | 
					    pass | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def get_token(self): | 
				
			||||||
 | 
					    return "fake-token" | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MockParams(): | 
				
			||||||
 | 
					  def __init__(self): | 
				
			||||||
 | 
					    self.params = { | 
				
			||||||
 | 
					      "DongleId": b"0000000000000000", | 
				
			||||||
 | 
					      "GithubSshKeys": b"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQC307aE+nuHzTAgaJhzSf5v7ZZQW9gaperjhCmyPyl4PzY7T1mDGenTlVTN7yoVFZ9UfO9oMQqo0n1OwDIiqbIFxqnhrHU0cYfj88rI85m5BEKlNu5RdaVTj1tcbaPpQc5kZEolaI1nDDjzV0lwS7jo5VYDHseiJHlik3HH1SgtdtsuamGR2T80q1SyW+5rHoMOJG73IH2553NnWuikKiuikGHUYBd00K1ilVAK2xSiMWJp55tQfZ0ecr9QjEsJ+J/efL4HqGNXhffxvypCXvbUYAFSddOwXUPo5BTKevpxMtH+2YrkpSjocWA04VnTYFiPG6U4ItKmbLOTFZtPzoez private" | 
				
			||||||
 | 
					    } | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def get(self, k, encoding=None): | 
				
			||||||
 | 
					    ret = self.params.get(k) | 
				
			||||||
 | 
					    if ret is not None and encoding is not None: | 
				
			||||||
 | 
					      ret = ret.decode(encoding) | 
				
			||||||
 | 
					    return ret | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MockWebsocket(): | 
				
			||||||
 | 
					  def __init__(self, recv_queue, send_queue): | 
				
			||||||
 | 
					    self.recv_queue = recv_queue | 
				
			||||||
 | 
					    self.send_queue = send_queue | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def recv(self): | 
				
			||||||
 | 
					    data = self.recv_queue.get() | 
				
			||||||
 | 
					    if isinstance(data, Exception): | 
				
			||||||
 | 
					      raise data | 
				
			||||||
 | 
					    return data | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def send(self, data, opcode): | 
				
			||||||
 | 
					    self.send_queue.put_nowait((data, opcode)) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class HTTPRequestHandler(http.server.SimpleHTTPRequestHandler): | 
				
			||||||
 | 
					  def do_PUT(self): | 
				
			||||||
 | 
					    length = int(self.headers['Content-Length']) | 
				
			||||||
 | 
					    self.rfile.read(length) | 
				
			||||||
 | 
					    self.send_response(201, "Created") | 
				
			||||||
 | 
					    self.end_headers() | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def http_server(port_queue, **kwargs): | 
				
			||||||
 | 
					  while 1: | 
				
			||||||
 | 
					    try: | 
				
			||||||
 | 
					      port = random.randrange(40000, 50000) | 
				
			||||||
 | 
					      port_queue.put(port) | 
				
			||||||
 | 
					      http.server.test(**kwargs, port=port) | 
				
			||||||
 | 
					    except OSError as e: | 
				
			||||||
 | 
					      if e.errno == 98: | 
				
			||||||
 | 
					        continue | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def with_http_server(func): | 
				
			||||||
 | 
					  @wraps(func) | 
				
			||||||
 | 
					  def inner(*args, **kwargs): | 
				
			||||||
 | 
					    port_queue = multiprocessing.Queue() | 
				
			||||||
 | 
					    host = '127.0.0.1' | 
				
			||||||
 | 
					    p = Process(target=http_server, | 
				
			||||||
 | 
					                args=(port_queue,), | 
				
			||||||
 | 
					                kwargs={ | 
				
			||||||
 | 
					                  'HandlerClass': HTTPRequestHandler, | 
				
			||||||
 | 
					                  'bind': host}) | 
				
			||||||
 | 
					    p.start() | 
				
			||||||
 | 
					    now = time.time() | 
				
			||||||
 | 
					    port = None | 
				
			||||||
 | 
					    while 1: | 
				
			||||||
 | 
					      if time.time() - now > 5: | 
				
			||||||
 | 
					        raise Exception('HTTP Server did not start') | 
				
			||||||
 | 
					      try: | 
				
			||||||
 | 
					        port = port_queue.get(timeout=0.1) | 
				
			||||||
 | 
					        requests.put(f'http://{host}:{port}/qlog.bz2', data='') | 
				
			||||||
 | 
					        break | 
				
			||||||
 | 
					      except (requests.exceptions.ConnectionError, queue.Empty): | 
				
			||||||
 | 
					        time.sleep(0.1) | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    try: | 
				
			||||||
 | 
					      return func(*args, f'http://{host}:{port}', **kwargs) | 
				
			||||||
 | 
					    finally: | 
				
			||||||
 | 
					      p.terminate() | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return inner | 
				
			||||||
					Loading…
					
					
				
		Reference in new issue