diff --git a/bofire/data_models/features/continuous.py b/bofire/data_models/features/continuous.py index 5ddd019a..52a773af 100644 --- a/bofire/data_models/features/continuous.py +++ b/bofire/data_models/features/continuous.py @@ -3,11 +3,12 @@ import numpy as np import pandas as pd -from pydantic import Field, field_validator, model_validator +from pydantic import Field, model_validator from bofire.data_models.features.feature import Output, TTransform from bofire.data_models.features.numerical import NumericalInput from bofire.data_models.objectives.api import AnyObjective, MaximizeObjective +from bofire.data_models.types import Bounds class ContinuousInput(NumericalInput): @@ -23,7 +24,7 @@ class ContinuousInput(NumericalInput): type: Literal["ContinuousInput"] = "ContinuousInput" order_id: ClassVar[int] = 1 - bounds: Tuple[float, float] + bounds: Bounds local_relative_bounds: Optional[ Tuple[Annotated[float, Field(gt=0)], Annotated[float, Field(gt=0)]] ] = None @@ -78,26 +79,6 @@ def round(self, values: pd.Series) -> pd.Series: data=self.lower_bound + idx * self.stepsize, index=values.index ) - @field_validator("bounds") - @classmethod - def validate_lower_upper(cls, bounds): - """Validates that the lower bound is lower than the upper bound - - Args: - values (Dict): Dictionary with attributes key, lower and upper bound - - Raises: - ValueError: when the lower bound is higher than the upper bound - - Returns: - Dict: The attributes as dictionary - """ - if bounds[0] > bounds[1]: - raise ValueError( - f"lower bound must be <= upper bound, got {bounds[0]} > {bounds[1]}" - ) - return bounds - def validate_candidental(self, values: pd.Series) -> pd.Series: """Method to validate the suggested candidates diff --git a/bofire/data_models/types.py b/bofire/data_models/types.py index cbc33080..d4e44323 100644 --- a/bofire/data_models/types.py +++ b/bofire/data_models/types.py @@ -1,4 +1,4 @@ -from typing import Annotated, Dict, List, Union +from typing import Annotated, Dict, List, Tuple, Union from pydantic import AfterValidator, Field, PositiveInt @@ -53,6 +53,25 @@ def is_power_of_two(n): return value +def validate_bounds(bounds: Tuple[float, float]) -> Tuple[float, float]: + """Validate that the lower bound is less than or equal to the upper bound. + + Args: + bounds: Tuple of lower and upper bounds. + + Raises: + ValueError: If lower bound is greater than upper bound. + + Returns: + Validated bounds. + """ + if bounds[0] > bounds[1]: + raise ValueError( + f"lower bound must be <= upper bound, got {bounds[0]} > {bounds[1]}" + ) + return bounds + + FeatureKeys = Annotated[ List[str], Field(min_length=2), AfterValidator(make_unique_validator("Features")) ] @@ -65,6 +84,8 @@ def is_power_of_two(n): List[str], Field(min_length=1), AfterValidator(make_unique_validator("Descriptors")) ] +Bounds = Annotated[Tuple[float, float], AfterValidator(validate_bounds)] + DiscreteVals = Annotated[List[float], Field(min_length=1)] InputTransformSpecs = Dict[str, Union[CategoricalEncodingEnum, AnyMolFeatures]] diff --git a/bofire/surrogates/mlp.py b/bofire/surrogates/mlp.py index 83e9d052..e99b90ab 100644 --- a/bofire/surrogates/mlp.py +++ b/bofire/surrogates/mlp.py @@ -152,7 +152,7 @@ def fit_mlp( """ mlp.train() train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle) - optimizer = torch.optim.Adam(mlp.parameters(), lr=lr, weight_decay=weight_decay) + optimizer = torch.optim.Adam(mlp.parameters(), lr=lr, weight_decay=weight_decay) # type: ignore loss_function = loss_function() for _ in range(n_epoches): current_loss = 0.0 diff --git a/tests/bofire/data_models/specs/features.py b/tests/bofire/data_models/specs/features.py index 6889b09c..683e776c 100644 --- a/tests/bofire/data_models/specs/features.py +++ b/tests/bofire/data_models/specs/features.py @@ -45,6 +45,14 @@ "stepsize": None, }, ) + +specs.add_invalid( + features.ContinuousInput, + lambda: {"key": "a", "bounds": (5, 3)}, + error=ValueError, + message="lower bound must be <= upper bound, got 5.0 > 3.0", +) + specs.add_valid( features.ContinuousDescriptorInput, lambda: { diff --git a/tests/bofire/data_models/test_types.py b/tests/bofire/data_models/test_types.py index 501cb627..97eec5fd 100644 --- a/tests/bofire/data_models/test_types.py +++ b/tests/bofire/data_models/test_types.py @@ -1,7 +1,12 @@ import pytest from bofire.data_models.base import BaseModel -from bofire.data_models.types import CategoryVals, FeatureKeys, make_unique_validator +from bofire.data_models.types import ( + CategoryVals, + FeatureKeys, + make_unique_validator, + validate_bounds, +) def test_make_unique_validator(): @@ -23,3 +28,11 @@ class Bla(BaseModel): with pytest.raises(ValueError, match="Categories must be unique"): Bla(features=["a", "b"], categories=["a", "a"]) + + +def test_validate_bounds(): + with pytest.raises( + ValueError, match="lower bound must be <= upper bound, got 2.0 > 1.0" + ): + validate_bounds((2.0, 1.0)) + assert validate_bounds((1.0, 2.0)) == (1.0, 2.0)