You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
352 lines
11 KiB
352 lines
11 KiB
import functools
|
|
import threading
|
|
import inspect
|
|
import sys
|
|
import select
|
|
import struct
|
|
from math import sqrt
|
|
from collections import OrderedDict, deque
|
|
from time import time
|
|
|
|
from tools.lib.pollable_queue import PollableQueue, Empty, Full, ExistentialError
|
|
|
|
EndSentinel = object()
|
|
|
|
|
|
def _sync_inner_generator(input_queue, *args, **kwargs):
|
|
func = args[0]
|
|
args = args[1:]
|
|
|
|
get = input_queue.get
|
|
while True:
|
|
item = get()
|
|
if item is EndSentinel:
|
|
return
|
|
|
|
cookie, value = item
|
|
yield cookie, func(value, *args, **kwargs)
|
|
|
|
|
|
def _async_streamer_async_inner(input_queue, output_queue, generator_func, args, kwargs):
|
|
put = output_queue.put
|
|
put_end = True
|
|
try:
|
|
g = generator_func(input_queue, *args, **kwargs)
|
|
for item in g:
|
|
put((time(), item))
|
|
g.close()
|
|
except ExistentialError:
|
|
put_end = False
|
|
raise
|
|
finally:
|
|
if put_end:
|
|
put((None, EndSentinel))
|
|
|
|
def _running_mean_var(ltc_stats, x):
|
|
old_mean, var = ltc_stats
|
|
mean = min(600., 0.98 * old_mean + 0.02 * x)
|
|
var = min(5., max(0.1, 0.98 * var + 0.02 * (mean - x) * (old_mean - x)))
|
|
return mean, var
|
|
|
|
def _find_next_resend(sent_messages, ltc_stats):
|
|
if not sent_messages:
|
|
return None, None
|
|
|
|
oldest_sent_idx = sent_messages._OrderedDict__root[1][2]
|
|
send_time, _ = sent_messages[oldest_sent_idx]
|
|
|
|
# Assume message has been lost if it is >10 standard deviations from mean.
|
|
mean, var = ltc_stats
|
|
next_resend_time = send_time + mean + 40. * sqrt(var)
|
|
|
|
return oldest_sent_idx, next_resend_time
|
|
|
|
|
|
def _do_cleanup(input_queue, output_queue, num_workers, sentinels_received, num_outstanding):
|
|
input_fd = input_queue.put_fd()
|
|
output_fd = output_queue.get_fd()
|
|
|
|
poller = select.epoll()
|
|
poller.register(input_fd, select.EPOLLOUT)
|
|
poller.register(output_fd, select.EPOLLIN)
|
|
|
|
remaining_outputs = []
|
|
end_sentinels_to_send = num_workers - sentinels_received
|
|
while sentinels_received < num_workers:
|
|
evts = dict(poller.poll(-1 if num_outstanding > 0 else 10.))
|
|
if not evts:
|
|
# Workers aren't responding, crash.
|
|
break
|
|
|
|
if output_fd in evts:
|
|
_, maybe_sentinel = output_queue.get()
|
|
if maybe_sentinel is EndSentinel:
|
|
sentinels_received += 1
|
|
else:
|
|
remaining_outputs.append(maybe_sentinel[1])
|
|
num_outstanding -= 1
|
|
|
|
if input_fd in evts:
|
|
if end_sentinels_to_send > 0:
|
|
input_queue.put_nowait(EndSentinel)
|
|
end_sentinels_to_send -= 1
|
|
else:
|
|
poller.modify(input_fd, 0)
|
|
|
|
# TODO: Raise an exception when a queue thread raises one.
|
|
assert sentinels_received == num_workers, (sentinels_received, num_workers)
|
|
assert output_queue.empty()
|
|
return remaining_outputs
|
|
|
|
def _generate_results(input_stream, input_queue, worker_output_queue, output_queue,
|
|
num_workers, max_outstanding):
|
|
pack_cookie = struct.pack
|
|
|
|
# Maps idx -> (send_time, input)
|
|
sent_messages = OrderedDict()
|
|
oldest_sent_idx = None
|
|
next_resend_time = None
|
|
ltc_stats = 5., 10.
|
|
|
|
# Maps idx -> result
|
|
received_messages = {}
|
|
next_out = 0
|
|
|
|
# Start things off by pulling the first value.
|
|
next_in_item = next(input_stream, EndSentinel)
|
|
inputs_remain = next_in_item is not EndSentinel
|
|
sentinels_received = 0
|
|
|
|
input_fd = input_queue.put_fd()
|
|
worker_output_fd = worker_output_queue.get_fd()
|
|
output_fd = output_queue.put_fd()
|
|
|
|
poller = select.epoll()
|
|
poller.register(input_fd, select.EPOLLOUT)
|
|
poller.register(worker_output_fd, select.EPOLLIN)
|
|
poller.register(output_fd, 0)
|
|
|
|
# Keep sending/retrying until the input stream and sent messages are all done.
|
|
while sentinels_received < num_workers and (inputs_remain or sent_messages):
|
|
if max_outstanding:
|
|
can_send_new = (len(sent_messages) < max_outstanding and
|
|
len(received_messages) < max_outstanding and inputs_remain)
|
|
else:
|
|
can_send_new = inputs_remain
|
|
|
|
if (next_resend_time and now >= next_resend_time) or can_send_new:
|
|
poller.modify(input_fd, select.EPOLLOUT)
|
|
else:
|
|
poller.modify(input_fd, 0)
|
|
|
|
if next_resend_time:
|
|
t = max(0, next_resend_time - now)
|
|
evts = dict(poller.poll(t))
|
|
else:
|
|
evts = dict(poller.poll())
|
|
now = time()
|
|
|
|
if output_fd in evts:
|
|
output_queue.put_nowait(received_messages.pop(next_out))
|
|
next_out += 1
|
|
|
|
if next_out not in received_messages:
|
|
poller.modify(output_fd, 0)
|
|
|
|
if worker_output_fd in evts:
|
|
for receive_time, maybe_sentinel in worker_output_queue.get_multiple_nowait():
|
|
# Check for EndSentinel in case of worker crash.
|
|
if maybe_sentinel is EndSentinel:
|
|
sentinels_received += 1
|
|
continue
|
|
idx_bytes, value = maybe_sentinel
|
|
idx = struct.unpack("<Q", idx_bytes)[0]
|
|
|
|
# Don't store duplicates.
|
|
sent_message = sent_messages.pop(idx, None)
|
|
if sent_message is not None:
|
|
received_messages[idx] = value
|
|
ltc_stats = _running_mean_var(ltc_stats, receive_time - sent_message[0])
|
|
if idx == oldest_sent_idx:
|
|
oldest_sent_idx, next_resend_time = _find_next_resend(sent_messages, ltc_stats)
|
|
|
|
if idx == next_out:
|
|
poller.modify(output_fd, select.EPOLLOUT)
|
|
else:
|
|
if oldest_sent_idx is not None:
|
|
# The message was resent, so it is at least old as the oldest tracked message.
|
|
ltc_stats = _running_mean_var(ltc_stats, now - sent_messages[oldest_sent_idx][0])
|
|
elif input_fd in evts:
|
|
if can_send_new:
|
|
# We're under the limit, get the next input.
|
|
send_idx, send_value = next_in_item
|
|
input_queue.put_nowait((pack_cookie("<Q", send_idx), send_value))
|
|
sent_messages[next_in_item[0]] = now, next_in_item[1]
|
|
next_in_item = next(input_stream, EndSentinel)
|
|
inputs_remain = next_in_item is not EndSentinel
|
|
|
|
if oldest_sent_idx is None:
|
|
oldest_sent_idx, next_resend_time = _find_next_resend(sent_messages, ltc_stats)
|
|
else:
|
|
# Move the resent item to the end.
|
|
send_time, resend_input = sent_messages.pop(oldest_sent_idx)
|
|
|
|
sys.stdout.write("Resending {} (ltc, mean, var) = ({}, {}, {})\n".format(
|
|
oldest_sent_idx, now - send_time, ltc_stats[0], ltc_stats[1]))
|
|
input_queue.put_nowait((pack_cookie("<Q", oldest_sent_idx), resend_input))
|
|
|
|
sent_messages[oldest_sent_idx] = now, resend_input
|
|
oldest_sent_idx, next_resend_time = _find_next_resend(sent_messages, ltc_stats)
|
|
|
|
# Return remaining messages.
|
|
while next_out in received_messages:
|
|
output_queue.put(received_messages.pop(next_out))
|
|
next_out += 1
|
|
|
|
_do_cleanup(input_queue, worker_output_queue, num_workers, sentinels_received, 0)
|
|
output_queue.put(EndSentinel)
|
|
|
|
|
|
def _generate_results_unreliable(input_stream, input_queue, worker_output_queue,
|
|
output_queue, num_workers, max_outstanding_unused):
|
|
# Start things off by pulling the first value.
|
|
next_in_item = next(input_stream, EndSentinel)
|
|
inputs_remain = next_in_item is not EndSentinel
|
|
|
|
# TODO: Use heapq to return oldest message.
|
|
received_messages = deque()
|
|
pack_cookie = struct.pack
|
|
|
|
input_fd = input_queue.put_fd()
|
|
worker_output_fd = worker_output_queue.get_fd()
|
|
output_fd = output_queue.put_fd()
|
|
|
|
poller = select.epoll()
|
|
poller.register(input_fd, select.EPOLLOUT)
|
|
poller.register(worker_output_fd, select.EPOLLIN)
|
|
poller.register(output_fd, 0)
|
|
|
|
# Keep sending/retrying until the input stream and sent messages are all done.
|
|
num_outstanding = 0
|
|
sentinels_received = 0
|
|
while sentinels_received < num_workers and (inputs_remain or received_messages):
|
|
# Avoid poll() if we can easily detect that there is work to do.
|
|
evts = (input_fd if inputs_remain and not input_queue.full() else 0, output_fd
|
|
if not output_queue.full() and len(received_messages) else 0, worker_output_fd
|
|
if not worker_output_queue.empty() else 0)
|
|
if all(evt == 0 for evt in evts):
|
|
evts = dict(poller.poll())
|
|
|
|
if output_fd in evts:
|
|
output_queue.put(received_messages.pop())
|
|
|
|
if len(received_messages) == 0:
|
|
poller.modify(output_fd, 0)
|
|
|
|
if worker_output_fd in evts:
|
|
for receive_time, maybe_sentinel in worker_output_queue.get_multiple():
|
|
# Check for EndSentinel in case of worker crash.
|
|
if maybe_sentinel is EndSentinel:
|
|
sentinels_received += 1
|
|
continue
|
|
received_messages.appendleft(maybe_sentinel[1])
|
|
num_outstanding -= 1
|
|
poller.modify(output_fd, select.EPOLLOUT)
|
|
|
|
if input_fd in evts:
|
|
# We're under the limit, get the next input.
|
|
send_idx, send_value = next_in_item
|
|
input_queue.put((pack_cookie("<Q", send_idx), send_value))
|
|
next_in_item = next(input_stream, EndSentinel)
|
|
inputs_remain = next_in_item is not EndSentinel
|
|
num_outstanding += 1
|
|
|
|
if not inputs_remain:
|
|
poller.modify(input_fd, 0)
|
|
|
|
# TODO: Track latency even though we don't retry.
|
|
for value in _do_cleanup(input_queue, worker_output_queue, num_workers,
|
|
sentinels_received, num_outstanding):
|
|
output_queue.put(value)
|
|
output_queue.put(EndSentinel)
|
|
|
|
|
|
def _async_generator(func, max_workers, in_q_size, out_q_size, max_outstanding,
|
|
async_inner, reliable):
|
|
if async_inner:
|
|
assert inspect.isgeneratorfunction(
|
|
func), "async_inner == True but {} is not a generator".format(func)
|
|
|
|
@functools.wraps(func)
|
|
def wrapper(input_sequence_or_self, *args, **kwargs):
|
|
# HACK: Determine whether the first arg is "self". ismethod returns False here.
|
|
if inspect.getargspec(func).args[0] == "self":
|
|
inner_func = func.__get__(input_sequence_or_self, type(input_sequence_or_self))
|
|
input_sequence = args[0]
|
|
args = args[1:]
|
|
else:
|
|
inner_func = func
|
|
input_sequence = input_sequence_or_self
|
|
input_stream = enumerate(iter(input_sequence))
|
|
|
|
if reliable:
|
|
generate_func = _generate_results
|
|
else:
|
|
generate_func = _generate_results_unreliable
|
|
|
|
input_queue = PollableQueue(in_q_size)
|
|
worker_output_queue = PollableQueue(8 * max_workers)
|
|
output_queue = PollableQueue(out_q_size)
|
|
|
|
# Start the worker threads.
|
|
if async_inner:
|
|
generator_func = inner_func
|
|
else:
|
|
args = (inner_func,) + args
|
|
generator_func = _sync_inner_generator
|
|
|
|
worker_threads = []
|
|
for _ in range(max_workers):
|
|
t = threading.Thread(
|
|
target=_async_streamer_async_inner,
|
|
args=(input_queue, worker_output_queue, generator_func, args, kwargs))
|
|
t.daemon = True
|
|
t.start()
|
|
worker_threads.append(t)
|
|
|
|
master_thread = threading.Thread(
|
|
target=generate_func,
|
|
args=(input_stream, input_queue, worker_output_queue, output_queue, max_workers,
|
|
max_outstanding))
|
|
master_thread.daemon = True
|
|
master_thread.start()
|
|
|
|
try:
|
|
while True:
|
|
for value in output_queue.get_multiple():
|
|
if value is EndSentinel:
|
|
return
|
|
else:
|
|
yield value
|
|
finally:
|
|
# Make sure work is done and the threads are stopped.
|
|
for t in worker_threads:
|
|
t.join(1)
|
|
master_thread.join(1)
|
|
|
|
input_queue.close()
|
|
worker_output_queue.close()
|
|
output_queue.close()
|
|
|
|
return wrapper
|
|
|
|
|
|
def async_generator(max_workers=1,
|
|
in_q_size=10,
|
|
out_q_size=12,
|
|
max_outstanding=10000,
|
|
async_inner=False,
|
|
reliable=True):
|
|
return (
|
|
lambda f: _async_generator(f, max_workers, in_q_size, out_q_size, max_outstanding, async_inner, reliable)
|
|
)
|
|
|