Skip to content

Commit

Permalink
implement behrangs suggestion
Browse files Browse the repository at this point in the history
  • Loading branch information
jduerholt committed Jul 4, 2023
1 parent ae5288a commit bc2b246
Show file tree
Hide file tree
Showing 10 changed files with 71 additions and 45 deletions.
4 changes: 2 additions & 2 deletions bofire/data_models/strategies/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from bofire.data_models.strategies.samplers.sampler import SamplerStrategy
from bofire.data_models.strategies.stepwise.conditions import ( # noqa: F401
CombiCondition,
RequiredExperimentsCondition,
NumberOfExperimentsCondition,
)
from bofire.data_models.strategies.stepwise.stepwise import ( # noqa: F401
Step,
Expand Down Expand Up @@ -69,4 +69,4 @@
RejectionSampler,
]

AnyCondition = Union[RequiredExperimentsCondition, CombiCondition]
AnyCondition = Union[NumberOfExperimentsCondition, CombiCondition]
10 changes: 5 additions & 5 deletions bofire/data_models/strategies/stepwise/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,22 @@ class SingleCondition(BaseModel):
type: str


class RequiredExperimentsCondition(SingleCondition):
type: Literal["RequiredExperiments"] = "RequiredExperiments"
n_required_experiments: Annotated[int, Field(ge=0)]
class NumberOfExperimentsCondition(SingleCondition):
type: Literal["NumberOfExperimentsCondition"] = "NumberOfExperimentsCondition"
n_experiments: Annotated[int, Field(ge=1)]


class CombiCondition(Condition):
type: Literal["CombiCondition"] = "CombiCondition"
conditions: Annotated[
List[Union[RequiredExperimentsCondition, "CombiCondition"]], Field(min_items=2)
List[Union[NumberOfExperimentsCondition, "CombiCondition"]], Field(min_items=2)
]
n_required_conditions: Annotated[int, Field(ge=0)]

@validator("n_required_conditions")
def validate_n_required_conditions(cls, v, values):
if v > len(values["conditions"]):
raise ValueError(
"Number of required conditions largen than number of conditions."
"Number of required conditions larger than number of conditions."
)
return v
4 changes: 2 additions & 2 deletions bofire/data_models/strategies/stepwise/stepwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from bofire.data_models.strategies.samplers.rejection import RejectionSampler
from bofire.data_models.strategies.stepwise.conditions import (
CombiCondition,
RequiredExperimentsCondition,
NumberOfExperimentsCondition,
)
from bofire.data_models.strategies.strategy import Strategy

Expand All @@ -37,7 +37,7 @@
DoEStrategy,
]

AnyCondition = Union[RequiredExperimentsCondition, CombiCondition]
AnyCondition = Union[NumberOfExperimentsCondition, CombiCondition]


class Step(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions bofire/strategies/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
from bofire.strategies.samplers.polytope import PolytopeSampler # noqa: F401
from bofire.strategies.samplers.rejection import RejectionSampler # noqa: F401
from bofire.strategies.samplers.sampler import SamplerStrategy # noqa: F401
from bofire.strategies.stepwise.stepwise import StepwiseStrategy # noqa: F401
from bofire.strategies.strategy import Strategy # noqa: F401
2 changes: 1 addition & 1 deletion bofire/strategies/samplers/polytope.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _ask(self, n: int) -> pd.DataFrame:
return self.domain.inputs.sample(n, self.fallback_sampling_method)

# check if we have pseudo fixed features in the linear equality constraints
# a pseude fixed is a linear euquality constraint with only one feature included
# a pseudo fixed is a linear euquality constraint with only one feature included
# this can happen when fixing features when sampling with NChooseK constraints
eqs = get_linear_constraints(
domain=self.domain,
Expand Down
14 changes: 7 additions & 7 deletions bofire/strategies/stepwise/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ def evaluate(self, domain: Domain, experiments: pd.DataFrame) -> bool:
pass


class RequiredExperimentsCondition(Condition):
def __init__(self, data_model: data_models.RequiredExperimentsCondition):
self.n_required_experiments = data_model.n_required_experiments
class NumberOfExperimentsCondition(Condition):
def __init__(self, data_model: data_models.NumberOfExperimentsCondition):
self.n_experiments = data_model.n_experiments

def evaluate(self, domain: Domain, experiments: pd.DataFrame) -> bool:
n_experiments = len(
domain.outputs.preprocess_experiments_all_valid_outputs(experiments)
)
return n_experiments >= self.n_required_experiments
return n_experiments <= self.n_experiments


class CombiCondition(Condition):
Expand All @@ -41,13 +41,13 @@ def evaluate(self, domain: Domain, experiments: pd.DataFrame) -> bool:

CONDITION_MAP = {
data_models.CombiCondition: CombiCondition,
data_models.RequiredExperimentsCondition: RequiredExperimentsCondition,
data_models.NumberOfExperimentsCondition: NumberOfExperimentsCondition,
}


def map(
data_model: Union[
data_models.CombiCondition, data_models.RequiredExperimentsCondition
data_models.CombiCondition, data_models.NumberOfExperimentsCondition
],
) -> Union[CombiCondition, RequiredExperimentsCondition]:
) -> Union[CombiCondition, NumberOfExperimentsCondition]:
return CONDITION_MAP[data_model.__class__](data_model)
3 changes: 1 addition & 2 deletions bofire/strategies/stepwise/stepwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def has_sufficient_experiments(self) -> bool:
return True

def _get_step(self) -> Tuple[int, Step]: # type: ignore
for i in range(len(self.steps) - 1, -1, -1):
step = self.steps[i]
for i, step in enumerate(self.steps):
condition = conditions.map(step.condition)
if condition.evaluate(self.domain, experiments=self.experiments):
return i, step
Expand Down
8 changes: 2 additions & 6 deletions tests/bofire/data_models/specs/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,12 @@
"steps": [
strategies.Step(
strategy_data=strategies.RandomStrategy(domain=tempdomain),
condition=strategies.RequiredExperimentsCondition(
n_required_experiments=0
),
condition=strategies.NumberOfExperimentsCondition(n_experiments=10),
max_parallelism=2,
).dict(),
strategies.Step(
strategy_data=strategies.QehviStrategy(domain=tempdomain),
condition=strategies.RequiredExperimentsCondition(
n_required_experiments=10
),
condition=strategies.NumberOfExperimentsCondition(n_experiments=30),
max_parallelism=2,
).dict(),
],
Expand Down
24 changes: 14 additions & 10 deletions tests/bofire/strategies/stepwise/test_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,39 @@
def test_RequiredExperimentsCondition():
benchmark = Himmelblau()
experiments = benchmark.f(benchmark.domain.inputs.sample(3), return_complete=True)
data_model = data_models.RequiredExperimentsCondition(n_required_experiments=3)
data_model = data_models.NumberOfExperimentsCondition(n_experiments=3)
condition = conditions.map(data_model=data_model)
assert condition.evaluate(benchmark.domain, experiments=experiments) is True
experiments = benchmark.f(benchmark.domain.inputs.sample(2), return_complete=True)
experiments = benchmark.f(benchmark.domain.inputs.sample(10), return_complete=True)
assert condition.evaluate(benchmark.domain, experiments=experiments) is False


def test_CombiCondition_invalid():
with pytest.raises(
ValueError,
match="Number of required conditions largen than number of conditions.",
match="Number of required conditions larger than number of conditions.",
):
data_models.CombiCondition(
conditions=[
data_models.RequiredExperimentsCondition(n_required_experiments=2),
data_models.RequiredExperimentsCondition(n_required_experiments=3),
data_models.NumberOfExperimentsCondition(n_experiments=2),
data_models.NumberOfExperimentsCondition(n_experiments=3),
],
n_required_conditions=3,
)


@pytest.mark.parametrize("n_required, expected", [(1, True), (2, False)])
def test_CombiCondition(n_required, expected):
@pytest.mark.parametrize(
"n_required, n_experiments, expected", [(1, 10, True), (2, 1, True)]
)
def test_CombiCondition(n_required, n_experiments, expected):
benchmark = Himmelblau()
experiments = benchmark.f(benchmark.domain.inputs.sample(10), return_complete=True)
experiments = benchmark.f(
benchmark.domain.inputs.sample(n_experiments), return_complete=True
)
data_model = data_models.CombiCondition(
conditions=[
data_models.RequiredExperimentsCondition(n_required_experiments=2),
data_models.RequiredExperimentsCondition(n_required_experiments=12),
data_models.NumberOfExperimentsCondition(n_experiments=2),
data_models.NumberOfExperimentsCondition(n_experiments=12),
],
n_required_conditions=n_required,
)
Expand Down
46 changes: 36 additions & 10 deletions tests/bofire/strategies/stepwise/test_stepwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from bofire.benchmarks.single import Himmelblau
from bofire.data_models.acquisition_functions.api import qNEI
from bofire.data_models.strategies.api import (
NumberOfExperimentsCondition,
RandomStrategy,
RequiredExperimentsCondition,
SoboStrategy,
Step,
StepwiseStrategy,
Expand All @@ -20,22 +20,22 @@ def test_StepwiseStrategy_invalid():
domain2.inputs[0].key = "mama"
with pytest.raises(
ValueError,
# match="Domain of step 0 is incompatible to domain of StepwiseStrategy.",
match="Domain of step 0 is incompatible to domain of StepwiseStrategy.",
# match=f"Domain of step {0} is incompatible to domain of StepwiseStrategy.",
):
StepwiseStrategy(
domain=benchmark.domain,
steps=[
Step(
strategy_data=RandomStrategy(domain=domain2),
condition=RequiredExperimentsCondition(n_required_experiments=0),
condition=NumberOfExperimentsCondition(n_experiments=5),
max_parallelism=-1,
),
Step(
strategy_data=SoboStrategy(
domain=benchmark.domain, acquisition_function=qNEI()
),
condition=RequiredExperimentsCondition(n_required_experiments=0),
condition=NumberOfExperimentsCondition(n_experiments=15),
max_parallelism=2,
),
],
Expand All @@ -56,14 +56,14 @@ def test_StepWiseStrategy_get_step(n_experiments, expected_strategy, expected_in
steps=[
Step(
strategy_data=RandomStrategy(domain=benchmark.domain),
condition=RequiredExperimentsCondition(n_required_experiments=0),
condition=NumberOfExperimentsCondition(n_experiments=6),
max_parallelism=-1,
),
Step(
strategy_data=SoboStrategy(
domain=benchmark.domain, acquisition_function=qNEI()
),
condition=RequiredExperimentsCondition(n_required_experiments=10),
condition=NumberOfExperimentsCondition(n_experiments=10),
max_parallelism=2,
),
],
Expand All @@ -75,21 +75,47 @@ def test_StepWiseStrategy_get_step(n_experiments, expected_strategy, expected_in
assert i == expected_index


def test_StepWiseStrategy_get_step_invalid():
benchmark = Himmelblau()
experiments = benchmark.f(benchmark.domain.inputs.sample(12), return_complete=True)
data_model = StepwiseStrategy(
domain=benchmark.domain,
steps=[
Step(
strategy_data=RandomStrategy(domain=benchmark.domain),
condition=NumberOfExperimentsCondition(n_experiments=6),
max_parallelism=-1,
),
Step(
strategy_data=SoboStrategy(
domain=benchmark.domain, acquisition_function=qNEI()
),
condition=NumberOfExperimentsCondition(n_experiments=10),
max_parallelism=2,
),
],
)
strategy = strategies.map(data_model)
strategy.tell(experiments)
with pytest.raises(ValueError, match="No condition could be satisfied."):
strategy._get_step()


def test_StepWiseStrategy_invalid_ask():
benchmark = Himmelblau()
data_model = StepwiseStrategy(
domain=benchmark.domain,
steps=[
Step(
strategy_data=RandomStrategy(domain=benchmark.domain),
condition=RequiredExperimentsCondition(n_required_experiments=0),
condition=NumberOfExperimentsCondition(n_experiments=8),
max_parallelism=2,
),
Step(
strategy_data=SoboStrategy(
domain=benchmark.domain, acquisition_function=qNEI()
),
condition=RequiredExperimentsCondition(n_required_experiments=10),
condition=NumberOfExperimentsCondition(n_experiments=10),
max_parallelism=2,
),
],
Expand All @@ -111,14 +137,14 @@ def test_StepWiseStrategy_ask():
steps=[
Step(
strategy_data=RandomStrategy(domain=benchmark.domain),
condition=RequiredExperimentsCondition(n_required_experiments=0),
condition=NumberOfExperimentsCondition(n_experiments=5),
max_parallelism=2,
),
Step(
strategy_data=SoboStrategy(
domain=benchmark.domain, acquisition_function=qNEI()
),
condition=RequiredExperimentsCondition(n_required_experiments=10),
condition=NumberOfExperimentsCondition(n_experiments=10),
max_parallelism=2,
),
],
Expand Down

0 comments on commit bc2b246

Please sign in to comment.