diff --git a/docs/examples/decision_making.py b/docs/examples/decision_making.py index 1460fc25a..ee9df229f 100644 --- a/docs/examples/decision_making.py +++ b/docs/examples/decision_making.py @@ -237,6 +237,12 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: num_initial_samples=100, num_restarts=1 ) +# %% [markdown] + +# It is worth noting that `ThompsonSampling` is not the only utility function we could use, +# since our module also provides e.g. `ProbabilityOfImprovement`, +# which was briefly discussed in [our previous introduction to Bayesian optimisation](https://docs.jaxgaussianprocesses.com/examples/bayesian_optimisation/). + # %% [markdown] # ## Putting it All Together with the Decision Maker diff --git a/gpjax/decision_making/utility_functions/__init__.py b/gpjax/decision_making/utility_functions/__init__.py index 18868949d..5186b1075 100644 --- a/gpjax/decision_making/utility_functions/__init__.py +++ b/gpjax/decision_making/utility_functions/__init__.py @@ -18,6 +18,9 @@ SinglePointUtilityFunction, UtilityFunction, ) +from gpjax.decision_making.utility_functions.probability_of_improvement import ( + ProbabilityOfImprovement, +) from gpjax.decision_making.utility_functions.thompson_sampling import ThompsonSampling __all__ = [ @@ -26,4 +29,5 @@ "AbstractSinglePointUtilityFunctionBuilder", "SinglePointUtilityFunction", "ThompsonSampling", + "ProbabilityOfImprovement", ] diff --git a/gpjax/decision_making/utility_functions/probability_of_improvement.py b/gpjax/decision_making/utility_functions/probability_of_improvement.py new file mode 100644 index 000000000..30aa3986f --- /dev/null +++ b/gpjax/decision_making/utility_functions/probability_of_improvement.py @@ -0,0 +1,127 @@ +# Copyright 2024 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from dataclasses import dataclass + +from beartype.typing import Mapping +from jaxtyping import Num +import tensorflow_probability.substrates.jax as tfp + +from gpjax.dataset import Dataset +from gpjax.decision_making.utility_functions.base import ( + AbstractSinglePointUtilityFunctionBuilder, + SinglePointUtilityFunction, +) +from gpjax.decision_making.utils import OBJECTIVE +from gpjax.gps import ConjugatePosterior +from gpjax.typing import ( + Array, + KeyArray, +) + + +@dataclass +class ProbabilityOfImprovement(AbstractSinglePointUtilityFunctionBuilder): + r""" + An acquisition function which returns the probability of improvement + of the objective function over the best observed value. + + More precisely, given a predictive posterior distribution of the objective + function $`f`$, the probability of improvement at a test point $`x`$ is defined as: + $$`\text{PI}(x) = \text{Prob}[f(x) < f(x_{\text{best}})]`$$ + where $`x_{\text{best}}`$ is the minimiser of the posterior mean + at previously observed values (to handle noisy observations). + + The probability of improvement can be easily computed using the + cumulative distribution function of the standard normal distribution $`\Phi`$: + $$`\text{PI}(x) = \Phi\left(\frac{f(x_{\text{best}}) - \mu}{\sigma}\right)`$$ + where $`\mu`$ and $`\sigma`$ are the mean and standard deviation of the + predictive distribution of the objective function at $`x`$. + + References + ---------- + [1] Kushner, H. J. (1964). + A new method of locating the maximum point of an arbitrary multipeak curve in the presence of noise. + Journal of Basic Engineering, 86(1), 97-106. + + [2] Shahriari, B., Swersky, K., Wang, Z., Adams, R. P., & de Freitas, N. (2016). + Taking the human out of the loop: A review of Bayesian optimization. + Proceedings of the IEEE, 104(1), 148-175. doi: 10.1109/JPROC.2015.2494218 + """ + + def build_utility_function( + self, + posteriors: Mapping[str, ConjugatePosterior], + datasets: Mapping[str, Dataset], + key: KeyArray, + ) -> SinglePointUtilityFunction: + """ + Constructs the probability of improvement utility function + using the predictive posterior of the objective function. + + Args: + posteriors (Mapping[str, AbstractPosterior]): Dictionary of posteriors to be + used to form the utility function. One of the posteriors must correspond + to the `OBJECTIVE` key, as we sample from the objective posterior to form + the utility function. + datasets (Mapping[str, Dataset]): Dictionary of datasets which may be used + to form the utility function. Keys in `datasets` should correspond to + keys in `posteriors`. One of the datasets must correspond + to the `OBJECTIVE` key. + key (KeyArray): JAX PRNG key used for random number generation. Since + the probability of improvement is computed deterministically + from the predictive posterior, the key is not used. + + Returns: + SinglePointUtilityFunction: the probability of improvement utility function. + """ + self.check_objective_present(posteriors, datasets) + + objective_posterior = posteriors[OBJECTIVE] + if not isinstance(objective_posterior, ConjugatePosterior): + raise ValueError( + "Objective posterior must be a ConjugatePosterior to compute the Probability of Improvement using a Gaussian CDF." + ) + + objective_dataset = datasets[OBJECTIVE] + if ( + objective_dataset.X is None + or objective_dataset.n == 0 + or objective_dataset.y is None + ): + raise ValueError( + "Objective dataset must be non-empty to compute the " + "Probability of Improvement (since we need a " + "`best_y` value)." + ) + + def probability_of_improvement(x_test: Num[Array, "N D"]): + # Computing the posterior mean for the training dataset + # for computing the best_y value (as the minimum + # posterior mean of the objective function) + predictive_dist_for_training = objective_posterior.predict( + objective_dataset.X, objective_dataset + ) + best_y = predictive_dist_for_training.mean().min() + + predictive_dist = objective_posterior.predict(x_test, objective_dataset) + + normal_dist = tfp.distributions.Normal( + loc=predictive_dist.mean(), + scale=predictive_dist.stddev(), + ) + + return normal_dist.cdf(best_y).reshape(-1, 1) + + return probability_of_improvement diff --git a/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py b/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py new file mode 100644 index 000000000..417f5b2f6 --- /dev/null +++ b/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py @@ -0,0 +1,64 @@ +# Copyright 2023 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from jax import config + +config.update("jax_enable_x64", True) + +import jax +import jax.numpy as jnp +import jax.random as jr + +from gpjax.decision_making.test_functions.continuous_functions import Forrester +from gpjax.decision_making.utility_functions.probability_of_improvement import ( + ProbabilityOfImprovement, +) +from gpjax.decision_making.utils import OBJECTIVE +from tests.test_decision_making.utils import generate_dummy_conjugate_posterior + + +def test_probability_of_improvement_gives_correct_value_for_a_seed(): + key = jr.key(42) + forrester = Forrester() + dataset = forrester.generate_dataset(num_points=10, key=key) + posterior = generate_dummy_conjugate_posterior(dataset) + posteriors = {OBJECTIVE: posterior} + datasets = {OBJECTIVE: dataset} + + pi_utility_builder = ProbabilityOfImprovement() + pi_utility = pi_utility_builder.build_utility_function( + posteriors=posteriors, datasets=datasets, key=key + ) + + test_X = forrester.generate_test_points(num_points=10, key=key) + utility_values = pi_utility(test_X) + + # Computing the expected utility values + predictive_dist = posterior.predict(test_X, train_data=dataset) + predictive_mean = predictive_dist.mean() + predictive_std = predictive_dist.stddev() + + # Computing best_y as the min. of the posterior predictive mean + # over the training set. + predictive_dist_for_training_data = posterior.predict(dataset.X, train_data=dataset) + best_y = predictive_dist_for_training_data.mean().min() + + # Gaussian CDF computed "by hand" + x_ = (best_y - predictive_mean) / predictive_std + expected_utility_values = 0.5 * ( + 1 + jax.scipy.special.erf(x_ / jnp.sqrt(2)) + ).reshape(-1, 1) + + assert utility_values.shape == (10, 1) + assert jnp.isclose(utility_values, expected_utility_values).all() diff --git a/tests/test_decision_making/test_utility_functions/test_thompson_sampling.py b/tests/test_decision_making/test_utility_functions/test_thompson_sampling.py index f2a726274..8986a4cf6 100644 --- a/tests/test_decision_making/test_utility_functions/test_thompson_sampling.py +++ b/tests/test_decision_making/test_utility_functions/test_thompson_sampling.py @@ -17,11 +17,9 @@ config.update("jax_enable_x64", True) from beartype.typing import Callable -import jax.numpy as jnp import jax.random as jr import pytest -from gpjax.dataset import Dataset from gpjax.decision_making.test_functions.continuous_functions import ( AbstractContinuousTestFunction, Forrester, @@ -29,87 +27,8 @@ ) from gpjax.decision_making.utility_functions.thompson_sampling import ThompsonSampling from gpjax.decision_making.utils import OBJECTIVE -from gpjax.gps import ( - ConjugatePosterior, - NonConjugatePosterior, - Prior, -) -from gpjax.kernels import RBF -from gpjax.likelihoods import ( - Gaussian, - Poisson, -) -from gpjax.mean_functions import Zero from gpjax.typing import KeyArray - - -def generate_dummy_conjugate_posterior(dataset: Dataset) -> ConjugatePosterior: - kernel = RBF(lengthscale=jnp.ones(dataset.X.shape[1])) - mean_function = Zero() - prior = Prior(kernel=kernel, mean_function=mean_function) - likelihood = Gaussian(num_datapoints=dataset.n) - posterior = prior * likelihood - return posterior - - -def generate_dummy_non_conjugate_posterior(dataset: Dataset) -> NonConjugatePosterior: - kernel = RBF(lengthscale=jnp.ones(dataset.X.shape[1])) - mean_function = Zero() - prior = Prior(kernel=kernel, mean_function=mean_function) - likelihood = Poisson(num_datapoints=dataset.n) - posterior = prior * likelihood - return posterior - - -@pytest.mark.filterwarnings( - "ignore::UserWarning" -) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort -def test_thompson_sampling_no_objective_posterior_raises_error(): - key = jr.key(42) - forrester = Forrester() - dataset = forrester.generate_dataset(num_points=10, key=key) - posterior = generate_dummy_conjugate_posterior(dataset) - posteriors = {"CONSTRAINT": posterior} - datasets = {OBJECTIVE: dataset} - with pytest.raises(ValueError): - ts_utility_builder = ThompsonSampling(num_features=100) - ts_utility_builder.build_utility_function( - posteriors=posteriors, datasets=datasets, key=key - ) - - -@pytest.mark.filterwarnings( - "ignore::UserWarning" -) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort -def test_thompson_sampling_no_objective_dataset_raises_error(): - key = jr.key(42) - forrester = Forrester() - dataset = forrester.generate_dataset(num_points=10, key=key) - posterior = generate_dummy_conjugate_posterior(dataset) - posteriors = {OBJECTIVE: posterior} - datasets = {"CONSTRAINT": dataset} - with pytest.raises(ValueError): - ts_utility_builder = ThompsonSampling(num_features=100) - ts_utility_builder.build_utility_function( - posteriors=posteriors, datasets=datasets, key=key - ) - - -@pytest.mark.filterwarnings( - "ignore::UserWarning" -) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort -def test_thompson_sampling_non_conjugate_posterior_raises_error(): - key = jr.key(42) - forrester = Forrester() - dataset = forrester.generate_dataset(num_points=10, key=key) - posterior = generate_dummy_non_conjugate_posterior(dataset) - posteriors = {OBJECTIVE: posterior} - datasets = {OBJECTIVE: dataset} - with pytest.raises(ValueError): - ts_utility_builder = ThompsonSampling(num_features=100) - ts_utility_builder.build_utility_function( - posteriors=posteriors, datasets=datasets, key=key - ) +from tests.test_decision_making.utils import generate_dummy_conjugate_posterior @pytest.mark.parametrize("num_rff_features", [0, -1, -10]) @@ -130,34 +49,6 @@ def test_thompson_sampling_invalid_rff_num_raises_error(num_rff_features: int): ) -@pytest.mark.parametrize( - "test_target_function", - [(Forrester()), (LogarithmicGoldsteinPrice())], -) -@pytest.mark.parametrize("num_test_points", [50, 100]) -@pytest.mark.parametrize("key", [jr.key(42), jr.key(10)]) -@pytest.mark.filterwarnings( - "ignore::UserWarning" -) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort -def test_thompson_sampling_utility_function_correct_shapes( - test_target_function: AbstractContinuousTestFunction, - num_test_points: int, - key: KeyArray, -): - dataset = test_target_function.generate_dataset(num_points=10, key=key) - posterior = generate_dummy_conjugate_posterior(dataset) - posteriors = {OBJECTIVE: posterior} - datasets = {OBJECTIVE: dataset} - ts_utility_builder = ThompsonSampling(num_features=100) - ts_utility_function = ts_utility_builder.build_utility_function( - posteriors=posteriors, datasets=datasets, key=key - ) - test_key, _ = jr.split(key) - test_X = test_target_function.generate_test_points(num_test_points, test_key) - ts_utility_function_values = ts_utility_function(test_X) - assert ts_utility_function_values.shape == (num_test_points, 1) - - @pytest.mark.parametrize( "test_target_function", [(Forrester()), (LogarithmicGoldsteinPrice())], diff --git a/tests/test_decision_making/test_utility_functions/test_utility_functions.py b/tests/test_decision_making/test_utility_functions/test_utility_functions.py new file mode 100644 index 000000000..658d5b88a --- /dev/null +++ b/tests/test_decision_making/test_utility_functions/test_utility_functions.py @@ -0,0 +1,158 @@ +# Copyright 2023 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from jax import config + +config.update("jax_enable_x64", True) + +from beartype.typing import Type +import jax.random as jr +import pytest + +from gpjax.decision_making.test_functions.continuous_functions import ( + AbstractContinuousTestFunction, + Forrester, + LogarithmicGoldsteinPrice, +) +from gpjax.decision_making.utility_functions.base import ( + AbstractSinglePointUtilityFunctionBuilder, +) +from gpjax.decision_making.utility_functions.probability_of_improvement import ( + ProbabilityOfImprovement, +) +from gpjax.decision_making.utility_functions.thompson_sampling import ThompsonSampling +from gpjax.decision_making.utils import OBJECTIVE +from gpjax.typing import KeyArray +from tests.test_decision_making.utils import ( + generate_dummy_conjugate_posterior, + generate_dummy_non_conjugate_posterior, +) + + +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +@pytest.mark.parametrize( + "utility_function_builder, utility_function_kwargs", + [ + (ProbabilityOfImprovement, {}), + (ThompsonSampling, {"num_features": 100}), + ], +) +def test_utility_function_no_objective_posterior_raises_error( + utility_function_builder: Type[AbstractSinglePointUtilityFunctionBuilder], + utility_function_kwargs: dict, +): + key = jr.key(42) + forrester = Forrester() + dataset = forrester.generate_dataset(num_points=10, key=key) + posterior = generate_dummy_conjugate_posterior(dataset) + posteriors = {"CONSTRAINT": posterior} + datasets = {OBJECTIVE: dataset} + with pytest.raises(ValueError): + utility_function = utility_function_builder(**utility_function_kwargs) + utility_function.build_utility_function( + posteriors=posteriors, datasets=datasets, key=key + ) + + +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +@pytest.mark.parametrize( + "utility_function_builder, utility_function_kwargs", + [ + (ProbabilityOfImprovement, {}), + (ThompsonSampling, {"num_features": 100}), + ], +) +def test_utility_function_no_objective_dataset_raises_error( + utility_function_builder: Type[AbstractSinglePointUtilityFunctionBuilder], + utility_function_kwargs: dict, +): + key = jr.key(42) + forrester = Forrester() + dataset = forrester.generate_dataset(num_points=10, key=key) + posterior = generate_dummy_conjugate_posterior(dataset) + posteriors = {OBJECTIVE: posterior} + datasets = {"CONSTRAINT": dataset} + with pytest.raises(ValueError): + utility_function = utility_function_builder(**utility_function_kwargs) + utility_function.build_utility_function( + posteriors=posteriors, datasets=datasets, key=key + ) + + +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +@pytest.mark.parametrize( + "utility_function_builder, utility_function_kwargs", + [ + (ProbabilityOfImprovement, {}), + (ThompsonSampling, {"num_features": 100}), + ], +) +def test_non_conjugate_posterior_raises_error( + utility_function_builder: Type[AbstractSinglePointUtilityFunctionBuilder], + utility_function_kwargs: dict, +): + key = jr.key(42) + forrester = Forrester() + dataset = forrester.generate_dataset(num_points=10, key=key) + posterior = generate_dummy_non_conjugate_posterior(dataset) + posteriors = {OBJECTIVE: posterior} + datasets = {OBJECTIVE: dataset} + with pytest.raises(ValueError): + utility_function = utility_function_builder(**utility_function_kwargs) + utility_function.build_utility_function( + posteriors=posteriors, datasets=datasets, key=key + ) + + +@pytest.mark.parametrize( + "utility_function_builder, utility_function_kwargs", + [ + (ProbabilityOfImprovement, {}), + (ThompsonSampling, {"num_features": 100}), + ], +) +@pytest.mark.parametrize( + "test_target_function", + [(Forrester()), (LogarithmicGoldsteinPrice())], +) +@pytest.mark.parametrize("num_test_points", [50, 100]) +@pytest.mark.parametrize("key", [jr.key(42), jr.key(10)]) +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +def test_utility_functions_have_correct_shapes( + utility_function_builder: Type[AbstractSinglePointUtilityFunctionBuilder], + utility_function_kwargs: dict, + test_target_function: AbstractContinuousTestFunction, + num_test_points: int, + key: KeyArray, +): + dataset = test_target_function.generate_dataset(num_points=10, key=key) + posterior = generate_dummy_conjugate_posterior(dataset) + posteriors = {OBJECTIVE: posterior} + datasets = {OBJECTIVE: dataset} + utility_builder = utility_function_builder(**utility_function_kwargs) + utility_function = utility_builder.build_utility_function( + posteriors=posteriors, datasets=datasets, key=key + ) + test_key, _ = jr.split(key) + test_X = test_target_function.generate_test_points(num_test_points, test_key) + ts_utility_function_values = utility_function(test_X) + assert ts_utility_function_values.shape == (num_test_points, 1) diff --git a/tests/test_decision_making/utils.py b/tests/test_decision_making/utils.py index 4ea3e9dc3..e4f083832 100644 --- a/tests/test_decision_making/utils.py +++ b/tests/test_decision_making/utils.py @@ -14,6 +14,7 @@ # ============================================================================== from beartype.typing import Mapping +import jax.numpy as jnp from gpjax.dataset import Dataset from gpjax.decision_making.test_functions import Quadratic @@ -21,7 +22,17 @@ AbstractSinglePointUtilityFunctionBuilder, SinglePointUtilityFunction, ) -from gpjax.gps import ConjugatePosterior +from gpjax.gps import ( + ConjugatePosterior, + NonConjugatePosterior, + Prior, +) +from gpjax.kernels import RBF +from gpjax.likelihoods import ( + Gaussian, + Poisson, +) +from gpjax.mean_functions import Zero from gpjax.typing import KeyArray @@ -45,3 +56,21 @@ def build_utility_function( return lambda x: -1.0 * test_function.evaluate( x ) # Utility functions are *maximised* + + +def generate_dummy_conjugate_posterior(dataset: Dataset) -> ConjugatePosterior: + kernel = RBF(lengthscale=jnp.ones(dataset.X.shape[1])) + mean_function = Zero() + prior = Prior(kernel=kernel, mean_function=mean_function) + likelihood = Gaussian(num_datapoints=dataset.n) + posterior = prior * likelihood + return posterior + + +def generate_dummy_non_conjugate_posterior(dataset: Dataset) -> NonConjugatePosterior: + kernel = RBF(lengthscale=jnp.ones(dataset.X.shape[1])) + mean_function = Zero() + prior = Prior(kernel=kernel, mean_function=mean_function) + likelihood = Poisson(num_datapoints=dataset.n) + posterior = prior * likelihood + return posterior