openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

551 lines
19 KiB

/**
* @file broker.cuh
* @brief Utility for multiprocess data exchange and synchronization.
*
* This file provides the KittensBroker class, which enables efficient inter-process
* communication and synchronization using POSIX shared memory, semaphores, and sockets.
* The broker is designed to work in multi-GPU environments where processes need to
* exchange data and synchronize execution across different local ranks.
*
* @note This implementation relies on POSIX IPC mechanisms and is intended for
* Unix-like systems. All processes must be running on the same node.
*/
#pragma once
#include <cerrno>
#include <cstdint>
#include <cstring>
#include <fcntl.h>
#include <semaphore.h>
#include <stdexcept>
#include <sys/mman.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/un.h>
#include <sys/uio.h>
#include <unistd.h>
#include <vector>
#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64)
#error "KittensBroker is not supported on Windows"
#endif
namespace kittens {
namespace detail {
namespace broker {
static constexpr int MAX_LOCAL_WORLD_SIZE = 72;
static constexpr int VAULT_SIZE_PER_RANK = 64; // sizeof(cudaIpcMemHandle_t)
struct KittensVault {
static constexpr int INIT_CODE = 0x43617473; // "Cats"
int init;
int barrier;
int sense;
uint8_t data[MAX_LOCAL_WORLD_SIZE * VAULT_SIZE_PER_RANK];
};
static constexpr int SHM_SIZE = (sizeof(KittensVault) + 4095) / 4096 * 4096;
__host__ inline static void init_sync(
int local_rank,
volatile KittensVault *vault
) {
if (local_rank == 0) {
// initialize barrier resources
vault->barrier = 0;
vault->sense = 0;
__sync_synchronize(); // make previous writes visible
vault->init = KittensVault::INIT_CODE;
} else {
while (vault->init != KittensVault::INIT_CODE) usleep(1);
__sync_synchronize(); // see leader's previous writes
}
}
__host__ inline static void sync(
int local_world_size,
volatile KittensVault *vault
) {
if (vault->init != KittensVault::INIT_CODE)
throw std::runtime_error("KittensBroker: KittensVault not initialized");
// Phase 1
int arrived = __sync_add_and_fetch(&vault->barrier, 1);
if (arrived == local_world_size) vault->sense = 1;
while (!vault->sense) usleep(1);
// Make previous writes visible
__sync_synchronize();
// Phase 2
arrived = __sync_add_and_fetch(&vault->barrier, -1);
if (arrived == 0) vault->sense = 0;
while (vault->sense) usleep(1);
}
__host__ inline void *create_shm(const char *key, size_t size) {
int shm_fd;
shm_fd = shm_open(key, O_RDWR | O_CREAT | O_EXCL | O_CLOEXEC, 0600);
if (shm_fd < 0) {
if (errno == EEXIST)
throw std::runtime_error("KittensBroker: Named shared memory already exists");
throw std::runtime_error("KittensBroker: Failed to create shared memory");
}
if (ftruncate(shm_fd, size) != 0) {
shm_unlink(key);
close(shm_fd);
throw std::runtime_error("KittensBroker: Failed to truncate shared memory");
}
void *addr = mmap(0, size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0);
close(shm_fd);
if (addr == MAP_FAILED) {
shm_unlink(key);
throw std::runtime_error("KittensBroker: Failed to map to shared memory");
}
return addr;
}
__host__ inline void *open_shm(const char *key, size_t size) {
int shm_fd;
while (true) {
shm_fd = shm_open(key, O_RDWR | O_CLOEXEC, 0);
if (shm_fd >= 0)
break;
if (errno != ENOENT)
throw std::runtime_error("KittensBroker: Failed to open shared memory");
usleep(1);
}
struct stat shm_st;
do {
if (fstat(shm_fd, &shm_st) != 0) {
shm_unlink(key);
close(shm_fd);
throw std::runtime_error("KittensBroker: Failed to open shared memory stats");
}
usleep(1);
} while ((size_t)shm_st.st_size < size);
void *addr = mmap(0, size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0);
close(shm_fd);
if (addr == MAP_FAILED) {
shm_unlink(key);
throw std::runtime_error("KittensBroker: Failed to map to shared memory");
}
return addr;
}
__host__ inline void unlink_shm(const char *key) {
shm_unlink(key);
}
__host__ inline void unmap_shm(void *addr, size_t size) {
munmap(addr, size);
}
__host__ inline int create_socket(const char *key, int local_rank) {
int sock_fd;
if ((sock_fd = socket(AF_UNIX, SOCK_DGRAM | SOCK_CLOEXEC, 0)) < 0)
throw std::runtime_error("KittensBroker: Socket creation error");
struct sockaddr_un addr;
memset(&addr, 0, sizeof(addr));
addr.sun_family = AF_UNIX;
char unique_key[64];
int n = snprintf(unique_key, sizeof(unique_key), "%s%d", key, local_rank);
if (n < 0 || n >= (int)sizeof(unique_key)) {
close(sock_fd);
throw std::runtime_error("KittensBroker: Socket name too long");
}
size_t len = strnlen(unique_key, sizeof(addr.sun_path));
if (len > (sizeof(addr.sun_path) - 1)) {
close(sock_fd);
throw std::runtime_error("KittensBroker: Socket name too long");
}
strcpy(addr.sun_path, unique_key);
unlink(unique_key);
if (bind(sock_fd, (struct sockaddr *)&addr, SUN_LEN(&addr)) < 0) {
close(sock_fd);
throw std::runtime_error("KittensBroker: Failed to bind socket");
}
return sock_fd;
}
__host__ inline void send_fd(
int sock_fd,
int data_fd,
const char *dst_key,
int dst_local_rank,
int src_local_rank
) {
union {
struct cmsghdr cm;
char* control;
} control_un;
size_t sizeof_control = CMSG_SPACE(sizeof(int));
control_un.control = reinterpret_cast<char *>(malloc(sizeof_control));
if (!control_un.control) {
close(sock_fd);
close(data_fd);
throw std::runtime_error("KittensBroker: Failed to allocate a control buffer");
}
struct msghdr msg {};
msg.msg_control = control_un.control;
msg.msg_controllen = sizeof_control;
struct cmsghdr *cmptr = CMSG_FIRSTHDR(&msg);
cmptr->cmsg_len = CMSG_LEN(sizeof(int));
cmptr->cmsg_level = SOL_SOCKET;
cmptr->cmsg_type = SCM_RIGHTS;
memmove(CMSG_DATA(cmptr), &data_fd, sizeof(data_fd));
struct sockaddr_un addr {};
addr.sun_family = AF_UNIX;
char dst_unique_key[64];
int n = snprintf(dst_unique_key, sizeof(dst_unique_key), "%s%d", dst_key, dst_local_rank);
if (n < 0 || n >= (int)sizeof(dst_unique_key)) {
free(control_un.control);
close(sock_fd);
close(data_fd);
throw std::runtime_error("KittensBroker: dst path too long");
}
strcpy(addr.sun_path, dst_unique_key);
msg.msg_name = (void *)&addr;
msg.msg_namelen = sizeof(struct sockaddr_un);
int payload = src_local_rank;
struct iovec iov[1];
iov[0].iov_base = &payload;
iov[0].iov_len = sizeof(payload);
msg.msg_iov = iov;
msg.msg_iovlen = 1;
while (true) {
ssize_t sent = sendmsg(sock_fd, &msg, 0);
if (sent <= 0) {
if (errno == EINTR) continue;
close(sock_fd);
close(data_fd);
free(control_un.control);
throw std::runtime_error("KittensBroker: Failed to send FD over socket");
}
break;
}
free(control_un.control);
}
__host__ inline void recv_fd(int sock_fd, int *data_fd, int *src_local_rank) {
union {
struct cmsghdr cm;
char* control;
} control_un;
size_t sizeof_control = CMSG_SPACE(sizeof(int));
control_un.control = reinterpret_cast<char *>(malloc(sizeof_control));
if (!control_un.control) {
close(sock_fd);
throw std::runtime_error("KittensBroker: Failed to allocate a control buffer");
}
struct msghdr msg {};
msg.msg_control = control_un.control;
msg.msg_controllen = sizeof_control;
int payload = -1;
struct iovec iov[1];
iov[0].iov_base = &payload;
iov[0].iov_len = sizeof(payload);
msg.msg_iov = iov;
msg.msg_iovlen = 1;
while (true) {
ssize_t received = recvmsg(sock_fd, &msg, 0);
if (received < 0 && errno == EINTR) {
msg.msg_controllen = sizeof_control;
msg.msg_iovlen = 1;
continue;
}
if (received < static_cast<ssize_t>(sizeof(*data_fd))) {
free(control_un.control);
close(sock_fd);
throw std::runtime_error("KittensBroker: Failed to receive data over socket");
}
break;
}
if (msg.msg_flags & MSG_CTRUNC) {
free(control_un.control);
close(sock_fd);
throw std::runtime_error("KittensBroker: Control data truncated");
}
struct cmsghdr *cmptr = CMSG_FIRSTHDR(&msg);
if (!cmptr ||
cmptr->cmsg_len != CMSG_LEN(sizeof(int)) ||
cmptr->cmsg_level != SOL_SOCKET ||
cmptr->cmsg_type != SCM_RIGHTS) {
free(control_un.control);
close(sock_fd);
throw std::runtime_error("KittensBroker: Failed to receive data over socket");
}
memmove(data_fd, CMSG_DATA(cmptr), sizeof(*data_fd));
free(control_un.control);
*src_local_rank = payload;
}
__host__ inline void unlink_socket(const char *key, int local_rank) {
char unique_key[64];
int n = snprintf(unique_key, sizeof(unique_key), "%s%d", key, local_rank);
if (n < 0 || n >= (int)sizeof(unique_key))
throw std::runtime_error("KittensBroker: Socket name too long");
unlink(unique_key);
}
__host__ inline void close_socket(int sock_fd) {
close(sock_fd);
}
} // namespace broker
} // namespace detail
/**
@brief KittensBroker utility for multiprocess data exchange.
Note that the code relies on POSIX sockets/shared memory/semaphores for
inter-process communication and synchronization.
The main functions meant to be used by the user are:
KittensBroker broker(local_rank, local_world_size);
broker.exchange_data(dst, src, size); // exchange data between all processes
broker.exchange_fds(dst, src_fd); // exchange file descriptors between all processes
broker.broadcast_fd(dst, src_fd, src_rank); // broadcast file descriptor from src_rank to all processes
broker.sync(); // wait until all processes reach here
*/
struct KittensBroker {
// TODO: make unique per process group
static inline constexpr const char *SHM_KEY_ = "/kittens_broker_shm";
static inline constexpr const char *SOCK_KEY_ = "/tmp/kittens_broker.sock";
int local_rank_;
int local_world_size_;
void *shm_raw_;
volatile detail::broker::KittensVault *shm_;
int sock_;
__host__ inline KittensBroker(int local_rank, int local_world_size)
: local_rank_(local_rank),
local_world_size_(local_world_size),
shm_raw_(nullptr),
shm_(nullptr),
sock_(-1) {
if (local_rank_ < 0)
throw std::runtime_error("KittensBroker: Local rank must be non-negative");
if (local_rank_ >= local_world_size_)
throw std::runtime_error("KittensBroker: Local rank is greater than local world size");
if (local_world_size_ > detail::broker::MAX_LOCAL_WORLD_SIZE)
throw std::runtime_error("KittensBroker: Local world size is greater than MAX_LOCAL_WORLD_SIZE");
if (local_rank_ == 0) {
shm_raw_ = detail::broker::create_shm(SHM_KEY_, sizeof(detail::broker::KittensVault));
shm_ = reinterpret_cast<volatile detail::broker::KittensVault *>(shm_raw_);
memset(shm_raw_, 0, sizeof(detail::broker::KittensVault));
} else {
shm_raw_ = detail::broker::open_shm(SHM_KEY_, sizeof(detail::broker::KittensVault));
shm_ = reinterpret_cast<volatile detail::broker::KittensVault *>(shm_raw_);
}
detail::broker::init_sync(local_rank_, shm_);
detail::broker::sync(local_world_size_, shm_);
if (local_rank_ ==0)
detail::broker::unlink_shm(SHM_KEY_);
detail::broker::sync(local_world_size_, shm_);
sock_ = detail::broker::create_socket(SOCK_KEY_, local_rank_);
detail::broker::sync(local_world_size_, shm_);
}
KittensBroker(const KittensBroker&) = delete;
KittensBroker& operator=(const KittensBroker&) = delete;
__host__ inline KittensBroker(KittensBroker&& other) noexcept
: local_rank_(other.local_rank_),
local_world_size_(other.local_world_size_),
shm_raw_(other.shm_raw_),
shm_(other.shm_),
sock_(other.sock_) {
other.local_rank_ = -1;
other.local_world_size_ = -1;
other.shm_raw_ = nullptr;
other.shm_ = nullptr;
other.sock_ = -1;
}
__host__ inline void destroy() {
if (shm_raw_) {
detail::broker::unmap_shm(shm_raw_, sizeof(detail::broker::KittensVault));
shm_raw_ = nullptr;
shm_ = nullptr;
}
if (sock_ >= 0) {
detail::broker::unlink_socket(SOCK_KEY_, local_rank_);
detail::broker::close_socket(sock_);
sock_ = -1;
}
local_rank_ = -1;
local_world_size_ = -1;
}
__host__ inline KittensBroker& operator=(KittensBroker&& other) noexcept {
if (this != &other) {
destroy();
local_rank_ = other.local_rank_;
local_world_size_ = other.local_world_size_;
shm_raw_ = other.shm_raw_;
shm_ = other.shm_;
sock_ = other.sock_;
other.local_rank_ = -1;
other.local_world_size_ = -1;
other.shm_raw_ = nullptr;
other.shm_ = nullptr;
other.sock_ = -1;
}
return *this;
}
__host__ inline ~KittensBroker() {
destroy();
}
__host__ inline void sync(int num_ranks = -1) {
if (num_ranks == -1)
num_ranks = local_world_size_;
else if (num_ranks < 0 || num_ranks > local_world_size_)
throw std::runtime_error("KittensBroker: Invalid number of ranks");
detail::broker::sync(num_ranks, shm_);
}
__host__ inline void exchange_data(void *dst_, const void *src_, size_t size) {
if (size > detail::broker::VAULT_SIZE_PER_RANK)
throw std::runtime_error("KittensBroker: Size is greater than VAULT_SIZE_PER_RANK");
uint8_t *dst = reinterpret_cast<uint8_t *>(dst_);
const uint8_t *src = reinterpret_cast<const uint8_t *>(src_);
// Exchange data
sync(); // ensure all processes enter together
memcpy(const_cast<uint8_t *>(shm_->data) + local_rank_ * detail::broker::VAULT_SIZE_PER_RANK, src, size);
sync(); // ensure all processes exit together
// Pack and copy back to destination
for (int i = 0; i < local_world_size_; i++)
memcpy(dst + i * size, const_cast<uint8_t *>(shm_->data) + i * detail::broker::VAULT_SIZE_PER_RANK, size);
}
__host__ inline void exchange_fds(int *dst, const int data_fd) {
if (dst == nullptr)
throw std::runtime_error("KittensBroker: dst is null");
if (data_fd < 0)
throw std::runtime_error("KittensBroker: source fd is negative");
// Initialize dst buffer
for (int i = 0; i < local_world_size_; ++i)
dst[i] = -1;
// Ensure all processes enter together
sync();
if (local_rank_ == 0) {
// Rank 0 receives all FDs from and distributes them to other ranks
dst[0] = data_fd;
for (int i = 0; i < local_world_size_ - 1; i++) {
int received_fd;
int src_local_rank;
detail::broker::recv_fd(sock_, &received_fd, &src_local_rank);
if (received_fd < 0)
throw std::runtime_error("KittensBroker: Failed to receive FD over socket");
if (src_local_rank == local_rank_)
throw std::runtime_error("KittensBroker: Invalid source rank");
dst[src_local_rank] = received_fd;
}
for (int dst_local_rank = 1; dst_local_rank < local_world_size_; dst_local_rank++) {
for (int src_local_rank = 0; src_local_rank < local_world_size_; src_local_rank++) {
if (dst_local_rank == src_local_rank)
continue;
detail::broker::send_fd(sock_, dst[src_local_rank], SOCK_KEY_, dst_local_rank, src_local_rank);
}
}
close(dst[0]); // no longer needed
dst[0] = -1;
} else {
// The rest sends its FD to and receives the other FDs from rank 0
detail::broker::send_fd(sock_, data_fd, SOCK_KEY_, 0, local_rank_);
close(data_fd); // no longer needed
for (int i = 0; i < local_world_size_ - 1; i++) {
int received_fd;
int src_local_rank;
detail::broker::recv_fd(sock_, &received_fd, &src_local_rank);
if (received_fd < 0)
throw std::runtime_error("KittensBroker: Failed to receive FD over socket");
if (src_local_rank == local_rank_)
throw std::runtime_error("KittensBroker: Invalid source rank");
dst[src_local_rank] = received_fd;
}
}
// Ensure all processes exit together
sync();
}
__host__ inline void broadcast_fd(int *dst, const int data_fd, const int src_local_rank) {
if (src_local_rank < 0 || src_local_rank >= local_world_size_)
throw std::runtime_error("KittensBroker: Invalid source rank");
// Ensure all processes enter together
sync();
if (local_rank_ == src_local_rank) {
if (data_fd < 0)
throw std::runtime_error("KittensBroker: Source rank has invalid FD");
for (int dst_local_rank = 0; dst_local_rank < local_world_size_; dst_local_rank++) {
if (dst_local_rank == src_local_rank)
continue;
detail::broker::send_fd(sock_, data_fd, SOCK_KEY_, dst_local_rank, src_local_rank);
}
close(data_fd); // no longer needed
} else {
if (!dst)
throw std::runtime_error("KittensBroker: Destination rank has invalid buffer");
int _src_local_rank;
detail::broker::recv_fd(sock_, dst, &_src_local_rank);
if (*dst < 0)
throw std::runtime_error("KittensBroker: Failed to receive valid FD over socket");
if (_src_local_rank != src_local_rank)
throw std::runtime_error("KittensBroker: Invalid source rank");
}
// Ensure all processes exit together
sync();
}
};
} // namespace kittens