Skip to content

Commit

Permalink
Add startpoint_method to Problem
Browse files Browse the repository at this point in the history
It's not intuitive that `x_guesses` is part of `Problem`, but `startpoint_method`s are handled separately. See also discussion in #1017.

This patch
* adds `startpoint_method` to `Problem`
* during PEtab import, sets `Problem.startpoint_method` based on the PEtab problem

To be discussed: Do we want to keep the existing `startpoint_method` argument in the long term or should it be removed (deprecated)?
  • Loading branch information
dweindl committed Jul 11, 2023
1 parent d6723dd commit 2ee7726
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 19 deletions.
5 changes: 3 additions & 2 deletions pypesto/optimize/ess/cess.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def _initialize(self):
def minimize(
self,
problem: Problem,
startpoint_method: StartpointMethod,
# TODO: deprecate?
startpoint_method: StartpointMethod = None,
) -> pypesto.Result:
"""Minimize the given objective using CESS.
Expand All @@ -122,7 +123,7 @@ def minimize(

evaluator = FunctionEvaluator(
problem=problem,
startpoint_method=startpoint_method,
startpoint_method=startpoint_method or problem.startpoint_method,
)

refsets = [
Expand Down
14 changes: 5 additions & 9 deletions pypesto/optimize/ess/ess.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,15 +189,11 @@ def minimize(
self._initialize()
self.starttime = time.time()

if (
refset is None and (problem is None or startpoint_method is None)
) or (
refset is not None
and (problem is not None or startpoint_method is not None)
if (refset is None and problem is None) or (
refset is not None and problem is not None
):
raise ValueError(
"Either `refset` or `problem` and `startpoint_method` "
"has to be provided."
"Either `refset` or `problem` has to be provided."
)
# generate initial RefSet if not provided
if refset is None:
Expand All @@ -210,13 +206,13 @@ def minimize(
if self.n_procs:
self.evaluator = FunctionEvaluatorMP(
problem=problem,
startpoint_method=startpoint_method,
startpoint_method=startpoint_method or problem.startpoint,
n_procs=self.n_procs,
)
else:
self.evaluator = FunctionEvaluatorMT(
problem=problem,
startpoint_method=startpoint_method,
startpoint_method=startpoint_method or problem.startpoint,
n_threads=self.n_threads or 1,
)

Expand Down
7 changes: 5 additions & 2 deletions pypesto/optimize/ess/function_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ class FunctionEvaluator:
def __init__(
self,
problem: Problem,
startpoint_method: StartpointMethod,
# TODO deprecate?
startpoint_method: StartpointMethod = None,
):
"""Construct.
Expand All @@ -41,7 +42,9 @@ def __init__(
startpoint_method: Method for choosing feasible parameters
"""
self.problem: Problem = problem
self.startpoint_method: StartpointMethod = startpoint_method
self.startpoint_method: StartpointMethod = (
startpoint_method or problem.startpoint_method
)
self.n_eval: int = 0
self.n_eval_round: int = 0

Expand Down
5 changes: 3 additions & 2 deletions pypesto/optimize/ess/sacess.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def __init__(
def minimize(
self,
problem: Problem,
startpoint_method: StartpointMethod,
# TODO: deprecate?
startpoint_method: StartpointMethod = None,
):
"""Solve the given optimization problem."""
start_time = time.time()
Expand Down Expand Up @@ -135,7 +136,7 @@ def minimize(
args=(
worker,
problem,
startpoint_method,
startpoint_method or problem.startpoint_method,
self.sacess_loglevel,
),
)
Expand Down
6 changes: 5 additions & 1 deletion pypesto/optimize/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,12 @@ def minimize(
n_starts = bound_n_starts_from_env(n_starts)

# startpoint method
# TODO: deprecate or support both, `problem.startpoint_method` and `startpoint_method`?
if startpoint_method is None:
startpoint_method = uniform
if problem.startpoint_method is None:
startpoint_method = uniform
else:
startpoint_method = problem.startpoint_method
# convert startpoint method to class instance
startpoint_method = to_startpoint_method(startpoint_method)

Expand Down
1 change: 1 addition & 0 deletions pypesto/petab/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,7 @@ def create_problem(
x_names=x_ids,
x_scales=x_scales,
x_priors_defs=prior,
startpoint_method=self.create_startpoint_method(),
**problem_kwargs,
)

Expand Down
18 changes: 17 additions & 1 deletion pypesto/problem/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
import copy
import logging
from typing import Iterable, List, Optional, SupportsFloat, SupportsInt, Union
from typing import (
Callable,
Iterable,
List,
Optional,
SupportsFloat,
SupportsInt,
Union,
)

import numpy as np
import pandas as pd

from ..objective import ObjectiveBase
from ..objective.priors import NegLogParameterPriors
from ..startpoint import StartpointMethod, to_startpoint_method, uniform

SupportsFloatIterableOrValue = Union[Iterable[SupportsFloat], SupportsFloat]
SupportsIntIterableOrValue = Union[Iterable[SupportsInt], SupportsInt]
Expand Down Expand Up @@ -90,6 +99,7 @@ def __init__(
lb_init: Union[np.ndarray, List[float], None] = None,
ub_init: Union[np.ndarray, List[float], None] = None,
copy_objective: bool = True,
startpoint_method: Union[StartpointMethod, Callable, bool] = None,
):
if copy_objective:
objective = copy.deepcopy(objective)
Expand Down Expand Up @@ -147,6 +157,12 @@ def __init__(
self.normalize()
self._check_x_guesses()

# startpoint method
if startpoint_method is None:
startpoint_method = uniform
# convert startpoint method to class instance
self.startpoint_method = to_startpoint_method(startpoint_method)

@property
def lb(self) -> np.ndarray:
"""Return lower bounds of free parameters."""
Expand Down
7 changes: 5 additions & 2 deletions pypesto/startpoint/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""Startpoint base classes."""
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Callable, Union
from typing import TYPE_CHECKING, Callable, Union

import numpy as np

from ..C import FVAL, GRAD
from ..objective import ObjectiveBase
from ..problem import Problem

if TYPE_CHECKING:
from ..problem import Problem


class StartpointMethod(ABC):
Expand Down

0 comments on commit 2ee7726

Please sign in to comment.