Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Solve bug caused by scientific calculation errors #828

Merged
merged 2 commits into from
Mar 12, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

_next_parameter_id = 0
_KEY = 'STEPS'
_epsilon = 1e-6

@unique
class OptimizeMode(Enum):
Expand Down Expand Up @@ -141,8 +142,8 @@ def __init__(self, s, s_max, eta, R, optimize_mode):
self.bracket_id = s
self.s_max = s_max
self.eta = eta
self.n = math.ceil((s_max + 1) * (eta**s) / (s + 1)) # pylint: disable=invalid-name
self.r = math.ceil(R / eta**s) # pylint: disable=invalid-name
self.n = math.ceil((s_max + 1) * (eta**s) / (s + 1) - _epsilon) # pylint: disable=invalid-name
self.r = math.ceil(R / eta**s - _epsilon) # pylint: disable=invalid-name
self.i = 0
self.hyper_configs = [] # [ {id: params}, {}, ... ]
self.configs_perf = [] # [ {id: [seq, acc]}, {}, ... ]
Expand All @@ -157,7 +158,7 @@ def is_completed(self):

def get_n_r(self):
"""return the values of n and r for the next round"""
return math.floor(self.n / self.eta**self.i), self.r * self.eta**self.i
return math.floor(self.n / self.eta**self.i + _epsilon), self.r * self.eta**self.i

def increase_i(self):
"""i means the ith round. Increase i by 1"""
Expand Down Expand Up @@ -305,7 +306,7 @@ def __init__(self, R, eta=3, optimize_mode='maximize'):
self.brackets = dict() # dict of Bracket
self.generated_hyper_configs = [] # all the configs waiting for run
self.completed_hyper_configs = [] # all the completed configs
self.s_max = math.floor(math.log(self.R, self.eta))
self.s_max = math.floor(math.log(self.R, self.eta) + _epsilon)
self.curr_s = self.s_max

self.searchspace_json = None
Expand Down