openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

220 lines
10 KiB

# the CLOUD=1 device is a process boundary between the frontend/runtime
# normally tinygrad is frontend <-> middleware <-> runtime <-> hardware
# with CLOUD tinygrad is frontend <-> middleware <-> CloudDevice ///HTTP/// cloud_server <-> runtime <-> hardware
# this client and server can be on the same machine, same network, or just same internet
# it should be a secure (example: no use of pickle) boundary. HTTP is used for RPC
from __future__ import annotations
from typing import Optional, Any
from collections import defaultdict
from dataclasses import dataclass, field
import multiprocessing, functools, http.client, hashlib, json, time, os, binascii, struct, ast, contextlib
from http.server import HTTPServer, BaseHTTPRequestHandler
from tinygrad.renderer import Renderer
from tinygrad.dtype import dtypes
from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, Timing
from tinygrad.device import Compiled, Allocator, Compiler, Device, BufferSpec
# ***** API *****
class CloudRequest: pass
@dataclass(frozen=True)
class BufferAlloc(CloudRequest): buffer_num: int; size: int; options: BufferSpec # noqa: E702
@dataclass(frozen=True)
class BufferFree(CloudRequest): buffer_num: int # noqa: E702
@dataclass(frozen=True)
class CopyIn(CloudRequest): buffer_num: int; datahash: str # noqa: E702
@dataclass(frozen=True)
class CopyOut(CloudRequest): buffer_num: int
@dataclass(frozen=True)
class ProgramAlloc(CloudRequest): name: str; datahash: str # noqa: E702
@dataclass(frozen=True)
class ProgramFree(CloudRequest): name: str; datahash: str # noqa: E702
@dataclass(frozen=True)
class ProgramExec(CloudRequest):
name: str; datahash: str; bufs: tuple[int, ...]; vals: tuple[int, ...] # noqa: E702
global_size: Optional[tuple[int, ...]]; local_size: Optional[tuple[int, ...]]; wait: bool # noqa: E702
# for safe deserialization
whitelist = {x.__name__:x for x in [BufferAlloc, BufferFree, CopyIn, CopyOut, ProgramAlloc, ProgramFree, ProgramExec, BufferSpec]}
eval_fxns = {ast.Constant: lambda x: x.value, ast.Tuple: lambda x: tuple(map(safe_eval, x.elts)), ast.List: lambda x: list(map(safe_eval, x.elts)),
ast.Call: lambda x: safe_eval(x.func)(*[safe_eval(arg) for arg in x.args], **{kwarg.arg: safe_eval(kwarg.value) for kwarg in x.keywords}),
ast.Name: lambda x: whitelist[x.id], ast.Attribute: lambda x: {"imagef": dtypes.imagef, "imageh": dtypes.imageh}[x.attr]}
def safe_eval(node): return eval_fxns[node.__class__](node)
class BatchRequest:
def __init__(self):
self._q: list[CloudRequest] = []
self._h: dict[str, bytes] = {}
def h(self, d:bytes) -> str:
binhash = hashlib.sha256(d).digest()
self._h[datahash:=binascii.hexlify(binhash).decode()] = binhash+struct.pack("<Q", len(d))+d
return datahash
def q(self, x:CloudRequest): self._q.append(x)
def serialize(self) -> bytes:
self.h(repr(self._q).encode())
return b''.join(self._h.values())
def deserialize(self, dat:bytes) -> BatchRequest:
ptr = 0
while ptr < len(dat):
datahash, datalen = binascii.hexlify(dat[ptr:ptr+0x20]).decode(), struct.unpack("<Q", dat[ptr+0x20:ptr+0x28])[0]
self._h[datahash] = dat[ptr+0x28:ptr+0x28+datalen]
ptr += 0x28+datalen
self._q = safe_eval(ast.parse(self._h[datahash], mode="eval").body)
return self
# ***** backend *****
@dataclass
class CloudSession:
programs: dict[tuple[str, str], Any] = field(default_factory=dict)
# TODO: the buffer should track this internally
buffers: dict[int, tuple[Any, int, Optional[BufferSpec]]] = field(default_factory=dict)
class CloudHandler(BaseHTTPRequestHandler):
protocol_version = 'HTTP/1.1'
device: str
sessions: defaultdict[str, CloudSession] = defaultdict(CloudSession)
def setup(self):
super().setup()
print(f"connection established with {self.client_address}, socket: {self.connection.fileno()}")
def _do(self, method):
session = CloudHandler.sessions[unwrap(self.headers.get("Cookie")).split("session=")[1]]
ret, status_code = b"", 200
if self.path == "/batch" and method == "POST":
# TODO: streaming deserialize?
req = BatchRequest().deserialize(self.rfile.read(int(unwrap(self.headers.get('Content-Length')))))
# the cmds are always last (currently in datahash)
for c in req._q:
if DEBUG >= 1: print(c)
match c:
case BufferAlloc():
assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated"
session.buffers[c.buffer_num] = (Device[CloudHandler.device].allocator.alloc(c.size, c.options), c.size, c.options)
case BufferFree():
buf,sz,buffer_options = session.buffers[c.buffer_num]
Device[CloudHandler.device].allocator.free(buf,sz,buffer_options)
del session.buffers[c.buffer_num]
case CopyIn(): Device[CloudHandler.device].allocator._copyin(session.buffers[c.buffer_num][0], memoryview(bytearray(req._h[c.datahash])))
case CopyOut():
buf,sz,_ = session.buffers[c.buffer_num]
Device[CloudHandler.device].allocator._copyout(memoryview(ret:=bytearray(sz)), buf)
case ProgramAlloc():
lib = Device[CloudHandler.device].compiler.compile_cached(req._h[c.datahash].decode())
session.programs[(c.name, c.datahash)] = Device[CloudHandler.device].runtime(c.name, lib)
case ProgramFree(): del session.programs[(c.name, c.datahash)]
case ProgramExec():
bufs = [session.buffers[x][0] for x in c.bufs]
extra_args = {k:v for k,v in [("global_size", c.global_size), ("local_size", c.local_size)] if v is not None}
r = session.programs[(c.name, c.datahash)](*bufs, vals=c.vals, wait=c.wait, **extra_args)
if r is not None: ret = str(r).encode()
elif self.path == "/renderer" and method == "GET":
cls, args = Device[CloudHandler.device].renderer.__reduce__()
ret = json.dumps((cls.__module__, cls.__name__, args)).encode()
else: status_code = 404
self.send_response(status_code)
self.send_header('Content-Length', str(len(ret)))
self.end_headers()
return self.wfile.write(ret)
def do_GET(self): return self._do("GET")
def do_POST(self): return self._do("POST")
def cloud_server(port:int):
multiprocessing.current_process().name = "MainProcess"
CloudHandler.device = getenv("CLOUDDEV", "METAL") if Device.DEFAULT == "CLOUD" else Device.DEFAULT
print(f"start cloud server on {port} with device {CloudHandler.device}")
server = HTTPServer(('', port), CloudHandler)
server.serve_forever()
# ***** frontend *****
class CloudAllocator(Allocator):
def __init__(self, dev:CloudDevice):
self.device = dev
super().__init__()
# TODO: ideally we shouldn't have to deal with images here
def _alloc(self, size:int, options:BufferSpec) -> int:
self.device.buffer_num += 1
self.device.req.q(BufferAlloc(self.device.buffer_num, size, options))
return self.device.buffer_num
# TODO: options should not be here in any Allocator
def _free(self, opaque:int, options): self.device.req.q(BufferFree(opaque))
def _copyin(self, dest:int, src:memoryview): self.device.req.q(CopyIn(dest, self.device.req.h(bytes(src))))
def _copyout(self, dest:memoryview, src:int):
self.device.req.q(CopyOut(src))
resp = self.device.batch_submit()
assert len(resp) == len(dest), f"buffer length mismatch {len(resp)} != {len(dest)}"
dest[:] = resp
class CloudProgram:
def __init__(self, dev:CloudDevice, name:str, lib:bytes):
self.dev, self.name = dev, name
self.datahash = self.dev.req.h(lib)
self.dev.req.q(ProgramAlloc(self.name, self.datahash))
super().__init__()
def __del__(self): self.dev.req.q(ProgramFree(self.name, self.datahash))
def __call__(self, *bufs, global_size=None, local_size=None, vals:tuple[int, ...]=(), wait=False):
self.dev.req.q(ProgramExec(self.name, self.datahash, bufs, vals, global_size, local_size, wait))
if wait: return float(self.dev.batch_submit())
class CloudDevice(Compiled):
def __init__(self, device:str):
if (host:=getenv("HOST", "")) != "": self.host = host
else:
p = multiprocessing.Process(target=cloud_server, args=(6667,))
p.daemon = True
p.start()
self.host = "127.0.0.1:6667"
# state for the connection
self.session = binascii.hexlify(os.urandom(0x10)).decode()
self.buffer_num = 0
self.req: BatchRequest = BatchRequest()
if DEBUG >= 1: print(f"cloud with host {self.host}")
while 1:
try:
self.conn = http.client.HTTPConnection(self.host, timeout=60.0)
clouddev = json.loads(self.send("GET", "renderer").decode())
break
except Exception as e:
print(e)
time.sleep(0.1)
if DEBUG >= 1: print(f"remote has device {clouddev}")
# TODO: how to we have BEAM be cached on the backend? this should just send a specification of the compute. rethink what goes in Renderer
if not clouddev[0].startswith("tinygrad.renderer.") or not clouddev[1].endswith("Renderer"): raise RuntimeError(f"bad renderer {clouddev}")
renderer_class = fromimport(clouddev[0], clouddev[1]) # TODO: is this secure?
if not issubclass(renderer_class, Renderer): raise RuntimeError(f"renderer isn't a Renderer {clouddev}")
super().__init__(device, CloudAllocator(self), renderer_class(*clouddev[2]), Compiler(), functools.partial(CloudProgram, self))
def __del__(self):
# TODO: this is never being called
# TODO: should close the whole session
with contextlib.suppress(ConnectionRefusedError, http.client.CannotSendRequest, http.client.RemoteDisconnected): self.batch_submit()
def batch_submit(self):
data = self.req.serialize()
with Timing(f"*** send {len(self.req._q):-3d} requests {len(self.req._h):-3d} hashes with len {len(data)/1024:.2f} kB in ", enabled=DEBUG>=1):
ret = self.send("POST", "batch", data)
self.req = BatchRequest()
return ret
def send(self, method, path, data:Optional[bytes]=None) -> bytes:
# TODO: retry logic
self.conn.request(method, "/"+path, data, headers={"Cookie": f"session={self.session}"})
response = self.conn.getresponse()
assert response.status == 200, f"failed on {method} {path}"
return response.read()
if __name__ == "__main__": cloud_server(getenv("PORT", 6667))