Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Route tuner and assessor commands to 2 seperate queues (#891)
Browse files Browse the repository at this point in the history
1. Route tuner and assessor commands to 2 seperate queues  issue #841
2. Allow tuner to leverage intermediate result when trial is early stopped.  issue #843
  • Loading branch information
chicm-ms authored Mar 22, 2019
1 parent c297650 commit 63697ec
Show file tree
Hide file tree
Showing 13 changed files with 138 additions and 52 deletions.
7 changes: 6 additions & 1 deletion docs/en_US/ExperimentConfig.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,17 @@ machineList:
* __classArgs__

__classArgs__ specifies the arguments of tuner algorithm.
* __gpuNum__

* __gpuNum__

__gpuNum__ specifies the gpu number to run the tuner process. The value of this field should be a positive number.

Note: users could only specify one way to set tuner, for example, set {tunerName, optimizationMode} or {tunerCommand, tunerCwd}, and could not set them both.

* __includeIntermediateResults__

If __includeIntermediateResults__ is true, the last intermediate result of the trial that is early stopped by assessor is sent to tuner as final result. The default value of __includeIntermediateResults__ is false.

* __assessor__

* Description
Expand Down
1 change: 1 addition & 0 deletions src/nni_manager/common/manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ interface ExperimentParams {
classFileName?: string;
checkpointDir: string;
gpuNum?: number;
includeIntermediateResults?: boolean;
};
assessor?: {
className: string;
Expand Down
8 changes: 7 additions & 1 deletion src/nni_manager/core/nnimanager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -277,11 +277,17 @@ class NNIManager implements Manager {
newCwd = cwd;
}
// TO DO: add CUDA_VISIBLE_DEVICES
let includeIntermediateResultsEnv: boolean | undefined = false;
if (this.experimentProfile.params.tuner !== undefined) {
includeIntermediateResultsEnv = this.experimentProfile.params.tuner.includeIntermediateResults;
}

let nniEnv = {
NNI_MODE: mode,
NNI_CHECKPOINT_DIRECTORY: dataDirectory,
NNI_LOG_DIRECTORY: getLogDir(),
NNI_LOG_LEVEL: getLogLevel()
NNI_LOG_LEVEL: getLogLevel(),
NNI_INCLUDE_INTERMEDIATE_RESULTS: includeIntermediateResultsEnv
};
let newEnv = Object.assign({}, process.env, nniEnv);
const tunerProc: ChildProcess = spawn(command, [], {
Expand Down
2 changes: 1 addition & 1 deletion src/nni_manager/core/test/ipcInterfaceTerminate.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ describe('core/ipcInterface.terminate', (): void => {
assert.ok(!procError);
deferred.resolve();
},
2000);
5000);

return deferred.promise;
});
Expand Down
3 changes: 2 additions & 1 deletion src/nni_manager/rest_server/restValidationSchemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ export namespace ValidationSchemas {
className: joi.string(),
classArgs: joi.any(),
gpuNum: joi.number().min(0),
checkpointDir: joi.string().allow('')
checkpointDir: joi.string().allow(''),
includeIntermediateResults: joi.boolean()
}),
assessor: joi.object({
builtinAssessorName: joi.string().valid('Medianstop', 'Curvefitting'),
Expand Down
44 changes: 33 additions & 11 deletions src/sdk/pynni/nni/msg_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================

import os
import logging
from collections import defaultdict
import json_tricks
import threading

from .protocol import CommandType, send
from .msg_dispatcher_base import MsgDispatcherBase
from .assessor import AssessResult
from .common import multi_thread_enabled

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -70,7 +71,7 @@ def _pack_parameter(parameter_id, params, customized=False):

class MsgDispatcher(MsgDispatcherBase):
def __init__(self, tuner, assessor=None):
super().__init__()
super(MsgDispatcher, self).__init__()
self.tuner = tuner
self.assessor = assessor
if assessor is None:
Expand All @@ -87,9 +88,8 @@ def save_checkpoint(self):
self.assessor.save_checkpoint()

def handle_initialize(self, data):
'''
data is search space
'''
"""Data is search space
"""
self.tuner.update_search_space(data)
send(CommandType.Initialized, '')
return True
Expand Down Expand Up @@ -126,12 +126,7 @@ def handle_report_metric_data(self, data):
- 'type': report type, support {'FINAL', 'PERIODICAL'}
"""
if data['type'] == 'FINAL':
id_ = data['parameter_id']
value = data['value']
if id_ in _customized_parameter_ids:
self.tuner.receive_customized_trial_result(id_, _trial_params[id_], value)
else:
self.tuner.receive_trial_result(id_, _trial_params[id_], value)
self._handle_final_metric_data(data)
elif data['type'] == 'PERIODICAL':
if self.assessor is not None:
self._handle_intermediate_metric_data(data)
Expand All @@ -157,7 +152,19 @@ def handle_trial_end(self, data):
self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED')
return True

def _handle_final_metric_data(self, data):
"""Call tuner to process final results
"""
id_ = data['parameter_id']
value = data['value']
if id_ in _customized_parameter_ids:
self.tuner.receive_customized_trial_result(id_, _trial_params[id_], value)
else:
self.tuner.receive_trial_result(id_, _trial_params[id_], value)

def _handle_intermediate_metric_data(self, data):
"""Call assessor to process intermediate results
"""
if data['type'] != 'PERIODICAL':
return True
if self.assessor is None:
Expand Down Expand Up @@ -187,5 +194,20 @@ def _handle_intermediate_metric_data(self, data):
if result is AssessResult.Bad:
_logger.debug('BAD, kill %s', trial_job_id)
send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id))
# notify tuner
_logger.debug('env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]', os.environ.get('NNI_INCLUDE_INTERMEDIATE_RESULTS'))
if os.environ.get('NNI_INCLUDE_INTERMEDIATE_RESULTS') == 'true':
self._earlystop_notify_tuner(data)
else:
_logger.debug('GOOD')

def _earlystop_notify_tuner(self, data):
"""Send last intermediate result as final result to tuner in case the
trial is early stopped.
"""
_logger.debug('Early stop notify tuner data: [%s]', data)
data['type'] = 'FINAL'
if multi_thread_enabled():
self._handle_final_metric_data(data)
else:
self.enqueue_command(CommandType.ReportMetricData, data)
85 changes: 68 additions & 17 deletions src/sdk/pynni/nni/msg_dispatcher_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,72 +19,123 @@
# ==================================================================================================

#import json_tricks
import logging
import os
from queue import Queue
import sys

import threading
import logging
from multiprocessing.dummy import Pool as ThreadPool

from queue import Queue, Empty
import json_tricks

from .common import init_logger, multi_thread_enabled
from .recoverable import Recoverable
from .protocol import CommandType, receive

init_logger('dispatcher.log')
_logger = logging.getLogger(__name__)

QUEUE_LEN_WARNING_MARK = 20
_worker_fast_exit_on_terminate = True

class MsgDispatcherBase(Recoverable):
def __init__(self):
if multi_thread_enabled():
self.pool = ThreadPool()
self.thread_results = []
else:
self.stopping = False
self.default_command_queue = Queue()
self.assessor_command_queue = Queue()
self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,))
self.assessor_worker = threading.Thread(target=self.command_queue_worker, args=(self.assessor_command_queue,))
self.default_worker.start()
self.assessor_worker.start()
self.worker_exceptions = []

def run(self):
"""Run the tuner.
This function will never return unless raise.
"""
_logger.info('Start dispatcher')
mode = os.getenv('NNI_MODE')
if mode == 'resume':
self.load_checkpoint()

while True:
_logger.debug('waiting receive_message')
command, data = receive()
if data:
data = json_tricks.loads(data)

if command is None or command is CommandType.Terminate:
break
if multi_thread_enabled():
result = self.pool.map_async(self.handle_request_thread, [(command, data)])
result = self.pool.map_async(self.process_command_thread, [(command, data)])
self.thread_results.append(result)
if any([thread_result.ready() and not thread_result.successful() for thread_result in self.thread_results]):
_logger.debug('Caught thread exception')
break
else:
self.handle_request((command, data))
self.enqueue_command(command, data)

_logger.info('Dispatcher exiting...')
self.stopping = True
if multi_thread_enabled():
self.pool.close()
self.pool.join()
else:
self.default_worker.join()
self.assessor_worker.join()

_logger.info('Terminated by NNI manager')

def handle_request_thread(self, request):
def command_queue_worker(self, command_queue):
"""Process commands in command queues.
"""
while True:
try:
# set timeout to ensure self.stopping is checked periodically
command, data = command_queue.get(timeout=3)
try:
self.process_command(command, data)
except Exception as e:
_logger.exception(e)
self.worker_exceptions.append(e)
break
except Empty:
pass
if self.stopping and (_worker_fast_exit_on_terminate or command_queue.empty()):
break

def enqueue_command(self, command, data):
"""Enqueue command into command queues
"""
if command == CommandType.TrialEnd or (command == CommandType.ReportMetricData and data['type'] == 'PERIODICAL'):
self.assessor_command_queue.put((command, data))
else:
self.default_command_queue.put((command, data))

qsize = self.default_command_queue.qsize()
if qsize >= QUEUE_LEN_WARNING_MARK:
_logger.warning('default queue length: %d', qsize)

qsize = self.assessor_command_queue.qsize()
if qsize >= QUEUE_LEN_WARNING_MARK:
_logger.warning('assessor queue length: %d', qsize)

def process_command_thread(self, request):
"""Worker thread to process a command.
"""
command, data = request
if multi_thread_enabled():
try:
self.handle_request(request)
self.process_command(command, data)
except Exception as e:
_logger.exception(str(e))
raise
else:
pass

def handle_request(self, request):
command, data = request

_logger.debug('handle request: command: [{}], data: [{}]'.format(command, data))

if data:
data = json_tricks.loads(data)
def process_command(self, command, data):
_logger.debug('process_command: command: [{}], data: [{}]'.format(command, data))

command_handlers = {
# Tunner commands:
Expand Down
2 changes: 1 addition & 1 deletion src/sdk/pynni/nni/multi_phase/multi_phase_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, p

class MultiPhaseMsgDispatcher(MsgDispatcherBase):
def __init__(self, tuner, assessor=None):
super()
super(MultiPhaseMsgDispatcher, self).__init__()
self.tuner = tuner
self.assessor = assessor
if assessor is None:
Expand Down
9 changes: 3 additions & 6 deletions src/sdk/pynni/nni/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,10 @@ class CommandType(Enum):
NoMoreTrialJobs = b'NO'
KillTrialJob = b'KI'


_lock = threading.Lock()
try:
_in_file = open(3, 'rb')
_out_file = open(4, 'wb')
_lock = threading.Lock()
except OSError:
_msg = 'IPC pipeline not exists, maybe you are importing tuner/assessor from trial code?'
import logging
Expand All @@ -60,17 +59,15 @@ def send(command, data):
"""
global _lock
try:
if multi_thread_enabled():
_lock.acquire()
_lock.acquire()
data = data.encode('utf8')
assert len(data) < 1000000, 'Command too long'
msg = b'%b%06d%b' % (command.value, len(data), data)
logging.getLogger(__name__).debug('Sending command, data: [%s]' % msg)
_out_file.write(msg)
_out_file.flush()
finally:
if multi_thread_enabled():
_lock.release()
_lock.release()


def receive():
Expand Down
13 changes: 7 additions & 6 deletions src/sdk/pynni/tests/test_assessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,12 @@ def test_assessor(self):

assessor = NaiveAssessor()
dispatcher = MsgDispatcher(None, assessor)
try:
dispatcher.run()
except Exception as e:
self.assertIs(type(e), AssertionError)
self.assertEqual(e.args[0], 'Unsupported command: CommandType.NewTrialJob')
nni.msg_dispatcher_base._worker_fast_exit_on_terminate = False

dispatcher.run()
e = dispatcher.worker_exceptions[0]
self.assertIs(type(e), AssertionError)
self.assertEqual(e.args[0], 'Unsupported command: CommandType.NewTrialJob')

self.assertEqual(_trials, ['A', 'B', 'A'])
self.assertEqual(_end_trials, [('A', False), ('B', True)])
Expand All @@ -90,4 +91,4 @@ def test_assessor(self):


if __name__ == '__main__':
main()
main()
11 changes: 6 additions & 5 deletions src/sdk/pynni/tests/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,12 @@ def test_tuner(self):

tuner = NaiveTuner()
dispatcher = MsgDispatcher(tuner)
try:
dispatcher.run()
except Exception as e:
self.assertIs(type(e), AssertionError)
self.assertEqual(e.args[0], 'Unsupported command: CommandType.KillTrialJob')
nni.msg_dispatcher_base._worker_fast_exit_on_terminate = False

dispatcher.run()
e = dispatcher.worker_exceptions[0]
self.assertIs(type(e), AssertionError)
self.assertEqual(e.args[0], 'Unsupported command: CommandType.KillTrialJob')

_reverse_io() # now we are receiving from Tuner's outgoing stream
self._assert_params(0, 2, [ ], None)
Expand Down
4 changes: 2 additions & 2 deletions test/tuner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def run(dispatch_type):
dipsatcher_list = TUNER_LIST if dispatch_type == 'Tuner' else ASSESSOR_LIST
for dispatcher_name in dipsatcher_list:
try:
# sleep 5 seconds here, to make sure previous stopped exp has enough time to exit to avoid port conflict
time.sleep(5)
# Sleep here to make sure previous stopped exp has enough time to exit to avoid port conflict
time.sleep(6)
test_builtin_dispatcher(dispatch_type, dispatcher_name)
print(GREEN + 'Test %s %s: TEST PASS' % (dispatcher_name, dispatch_type) + CLEAR)
except Exception as error:
Expand Down
Loading

0 comments on commit 63697ec

Please sign in to comment.