diff --git a/pkg/suggestion/v1alpha3/nas/enas/AlgorithmSettings.py b/pkg/suggestion/v1alpha3/nas/enas/AlgorithmSettings.py index b678b95a744..064b503a533 100644 --- a/pkg/suggestion/v1alpha3/nas/enas/AlgorithmSettings.py +++ b/pkg/suggestion/v1alpha3/nas/enas/AlgorithmSettings.py @@ -11,9 +11,10 @@ "controller_train_steps": [int, [1, 'inf']], "controller_log_every_steps": [int, [1, 'inf']], } +enableNoneSettingsList = [ + "controller_temperature", "controller_tanh_const", "controller_entropy_weight", "controller_skip_weight"] -# TODO: Enable to add None values, e.g in controller_temperature parameter def parseAlgorithmSettings(settings_raw): algorithm_settings_default = { @@ -32,7 +33,10 @@ def parseAlgorithmSettings(settings_raw): for setting in settings_raw: s_name = setting.name s_value = setting.value - s_type = algorithmSettingsValidator[s_name][0] - algorithm_settings_default[s_name] = s_type(s_value) + if s_value == "None": + algorithm_settings_default[s_name] = None + else: + s_type = algorithmSettingsValidator[s_name][0] + algorithm_settings_default[s_name] = s_type(s_value) return algorithm_settings_default diff --git a/pkg/suggestion/v1alpha3/nas/enas/Controller.py b/pkg/suggestion/v1alpha3/nas/enas/Controller.py index a6a76caf3b9..cb3ec4021b7 100755 --- a/pkg/suggestion/v1alpha3/nas/enas/Controller.py +++ b/pkg/suggestion/v1alpha3/nas/enas/Controller.py @@ -190,7 +190,7 @@ def build_trainer(self): normalize = tf.dtypes.cast((self.num_layers * (self.num_layers - 1) / 2), tf.float32) self.skip_rate = tf.dtypes.cast((self.skip_count / normalize), tf.float32) - if self.controller_entropy_weight: + if self.controller_entropy_weight is not None: self.reward += self.controller_entropy_weight * self.sample_entropy self.sample_log_probs = tf.reduce_sum(self.sample_log_probs) diff --git a/pkg/suggestion/v1alpha3/nas/enas_service.py b/pkg/suggestion/v1alpha3/nas/enas_service.py index 96aff334488..0419610c41b 100644 --- a/pkg/suggestion/v1alpha3/nas/enas_service.py +++ b/pkg/suggestion/v1alpha3/nas/enas_service.py @@ -9,7 +9,8 @@ from pkg.apis.manager.v1alpha3.python import api_pb2_grpc from pkg.suggestion.v1alpha3.nas.enas.Controller import Controller from pkg.suggestion.v1alpha3.nas.enas.Operation import SearchSpace -from pkg.suggestion.v1alpha3.nas.enas.AlgorithmSettings import parseAlgorithmSettings, algorithmSettingsValidator +from pkg.suggestion.v1alpha3.nas.enas.AlgorithmSettings import ( + parseAlgorithmSettings, algorithmSettingsValidator, enableNoneSettingsList) from pkg.suggestion.v1alpha3.base_health_service import HealthServicer @@ -122,8 +123,6 @@ def print_algorithm_settings(self): spec, self.algorithm_settings[spec])) self.logger.info("") - self.logger.info("RequestNumber:\t\t\t{}".format(self.num_trials)) - self.logger.info("") class EnasService(api_pb2_grpc.SuggestionServicer, HealthServicer): @@ -205,6 +204,8 @@ def ValidateAlgorithmSettings(self, request, context): settings_raw = request.experiment.spec.algorithm.algorithm_setting for setting in settings_raw: if setting.name in algorithmSettingsValidator.keys(): + if setting.name in enableNoneSettingsList and setting.value == "None": + continue setting_type = algorithmSettingsValidator[setting.name][0] setting_range = algorithmSettingsValidator[setting.name][1] try: @@ -243,6 +244,10 @@ def GetSuggestions(self, request, context): self.logger.info("-" * 100 + "\nSuggestion Step {} for Experiment {}\n".format( experiment.suggestion_step, experiment.experiment_name) + "-" * 100) + self.logger.info("") + self.logger.info(">>> RequestNumber:\t\t{}".format(experiment.num_trials)) + self.logger.info("") + with experiment.tf_graph.as_default(): saver = tf.compat.v1.train.Saver() ctrl = experiment.controller