# include <iostream>
# include <cassert>
# include <cerrno>
# include <cmath>
# include <cstring>
# include <cstdint>
# include <chrono>
# include <algorithm>
# include <cstdlib>
# include <csignal>
# include <random>
# include <poll.h>
# include <sys/ioctl.h>
# include <sys/mman.h>
# include <sys/stat.h>
# include <sys/types.h>
# include <sys/syscall.h>
# include <fcntl.h>
# include <unistd.h>
# include <stdio.h>
# include "msgq.hpp"
void sigusr2_handler ( int signal ) {
assert ( signal = = SIGUSR2 ) ;
}
uint64_t msgq_get_uid ( void ) {
std : : random_device rd ( " /dev/urandom " ) ;
std : : uniform_int_distribution < uint64_t > distribution ( 0 , std : : numeric_limits < uint32_t > : : max ( ) ) ;
uint64_t uid = distribution ( rd ) < < 32 | syscall ( SYS_gettid ) ;
return uid ;
}
int msgq_msg_init_size ( msgq_msg_t * msg , size_t size ) {
msg - > size = size ;
msg - > data = new ( std : : nothrow ) char [ size ] ;
return ( msg - > data = = NULL ) ? - 1 : 0 ;
}
int msgq_msg_init_data ( msgq_msg_t * msg , char * data , size_t size ) {
int r = msgq_msg_init_size ( msg , size ) ;
if ( r = = 0 )
memcpy ( msg - > data , data , size ) ;
return r ;
}
int msgq_msg_close ( msgq_msg_t * msg ) {
if ( msg - > size > 0 )
delete [ ] msg - > data ;
msg - > size = 0 ;
return 0 ;
}
void msgq_reset_reader ( msgq_queue_t * q ) {
int id = q - > reader_id ;
q - > read_valids [ id ] - > store ( true ) ;
q - > read_pointers [ id ] - > store ( * q - > write_pointer ) ;
}
void msgq_wait_for_subscriber ( msgq_queue_t * q ) {
while ( * q - > num_readers = = 0 ) {
;
}
return ;
}
int msgq_new_queue ( msgq_queue_t * q , const char * path , size_t size ) {
assert ( size < 0xFFFFFFFF ) ; // Buffer must be smaller than 2^32 bytes
std : : signal ( SIGUSR2 , sigusr2_handler ) ;
const char * prefix = " /dev/shm/ " ;
char * full_path = new char [ strlen ( path ) + strlen ( prefix ) + 1 ] ;
strcpy ( full_path , prefix ) ;
strcat ( full_path , path ) ;
auto fd = open ( full_path , O_RDWR | O_CREAT , 0777 ) ;
delete [ ] full_path ;
if ( fd < 0 )
return - 1 ;
int rc = ftruncate ( fd , size + sizeof ( msgq_header_t ) ) ;
if ( rc < 0 )
return - 1 ;
char * mem = ( char * ) mmap ( NULL , size + sizeof ( msgq_header_t ) , PROT_READ | PROT_WRITE , MAP_SHARED , fd , 0 ) ;
close ( fd ) ;
if ( mem = = NULL )
return - 1 ;
q - > mmap_p = mem ;
msgq_header_t * header = ( msgq_header_t * ) mem ;
// Setup pointers to header segment
q - > num_readers = reinterpret_cast < std : : atomic < uint64_t > * > ( & header - > num_readers ) ;
q - > write_pointer = reinterpret_cast < std : : atomic < uint64_t > * > ( & header - > write_pointer ) ;
q - > write_uid = reinterpret_cast < std : : atomic < uint64_t > * > ( & header - > write_uid ) ;
for ( size_t i = 0 ; i < NUM_READERS ; i + + ) {
q - > read_pointers [ i ] = reinterpret_cast < std : : atomic < uint64_t > * > ( & header - > read_pointers [ i ] ) ;
q - > read_valids [ i ] = reinterpret_cast < std : : atomic < uint64_t > * > ( & header - > read_valids [ i ] ) ;
q - > read_uids [ i ] = reinterpret_cast < std : : atomic < uint64_t > * > ( & header - > read_uids [ i ] ) ;
}
q - > data = mem + sizeof ( msgq_header_t ) ;
q - > size = size ;
q - > reader_id = - 1 ;
q - > endpoint = path ;
q - > read_conflate = false ;
return 0 ;
}
void msgq_close_queue ( msgq_queue_t * q ) {
if ( q - > mmap_p ! = NULL ) {
munmap ( q - > mmap_p , q - > size + sizeof ( msgq_header_t ) ) ;
}
}
void msgq_init_publisher ( msgq_queue_t * q ) {
//std::cout << "Starting publisher" << std::endl;
uint64_t uid = msgq_get_uid ( ) ;
* q - > write_uid = uid ;
* q - > num_readers = 0 ;
for ( size_t i = 0 ; i < NUM_READERS ; i + + ) {
* q - > read_valids [ i ] = false ;
* q - > read_uids [ i ] = 0 ;
}
q - > write_uid_local = uid ;
}
static void thread_signal ( uint32_t tid ) {
# ifndef SYS_tkill
// TODO: this won't work for multithreaded programs
kill ( tid , SIGUSR2 ) ;
# else
syscall ( SYS_tkill , tid , SIGUSR2 ) ;
# endif
}
void msgq_init_subscriber ( msgq_queue_t * q ) {
assert ( q ! = NULL ) ;
assert ( q - > num_readers ! = NULL ) ;
uint64_t uid = msgq_get_uid ( ) ;
// Get reader id
while ( true ) {
uint64_t cur_num_readers = * q - > num_readers ;
uint64_t new_num_readers = cur_num_readers + 1 ;
// No more slots available. Reset all subscribers to kick out inactive ones
if ( new_num_readers > NUM_READERS ) {
std : : cout < < " Warning, evicting all subscribers! " < < std : : endl ;
* q - > num_readers = 0 ;
for ( size_t i = 0 ; i < NUM_READERS ; i + + ) {
* q - > read_valids [ i ] = false ;
uint64_t old_uid = * q - > read_uids [ i ] ;
* q - > read_uids [ i ] = 0 ;
// Wake up reader in case they are in a poll
thread_signal ( old_uid & 0xFFFFFFFF ) ;
}
continue ;
}
// Use atomic compare and swap to handle race condition
// where two subscribers start at the same time
if ( std : : atomic_compare_exchange_strong ( q - > num_readers ,
& cur_num_readers ,
new_num_readers ) ) {
q - > reader_id = cur_num_readers ;
q - > read_uid_local = uid ;
// We start with read_valid = false,
// on the first read the read pointer will be synchronized with the write pointer
* q - > read_valids [ cur_num_readers ] = false ;
* q - > read_pointers [ cur_num_readers ] = 0 ;
* q - > read_uids [ cur_num_readers ] = uid ;
break ;
}
}
//std::cout << "New subscriber id: " << q->reader_id << " uid: " << q->read_uid_local << " " << q->endpoint << std::endl;
msgq_reset_reader ( q ) ;
}
int msgq_msg_send ( msgq_msg_t * msg , msgq_queue_t * q ) {
// Die if we are no longer the active publisher
if ( q - > write_uid_local ! = * q - > write_uid ) {
std : : cout < < " Killing old publisher: " < < q - > endpoint < < std : : endl ;
errno = EADDRINUSE ;
return - 1 ;
}
uint64_t total_msg_size = ALIGN ( msg - > size + sizeof ( int64_t ) ) ;
// We need to fit at least three messages in the queue,
// then we can always safely access the last message
assert ( 3 * total_msg_size < = q - > size ) ;
uint64_t num_readers = * q - > num_readers ;
uint32_t write_cycles , write_pointer ;
UNPACK64 ( write_cycles , write_pointer , * q - > write_pointer ) ;
char * p = q - > data + write_pointer ; // add base offset
// Check remaining space
// Always leave space for a wraparound tag for the next message, including alignment
int64_t remaining_space = q - > size - write_pointer - total_msg_size - sizeof ( int64_t ) ;
if ( remaining_space < = 0 ) {
// Write -1 size tag indicating wraparound
* ( int64_t * ) p = - 1 ;
// Invalidate all readers that are beyond the write pointer
// TODO: should we handle the case where a new reader shows up while this is running?
for ( uint64_t i = 0 ; i < num_readers ; i + + ) {
uint64_t read_pointer = * q - > read_pointers [ i ] ;
uint64_t read_cycles = read_pointer > > 32 ;
read_pointer & = 0xFFFFFFFF ;
if ( ( read_pointer > write_pointer ) & & ( read_cycles ! = write_cycles ) ) {
* q - > read_valids [ i ] = false ;
}
}
// Update global and local copies of write pointer and write_cycles
write_pointer = 0 ;
write_cycles = write_cycles + 1 ;
PACK64 ( * q - > write_pointer , write_cycles , write_pointer ) ;
// Set actual pointer to the beginning of the data segment
p = q - > data ;
}
// Invalidate readers that are in the area that will be written
uint64_t start = write_pointer ;
uint64_t end = ALIGN ( start + sizeof ( int64_t ) + msg - > size ) ;
for ( uint64_t i = 0 ; i < num_readers ; i + + ) {
uint32_t read_cycles , read_pointer ;
UNPACK64 ( read_cycles , read_pointer , * q - > read_pointers [ i ] ) ;
if ( ( read_pointer > = start ) & & ( read_pointer < end ) & & ( read_cycles ! = write_cycles ) ) {
* q - > read_valids [ i ] = false ;
}
}
// Write size tag
std : : atomic < int64_t > * size_p = reinterpret_cast < std : : atomic < int64_t > * > ( p ) ;
* size_p = msg - > size ;
// Copy data
memcpy ( p + sizeof ( int64_t ) , msg - > data , msg - > size ) ;
__sync_synchronize ( ) ;
// Update write pointer
uint32_t new_ptr = ALIGN ( write_pointer + msg - > size + sizeof ( int64_t ) ) ;
PACK64 ( * q - > write_pointer , write_cycles , new_ptr ) ;
// Notify readers
for ( uint64_t i = 0 ; i < num_readers ; i + + ) {
uint64_t reader_uid = * q - > read_uids [ i ] ;
thread_signal ( reader_uid & 0xFFFFFFFF ) ;
}
return msg - > size ;
}
int msgq_msg_ready ( msgq_queue_t * q ) {
start :
int id = q - > reader_id ;
assert ( id > = 0 ) ; // Make sure subscriber is initialized
if ( q - > read_uid_local ! = * q - > read_uids [ id ] ) {
std : : cout < < q - > endpoint < < " : Reader was evicted, reconnecting " < < std : : endl ;
msgq_init_subscriber ( q ) ;
goto start ;
}
// Check valid
if ( ! * q - > read_valids [ id ] ) {
msgq_reset_reader ( q ) ;
goto start ;
}
uint32_t read_cycles , read_pointer ;
UNPACK64 ( read_cycles , read_pointer , * q - > read_pointers [ id ] ) ;
uint32_t write_cycles , write_pointer ;
UNPACK64 ( write_cycles , write_pointer , * q - > write_pointer ) ;
// Check if new message is available
return ( read_pointer ! = write_pointer ) ;
}
int msgq_msg_recv ( msgq_msg_t * msg , msgq_queue_t * q ) {
start :
int id = q - > reader_id ;
assert ( id > = 0 ) ; // Make sure subscriber is initialized
if ( q - > read_uid_local ! = * q - > read_uids [ id ] ) {
std : : cout < < q - > endpoint < < " : Reader was evicted, reconnecting " < < std : : endl ;
msgq_init_subscriber ( q ) ;
goto start ;
}
// Check valid
if ( ! * q - > read_valids [ id ] ) {
msgq_reset_reader ( q ) ;
goto start ;
}
uint32_t read_cycles , read_pointer ;
UNPACK64 ( read_cycles , read_pointer , * q - > read_pointers [ id ] ) ;
uint32_t write_cycles , write_pointer ;
UNPACK64 ( write_cycles , write_pointer , * q - > write_pointer ) ;
char * p = q - > data + read_pointer ;
// Check if new message is available
if ( read_pointer = = write_pointer ) {
msg - > size = 0 ;
return 0 ;
}
// Read potential message size
std : : atomic < int64_t > * size_p = reinterpret_cast < std : : atomic < int64_t > * > ( p ) ;
std : : int64_t size = * size_p ;
// Check if the size that was read is valid
if ( ! * q - > read_valids [ id ] ) {
msgq_reset_reader ( q ) ;
goto start ;
}
// If size is -1 the buffer was full, and we need to wrap around
if ( size = = - 1 ) {
read_cycles + + ;
PACK64 ( * q - > read_pointers [ id ] , read_cycles , 0 ) ;
goto start ;
}
// crashing is better than passing garbage data to the consumer
// the size will have weird value if it was overwritten by data accidentally
assert ( ( uint64_t ) size < q - > size ) ;
assert ( size > 0 ) ;
uint32_t new_read_pointer = ALIGN ( read_pointer + sizeof ( std : : int64_t ) + size ) ;
// If conflate is true, check if this is the latest message, else start over
if ( q - > read_conflate ) {
if ( new_read_pointer ! = write_pointer ) {
// Update read pointer
PACK64 ( * q - > read_pointers [ id ] , read_cycles , new_read_pointer ) ;
goto start ;
}
}
// Copy message
if ( msgq_msg_init_size ( msg , size ) < 0 )
return - 1 ;
__sync_synchronize ( ) ;
memcpy ( msg - > data , p + sizeof ( int64_t ) , size ) ;
__sync_synchronize ( ) ;
// Update read pointer
PACK64 ( * q - > read_pointers [ id ] , read_cycles , new_read_pointer ) ;
// Check if the actual data that was copied is valid
if ( ! * q - > read_valids [ id ] ) {
msgq_msg_close ( msg ) ;
msgq_reset_reader ( q ) ;
goto start ;
}
return msg - > size ;
}
int msgq_poll ( msgq_pollitem_t * items , size_t nitems , int timeout ) {
assert ( timeout > = 0 ) ;
int num = 0 ;
// Check if messages ready
for ( size_t i = 0 ; i < nitems ; i + + ) {
items [ i ] . revents = msgq_msg_ready ( items [ i ] . q ) ;
if ( items [ i ] . revents ) num + + ;
}
int ms = ( timeout = = - 1 ) ? 100 : timeout ;
struct timespec ts ;
ts . tv_sec = ms / 1000 ;
ts . tv_nsec = ( ms % 1000 ) * 1000 * 1000 ;
while ( num = = 0 ) {
int ret ;
ret = nanosleep ( & ts , & ts ) ;
// Check if messages ready
for ( size_t i = 0 ; i < nitems ; i + + ) {
if ( items [ i ] . revents = = 0 & & msgq_msg_ready ( items [ i ] . q ) ) {
num + = 1 ;
items [ i ] . revents = 1 ;
}
}
// exit if we had a timeout and the sleep finished
if ( timeout ! = - 1 & & ret = = 0 ) {
break ;
}
}
return num ;
}