Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Route tuner and assessor commands to 2 seperate queues #891

Merged
merged 24 commits into from
Mar 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion docs/en_US/ExperimentConfig.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,17 @@ machineList:
* __classArgs__

__classArgs__ specifies the arguments of tuner algorithm.
* __gpuNum__

* __gpuNum__

__gpuNum__ specifies the gpu number to run the tuner process. The value of this field should be a positive number.

Note: users could only specify one way to set tuner, for example, set {tunerName, optimizationMode} or {tunerCommand, tunerCwd}, and could not set them both.

* __includeIntermediateResults__

If __includeIntermediateResults__ is true, the last intermediate result of the trial that is early stopped by assessor is sent to tuner as final result. The default value of __includeIntermediateResults__ is false.

* __assessor__

* Description
Expand Down
1 change: 1 addition & 0 deletions src/nni_manager/common/manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ interface ExperimentParams {
classFileName?: string;
checkpointDir: string;
gpuNum?: number;
includeIntermediateResults?: boolean;
};
assessor?: {
className: string;
Expand Down
8 changes: 7 additions & 1 deletion src/nni_manager/core/nnimanager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,17 @@ class NNIManager implements Manager {
newCwd = cwd;
}
// TO DO: add CUDA_VISIBLE_DEVICES
let includeIntermediateResultsEnv: boolean | undefined = false;
if (this.experimentProfile.params.tuner !== undefined) {
includeIntermediateResultsEnv = this.experimentProfile.params.tuner.includeIntermediateResults;
}

let nniEnv = {
NNI_MODE: mode,
NNI_CHECKPOINT_DIRECTORY: dataDirectory,
NNI_LOG_DIRECTORY: getLogDir(),
NNI_LOG_LEVEL: getLogLevel()
NNI_LOG_LEVEL: getLogLevel(),
NNI_INCLUDE_INTERMEDIATE_RESULTS: includeIntermediateResultsEnv
};
let newEnv = Object.assign({}, process.env, nniEnv);
const tunerProc: ChildProcess = spawn(command, [], {
Expand Down
2 changes: 1 addition & 1 deletion src/nni_manager/core/test/ipcInterfaceTerminate.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ describe('core/ipcInterface.terminate', (): void => {
assert.ok(!procError);
deferred.resolve();
},
2000);
5000);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same reason as above to accommodate the queue change.


return deferred.promise;
});
Expand Down
3 changes: 2 additions & 1 deletion src/nni_manager/rest_server/restValidationSchemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ export namespace ValidationSchemas {
className: joi.string(),
classArgs: joi.any(),
gpuNum: joi.number().min(0),
checkpointDir: joi.string().allow('')
checkpointDir: joi.string().allow(''),
includeIntermediateResults: joi.boolean()
}),
assessor: joi.object({
builtinAssessorName: joi.string().valid('Medianstop', 'Curvefitting'),
Expand Down
44 changes: 33 additions & 11 deletions src/sdk/pynni/nni/msg_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================

import os
import logging
from collections import defaultdict
import json_tricks
import threading

from .protocol import CommandType, send
from .msg_dispatcher_base import MsgDispatcherBase
from .assessor import AssessResult
from .common import multi_thread_enabled

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -70,7 +71,7 @@ def _pack_parameter(parameter_id, params, customized=False):

class MsgDispatcher(MsgDispatcherBase):
def __init__(self, tuner, assessor=None):
super().__init__()
super(MsgDispatcher, self).__init__()
self.tuner = tuner
self.assessor = assessor
if assessor is None:
Expand All @@ -87,9 +88,8 @@ def save_checkpoint(self):
self.assessor.save_checkpoint()

def handle_initialize(self, data):
'''
data is search space
'''
"""Data is search space
"""
self.tuner.update_search_space(data)
send(CommandType.Initialized, '')
return True
Expand Down Expand Up @@ -126,12 +126,7 @@ def handle_report_metric_data(self, data):
- 'type': report type, support {'FINAL', 'PERIODICAL'}
"""
if data['type'] == 'FINAL':
id_ = data['parameter_id']
value = data['value']
if id_ in _customized_parameter_ids:
self.tuner.receive_customized_trial_result(id_, _trial_params[id_], value)
else:
self.tuner.receive_trial_result(id_, _trial_params[id_], value)
self._handle_final_metric_data(data)
elif data['type'] == 'PERIODICAL':
if self.assessor is not None:
self._handle_intermediate_metric_data(data)
Expand All @@ -157,7 +152,19 @@ def handle_trial_end(self, data):
self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED')
return True

def _handle_final_metric_data(self, data):
"""Call tuner to process final results
"""
id_ = data['parameter_id']
value = data['value']
if id_ in _customized_parameter_ids:
self.tuner.receive_customized_trial_result(id_, _trial_params[id_], value)
else:
self.tuner.receive_trial_result(id_, _trial_params[id_], value)

def _handle_intermediate_metric_data(self, data):
"""Call assessor to process intermediate results
"""
if data['type'] != 'PERIODICAL':
return True
if self.assessor is None:
Expand Down Expand Up @@ -187,5 +194,20 @@ def _handle_intermediate_metric_data(self, data):
if result is AssessResult.Bad:
_logger.debug('BAD, kill %s', trial_job_id)
send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id))
# notify tuner
_logger.debug('env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]', os.environ.get('NNI_INCLUDE_INTERMEDIATE_RESULTS'))
if os.environ.get('NNI_INCLUDE_INTERMEDIATE_RESULTS') == 'true':
self._earlystop_notify_tuner(data)
else:
_logger.debug('GOOD')

def _earlystop_notify_tuner(self, data):
"""Send last intermediate result as final result to tuner in case the
trial is early stopped.
"""
_logger.debug('Early stop notify tuner data: [%s]', data)
data['type'] = 'FINAL'
if multi_thread_enabled():
self._handle_final_metric_data(data)
else:
self.enqueue_command(CommandType.ReportMetricData, data)
85 changes: 68 additions & 17 deletions src/sdk/pynni/nni/msg_dispatcher_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,72 +19,123 @@
# ==================================================================================================

#import json_tricks
import logging
import os
from queue import Queue
import sys

import threading
import logging
from multiprocessing.dummy import Pool as ThreadPool

from queue import Queue, Empty
import json_tricks

from .common import init_logger, multi_thread_enabled
from .recoverable import Recoverable
from .protocol import CommandType, receive

init_logger('dispatcher.log')
_logger = logging.getLogger(__name__)

QUEUE_LEN_WARNING_MARK = 20
_worker_fast_exit_on_terminate = True

class MsgDispatcherBase(Recoverable):
def __init__(self):
if multi_thread_enabled():
self.pool = ThreadPool()
self.thread_results = []
else:
self.stopping = False
self.default_command_queue = Queue()
self.assessor_command_queue = Queue()
self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,))
self.assessor_worker = threading.Thread(target=self.command_queue_worker, args=(self.assessor_command_queue,))
self.default_worker.start()
self.assessor_worker.start()
self.worker_exceptions = []

def run(self):
"""Run the tuner.
This function will never return unless raise.
"""
_logger.info('Start dispatcher')
mode = os.getenv('NNI_MODE')
if mode == 'resume':
self.load_checkpoint()

while True:
_logger.debug('waiting receive_message')
command, data = receive()
if data:
data = json_tricks.loads(data)

if command is None or command is CommandType.Terminate:
break
if multi_thread_enabled():
result = self.pool.map_async(self.handle_request_thread, [(command, data)])
result = self.pool.map_async(self.process_command_thread, [(command, data)])
self.thread_results.append(result)
if any([thread_result.ready() and not thread_result.successful() for thread_result in self.thread_results]):
_logger.debug('Caught thread exception')
break
else:
self.handle_request((command, data))
self.enqueue_command(command, data)

_logger.info('Dispatcher exiting...')
self.stopping = True
if multi_thread_enabled():
self.pool.close()
self.pool.join()
else:
self.default_worker.join()
self.assessor_worker.join()

_logger.info('Terminated by NNI manager')

def handle_request_thread(self, request):
def command_queue_worker(self, command_queue):
"""Process commands in command queues.
"""
while True:
try:
# set timeout to ensure self.stopping is checked periodically
command, data = command_queue.get(timeout=3)
try:
self.process_command(command, data)
except Exception as e:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we only catch the expected exception?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The process_command calls Tuner or Assessor code, we can not know the exception types it will raise, if we do not capture them, they will be raise to nowhere because this code path is in a worker thread.

_logger.exception(e)
self.worker_exceptions.append(e)
break
except Empty:
pass
if self.stopping and (_worker_fast_exit_on_terminate or command_queue.empty()):
break

def enqueue_command(self, command, data):
"""Enqueue command into command queues
"""
if command == CommandType.TrialEnd or (command == CommandType.ReportMetricData and data['type'] == 'PERIODICAL'):
self.assessor_command_queue.put((command, data))
else:
self.default_command_queue.put((command, data))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please make sure the queue put is thread-safe, because both main thread and assessor thread will put record to this queue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, merge for integration test, will check this later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checked:
Queue is thread safe.
https://docs.python.org/3/library/queue.html
The queue module implements multi-producer, multi-consumer queues. It is especially useful in threaded programming when information must be exchanged safely between multiple threads. The Queue class in this module implements all the required locking semantics. It depends on the availability of thread support in Python; see the threading module.

Reference: https://github.com/python/cpython/blob/3.7/Lib/queue.py self.not_full lock is used with put method.


qsize = self.default_command_queue.qsize()
if qsize >= QUEUE_LEN_WARNING_MARK:
_logger.warning('default queue length: %d', qsize)

qsize = self.assessor_command_queue.qsize()
if qsize >= QUEUE_LEN_WARNING_MARK:
_logger.warning('assessor queue length: %d', qsize)

def process_command_thread(self, request):
"""Worker thread to process a command.
"""
command, data = request
if multi_thread_enabled():
try:
self.handle_request(request)
self.process_command(command, data)
except Exception as e:
_logger.exception(str(e))
raise
else:
pass

def handle_request(self, request):
command, data = request

_logger.debug('handle request: command: [{}], data: [{}]'.format(command, data))

if data:
data = json_tricks.loads(data)
def process_command(self, command, data):
_logger.debug('process_command: command: [{}], data: [{}]'.format(command, data))

command_handlers = {
# Tunner commands:
Expand Down
2 changes: 1 addition & 1 deletion src/sdk/pynni/nni/multi_phase/multi_phase_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, p

class MultiPhaseMsgDispatcher(MsgDispatcherBase):
def __init__(self, tuner, assessor=None):
super()
super(MultiPhaseMsgDispatcher, self).__init__()
self.tuner = tuner
self.assessor = assessor
if assessor is None:
Expand Down
9 changes: 3 additions & 6 deletions src/sdk/pynni/nni/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,10 @@ class CommandType(Enum):
NoMoreTrialJobs = b'NO'
KillTrialJob = b'KI'


_lock = threading.Lock()
try:
_in_file = open(3, 'rb')
_out_file = open(4, 'wb')
_lock = threading.Lock()
except OSError:
_msg = 'IPC pipeline not exists, maybe you are importing tuner/assessor from trial code?'
import logging
Expand All @@ -60,17 +59,15 @@ def send(command, data):
"""
global _lock
try:
if multi_thread_enabled():
_lock.acquire()
_lock.acquire()
data = data.encode('utf8')
assert len(data) < 1000000, 'Command too long'
msg = b'%b%06d%b' % (command.value, len(data), data)
logging.getLogger(__name__).debug('Sending command, data: [%s]' % msg)
_out_file.write(msg)
_out_file.flush()
finally:
if multi_thread_enabled():
_lock.release()
_lock.release()


def receive():
Expand Down
13 changes: 7 additions & 6 deletions src/sdk/pynni/tests/test_assessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,12 @@ def test_assessor(self):

assessor = NaiveAssessor()
dispatcher = MsgDispatcher(None, assessor)
try:
dispatcher.run()
except Exception as e:
self.assertIs(type(e), AssertionError)
self.assertEqual(e.args[0], 'Unsupported command: CommandType.NewTrialJob')
nni.msg_dispatcher_base._worker_fast_exit_on_terminate = False

dispatcher.run()
e = dispatcher.worker_exceptions[0]
self.assertIs(type(e), AssertionError)
self.assertEqual(e.args[0], 'Unsupported command: CommandType.NewTrialJob')

self.assertEqual(_trials, ['A', 'B', 'A'])
self.assertEqual(_end_trials, [('A', False), ('B', True)])
Expand All @@ -90,4 +91,4 @@ def test_assessor(self):


if __name__ == '__main__':
main()
main()
11 changes: 6 additions & 5 deletions src/sdk/pynni/tests/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,12 @@ def test_tuner(self):

tuner = NaiveTuner()
dispatcher = MsgDispatcher(tuner)
try:
dispatcher.run()
except Exception as e:
self.assertIs(type(e), AssertionError)
self.assertEqual(e.args[0], 'Unsupported command: CommandType.KillTrialJob')
nni.msg_dispatcher_base._worker_fast_exit_on_terminate = False

dispatcher.run()
e = dispatcher.worker_exceptions[0]
self.assertIs(type(e), AssertionError)
self.assertEqual(e.args[0], 'Unsupported command: CommandType.KillTrialJob')

_reverse_io() # now we are receiving from Tuner's outgoing stream
self._assert_params(0, 2, [ ], None)
Expand Down
4 changes: 2 additions & 2 deletions test/tuner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def run(dispatch_type):
dipsatcher_list = TUNER_LIST if dispatch_type == 'Tuner' else ASSESSOR_LIST
for dispatcher_name in dipsatcher_list:
try:
# sleep 5 seconds here, to make sure previous stopped exp has enough time to exit to avoid port conflict
time.sleep(5)
# Sleep here to make sure previous stopped exp has enough time to exit to avoid port conflict
time.sleep(6)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change this? shall we change the comments at the same time? "# sleep 5 seconds here, to make sure previous stopped exp has enough time to exit to avoid port conflict"....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated the comments.
The reason to change this is: after this tuner/assessor queue change, if there is no commands in queues, it will still possible to wait up to 3 seconds after dispatcher receives TERMINATE command. So it may take a little bit longer time to end for a normal exepriment.

test_builtin_dispatcher(dispatch_type, dispatcher_name)
print(GREEN + 'Test %s %s: TEST PASS' % (dispatcher_name, dispatch_type) + CLEAR)
except Exception as error:
Expand Down
Loading