Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Adds probability of improvement as an acquisition function #458

Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/examples/decision_making.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,12 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:
num_initial_samples=100, num_restarts=1
)

# %% [markdown]

# It is worth noting that `ThompsonSampling` is not the only utility function we could use,
# since our module also provides e.g. `ProbabilityOfImprovement`,
# which was briefly discussed in [our previous introduction to Bayesian optimisation](https://docs.jaxgaussianprocesses.com/examples/bayesian_optimisation/).


# %% [markdown]
# ## Putting it All Together with the Decision Maker
Expand Down
4 changes: 4 additions & 0 deletions gpjax/decision_making/utility_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
SinglePointUtilityFunction,
UtilityFunction,
)
from gpjax.decision_making.utility_functions.probability_of_improvement import (
ProbabilityOfImprovement,
)
from gpjax.decision_making.utility_functions.thompson_sampling import ThompsonSampling

__all__ = [
Expand All @@ -26,4 +29,5 @@
"AbstractSinglePointUtilityFunctionBuilder",
"SinglePointUtilityFunction",
"ThompsonSampling",
"ProbabilityOfImprovement",
]
110 changes: 110 additions & 0 deletions gpjax/decision_making/utility_functions/probability_of_improvement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2024 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from dataclasses import dataclass

from beartype.typing import Mapping
from jaxtyping import Num

from gpjax.dataset import Dataset
from gpjax.decision_making.utility_functions.base import (
AbstractSinglePointUtilityFunctionBuilder,
SinglePointUtilityFunction,
)
from gpjax.decision_making.utils import (
OBJECTIVE,
gaussian_cdf,
)
from gpjax.gps import ConjugatePosterior
from gpjax.typing import (
Array,
KeyArray,
)


@dataclass
class ProbabilityOfImprovement(AbstractSinglePointUtilityFunctionBuilder):
r"""
An acquisition function which returns the probability of improvement
of the objective function over the best observed value.

More precisely, given a predictive posterior distribution of the objective
function $`f`$, the probability of improvement at a test point $`x`$ is defined as:
$$`\text{PI}(x) = \text{Prob}[f(x) < f(x_{\text{best}})]`$$
where $`x_{\text{best}}`$ is the minimizer of $`f`$ in the dataset.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could update to something like where $x_{\text{best}}$ is the minimiser of the posterior mean at previously observed values, to handle noisy observations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated! Forgot to change this.


The probability of improvement can be easily computed using the
cumulative distribution function of the standard normal distribution $`\Phi`$:
$$`\text{PI}(x) = \Phi\left(\frac{f(x_{\text{best}}) - \mu}{\sigma}\right)`$$
where $`\mu`$ and $`\sigma`$ are the mean and standard deviation of the
predictive distribution of the objective function at $`x`$.

References
----------
[1] Kushner, H. J. (1964).
A new method of locating the maximum point of an arbitrary multipeak curve in the presence of noise.
Journal of Basic Engineering, 86(1), 97-106.

[2] Shahriari, B., Swersky, K., Wang, Z., Adams, R. P., & de Freitas, N. (2016).
Taking the human out of the loop: A review of Bayesian optimization.
Proceedings of the IEEE, 104(1), 148-175. doi: 10.1109/JPROC.2015.2494218
"""

def build_utility_function(
self,
posteriors: Mapping[str, ConjugatePosterior],
datasets: Mapping[str, Dataset],
key: KeyArray,
) -> SinglePointUtilityFunction:
"""
Constructs the probability of improvement utility function
using the predictive posterior of the objective function.

Args:
posteriors (Mapping[str, AbstractPosterior]): Dictionary of posteriors to be
used to form the utility function. One of the posteriors must correspond
to the `OBJECTIVE` key, as we sample from the objective posterior to form
the utility function.
datasets (Mapping[str, Dataset]): Dictionary of datasets which may be used
to form the utility function. Keys in `datasets` should correspond to
keys in `posteriors`. One of the datasets must correspond
to the `OBJECTIVE` key.
key (KeyArray): JAX PRNG key used for random number generation. Since
the probability of improvement is computed deterministically
from the predictive posterior, the key is not used.

Returns:
SinglePointUtilityFunction: the probability of improvement utility function.
"""
self.check_objective_present(posteriors, datasets)

objective_posterior = posteriors[OBJECTIVE]
if not isinstance(objective_posterior, ConjugatePosterior):
raise ValueError(
"Objective posterior must be a ConjugatePosterior to draw an approximate sample."
miguelgondu marked this conversation as resolved.
Show resolved Hide resolved
)

objective_dataset = datasets[OBJECTIVE]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably best to have something along the lines of

if objective_dataset.X is None or objective_dataset.n == 0:
            raise ValueError("Objective dataset must contain at least one item")

given that we use the objective dataset to find best_y.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed!


def probability_of_improvement(x_test: Num[Array, "N D"]):
predictive_dist = objective_posterior.predict(x_test, objective_dataset)

# Assuming that the goal is to minimize the objective function
best_y = objective_dataset.y.min()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be useful to define best_y as the minimum posterior mean value at any of the observed points, which would also handle the case where observations could be noisy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I've updated how we compute best_y. It is the first time I see it being computed like that, though. If I'm not mistaken, other frameworks use the min of the dataset.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah - for instance this approach is mentioned in "A benchmark of kriging-based infill criteria for noisy
optimization"
, section 3.2.


return gaussian_cdf(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of implementing our own gaussian_cdf function, IMO it makes sense to use the existing normal distribution provided by Tensorflow probability i.e. normal = tfp.distributions.Normal(predicitve_dist.mean(), predictive_dist.stddev()) and use the cdf defined on this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another alternative would be to update GaussianDistribution, right? For now, I'll create a normal as you state.

(best_y - predictive_dist.mean()) / predictive_dist.stddev()
).reshape(-1, 1)

return probability_of_improvement
10 changes: 10 additions & 0 deletions gpjax/decision_making/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
Dict,
Final,
)
import jax
import jax.numpy as jnp

from gpjax.dataset import Dataset
from gpjax.typing import (
Expand Down Expand Up @@ -48,3 +50,11 @@ def build_function_evaluator(
dictionary of datasets storing the evaluated points.
"""
return lambda x: {tag: Dataset(x, f(x)) for tag, f in functions.items()}


def gaussian_cdf(x: Float[Array, " N"]) -> Float[Array, " N"]:
"""
Compute the cumulative distribution function of the standard normal distribution at
the points `x`.
"""
return 0.5 * (1 + jax.scipy.special.erf(x / jnp.sqrt(2)))
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2023 The GPJax Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from jax import config

config.update("jax_enable_x64", True)

import jax.numpy as jnp
import jax.random as jr

from gpjax.decision_making.test_functions.continuous_functions import Forrester
from gpjax.decision_making.utility_functions.probability_of_improvement import (
ProbabilityOfImprovement,
)
from gpjax.decision_making.utils import (
OBJECTIVE,
gaussian_cdf,
)
from tests.test_decision_making.utils import generate_dummy_conjugate_posterior


def test_probability_of_improvement_gives_correct_value_for_a_seed():
key = jr.key(42)
forrester = Forrester()
dataset = forrester.generate_dataset(num_points=10, key=key)
posterior = generate_dummy_conjugate_posterior(dataset)
posteriors = {OBJECTIVE: posterior}
datasets = {OBJECTIVE: dataset}

pi_utility_builder = ProbabilityOfImprovement()
pi_utility = pi_utility_builder.build_utility_function(
posteriors=posteriors, datasets=datasets, key=key
)

test_X = forrester.generate_test_points(num_points=10, key=key)
utility_values = pi_utility(test_X)

# Computing the expected utility values
predictive_dist = posterior.predict(test_X, train_data=dataset)
predictive_mean = predictive_dist.mean()
predictive_std = predictive_dist.stddev()

expected_utility_values = gaussian_cdf(
(dataset.y.min() - predictive_mean) / predictive_std
).reshape(-1, 1)

assert utility_values.shape == (10, 1)
assert jnp.isclose(utility_values, expected_utility_values).all()
Original file line number Diff line number Diff line change
Expand Up @@ -17,99 +17,18 @@
config.update("jax_enable_x64", True)

from beartype.typing import Callable
import jax.numpy as jnp
import jax.random as jr
import pytest

from gpjax.dataset import Dataset
from gpjax.decision_making.test_functions.continuous_functions import (
AbstractContinuousTestFunction,
Forrester,
LogarithmicGoldsteinPrice,
)
from gpjax.decision_making.utility_functions.thompson_sampling import ThompsonSampling
from gpjax.decision_making.utils import OBJECTIVE
from gpjax.gps import (
ConjugatePosterior,
NonConjugatePosterior,
Prior,
)
from gpjax.kernels import RBF
from gpjax.likelihoods import (
Gaussian,
Poisson,
)
from gpjax.mean_functions import Zero
from gpjax.typing import KeyArray


def generate_dummy_conjugate_posterior(dataset: Dataset) -> ConjugatePosterior:
kernel = RBF(lengthscale=jnp.ones(dataset.X.shape[1]))
mean_function = Zero()
prior = Prior(kernel=kernel, mean_function=mean_function)
likelihood = Gaussian(num_datapoints=dataset.n)
posterior = prior * likelihood
return posterior


def generate_dummy_non_conjugate_posterior(dataset: Dataset) -> NonConjugatePosterior:
kernel = RBF(lengthscale=jnp.ones(dataset.X.shape[1]))
mean_function = Zero()
prior = Prior(kernel=kernel, mean_function=mean_function)
likelihood = Poisson(num_datapoints=dataset.n)
posterior = prior * likelihood
return posterior


@pytest.mark.filterwarnings(
"ignore::UserWarning"
) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort
def test_thompson_sampling_no_objective_posterior_raises_error():
key = jr.key(42)
forrester = Forrester()
dataset = forrester.generate_dataset(num_points=10, key=key)
posterior = generate_dummy_conjugate_posterior(dataset)
posteriors = {"CONSTRAINT": posterior}
datasets = {OBJECTIVE: dataset}
with pytest.raises(ValueError):
ts_utility_builder = ThompsonSampling(num_features=100)
ts_utility_builder.build_utility_function(
posteriors=posteriors, datasets=datasets, key=key
)


@pytest.mark.filterwarnings(
"ignore::UserWarning"
) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort
def test_thompson_sampling_no_objective_dataset_raises_error():
key = jr.key(42)
forrester = Forrester()
dataset = forrester.generate_dataset(num_points=10, key=key)
posterior = generate_dummy_conjugate_posterior(dataset)
posteriors = {OBJECTIVE: posterior}
datasets = {"CONSTRAINT": dataset}
with pytest.raises(ValueError):
ts_utility_builder = ThompsonSampling(num_features=100)
ts_utility_builder.build_utility_function(
posteriors=posteriors, datasets=datasets, key=key
)


@pytest.mark.filterwarnings(
"ignore::UserWarning"
) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort
def test_thompson_sampling_non_conjugate_posterior_raises_error():
key = jr.key(42)
forrester = Forrester()
dataset = forrester.generate_dataset(num_points=10, key=key)
posterior = generate_dummy_non_conjugate_posterior(dataset)
posteriors = {OBJECTIVE: posterior}
datasets = {OBJECTIVE: dataset}
with pytest.raises(ValueError):
ts_utility_builder = ThompsonSampling(num_features=100)
ts_utility_builder.build_utility_function(
posteriors=posteriors, datasets=datasets, key=key
)
from tests.test_decision_making.utils import generate_dummy_conjugate_posterior


@pytest.mark.parametrize("num_rff_features", [0, -1, -10])
Expand All @@ -130,34 +49,6 @@ def test_thompson_sampling_invalid_rff_num_raises_error(num_rff_features: int):
)


@pytest.mark.parametrize(
"test_target_function",
[(Forrester()), (LogarithmicGoldsteinPrice())],
)
@pytest.mark.parametrize("num_test_points", [50, 100])
@pytest.mark.parametrize("key", [jr.key(42), jr.key(10)])
@pytest.mark.filterwarnings(
"ignore::UserWarning"
) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort
def test_thompson_sampling_utility_function_correct_shapes(
test_target_function: AbstractContinuousTestFunction,
num_test_points: int,
key: KeyArray,
):
dataset = test_target_function.generate_dataset(num_points=10, key=key)
posterior = generate_dummy_conjugate_posterior(dataset)
posteriors = {OBJECTIVE: posterior}
datasets = {OBJECTIVE: dataset}
ts_utility_builder = ThompsonSampling(num_features=100)
ts_utility_function = ts_utility_builder.build_utility_function(
posteriors=posteriors, datasets=datasets, key=key
)
test_key, _ = jr.split(key)
test_X = test_target_function.generate_test_points(num_test_points, test_key)
ts_utility_function_values = ts_utility_function(test_X)
assert ts_utility_function_values.shape == (num_test_points, 1)


@pytest.mark.parametrize(
"test_target_function",
[(Forrester()), (LogarithmicGoldsteinPrice())],
Expand Down
Loading