openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

172 lines
9.9 KiB

from __future__ import annotations
import resource, ctypes, weakref, functools, itertools, tinygrad.runtime.autogen.ib as ib
from typing import Iterator
from dataclasses import dataclass
from weakref import WeakKeyDictionary
from tinygrad.device import Buffer, DMACPURef, DMAFdRef
from tinygrad.helpers import getenv, round_up, DEBUG
DEFAULT_PORT, DEFAULT_GID = getenv("DEFAULT_PORT", 1), getenv("DEFAULT_GID", 3) # DEFAULT_GID=0 for RXE
IOVA_ALIGN = resource.getpagesize()
def checkz(x, ret=None):
assert x == 0, f'{x} != 0 (errno {ctypes.get_errno()})'
return ret
@dataclass(frozen=True)
class SGE:
dst_iova: int
dst_key: int
src_iova: int
src_key: int
size: int
class IBCtx:
def __init__(self, idx:int):
# Open the device (aka Host Channel Adapter in ib-speak)
devs = ib.ibv_get_device_list(ctypes.byref(ndevs:=ctypes.c_int32()))
if idx >= ndevs.value: raise IndexError(f"{idx} > {ndevs.value}")
self.ctx = ib.ibv_open_device(devs[idx])
ib.ibv_free_device_list(devs)
# HACK: remove this (and all usage of `ctx.contents.ops`) when clang2py can deal with `static inline` wrapper-functions
self.vctx = ctypes.cast(ctypes.addressof(self.ctx.contents) - ib.struct_verbs_context.context.offset, ctypes.POINTER(ib.struct_verbs_context))
# Get attributes. Something like port_attr.max_msg_sz sound like it might requre taking the min of host's and remote's attributes if they differ
self.device_attr = checkz(ib.ibv_query_device(self.ctx, ctypes.byref(da:=ib.struct_ibv_device_attr())), da)
self.port_attr = checkz(self.vctx.contents.query_port(self.ctx, DEFAULT_PORT, ctypes.byref(pa:=ib.struct_ibv_port_attr()), ctypes.sizeof(pa)), pa)
self.gid_attr = checkz(ib.ibv_query_gid(self.ctx, DEFAULT_PORT, DEFAULT_GID, ctypes.byref(ga:=ib.union_ibv_gid())), ga)
# Allocate protection domain
self.pd = ib.ibv_alloc_pd(self.ctx)
self.next_iova: int = IOVA_ALIGN # don't start at zero (nullptr)
# weakref(buf) => (iova, mr, mr_dealloc). mr_dealloc is kept here to avoid double freeing mrs that are deallocated in __del__
self.mrs: WeakKeyDictionary[Buffer, tuple[int, ctypes._Pointer[ib.struct_ibv_mr], weakref.finalize]] = WeakKeyDictionary()
# Default soft fd limit is 1024, which is not enough, set soft to hard (maximum allowed by the os)
IBCtx.rlimit_fix()
def __del__(self):
# must deallocate all mrs in protection domain before deallocating the protection domain
if hasattr(self, "mrs"): [fin() for _,_,fin in self.mrs.values()]
if hasattr(self, "pd"): ib.ibv_dealloc_pd(self.pd)
if hasattr(self, "ctx"): ib.ibv_close_device(self.ctx)
@functools.cache # run once
@staticmethod
def rlimit_fix():
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
if DEBUG>=2: print(f"IB: Increased fd limit from {soft} to {hard}")
def alloc_iova(self, size:int, required_offset:int):
iova = round_up(self.next_iova - required_offset, IOVA_ALIGN) + required_offset
self.next_iova = iova + size
return iova
def reg(self, buf:Buffer) -> tuple[int, ctypes._Pointer[ib.struct_ibv_mr]]:
buf = buf.base
if buf not in self.mrs:
if buf.nbytes > self.device_attr.max_mr_size: raise RuntimeError(f"Buffer too big: {buf.nbytes:#x} > {self.device_attr.max_mr_size:#x}")
if len(self.mrs) >= self.device_attr.max_mr: raise RuntimeError(f"Out of memory region cap: {len(self.mrs)} >= {self.device_attr.max_mr}")
# Local read is implied (but still have to create the memory region, except for short sends/writes with IBV_SEND_INLINE that are inlined by cpu)
mr_flags = ib.IBV_ACCESS_LOCAL_WRITE | ib.IBV_ACCESS_REMOTE_READ | ib.IBV_ACCESS_REMOTE_WRITE
match (dmaref:=buf.as_dmaref()):
case DMACPURef():
iova = self.alloc_iova(dmaref.size, dmaref.addr % IOVA_ALIGN)
mr = ib.ibv_reg_mr_iova2(self.pd, ctypes.c_void_p(dmaref.addr), dmaref.size, iova, mr_flags)
case DMAFdRef():
iova = self.alloc_iova(dmaref.size, dmaref.offset % IOVA_ALIGN)
mr = ib.ibv_reg_dmabuf_mr(self.pd, dmaref.offset, dmaref.size, iova, dmaref.fd, mr_flags)
case _: raise RuntimeError(f"Unknown type of dma ref: {dmaref}")
if not mr: raise RuntimeError(f"Couldn't register memory region for {buf} {dmaref} (errno={ctypes.get_errno()})")
self.mrs[buf] = (iova, mr, weakref.finalize(buf, ib.ibv_dereg_mr, mr))
return self.mrs[buf][0:2]
class IBConn:
def __init__(self, ctx:IBCtx):
self.ctx = ctx
# Create Completion Channel. It is a file descriptor that kernel sends notifications through, not a thing in infiniband spec, just linux-ism
self.comp_channel = ib.ibv_create_comp_channel(self.ctx.ctx)
# Create Completion Queue. When a Work Request with signaled flag is completed a Completion Queue Entry is pushed onto this queue
self.cq = ib.ibv_create_cq(self.ctx.ctx, _capacity:=256, _cq_context:=None, self.comp_channel, _comp_vector:=0)
self.pending_wrids: set[int] = set()
self.wrid_num: Iterator[int] = itertools.count(0) # wc_id is uint64, this will never overflow
# Create Queue Pair. It's the closest thing to a socket in infiniband with QP num being the closest thing to a port, except it's allocated by hca
qp_init_attrs_cap = ib.struct_ibv_qp_cap(max_send_wr=1024, max_recv_wr=64, max_send_sge=8, max_recv_sge=8, max_inline_data=64)
qp_init_attrs = ib.struct_ibv_qp_init_attr(send_cq=self.cq, recv_cq=self.cq, cap=qp_init_attrs_cap, qp_type=ib.IBV_QPT_RC) # Reliable Connection
self.qp = ib.ibv_create_qp(self.ctx.pd, ctypes.byref(qp_init_attrs))
self.qp_cap = qp_init_attrs.cap
# The most important thing about QPs is their state, when a new QP is created it's in the RESET state, before it can be properly used it has to go
# through Init, Ready To Receive, Ready To Send. A good docs on QP state machine: https://www.rdmamojo.com/2012/05/05/qp-state-machine/
# INIT
qp_access_flags = ib.IBV_ACCESS_REMOTE_WRITE | ib.IBV_ACCESS_REMOTE_READ
qpa = ib.struct_ibv_qp_attr(qp_state=ib.IBV_QPS_INIT, port_num=DEFAULT_PORT, qp_access_flags=qp_access_flags)
checkz(ib.ibv_modify_qp(self.qp, qpa, ib.IBV_QP_STATE | ib.IBV_QP_PORT | ib.IBV_QP_ACCESS_FLAGS | ib.IBV_QP_PKEY_INDEX))
self.gid, self.qp_num = bytes(self.ctx.gid_attr.raw), self.qp.contents.qp_num
# Exchange GID and QP num with remote. At least in RoCEv2 gid can be guessed from remote's ip, QP num can't.
def connect(self, remote_gid:bytes, remote_qp_num:int):
# RTR
qp_ah_attr_grh = ib.struct_ibv_global_route(hop_limit=1, dgid=ib.union_ibv_gid(raw=(ctypes.c_ubyte * 16)(*remote_gid)), sgid_index=DEFAULT_GID)
qp_ah_attr = ib.struct_ibv_ah_attr(is_global=1, port_num=DEFAULT_PORT, grh=qp_ah_attr_grh)
qpa = ib.struct_ibv_qp_attr(qp_state=ib.IBV_QPS_RTR, path_mtu=ib.IBV_MTU_4096, dest_qp_num=remote_qp_num, rq_psn=0, max_dest_rd_atomic=1,
min_rnr_timer=12, ah_attr=qp_ah_attr)
checkz(ib.ibv_modify_qp(self.qp, qpa, ib.IBV_QP_STATE | ib.IBV_QP_PATH_MTU | ib.IBV_QP_DEST_QPN | ib.IBV_QP_RQ_PSN | \
ib.IBV_QP_MAX_DEST_RD_ATOMIC | ib.IBV_QP_MIN_RNR_TIMER | ib.IBV_QP_AV))
# RTS
qpa = ib.struct_ibv_qp_attr(qp_state=ib.IBV_QPS_RTS, timeout=14, retry_cnt=7, rnr_retry=7, sq_psn=0, max_rd_atomic=1)
checkz(ib.ibv_modify_qp(self.qp, qpa, ib.IBV_QP_STATE | ib.IBV_QP_TIMEOUT | ib.IBV_QP_RETRY_CNT | ib.IBV_QP_RNR_RETRY | ib.IBV_QP_SQ_PSN | \
ib.IBV_QP_MAX_QP_RD_ATOMIC))
def __del__(self):
self.wait_cq() # need to wait for **everything** to complete before it's safe to dealloc queues and stuff
ib.ibv_destroy_qp(self.qp)
ib.ibv_destroy_cq(self.cq)
ib.ibv_destroy_comp_channel(self.comp_channel)
def next_wrid(self):
self.pending_wrids.add(wrid:=next(self.wrid_num))
return wrid
def wait_cq(self, wr_id: int|None=None):
while (wr_id in self.pending_wrids) if wr_id is not None else self.pending_wrids:
if self.ctx.ctx.contents.ops.poll_cq(self.cq, _num_entries:=1, ctypes.byref(wc:=ib.struct_ibv_wc())):
if wc.status != ib.IBV_WC_SUCCESS:
raise RuntimeError(f'Work Request completed with error: wr_id={wc.wr_id} status={ib.ibv_wc_status__enumvalues.get(wc.status, wc.status)}')
self.pending_wrids.remove(wc.wr_id)
def rdma_write(self, sgl:list[SGE]):
swr: ctypes._Pointer[ib.struct_ibv_send_wr]|None = None
swr_cnt, wr_id = 0, self.next_wrid()
def _post():
nonlocal swr, swr_cnt, wr_id
if swr is not None:
# The swr can be freed when this returns, the memory that sge points to can be unmapped after work completion is retrieved from cq
checkz(self.ctx.ctx.contents.ops.post_send(self.qp, swr, ctypes.byref(_bad_wr:=ctypes.POINTER(ib.struct_ibv_send_wr)())))
# TODO: async
self.wait_cq(wr_id)
swr, swr_cnt, wr_id = None, 0, self.next_wrid()
# Everything is in reverse for elegant chaining
for sg in reversed(sgl):
# Message size limit (max 2GB per ib spec, 1GB on tinybox mellanoxes) applies to both scatter-gather entries and entire wrs
for off in reversed(range(0, sg.size, self.ctx.port_attr.max_msg_sz)):
# Scatter-Gather Entry for local memory
sge = ctypes.pointer(ib.struct_ibv_sge(addr=sg.src_iova+off, length=min(sg.size-off, self.ctx.port_attr.max_msg_sz), lkey=sg.src_key))
# RDMA struct for remote memory
wr = ib.union_ibv_send_wr_wr(rdma=ib.struct_ibv_send_wr_1_rdma(remote_addr=sg.dst_iova+off, rkey=sg.dst_key))
# Signal (with chosen work request id) if it's the last wr (first in the loop since it's reversed)
wid, flags = (wr_id, ib.IBV_SEND_SIGNALED) if swr is None else (0, 0)
# Create Send Request
swr = ctypes.pointer(ib.struct_ibv_send_wr(opcode=ib.IBV_WR_RDMA_WRITE, sg_list=sge, num_sge=1, wr=wr, wr_id=wid, send_flags=flags, next=swr))
# Flush if queue is being overrun
if (swr_cnt:=swr_cnt + 1) >= self.qp_cap.max_send_wr: _post()
_post()