Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jduerholt committed Oct 2, 2024
1 parent 4c2f7a8 commit d2de51b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
2 changes: 1 addition & 1 deletion bofire/utils/multiobjective.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_pareto_front(
df = domain.outputs.preprocess_experiments_all_valid_outputs(
experiments, output_feature_keys
)
objective = get_multiobjective_objective(outputs=outputs) # type: ignore
objective = get_multiobjective_objective(outputs=outputs, experiments=experiments) # type: ignore
pareto_mask = np.array(
is_non_dominated(
objective(
Expand Down
13 changes: 12 additions & 1 deletion tests/bofire/strategies/test_sobo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from itertools import chain

import numpy as np
import pandas as pd
import pytest
import torch
from botorch.acquisition import (
Expand Down Expand Up @@ -252,11 +253,13 @@ def f(samples, callables, weights, X):
return (samples[..., 0] + samples[..., 1]) * (samples[..., 0] * samples[..., 1])

benchmark = DTLZ2(3)
experiments = benchmark.f(benchmark.domain.inputs.sample(5), return_complete=True)
data_model = data_models.CustomSoboStrategy(
domain=benchmark.domain, acquisition_function=qNEI()
)
strategy = CustomSoboStrategy(data_model=data_model)
strategy.f = f
strategy._experiments = experiments
generic_objective, _, _ = strategy._get_objective_and_constraints()
assert isinstance(generic_objective, GenericMCObjective)

Expand All @@ -267,6 +270,8 @@ def test_custom_get_objective_invalid():
domain=benchmark.domain, acquisition_function=qNEI()
)
strategy = CustomSoboStrategy(data_model=data_model)
experiments = benchmark.f(benchmark.domain.inputs.sample(5), return_complete=True)
strategy._experiments = experiments

with pytest.raises(ValueError):
strategy._get_objective_and_constraints()
Expand All @@ -288,6 +293,8 @@ def f(samples, callables, weights, X):
use_output_constraints=False,
)
strategy1 = CustomSoboStrategy(data_model=data_model1)
experiments = benchmark.f(benchmark.domain.inputs.sample(5), return_complete=True)
strategy1._experiments = experiments
strategy1.f = f
f_str = strategy1.dumps()

Expand All @@ -298,11 +305,13 @@ def f(samples, callables, weights, X):
dump=f_str,
)
strategy2 = CustomSoboStrategy(data_model=data_model2)
strategy2._experiments = experiments

data_model3 = data_models.CustomSoboStrategy(
domain=benchmark.domain, acquisition_function=qNEI()
)
strategy3 = CustomSoboStrategy(data_model=data_model3)
strategy3._experiments = experiments
strategy3.loads(f_str)

assert isinstance(strategy2.f, type(f))
Expand Down Expand Up @@ -365,14 +374,16 @@ def test_sobo_fully_combinatorical(candidate_count):
),
],
)
def test_sobo_get_obective(outputs, expected_objective):
def test_sobo_get_objective(outputs, expected_objective):
strategy_data = data_models.SoboStrategy(
domain=Domain(
inputs=Inputs(features=[ContinuousInput(key="a", bounds=(0, 1))]),
outputs=outputs,
)
)
experiments = pd.DataFrame({"a": [0.5], "alpha": [0.5], "valid_alpha": [1]})
strategy = SoboStrategy(data_model=strategy_data)
strategy._experiments = experiments
obj, _, _ = strategy._get_objective_and_constraints()
assert isinstance(obj, expected_objective)

Expand Down

0 comments on commit d2de51b

Please sign in to comment.