openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

115 lines
6.2 KiB

import time, itertools
from tinygrad.uop.ops import Variable
from tinygrad.engine.jit import MultiGraphRunner
from tinygrad.engine.realize import CompiledRunner, BufferXfer, ExecItem
from tinygrad.device import Device, Compiled, Buffer
from tinygrad.runtime.ops_remote import RemoteDevice, RemoteConnection, RemoteRequest, GraphComputeItem, Transfer, GraphAlloc, GraphFree, GraphExec
from tinygrad.runtime.ops_remote import BatchTransfer, Event, Wait
from tinygrad.helpers import unwrap, flatten, dedup
from enum import Enum, auto
from dataclasses import replace
from collections import defaultdict
from typing import cast
class StagingType(Enum): NONE = auto(); GRAPH = auto(); TRANSFER = auto() # noqa: E702
def rd(dev:Compiled) -> RemoteDevice: return cast(RemoteDevice, dev)
def dev_key(dev:RemoteDevice): return dev.conn if dev.properties.graph_supports_multi else dev
def map_rawbuf(rawbuf:Buffer): return (cast(RemoteDevice, Device[rawbuf.device]).session, rawbuf._buf)
class RemoteGraph(MultiGraphRunner):
def __init__(self, jit_cache: list[ExecItem], rawbufs: list[Buffer], var_vals: dict[Variable, int]):
super().__init__(jit_cache, rawbufs, var_vals)
devices = dedup(flatten([[Device[unwrap(buf).device] for buf in ji.bufs] for ji in jit_cache]))
c2d = {device.conn: device for device in devices}
self.handle_indexes = {map_rawbuf(rawbufs[i]): i for i in sorted(dedup(self.input_replace.values()))}
self.template: list[RemoteRequest] = []
stagings: dict[RemoteDevice|RemoteConnection, list[GraphComputeItem|Transfer]] = defaultdict(list)
clobbered_buffers: set[Buffer] = set()
cur_staging_type: StagingType = StagingType.NONE
def _flush(new_staging_type:StagingType, force_break:bool=False):
nonlocal cur_staging_type
if cur_staging_type == new_staging_type and not force_break: return
# Pre-sync
if cur_staging_type == StagingType.TRANSFER:
for sdev,ddev in itertools.permutations(c2d.values(), 2):
self.template.append(Event(ddev.session, event:=next(ddev.event_num), session=sdev.session))
self.template.append(Wait(event, session=ddev.session))
# Flush
for dev in devices:
dk = dev_key(dev)
staging = stagings[dk]
if not staging: continue
match cur_staging_type:
case StagingType.GRAPH:
bufs = tuple(map_rawbuf(rawbufs[i]) for i in sorted(dedup(self.input_replace.values())) if dev_key(rd(Device[rawbufs[i].device])) == dk)
dev.q(GraphAlloc(graph_num:=next(dev.graph_num), tuple(staging), tuple(bufs), var_vals))
self.template.append(GraphExec(graph_num, bufs, var_vals, wait=False, session=dev.session))
case StagingType.TRANSFER:
st = cast(list[Transfer], staging)
for host in dedup(t.dsession.host for t in st):
sbuffer_nums = [(unwrap(t.session), t.buffer_num) for t in st if t.dsession.host == host]
dbuffer_nums = [(t.dsession, t.dbuffer_num) for t in st if t.dsession.host == host]
self.template.append(BatchTransfer(sbuffer_nums, dbuffer_nums, session=dev.session))
staging.clear()
# Post-sync
if cur_staging_type == StagingType.TRANSFER:
for sdev,ddev in itertools.permutations(c2d.values(), 2):
self.template.append(Event(ddev.session, event:=next(ddev.event_num), session=sdev.session))
self.template.append(Wait(event, session=ddev.session))
cur_staging_type = new_staging_type
clobbered_buffers.clear()
for ji in jit_cache:
match ji.prg:
case CompiledRunner():
_flush(StagingType.GRAPH)
gi = GraphComputeItem(ji.prg.dev.session, ji.prg._prg.name, ji.prg._prg.datahash, tuple(unwrap(buf)._buf for buf in ji.bufs),
tuple(ji.prg.p.vars), ji.fixedvars, tuple(ji.prg.p.ins), tuple(ji.prg.p.outs),
tuple(ji.prg.p.global_size) if ji.prg.p.global_size is not None else None,
tuple(ji.prg.p.local_size) if ji.prg.p.local_size is not None else None)
stagings[dev_key(ji.prg.dev)].append(gi)
case BufferXfer():
dest, src = ji.bufs[0:2]
dest_dev, src_dev = cast(RemoteDevice, Device[unwrap(dest).device]), cast(RemoteDevice, Device[unwrap(src).device])
assert dest is not None and src is not None, ji
ti = Transfer(session=src_dev.session, buffer_num=src._buf, dsession=dest_dev.session, dbuffer_num=dest._buf)
if dev_key(dest_dev) == dev_key(src_dev):
_flush(StagingType.GRAPH)
stagings[dev_key(src_dev)].append(ti)
elif dest_dev.conn == src_dev.conn:
_flush(StagingType.NONE)
self.template.append(ti)
else:
_flush(StagingType.TRANSFER, force_break=src in clobbered_buffers)
clobbered_buffers.add(dest)
stagings[dev_key(src_dev)].append(ti)
case _: raise NotImplementedError(ji.prg)
_flush(StagingType.NONE)
def __del__(self):
for req in self.template:
match req:
case GraphExec(): RemoteConnection(unwrap(req.session).host).q(GraphFree(req.graph_num, session=req.session))
def __call__(self, rawbufs: list[Buffer], var_vals: dict[Variable, int], wait=False):
if wait: st = time.perf_counter()
rmap = {orig: map_rawbuf(rawbufs[replace_idx]) for orig,replace_idx in self.handle_indexes.items()}
for req in self.template:
match req:
case GraphExec():
req = replace(req, bufs=tuple(rmap[buf] for buf in req.bufs), var_vals=var_vals, wait=wait)
case Transfer():
if (req.session, req.buffer_num) in rmap: req = replace(req, buffer_num=rmap[(req.session, req.buffer_num)][1])
if (req.dsession, req.dbuffer_num) in rmap: req = replace(req, dbuffer_num=rmap[(req.dsession, req.dbuffer_num)][1])
case BatchTransfer():
req = replace(req, sbuffer_nums=[rmap.get(b, b) for b in req.sbuffer_nums], dbuffer_nums=[rmap.get(b, b) for b in req.dbuffer_nums])
case Event()|Wait():
pass # event number can be reused
case _: raise NotImplementedError(req)
RemoteConnection(unwrap(req.session).host).q(req)
if wait:
RemoteConnection(unwrap(req.session).host).batch_submit()
return time.perf_counter() - st