Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
fix bug in smac search space convert (#940)
Browse files Browse the repository at this point in the history
* fix bug in smac search space convert

* update docstring
  • Loading branch information
QuanluZhang authored and yds05 committed Apr 1, 2019
1 parent 151013a commit c49c24c
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions src/sdk/pynni/nni/smac_tuner/smac_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,17 +192,20 @@ def convert_loguniform_categorical(self, challenger_dict):
Returns
-------
dict
challenger dict
dict which stores copy of challengers
"""
converted_dict = {}
for key, value in challenger_dict.items():
# convert to loguniform
if key in self.loguniform_key:
challenger_dict[key] = np.exp(challenger_dict[key])
converted_dict[key] = np.exp(challenger_dict[key])
# convert categorical back to original value
if key in self.categorical_dict:
elif key in self.categorical_dict:
idx = challenger_dict[key]
challenger_dict[key] = self.categorical_dict[key][idx]
return challenger_dict
converted_dict[key] = self.categorical_dict[key][idx]
else:
converted_dict[key] = value
return converted_dict

def generate_parameters(self, parameter_id):
"""generate one instance of hyperparameters
Expand All @@ -220,13 +223,11 @@ def generate_parameters(self, parameter_id):
if self.first_one:
init_challenger = self.smbo_solver.nni_smac_start()
self.total_data[parameter_id] = init_challenger
json_tricks.dumps(init_challenger.get_dictionary())
return self.convert_loguniform_categorical(init_challenger.get_dictionary())
else:
challengers = self.smbo_solver.nni_smac_request_challengers()
for challenger in challengers:
self.total_data[parameter_id] = challenger
json_tricks.dumps(challenger.get_dictionary())
return self.convert_loguniform_categorical(challenger.get_dictionary())

def generate_multiple_parameters(self, parameter_id_list):
Expand All @@ -247,7 +248,6 @@ def generate_multiple_parameters(self, parameter_id_list):
for one_id in parameter_id_list:
init_challenger = self.smbo_solver.nni_smac_start()
self.total_data[one_id] = init_challenger
json_tricks.dumps(init_challenger.get_dictionary())
params.append(self.convert_loguniform_categorical(init_challenger.get_dictionary()))
else:
challengers = self.smbo_solver.nni_smac_request_challengers()
Expand All @@ -257,7 +257,6 @@ def generate_multiple_parameters(self, parameter_id_list):
if cnt >= len(parameter_id_list):
break
self.total_data[parameter_id_list[cnt]] = challenger
json_tricks.dumps(challenger.get_dictionary())
params.append(self.convert_loguniform_categorical(challenger.get_dictionary()))
cnt += 1
return params

0 comments on commit c49c24c

Please sign in to comment.