diff --git a/examples/nasjob-example-RL.yaml b/examples/nasjob-example-RL.yaml index 3942aed635e..a64215857d9 100644 --- a/examples/nasjob-example-RL.yaml +++ b/examples/nasjob-example-RL.yaml @@ -149,6 +149,7 @@ spec: restartPolicy: Never suggestionSpec: suggestionAlgorithm: "nasrl" + requestNumber: 3 suggestionParameters: - name: "lstm_num_cells" value: "64" diff --git a/pkg/suggestion/NAS_Reinforcement_Learning/Controller.py b/pkg/suggestion/NAS_Reinforcement_Learning/Controller.py index 2c87868a573..937dc130b1a 100755 --- a/pkg/suggestion/NAS_Reinforcement_Learning/Controller.py +++ b/pkg/suggestion/NAS_Reinforcement_Learning/Controller.py @@ -31,7 +31,7 @@ def __init__(self, logger=None): self.logger = logger - self.logger.info("Building Controller") + self.logger.info(">>> Building Controller") self.num_layers = num_layers self.num_operations = num_operations @@ -87,7 +87,7 @@ def _create_params(self): def _build_sampler(self): """Build the sampler ops and the log_prob ops.""" - self.logger.info("Building Controller Sampler") + self.logger.info(">>> Building Controller Sampler") anchors = [] anchors_w_1 = [] diff --git a/pkg/suggestion/NAS_Reinforcement_Learning/README.md b/pkg/suggestion/NAS_Reinforcement_Learning/README.md index 89b9735f1ff..25be4cbd88f 100644 --- a/pkg/suggestion/NAS_Reinforcement_Learning/README.md +++ b/pkg/suggestion/NAS_Reinforcement_Learning/README.md @@ -25,9 +25,8 @@ If n = 12, m = 6, the definition of an architecture will be like: There are n rows, the ith row has i elements and describes the ith layer. Please notice that layer 0 is the input and is not included in this definition. -In each row: -The first integer ranges from 0 to m-1, indicates the operation in this layer. -The next (i-1) integers is either 0 or 1. The kth (k>=2) integer indicates whether (k-2)th layer has a skip connection with this layer. (There will always be a connection from (k-1)th layer to kth layer) +In each row, the first integer ranges from 0 to m-1 and indicates the operation in this layer. +Starting from the second position, the kth integer is a boolean value that indicates whether (k-2)th layer has a skip connection with this layer. (There will always be a connection from (k-1)th layer to kth layer) ## Output of `GetSuggestion()` The output of `GetSuggestion()` consists of two parts: `architecture` and `nn_config`. @@ -122,6 +121,6 @@ This neural architecture can be visualized as ![a neural netowrk architecure example](example.png) ## To Do -1. Add support for multiple trials -2. Change LSTM cell from self defined functions in LSTM.py to `tf.nn.rnn_cell.LSTMCell` -3. Store the suggestion checkpoint to PVC to protect against unexpected nasrl service pod restarts +1. Change LSTM cell from self defined functions in LSTM.py to `tf.nn.rnn_cell.LSTMCell` +2. Store the suggestion checkpoint to PVC to protect against unexpected nasrl service pod restarts +3. Add `RequestCount` into API so that the suggestion can clean the information of completed studies. diff --git a/pkg/suggestion/NAS_Reinforcement_Learning/SuggestionParam.py b/pkg/suggestion/NAS_Reinforcement_Learning/SuggestionParam.py index ae9f1d19777..ace456782b3 100644 --- a/pkg/suggestion/NAS_Reinforcement_Learning/SuggestionParam.py +++ b/pkg/suggestion/NAS_Reinforcement_Learning/SuggestionParam.py @@ -1,34 +1,34 @@ def parseSuggestionParam(params_raw): param_standard = { - "lstm_num_cells": ['value', int, [1, 'inf']], - "lstm_num_layers": ['value', int, [1, 'inf']], - "lstm_keep_prob": ['value', float, [0.0, 1.0]], - "optimizer": ['categorical', str, ["adam", "momentum", "sgd"]], - "init_learning_rate": ['value', float, [1e-6, 1.0]], - "lr_decay_start": ['value', int, [0, 'inf']], - "lr_decay_every": ['value', int, [1, 'inf']], - "lr_decay_rate": ['value', float, [0.0, 1.0]], - "skip-target": ['value', float, [0.0, 1.0]], - "skip-weight": ['value', float, [0.0, 'inf']], - "l2_reg": ['value', float, [0.0, 'inf']], - "entropy_weight": ['value', float, [0.0, 'inf']], - "baseline_decay": ['value', float, [0.0, 1.0]], + "lstm_num_cells": ['value', int, [1, 'inf']], + "lstm_num_layers": ['value', int, [1, 'inf']], + "lstm_keep_prob": ['value', float, [0.0, 1.0]], + "optimizer": ['categorical', str, ["adam", "momentum", "sgd"]], + "init_learning_rate": ['value', float, [1e-6, 1.0]], + "lr_decay_start": ['value', int, [0, 'inf']], + "lr_decay_every": ['value', int, [1, 'inf']], + "lr_decay_rate": ['value', float, [0.0, 1.0]], + "skip-target": ['value', float, [0.0, 1.0]], + "skip-weight": ['value', float, [0.0, 'inf']], + "l2_reg": ['value', float, [0.0, 'inf']], + "entropy_weight": ['value', float, [0.0, 'inf']], + "baseline_decay": ['value', float, [0.0, 1.0]], } suggestion_params = { - "lstm_num_cells": 64, - "lstm_num_layers": 1, - "lstm_keep_prob": 1.0, - "optimizer": "adam", - "init_learning_rate": 1e-3, - "lr_decay_start": 0, - "lr_decay_every": 1000, - "lr_decay_rate": 0.9, - "skip-target": 0.4, - "skip-weight": 0.8, - "l2_reg": 0, - "entropy_weight": 1e-4, - "baseline_decay": 0.9999 + "lstm_num_cells": 64, + "lstm_num_layers": 1, + "lstm_keep_prob": 1.0, + "optimizer": "adam", + "init_learning_rate": 1e-3, + "lr_decay_start": 0, + "lr_decay_every": 1000, + "lr_decay_rate": 0.9, + "skip-target": 0.4, + "skip-weight": 0.8, + "l2_reg": 0, + "entropy_weight": 1e-4, + "baseline_decay": 0.9999 } def checktype(param_name, param_value, check_mode, supposed_type, supposed_range=None): diff --git a/pkg/suggestion/nasrl_service.py b/pkg/suggestion/nasrl_service.py index d571784ad5e..ebba3e13524 100644 --- a/pkg/suggestion/nasrl_service.py +++ b/pkg/suggestion/nasrl_service.py @@ -9,18 +9,22 @@ from logging import getLogger, StreamHandler, INFO, DEBUG import json import os +import time + MANAGER_ADDRESS = "vizier-core" MANAGER_PORT = 6789 + class NAS_RL_StudyJob(object): def __init__(self, request, logger): self.logger = logger self.study_id = request.study_id self.param_id = request.param_id + self.num_trials = request.request_number self.study_name = None self.tf_graph = tf.Graph() - self.prev_trial_id = None + self.prev_trial_ids = list() self.ctrl_cache_file = "ctrl_cache/{}/{}.ckpt".format(request.study_id, request.study_id) self.ctrl_step = 0 self.is_first_run = True @@ -38,7 +42,7 @@ def __init__(self, request, logger): self._get_study_param() self._get_suggestion_param() self._setup_controller() - self.logger.info("Suggestion for StudyJob {} (ID: {}) has been initialized.\n".format(self.study_name, self.study_id)) + self.logger.info(">>> Suggestion for StudyJob {} (ID: {}) has been initialized.\n".format(self.study_name, self.study_id)) def _get_study_param(self): # this function need to @@ -111,7 +115,7 @@ def print_search_space(self): self.logger.warning("Error! The Suggestion has not yet been initialized!") return - self.logger.info("Search Space for StudyJob {} (ID: {}):".format(self.study_name, self.study_id)) + self.logger.info(">>> Search Space for StudyJob {} (ID: {}):".format(self.study_name, self.study_id)) for opt in self.search_space: opt.print_op(self.logger) self.logger.info("There are {} operations in total.\n".format(self.num_operations)) @@ -121,12 +125,13 @@ def print_suggestion_params(self): self.logger.warning("Error! The Suggestion has not yet been initialized!") return - self.logger.info("Parameters of LSTM Controller for StudyJob {} (ID: {}):".format(self.study_name, self.study_id)) + self.logger.info(">>> Parameters of LSTM Controller for StudyJob {} (ID: {}):".format(self.study_name, self.study_id)) for spec in self.suggestion_config: if len(spec) > 13: self.logger.info("{}: \t{}".format(spec, self.suggestion_config[spec])) else: self.logger.info("{}: \t\t{}".format(spec, self.suggestion_config[spec])) + self.logger.info("RequestNumber:\t\t{}".format(self.num_trials)) self.logger.info("") @@ -185,10 +190,13 @@ def GetSuggestions(self, request, context): controller_ops["train_op"]] if study.is_first_run: - self.logger.info("First time running suggestion for {}. Random architecture will be given.".format(study.study_name)) + self.logger.info(">>> First time running suggestion for {}. Random architecture will be given.".format(study.study_name)) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) - arc = sess.run(controller_ops["sample_arc"]) + candidates = list() + for _ in range(study.num_trials): + candidates.append(sess.run(controller_ops["sample_arc"])) + # TODO: will use PVC to store the checkpoint to protect against unexpected suggestion pod restart saver.save(sess, study.ctrl_cache_file) @@ -201,88 +209,112 @@ def GetSuggestions(self, request, context): valid_acc = ctrl.reward result = self.GetEvaluationResult(study) - # This lstm cell is designed to maximize the metrics - # However, if the user want to minimize the metrics, we can take the negative of the result + # In some rare cases, GetEvaluationResult() may return None + # if GetSuggestions() is called before all the trials are completed + while result is None: + self.logger.warning(">>> GetEvaluationResult() returns None") + time.sleep(20) + result = self.GetEvaluationResult(study) + + # This LSTM network is designed to maximize the metrics + # However, if the user wants to minimize the metrics, we can take the negative of the result if study.opt_direction == api_pb2.MINIMIZE: result = -result loss, entropy, lr, gn, bl, skip, _ = sess.run( fetches=run_ops, feed_dict={valid_acc: result}) - self.logger.info("Suggetion updated. LSTM Controller Reward: {}".format(loss)) + self.logger.info(">>> Suggetion updated. LSTM Controller Reward: {}".format(loss)) - arc = sess.run(controller_ops["sample_arc"]) + candidates = list() + for _ in range(study.num_trials): + candidates.append(sess.run(controller_ops["sample_arc"])) saver.save(sess, study.ctrl_cache_file) - - arc = arc.tolist() - organized_arc = [0 for _ in range(study.num_layers)] - record = 0 - for l in range(study.num_layers): - organized_arc[l] = arc[record: record + l + 1] - record += l + 1 - - nn_config = dict() - nn_config['num_layers'] = study.num_layers - nn_config['input_size'] = study.input_size - nn_config['output_size'] = study.output_size - nn_config['embedding'] = dict() - for l in range(study.num_layers): - opt = organized_arc[l][0] - nn_config['embedding'][opt] = study.search_space[opt].get_dict() - - organized_arc_json = json.dumps(organized_arc) - nn_config_json = json.dumps(nn_config) - - organized_arc_str = str(organized_arc_json).replace('\"', '\'') - nn_config_str = str(nn_config_json).replace('\"', '\'') - - self.logger.info("\nNew Neural Network Architecture (internal representation):") - self.logger.info(organized_arc_json) - self.logger.info("\nCorresponding Seach Space Description:") - self.logger.info(nn_config_str) - self.logger.info("") - - trials = [] - trials.append(api_pb2.Trial( - study_id=request.study_id, - parameter_set=[ - api_pb2.Parameter( - name="architecture", - value=organized_arc_str, - parameter_type= api_pb2.CATEGORICAL), - api_pb2.Parameter( - name="nn_config", - value=nn_config_str, - parameter_type= api_pb2.CATEGORICAL) - ], + + organized_candidates = list() + trials = list() + + for i in range(study.num_trials): + arc = candidates[i].tolist() + organized_arc = [0 for _ in range(study.num_layers)] + record = 0 + for l in range(study.num_layers): + organized_arc[l] = arc[record: record + l + 1] + record += l + 1 + organized_candidates.append(organized_arc) + + nn_config = dict() + nn_config['num_layers'] = study.num_layers + nn_config['input_size'] = study.input_size + nn_config['output_size'] = study.output_size + nn_config['embedding'] = dict() + for l in range(study.num_layers): + opt = organized_arc[l][0] + nn_config['embedding'][opt] = study.search_space[opt].get_dict() + + organized_arc_json = json.dumps(organized_arc) + nn_config_json = json.dumps(nn_config) + + organized_arc_str = str(organized_arc_json).replace('\"', '\'') + nn_config_str = str(nn_config_json).replace('\"', '\'') + + self.logger.info("\n>>> New Neural Network Architecture Candidate #{} (internal representation):".format(i)) + self.logger.info(organized_arc_json) + self.logger.info("\n>>> Corresponding Seach Space Description:") + self.logger.info(nn_config_str) + + trials.append(api_pb2.Trial( + study_id=request.study_id, + parameter_set=[ + api_pb2.Parameter( + name="architecture", + value=organized_arc_str, + parameter_type= api_pb2.CATEGORICAL), + api_pb2.Parameter( + name="nn_config", + value=nn_config_str, + parameter_type= api_pb2.CATEGORICAL) + ], + ) ) - ) + self.prev_trial_ids = list() + self.logger.info("") channel = grpc.beta.implementations.insecure_channel(MANAGER_ADDRESS, MANAGER_PORT) with api_pb2.beta_create_Manager_stub(channel) as client: for i, t in enumerate(trials): ctrep = client.CreateTrial(api_pb2.CreateTrialRequest(trial=t), 10) trials[i].trial_id = ctrep.trial_id - self.logger.info("Trial {} Created\n".format(ctrep.trial_id)) - study.prev_trial_id = ctrep.trial_id + self.prev_trial_ids.append(ctrep.trial_id) + self.logger.info(">>> {} Trials were created:".format(study.num_trials)) + for t in self.prev_trial_ids: + self.logger.info(t) + self.logger.info("") + study.ctrl_step += 1 return api_pb2.GetSuggestionsReply(trials=trials) def GetEvaluationResult(self, study): - worker_list = [] channel = grpc.beta.implementations.insecure_channel(MANAGER_ADDRESS, MANAGER_PORT) with api_pb2.beta_create_Manager_stub(channel) as client: - gwfrep = client.GetWorkerFullInfo(api_pb2.GetWorkerFullInfoRequest(study_id=study.study_id, trial_id=study.prev_trial_id, only_latest_log=True), 10) - worker_list = gwfrep.worker_full_infos - - for w in worker_list: - if w.Worker.status == api_pb2.COMPLETED: - for ml in w.metrics_logs: + gwfrep = client.GetWorkerFullInfo(api_pb2.GetWorkerFullInfoRequest(study_id=study.study_id, only_latest_log=True), 10) + trials_list = gwfrep.worker_full_infos + + completed_trials = dict() + for t in trials_list: + if t.Worker.trial_id in self.prev_trial_ids and t.Worker.status == api_pb2.COMPLETED: + for ml in t.metrics_logs: if ml.name == study.objective_name: - self.logger.info("Evaluation result of previous candidate: {}".format(ml.values[-1].value)) - return float(ml.values[-1].value) - - # TODO: add support for multiple trials + completed_trials[t.Worker.trial_id] = float(ml.values[-1].value) + + if len(completed_trials) == study.num_trials: + self.logger.info(">>> Evaluation results of previous trials:") + for k in completed_trials: + self.logger.info("{}: {}".format(k, completed_trials[k])) + avg_metrics = sum(completed_trials.values()) / study.num_trials + self.logger.info("The average is {}\n".format(avg_metrics)) + + return avg_metrics \ No newline at end of file