Skip to content

Commit

Permalink
Add startpoint_method to Problem (#1093)
Browse files Browse the repository at this point in the history
It's not intuitive that `x_guesses` is part of `Problem`, but `startpoint_method`s are handled separately. See also discussion in #1017.

This patch
* adds `startpoint_method` to `Problem`
* during PEtab import, sets `Problem.startpoint_method` based on the PEtab problem
* deprecates `startpoint_method` arguments where also `Problem` is passed

Closes #1035
  • Loading branch information
dweindl authored Jul 13, 2023
1 parent daf47dd commit 660862b
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 20 deletions.
10 changes: 9 additions & 1 deletion pypesto/optimize/ess/cess.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import time
from typing import Dict, List, Optional
from warnings import warn

import numpy as np

Expand Down Expand Up @@ -107,7 +108,7 @@ def _initialize(self):
def minimize(
self,
problem: Problem,
startpoint_method: StartpointMethod,
startpoint_method: StartpointMethod = None,
) -> pypesto.Result:
"""Minimize the given objective using CESS.
Expand All @@ -117,7 +118,14 @@ def minimize(
Problem to run ESS on.
startpoint_method:
Method for choosing starting points.
**Deprecated. Use ``problem.startpoint_method`` instead.**
"""
if startpoint_method is not None:
warn(
"Passing `startpoint_method` directly is deprecated, use `problem.startpoint_method` instead.",
DeprecationWarning,
)

self._initialize()

evaluator = FunctionEvaluator(
Expand Down
18 changes: 11 additions & 7 deletions pypesto/optimize/ess/ess.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import logging
import time
from typing import List, Optional, Tuple
from warnings import warn

import numpy as np

Expand Down Expand Up @@ -183,21 +184,24 @@ def minimize(
Problem to run ESS on.
startpoint_method:
Method for choosing starting points.
**Deprecated. Use ``problem.startpoint_method`` instead.**
refset:
The initial RefSet or ``None`` to auto-generate.
"""
if startpoint_method is not None:
warn(
"Passing `startpoint_method` directly is deprecated, use `problem.startpoint_method` instead.",
DeprecationWarning,
)

self._initialize()
self.starttime = time.time()

if (
refset is None and (problem is None or startpoint_method is None)
) or (
refset is not None
and (problem is not None or startpoint_method is not None)
if (refset is None and problem is None) or (
refset is not None and problem is not None
):
raise ValueError(
"Either `refset` or `problem` and `startpoint_method` "
"has to be provided."
"Either `refset` or `problem` has to be provided."
)
# generate initial RefSet if not provided
if refset is None:
Expand Down
20 changes: 16 additions & 4 deletions pypesto/optimize/ess/function_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from typing import Optional, Sequence, Tuple
from warnings import warn

import numpy as np

Expand All @@ -31,17 +32,28 @@ class FunctionEvaluator:
def __init__(
self,
problem: Problem,
startpoint_method: StartpointMethod,
startpoint_method: StartpointMethod = None,
):
"""Construct.
Parameters
----------
problem: The problem
startpoint_method: Method for choosing feasible parameters
problem:
The problem
startpoint_method:
Method for choosing feasible parameters
**Deprecated. Use ``problem.startpoint_method`` instead.**
"""
if startpoint_method is not None:
warn(
"Passing `startpoint_method` directly is deprecated, use `problem.startpoint_method` instead.",
DeprecationWarning,
)

self.problem: Problem = problem
self.startpoint_method: StartpointMethod = startpoint_method
self.startpoint_method: StartpointMethod = (
startpoint_method or problem.startpoint_method
)
self.n_eval: int = 0
self.n_eval_round: int = 0

Expand Down
9 changes: 8 additions & 1 deletion pypesto/optimize/ess/sacess.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from multiprocessing import Manager, Process
from multiprocessing.managers import SyncManager
from typing import Any, Dict, List, Optional, Tuple
from warnings import warn

import numpy as np

Expand Down Expand Up @@ -96,9 +97,15 @@ def __init__(
def minimize(
self,
problem: Problem,
startpoint_method: StartpointMethod,
startpoint_method: StartpointMethod = None,
):
"""Solve the given optimization problem."""
if startpoint_method is not None:
warn(
"Passing `startpoint_method` directly is deprecated, use `problem.startpoint_method` instead.",
DeprecationWarning,
)

start_time = time.time()
logger.debug(
f"Running sacess with {self.num_workers} "
Expand Down
13 changes: 12 additions & 1 deletion pypesto/optimize/optimize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from typing import Callable, Iterable, Union
from warnings import warn

from ..engine import Engine, SingleCoreEngine
from ..history import HistoryOptions
Expand Down Expand Up @@ -50,6 +51,7 @@ def minimize(
startpoint_method:
Method for how to choose start points. False means the optimizer does
not require start points, e.g. for the 'PyswarmOptimizer'.
**Deprecated. Use ``problem.startpoint_method`` instead.**
result:
A result object to append the optimization results to. For example,
one might append more runs to a previous optimization. If None,
Expand Down Expand Up @@ -88,7 +90,16 @@ def minimize(

# startpoint method
if startpoint_method is None:
startpoint_method = uniform
if problem.startpoint_method is None:
startpoint_method = uniform
else:
startpoint_method = problem.startpoint_method
else:
warn(
"Passing `startpoint_method` directly is deprecated, use `problem.startpoint_method` instead.",
DeprecationWarning,
)

# convert startpoint method to class instance
startpoint_method = to_startpoint_method(startpoint_method)

Expand Down
1 change: 1 addition & 0 deletions pypesto/petab/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,7 @@ def create_problem(
x_names=x_ids,
x_scales=x_scales,
x_priors_defs=prior,
startpoint_method=self.create_startpoint_method(),
**problem_kwargs,
)

Expand Down
21 changes: 20 additions & 1 deletion pypesto/problem/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
import copy
import logging
from typing import Iterable, List, Optional, SupportsFloat, SupportsInt, Union
from typing import (
Callable,
Iterable,
List,
Optional,
SupportsFloat,
SupportsInt,
Union,
)

import numpy as np
import pandas as pd

from ..objective import ObjectiveBase
from ..objective.priors import NegLogParameterPriors
from ..startpoint import StartpointMethod, to_startpoint_method, uniform

SupportsFloatIterableOrValue = Union[Iterable[SupportsFloat], SupportsFloat]
SupportsIntIterableOrValue = Union[Iterable[SupportsInt], SupportsInt]
Expand Down Expand Up @@ -60,6 +69,9 @@ class Problem:
copy_objective:
Whethter to generate a deep copy of the objective function before
potential modification the problem class performs on it.
startpoint_method:
Method for how to choose start points. ``False`` means the optimizer
does not require start points, e.g. for the ``PyswarmOptimizer``.
Notes
-----
Expand Down Expand Up @@ -90,6 +102,7 @@ def __init__(
lb_init: Union[np.ndarray, List[float], None] = None,
ub_init: Union[np.ndarray, List[float], None] = None,
copy_objective: bool = True,
startpoint_method: Union[StartpointMethod, Callable, bool] = None,
):
if copy_objective:
objective = copy.deepcopy(objective)
Expand Down Expand Up @@ -147,6 +160,12 @@ def __init__(
self.normalize()
self._check_x_guesses()

# startpoint method
if startpoint_method is None:
startpoint_method = uniform
# convert startpoint method to class instance
self.startpoint_method = to_startpoint_method(startpoint_method)

@property
def lb(self) -> np.ndarray:
"""Return lower bounds of free parameters."""
Expand Down
13 changes: 8 additions & 5 deletions pypesto/startpoint/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""Startpoint base classes."""
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Callable, Union
from typing import TYPE_CHECKING, Callable, Union

import numpy as np

from ..C import FVAL, GRAD
from ..objective import ObjectiveBase
from ..problem import Problem

if TYPE_CHECKING:
import pypesto


class StartpointMethod(ABC):
Expand All @@ -21,7 +24,7 @@ class StartpointMethod(ABC):
def __call__(
self,
n_starts: int,
problem: Problem,
problem: pypesto.problem.Problem,
) -> np.ndarray:
"""Generate startpoints.
Expand All @@ -42,7 +45,7 @@ class NoStartpoints(StartpointMethod):
def __call__(
self,
n_starts: int,
problem: Problem,
problem: pypesto.problem.Problem,
) -> np.ndarray:
"""Generate a (n_starts, dim) nan matrix."""
startpoints = np.full(shape=(n_starts, problem.dim), fill_value=np.nan)
Expand Down Expand Up @@ -78,7 +81,7 @@ def __init__(
def __call__(
self,
n_starts: int,
problem: Problem,
problem: pypesto.problem.Problem,
) -> np.ndarray:
"""Generate checked startpoints."""
# shape: (n_guesses, dim)
Expand Down

0 comments on commit 660862b

Please sign in to comment.