Skip to content

Commit

Permalink
Refactor ESSOptimizer (#1182)
Browse files Browse the repository at this point in the history
Fixes two log messages. The rest is just some cleanup and shuffling some code around for readability / reducing duplications.
  • Loading branch information
dweindl authored Nov 19, 2023
1 parent 0d83ed9 commit 4a953cd
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 135 deletions.
126 changes: 66 additions & 60 deletions pypesto/optimize/ess/ess.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,7 @@
from pypesto import OptimizerResult, Problem
from pypesto.startpoint import StartpointMethod

from .function_evaluator import (
FunctionEvaluator,
FunctionEvaluatorMP,
FunctionEvaluatorMT,
)
from .function_evaluator import FunctionEvaluator, create_function_evaluator
from .refset import RefSet

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -147,9 +143,7 @@ def __init__(
self.max_iter: int = max_iter
self.max_eval: int = max_eval
self.dim_refset: int = dim_refset
self.local_optimizer: Optional[
'pypesto.optimize.Optimizer'
] = local_optimizer
self.local_optimizer = local_optimizer
self.n_diverse: int = n_diverse
if n_procs is not None and n_threads is not None:
raise ValueError(
Expand Down Expand Up @@ -189,23 +183,15 @@ def _initialize(self):
self.evaluator: Optional[FunctionEvaluator] = None
self.starttime: Optional[float] = None

def minimize(
def _initialize_minimize(
self,
problem: Problem = None,
startpoint_method: StartpointMethod = None,
refset: Optional[RefSet] = None,
) -> pypesto.Result:
"""Minimize the given objective.
):
"""Initialize for optimizations.
Parameters
----------
problem:
Problem to run ESS on.
startpoint_method:
Method for choosing starting points.
**Deprecated. Use ``problem.startpoint_method`` instead.**
refset:
The initial RefSet or ``None`` to auto-generate.
Create initial refset, start timer, ... .
"""
if startpoint_method is not None:
warn(
Expand All @@ -220,8 +206,9 @@ def minimize(
refset is not None and problem is not None
):
raise ValueError(
"Either `refset` or `problem` has to be provided."
"Exactly one of `problem` or `refset` has to be provided."
)

# generate initial RefSet if not provided
if refset is None:
if self.dim_refset is None:
Expand All @@ -230,66 +217,85 @@ def minimize(
)
# [EgeaMar2010]_ 2.1
self.n_diverse = self.n_diverse or 10 * problem.dim
if self.n_procs:
self.evaluator = FunctionEvaluatorMP(
problem=problem,
startpoint_method=startpoint_method,
n_procs=self.n_procs,
)
else:
self.evaluator = FunctionEvaluatorMT(
problem=problem,
startpoint_method=startpoint_method,
n_threads=self.n_threads or 1,
)
self.evaluator = create_function_evaluator(
problem,
startpoint_method,
n_threads=self.n_threads,
n_procs=self.n_procs,
)

self.refset = RefSet(dim=self.dim_refset, evaluator=self.evaluator)
# Initial RefSet generation
self.refset.initialize_random(n_diverse=self.n_diverse)
refset = self.refset
else:
self.refset = refset

self.evaluator = refset.evaluator
self.evaluator = self.refset.evaluator
self.x_best = np.full(
shape=(self.evaluator.problem.dim,), fill_value=np.nan
)
# initialize global best from initial refset
for x, fx in zip(refset.x, refset.fx):
for x, fx in zip(self.refset.x, self.refset.fx):
self._maybe_update_global_best(x, fx)

def minimize(
self,
problem: Problem = None,
startpoint_method: StartpointMethod = None,
refset: Optional[RefSet] = None,
) -> pypesto.Result:
"""Minimize the given objective.
Parameters
----------
problem:
Problem to run ESS on.
startpoint_method:
Method for choosing starting points.
**Deprecated. Use ``problem.startpoint_method`` instead.**
refset:
The initial RefSet or ``None`` to auto-generate.
"""
self._initialize_minimize(
problem=problem, startpoint_method=startpoint_method, refset=refset
)

# [PenasGon2017]_ Algorithm 1
while self._keep_going():
self.x_best_has_changed = False
self._do_iteration()

refset.sort()
self._report_iteration()
refset.prune_too_close()
self._report_final()
return self._create_result()

# Apply combination method to update the RefSet
x_best_children, fx_best_children = self._combine_solutions()
def _do_iteration(self):
"""Perform an ESS iteration."""
self.x_best_has_changed = False

# Go-beyond strategy to further improve the new combinations
self._go_beyond(x_best_children, fx_best_children)
self.refset.sort()
self._report_iteration()
self.refset.prune_too_close()

# Maybe perform a local search
if self.local_optimizer is not None and self._keep_going():
self._do_local_search(x_best_children, fx_best_children)
# Apply combination method to update the RefSet
x_best_children, fx_best_children = self._combine_solutions()

# Replace RefSet members by best children where an improvement
# was made. replace stuck members by random points.
for i in range(refset.dim):
if fx_best_children[i] < refset.fx[i]:
refset.update(i, x_best_children[i], fx_best_children[i])
else:
refset.n_stuck[i] += 1
if refset.n_stuck[i] > self.n_change:
refset.replace_by_random(i)
# Go-beyond strategy to further improve the new combinations
self._go_beyond(x_best_children, fx_best_children)

self.n_iter += 1
# Maybe perform a local search
if self.local_optimizer is not None and self._keep_going():
self._do_local_search(x_best_children, fx_best_children)

self._report_final()
return self._create_result()
# Replace RefSet members by best children where an improvement
# was made. replace stuck members by random points.
for i in range(self.refset.dim):
if fx_best_children[i] < self.refset.fx[i]:
self.refset.update(i, x_best_children[i], fx_best_children[i])
else:
self.refset.n_stuck[i] += 1
if self.refset.n_stuck[i] > self.n_change:
self.refset.replace_by_random(i)

self.n_iter += 1

def _create_result(self) -> pypesto.Result:
"""Create the result object.
Expand Down Expand Up @@ -515,7 +521,7 @@ def _do_local_search(

self.logger.info(
f"Local search: {local_search_fx0} -> {optimizer_result.fval} "
f" took {optimizer_result.time:.3g}s, finished with "
f"took {optimizer_result.time:.3g}s, finished with "
f"{optimizer_result.exitflag}: {optimizer_result.message}"
)
self.local_solutions.append(optimizer_result.x)
Expand Down
32 changes: 32 additions & 0 deletions pypesto/optimize/ess/function_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,35 @@ def multiple(self, xs: Sequence[np.ndarray]) -> np.array:
self.n_eval += len(xs)
self.n_eval_round += len(xs)
return res


def create_function_evaluator(
problem: Problem = None,
startpoint_method: StartpointMethod = None,
n_procs: int = None,
n_threads: int = None,
):
"""Create a FunctionEvaluator.
Based on multiprocessing or multithreading, depending on whether
``n_procs`` (number of processes) or ``n_threads`` (number of threads)
is specified. If neither is specified, a single-threaded evaluator is
returned.
"""
if n_procs and n_threads:
raise ValueError(
"Only one of `n_procs` and `n_threads` may be specified."
)

if n_procs:
return FunctionEvaluatorMP(
problem=problem,
startpoint_method=startpoint_method,
n_procs=n_procs,
)

return FunctionEvaluatorMT(
problem=problem,
startpoint_method=startpoint_method,
n_threads=n_threads or 1,
)
Loading

0 comments on commit 4a953cd

Please sign in to comment.