Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add startpoint_method to Problem #1093

Merged
merged 2 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion pypesto/optimize/ess/cess.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import time
from typing import Dict, List, Optional
from warnings import warn

import numpy as np

Expand Down Expand Up @@ -107,7 +108,7 @@ def _initialize(self):
def minimize(
self,
problem: Problem,
startpoint_method: StartpointMethod,
startpoint_method: StartpointMethod = None,
) -> pypesto.Result:
"""Minimize the given objective using CESS.

Expand All @@ -117,7 +118,14 @@ def minimize(
Problem to run ESS on.
startpoint_method:
Method for choosing starting points.
**Deprecated. Use ``problem.startpoint_method`` instead.**
"""
if startpoint_method is not None:
warn(
"Passing `startpoint_method` directly is deprecated, use `problem.startpoint_method` instead.",
DeprecationWarning,
)

self._initialize()

evaluator = FunctionEvaluator(
Expand Down
18 changes: 11 additions & 7 deletions pypesto/optimize/ess/ess.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import logging
import time
from typing import List, Optional, Tuple
from warnings import warn

import numpy as np

Expand Down Expand Up @@ -183,21 +184,24 @@ def minimize(
Problem to run ESS on.
startpoint_method:
Method for choosing starting points.
**Deprecated. Use ``problem.startpoint_method`` instead.**
refset:
The initial RefSet or ``None`` to auto-generate.
"""
if startpoint_method is not None:
warn(
"Passing `startpoint_method` directly is deprecated, use `problem.startpoint_method` instead.",
DeprecationWarning,
)

self._initialize()
self.starttime = time.time()

if (
refset is None and (problem is None or startpoint_method is None)
) or (
refset is not None
and (problem is not None or startpoint_method is not None)
if (refset is None and problem is None) or (
refset is not None and problem is not None
):
raise ValueError(
"Either `refset` or `problem` and `startpoint_method` "
"has to be provided."
"Either `refset` or `problem` has to be provided."
)
# generate initial RefSet if not provided
if refset is None:
Expand Down
20 changes: 16 additions & 4 deletions pypesto/optimize/ess/function_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from typing import Optional, Sequence, Tuple
from warnings import warn

import numpy as np

Expand All @@ -31,17 +32,28 @@ class FunctionEvaluator:
def __init__(
self,
problem: Problem,
startpoint_method: StartpointMethod,
startpoint_method: StartpointMethod = None,
):
"""Construct.

Parameters
----------
problem: The problem
startpoint_method: Method for choosing feasible parameters
problem:
The problem
startpoint_method:
Method for choosing feasible parameters
**Deprecated. Use ``problem.startpoint_method`` instead.**
"""
if startpoint_method is not None:
warn(
"Passing `startpoint_method` directly is deprecated, use `problem.startpoint_method` instead.",
DeprecationWarning,
)

self.problem: Problem = problem
self.startpoint_method: StartpointMethod = startpoint_method
self.startpoint_method: StartpointMethod = (
startpoint_method or problem.startpoint_method
)
self.n_eval: int = 0
self.n_eval_round: int = 0

Expand Down
9 changes: 8 additions & 1 deletion pypesto/optimize/ess/sacess.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from multiprocessing import Manager, Process
from multiprocessing.managers import SyncManager
from typing import Any, Dict, List, Optional, Tuple
from warnings import warn

import numpy as np

Expand Down Expand Up @@ -96,9 +97,15 @@ def __init__(
def minimize(
self,
problem: Problem,
startpoint_method: StartpointMethod,
startpoint_method: StartpointMethod = None,
):
"""Solve the given optimization problem."""
if startpoint_method is not None:
warn(
"Passing `startpoint_method` directly is deprecated, use `problem.startpoint_method` instead.",
DeprecationWarning,
)

start_time = time.time()
logger.debug(
f"Running sacess with {self.num_workers} "
Expand Down
13 changes: 12 additions & 1 deletion pypesto/optimize/optimize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from typing import Callable, Iterable, Union
from warnings import warn

from ..engine import Engine, SingleCoreEngine
from ..history import HistoryOptions
Expand Down Expand Up @@ -50,6 +51,7 @@ def minimize(
startpoint_method:
Method for how to choose start points. False means the optimizer does
not require start points, e.g. for the 'PyswarmOptimizer'.
**Deprecated. Use ``problem.startpoint_method`` instead.**
result:
A result object to append the optimization results to. For example,
one might append more runs to a previous optimization. If None,
Expand Down Expand Up @@ -88,7 +90,16 @@ def minimize(

# startpoint method
if startpoint_method is None:
startpoint_method = uniform
if problem.startpoint_method is None:
startpoint_method = uniform
else:
startpoint_method = problem.startpoint_method
else:
warn(
"Passing `startpoint_method` directly is deprecated, use `problem.startpoint_method` instead.",
DeprecationWarning,
)

# convert startpoint method to class instance
startpoint_method = to_startpoint_method(startpoint_method)

Expand Down
1 change: 1 addition & 0 deletions pypesto/petab/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,7 @@ def create_problem(
x_names=x_ids,
x_scales=x_scales,
x_priors_defs=prior,
startpoint_method=self.create_startpoint_method(),
**problem_kwargs,
)

Expand Down
21 changes: 20 additions & 1 deletion pypesto/problem/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
import copy
import logging
from typing import Iterable, List, Optional, SupportsFloat, SupportsInt, Union
from typing import (
Callable,
Iterable,
List,
Optional,
SupportsFloat,
SupportsInt,
Union,
)

import numpy as np
import pandas as pd

from ..objective import ObjectiveBase
from ..objective.priors import NegLogParameterPriors
from ..startpoint import StartpointMethod, to_startpoint_method, uniform

SupportsFloatIterableOrValue = Union[Iterable[SupportsFloat], SupportsFloat]
SupportsIntIterableOrValue = Union[Iterable[SupportsInt], SupportsInt]
Expand Down Expand Up @@ -60,6 +69,9 @@ class Problem:
copy_objective:
Whethter to generate a deep copy of the objective function before
potential modification the problem class performs on it.
startpoint_method:
Method for how to choose start points. ``False`` means the optimizer
does not require start points, e.g. for the ``PyswarmOptimizer``.

Notes
-----
Expand Down Expand Up @@ -90,6 +102,7 @@ def __init__(
lb_init: Union[np.ndarray, List[float], None] = None,
ub_init: Union[np.ndarray, List[float], None] = None,
copy_objective: bool = True,
startpoint_method: Union[StartpointMethod, Callable, bool] = None,
dweindl marked this conversation as resolved.
Show resolved Hide resolved
):
if copy_objective:
objective = copy.deepcopy(objective)
Expand Down Expand Up @@ -147,6 +160,12 @@ def __init__(
self.normalize()
self._check_x_guesses()

# startpoint method
if startpoint_method is None:
startpoint_method = uniform
# convert startpoint method to class instance
self.startpoint_method = to_startpoint_method(startpoint_method)

@property
def lb(self) -> np.ndarray:
"""Return lower bounds of free parameters."""
Expand Down
13 changes: 8 additions & 5 deletions pypesto/startpoint/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""Startpoint base classes."""
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Callable, Union
from typing import TYPE_CHECKING, Callable, Union

import numpy as np

from ..C import FVAL, GRAD
from ..objective import ObjectiveBase
from ..problem import Problem

if TYPE_CHECKING:
import pypesto


class StartpointMethod(ABC):
Expand All @@ -21,7 +24,7 @@ class StartpointMethod(ABC):
def __call__(
self,
n_starts: int,
problem: Problem,
problem: pypesto.problem.Problem,
) -> np.ndarray:
"""Generate startpoints.

Expand All @@ -42,7 +45,7 @@ class NoStartpoints(StartpointMethod):
def __call__(
self,
n_starts: int,
problem: Problem,
problem: pypesto.problem.Problem,
) -> np.ndarray:
"""Generate a (n_starts, dim) nan matrix."""
startpoints = np.full(shape=(n_starts, problem.dim), fill_value=np.nan)
Expand Down Expand Up @@ -78,7 +81,7 @@ def __init__(
def __call__(
self,
n_starts: int,
problem: Problem,
problem: pypesto.problem.Problem,
) -> np.ndarray:
"""Generate checked startpoints."""
# shape: (n_guesses, dim)
Expand Down
Loading