Skip to content

Commit

Permalink
Infinite Width BNN Kernel and Surrogate (#405)
Browse files Browse the repository at this point in the history
add bnn kernel and surrogate
  • Loading branch information
jduerholt authored Jun 27, 2024
1 parent 6ca1586 commit 95e44ff
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 5 deletions.
7 changes: 3 additions & 4 deletions bofire/data_models/kernels/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from bofire.data_models.kernels.continuous import (
ContinuousKernel,
InfiniteWidthBNNKernel,
LinearKernel,
MaternKernel,
PolynomialKernel,
Expand All @@ -22,10 +23,7 @@
AbstractKernel = Union[Kernel, CategoricalKernel, ContinuousKernel, MolecularKernel]

AnyContinuousKernel = Union[
MaternKernel,
LinearKernel,
PolynomialKernel,
RBFKernel,
MaternKernel, LinearKernel, PolynomialKernel, RBFKernel, InfiniteWidthBNNKernel
]

AnyCategoricalKernel = HammingDistanceKernel
Expand All @@ -42,4 +40,5 @@
MaternKernel,
RBFKernel,
TanimotoKernel,
InfiniteWidthBNNKernel,
]
7 changes: 6 additions & 1 deletion bofire/data_models/kernels/continuous.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Literal, Optional

from pydantic import field_validator
from pydantic import PositiveInt, field_validator

from bofire.data_models.kernels.kernel import Kernel
from bofire.data_models.priors.api import AnyGeneralPrior, AnyPrior
Expand Down Expand Up @@ -38,3 +38,8 @@ class PolynomialKernel(ContinuousKernel):
type: Literal["PolynomialKernel"] = "PolynomialKernel"
offset_prior: Optional[AnyGeneralPrior] = None
power: int = 2


class InfiniteWidthBNNKernel(Kernel):
type: Literal["InfiniteWidthBNNKernel"] = "InfiniteWidthBNNKernel"
depth: PositiveInt = 3
4 changes: 4 additions & 0 deletions bofire/data_models/surrogates/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Union

from bofire.data_models.surrogates.bnn import SingleTaskIBNNSurrogate
from bofire.data_models.surrogates.botorch import BotorchSurrogate
from bofire.data_models.surrogates.botorch_surrogates import (
AnyBotorchSurrogate,
Expand Down Expand Up @@ -53,6 +54,7 @@
TanimotoGPSurrogate,
LinearDeterministicSurrogate,
MultiTaskGPSurrogate,
SingleTaskIBNNSurrogate,
]

AnyTrainableSurrogate = Union[
Expand All @@ -66,6 +68,7 @@
XGBoostSurrogate,
LinearSurrogate,
PolynomialSurrogate,
SingleTaskIBNNSurrogate,
TanimotoGPSurrogate,
]

Expand All @@ -83,6 +86,7 @@
TanimotoGPSurrogate,
LinearDeterministicSurrogate,
MultiTaskGPSurrogate,
SingleTaskIBNNSurrogate,
]

AnyClassificationSurrogate = ClassificationMLPEnsemble
11 changes: 11 additions & 0 deletions bofire/data_models/surrogates/bnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import Literal, Optional

from bofire.data_models.kernels.api import InfiniteWidthBNNKernel
from bofire.data_models.surrogates.single_task_gp import SingleTaskGPSurrogate
from bofire.data_models.surrogates.trainable import Hyperconfig


class SingleTaskIBNNSurrogate(SingleTaskGPSurrogate):
type: Literal["SingleTaskIBNNSurrogate"] = "SingleTaskIBNNSurrogate"
kernel: InfiniteWidthBNNKernel = InfiniteWidthBNNKernel()
hyperconfig: Optional[Hyperconfig] = None
20 changes: 20 additions & 0 deletions bofire/kernels/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,25 @@ def map_MaternKernel(
)


def map_InfiniteWidthBNNKernel(
data_model: data_models.InfiniteWidthBNNKernel,
batch_shape: torch.Size,
ard_num_dims: int,
active_dims: List[int],
) -> "InfiniteWidthBNNKernel": # noqa: F821 # type: ignore
try:
from botorch.models.kernels import InfiniteWidthBNNKernel
except ImportError:
raise ImportError(
"Please update to botorch development version to use this feature."
)
return InfiniteWidthBNNKernel(
batch_shape=batch_shape,
active_dims=tuple(active_dims),
depth=data_model.depth,
)


def map_LinearKernel(
data_model: data_models.LinearKernel,
batch_shape: torch.Size,
Expand Down Expand Up @@ -177,6 +196,7 @@ def map_HammingDistanceKernel(
data_models.ScaleKernel: map_ScaleKernel,
data_models.TanimotoKernel: map_TanimotoKernel,
data_models.HammingDistanceKernel: map_HammingDistanceKernel,
data_models.InfiniteWidthBNNKernel: map_InfiniteWidthBNNKernel,
}


Expand Down
1 change: 1 addition & 0 deletions bofire/surrogates/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
data_models.TanimotoGPSurrogate: SingleTaskGPSurrogate,
data_models.LinearDeterministicSurrogate: LinearDeterministicSurrogate,
data_models.MultiTaskGPSurrogate: MultiTaskGPSurrogate,
data_models.SingleTaskIBNNSurrogate: SingleTaskGPSurrogate,
}


Expand Down
6 changes: 6 additions & 0 deletions tests/bofire/data_models/specs/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@
error=ValueError,
message="nu expected to be 0.5, 1.5, or 2.5",
)
specs.add_valid(
kernels.InfiniteWidthBNNKernel,
lambda: {
"depth": 3,
},
)

specs.add_valid(
kernels.RBFKernel,
Expand Down
26 changes: 26 additions & 0 deletions tests/bofire/data_models/specs/surrogates.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from bofire.data_models.kernels.api import (
HammingDistanceKernel,
InfiniteWidthBNNKernel,
MaternKernel,
ScaleKernel,
TanimotoKernel,
Expand Down Expand Up @@ -72,6 +73,31 @@
},
)

specs.add_valid(
models.SingleTaskIBNNSurrogate,
lambda: {
"inputs": Inputs(
features=[
ContinuousInput(key="a", bounds=(0, 1)),
ContinuousInput(key="b", bounds=(0, 1)),
]
).model_dump(),
"outputs": Outputs(
features=[
features.valid(ContinuousOutput).obj(),
]
).model_dump(),
"scaler": ScalerEnum.NORMALIZE,
"output_scaler": ScalerEnum.STANDARDIZE,
"noise_prior": BOTORCH_NOISE_PRIOR().model_dump(),
"hyperconfig": None,
"input_preprocessing_specs": {},
"aggregations": None,
"dump": None,
"kernel": InfiniteWidthBNNKernel(depth=3).model_dump(),
},
)

specs.add_valid(
models.MixedSingleTaskGPSurrogate,
lambda: {
Expand Down
20 changes: 20 additions & 0 deletions tests/bofire/kernels/test_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from bofire.data_models.kernels.api import (
AdditiveKernel,
HammingDistanceKernel,
InfiniteWidthBNNKernel,
LinearKernel,
MaternKernel,
MultiplicativeKernel,
Expand All @@ -20,6 +21,14 @@
from bofire.data_models.priors.api import BOTORCH_SCALE_PRIOR, GammaPrior
from tests.bofire.data_models.specs.api import Spec

try:
from botorch.models.kernels import InfiniteWidthBNNKernel as BNNKernel
except ImportError:
BNN_AVAILABLE = False
else:
BNN_AVAILABLE = True


EQUIVALENTS = {
RBFKernel: gpytorch.kernels.RBFKernel,
MaternKernel: gpytorch.kernels.MaternKernel,
Expand All @@ -34,12 +43,23 @@

def test_map(kernel_spec: Spec):
kernel = kernel_spec.cls(**kernel_spec.typed_spec())
if isinstance(kernel, InfiniteWidthBNNKernel):
return
gkernel = kernels.map(
kernel, batch_shape=torch.Size(), ard_num_dims=10, active_dims=list(range(5))
)
assert isinstance(gkernel, EQUIVALENTS[kernel.__class__])


@pytest.mark.skipif(BNN_AVAILABLE is False, reason="requires latest botorch")
def test_map_infinite_width_bnn_kernel():
kernel = InfiniteWidthBNNKernel(depth=3)
gkernel = kernels.map(
kernel, batch_shape=torch.Size(), active_dims=list(range(5)), ard_num_dims=10
)
assert isinstance(gkernel, BNNKernel)


def test_map_scale_kernel():
kernel = ScaleKernel(
base_kernel=RBFKernel(), outputscale_prior=BOTORCH_SCALE_PRIOR()
Expand Down

0 comments on commit 95e44ff

Please sign in to comment.