-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Route tuner and assessor commands to 2 seperate queues #891
Changes from all commits
fa40ef7
79fc529
2ba38a2
dcc19d7
43bb9e8
9074f84
7846d93
74584e1
8af3f8b
ffa6623
0d94a8e
d005790
5321ed3
7f2ab19
f2a2f2e
556f8ee
bd5084a
b17b15f
f56c765
2ca2fc6
fc1ece0
8a21c97
96fceb0
cc425e5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,72 +19,123 @@ | |
# ================================================================================================== | ||
|
||
#import json_tricks | ||
import logging | ||
import os | ||
from queue import Queue | ||
import sys | ||
|
||
import threading | ||
import logging | ||
from multiprocessing.dummy import Pool as ThreadPool | ||
|
||
from queue import Queue, Empty | ||
import json_tricks | ||
|
||
from .common import init_logger, multi_thread_enabled | ||
from .recoverable import Recoverable | ||
from .protocol import CommandType, receive | ||
|
||
init_logger('dispatcher.log') | ||
_logger = logging.getLogger(__name__) | ||
|
||
QUEUE_LEN_WARNING_MARK = 20 | ||
_worker_fast_exit_on_terminate = True | ||
|
||
class MsgDispatcherBase(Recoverable): | ||
def __init__(self): | ||
if multi_thread_enabled(): | ||
self.pool = ThreadPool() | ||
self.thread_results = [] | ||
else: | ||
self.stopping = False | ||
self.default_command_queue = Queue() | ||
self.assessor_command_queue = Queue() | ||
self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,)) | ||
self.assessor_worker = threading.Thread(target=self.command_queue_worker, args=(self.assessor_command_queue,)) | ||
self.default_worker.start() | ||
self.assessor_worker.start() | ||
self.worker_exceptions = [] | ||
|
||
def run(self): | ||
"""Run the tuner. | ||
This function will never return unless raise. | ||
""" | ||
_logger.info('Start dispatcher') | ||
mode = os.getenv('NNI_MODE') | ||
if mode == 'resume': | ||
self.load_checkpoint() | ||
|
||
while True: | ||
_logger.debug('waiting receive_message') | ||
command, data = receive() | ||
if data: | ||
data = json_tricks.loads(data) | ||
|
||
if command is None or command is CommandType.Terminate: | ||
break | ||
if multi_thread_enabled(): | ||
result = self.pool.map_async(self.handle_request_thread, [(command, data)]) | ||
result = self.pool.map_async(self.process_command_thread, [(command, data)]) | ||
self.thread_results.append(result) | ||
if any([thread_result.ready() and not thread_result.successful() for thread_result in self.thread_results]): | ||
_logger.debug('Caught thread exception') | ||
break | ||
else: | ||
self.handle_request((command, data)) | ||
self.enqueue_command(command, data) | ||
|
||
_logger.info('Dispatcher exiting...') | ||
self.stopping = True | ||
if multi_thread_enabled(): | ||
self.pool.close() | ||
self.pool.join() | ||
else: | ||
self.default_worker.join() | ||
self.assessor_worker.join() | ||
|
||
_logger.info('Terminated by NNI manager') | ||
|
||
def handle_request_thread(self, request): | ||
def command_queue_worker(self, command_queue): | ||
"""Process commands in command queues. | ||
""" | ||
while True: | ||
try: | ||
# set timeout to ensure self.stopping is checked periodically | ||
command, data = command_queue.get(timeout=3) | ||
try: | ||
self.process_command(command, data) | ||
except Exception as e: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we only catch the expected exception? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The process_command calls Tuner or Assessor code, we can not know the exception types it will raise, if we do not capture them, they will be raise to nowhere because this code path is in a worker thread. |
||
_logger.exception(e) | ||
self.worker_exceptions.append(e) | ||
break | ||
except Empty: | ||
pass | ||
if self.stopping and (_worker_fast_exit_on_terminate or command_queue.empty()): | ||
break | ||
|
||
def enqueue_command(self, command, data): | ||
"""Enqueue command into command queues | ||
""" | ||
if command == CommandType.TrialEnd or (command == CommandType.ReportMetricData and data['type'] == 'PERIODICAL'): | ||
self.assessor_command_queue.put((command, data)) | ||
else: | ||
self.default_command_queue.put((command, data)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please make sure the queue put is thread-safe, because both main thread and assessor thread will put record to this queue There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, merge for integration test, will check this later. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just checked: Reference: https://github.com/python/cpython/blob/3.7/Lib/queue.py self.not_full lock is used with put method. |
||
|
||
qsize = self.default_command_queue.qsize() | ||
if qsize >= QUEUE_LEN_WARNING_MARK: | ||
_logger.warning('default queue length: %d', qsize) | ||
|
||
qsize = self.assessor_command_queue.qsize() | ||
if qsize >= QUEUE_LEN_WARNING_MARK: | ||
_logger.warning('assessor queue length: %d', qsize) | ||
|
||
def process_command_thread(self, request): | ||
"""Worker thread to process a command. | ||
""" | ||
command, data = request | ||
if multi_thread_enabled(): | ||
try: | ||
self.handle_request(request) | ||
self.process_command(command, data) | ||
except Exception as e: | ||
_logger.exception(str(e)) | ||
raise | ||
else: | ||
pass | ||
|
||
def handle_request(self, request): | ||
command, data = request | ||
|
||
_logger.debug('handle request: command: [{}], data: [{}]'.format(command, data)) | ||
|
||
if data: | ||
data = json_tricks.loads(data) | ||
def process_command(self, command, data): | ||
_logger.debug('process_command: command: [{}], data: [{}]'.format(command, data)) | ||
|
||
command_handlers = { | ||
# Tunner commands: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -76,8 +76,8 @@ def run(dispatch_type): | |
dipsatcher_list = TUNER_LIST if dispatch_type == 'Tuner' else ASSESSOR_LIST | ||
for dispatcher_name in dipsatcher_list: | ||
try: | ||
# sleep 5 seconds here, to make sure previous stopped exp has enough time to exit to avoid port conflict | ||
time.sleep(5) | ||
# Sleep here to make sure previous stopped exp has enough time to exit to avoid port conflict | ||
time.sleep(6) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why change this? shall we change the comments at the same time? "# sleep 5 seconds here, to make sure previous stopped exp has enough time to exit to avoid port conflict".... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, updated the comments. |
||
test_builtin_dispatcher(dispatch_type, dispatcher_name) | ||
print(GREEN + 'Test %s %s: TEST PASS' % (dispatcher_name, dispatch_type) + CLEAR) | ||
except Exception as error: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why change this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same reason as above to accommodate the queue change.