You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
172 lines
9.9 KiB
172 lines
9.9 KiB
from __future__ import annotations
|
|
import resource, ctypes, weakref, functools, itertools, tinygrad.runtime.autogen.ib as ib
|
|
from typing import Iterator
|
|
from dataclasses import dataclass
|
|
from weakref import WeakKeyDictionary
|
|
from tinygrad.device import Buffer, DMACPURef, DMAFdRef
|
|
from tinygrad.helpers import getenv, round_up, DEBUG
|
|
|
|
DEFAULT_PORT, DEFAULT_GID = getenv("DEFAULT_PORT", 1), getenv("DEFAULT_GID", 3) # DEFAULT_GID=0 for RXE
|
|
IOVA_ALIGN = resource.getpagesize()
|
|
|
|
def checkz(x, ret=None):
|
|
assert x == 0, f'{x} != 0 (errno {ctypes.get_errno()})'
|
|
return ret
|
|
|
|
@dataclass(frozen=True)
|
|
class SGE:
|
|
dst_iova: int
|
|
dst_key: int
|
|
src_iova: int
|
|
src_key: int
|
|
size: int
|
|
|
|
class IBCtx:
|
|
def __init__(self, idx:int):
|
|
# Open the device (aka Host Channel Adapter in ib-speak)
|
|
devs = ib.ibv_get_device_list(ctypes.byref(ndevs:=ctypes.c_int32()))
|
|
if idx >= ndevs.value: raise IndexError(f"{idx} > {ndevs.value}")
|
|
self.ctx = ib.ibv_open_device(devs[idx])
|
|
ib.ibv_free_device_list(devs)
|
|
|
|
# HACK: remove this (and all usage of `ctx.contents.ops`) when clang2py can deal with `static inline` wrapper-functions
|
|
self.vctx = ctypes.cast(ctypes.addressof(self.ctx.contents) - ib.struct_verbs_context.context.offset, ctypes.POINTER(ib.struct_verbs_context))
|
|
|
|
# Get attributes. Something like port_attr.max_msg_sz sound like it might requre taking the min of host's and remote's attributes if they differ
|
|
self.device_attr = checkz(ib.ibv_query_device(self.ctx, ctypes.byref(da:=ib.struct_ibv_device_attr())), da)
|
|
self.port_attr = checkz(self.vctx.contents.query_port(self.ctx, DEFAULT_PORT, ctypes.byref(pa:=ib.struct_ibv_port_attr()), ctypes.sizeof(pa)), pa)
|
|
self.gid_attr = checkz(ib.ibv_query_gid(self.ctx, DEFAULT_PORT, DEFAULT_GID, ctypes.byref(ga:=ib.union_ibv_gid())), ga)
|
|
|
|
# Allocate protection domain
|
|
self.pd = ib.ibv_alloc_pd(self.ctx)
|
|
self.next_iova: int = IOVA_ALIGN # don't start at zero (nullptr)
|
|
|
|
# weakref(buf) => (iova, mr, mr_dealloc). mr_dealloc is kept here to avoid double freeing mrs that are deallocated in __del__
|
|
self.mrs: WeakKeyDictionary[Buffer, tuple[int, ctypes._Pointer[ib.struct_ibv_mr], weakref.finalize]] = WeakKeyDictionary()
|
|
|
|
# Default soft fd limit is 1024, which is not enough, set soft to hard (maximum allowed by the os)
|
|
IBCtx.rlimit_fix()
|
|
|
|
def __del__(self):
|
|
# must deallocate all mrs in protection domain before deallocating the protection domain
|
|
if hasattr(self, "mrs"): [fin() for _,_,fin in self.mrs.values()]
|
|
if hasattr(self, "pd"): ib.ibv_dealloc_pd(self.pd)
|
|
if hasattr(self, "ctx"): ib.ibv_close_device(self.ctx)
|
|
|
|
@functools.cache # run once
|
|
@staticmethod
|
|
def rlimit_fix():
|
|
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
|
resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
|
|
if DEBUG>=2: print(f"IB: Increased fd limit from {soft} to {hard}")
|
|
|
|
def alloc_iova(self, size:int, required_offset:int):
|
|
iova = round_up(self.next_iova - required_offset, IOVA_ALIGN) + required_offset
|
|
self.next_iova = iova + size
|
|
return iova
|
|
|
|
def reg(self, buf:Buffer) -> tuple[int, ctypes._Pointer[ib.struct_ibv_mr]]:
|
|
buf = buf.base
|
|
if buf not in self.mrs:
|
|
if buf.nbytes > self.device_attr.max_mr_size: raise RuntimeError(f"Buffer too big: {buf.nbytes:#x} > {self.device_attr.max_mr_size:#x}")
|
|
if len(self.mrs) >= self.device_attr.max_mr: raise RuntimeError(f"Out of memory region cap: {len(self.mrs)} >= {self.device_attr.max_mr}")
|
|
# Local read is implied (but still have to create the memory region, except for short sends/writes with IBV_SEND_INLINE that are inlined by cpu)
|
|
mr_flags = ib.IBV_ACCESS_LOCAL_WRITE | ib.IBV_ACCESS_REMOTE_READ | ib.IBV_ACCESS_REMOTE_WRITE
|
|
match (dmaref:=buf.as_dmaref()):
|
|
case DMACPURef():
|
|
iova = self.alloc_iova(dmaref.size, dmaref.addr % IOVA_ALIGN)
|
|
mr = ib.ibv_reg_mr_iova2(self.pd, ctypes.c_void_p(dmaref.addr), dmaref.size, iova, mr_flags)
|
|
case DMAFdRef():
|
|
iova = self.alloc_iova(dmaref.size, dmaref.offset % IOVA_ALIGN)
|
|
mr = ib.ibv_reg_dmabuf_mr(self.pd, dmaref.offset, dmaref.size, iova, dmaref.fd, mr_flags)
|
|
case _: raise RuntimeError(f"Unknown type of dma ref: {dmaref}")
|
|
if not mr: raise RuntimeError(f"Couldn't register memory region for {buf} {dmaref} (errno={ctypes.get_errno()})")
|
|
self.mrs[buf] = (iova, mr, weakref.finalize(buf, ib.ibv_dereg_mr, mr))
|
|
return self.mrs[buf][0:2]
|
|
|
|
class IBConn:
|
|
def __init__(self, ctx:IBCtx):
|
|
self.ctx = ctx
|
|
|
|
# Create Completion Channel. It is a file descriptor that kernel sends notifications through, not a thing in infiniband spec, just linux-ism
|
|
self.comp_channel = ib.ibv_create_comp_channel(self.ctx.ctx)
|
|
# Create Completion Queue. When a Work Request with signaled flag is completed a Completion Queue Entry is pushed onto this queue
|
|
self.cq = ib.ibv_create_cq(self.ctx.ctx, _capacity:=256, _cq_context:=None, self.comp_channel, _comp_vector:=0)
|
|
self.pending_wrids: set[int] = set()
|
|
self.wrid_num: Iterator[int] = itertools.count(0) # wc_id is uint64, this will never overflow
|
|
|
|
# Create Queue Pair. It's the closest thing to a socket in infiniband with QP num being the closest thing to a port, except it's allocated by hca
|
|
qp_init_attrs_cap = ib.struct_ibv_qp_cap(max_send_wr=1024, max_recv_wr=64, max_send_sge=8, max_recv_sge=8, max_inline_data=64)
|
|
qp_init_attrs = ib.struct_ibv_qp_init_attr(send_cq=self.cq, recv_cq=self.cq, cap=qp_init_attrs_cap, qp_type=ib.IBV_QPT_RC) # Reliable Connection
|
|
self.qp = ib.ibv_create_qp(self.ctx.pd, ctypes.byref(qp_init_attrs))
|
|
self.qp_cap = qp_init_attrs.cap
|
|
|
|
# The most important thing about QPs is their state, when a new QP is created it's in the RESET state, before it can be properly used it has to go
|
|
# through Init, Ready To Receive, Ready To Send. A good docs on QP state machine: https://www.rdmamojo.com/2012/05/05/qp-state-machine/
|
|
|
|
# INIT
|
|
qp_access_flags = ib.IBV_ACCESS_REMOTE_WRITE | ib.IBV_ACCESS_REMOTE_READ
|
|
qpa = ib.struct_ibv_qp_attr(qp_state=ib.IBV_QPS_INIT, port_num=DEFAULT_PORT, qp_access_flags=qp_access_flags)
|
|
checkz(ib.ibv_modify_qp(self.qp, qpa, ib.IBV_QP_STATE | ib.IBV_QP_PORT | ib.IBV_QP_ACCESS_FLAGS | ib.IBV_QP_PKEY_INDEX))
|
|
|
|
self.gid, self.qp_num = bytes(self.ctx.gid_attr.raw), self.qp.contents.qp_num
|
|
|
|
# Exchange GID and QP num with remote. At least in RoCEv2 gid can be guessed from remote's ip, QP num can't.
|
|
|
|
def connect(self, remote_gid:bytes, remote_qp_num:int):
|
|
# RTR
|
|
qp_ah_attr_grh = ib.struct_ibv_global_route(hop_limit=1, dgid=ib.union_ibv_gid(raw=(ctypes.c_ubyte * 16)(*remote_gid)), sgid_index=DEFAULT_GID)
|
|
qp_ah_attr = ib.struct_ibv_ah_attr(is_global=1, port_num=DEFAULT_PORT, grh=qp_ah_attr_grh)
|
|
qpa = ib.struct_ibv_qp_attr(qp_state=ib.IBV_QPS_RTR, path_mtu=ib.IBV_MTU_4096, dest_qp_num=remote_qp_num, rq_psn=0, max_dest_rd_atomic=1,
|
|
min_rnr_timer=12, ah_attr=qp_ah_attr)
|
|
checkz(ib.ibv_modify_qp(self.qp, qpa, ib.IBV_QP_STATE | ib.IBV_QP_PATH_MTU | ib.IBV_QP_DEST_QPN | ib.IBV_QP_RQ_PSN | \
|
|
ib.IBV_QP_MAX_DEST_RD_ATOMIC | ib.IBV_QP_MIN_RNR_TIMER | ib.IBV_QP_AV))
|
|
|
|
# RTS
|
|
qpa = ib.struct_ibv_qp_attr(qp_state=ib.IBV_QPS_RTS, timeout=14, retry_cnt=7, rnr_retry=7, sq_psn=0, max_rd_atomic=1)
|
|
checkz(ib.ibv_modify_qp(self.qp, qpa, ib.IBV_QP_STATE | ib.IBV_QP_TIMEOUT | ib.IBV_QP_RETRY_CNT | ib.IBV_QP_RNR_RETRY | ib.IBV_QP_SQ_PSN | \
|
|
ib.IBV_QP_MAX_QP_RD_ATOMIC))
|
|
|
|
def __del__(self):
|
|
self.wait_cq() # need to wait for **everything** to complete before it's safe to dealloc queues and stuff
|
|
ib.ibv_destroy_qp(self.qp)
|
|
ib.ibv_destroy_cq(self.cq)
|
|
ib.ibv_destroy_comp_channel(self.comp_channel)
|
|
|
|
def next_wrid(self):
|
|
self.pending_wrids.add(wrid:=next(self.wrid_num))
|
|
return wrid
|
|
|
|
def wait_cq(self, wr_id: int|None=None):
|
|
while (wr_id in self.pending_wrids) if wr_id is not None else self.pending_wrids:
|
|
if self.ctx.ctx.contents.ops.poll_cq(self.cq, _num_entries:=1, ctypes.byref(wc:=ib.struct_ibv_wc())):
|
|
if wc.status != ib.IBV_WC_SUCCESS:
|
|
raise RuntimeError(f'Work Request completed with error: wr_id={wc.wr_id} status={ib.ibv_wc_status__enumvalues.get(wc.status, wc.status)}')
|
|
self.pending_wrids.remove(wc.wr_id)
|
|
|
|
def rdma_write(self, sgl:list[SGE]):
|
|
swr: ctypes._Pointer[ib.struct_ibv_send_wr]|None = None
|
|
swr_cnt, wr_id = 0, self.next_wrid()
|
|
def _post():
|
|
nonlocal swr, swr_cnt, wr_id
|
|
if swr is not None:
|
|
# The swr can be freed when this returns, the memory that sge points to can be unmapped after work completion is retrieved from cq
|
|
checkz(self.ctx.ctx.contents.ops.post_send(self.qp, swr, ctypes.byref(_bad_wr:=ctypes.POINTER(ib.struct_ibv_send_wr)())))
|
|
# TODO: async
|
|
self.wait_cq(wr_id)
|
|
swr, swr_cnt, wr_id = None, 0, self.next_wrid()
|
|
# Everything is in reverse for elegant chaining
|
|
for sg in reversed(sgl):
|
|
# Message size limit (max 2GB per ib spec, 1GB on tinybox mellanoxes) applies to both scatter-gather entries and entire wrs
|
|
for off in reversed(range(0, sg.size, self.ctx.port_attr.max_msg_sz)):
|
|
# Scatter-Gather Entry for local memory
|
|
sge = ctypes.pointer(ib.struct_ibv_sge(addr=sg.src_iova+off, length=min(sg.size-off, self.ctx.port_attr.max_msg_sz), lkey=sg.src_key))
|
|
# RDMA struct for remote memory
|
|
wr = ib.union_ibv_send_wr_wr(rdma=ib.struct_ibv_send_wr_1_rdma(remote_addr=sg.dst_iova+off, rkey=sg.dst_key))
|
|
# Signal (with chosen work request id) if it's the last wr (first in the loop since it's reversed)
|
|
wid, flags = (wr_id, ib.IBV_SEND_SIGNALED) if swr is None else (0, 0)
|
|
# Create Send Request
|
|
swr = ctypes.pointer(ib.struct_ibv_send_wr(opcode=ib.IBV_WR_RDMA_WRITE, sg_list=sge, num_sge=1, wr=wr, wr_id=wid, send_flags=flags, next=swr))
|
|
# Flush if queue is being overrun
|
|
if (swr_cnt:=swr_cnt + 1) >= self.qp_cap.max_send_wr: _post()
|
|
_post()
|
|
|