Skip to content

Commit

Permalink
Add possibility to compute feature importance over the lengthscales i…
Browse files Browse the repository at this point in the history
…n a SingleTaskGPSurrogate (#293)

* add lenghtscale feature importance

* add test

* fix test
  • Loading branch information
jduerholt authored Sep 28, 2023
1 parent c478e6a commit 50caea8
Show file tree
Hide file tree
Showing 5 changed files with 323 additions and 278 deletions.
72 changes: 61 additions & 11 deletions bofire/surrogates/feature_importance.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,62 @@
from typing import Dict, Sequence
from typing import Dict, Optional, Sequence

import numpy as np
import pandas as pd

from bofire.data_models.enum import RegressionMetricsEnum
from bofire.surrogates.diagnostics import metrics
from bofire.surrogates.single_task_gp import SingleTaskGPSurrogate
from bofire.surrogates.surrogate import Surrogate


def lengthscale_importance(surrogate: SingleTaskGPSurrogate) -> pd.Series:
"""Compute the lengthscale importance based on ARD.
Args:
surrogate (SingleTaskGPSurrogate): Surrogate to extract the importances.
Returns:
pd.Series: The importance values (inverse of the individual lenght scales).
"""
try:
scales = surrogate.model.covar_module.base_kernel.lengthscale # type: ignore
except AttributeError:
raise ValueError("No lenghtscale based kernel found.")
scales = 1.0 / scales.squeeze().detach().numpy() # type: ignore
if isinstance(scales, float):
raise ValueError("Only one lengthscale found, use `ard=True`.")
if len(scales) != len(surrogate.inputs):
raise ValueError(
"Number of lengthscale parameters to not matches the number of inputs."
)
return pd.Series(data=scales, index=surrogate.inputs.get_keys())


def lengthscale_importance_hook(
surrogate: SingleTaskGPSurrogate,
X_train: Optional[pd.DataFrame] = None,
y_train: Optional[pd.DataFrame] = None,
X_test: Optional[pd.DataFrame] = None,
y_test: Optional[pd.DataFrame] = None,
):
"""Hook that can be used within `model.cross_validate` to compute a cross validated permutation feature importance."""
return lengthscale_importance(surrogate=surrogate)


def combine_lengthscale_importances(importances: Sequence[pd.Series]) -> pd.DataFrame:
"""Combine the importance values from each fold into one dataframe.
Args:
importances (Sequence[pd.Series]): List of importance values per fold.
Returns:
pd.DataFrame: Dataframe with feature keys as columns, and one row per fold.
"""
return pd.concat(importances, axis=1).T


def permutation_importance(
model: Surrogate,
surrogate: Surrogate,
X: pd.DataFrame,
y: pd.DataFrame,
n_repeats: int = 5,
Expand All @@ -28,7 +75,7 @@ def permutation_importance(
Dict[str, pd.DataFrame]: keys are the metrices for which the model is evluated and value is a dataframe
with the feature keys as columns and the mean and std of the respective permutation importances as rows.
"""
assert len(model.outputs) == 1, "Only single output model supported so far."
assert len(surrogate.outputs) == 1, "Only single output model supported so far."
assert n_repeats > 1, "Number of repeats has to be larger than 1."
assert seed > 0, "Seed has to be larger than zero."

Expand All @@ -42,12 +89,13 @@ def permutation_importance(
RegressionMetricsEnum.SPEARMAN: 1.0,
}

output_key = model.outputs[0].key
output_key = surrogate.outputs[0].key
rng = np.random.default_rng(seed)
prelim_results = {
k.name: {feature.key: [] for feature in model.inputs} for k in metrics.keys()
k.name: {feature.key: [] for feature in surrogate.inputs}
for k in metrics.keys()
}
pred = model.predict(X)
pred = surrogate.predict(X)
if len(pred) >= 2:
original_metrics = {
k.name: metrics[k](y[output_key].values, pred[output_key + "_pred"].values) # type: ignore
Expand All @@ -56,13 +104,13 @@ def permutation_importance(
else:
original_metrics = {k.name: np.nan for k in metrics.keys()}

for feature in model.inputs:
for feature in surrogate.inputs:
for _ in range(n_repeats):
# shuffle
X_i = X.copy()
X_i[feature.key] = rng.permutation(X_i[feature.key].values) # type: ignore
# predict
pred = model.predict(X_i)
pred = surrogate.predict(X_i)
# compute scores
for metricenum, metric in metrics.items():
if len(pred) >= 2:
Expand All @@ -83,7 +131,7 @@ def permutation_importance(
- np.mean(prelim_results[k.name][feature.key]),
np.std(prelim_results[k.name][feature.key]),
]
for feature in model.inputs
for feature in surrogate.inputs
},
index=["mean", "std"],
)
Expand All @@ -93,7 +141,7 @@ def permutation_importance(


def permutation_importance_hook(
model: Surrogate,
surrogate: Surrogate,
X_train: pd.DataFrame,
y_train: pd.DataFrame,
X_test: pd.DataFrame,
Expand Down Expand Up @@ -125,7 +173,9 @@ def permutation_importance_hook(
else:
X = X_train
y = y_train
return permutation_importance(model=model, X=X, y=y, n_repeats=n_repeats, seed=seed)
return permutation_importance(
surrogate=surrogate, X=X, y=y, n_repeats=n_repeats, seed=seed
)


def combine_permutation_importances(
Expand Down
2 changes: 1 addition & 1 deletion bofire/surrogates/trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def cross_validate(
for hookname, hook in hooks.items():
hook_results[hookname].append(
hook(
model=self, # type: ignore
surrogate=self, # type: ignore
X_train=X_train,
y_train=y_train,
X_test=X_test,
Expand Down
6 changes: 3 additions & 3 deletions tests/bofire/surrogates/test_cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,13 @@ def test_model_cross_validate_include_X(include_X, include_labcodes):


def test_model_cross_validate_hooks():
def hook1(model, X_train, y_train, X_test, y_test):
assert isinstance(model, surrogates.SingleTaskGPSurrogate)
def hook1(surrogate, X_train, y_train, X_test, y_test):
assert isinstance(surrogate, surrogates.SingleTaskGPSurrogate)
assert y_train.shape == (8, 1)
assert y_test.shape == (2, 1)
return X_train.shape

def hook2(model, X_train, y_train, X_test, y_test, return_test=True):
def hook2(surrogate, X_train, y_train, X_test, y_test, return_test=True):
if return_test:
return X_test.shape
return X_train.shape
Expand Down
57 changes: 52 additions & 5 deletions tests/bofire/surrogates/test_feature_importance.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import numpy as np
import pandas as pd
import pytest

import bofire.surrogates.api as surrogates
from bofire.data_models.domain.api import Inputs, Outputs
from bofire.data_models.features.api import ContinuousInput, ContinuousOutput
from bofire.data_models.kernels.api import RBFKernel, ScaleKernel
from bofire.data_models.surrogates.api import SingleTaskGPSurrogate
from bofire.surrogates.diagnostics import metrics
from bofire.surrogates.feature_importance import (
combine_lengthscale_importances,
combine_permutation_importances,
lengthscale_importance,
lengthscale_importance_hook,
permutation_importance,
permutation_importance_hook,
)
Expand Down Expand Up @@ -35,23 +40,65 @@ def get_model_and_data():
return model, experiments


def test_lengthscale_importance_invalid():
model, experiments = get_model_and_data()
surrogate_data = SingleTaskGPSurrogate(
inputs=model.inputs, outputs=model.outputs, kernel=RBFKernel()
)
surrogate = surrogates.map(surrogate_data)
surrogate.fit(experiments)
with pytest.raises(ValueError, match="No lenghtscale based kernel found."):
lengthscale_importance(surrogate=surrogate)
surrogate_data = SingleTaskGPSurrogate(
inputs=model.inputs,
outputs=model.outputs,
kernel=ScaleKernel(base_kernel=RBFKernel(ard=False)),
)
surrogate = surrogates.map(surrogate_data)
surrogate.fit(experiments)
with pytest.raises(ValueError, match="Only one lengthscale found, use `ard=True`."):
lengthscale_importance(surrogate=surrogate)


def test_lengthscale_importance():
surrogate, experiments = get_model_and_data()
surrogate.fit(experiments)
importance = lengthscale_importance(surrogate=surrogate)
assert isinstance(importance, pd.Series)
assert list(importance.index) == surrogate.inputs.get_keys()
importance = lengthscale_importance_hook(surrogate=surrogate)
assert isinstance(importance, pd.Series)
assert list(importance.index) == surrogate.inputs.get_keys()


def test_combine_lengthscale_importances():
importances = [
pd.Series(index=["x_1", "x_2", "x_3"], data=np.random.uniform(size=3))
for _ in range(5)
]
combined = combine_lengthscale_importances(importances=importances)
assert isinstance(combined, pd.DataFrame)
assert combined.shape == (5, 3)
assert list(combined.columns) == ["x_1", "x_2", "x_3"]


def test_permutation_importance_invalid():
model, experiments = get_model_and_data()
X = experiments[model.inputs.get_keys()]
y = experiments[["y"]]
model.fit(experiments=experiments)
with pytest.raises(AssertionError):
permutation_importance(model=model, X=X, y=y, n_repeats=1)
permutation_importance(surrogate=model, X=X, y=y, n_repeats=1)
with pytest.raises(AssertionError):
permutation_importance(model=model, X=X, y=y, n_repeats=2, seed=-1)
permutation_importance(surrogate=model, X=X, y=y, n_repeats=2, seed=-1)


def test_permutation_importance():
model, experiments = get_model_and_data()
X = experiments[model.inputs.get_keys()]
y = experiments[["y"]]
model.fit(experiments=experiments)
results = permutation_importance(model=model, X=X, y=y, n_repeats=5)
results = permutation_importance(surrogate=model, X=X, y=y, n_repeats=5)
assert isinstance(results, dict)
assert len(results) == len(metrics)
for m in metrics.keys():
Expand All @@ -66,7 +113,7 @@ def test_permutation_importance_nan():
X = experiments[model.inputs.get_keys()][:1]
y = experiments[["y"]][:1]
model.fit(experiments=experiments)
results = permutation_importance(model=model, X=X, y=y, n_repeats=5)
results = permutation_importance(surrogate=model, X=X, y=y, n_repeats=5)
assert isinstance(results, dict)
assert len(results) == len(metrics)
for m in metrics.keys():
Expand All @@ -84,7 +131,7 @@ def test_permutation_importance_hook(use_test):
y = experiments[["y"]]
model.fit(experiments=experiments)
results = permutation_importance_hook(
model=model, X_train=X, y_train=y, X_test=X, y_test=y, use_test=use_test
surrogate=model, X_train=X, y_train=y, X_test=X, y_test=y, use_test=use_test
)
assert isinstance(results, dict)
assert len(results) == len(metrics)
Expand Down
Loading

0 comments on commit 50caea8

Please sign in to comment.