You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							172 lines
						
					
					
						
							9.9 KiB
						
					
					
				
			
		
		
	
	
							172 lines
						
					
					
						
							9.9 KiB
						
					
					
				| from __future__ import annotations
 | |
| import resource, ctypes, weakref, functools, itertools, tinygrad.runtime.autogen.ib as ib
 | |
| from typing import Iterator
 | |
| from dataclasses import dataclass
 | |
| from weakref import WeakKeyDictionary
 | |
| from tinygrad.device import Buffer, DMACPURef, DMAFdRef
 | |
| from tinygrad.helpers import getenv, round_up, DEBUG
 | |
| 
 | |
| DEFAULT_PORT, DEFAULT_GID = getenv("DEFAULT_PORT", 1), getenv("DEFAULT_GID", 3) # DEFAULT_GID=0 for RXE
 | |
| IOVA_ALIGN = resource.getpagesize()
 | |
| 
 | |
| def checkz(x, ret=None):
 | |
|   assert x == 0, f'{x} != 0 (errno {ctypes.get_errno()})'
 | |
|   return ret
 | |
| 
 | |
| @dataclass(frozen=True)
 | |
| class SGE:
 | |
|   dst_iova: int
 | |
|   dst_key: int
 | |
|   src_iova: int
 | |
|   src_key: int
 | |
|   size: int
 | |
| 
 | |
| class IBCtx:
 | |
|   def __init__(self, idx:int):
 | |
|     # Open the device (aka Host Channel Adapter in ib-speak)
 | |
|     devs = ib.ibv_get_device_list(ctypes.byref(ndevs:=ctypes.c_int32()))
 | |
|     if idx >= ndevs.value: raise IndexError(f"{idx} > {ndevs.value}")
 | |
|     self.ctx = ib.ibv_open_device(devs[idx])
 | |
|     ib.ibv_free_device_list(devs)
 | |
| 
 | |
|     # HACK: remove this (and all usage of `ctx.contents.ops`) when clang2py can deal with `static inline` wrapper-functions
 | |
|     self.vctx = ctypes.cast(ctypes.addressof(self.ctx.contents) - ib.struct_verbs_context.context.offset, ctypes.POINTER(ib.struct_verbs_context))
 | |
| 
 | |
|     # Get attributes. Something like port_attr.max_msg_sz sound like it might requre taking the min of host's and remote's attributes if they differ
 | |
|     self.device_attr = checkz(ib.ibv_query_device(self.ctx, ctypes.byref(da:=ib.struct_ibv_device_attr())), da)
 | |
|     self.port_attr = checkz(self.vctx.contents.query_port(self.ctx, DEFAULT_PORT, ctypes.byref(pa:=ib.struct_ibv_port_attr()), ctypes.sizeof(pa)), pa)
 | |
|     self.gid_attr = checkz(ib.ibv_query_gid(self.ctx, DEFAULT_PORT, DEFAULT_GID, ctypes.byref(ga:=ib.union_ibv_gid())), ga)
 | |
| 
 | |
|     # Allocate protection domain
 | |
|     self.pd = ib.ibv_alloc_pd(self.ctx)
 | |
|     self.next_iova: int = IOVA_ALIGN # don't start at zero (nullptr)
 | |
| 
 | |
|     # weakref(buf) => (iova, mr, mr_dealloc). mr_dealloc is kept here to avoid double freeing mrs that are deallocated in __del__
 | |
|     self.mrs: WeakKeyDictionary[Buffer, tuple[int, ctypes._Pointer[ib.struct_ibv_mr], weakref.finalize]] = WeakKeyDictionary()
 | |
| 
 | |
|     # Default soft fd limit is 1024, which is not enough, set soft to hard (maximum allowed by the os)
 | |
|     IBCtx.rlimit_fix()
 | |
| 
 | |
|   def __del__(self):
 | |
|     # must deallocate all mrs in protection domain before deallocating the protection domain
 | |
|     if hasattr(self, "mrs"): [fin() for _,_,fin in self.mrs.values()]
 | |
|     if hasattr(self, "pd"): ib.ibv_dealloc_pd(self.pd)
 | |
|     if hasattr(self, "ctx"): ib.ibv_close_device(self.ctx)
 | |
| 
 | |
|   @functools.cache # run once
 | |
|   @staticmethod
 | |
|   def rlimit_fix():
 | |
|     soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
 | |
|     resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
 | |
|     if DEBUG>=2: print(f"IB: Increased fd limit from {soft} to {hard}")
 | |
| 
 | |
|   def alloc_iova(self, size:int, required_offset:int):
 | |
|     iova = round_up(self.next_iova - required_offset, IOVA_ALIGN) + required_offset
 | |
|     self.next_iova = iova + size
 | |
|     return iova
 | |
| 
 | |
|   def reg(self, buf:Buffer) -> tuple[int, ctypes._Pointer[ib.struct_ibv_mr]]:
 | |
|     buf = buf.base
 | |
|     if buf not in self.mrs:
 | |
|       if buf.nbytes > self.device_attr.max_mr_size: raise RuntimeError(f"Buffer too big: {buf.nbytes:#x} > {self.device_attr.max_mr_size:#x}")
 | |
|       if len(self.mrs) >= self.device_attr.max_mr: raise RuntimeError(f"Out of memory region cap: {len(self.mrs)} >= {self.device_attr.max_mr}")
 | |
|       # Local read is implied (but still have to create the memory region, except for short sends/writes with IBV_SEND_INLINE that are inlined by cpu)
 | |
|       mr_flags = ib.IBV_ACCESS_LOCAL_WRITE | ib.IBV_ACCESS_REMOTE_READ | ib.IBV_ACCESS_REMOTE_WRITE
 | |
|       match (dmaref:=buf.as_dmaref()):
 | |
|         case DMACPURef():
 | |
|           iova = self.alloc_iova(dmaref.size, dmaref.addr % IOVA_ALIGN)
 | |
|           mr = ib.ibv_reg_mr_iova2(self.pd, ctypes.c_void_p(dmaref.addr), dmaref.size, iova, mr_flags)
 | |
|         case DMAFdRef():
 | |
|           iova = self.alloc_iova(dmaref.size, dmaref.offset % IOVA_ALIGN)
 | |
|           mr = ib.ibv_reg_dmabuf_mr(self.pd, dmaref.offset, dmaref.size, iova, dmaref.fd, mr_flags)
 | |
|         case _: raise RuntimeError(f"Unknown type of dma ref: {dmaref}")
 | |
|       if not mr: raise RuntimeError(f"Couldn't register memory region for {buf} {dmaref} (errno={ctypes.get_errno()})")
 | |
|       self.mrs[buf] = (iova, mr, weakref.finalize(buf, ib.ibv_dereg_mr, mr))
 | |
|     return self.mrs[buf][0:2]
 | |
| 
 | |
| class IBConn:
 | |
|   def __init__(self, ctx:IBCtx):
 | |
|     self.ctx = ctx
 | |
| 
 | |
|     # Create Completion Channel. It is a file descriptor that kernel sends notifications through, not a thing in infiniband spec, just linux-ism
 | |
|     self.comp_channel = ib.ibv_create_comp_channel(self.ctx.ctx)
 | |
|     # Create Completion Queue. When a Work Request with signaled flag is completed a Completion Queue Entry is pushed onto this queue
 | |
|     self.cq = ib.ibv_create_cq(self.ctx.ctx, _capacity:=256, _cq_context:=None, self.comp_channel, _comp_vector:=0)
 | |
|     self.pending_wrids: set[int] = set()
 | |
|     self.wrid_num: Iterator[int] = itertools.count(0) # wc_id is uint64, this will never overflow
 | |
| 
 | |
|     # Create Queue Pair. It's the closest thing to a socket in infiniband with QP num being the closest thing to a port, except it's allocated by hca
 | |
|     qp_init_attrs_cap = ib.struct_ibv_qp_cap(max_send_wr=1024, max_recv_wr=64, max_send_sge=8, max_recv_sge=8, max_inline_data=64)
 | |
|     qp_init_attrs = ib.struct_ibv_qp_init_attr(send_cq=self.cq, recv_cq=self.cq, cap=qp_init_attrs_cap, qp_type=ib.IBV_QPT_RC) # Reliable Connection
 | |
|     self.qp = ib.ibv_create_qp(self.ctx.pd, ctypes.byref(qp_init_attrs))
 | |
|     self.qp_cap = qp_init_attrs.cap
 | |
| 
 | |
|     # The most important thing about QPs is their state, when a new QP is created it's in the RESET state, before it can be properly used it has to go
 | |
|     # through Init, Ready To Receive, Ready To Send. A good docs on QP state machine: https://www.rdmamojo.com/2012/05/05/qp-state-machine/
 | |
| 
 | |
|     # INIT
 | |
|     qp_access_flags = ib.IBV_ACCESS_REMOTE_WRITE | ib.IBV_ACCESS_REMOTE_READ
 | |
|     qpa = ib.struct_ibv_qp_attr(qp_state=ib.IBV_QPS_INIT, port_num=DEFAULT_PORT, qp_access_flags=qp_access_flags)
 | |
|     checkz(ib.ibv_modify_qp(self.qp, qpa, ib.IBV_QP_STATE | ib.IBV_QP_PORT | ib.IBV_QP_ACCESS_FLAGS | ib.IBV_QP_PKEY_INDEX))
 | |
| 
 | |
|     self.gid, self.qp_num = bytes(self.ctx.gid_attr.raw), self.qp.contents.qp_num
 | |
| 
 | |
|   # Exchange GID and QP num with remote. At least in RoCEv2 gid can be guessed from remote's ip, QP num can't.
 | |
| 
 | |
|   def connect(self, remote_gid:bytes, remote_qp_num:int):
 | |
|     # RTR
 | |
|     qp_ah_attr_grh = ib.struct_ibv_global_route(hop_limit=1, dgid=ib.union_ibv_gid(raw=(ctypes.c_ubyte * 16)(*remote_gid)), sgid_index=DEFAULT_GID)
 | |
|     qp_ah_attr = ib.struct_ibv_ah_attr(is_global=1, port_num=DEFAULT_PORT, grh=qp_ah_attr_grh)
 | |
|     qpa = ib.struct_ibv_qp_attr(qp_state=ib.IBV_QPS_RTR, path_mtu=ib.IBV_MTU_4096, dest_qp_num=remote_qp_num, rq_psn=0, max_dest_rd_atomic=1,
 | |
|                                 min_rnr_timer=12, ah_attr=qp_ah_attr)
 | |
|     checkz(ib.ibv_modify_qp(self.qp, qpa, ib.IBV_QP_STATE | ib.IBV_QP_PATH_MTU | ib.IBV_QP_DEST_QPN | ib.IBV_QP_RQ_PSN | \
 | |
|                                           ib.IBV_QP_MAX_DEST_RD_ATOMIC | ib.IBV_QP_MIN_RNR_TIMER | ib.IBV_QP_AV))
 | |
| 
 | |
|     # RTS
 | |
|     qpa = ib.struct_ibv_qp_attr(qp_state=ib.IBV_QPS_RTS, timeout=14, retry_cnt=7, rnr_retry=7, sq_psn=0, max_rd_atomic=1)
 | |
|     checkz(ib.ibv_modify_qp(self.qp, qpa, ib.IBV_QP_STATE | ib.IBV_QP_TIMEOUT | ib.IBV_QP_RETRY_CNT | ib.IBV_QP_RNR_RETRY | ib.IBV_QP_SQ_PSN | \
 | |
|                                           ib.IBV_QP_MAX_QP_RD_ATOMIC))
 | |
| 
 | |
|   def __del__(self):
 | |
|     self.wait_cq() # need to wait for **everything** to complete before it's safe to dealloc queues and stuff
 | |
|     ib.ibv_destroy_qp(self.qp)
 | |
|     ib.ibv_destroy_cq(self.cq)
 | |
|     ib.ibv_destroy_comp_channel(self.comp_channel)
 | |
| 
 | |
|   def next_wrid(self):
 | |
|     self.pending_wrids.add(wrid:=next(self.wrid_num))
 | |
|     return wrid
 | |
| 
 | |
|   def wait_cq(self, wr_id: int|None=None):
 | |
|     while (wr_id in self.pending_wrids) if wr_id is not None else self.pending_wrids:
 | |
|       if self.ctx.ctx.contents.ops.poll_cq(self.cq, _num_entries:=1, ctypes.byref(wc:=ib.struct_ibv_wc())):
 | |
|         if wc.status != ib.IBV_WC_SUCCESS:
 | |
|           raise RuntimeError(f'Work Request completed with error: wr_id={wc.wr_id} status={ib.ibv_wc_status__enumvalues.get(wc.status, wc.status)}')
 | |
|         self.pending_wrids.remove(wc.wr_id)
 | |
| 
 | |
|   def rdma_write(self, sgl:list[SGE]):
 | |
|     swr: ctypes._Pointer[ib.struct_ibv_send_wr]|None = None
 | |
|     swr_cnt, wr_id = 0, self.next_wrid()
 | |
|     def _post():
 | |
|       nonlocal swr, swr_cnt, wr_id
 | |
|       if swr is not None:
 | |
|         # The swr can be freed when this returns, the memory that sge points to can be unmapped after work completion is retrieved from cq
 | |
|         checkz(self.ctx.ctx.contents.ops.post_send(self.qp, swr, ctypes.byref(_bad_wr:=ctypes.POINTER(ib.struct_ibv_send_wr)())))
 | |
|         # TODO: async
 | |
|         self.wait_cq(wr_id)
 | |
|         swr, swr_cnt, wr_id = None, 0, self.next_wrid()
 | |
|     # Everything is in reverse for elegant chaining
 | |
|     for sg in reversed(sgl):
 | |
|       # Message size limit (max 2GB per ib spec, 1GB on tinybox mellanoxes) applies to both scatter-gather entries and entire wrs
 | |
|       for off in reversed(range(0, sg.size, self.ctx.port_attr.max_msg_sz)):
 | |
|         # Scatter-Gather Entry for local memory
 | |
|         sge = ctypes.pointer(ib.struct_ibv_sge(addr=sg.src_iova+off, length=min(sg.size-off, self.ctx.port_attr.max_msg_sz), lkey=sg.src_key))
 | |
|         # RDMA struct for remote memory
 | |
|         wr = ib.union_ibv_send_wr_wr(rdma=ib.struct_ibv_send_wr_1_rdma(remote_addr=sg.dst_iova+off, rkey=sg.dst_key))
 | |
|         # Signal (with chosen work request id) if it's the last wr (first in the loop since it's reversed)
 | |
|         wid, flags = (wr_id, ib.IBV_SEND_SIGNALED) if swr is None else (0, 0)
 | |
|         # Create Send Request
 | |
|         swr = ctypes.pointer(ib.struct_ibv_send_wr(opcode=ib.IBV_WR_RDMA_WRITE, sg_list=sge, num_sge=1, wr=wr, wr_id=wid, send_flags=flags, next=swr))
 | |
|         # Flush if queue is being overrun
 | |
|         if (swr_cnt:=swr_cnt + 1) >= self.qp_cap.max_send_wr: _post()
 | |
|     _post()
 | |
| 
 |