Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Linear Interpolation #443

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions bofire/data_models/kernels/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from bofire.data_models.kernels.continuous import LinearKernel, MaternKernel, RBFKernel
from bofire.data_models.kernels.kernel import Kernel
from bofire.data_models.kernels.molecular import TanimotoKernel
from bofire.data_models.kernels.shape import WassersteinKernel
from bofire.data_models.priors.api import AnyGeneralPrior


Expand Down Expand Up @@ -51,6 +52,7 @@ class ScaleKernel(Kernel):
MultiplicativeKernel,
TanimotoKernel,
"ScaleKernel",
WassersteinKernel,
]
outputscale_prior: Optional[AnyGeneralPrior] = None

Expand Down
2 changes: 2 additions & 0 deletions bofire/data_models/kernels/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from bofire.data_models.kernels.kernel import Kernel
from bofire.data_models.kernels.molecular import MolecularKernel, TanimotoKernel
from bofire.data_models.kernels.shape import WassersteinKernel

AbstractKernel = Union[Kernel, CategoricalKernel, ContinuousKernel, MolecularKernel]

Expand All @@ -41,4 +42,5 @@
RBFKernel,
TanimotoKernel,
InfiniteWidthBNNKernel,
WassersteinKernel,
]
10 changes: 10 additions & 0 deletions bofire/data_models/kernels/shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Literal, Optional

from bofire.data_models.kernels.kernel import Kernel
from bofire.data_models.priors.api import AnyPrior


class WassersteinKernel(Kernel):
type: Literal["WassersteinKernel"] = "WassersteinKernel"
squared: bool = False
lengthscale_prior: Optional[AnyPrior] = None
4 changes: 4 additions & 0 deletions bofire/data_models/surrogates/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from bofire.data_models.surrogates.polynomial import PolynomialSurrogate
from bofire.data_models.surrogates.random_forest import RandomForestSurrogate
from bofire.data_models.surrogates.scaler import ScalerEnum
from bofire.data_models.surrogates.shape import PiecewiseLinearGPSurrogate
from bofire.data_models.surrogates.single_task_gp import (
SingleTaskGPHyperconfig,
SingleTaskGPSurrogate,
Expand Down Expand Up @@ -55,6 +56,7 @@
LinearDeterministicSurrogate,
MultiTaskGPSurrogate,
SingleTaskIBNNSurrogate,
PiecewiseLinearGPSurrogate,
]

AnyTrainableSurrogate = Union[
Expand All @@ -70,6 +72,7 @@
PolynomialSurrogate,
SingleTaskIBNNSurrogate,
TanimotoGPSurrogate,
PiecewiseLinearGPSurrogate,
]

AnyRegressionSurrogate = Union[
Expand All @@ -87,6 +90,7 @@
LinearDeterministicSurrogate,
MultiTaskGPSurrogate,
SingleTaskIBNNSurrogate,
PiecewiseLinearGPSurrogate,
]

AnyClassificationSurrogate = ClassificationMLPEnsemble
2 changes: 2 additions & 0 deletions bofire/data_models/surrogates/botorch_surrogates.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from bofire.data_models.surrogates.multi_task_gp import MultiTaskGPSurrogate
from bofire.data_models.surrogates.polynomial import PolynomialSurrogate
from bofire.data_models.surrogates.random_forest import RandomForestSurrogate
from bofire.data_models.surrogates.shape import PiecewiseLinearGPSurrogate
from bofire.data_models.surrogates.single_task_gp import SingleTaskGPSurrogate
from bofire.data_models.surrogates.tanimoto_gp import TanimotoGPSurrogate
from bofire.data_models.types import InputTransformSpecs
Expand All @@ -38,6 +39,7 @@
PolynomialSurrogate,
LinearDeterministicSurrogate,
MultiTaskGPSurrogate,
PiecewiseLinearGPSurrogate,
]


Expand Down
80 changes: 80 additions & 0 deletions bofire/data_models/surrogates/shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Annotated, List, Literal, Optional, Type, Union

from pydantic import AfterValidator, Field, PositiveInt, model_validator

# from bofire.data_models.strategies.api import FactorialStrategy
from bofire.data_models.features.api import AnyOutput, ContinuousOutput
from bofire.data_models.kernels.api import MaternKernel, RBFKernel, WassersteinKernel
from bofire.data_models.priors.api import (
BOTORCH_LENGTHCALE_PRIOR,
BOTORCH_NOISE_PRIOR,
BOTORCH_SCALE_PRIOR,
AnyPrior,
LogNormalPrior,
)
from bofire.data_models.surrogates.trainable_botorch import TrainableBotorchSurrogate
from bofire.data_models.types import Bounds, validate_monotonically_increasing


class PiecewiseLinearGPSurrogate(TrainableBotorchSurrogate):
type: Literal["PiecewiseLinearGPSurrogate"] = "PiecewiseLinearGPSurrogate"
interpolation_range: Bounds
n_interpolation_points: PositiveInt = 1000
x_keys: list[str]
y_keys: list[str]
continuous_keys: list[str]
prepend_x: Annotated[List[float], AfterValidator(validate_monotonically_increasing)]
append_x: Annotated[List[float], AfterValidator(validate_monotonically_increasing)]
prepend_y: Annotated[List[float], AfterValidator(validate_monotonically_increasing)]
append_y: Annotated[List[float], AfterValidator(validate_monotonically_increasing)]

shape_kernel: WassersteinKernel = Field(
default_factory=lambda: WassersteinKernel(
squared=False,
lengthscale_prior=LogNormalPrior(loc=1.0, scale=2.0),
)
)

continuous_kernel: Optional[Union[RBFKernel, MaternKernel]] = Field(
default_factory=lambda: RBFKernel(
lengthscale_prior=BOTORCH_LENGTHCALE_PRIOR(),
)
)

outputscale_prior: AnyPrior = Field(default_factory=lambda: BOTORCH_SCALE_PRIOR())
noise_prior: AnyPrior = Field(default_factory=lambda: BOTORCH_NOISE_PRIOR())

@model_validator(mode="after")
def validate_keys(self):
if (
sorted(self.x_keys + self.y_keys + self.continuous_keys)
!= self.inputs.get_keys()
):
raise ValueError("Feature keys do not match input keys.")
if len(self.x_keys) == 0 or len(self.y_keys) == 0:
raise ValueError(
"No features for interpolation. Please provide `x_keys` and `y_keys`."
)
if len(self.x_keys) + len(self.append_x) + len(self.prepend_x) != len(
self.y_keys
) + len(self.append_y) + len(self.prepend_y):
raise ValueError("Different number of x and y values for interpolation.")
return self

@model_validator(mode="after")
def validate_continuous_kernel(self):
if len(self.continuous_keys) == 0 and self.continuous_kernel is not None:
raise ValueError(
"Continuous kernel specified but no features for continuous kernel."
)
return self

@classmethod
def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool:
"""Abstract method to check output type for surrogate models
Args:
my_type: continuous or categorical output
Returns:
bool: True if the output type is valid for the surrogate chosen, False otherwise
"""
return isinstance(my_type, type(ContinuousOutput))
23 changes: 12 additions & 11 deletions bofire/data_models/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Annotated, Dict, List, Tuple, Union
from typing import Annotated, Dict, List, Sequence, Tuple, Union

from pydantic import AfterValidator, Field, PositiveInt

Expand Down Expand Up @@ -53,23 +53,22 @@ def is_power_of_two(n):
return value


def validate_bounds(bounds: Tuple[float, float]) -> Tuple[float, float]:
"""Validate that the lower bound is less than or equal to the upper bound.
def validate_monotonically_increasing(sequence: Sequence[float]) -> Sequence[float]:
"""Validate that the sequence is monotonically increasing.

Args:
bounds: Tuple of lower and upper bounds.
sequence: Sequence of values.

Raises:
ValueError: If lower bound is greater than upper bound.

Returns:
Validated bounds.
Validated sequence
"""
if bounds[0] > bounds[1]:
raise ValueError(
f"lower bound must be <= upper bound, got {bounds[0]} > {bounds[1]}"
)
return bounds
if len(sequence) > 1:
if not all(x <= y for x, y in zip(sequence, sequence[1:])):
raise ValueError("Sequence is not monotonically increasing.")
return sequence


FeatureKeys = Annotated[
Expand All @@ -84,7 +83,9 @@ def validate_bounds(bounds: Tuple[float, float]) -> Tuple[float, float]:
List[str], Field(min_length=1), AfterValidator(make_unique_validator("Descriptors"))
]

Bounds = Annotated[Tuple[float, float], AfterValidator(validate_bounds)]
Bounds = Annotated[
Tuple[float, float], AfterValidator(validate_monotonically_increasing)
]

DiscreteVals = Annotated[List[float], Field(min_length=1)]

Expand Down
19 changes: 19 additions & 0 deletions bofire/kernels/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import bofire.data_models.kernels.api as data_models
import bofire.priors.api as priors
from bofire.kernels.fingerprint_kernels.tanimoto_kernel import TanimotoKernel
from bofire.kernels.shape import WassersteinKernel


def map_RBFKernel(
Expand Down Expand Up @@ -186,7 +187,25 @@ def map_HammingDistanceKernel(
)


def map_WassersteinKernel(
data_model: data_models.WassersteinKernel,
batch_shape: torch.Size,
ard_num_dims: int,
active_dims: List[int],
) -> WassersteinKernel:
return WassersteinKernel(
squared=data_model.squared,
lengthscale_prior=(
priors.map(data_model.lengthscale_prior, d=len(active_dims))
if data_model.lengthscale_prior is not None
else None
),
active_dims=active_dims,
)


KERNEL_MAP = {
data_models.WassersteinKernel: map_WassersteinKernel,
data_models.RBFKernel: map_RBFKernel,
data_models.MaternKernel: map_MaternKernel,
data_models.LinearKernel: map_LinearKernel,
Expand Down
27 changes: 27 additions & 0 deletions bofire/kernels/shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch
from gpytorch.kernels.kernel import Kernel
from torch import Tensor


class WassersteinKernel(Kernel):
has_lengthscale = True

def __init__(self, squared: bool = False, **kwargs):
super(WassersteinKernel, self).__init__(**kwargs)
self.squared = squared

def calc_wasserstein_distances(self, x1: Tensor, x2: Tensor):
return (torch.cdist(x1, x2, p=1) / x1.shape[-1]).clamp_min(1e-15)

def forward(
self,
x1: Tensor,
x2: Tensor,
diag: bool = False,
last_dim_is_batch: bool = False,
) -> Tensor:
dists = self.calc_wasserstein_distances(x1, x2)
dists = dists / self.lengthscale
if self.squared:
return torch.exp(-(dists**2))
return torch.exp(-dists)
1 change: 1 addition & 0 deletions bofire/surrogates/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from bofire.surrogates.multi_task_gp import MultiTaskGPSurrogate
from bofire.surrogates.random_forest import RandomForestSurrogate
from bofire.surrogates.shape import PiecewiseLinearGPSurrogate
from bofire.surrogates.single_task_gp import SingleTaskGPSurrogate
from bofire.surrogates.surrogate import Surrogate
from bofire.surrogates.trainable import TrainableSurrogate
Expand Down
2 changes: 2 additions & 0 deletions bofire/surrogates/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from bofire.surrogates.mlp import ClassificationMLPEnsemble, RegressionMLPEnsemble
from bofire.surrogates.multi_task_gp import MultiTaskGPSurrogate
from bofire.surrogates.random_forest import RandomForestSurrogate
from bofire.surrogates.shape import PiecewiseLinearGPSurrogate
from bofire.surrogates.single_task_gp import SingleTaskGPSurrogate
from bofire.surrogates.surrogate import Surrogate
from bofire.surrogates.xgb import XGBoostSurrogate
Expand All @@ -29,6 +30,7 @@
data_models.LinearDeterministicSurrogate: LinearDeterministicSurrogate,
data_models.MultiTaskGPSurrogate: MultiTaskGPSurrogate,
data_models.SingleTaskIBNNSurrogate: SingleTaskGPSurrogate,
data_models.PiecewiseLinearGPSurrogate: PiecewiseLinearGPSurrogate,
}


Expand Down
Loading
Loading