From 31b2b8867e2f6f230684081e56639ea16ffd9715 Mon Sep 17 00:00:00 2001 From: liuzhe-lz <40699903+liuzhe-lz@users.noreply.github.com> Date: Mon, 25 Nov 2019 14:00:55 +0800 Subject: [PATCH] Fix bug introduced in customized trial (#1779) --- src/sdk/pynni/nni/msg_dispatcher.py | 2 +- test/config_test/multi_thread/multi_thread_tuner.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/sdk/pynni/nni/msg_dispatcher.py b/src/sdk/pynni/nni/msg_dispatcher.py index 420db8a566..27cd692ff5 100644 --- a/src/sdk/pynni/nni/msg_dispatcher.py +++ b/src/sdk/pynni/nni/msg_dispatcher.py @@ -184,7 +184,7 @@ def _handle_final_metric_data(self, data): """ id_ = data['parameter_id'] value = data['value'] - if not id_ or id_ in _customized_parameter_ids: + if id_ is None or id_ in _customized_parameter_ids: if not hasattr(self.tuner, '_accept_customized'): self.tuner._accept_customized = False if not self.tuner._accept_customized: diff --git a/test/config_test/multi_thread/multi_thread_tuner.py b/test/config_test/multi_thread/multi_thread_tuner.py index e6db0d39a4..216b696aea 100644 --- a/test/config_test/multi_thread/multi_thread_tuner.py +++ b/test/config_test/multi_thread/multi_thread_tuner.py @@ -1,3 +1,4 @@ +import logging import time from nni.tuner import Tuner @@ -7,14 +8,18 @@ def __init__(self): self.parent_done = False def generate_parameters(self, parameter_id, **kwargs): + logging.debug('generate_parameters: %s %s', parameter_id, kwargs) if parameter_id == 0: return {'x': 0} else: while not self.parent_done: + logging.debug('parameter_id %s sleeping', parameter_id) time.sleep(2) + logging.debug('parameter_id %s waked up', parameter_id) return {'x': 1} def receive_trial_result(self, parameter_id, parameters, value, **kwargs): + logging.debug('receive_trial_result: %s %s %s %s', parameter_id, parameters, value, kwargs) if parameter_id == 0: self.parent_done = True