Skip to content

Commit

Permalink
Types for bounds (#423)
Browse files Browse the repository at this point in the history
* add new types for bounds

* fix pyright
  • Loading branch information
jduerholt authored Jul 31, 2024
1 parent caaeda1 commit 645a943
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 25 deletions.
25 changes: 3 additions & 22 deletions bofire/data_models/features/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

import numpy as np
import pandas as pd
from pydantic import Field, field_validator, model_validator
from pydantic import Field, model_validator

from bofire.data_models.features.feature import Output, TTransform
from bofire.data_models.features.numerical import NumericalInput
from bofire.data_models.objectives.api import AnyObjective, MaximizeObjective
from bofire.data_models.types import Bounds


class ContinuousInput(NumericalInput):
Expand All @@ -23,7 +24,7 @@ class ContinuousInput(NumericalInput):
type: Literal["ContinuousInput"] = "ContinuousInput"
order_id: ClassVar[int] = 1

bounds: Tuple[float, float]
bounds: Bounds
local_relative_bounds: Optional[
Tuple[Annotated[float, Field(gt=0)], Annotated[float, Field(gt=0)]]
] = None
Expand Down Expand Up @@ -78,26 +79,6 @@ def round(self, values: pd.Series) -> pd.Series:
data=self.lower_bound + idx * self.stepsize, index=values.index
)

@field_validator("bounds")
@classmethod
def validate_lower_upper(cls, bounds):
"""Validates that the lower bound is lower than the upper bound
Args:
values (Dict): Dictionary with attributes key, lower and upper bound
Raises:
ValueError: when the lower bound is higher than the upper bound
Returns:
Dict: The attributes as dictionary
"""
if bounds[0] > bounds[1]:
raise ValueError(
f"lower bound must be <= upper bound, got {bounds[0]} > {bounds[1]}"
)
return bounds

def validate_candidental(self, values: pd.Series) -> pd.Series:
"""Method to validate the suggested candidates
Expand Down
23 changes: 22 additions & 1 deletion bofire/data_models/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Annotated, Dict, List, Union
from typing import Annotated, Dict, List, Tuple, Union

from pydantic import AfterValidator, Field, PositiveInt

Expand Down Expand Up @@ -53,6 +53,25 @@ def is_power_of_two(n):
return value


def validate_bounds(bounds: Tuple[float, float]) -> Tuple[float, float]:
"""Validate that the lower bound is less than or equal to the upper bound.
Args:
bounds: Tuple of lower and upper bounds.
Raises:
ValueError: If lower bound is greater than upper bound.
Returns:
Validated bounds.
"""
if bounds[0] > bounds[1]:
raise ValueError(
f"lower bound must be <= upper bound, got {bounds[0]} > {bounds[1]}"
)
return bounds


FeatureKeys = Annotated[
List[str], Field(min_length=2), AfterValidator(make_unique_validator("Features"))
]
Expand All @@ -65,6 +84,8 @@ def is_power_of_two(n):
List[str], Field(min_length=1), AfterValidator(make_unique_validator("Descriptors"))
]

Bounds = Annotated[Tuple[float, float], AfterValidator(validate_bounds)]

DiscreteVals = Annotated[List[float], Field(min_length=1)]

InputTransformSpecs = Dict[str, Union[CategoricalEncodingEnum, AnyMolFeatures]]
Expand Down
2 changes: 1 addition & 1 deletion bofire/surrogates/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def fit_mlp(
"""
mlp.train()
train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle)
optimizer = torch.optim.Adam(mlp.parameters(), lr=lr, weight_decay=weight_decay)
optimizer = torch.optim.Adam(mlp.parameters(), lr=lr, weight_decay=weight_decay) # type: ignore
loss_function = loss_function()
for _ in range(n_epoches):
current_loss = 0.0
Expand Down
8 changes: 8 additions & 0 deletions tests/bofire/data_models/specs/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@
"stepsize": None,
},
)

specs.add_invalid(
features.ContinuousInput,
lambda: {"key": "a", "bounds": (5, 3)},
error=ValueError,
message="lower bound must be <= upper bound, got 5.0 > 3.0",
)

specs.add_valid(
features.ContinuousDescriptorInput,
lambda: {
Expand Down
15 changes: 14 additions & 1 deletion tests/bofire/data_models/test_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import pytest

from bofire.data_models.base import BaseModel
from bofire.data_models.types import CategoryVals, FeatureKeys, make_unique_validator
from bofire.data_models.types import (
CategoryVals,
FeatureKeys,
make_unique_validator,
validate_bounds,
)


def test_make_unique_validator():
Expand All @@ -23,3 +28,11 @@ class Bla(BaseModel):

with pytest.raises(ValueError, match="Categories must be unique"):
Bla(features=["a", "b"], categories=["a", "a"])


def test_validate_bounds():
with pytest.raises(
ValueError, match="lower bound must be <= upper bound, got 2.0 > 1.0"
):
validate_bounds((2.0, 1.0))
assert validate_bounds((1.0, 2.0)) == (1.0, 2.0)

0 comments on commit 645a943

Please sign in to comment.