|
|
|
@ -4,7 +4,7 @@ import time |
|
|
|
|
import copy |
|
|
|
|
import heapq |
|
|
|
|
import signal |
|
|
|
|
from collections import Counter |
|
|
|
|
from collections import Counter, defaultdict |
|
|
|
|
from dataclasses import dataclass, field |
|
|
|
|
from itertools import islice |
|
|
|
|
from typing import Any |
|
|
|
@ -274,13 +274,13 @@ class ProcessContainer: |
|
|
|
|
assert self.rc and self.pm and self.sockets and self.process.proc |
|
|
|
|
|
|
|
|
|
output_msgs = [] |
|
|
|
|
with self.prefix, Timeout(self.cfg.timeout, error_msg=f"timed out testing process {repr(self.cfg.proc_name)}"): |
|
|
|
|
end_of_cycle = True |
|
|
|
|
if self.cfg.should_recv_callback is not None: |
|
|
|
|
end_of_cycle = self.cfg.should_recv_callback(msg, self.cfg, self.cnt) |
|
|
|
|
end_of_cycle = True |
|
|
|
|
if self.cfg.should_recv_callback is not None: |
|
|
|
|
end_of_cycle = self.cfg.should_recv_callback(msg, self.cfg, self.cnt) |
|
|
|
|
|
|
|
|
|
self.msg_queue.append(msg) |
|
|
|
|
if end_of_cycle: |
|
|
|
|
self.msg_queue.append(msg) |
|
|
|
|
if end_of_cycle: |
|
|
|
|
with self.prefix, Timeout(self.cfg.timeout, error_msg=f"timed out testing process {repr(self.cfg.proc_name)}"): |
|
|
|
|
# call recv to let sub-sockets reconnect, after we know the process is ready |
|
|
|
|
if self.cnt == 0: |
|
|
|
|
for s in self.sockets: |
|
|
|
@ -716,7 +716,9 @@ def _replay_multi_process( |
|
|
|
|
internal_pub_index_heap: list[tuple[int, int]] = [] |
|
|
|
|
|
|
|
|
|
pbar = tqdm(total=len(external_pub_queue), disable=disable_progress) |
|
|
|
|
times = defaultdict(list) |
|
|
|
|
while len(external_pub_queue) != 0 or (len(internal_pub_index_heap) != 0 and not all(c.has_empty_queue for c in containers)): |
|
|
|
|
t = time.monotonic() |
|
|
|
|
if len(internal_pub_index_heap) == 0 or (len(external_pub_queue) != 0 and external_pub_queue[0].logMonoTime < internal_pub_index_heap[0][0]): |
|
|
|
|
msg = external_pub_queue.pop(0) |
|
|
|
|
pbar.update(1) |
|
|
|
@ -724,15 +726,26 @@ def _replay_multi_process( |
|
|
|
|
_, index = heapq.heappop(internal_pub_index_heap) |
|
|
|
|
msg = internal_pub_queue[index] |
|
|
|
|
|
|
|
|
|
# print(f'get msg took {time.monotonic() - t}s') |
|
|
|
|
|
|
|
|
|
t = time.monotonic() |
|
|
|
|
target_containers = pubs_to_containers[msg.which()] |
|
|
|
|
for container in target_containers: |
|
|
|
|
t1 = time.monotonic() |
|
|
|
|
output_msgs = container.run_step(msg, frs) |
|
|
|
|
times[container.cfg.proc_name].append(time.monotonic() - t1) |
|
|
|
|
for m in output_msgs: |
|
|
|
|
if m.which() in all_pubs: |
|
|
|
|
internal_pub_queue.append(m) |
|
|
|
|
heapq.heappush(internal_pub_index_heap, (m.logMonoTime, len(internal_pub_queue) - 1)) |
|
|
|
|
log_msgs.extend(output_msgs) |
|
|
|
|
# print(f'run_step for {container.cfg.proc_name} took {time.monotonic() - t1}s') |
|
|
|
|
# print(f'all run_steps took {time.monotonic() - t}s') |
|
|
|
|
|
|
|
|
|
print("Average run_step times:") |
|
|
|
|
for container, time_list in times.items(): |
|
|
|
|
print(f" {container}: {sum(time_list)}s") |
|
|
|
|
print('Total run_step time: {:.2f}s'.format(sum(sum(time_list) for time_list in times.values()))) |
|
|
|
|
# flush last set of messages from each process |
|
|
|
|
for container in containers: |
|
|
|
|
last_time = log_msgs[-1].logMonoTime if len(log_msgs) > 0 else int(time.monotonic() * 1e9) |
|
|
|
|