You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
				
					353 lines
				
				11 KiB
			
		
		
			
		
	
	
					353 lines
				
				11 KiB
			| 
											6 years ago
										 | import functools
 | ||
|  | import threading
 | ||
|  | import inspect
 | ||
|  | import sys
 | ||
|  | import select
 | ||
|  | import struct
 | ||
|  | from math import sqrt
 | ||
|  | from collections import OrderedDict, deque
 | ||
|  | from time import time
 | ||
|  | 
 | ||
|  | from tools.lib.pollable_queue import PollableQueue, Empty, Full, ExistentialError
 | ||
|  | 
 | ||
|  | EndSentinel = object()
 | ||
|  | 
 | ||
|  | 
 | ||
|  | def _sync_inner_generator(input_queue, *args, **kwargs):
 | ||
|  |   func = args[0]
 | ||
|  |   args = args[1:]
 | ||
|  | 
 | ||
|  |   get = input_queue.get
 | ||
|  |   while True:
 | ||
|  |     item = get()
 | ||
|  |     if item is EndSentinel:
 | ||
|  |       return
 | ||
|  | 
 | ||
|  |     cookie, value = item
 | ||
|  |     yield cookie, func(value, *args, **kwargs)
 | ||
|  | 
 | ||
|  | 
 | ||
|  | def _async_streamer_async_inner(input_queue, output_queue, generator_func, args, kwargs):
 | ||
|  |   put = output_queue.put
 | ||
|  |   put_end = True
 | ||
|  |   try:
 | ||
|  |     g = generator_func(input_queue, *args, **kwargs)
 | ||
|  |     for item in g:
 | ||
|  |       put((time(), item))
 | ||
|  |     g.close()
 | ||
|  |   except ExistentialError:
 | ||
|  |     put_end = False
 | ||
|  |     raise
 | ||
|  |   finally:
 | ||
|  |     if put_end:
 | ||
|  |       put((None, EndSentinel))
 | ||
|  | 
 | ||
|  | def _running_mean_var(ltc_stats, x):
 | ||
|  |   old_mean, var = ltc_stats
 | ||
|  |   mean = min(600., 0.98 * old_mean + 0.02 * x)
 | ||
|  |   var = min(5., max(0.1, 0.98 * var + 0.02 * (mean - x) * (old_mean - x)))
 | ||
|  |   return mean, var
 | ||
|  | 
 | ||
|  | def _find_next_resend(sent_messages, ltc_stats):
 | ||
|  |   if not sent_messages:
 | ||
|  |     return None, None
 | ||
|  | 
 | ||
|  |   oldest_sent_idx = sent_messages._OrderedDict__root[1][2]
 | ||
|  |   send_time, _ = sent_messages[oldest_sent_idx]
 | ||
|  | 
 | ||
|  |   # Assume message has been lost if it is >10 standard deviations from mean.
 | ||
|  |   mean, var = ltc_stats
 | ||
|  |   next_resend_time = send_time + mean + 40. * sqrt(var)
 | ||
|  | 
 | ||
|  |   return oldest_sent_idx, next_resend_time
 | ||
|  | 
 | ||
|  | 
 | ||
|  | def _do_cleanup(input_queue, output_queue, num_workers, sentinels_received, num_outstanding):
 | ||
|  |   input_fd = input_queue.put_fd()
 | ||
|  |   output_fd = output_queue.get_fd()
 | ||
|  | 
 | ||
|  |   poller = select.epoll()
 | ||
|  |   poller.register(input_fd, select.EPOLLOUT)
 | ||
|  |   poller.register(output_fd, select.EPOLLIN)
 | ||
|  | 
 | ||
|  |   remaining_outputs = []
 | ||
|  |   end_sentinels_to_send = num_workers - sentinels_received
 | ||
|  |   while sentinels_received < num_workers:
 | ||
|  |     evts = dict(poller.poll(-1 if num_outstanding > 0 else 10.))
 | ||
|  |     if not evts:
 | ||
|  |       # Workers aren't responding, crash.
 | ||
|  |       break
 | ||
|  | 
 | ||
|  |     if output_fd in evts:
 | ||
|  |       _, maybe_sentinel = output_queue.get()
 | ||
|  |       if maybe_sentinel is EndSentinel:
 | ||
|  |         sentinels_received += 1
 | ||
|  |       else:
 | ||
|  |         remaining_outputs.append(maybe_sentinel[1])
 | ||
|  |         num_outstanding -= 1
 | ||
|  | 
 | ||
|  |     if input_fd in evts:
 | ||
|  |       if end_sentinels_to_send > 0:
 | ||
|  |         input_queue.put_nowait(EndSentinel)
 | ||
|  |         end_sentinels_to_send -= 1
 | ||
|  |       else:
 | ||
|  |         poller.modify(input_fd, 0)
 | ||
|  | 
 | ||
|  |   # TODO: Raise an exception when a queue thread raises one.
 | ||
|  |   assert sentinels_received == num_workers, (sentinels_received, num_workers)
 | ||
|  |   assert output_queue.empty()
 | ||
|  |   return remaining_outputs
 | ||
|  | 
 | ||
|  | def _generate_results(input_stream, input_queue, worker_output_queue, output_queue,
 | ||
|  |                       num_workers, max_outstanding):
 | ||
|  |   pack_cookie = struct.pack
 | ||
|  | 
 | ||
|  |   # Maps idx -> (send_time, input)
 | ||
|  |   sent_messages = OrderedDict()
 | ||
|  |   oldest_sent_idx = None
 | ||
|  |   next_resend_time = None
 | ||
|  |   ltc_stats = 5., 10.
 | ||
|  | 
 | ||
|  |   # Maps idx -> result
 | ||
|  |   received_messages = {}
 | ||
|  |   next_out = 0
 | ||
|  | 
 | ||
|  |   # Start things off by pulling the first value.
 | ||
|  |   next_in_item = next(input_stream, EndSentinel)
 | ||
|  |   inputs_remain = next_in_item is not EndSentinel
 | ||
|  |   sentinels_received = 0
 | ||
|  | 
 | ||
|  |   input_fd = input_queue.put_fd()
 | ||
|  |   worker_output_fd = worker_output_queue.get_fd()
 | ||
|  |   output_fd = output_queue.put_fd()
 | ||
|  | 
 | ||
|  |   poller = select.epoll()
 | ||
|  |   poller.register(input_fd, select.EPOLLOUT)
 | ||
|  |   poller.register(worker_output_fd, select.EPOLLIN)
 | ||
|  |   poller.register(output_fd, 0)
 | ||
|  | 
 | ||
|  |   # Keep sending/retrying until the input stream and sent messages are all done.
 | ||
|  |   while sentinels_received < num_workers and (inputs_remain or sent_messages):
 | ||
|  |     if max_outstanding:
 | ||
|  |       can_send_new = (len(sent_messages) < max_outstanding and
 | ||
|  |                       len(received_messages) < max_outstanding and inputs_remain)
 | ||
|  |     else:
 | ||
|  |       can_send_new = inputs_remain
 | ||
|  | 
 | ||
|  |     if (next_resend_time and now >= next_resend_time) or can_send_new:
 | ||
|  |       poller.modify(input_fd, select.EPOLLOUT)
 | ||
|  |     else:
 | ||
|  |       poller.modify(input_fd, 0)
 | ||
|  | 
 | ||
|  |     if next_resend_time:
 | ||
|  |       t = max(0, next_resend_time - now)
 | ||
|  |       evts = dict(poller.poll(t))
 | ||
|  |     else:
 | ||
|  |       evts = dict(poller.poll())
 | ||
|  |     now = time()
 | ||
|  | 
 | ||
|  |     if output_fd in evts:
 | ||
|  |       output_queue.put_nowait(received_messages.pop(next_out))
 | ||
|  |       next_out += 1
 | ||
|  | 
 | ||
|  |       if next_out not in received_messages:
 | ||
|  |         poller.modify(output_fd, 0)
 | ||
|  | 
 | ||
|  |     if worker_output_fd in evts:
 | ||
|  |       for receive_time, maybe_sentinel in worker_output_queue.get_multiple_nowait():
 | ||
|  |         # Check for EndSentinel in case of worker crash.
 | ||
|  |         if maybe_sentinel is EndSentinel:
 | ||
|  |           sentinels_received += 1
 | ||
|  |           continue
 | ||
|  |         idx_bytes, value = maybe_sentinel
 | ||
|  |         idx = struct.unpack("<Q", idx_bytes)[0]
 | ||
|  | 
 | ||
|  |         # Don't store duplicates.
 | ||
|  |         sent_message = sent_messages.pop(idx, None)
 | ||
|  |         if sent_message is not None:
 | ||
|  |           received_messages[idx] = value
 | ||
|  |           ltc_stats = _running_mean_var(ltc_stats, receive_time - sent_message[0])
 | ||
|  |           if idx == oldest_sent_idx:
 | ||
|  |             oldest_sent_idx, next_resend_time = _find_next_resend(sent_messages, ltc_stats)
 | ||
|  | 
 | ||
|  |           if idx == next_out:
 | ||
|  |             poller.modify(output_fd, select.EPOLLOUT)
 | ||
|  |         else:
 | ||
|  |           if oldest_sent_idx is not None:
 | ||
|  |             # The message was resent, so it is at least old as the oldest tracked message.
 | ||
|  |             ltc_stats = _running_mean_var(ltc_stats, now - sent_messages[oldest_sent_idx][0])
 | ||
|  |     elif input_fd in evts:
 | ||
|  |       if can_send_new:
 | ||
|  |         # We're under the limit, get the next input.
 | ||
|  |         send_idx, send_value = next_in_item
 | ||
|  |         input_queue.put_nowait((pack_cookie("<Q", send_idx), send_value))
 | ||
|  |         sent_messages[next_in_item[0]] = now, next_in_item[1]
 | ||
|  |         next_in_item = next(input_stream, EndSentinel)
 | ||
|  |         inputs_remain = next_in_item is not EndSentinel
 | ||
|  | 
 | ||
|  |         if oldest_sent_idx is None:
 | ||
|  |           oldest_sent_idx, next_resend_time = _find_next_resend(sent_messages, ltc_stats)
 | ||
|  |       else:
 | ||
|  |         # Move the resent item to the end.
 | ||
|  |         send_time, resend_input = sent_messages.pop(oldest_sent_idx)
 | ||
|  | 
 | ||
|  |         sys.stdout.write("Resending {} (ltc, mean, var) = ({}, {}, {})\n".format(
 | ||
|  |           oldest_sent_idx, now - send_time, ltc_stats[0], ltc_stats[1]))
 | ||
|  |         input_queue.put_nowait((pack_cookie("<Q", oldest_sent_idx), resend_input))
 | ||
|  | 
 | ||
|  |         sent_messages[oldest_sent_idx] = now, resend_input
 | ||
|  |         oldest_sent_idx, next_resend_time = _find_next_resend(sent_messages, ltc_stats)
 | ||
|  | 
 | ||
|  |   # Return remaining messages.
 | ||
|  |   while next_out in received_messages:
 | ||
|  |     output_queue.put(received_messages.pop(next_out))
 | ||
|  |     next_out += 1
 | ||
|  | 
 | ||
|  |   _do_cleanup(input_queue, worker_output_queue, num_workers, sentinels_received, 0)
 | ||
|  |   output_queue.put(EndSentinel)
 | ||
|  | 
 | ||
|  | 
 | ||
|  | def _generate_results_unreliable(input_stream, input_queue, worker_output_queue,
 | ||
|  |                                  output_queue, num_workers, max_outstanding_unused):
 | ||
|  |   # Start things off by pulling the first value.
 | ||
|  |   next_in_item = next(input_stream, EndSentinel)
 | ||
|  |   inputs_remain = next_in_item is not EndSentinel
 | ||
|  | 
 | ||
|  |   # TODO: Use heapq to return oldest message.
 | ||
|  |   received_messages = deque()
 | ||
|  |   pack_cookie = struct.pack
 | ||
|  | 
 | ||
|  |   input_fd = input_queue.put_fd()
 | ||
|  |   worker_output_fd = worker_output_queue.get_fd()
 | ||
|  |   output_fd = output_queue.put_fd()
 | ||
|  | 
 | ||
|  |   poller = select.epoll()
 | ||
|  |   poller.register(input_fd, select.EPOLLOUT)
 | ||
|  |   poller.register(worker_output_fd, select.EPOLLIN)
 | ||
|  |   poller.register(output_fd, 0)
 | ||
|  | 
 | ||
|  |   # Keep sending/retrying until the input stream and sent messages are all done.
 | ||
|  |   num_outstanding = 0
 | ||
|  |   sentinels_received = 0
 | ||
|  |   while sentinels_received < num_workers and (inputs_remain or received_messages):
 | ||
|  |     # Avoid poll() if we can easily detect that there is work to do.
 | ||
|  |     evts = (input_fd if inputs_remain and not input_queue.full() else 0, output_fd
 | ||
|  |             if not output_queue.full() and len(received_messages) else 0, worker_output_fd
 | ||
|  |             if not worker_output_queue.empty() else 0)
 | ||
|  |     if all(evt == 0 for evt in evts):
 | ||
|  |       evts = dict(poller.poll())
 | ||
|  | 
 | ||
|  |     if output_fd in evts:
 | ||
|  |       output_queue.put(received_messages.pop())
 | ||
|  | 
 | ||
|  |       if len(received_messages) == 0:
 | ||
|  |         poller.modify(output_fd, 0)
 | ||
|  | 
 | ||
|  |     if worker_output_fd in evts:
 | ||
|  |       for receive_time, maybe_sentinel in worker_output_queue.get_multiple():
 | ||
|  |         # Check for EndSentinel in case of worker crash.
 | ||
|  |         if maybe_sentinel is EndSentinel:
 | ||
|  |           sentinels_received += 1
 | ||
|  |           continue
 | ||
|  |         received_messages.appendleft(maybe_sentinel[1])
 | ||
|  |         num_outstanding -= 1
 | ||
|  |       poller.modify(output_fd, select.EPOLLOUT)
 | ||
|  | 
 | ||
|  |     if input_fd in evts:
 | ||
|  |       # We're under the limit, get the next input.
 | ||
|  |       send_idx, send_value = next_in_item
 | ||
|  |       input_queue.put((pack_cookie("<Q", send_idx), send_value))
 | ||
|  |       next_in_item = next(input_stream, EndSentinel)
 | ||
|  |       inputs_remain = next_in_item is not EndSentinel
 | ||
|  |       num_outstanding += 1
 | ||
|  | 
 | ||
|  |       if not inputs_remain:
 | ||
|  |         poller.modify(input_fd, 0)
 | ||
|  | 
 | ||
|  |   # TODO: Track latency even though we don't retry.
 | ||
|  |   for value in _do_cleanup(input_queue, worker_output_queue, num_workers,
 | ||
|  |                            sentinels_received, num_outstanding):
 | ||
|  |     output_queue.put(value)
 | ||
|  |   output_queue.put(EndSentinel)
 | ||
|  | 
 | ||
|  | 
 | ||
|  | def _async_generator(func, max_workers, in_q_size, out_q_size, max_outstanding,
 | ||
|  |                      async_inner, reliable):
 | ||
|  |   if async_inner:
 | ||
|  |     assert inspect.isgeneratorfunction(
 | ||
|  |       func), "async_inner == True but {} is not a generator".format(func)
 | ||
|  | 
 | ||
|  |   @functools.wraps(func)
 | ||
|  |   def wrapper(input_sequence_or_self, *args, **kwargs):
 | ||
|  |     # HACK: Determine whether the first arg is "self". ismethod returns False here.
 | ||
|  |     if inspect.getargspec(func).args[0] == "self":
 | ||
|  |       inner_func = func.__get__(input_sequence_or_self, type(input_sequence_or_self))
 | ||
|  |       input_sequence = args[0]
 | ||
|  |       args = args[1:]
 | ||
|  |     else:
 | ||
|  |       inner_func = func
 | ||
|  |       input_sequence = input_sequence_or_self
 | ||
|  |     input_stream = enumerate(iter(input_sequence))
 | ||
|  | 
 | ||
|  |     if reliable:
 | ||
|  |       generate_func = _generate_results
 | ||
|  |     else:
 | ||
|  |       generate_func = _generate_results_unreliable
 | ||
|  | 
 | ||
|  |     input_queue = PollableQueue(in_q_size)
 | ||
|  |     worker_output_queue = PollableQueue(8 * max_workers)
 | ||
|  |     output_queue = PollableQueue(out_q_size)
 | ||
|  | 
 | ||
|  |     # Start the worker threads.
 | ||
|  |     if async_inner:
 | ||
|  |       generator_func = inner_func
 | ||
|  |     else:
 | ||
|  |       args = (inner_func,) + args
 | ||
|  |       generator_func = _sync_inner_generator
 | ||
|  | 
 | ||
|  |     worker_threads = []
 | ||
|  |     for _ in range(max_workers):
 | ||
|  |       t = threading.Thread(
 | ||
|  |         target=_async_streamer_async_inner,
 | ||
|  |         args=(input_queue, worker_output_queue, generator_func, args, kwargs))
 | ||
|  |       t.daemon = True
 | ||
|  |       t.start()
 | ||
|  |       worker_threads.append(t)
 | ||
|  | 
 | ||
|  |     master_thread = threading.Thread(
 | ||
|  |       target=generate_func,
 | ||
|  |       args=(input_stream, input_queue, worker_output_queue, output_queue, max_workers,
 | ||
|  |             max_outstanding))
 | ||
|  |     master_thread.daemon = True
 | ||
|  |     master_thread.start()
 | ||
|  | 
 | ||
|  |     try:
 | ||
|  |       while True:
 | ||
|  |         for value in output_queue.get_multiple():
 | ||
|  |           if value is EndSentinel:
 | ||
|  |             return
 | ||
|  |           else:
 | ||
|  |             yield value
 | ||
|  |     finally:
 | ||
|  |       # Make sure work is done and the threads are stopped.
 | ||
|  |       for t in worker_threads:
 | ||
|  |         t.join(1)
 | ||
|  |       master_thread.join(1)
 | ||
|  | 
 | ||
|  |       input_queue.close()
 | ||
|  |       worker_output_queue.close()
 | ||
|  |       output_queue.close()
 | ||
|  | 
 | ||
|  |   return wrapper
 | ||
|  | 
 | ||
|  | 
 | ||
|  | def async_generator(max_workers=1,
 | ||
|  |                     in_q_size=10,
 | ||
|  |                     out_q_size=12,
 | ||
|  |                     max_outstanding=10000,
 | ||
|  |                     async_inner=False,
 | ||
|  |                     reliable=True):
 | ||
|  |   return (
 | ||
|  |     lambda f: _async_generator(f, max_workers, in_q_size, out_q_size, max_outstanding, async_inner, reliable)
 | ||
|  |   )
 |