# the REMOTE=1 device is a process boundary between the frontend/runtime # normally tinygrad is frontend <-> middleware <-> runtime <-> hardware # with REMOTE tinygrad is frontend <-> middleware <-> RemoteDevice ///HTTP/// remote_server <-> runtime <-> hardware # this client and server can be on the same machine, same network, or just same internet # it should be a secure (example: no use of pickle) boundary. HTTP is used for RPC from __future__ import annotations from typing import Callable, Optional, Any, cast from collections import defaultdict from dataclasses import dataclass, field, replace import multiprocessing, functools, asyncio, http, http.client, hashlib, time, os, binascii, struct, ast, contextlib from tinygrad.renderer import Renderer, ProgramSpec from tinygrad.dtype import DTYPES_DICT, dtypes from tinygrad.ops import UOp, Ops, Variable, sint from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, Timing from tinygrad.engine.jit import GraphRunner, ExecItem, graph_class from tinygrad.engine.realize import CompiledRunner from tinygrad.device import Compiled, Buffer, Allocator, Compiler, Device, BufferSpec from tinygrad.runtime.graph.cpu import CPUGraph # ***** API ***** @dataclass(frozen=True) class RemoteRequest: session: tuple[str, int]|None = field(default=None, kw_only=True) @dataclass(frozen=True) class RemoteProperties: real_device: str renderer: tuple[str, str, tuple[Any, ...]] graph_supported: bool transfer_supported: bool @dataclass(frozen=True) class GetProperties(RemoteRequest): pass @dataclass(frozen=True) class BufferAlloc(RemoteRequest): buffer_num: int; size: int; options: BufferSpec # noqa: E702 @dataclass(frozen=True) class BufferFree(RemoteRequest): buffer_num: int # noqa: E702 @dataclass(frozen=True) class CopyIn(RemoteRequest): buffer_num: int; datahash: str # noqa: E702 @dataclass(frozen=True) class CopyOut(RemoteRequest): buffer_num: int @dataclass(frozen=True) class Transfer(RemoteRequest): buffer_num: int; ssession: tuple[str, int]; sbuffer_num: int # noqa: E702 @dataclass(frozen=True) class ProgramAlloc(RemoteRequest): name: str; datahash: str # noqa: E702 @dataclass(frozen=True) class ProgramFree(RemoteRequest): name: str; datahash: str # noqa: E702 @dataclass(frozen=True) class ProgramExec(RemoteRequest): name: str; datahash: str; bufs: tuple[int, ...]; vals: tuple[int, ...] # noqa: E702 global_size: Optional[tuple[int, ...]]; local_size: Optional[tuple[int, ...]]; wait: bool # noqa: E702 @dataclass(frozen=True) class GraphComputeItem: name: str datahash: str bufs: tuple[int, ...] vars: tuple[Variable, ...] global_size: tuple[sint, ...]|None local_size: tuple[sint, ...]|None @dataclass(frozen=True) class GraphAlloc(RemoteRequest): graph_num: int jit_cache: tuple[GraphComputeItem, ...] bufs: tuple[int, ...] var_vals: dict[Variable, int] @dataclass(frozen=True) class GraphFree(RemoteRequest): graph_num: int @dataclass(frozen=True) class GraphExec(RemoteRequest): graph_num: int bufs: tuple[int, ...] var_vals: dict[Variable, int] wait: bool # for safe deserialization eval_globals = {x.__name__:x for x in [RemoteProperties, GetProperties, BufferAlloc, BufferFree, CopyIn, CopyOut, Transfer, ProgramAlloc, ProgramFree, ProgramExec, GraphComputeItem, GraphAlloc, GraphFree, GraphExec, BufferSpec, UOp, Ops, dtypes]} attribute_whitelist: dict[Any, set[str]] = {dtypes: {*DTYPES_DICT.keys(), 'imagef', 'imageh'}, Ops: {x.name for x in Ops}} eval_fxns = {ast.Constant: lambda x: x.value, ast.Tuple: lambda x: tuple(map(safe_eval, x.elts)), ast.List: lambda x: list(map(safe_eval, x.elts)), ast.Dict: lambda x: {safe_eval(k):safe_eval(v) for k,v in zip(x.keys, x.values)}, ast.Call: lambda x: safe_eval(x.func)(*[safe_eval(arg) for arg in x.args], **{kwarg.arg: safe_eval(kwarg.value) for kwarg in x.keywords}), ast.Name: lambda x: eval_globals[x.id], ast.Attribute: lambda x: safe_getattr(safe_eval(x.value), x.attr)} def safe_getattr(value, attr): assert attr in attribute_whitelist.get(value, set()), f'getattr({value}, {repr(attr)}) is not whitelisted' return getattr(value, attr) def safe_eval(node): return eval_fxns[node.__class__](node) class BatchRequest: def __init__(self): self._q: list[RemoteRequest] = [] self._h: dict[str, bytes] = {} def h(self, d:bytes) -> str: binhash = hashlib.sha256(d).digest() self._h[datahash:=binascii.hexlify(binhash).decode()] = binhash+struct.pack(" bytes: self.h(repr(self._q).encode()) return b''.join(self._h.values()) def deserialize(self, dat:bytes) -> BatchRequest: ptr = 0 while ptr < len(dat): datahash, datalen = binascii.hexlify(dat[ptr:ptr+0x20]).decode(), struct.unpack(" tuple[http.HTTPStatus, bytes]: status, ret = http.HTTPStatus.OK, b"" if path == "/batch" and method == "POST": # TODO: streaming deserialize? req = BatchRequest().deserialize(body) # the cmds are always last (currently in datahash) for c in req._q: if DEBUG >= 1: print(c) session, dev = self.sessions[unwrap(c.session)], Device[f"{self.base_device}:{unwrap(c.session)[1]}"] match c: case GetProperties(): cls, args = dev.renderer.__reduce__() # CPUGraph re-renders kernel from uops specified in CompiledRunner, this is not supported graph_cls = gt if (gt:=graph_class(Device[self.base_device])) is not CPUGraph else None rp = RemoteProperties(dev.device, (cls.__module__, cls.__name__, args), graph_cls is not None, hasattr(dev.allocator, '_transfer')) ret = repr(rp).encode() case BufferAlloc(): assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated" session.buffers[c.buffer_num] = Buffer(dev.device, c.size, dtypes.uint8, options=c.options, preallocate=True) case BufferFree(): del session.buffers[c.buffer_num] case CopyIn(): session.buffers[c.buffer_num].copyin(memoryview(bytearray(req._h[c.datahash]))) case CopyOut(): session.buffers[c.buffer_num].copyout(memoryview(ret:=bytearray(session.buffers[c.buffer_num].nbytes))) case Transfer(): ssession, sdev = self.sessions[c.ssession], Device[f"{self.base_device}:{unwrap(c.ssession)[1]}"] dbuf, sbuf = session.buffers[c.buffer_num], ssession.buffers[c.sbuffer_num] assert dbuf.nbytes == sbuf.nbytes, f"{dbuf.nbytes} != {sbuf.nbytes}" assert hasattr(dev.allocator, '_transfer'), f"Device {dev.device} doesn't support transfers" dev.allocator._transfer(dbuf._buf, sbuf._buf, dbuf.nbytes, dest_dev=dev, src_dev=sdev) case ProgramAlloc(): lib = dev.compiler.compile_cached(req._h[c.datahash].decode()) session.programs[(c.name, c.datahash)] = dev.runtime(c.name, lib) case ProgramFree(): del session.programs[(c.name, c.datahash)] case ProgramExec(): bufs = [session.buffers[x]._buf for x in c.bufs] extra_args = {k:v for k,v in [("global_size", c.global_size), ("local_size", c.local_size)] if v is not None} r = session.programs[(c.name, c.datahash)](*bufs, vals=c.vals, wait=c.wait, **extra_args) if r is not None: ret = str(r).encode() case GraphAlloc(): graph_fn: Callable = unwrap(dev.graph) def _parse_ji(gi: GraphComputeItem): prg = session.programs[(gi.name, gi.datahash)] ps = ProgramSpec(gi.name, '', dev.device, UOp(Ops.NOOP), vars=list(gi.vars), global_size=list(cast(tuple[int], gi.global_size)) if gi.global_size is not None else None, local_size=list(cast(tuple[int], gi.local_size)) if gi.local_size is not None else None) return ExecItem(CompiledRunner(ps, precompiled=b'', prg=prg), [session.buffers[buf] for buf in gi.bufs]) assert c.graph_num not in session.graphs, f"graph {c.graph_num} already allocated" session.graphs[c.graph_num] = graph_fn(list(map(_parse_ji, c.jit_cache)), [session.buffers[buf] for buf in c.bufs], c.var_vals) case GraphFree(): del session.graphs[c.graph_num] case GraphExec(): r = session.graphs[c.graph_num]([session.buffers[buf] for buf in c.bufs], c.var_vals, wait=c.wait) if r is not None: ret = str(r).encode() else: status, ret = http.HTTPStatus.NOT_FOUND, b"Not Found" return status, ret def remote_server(port:int): device = getenv("REMOTEDEV", next(Device.get_available_devices()) if Device.DEFAULT == "REMOTE" else Device.DEFAULT) async def _inner_async(port:int, device:str): print(f"start remote server on {port} with device {device}") await (await asyncio.start_server(RemoteHandler(device), host='', port=port)).serve_forever() asyncio.run(_inner_async(port, device)) # ***** frontend ***** class RemoteAllocator(Allocator['RemoteDevice']): # TODO: ideally we shouldn't have to deal with images here def _alloc(self, size:int, options:BufferSpec) -> int: self.dev.buffer_num += 1 self.dev.q(BufferAlloc(self.dev.buffer_num, size, options)) return self.dev.buffer_num # TODO: options should not be here in any Allocator def _free(self, opaque:int, options): self.dev.q(BufferFree(opaque)) def _copyin(self, dest:int, src:memoryview): self.dev.q(CopyIn(dest, self.dev.conn.req.h(bytes(src)))) def _copyout(self, dest:memoryview, src:int): self.dev.q(CopyOut(src)) resp = self.dev.conn.batch_submit() assert len(resp) == len(dest), f"buffer length mismatch {len(resp)} != {len(dest)}" dest[:] = resp def _transfer(self, dest, src, sz, src_dev, dest_dev): if dest_dev.properties.transfer_supported and src_dev.conn == dest_dev.conn: dest_dev.q(Transfer(dest, src_dev.session, src)) else: src_dev.allocator._copyout(tmp:=memoryview(bytearray(sz)), src) dest_dev.allocator._copyin(dest, tmp) class RemoteProgram: def __init__(self, dev:RemoteDevice, name:str, lib:bytes): self.dev, self.name = dev, name self.datahash = self.dev.conn.req.h(lib) self.dev.q(ProgramAlloc(self.name, self.datahash)) super().__init__() def __del__(self): self.dev.q(ProgramFree(self.name, self.datahash)) def __call__(self, *bufs, global_size=None, local_size=None, vals:tuple[int, ...]=(), wait=False): self.dev.q(ProgramExec(self.name, self.datahash, bufs, vals, global_size, local_size, wait)) if wait: return float(self.dev.conn.batch_submit()) @functools.cache class RemoteConnection: def __init__(self, host:str): if DEBUG >= 1: print(f"remote with host {host}") while 1: try: self.conn = http.client.HTTPConnection(host, timeout=60.0) self.conn.connect() break except Exception as e: print(e) time.sleep(0.1) self.req: BatchRequest = BatchRequest() def batch_submit(self): data = self.req.serialize() with Timing(f"*** send {len(self.req._q):-3d} requests {len(self.req._h):-3d} hashes with len {len(data)/1024:.2f} kB in ", enabled=DEBUG>=1): self.conn.request("POST", "/batch", data) response = self.conn.getresponse() assert response.status == 200, f"POST /batch failed: {response}" ret = response.read() self.req = BatchRequest() return ret class RemoteDevice(Compiled): def __init__(self, device:str): self.conn: RemoteConnection = RemoteConnection(getenv("HOST", "") or RemoteDevice.local_server()) # state for the connection self.session = (binascii.hexlify(os.urandom(0x10)).decode(), int(device.split(":")[1]) if ":" in device else 0) self.buffer_num: int = 0 self.graph_num: int = 0 self.q(GetProperties()) self.properties = safe_eval(ast.parse(self.conn.batch_submit(), mode="eval").body) if DEBUG >= 1: print(f"remote has device {self.properties.real_device}") # TODO: how to we have BEAM be cached on the backend? this should just send a specification of the compute. rethink what goes in Renderer renderer = self.properties.renderer if not renderer[0].startswith("tinygrad.renderer.") or not renderer[1].endswith("Renderer"): raise RuntimeError(f"bad renderer {renderer}") renderer_class = fromimport(renderer[0], renderer[1]) # TODO: is this secure? if not issubclass(renderer_class, Renderer): raise RuntimeError(f"renderer isn't a Renderer {renderer}") graph = fromimport('tinygrad.runtime.graph.remote', 'RemoteGraph') if self.properties.graph_supported else None super().__init__(device, RemoteAllocator(self), renderer_class(*renderer[2]), Compiler(), functools.partial(RemoteProgram, self), graph) def __del__(self): # TODO: this is never being called # TODO: should close the whole session with contextlib.suppress(ConnectionRefusedError, http.client.CannotSendRequest, http.client.RemoteDisconnected): self.conn.batch_submit() def q(self, x:RemoteRequest): self.conn.req.q(replace(x, session=self.session)) @functools.cache @staticmethod def local_server(): multiprocessing.Process(target=remote_server, args=(6667,), name="MainProcess", daemon=True).start() return "127.0.0.1:6667" if __name__ == "__main__": remote_server(getenv("PORT", 6667))