From eadb7bbe0245b7823cb9d4d3a3355f7339d184b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Gonz=C3=A1lez=20Duque?= Date: Sun, 30 Jun 2024 11:38:35 +0200 Subject: [PATCH 01/14] Adds probability of improvement, refactors common tests for utility functions --- .../probability_of_improvement.py | 84 +++++++++ gpjax/decision_making/utils.py | 11 ++ .../test_thompson_sampling.py | 113 +------------ .../test_utility_functions.py | 159 ++++++++++++++++++ tests/test_decision_making/utils.py | 28 ++- 5 files changed, 285 insertions(+), 110 deletions(-) create mode 100644 gpjax/decision_making/utility_functions/probability_of_improvement.py create mode 100644 tests/test_decision_making/test_utility_functions/test_utility_functions.py diff --git a/gpjax/decision_making/utility_functions/probability_of_improvement.py b/gpjax/decision_making/utility_functions/probability_of_improvement.py new file mode 100644 index 000000000..92d644a03 --- /dev/null +++ b/gpjax/decision_making/utility_functions/probability_of_improvement.py @@ -0,0 +1,84 @@ +# Copyright 2024 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from dataclasses import dataclass + +from beartype.typing import Mapping +from jaxtyping import Num + +from gpjax.dataset import Dataset +from gpjax.decision_making.utility_functions.base import ( + AbstractSinglePointUtilityFunctionBuilder, + SinglePointUtilityFunction, +) +from gpjax.decision_making.utils import OBJECTIVE, gaussian_cdf +from gpjax.gps import ConjugatePosterior +from gpjax.typing import KeyArray, Array + + +@dataclass +class ProbabilityOfImprovement(AbstractSinglePointUtilityFunctionBuilder): + """ + TODO: write. + """ + + def build_utility_function( + self, + posteriors: Mapping[str, ConjugatePosterior], + datasets: Mapping[str, Dataset], + key: KeyArray, + ) -> SinglePointUtilityFunction: + """ + Draw an approximate sample from the posterior of the objective model and return + the *negative* of this sample as a utility function, as utility functions + are *maximised*. + + Args: + posteriors (Mapping[str, AbstractPosterior]): Dictionary of posteriors to be + used to form the utility function. One of the posteriors must correspond + to the `OBJECTIVE` key, as we sample from the objective posterior to form + the utility function. + datasets (Mapping[str, Dataset]): Dictionary of datasets which may be used + to form the utility function. Keys in `datasets` should correspond to + keys in `posteriors`. One of the datasets must correspond + to the `OBJECTIVE` key. + key (KeyArray): JAX PRNG key used for random number generation. This can be + changed to draw different samples. + + Returns: + SinglePointUtilityFunction: An appproximate sample from the objective model + posterior to to be *maximised* in order to decide which point to query + next. + """ + self.check_objective_present(posteriors, datasets) + + objective_posterior = posteriors[OBJECTIVE] + if not isinstance(objective_posterior, ConjugatePosterior): + raise ValueError( + "Objective posterior must be a ConjugatePosterior to draw an approximate sample." + ) + + objective_dataset = datasets[OBJECTIVE] + + def probability_of_improvement(x_test: Num[Array, "N D"]): + predictive_dist = objective_posterior.predict(x_test, objective_dataset) + + # Assuming that the goal is to minimize the objective function + best_y = objective_dataset.y.min() + + return gaussian_cdf( + (best_y - predictive_dist.mean()) / predictive_dist.stddev() + ).reshape(-1, 1) + + return probability_of_improvement # Utility functions are *maximised* diff --git a/gpjax/decision_making/utils.py b/gpjax/decision_making/utils.py index 9af19de32..bc8d3c332 100644 --- a/gpjax/decision_making/utils.py +++ b/gpjax/decision_making/utils.py @@ -18,6 +18,9 @@ Final, ) +import jax +import jax.numpy as jnp + from gpjax.dataset import Dataset from gpjax.typing import ( Array, @@ -48,3 +51,11 @@ def build_function_evaluator( dictionary of datasets storing the evaluated points. """ return lambda x: {tag: Dataset(x, f(x)) for tag, f in functions.items()} + + +def gaussian_cdf(x: Float[Array, "N"]) -> Float[Array, "N"]: + """ + Compute the cumulative distribution function of the standard normal distribution at + the points `x`. + """ + return 0.5 * (1 + jax.scipy.special.erf(x / jnp.sqrt(2))) diff --git a/tests/test_decision_making/test_utility_functions/test_thompson_sampling.py b/tests/test_decision_making/test_utility_functions/test_thompson_sampling.py index f2a726274..9fceb638e 100644 --- a/tests/test_decision_making/test_utility_functions/test_thompson_sampling.py +++ b/tests/test_decision_making/test_utility_functions/test_thompson_sampling.py @@ -17,11 +17,9 @@ config.update("jax_enable_x64", True) from beartype.typing import Callable -import jax.numpy as jnp import jax.random as jr import pytest -from gpjax.dataset import Dataset from gpjax.decision_making.test_functions.continuous_functions import ( AbstractContinuousTestFunction, Forrester, @@ -29,87 +27,12 @@ ) from gpjax.decision_making.utility_functions.thompson_sampling import ThompsonSampling from gpjax.decision_making.utils import OBJECTIVE -from gpjax.gps import ( - ConjugatePosterior, - NonConjugatePosterior, - Prior, -) -from gpjax.kernels import RBF -from gpjax.likelihoods import ( - Gaussian, - Poisson, -) -from gpjax.mean_functions import Zero from gpjax.typing import KeyArray - -def generate_dummy_conjugate_posterior(dataset: Dataset) -> ConjugatePosterior: - kernel = RBF(lengthscale=jnp.ones(dataset.X.shape[1])) - mean_function = Zero() - prior = Prior(kernel=kernel, mean_function=mean_function) - likelihood = Gaussian(num_datapoints=dataset.n) - posterior = prior * likelihood - return posterior - - -def generate_dummy_non_conjugate_posterior(dataset: Dataset) -> NonConjugatePosterior: - kernel = RBF(lengthscale=jnp.ones(dataset.X.shape[1])) - mean_function = Zero() - prior = Prior(kernel=kernel, mean_function=mean_function) - likelihood = Poisson(num_datapoints=dataset.n) - posterior = prior * likelihood - return posterior - - -@pytest.mark.filterwarnings( - "ignore::UserWarning" -) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort -def test_thompson_sampling_no_objective_posterior_raises_error(): - key = jr.key(42) - forrester = Forrester() - dataset = forrester.generate_dataset(num_points=10, key=key) - posterior = generate_dummy_conjugate_posterior(dataset) - posteriors = {"CONSTRAINT": posterior} - datasets = {OBJECTIVE: dataset} - with pytest.raises(ValueError): - ts_utility_builder = ThompsonSampling(num_features=100) - ts_utility_builder.build_utility_function( - posteriors=posteriors, datasets=datasets, key=key - ) - - -@pytest.mark.filterwarnings( - "ignore::UserWarning" -) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort -def test_thompson_sampling_no_objective_dataset_raises_error(): - key = jr.key(42) - forrester = Forrester() - dataset = forrester.generate_dataset(num_points=10, key=key) - posterior = generate_dummy_conjugate_posterior(dataset) - posteriors = {OBJECTIVE: posterior} - datasets = {"CONSTRAINT": dataset} - with pytest.raises(ValueError): - ts_utility_builder = ThompsonSampling(num_features=100) - ts_utility_builder.build_utility_function( - posteriors=posteriors, datasets=datasets, key=key - ) - - -@pytest.mark.filterwarnings( - "ignore::UserWarning" -) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort -def test_thompson_sampling_non_conjugate_posterior_raises_error(): - key = jr.key(42) - forrester = Forrester() - dataset = forrester.generate_dataset(num_points=10, key=key) - posterior = generate_dummy_non_conjugate_posterior(dataset) - posteriors = {OBJECTIVE: posterior} - datasets = {OBJECTIVE: dataset} - with pytest.raises(ValueError): - ts_utility_builder = ThompsonSampling(num_features=100) - ts_utility_builder.build_utility_function( - posteriors=posteriors, datasets=datasets, key=key - ) +from tests.test_decision_making.utils import ( + generate_dummy_conjugate_posterior, + generate_dummy_non_conjugate_posterior, +) @pytest.mark.parametrize("num_rff_features", [0, -1, -10]) @@ -130,34 +53,6 @@ def test_thompson_sampling_invalid_rff_num_raises_error(num_rff_features: int): ) -@pytest.mark.parametrize( - "test_target_function", - [(Forrester()), (LogarithmicGoldsteinPrice())], -) -@pytest.mark.parametrize("num_test_points", [50, 100]) -@pytest.mark.parametrize("key", [jr.key(42), jr.key(10)]) -@pytest.mark.filterwarnings( - "ignore::UserWarning" -) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort -def test_thompson_sampling_utility_function_correct_shapes( - test_target_function: AbstractContinuousTestFunction, - num_test_points: int, - key: KeyArray, -): - dataset = test_target_function.generate_dataset(num_points=10, key=key) - posterior = generate_dummy_conjugate_posterior(dataset) - posteriors = {OBJECTIVE: posterior} - datasets = {OBJECTIVE: dataset} - ts_utility_builder = ThompsonSampling(num_features=100) - ts_utility_function = ts_utility_builder.build_utility_function( - posteriors=posteriors, datasets=datasets, key=key - ) - test_key, _ = jr.split(key) - test_X = test_target_function.generate_test_points(num_test_points, test_key) - ts_utility_function_values = ts_utility_function(test_X) - assert ts_utility_function_values.shape == (num_test_points, 1) - - @pytest.mark.parametrize( "test_target_function", [(Forrester()), (LogarithmicGoldsteinPrice())], diff --git a/tests/test_decision_making/test_utility_functions/test_utility_functions.py b/tests/test_decision_making/test_utility_functions/test_utility_functions.py new file mode 100644 index 000000000..851923cb4 --- /dev/null +++ b/tests/test_decision_making/test_utility_functions/test_utility_functions.py @@ -0,0 +1,159 @@ +# Copyright 2023 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from jax import config + +config.update("jax_enable_x64", True) + +from beartype.typing import Type +import jax.random as jr +import pytest + +from gpjax.decision_making.test_functions.continuous_functions import ( + AbstractContinuousTestFunction, + Forrester, + LogarithmicGoldsteinPrice, +) +from gpjax.decision_making.utility_functions.base import ( + AbstractSinglePointUtilityFunctionBuilder, +) +from gpjax.decision_making.utility_functions.thompson_sampling import ThompsonSampling +from gpjax.decision_making.utility_functions.probability_of_improvement import ( + ProbabilityOfImprovement, +) +from gpjax.decision_making.utils import OBJECTIVE +from gpjax.typing import KeyArray + +from tests.test_decision_making.utils import ( + generate_dummy_conjugate_posterior, + generate_dummy_non_conjugate_posterior, +) + + +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +@pytest.mark.parametrize( + "utility_function_builder, utility_function_kwargs", + [ + (ProbabilityOfImprovement, {}), + (ThompsonSampling, {"num_features": 100}), + ], +) +def test_utility_function_no_objective_posterior_raises_error( + utility_function_builder: Type[AbstractSinglePointUtilityFunctionBuilder], + utility_function_kwargs: dict, +): + key = jr.key(42) + forrester = Forrester() + dataset = forrester.generate_dataset(num_points=10, key=key) + posterior = generate_dummy_conjugate_posterior(dataset) + posteriors = {"CONSTRAINT": posterior} + datasets = {OBJECTIVE: dataset} + with pytest.raises(ValueError): + utility_function = utility_function_builder(**utility_function_kwargs) + utility_function.build_utility_function( + posteriors=posteriors, datasets=datasets, key=key + ) + + +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +@pytest.mark.parametrize( + "utility_function_builder, utility_function_kwargs", + [ + (ProbabilityOfImprovement, {}), + (ThompsonSampling, {"num_features": 100}), + ], +) +def test_utility_function_no_objective_dataset_raises_error( + utility_function_builder: Type[AbstractSinglePointUtilityFunctionBuilder], + utility_function_kwargs: dict, +): + key = jr.key(42) + forrester = Forrester() + dataset = forrester.generate_dataset(num_points=10, key=key) + posterior = generate_dummy_conjugate_posterior(dataset) + posteriors = {OBJECTIVE: posterior} + datasets = {"CONSTRAINT": dataset} + with pytest.raises(ValueError): + utility_function = utility_function_builder(**utility_function_kwargs) + utility_function.build_utility_function( + posteriors=posteriors, datasets=datasets, key=key + ) + + +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +@pytest.mark.parametrize( + "utility_function_builder, utility_function_kwargs", + [ + (ProbabilityOfImprovement, {}), + (ThompsonSampling, {"num_features": 100}), + ], +) +def test_non_conjugate_posterior_raises_error( + utility_function_builder: Type[AbstractSinglePointUtilityFunctionBuilder], + utility_function_kwargs: dict, +): + key = jr.key(42) + forrester = Forrester() + dataset = forrester.generate_dataset(num_points=10, key=key) + posterior = generate_dummy_non_conjugate_posterior(dataset) + posteriors = {OBJECTIVE: posterior} + datasets = {OBJECTIVE: dataset} + with pytest.raises(ValueError): + utility_function = utility_function_builder(**utility_function_kwargs) + utility_function.build_utility_function( + posteriors=posteriors, datasets=datasets, key=key + ) + + +@pytest.mark.parametrize( + "utility_function_builder, utility_function_kwargs", + [ + (ProbabilityOfImprovement, {}), + (ThompsonSampling, {"num_features": 100}), + ], +) +@pytest.mark.parametrize( + "test_target_function", + [(Forrester()), (LogarithmicGoldsteinPrice())], +) +@pytest.mark.parametrize("num_test_points", [50, 100]) +@pytest.mark.parametrize("key", [jr.key(42), jr.key(10)]) +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +def test_utility_functions_have_correct_shapes( + utility_function_builder: Type[AbstractSinglePointUtilityFunctionBuilder], + utility_function_kwargs: dict, + test_target_function: AbstractContinuousTestFunction, + num_test_points: int, + key: KeyArray, +): + dataset = test_target_function.generate_dataset(num_points=10, key=key) + posterior = generate_dummy_conjugate_posterior(dataset) + posteriors = {OBJECTIVE: posterior} + datasets = {OBJECTIVE: dataset} + ts_utility_builder = utility_function_builder(**utility_function_kwargs) + ts_utility_function = ts_utility_builder.build_utility_function( + posteriors=posteriors, datasets=datasets, key=key + ) + test_key, _ = jr.split(key) + test_X = test_target_function.generate_test_points(num_test_points, test_key) + ts_utility_function_values = ts_utility_function(test_X) + assert ts_utility_function_values.shape == (num_test_points, 1) diff --git a/tests/test_decision_making/utils.py b/tests/test_decision_making/utils.py index 4ea3e9dc3..6b3fe74c6 100644 --- a/tests/test_decision_making/utils.py +++ b/tests/test_decision_making/utils.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================== +import jax.numpy as jnp + from beartype.typing import Mapping from gpjax.dataset import Dataset @@ -21,8 +23,14 @@ AbstractSinglePointUtilityFunctionBuilder, SinglePointUtilityFunction, ) -from gpjax.gps import ConjugatePosterior +from gpjax.gps import ConjugatePosterior, NonConjugatePosterior, Prior from gpjax.typing import KeyArray +from gpjax.kernels import RBF +from gpjax.likelihoods import ( + Gaussian, + Poisson, +) +from gpjax.mean_functions import Zero class QuadraticSinglePointUtilityFunctionBuilder( @@ -45,3 +53,21 @@ def build_utility_function( return lambda x: -1.0 * test_function.evaluate( x ) # Utility functions are *maximised* + + +def generate_dummy_conjugate_posterior(dataset: Dataset) -> ConjugatePosterior: + kernel = RBF(lengthscale=jnp.ones(dataset.X.shape[1])) + mean_function = Zero() + prior = Prior(kernel=kernel, mean_function=mean_function) + likelihood = Gaussian(num_datapoints=dataset.n) + posterior = prior * likelihood + return posterior + + +def generate_dummy_non_conjugate_posterior(dataset: Dataset) -> NonConjugatePosterior: + kernel = RBF(lengthscale=jnp.ones(dataset.X.shape[1])) + mean_function = Zero() + prior = Prior(kernel=kernel, mean_function=mean_function) + likelihood = Poisson(num_datapoints=dataset.n) + posterior = prior * likelihood + return posterior From 7ec1f896f8c3b43ff982f225a30cd3fc47ba7679 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Gonz=C3=A1lez=20Duque?= Date: Mon, 1 Jul 2024 14:12:08 +0200 Subject: [PATCH 02/14] Finishes a first draft of the tutorial --- docs/examples/bayesian_optimisation.py | 2 + ...n_optimisation_with_other_acq_functions.py | 506 ++++++++++++++++++ .../test_probability_of_improvement.py | 34 ++ .../test_thompson_sampling.py | 1 - 4 files changed, 542 insertions(+), 1 deletion(-) create mode 100644 docs/examples/bayesian_optimisation_with_other_acq_functions.py create mode 100644 tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py diff --git a/docs/examples/bayesian_optimisation.py b/docs/examples/bayesian_optimisation.py index 8182beae1..e8807085c 100644 --- a/docs/examples/bayesian_optimisation.py +++ b/docs/examples/bayesian_optimisation.py @@ -756,3 +756,5 @@ def obtain_log_regret_statistics( # %% # %reload_ext watermark # %watermark -n -u -v -iv -w -a 'Thomas Christie' + +# %% diff --git a/docs/examples/bayesian_optimisation_with_other_acq_functions.py b/docs/examples/bayesian_optimisation_with_other_acq_functions.py new file mode 100644 index 000000000..04d54f8f8 --- /dev/null +++ b/docs/examples/bayesian_optimisation_with_other_acq_functions.py @@ -0,0 +1,506 @@ +# %% [markdown] +# # Bayesian Optimisation beyond Thompson Sampling +# +# In [a previous guide](), we gave an introduction to Bayesian optimisation: +# a framework for optimising black-box function that leverages the +# uncertainty estimates that come from Gaussian processes. + +# %% +# Enable Float64 for more stable matrix inversions. +from jax import config + +config.update("jax_enable_x64", True) + +import jax +from jax import jit +import jax.numpy as jnp +import jax.random as jr +from jaxtyping import install_import_hook, Float, Int +import matplotlib as mpl +import matplotlib.pyplot as plt +from matplotlib import cm +import optax as ox +import tensorflow_probability.substrates.jax as tfp +from typing import List, Tuple + +with install_import_hook("gpjax", "beartype.beartype"): + import gpjax as gpx +from gpjax.typing import Array, FunctionalSample, ScalarFloat +from jaxopt import ScipyBoundedMinimize + +key = jr.key(1337) +plt.style.use( + "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" +) +cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] + +# %% [markdown] + +# In a few words, Bayesian optimisation starts by fitting a Gaussian +# process to the data we have collected so far about the objective function, +# and then uses this model to construct an **acquisition function** +# which tells us which parts of the input domain have the potential +# of improving our optimization. Unlike the black-box objective, the +# acquisition function is easy to query and optimize. + +# Acquisition functions come in many flavors. In the previous guide, +# we sampled from the Gaussian Process' predictive posterior and +# optimized said sample. This is known as Thompson Sampling, +# and it is a popular acquisition function due to its simplicity +# and ease of parallelization. + +# In this guide we will introduce a new acquisition function: +# probability of improvement [TODO:ADDCITE](). This acquisition function can be formally +# defined as + +# $$ \text{PI}(x) = \text{Prob}(f(x) < f(x_{\text{best}})) $$ + +# where $f(x)$ is the objective functionw we aim at **minimizing**, +# and $x_{\text{best}}$ is the best point we have seen so far (i.e. +# the point with the lowest value of $f$). The name is clear: it measures +# the probability that a given point $x$ will **improve** our +# optimization trace. + +# %% [markdown] +# ## Optimizing a 1D function: Forrester +# +# Just like in our previous guide, let's start by defining the +# [Forrester objective function](https://www.sfu.ca/~ssurjano/forretal08.html). + + +# %% +def standardised_forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: + mean = 0.45321 + std = 4.4258 + return ((6 * x - 2) ** 2 * jnp.sin(12 * x - 4) - mean) / std + + +# %% +lower_bound = jnp.array([0.0]) +upper_bound = jnp.array([1.0]) +initial_sample_num = 5 + +initial_x = tfp.mcmc.sample_halton_sequence( + dim=1, num_results=initial_sample_num, seed=key, dtype=jnp.float64 +).reshape(-1, 1) +initial_y = standardised_forrester(initial_x) +D = gpx.Dataset(X=initial_x, y=initial_y) + +# %% [markdown] + +# ...defining the Gaussian Process model... + + +# %% +def return_optimised_posterior( + data: gpx.Dataset, prior: gpx.base.Module, key: Array +) -> gpx.base.Module: + likelihood = gpx.likelihoods.Gaussian( + num_datapoints=data.n, obs_stddev=jnp.array(1e-6) + ) # Our function is noise-free, so we set the observation noise's standard deviation to a very small value + likelihood = likelihood.replace_trainable(obs_stddev=False) + + posterior = prior * likelihood + + negative_mll = gpx.objectives.ConjugateMLL(negative=True) + negative_mll(posterior, train_data=data) + negative_mll = jit(negative_mll) + + opt_posterior, _ = gpx.fit( + model=posterior, + objective=negative_mll, + train_data=data, + optim=ox.adam(learning_rate=0.01), + num_iters=1000, + safe=True, + key=key, + verbose=False, + ) + + return opt_posterior + + +mean = gpx.mean_functions.Zero() +kernel = gpx.kernels.Matern52() +prior = gpx.gps.Prior(mean_function=mean, kernel=kernel) +opt_posterior = return_optimised_posterior(D, prior, key) + + +from gpjax.decision_making.utility_functions.probability_of_improvement import ( + ProbabilityOfImprovement, +) + +utility_function_builder = ProbabilityOfImprovement() +utility_function = utility_function_builder.build_utility_function( + posteriors={"OBJECTIVE": opt_posterior}, datasets={"OBJECTIVE": D}, key=key +) + +from gpjax.decision_making.utility_maximizer import ( + ContinuousSinglePointUtilityMaximizer, +) +from gpjax.decision_making.utility_functions.base import SinglePointUtilityFunction +from gpjax.decision_making.search_space import ContinuousSearchSpace + + +def optimize_acquisition_function( + utility_function: SinglePointUtilityFunction, + key: Array, + lower_bounds: Float[Array, "D"], + upper_bounds: Float[Array, "D"], + num_initial_samples: int = 100, + num_restarts: int = 5, +): + optimizer = ContinuousSinglePointUtilityMaximizer( + num_initial_samples=num_initial_samples, num_restarts=num_restarts + ) + + search_space = ContinuousSearchSpace( + lower_bounds=lower_bounds, upper_bounds=upper_bounds + ) + + x_next_best = optimizer.maximize( + utility_function, search_space=search_space, key=key + ) + + return x_next_best + + +# %% +def construct_acquisition_function( + opt_posterior: gpx.base.Module, + dataset: gpx.Dataset, + key: Array, +) -> SinglePointUtilityFunction: + utility_function_builder = ProbabilityOfImprovement() + utility_function = utility_function_builder.build_utility_function( + posteriors={"OBJECTIVE": opt_posterior}, + datasets={"OBJECTIVE": dataset}, + key=key, + ) + + return utility_function + + +def propose_next_candidate( + utility_function: SinglePointUtilityFunction, + key: Array, + lower_bounds: Float[Array, "D"], + upper_bounds: Float[Array, "D"], +) -> Float[Array, "D 1"]: + queried_x = optimize_acquisition_function( + utility_function, key, lower_bounds, upper_bounds, num_initial_samples=100 + ) + + return queried_x + + +def run_one_bo_loop_in_1D( + objective_function: standardised_forrester, + opt_posterior: gpx.base.Module, + dataset: gpx.Dataset, + key: Array, + plot: bool = True, +) -> Float[Array, "D 1"]: + domain = jnp.linspace(0, 1, 1000).reshape(-1, 1) + objective_values = objective_function(domain) + + latent_distribution = opt_posterior.predict(domain, train_data=dataset) + predictive_distribution = opt_posterior.likelihood(latent_distribution) + + predictive_mean = predictive_distribution.mean() + predictive_std = predictive_distribution.stddev() + + # Building PI + utility_function = construct_acquisition_function(opt_posterior, dataset, key) + utility_function_values = utility_function(domain) + + # Optimizing the acq. function + lower_bound = jnp.array([0.0]) + upper_bound = jnp.array([1.0]) + queried_x = propose_next_candidate(utility_function, key, lower_bound, upper_bound) + + if plot: + fig, ax = plt.subplots() + ax.plot(domain, predictive_mean, label="Predictive Mean", color=cols[1]) + ax.fill_between( + domain.squeeze(), + predictive_mean - 2 * predictive_std, + predictive_mean + 2 * predictive_std, + alpha=0.2, + label="Two sigma", + color=cols[1], + ) + ax.plot( + domain, + predictive_mean - 2 * predictive_std, + linestyle="--", + linewidth=1, + color=cols[1], + ) + ax.plot( + domain, + predictive_mean + 2 * predictive_std, + linestyle="--", + linewidth=1, + color=cols[1], + ) + ax.plot(domain, utility_function_values, label="Probability of Improvement") + ax.plot( + domain, + objective_values, + label="Forrester Function", + color=cols[0], + linestyle="--", + linewidth=2, + ) + ax.axvline(x=0.757, linestyle=":", color=cols[3], label="True Optimum") + ax.scatter(dataset.X, dataset.y, label="Observations", color=cols[2], zorder=2) + ax.scatter( + queried_x, + utility_function(queried_x), + label="Probability of Improvement Optimum", + marker="*", + color=cols[3], + zorder=3, + ) + ax.legend(loc="center left", bbox_to_anchor=(0.975, 0.5)) + plt.show() + + return queried_x + + +# plot_bayes_opt_using_pi(standardised_forrester, opt_posterior, D, key) + +# %% +bo_iters = 5 + +# Set up initial dataset +initial_x = tfp.mcmc.sample_halton_sequence( + dim=1, num_results=initial_sample_num, seed=key, dtype=jnp.float64 +).reshape(-1, 1) +initial_y = standardised_forrester(initial_x) +D = gpx.Dataset(X=initial_x, y=initial_y) + +for i in range(bo_iters): + key, subkey = jr.split(key) + + # Generate optimised posterior using previously observed data + mean = gpx.mean_functions.Zero() + kernel = gpx.kernels.Matern52() + prior = gpx.gps.Prior(mean_function=mean, kernel=kernel) + opt_posterior = return_optimised_posterior(D, prior, subkey) + + queried_x = run_one_bo_loop_in_1D(standardised_forrester, opt_posterior, D, key) + + # Evaluate the black-box function at the best point observed so far, and add it to the dataset + y_star = standardised_forrester(queried_x) + print(f"Queried Point: {queried_x}, Black-Box Function Value: {y_star}") + D = D + gpx.Dataset(X=queried_x, y=y_star) + + +# %% +def standardised_six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]: + mean = 1.12767 + std = 1.17500 + x1 = x[..., :1] + x2 = x[..., 1:] + term1 = (4 - 2.1 * x1**2 + x1**4 / 3) * x1**2 + term2 = x1 * x2 + term3 = (-4 + 4 * x2**2) * x2**2 + return (term1 + term2 + term3 - mean) / std + + +# %% +x1 = jnp.linspace(-2, 2, 100) +x2 = jnp.linspace(-1, 1, 100) +x1, x2 = jnp.meshgrid(x1, x2) +x = jnp.stack([x1.flatten(), x2.flatten()], axis=1) +y = standardised_six_hump_camel(x) + +fig, ax = plt.subplots(subplot_kw={"projection": "3d"}) +surf = ax.plot_surface( + x1, + x2, + y.reshape(x1.shape[0], x2.shape[0]), + linewidth=0, + cmap=cm.coolwarm, + antialiased=False, +) +ax.set_xlabel("x1") +ax.set_ylabel("x2") +plt.show() + + +# %% +x_star_one = jnp.array([[0.0898, -0.7126]]) +x_star_two = jnp.array([[-0.0898, 0.7126]]) +fig, ax = plt.subplots() +contour_plot = ax.contourf( + x1, x2, y.reshape(x1.shape[0], x2.shape[0]), cmap=cm.coolwarm, levels=40 +) +ax.scatter( + x_star_one[0][0], x_star_one[0][1], marker="*", color=cols[2], label="Global Minima" +) +ax.scatter(x_star_two[0][0], x_star_two[0][1], marker="*", color=cols[2]) +ax.set_xlabel("x1") +ax.set_ylabel("x2") +fig.colorbar(contour_plot) +ax.legend() +plt.show() + +# %% +lower_bound = jnp.array([-2.0, -1.0]) +upper_bound = jnp.array([2.0, 1.0]) +initial_sample_num = 5 +bo_iters = 20 +num_experiments = 5 +bo_experiment_results = [] + +for experiment in range(num_experiments): + print(f"Starting Experiment: {experiment + 1}") + # Set up initial dataset + initial_x = tfp.mcmc.sample_halton_sequence( + dim=2, num_results=initial_sample_num, seed=key, dtype=jnp.float64 + ) + initial_x = jnp.array(lower_bound + (upper_bound - lower_bound) * initial_x) + initial_y = standardised_six_hump_camel(initial_x) + D = gpx.Dataset(X=initial_x, y=initial_y) + + for i in range(bo_iters): + key, subkey = jr.split(key) + + # Generate optimised posterior + mean = gpx.mean_functions.Zero() + kernel = gpx.kernels.Matern52( + active_dims=[0, 1], lengthscale=jnp.array([1.0, 1.0]), variance=2.0 + ) + prior = gpx.gps.Prior(mean_function=mean, kernel=kernel) + opt_posterior = return_optimised_posterior(D, prior, subkey) + + # Constructing the acq. function + utility_function = construct_acquisition_function(opt_posterior, D, key) + + # Draw a sample from the posterior, and find the minimiser of it + queried_x = propose_next_candidate( + utility_function, key, lower_bound, upper_bound + ) + + # Evaluate the black-box function at the best point observed so far, and add it to the dataset + y_star = standardised_six_hump_camel(queried_x) + print( + f"BO Iteration: {i + 1}, Queried Point: {queried_x}, Black-Box Function Value:" + f" {y_star}" + ) + D = D + gpx.Dataset(X=queried_x, y=y_star) + bo_experiment_results.append(D) + +# %% +random_experiment_results = [] +for i in range(num_experiments): + key, subkey = jr.split(key) + initial_x = bo_experiment_results[i].X[:5] + initial_y = bo_experiment_results[i].y[:5] + final_x = jr.uniform( + key, + shape=(bo_iters, 2), + dtype=jnp.float64, + minval=lower_bound, + maxval=upper_bound, + ) + final_y = standardised_six_hump_camel(final_x) + random_x = jnp.concatenate([initial_x, final_x], axis=0) + random_y = jnp.concatenate([initial_y, final_y], axis=0) + random_experiment_results.append(gpx.Dataset(X=random_x, y=random_y)) + + +# %% +def obtain_log_regret_statistics( + experiment_results: List[gpx.Dataset], + global_minimum: ScalarFloat, +) -> Tuple[Float[Array, "N 1"], Float[Array, "N 1"]]: + log_regret_results = [] + for exp_result in experiment_results: + observations = exp_result.y + cumulative_best_observations = jax.lax.associative_scan( + jax.numpy.minimum, observations + ) + regret = cumulative_best_observations - global_minimum + log_regret = jnp.log(regret) + log_regret_results.append(log_regret) + + log_regret_results = jnp.array(log_regret_results) + log_regret_mean = jnp.mean(log_regret_results, axis=0) + log_regret_std = jnp.std(log_regret_results, axis=0) + return log_regret_mean, log_regret_std + + +bo_log_regret_mean, bo_log_regret_std = obtain_log_regret_statistics( + bo_experiment_results, -1.8377 +) +( + random_log_regret_mean, + random_log_regret_std, +) = obtain_log_regret_statistics(random_experiment_results, -1.8377) + +# %% [markdown] +# Now, when we plot the mean and standard deviation of the log regret at each iteration, +# we can see that BO outperforms random sampling! + +# %% +fig, ax = plt.subplots() +fn_evaluations = jnp.arange(1, bo_iters + initial_sample_num + 1) +ax.plot(fn_evaluations, bo_log_regret_mean, label="Bayesian Optimisation") +ax.fill_between( + fn_evaluations, + bo_log_regret_mean[:, 0] - bo_log_regret_std[:, 0], + bo_log_regret_mean[:, 0] + bo_log_regret_std[:, 0], + alpha=0.2, +) +ax.plot(fn_evaluations, random_log_regret_mean, label="Random Search") +ax.fill_between( + fn_evaluations, + random_log_regret_mean[:, 0] - random_log_regret_std[:, 0], + random_log_regret_mean[:, 0] + random_log_regret_std[:, 0], + alpha=0.2, +) +ax.axvline(x=initial_sample_num, linestyle=":") +ax.set_xlabel("Number of Black-Box Function Evaluations") +ax.set_ylabel("Log Regret") +ax.legend() +plt.show() + +# %% [markdown] +# It can also be useful to plot the queried points over the course of a single BO run, in +# order to gain some insight into how the algorithm queries the search space. Below +# we do this for one of the BO experiments, and can see that the algorithm initially +# performs some exploration of the search space whilst it is uncertain about the black-box +# function, but it then hones in one one of the global minima of the function, as we would hope! + +# %% +fig, ax = plt.subplots() +contour_plot = ax.contourf( + x1, x2, y.reshape(x1.shape[0], x2.shape[0]), cmap=cm.coolwarm, levels=40 +) +ax.scatter( + x_star_one[0][0], + x_star_one[0][1], + marker="*", + color=cols[2], + label="Global Minimum", + zorder=2, +) +ax.scatter(x_star_two[0][0], x_star_two[0][1], marker="*", color=cols[2], zorder=2) +ax.scatter( + bo_experiment_results[1].X[:, 0], + bo_experiment_results[1].X[:, 1], + marker="x", + color=cols[1], + label="Bayesian Optimisation Queries", +) +ax.set_xlabel("x1") +ax.set_ylabel("x2") +fig.colorbar(contour_plot) +ax.legend() +plt.show() + +# %% diff --git a/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py b/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py new file mode 100644 index 000000000..9b53cd553 --- /dev/null +++ b/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py @@ -0,0 +1,34 @@ +# Copyright 2023 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from jax import config + +config.update("jax_enable_x64", True) + +from beartype.typing import Callable +import jax.random as jr +import pytest + +from gpjax.decision_making.test_functions.continuous_functions import ( + AbstractContinuousTestFunction, + Forrester, + LogarithmicGoldsteinPrice, +) +from gpjax.decision_making.utility_functions.thompson_sampling import ThompsonSampling +from gpjax.decision_making.utils import OBJECTIVE +from gpjax.typing import KeyArray + +from tests.test_decision_making.utils import ( + generate_dummy_conjugate_posterior, +) diff --git a/tests/test_decision_making/test_utility_functions/test_thompson_sampling.py b/tests/test_decision_making/test_utility_functions/test_thompson_sampling.py index 9fceb638e..bc2ab487c 100644 --- a/tests/test_decision_making/test_utility_functions/test_thompson_sampling.py +++ b/tests/test_decision_making/test_utility_functions/test_thompson_sampling.py @@ -31,7 +31,6 @@ from tests.test_decision_making.utils import ( generate_dummy_conjugate_posterior, - generate_dummy_non_conjugate_posterior, ) From c67083fd065bb12b516f08c521f59010ad5b1a30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Gonz=C3=A1lez=20Duque?= Date: Mon, 1 Jul 2024 14:29:40 +0200 Subject: [PATCH 03/14] Runs pre-commit hooks --- ...n_optimisation_with_other_acq_functions.py | 5 ++- .../probability_of_improvement.py | 10 ++++-- gpjax/decision_making/utils.py | 1 - .../test_probability_of_improvement.py | 34 ------------------- .../test_thompson_sampling.py | 5 +-- .../test_utility_functions.py | 3 +- tests/test_decision_making/utils.py | 11 +++--- 7 files changed, 19 insertions(+), 50 deletions(-) delete mode 100644 tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py diff --git a/docs/examples/bayesian_optimisation_with_other_acq_functions.py b/docs/examples/bayesian_optimisation_with_other_acq_functions.py index 04d54f8f8..46ef523d7 100644 --- a/docs/examples/bayesian_optimisation_with_other_acq_functions.py +++ b/docs/examples/bayesian_optimisation_with_other_acq_functions.py @@ -15,7 +15,7 @@ from jax import jit import jax.numpy as jnp import jax.random as jr -from jaxtyping import install_import_hook, Float, Int +from jaxtyping import install_import_hook, Float import matplotlib as mpl import matplotlib.pyplot as plt from matplotlib import cm @@ -25,8 +25,7 @@ with install_import_hook("gpjax", "beartype.beartype"): import gpjax as gpx -from gpjax.typing import Array, FunctionalSample, ScalarFloat -from jaxopt import ScipyBoundedMinimize +from gpjax.typing import Array, ScalarFloat key = jr.key(1337) plt.style.use( diff --git a/gpjax/decision_making/utility_functions/probability_of_improvement.py b/gpjax/decision_making/utility_functions/probability_of_improvement.py index 92d644a03..82ab69f0d 100644 --- a/gpjax/decision_making/utility_functions/probability_of_improvement.py +++ b/gpjax/decision_making/utility_functions/probability_of_improvement.py @@ -22,9 +22,15 @@ AbstractSinglePointUtilityFunctionBuilder, SinglePointUtilityFunction, ) -from gpjax.decision_making.utils import OBJECTIVE, gaussian_cdf +from gpjax.decision_making.utils import ( + OBJECTIVE, + gaussian_cdf, +) from gpjax.gps import ConjugatePosterior -from gpjax.typing import KeyArray, Array +from gpjax.typing import ( + Array, + KeyArray, +) @dataclass diff --git a/gpjax/decision_making/utils.py b/gpjax/decision_making/utils.py index bc8d3c332..a186dc1b1 100644 --- a/gpjax/decision_making/utils.py +++ b/gpjax/decision_making/utils.py @@ -17,7 +17,6 @@ Dict, Final, ) - import jax import jax.numpy as jnp diff --git a/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py b/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py deleted file mode 100644 index 9b53cd553..000000000 --- a/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright 2023 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -from jax import config - -config.update("jax_enable_x64", True) - -from beartype.typing import Callable -import jax.random as jr -import pytest - -from gpjax.decision_making.test_functions.continuous_functions import ( - AbstractContinuousTestFunction, - Forrester, - LogarithmicGoldsteinPrice, -) -from gpjax.decision_making.utility_functions.thompson_sampling import ThompsonSampling -from gpjax.decision_making.utils import OBJECTIVE -from gpjax.typing import KeyArray - -from tests.test_decision_making.utils import ( - generate_dummy_conjugate_posterior, -) diff --git a/tests/test_decision_making/test_utility_functions/test_thompson_sampling.py b/tests/test_decision_making/test_utility_functions/test_thompson_sampling.py index bc2ab487c..8986a4cf6 100644 --- a/tests/test_decision_making/test_utility_functions/test_thompson_sampling.py +++ b/tests/test_decision_making/test_utility_functions/test_thompson_sampling.py @@ -28,10 +28,7 @@ from gpjax.decision_making.utility_functions.thompson_sampling import ThompsonSampling from gpjax.decision_making.utils import OBJECTIVE from gpjax.typing import KeyArray - -from tests.test_decision_making.utils import ( - generate_dummy_conjugate_posterior, -) +from tests.test_decision_making.utils import generate_dummy_conjugate_posterior @pytest.mark.parametrize("num_rff_features", [0, -1, -10]) diff --git a/tests/test_decision_making/test_utility_functions/test_utility_functions.py b/tests/test_decision_making/test_utility_functions/test_utility_functions.py index 851923cb4..f8e888e82 100644 --- a/tests/test_decision_making/test_utility_functions/test_utility_functions.py +++ b/tests/test_decision_making/test_utility_functions/test_utility_functions.py @@ -28,13 +28,12 @@ from gpjax.decision_making.utility_functions.base import ( AbstractSinglePointUtilityFunctionBuilder, ) -from gpjax.decision_making.utility_functions.thompson_sampling import ThompsonSampling from gpjax.decision_making.utility_functions.probability_of_improvement import ( ProbabilityOfImprovement, ) +from gpjax.decision_making.utility_functions.thompson_sampling import ThompsonSampling from gpjax.decision_making.utils import OBJECTIVE from gpjax.typing import KeyArray - from tests.test_decision_making.utils import ( generate_dummy_conjugate_posterior, generate_dummy_non_conjugate_posterior, diff --git a/tests/test_decision_making/utils.py b/tests/test_decision_making/utils.py index 6b3fe74c6..e4f083832 100644 --- a/tests/test_decision_making/utils.py +++ b/tests/test_decision_making/utils.py @@ -13,9 +13,8 @@ # limitations under the License. # ============================================================================== -import jax.numpy as jnp - from beartype.typing import Mapping +import jax.numpy as jnp from gpjax.dataset import Dataset from gpjax.decision_making.test_functions import Quadratic @@ -23,14 +22,18 @@ AbstractSinglePointUtilityFunctionBuilder, SinglePointUtilityFunction, ) -from gpjax.gps import ConjugatePosterior, NonConjugatePosterior, Prior -from gpjax.typing import KeyArray +from gpjax.gps import ( + ConjugatePosterior, + NonConjugatePosterior, + Prior, +) from gpjax.kernels import RBF from gpjax.likelihoods import ( Gaussian, Poisson, ) from gpjax.mean_functions import Zero +from gpjax.typing import KeyArray class QuadraticSinglePointUtilityFunctionBuilder( From e0fc9d25a2323e5ae8b5814af4d9221497275434 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Gonz=C3=A1lez=20Duque?= Date: Mon, 1 Jul 2024 14:31:27 +0200 Subject: [PATCH 04/14] Fixes a Ruff error on jaxtyping annotation --- gpjax/decision_making/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpjax/decision_making/utils.py b/gpjax/decision_making/utils.py index a186dc1b1..873bf21de 100644 --- a/gpjax/decision_making/utils.py +++ b/gpjax/decision_making/utils.py @@ -52,7 +52,7 @@ def build_function_evaluator( return lambda x: {tag: Dataset(x, f(x)) for tag, f in functions.items()} -def gaussian_cdf(x: Float[Array, "N"]) -> Float[Array, "N"]: +def gaussian_cdf(x: Float[Array, " N"]) -> Float[Array, " N"]: """ Compute the cumulative distribution function of the standard normal distribution at the points `x`. From 623181ac9c5900d57ca69f7d6d4cdf15e454ad20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Gonz=C3=A1lez=20Duque?= Date: Mon, 1 Jul 2024 15:18:43 +0200 Subject: [PATCH 05/14] Adds docstrings to probability of improvement --- docs/examples/bayesian_optimisation.py | 2 - .../probability_of_improvement.py | 41 ++++++++++++++----- tests/test_decision_making/test_utils.py | 7 ++++ 3 files changed, 37 insertions(+), 13 deletions(-) diff --git a/docs/examples/bayesian_optimisation.py b/docs/examples/bayesian_optimisation.py index e8807085c..8182beae1 100644 --- a/docs/examples/bayesian_optimisation.py +++ b/docs/examples/bayesian_optimisation.py @@ -756,5 +756,3 @@ def obtain_log_regret_statistics( # %% # %reload_ext watermark # %watermark -n -u -v -iv -w -a 'Thomas Christie' - -# %% diff --git a/gpjax/decision_making/utility_functions/probability_of_improvement.py b/gpjax/decision_making/utility_functions/probability_of_improvement.py index 82ab69f0d..ee483d3ea 100644 --- a/gpjax/decision_making/utility_functions/probability_of_improvement.py +++ b/gpjax/decision_making/utility_functions/probability_of_improvement.py @@ -35,8 +35,30 @@ @dataclass class ProbabilityOfImprovement(AbstractSinglePointUtilityFunctionBuilder): - """ - TODO: write. + r""" + An acquisition function which returns the probability of improvement + of the objective function over the best observed value. + + More precisely, given a predictive posterior distribution of the objective + function, the probability of improvement at a test point $`x`$ is defined as: + $$`\text{PI}(x) = \text{Prob}[f(x) < f(x_{\text{best}})]`$$ + where $`x_{\text{best}}`$ is the minimizer of $`f`$ in the dataset. + + The probability of improvement can be easily computed using the + cumulative distribution function of the standard normal distribution $`\Phi`$: + $$`\text{PI}(x) = \Phi\left(\frac{f(x_{\text{best}}) - \mu}{\sigma}\right)`$$ + where $`\mu`$ and $`\sigma`$ are the mean and standard deviation of the + predictive distribution of the objective function at $`x`$. + + References + ---------- + [1] Kushner, H. J. (1964). + A new method of locating the maximum point of an arbitrary multipeak curve in the presence of noise. + Journal of Basic Engineering, 86(1), 97-106. + + [2] Shahriari, B., Swersky, K., Wang, Z., Adams, R. P., & de Freitas, N. (2016). + Taking the human out of the loop: A review of Bayesian optimization. + Proceedings of the IEEE, 104(1), 148-175. doi: 10.1109/JPROC.2015.2494218 """ def build_utility_function( @@ -46,9 +68,8 @@ def build_utility_function( key: KeyArray, ) -> SinglePointUtilityFunction: """ - Draw an approximate sample from the posterior of the objective model and return - the *negative* of this sample as a utility function, as utility functions - are *maximised*. + Constructs the probability of improvement utility function + using the predictive posterior of the objective function. Args: posteriors (Mapping[str, AbstractPosterior]): Dictionary of posteriors to be @@ -59,13 +80,11 @@ def build_utility_function( to form the utility function. Keys in `datasets` should correspond to keys in `posteriors`. One of the datasets must correspond to the `OBJECTIVE` key. - key (KeyArray): JAX PRNG key used for random number generation. This can be - changed to draw different samples. + key (KeyArray): JAX PRNG key used for random number generation. Since the probability of improvement is computed deterministically + from the predictive posterior, the key is not used. Returns: - SinglePointUtilityFunction: An appproximate sample from the objective model - posterior to to be *maximised* in order to decide which point to query - next. + SinglePointUtilityFunction: the probability of improvement utility function. """ self.check_objective_present(posteriors, datasets) @@ -87,4 +106,4 @@ def probability_of_improvement(x_test: Num[Array, "N D"]): (best_y - predictive_dist.mean()) / predictive_dist.stddev() ).reshape(-1, 1) - return probability_of_improvement # Utility functions are *maximised* + return probability_of_improvement diff --git a/tests/test_decision_making/test_utils.py b/tests/test_decision_making/test_utils.py index f77bc6cab..d097ee23a 100644 --- a/tests/test_decision_making/test_utils.py +++ b/tests/test_decision_making/test_utils.py @@ -21,6 +21,7 @@ from gpjax.decision_making.utils import ( OBJECTIVE, build_function_evaluator, + gaussian_cdf, ) from gpjax.typing import ( Array, @@ -44,3 +45,9 @@ def _cube(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: assert jnp.equal(datasets[OBJECTIVE].y, _square(x)).all() assert jnp.equal(datasets["CONSTRAINT"].X, x).all() assert jnp.equal(datasets["CONSTRAINT"].y, _cube(x)).all() + + +def test_gaussian_cdf(): + x = jnp.array([0.0, 1.0, 2.0]) + cdf = jnp.array([0.5, 0.84134475, 0.97724987]) + assert jnp.allclose(cdf, gaussian_cdf(x)) From 5d0a84c8f4e74ecc7e96a99ed13a54f8a319e3cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Gonz=C3=A1lez=20Duque?= Date: Mon, 1 Jul 2024 15:38:59 +0200 Subject: [PATCH 06/14] Adds a simple test for PI --- .../probability_of_improvement.py | 5 +- .../test_probability_of_improvement.py | 62 +++++++++++++++++++ 2 files changed, 65 insertions(+), 2 deletions(-) create mode 100644 tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py diff --git a/gpjax/decision_making/utility_functions/probability_of_improvement.py b/gpjax/decision_making/utility_functions/probability_of_improvement.py index ee483d3ea..a044110b1 100644 --- a/gpjax/decision_making/utility_functions/probability_of_improvement.py +++ b/gpjax/decision_making/utility_functions/probability_of_improvement.py @@ -40,7 +40,7 @@ class ProbabilityOfImprovement(AbstractSinglePointUtilityFunctionBuilder): of the objective function over the best observed value. More precisely, given a predictive posterior distribution of the objective - function, the probability of improvement at a test point $`x`$ is defined as: + function $`f`$, the probability of improvement at a test point $`x`$ is defined as: $$`\text{PI}(x) = \text{Prob}[f(x) < f(x_{\text{best}})]`$$ where $`x_{\text{best}}`$ is the minimizer of $`f`$ in the dataset. @@ -80,7 +80,8 @@ def build_utility_function( to form the utility function. Keys in `datasets` should correspond to keys in `posteriors`. One of the datasets must correspond to the `OBJECTIVE` key. - key (KeyArray): JAX PRNG key used for random number generation. Since the probability of improvement is computed deterministically + key (KeyArray): JAX PRNG key used for random number generation. Since + the probability of improvement is computed deterministically from the predictive posterior, the key is not used. Returns: diff --git a/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py b/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py new file mode 100644 index 000000000..5f18f3ca0 --- /dev/null +++ b/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py @@ -0,0 +1,62 @@ +# Copyright 2023 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from jax import config + +config.update("jax_enable_x64", True) + +import jax.random as jr +import jax.numpy as jnp + +from gpjax.decision_making.test_functions.continuous_functions import Forrester +from gpjax.decision_making.utility_functions.probability_of_improvement import ( + ProbabilityOfImprovement, +) +from gpjax.decision_making.utils import OBJECTIVE +from tests.test_decision_making.utils import generate_dummy_conjugate_posterior + + +def test_probability_of_improvement_gives_correct_value_for_a_seed(): + key = jr.key(42) + forrester = Forrester() + dataset = forrester.generate_dataset(num_points=10, key=key) + posterior = generate_dummy_conjugate_posterior(dataset) + posteriors = {OBJECTIVE: posterior} + datasets = {OBJECTIVE: dataset} + + pi_utility_builder = ProbabilityOfImprovement() + pi_utility = pi_utility_builder.build_utility_function( + posteriors=posteriors, datasets=datasets, key=key + ) + + test_X = forrester.generate_test_points(num_points=10, key=key) + utility_values = pi_utility(test_X) + + expected_utility_values = jnp.array( + [ + 7.30230451e-05, + 5.00322831e-05, + 1.06219741e-03, + 2.19520435e-03, + 3.49279363e-05, + 1.66031943e-04, + 2.78478912e-04, + 3.35871920e-04, + 1.38265233e-04, + 3.63297977e-05, + ] + ).reshape(-1, 1) + + assert utility_values.shape == (10, 1) + assert jnp.isclose(utility_values, expected_utility_values).all() From 3924fa9749f59900218a5d91f1e78bf3b257ee53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Gonz=C3=A1lez=20Duque?= Date: Mon, 1 Jul 2024 20:56:57 +0200 Subject: [PATCH 07/14] Changes the tutorial a bit before deleting it --- ...n_optimisation_with_other_acq_functions.py | 156 ++++++++++-------- 1 file changed, 91 insertions(+), 65 deletions(-) diff --git a/docs/examples/bayesian_optimisation_with_other_acq_functions.py b/docs/examples/bayesian_optimisation_with_other_acq_functions.py index 46ef523d7..de3d58bba 100644 --- a/docs/examples/bayesian_optimisation_with_other_acq_functions.py +++ b/docs/examples/bayesian_optimisation_with_other_acq_functions.py @@ -5,6 +5,31 @@ # a framework for optimising black-box function that leverages the # uncertainty estimates that come from Gaussian processes. +# %% [markdown] + +# In a few words, Bayesian optimisation starts by fitting a Gaussian +# process to the data we have collected so far about the objective function, +# and then uses this model to construct an **acquisition function** +# which tells us which parts of the input domain have the potential +# of improving our optimization. Unlike the black-box objective, the +# acquisition function is easy to query and optimize. +# +# Acquisition functions come in many flavors. In the previous guide, +# we sampled from the Gaussian Process' predictive posterior and +# optimized said sample. This is known as **Thompson Sampling**, +# and it is a popular acquisition function due to its simplicity +# and ease of parallelization. +# +# In this guide we will introduce a new acquisition function: +# *Probability of Improvement* [[Kushner, 1964]](https://asmedigitalcollection.asme.org/fluidsengineering/article/86/1/97/392213/A-New-Method-of-Locating-the-Maximum-Point-of-an). This acquisition function can be formally +# defined as +# $$ \text{PI}(x) = \text{Prob}[f(x) < f(x_{\text{best}})] $$ +# where $f(x)$ is the objective functionw we aim at **minimizing**, +# and $x_{\text{best}}$ is the best point we have seen so far (i.e. +# the point with the lowest value for $f$). The name is clear: it measures +# the probability that a given point $x$ will **improve** our +# optimization trace. + # %% # Enable Float64 for more stable matrix inversions. from jax import config @@ -34,60 +59,46 @@ cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] # %% [markdown] +# ## Optimizing Logarithmic Goldstein-Price +# +# Our module for `decision_making` includes a couple of test functions +# you can use to test whether your model fitting/optimization algorithms are +# working properly. One of these is the [Logarithmic Goldstein-Price function](https://www.sfu.ca/~ssurjano/goldpr.html). +# +# These test functions we provide contain a search space, minimizer, and minimum +# attributes, as well as methods for generating training and testing datasets. -# In a few words, Bayesian optimisation starts by fitting a Gaussian -# process to the data we have collected so far about the objective function, -# and then uses this model to construct an **acquisition function** -# which tells us which parts of the input domain have the potential -# of improving our optimization. Unlike the black-box objective, the -# acquisition function is easy to query and optimize. - -# Acquisition functions come in many flavors. In the previous guide, -# we sampled from the Gaussian Process' predictive posterior and -# optimized said sample. This is known as Thompson Sampling, -# and it is a popular acquisition function due to its simplicity -# and ease of parallelization. - -# In this guide we will introduce a new acquisition function: -# probability of improvement [TODO:ADDCITE](). This acquisition function can be formally -# defined as +# %% +from gpjax.decision_making.test_functions import LogarithmicGoldsteinPrice -# $$ \text{PI}(x) = \text{Prob}(f(x) < f(x_{\text{best}})) $$ +logarithmic_goldstein_price = LogarithmicGoldsteinPrice() -# where $f(x)$ is the objective functionw we aim at **minimizing**, -# and $x_{\text{best}}$ is the best point we have seen so far (i.e. -# the point with the lowest value of $f$). The name is clear: it measures -# the probability that a given point $x$ will **improve** our -# optimization trace. +example_dataset = logarithmic_goldstein_price.generate_dataset(num_points=5, key=key) +example_values = logarithmic_goldstein_price.evaluate(example_dataset.X) +print(f"Example Dataset: {example_dataset.X}") +print(f"Example Values: {example_values}") +print(f"Minimizer: {logarithmic_goldstein_price.minimizer}") +print(f"Minimum: {logarithmic_goldstein_price.minimum}") # %% [markdown] -# ## Optimizing a 1D function: Forrester -# -# Just like in our previous guide, let's start by defining the -# [Forrester objective function](https://www.sfu.ca/~ssurjano/forretal08.html). +# Let's plot this function to see what it looks like. As you might have noticed from the dataset, this function takes as input points in # %% -def standardised_forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: - mean = 0.45321 - std = 4.4258 - return ((6 * x - 2) ** 2 * jnp.sin(12 * x - 4) - mean) / std +domain = jnp.linspace(0, 1, 1000).reshape(-1, 1) +objective_values = logarithmic_goldstein_price.evaluate(domain) -# %% -lower_bound = jnp.array([0.0]) -upper_bound = jnp.array([1.0]) -initial_sample_num = 5 - -initial_x = tfp.mcmc.sample_halton_sequence( - dim=1, num_results=initial_sample_num, seed=key, dtype=jnp.float64 -).reshape(-1, 1) -initial_y = standardised_forrester(initial_x) -D = gpx.Dataset(X=initial_x, y=initial_y) +fig, ax = plt.subplots() +ax.plot(domain, objective_values, label="Logarithmic Goldstein-Price Function") +ax.scatter( + example_dataset.X.flatten(), example_dataset.y, label="Observations", color=cols[2] +) # %% [markdown] -# ...defining the Gaussian Process model... +# Just like in the previous guide, we can fit a Gaussian Process to +# a dataset rendering an optimised posterior: # %% @@ -119,33 +130,61 @@ def return_optimised_posterior( return opt_posterior +D = logarithmic_goldstein_price.generate_dataset(num_points=10, key=key) + mean = gpx.mean_functions.Zero() kernel = gpx.kernels.Matern52() prior = gpx.gps.Prior(mean_function=mean, kernel=kernel) opt_posterior = return_optimised_posterior(D, prior, key) +# %% [markdown] + +# Using this optimised posterior, we can construct the Probability of Improvement +# acquisition function. + +# %% +from gpjax.decision_making.utility_functions.base import SinglePointUtilityFunction from gpjax.decision_making.utility_functions.probability_of_improvement import ( ProbabilityOfImprovement, ) -utility_function_builder = ProbabilityOfImprovement() -utility_function = utility_function_builder.build_utility_function( - posteriors={"OBJECTIVE": opt_posterior}, datasets={"OBJECTIVE": D}, key=key -) + +def construct_acquisition_function( + opt_posterior: gpx.base.Module, + dataset: gpx.Dataset, + key: Array, +) -> SinglePointUtilityFunction: + # ProbabilityOfImprovement is a builder of acquisition functions + utility_function_builder = ProbabilityOfImprovement() + + utility_function = utility_function_builder.build_utility_function( + posteriors={"OBJECTIVE": opt_posterior}, + datasets={"OBJECTIVE": dataset}, + key=key, + ) + + return utility_function + + +utility_function = construct_acquisition_function(opt_posterior, D, key) + +# %% [markdown] + +# Our module for decision making also provides a maximizer of utility functions. With it, we can construct a simple `optimize_acquisition_function` function that will return the point that maximizes the Probability of Improvement. + +# %% from gpjax.decision_making.utility_maximizer import ( ContinuousSinglePointUtilityMaximizer, ) -from gpjax.decision_making.utility_functions.base import SinglePointUtilityFunction from gpjax.decision_making.search_space import ContinuousSearchSpace def optimize_acquisition_function( utility_function: SinglePointUtilityFunction, key: Array, - lower_bounds: Float[Array, "D"], - upper_bounds: Float[Array, "D"], + search_space: ContinuousSearchSpace, num_initial_samples: int = 100, num_restarts: int = 5, ): @@ -153,10 +192,6 @@ def optimize_acquisition_function( num_initial_samples=num_initial_samples, num_restarts=num_restarts ) - search_space = ContinuousSearchSpace( - lower_bounds=lower_bounds, upper_bounds=upper_bounds - ) - x_next_best = optimizer.maximize( utility_function, search_space=search_space, key=key ) @@ -164,20 +199,11 @@ def optimize_acquisition_function( return x_next_best -# %% -def construct_acquisition_function( - opt_posterior: gpx.base.Module, - dataset: gpx.Dataset, - key: Array, -) -> SinglePointUtilityFunction: - utility_function_builder = ProbabilityOfImprovement() - utility_function = utility_function_builder.build_utility_function( - posteriors={"OBJECTIVE": opt_posterior}, - datasets={"OBJECTIVE": dataset}, - key=key, - ) +# %% [markdown] - return utility_function +# ...remembering the basics of a BO loop. + +# %% def propose_next_candidate( From ba5bbdce83760d923784046b65017d4ff22df80e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Gonz=C3=A1lez=20Duque?= Date: Mon, 1 Jul 2024 21:05:49 +0200 Subject: [PATCH 08/14] Updates the tutorial on decision making by mentioning PI --- ...n_optimisation_with_other_acq_functions.py | 531 ------------------ docs/examples/decision_making.py | 6 + .../utility_functions/__init__.py | 4 + 3 files changed, 10 insertions(+), 531 deletions(-) delete mode 100644 docs/examples/bayesian_optimisation_with_other_acq_functions.py diff --git a/docs/examples/bayesian_optimisation_with_other_acq_functions.py b/docs/examples/bayesian_optimisation_with_other_acq_functions.py deleted file mode 100644 index de3d58bba..000000000 --- a/docs/examples/bayesian_optimisation_with_other_acq_functions.py +++ /dev/null @@ -1,531 +0,0 @@ -# %% [markdown] -# # Bayesian Optimisation beyond Thompson Sampling -# -# In [a previous guide](), we gave an introduction to Bayesian optimisation: -# a framework for optimising black-box function that leverages the -# uncertainty estimates that come from Gaussian processes. - -# %% [markdown] - -# In a few words, Bayesian optimisation starts by fitting a Gaussian -# process to the data we have collected so far about the objective function, -# and then uses this model to construct an **acquisition function** -# which tells us which parts of the input domain have the potential -# of improving our optimization. Unlike the black-box objective, the -# acquisition function is easy to query and optimize. -# -# Acquisition functions come in many flavors. In the previous guide, -# we sampled from the Gaussian Process' predictive posterior and -# optimized said sample. This is known as **Thompson Sampling**, -# and it is a popular acquisition function due to its simplicity -# and ease of parallelization. -# -# In this guide we will introduce a new acquisition function: -# *Probability of Improvement* [[Kushner, 1964]](https://asmedigitalcollection.asme.org/fluidsengineering/article/86/1/97/392213/A-New-Method-of-Locating-the-Maximum-Point-of-an). This acquisition function can be formally -# defined as -# $$ \text{PI}(x) = \text{Prob}[f(x) < f(x_{\text{best}})] $$ -# where $f(x)$ is the objective functionw we aim at **minimizing**, -# and $x_{\text{best}}$ is the best point we have seen so far (i.e. -# the point with the lowest value for $f$). The name is clear: it measures -# the probability that a given point $x$ will **improve** our -# optimization trace. - -# %% -# Enable Float64 for more stable matrix inversions. -from jax import config - -config.update("jax_enable_x64", True) - -import jax -from jax import jit -import jax.numpy as jnp -import jax.random as jr -from jaxtyping import install_import_hook, Float -import matplotlib as mpl -import matplotlib.pyplot as plt -from matplotlib import cm -import optax as ox -import tensorflow_probability.substrates.jax as tfp -from typing import List, Tuple - -with install_import_hook("gpjax", "beartype.beartype"): - import gpjax as gpx -from gpjax.typing import Array, ScalarFloat - -key = jr.key(1337) -plt.style.use( - "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) -cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] - -# %% [markdown] -# ## Optimizing Logarithmic Goldstein-Price -# -# Our module for `decision_making` includes a couple of test functions -# you can use to test whether your model fitting/optimization algorithms are -# working properly. One of these is the [Logarithmic Goldstein-Price function](https://www.sfu.ca/~ssurjano/goldpr.html). -# -# These test functions we provide contain a search space, minimizer, and minimum -# attributes, as well as methods for generating training and testing datasets. - -# %% -from gpjax.decision_making.test_functions import LogarithmicGoldsteinPrice - -logarithmic_goldstein_price = LogarithmicGoldsteinPrice() - -example_dataset = logarithmic_goldstein_price.generate_dataset(num_points=5, key=key) -example_values = logarithmic_goldstein_price.evaluate(example_dataset.X) -print(f"Example Dataset: {example_dataset.X}") -print(f"Example Values: {example_values}") -print(f"Minimizer: {logarithmic_goldstein_price.minimizer}") -print(f"Minimum: {logarithmic_goldstein_price.minimum}") - -# %% [markdown] - -# Let's plot this function to see what it looks like. As you might have noticed from the dataset, this function takes as input points in - -# %% - -domain = jnp.linspace(0, 1, 1000).reshape(-1, 1) -objective_values = logarithmic_goldstein_price.evaluate(domain) - -fig, ax = plt.subplots() -ax.plot(domain, objective_values, label="Logarithmic Goldstein-Price Function") -ax.scatter( - example_dataset.X.flatten(), example_dataset.y, label="Observations", color=cols[2] -) - -# %% [markdown] - -# Just like in the previous guide, we can fit a Gaussian Process to -# a dataset rendering an optimised posterior: - - -# %% -def return_optimised_posterior( - data: gpx.Dataset, prior: gpx.base.Module, key: Array -) -> gpx.base.Module: - likelihood = gpx.likelihoods.Gaussian( - num_datapoints=data.n, obs_stddev=jnp.array(1e-6) - ) # Our function is noise-free, so we set the observation noise's standard deviation to a very small value - likelihood = likelihood.replace_trainable(obs_stddev=False) - - posterior = prior * likelihood - - negative_mll = gpx.objectives.ConjugateMLL(negative=True) - negative_mll(posterior, train_data=data) - negative_mll = jit(negative_mll) - - opt_posterior, _ = gpx.fit( - model=posterior, - objective=negative_mll, - train_data=data, - optim=ox.adam(learning_rate=0.01), - num_iters=1000, - safe=True, - key=key, - verbose=False, - ) - - return opt_posterior - - -D = logarithmic_goldstein_price.generate_dataset(num_points=10, key=key) - -mean = gpx.mean_functions.Zero() -kernel = gpx.kernels.Matern52() -prior = gpx.gps.Prior(mean_function=mean, kernel=kernel) -opt_posterior = return_optimised_posterior(D, prior, key) - -# %% [markdown] - -# Using this optimised posterior, we can construct the Probability of Improvement -# acquisition function. - -# %% - -from gpjax.decision_making.utility_functions.base import SinglePointUtilityFunction -from gpjax.decision_making.utility_functions.probability_of_improvement import ( - ProbabilityOfImprovement, -) - - -def construct_acquisition_function( - opt_posterior: gpx.base.Module, - dataset: gpx.Dataset, - key: Array, -) -> SinglePointUtilityFunction: - # ProbabilityOfImprovement is a builder of acquisition functions - utility_function_builder = ProbabilityOfImprovement() - - utility_function = utility_function_builder.build_utility_function( - posteriors={"OBJECTIVE": opt_posterior}, - datasets={"OBJECTIVE": dataset}, - key=key, - ) - - return utility_function - - -utility_function = construct_acquisition_function(opt_posterior, D, key) - -# %% [markdown] - -# Our module for decision making also provides a maximizer of utility functions. With it, we can construct a simple `optimize_acquisition_function` function that will return the point that maximizes the Probability of Improvement. - -# %% - -from gpjax.decision_making.utility_maximizer import ( - ContinuousSinglePointUtilityMaximizer, -) -from gpjax.decision_making.search_space import ContinuousSearchSpace - - -def optimize_acquisition_function( - utility_function: SinglePointUtilityFunction, - key: Array, - search_space: ContinuousSearchSpace, - num_initial_samples: int = 100, - num_restarts: int = 5, -): - optimizer = ContinuousSinglePointUtilityMaximizer( - num_initial_samples=num_initial_samples, num_restarts=num_restarts - ) - - x_next_best = optimizer.maximize( - utility_function, search_space=search_space, key=key - ) - - return x_next_best - - -# %% [markdown] - -# ...remembering the basics of a BO loop. - -# %% - - -def propose_next_candidate( - utility_function: SinglePointUtilityFunction, - key: Array, - lower_bounds: Float[Array, "D"], - upper_bounds: Float[Array, "D"], -) -> Float[Array, "D 1"]: - queried_x = optimize_acquisition_function( - utility_function, key, lower_bounds, upper_bounds, num_initial_samples=100 - ) - - return queried_x - - -def run_one_bo_loop_in_1D( - objective_function: standardised_forrester, - opt_posterior: gpx.base.Module, - dataset: gpx.Dataset, - key: Array, - plot: bool = True, -) -> Float[Array, "D 1"]: - domain = jnp.linspace(0, 1, 1000).reshape(-1, 1) - objective_values = objective_function(domain) - - latent_distribution = opt_posterior.predict(domain, train_data=dataset) - predictive_distribution = opt_posterior.likelihood(latent_distribution) - - predictive_mean = predictive_distribution.mean() - predictive_std = predictive_distribution.stddev() - - # Building PI - utility_function = construct_acquisition_function(opt_posterior, dataset, key) - utility_function_values = utility_function(domain) - - # Optimizing the acq. function - lower_bound = jnp.array([0.0]) - upper_bound = jnp.array([1.0]) - queried_x = propose_next_candidate(utility_function, key, lower_bound, upper_bound) - - if plot: - fig, ax = plt.subplots() - ax.plot(domain, predictive_mean, label="Predictive Mean", color=cols[1]) - ax.fill_between( - domain.squeeze(), - predictive_mean - 2 * predictive_std, - predictive_mean + 2 * predictive_std, - alpha=0.2, - label="Two sigma", - color=cols[1], - ) - ax.plot( - domain, - predictive_mean - 2 * predictive_std, - linestyle="--", - linewidth=1, - color=cols[1], - ) - ax.plot( - domain, - predictive_mean + 2 * predictive_std, - linestyle="--", - linewidth=1, - color=cols[1], - ) - ax.plot(domain, utility_function_values, label="Probability of Improvement") - ax.plot( - domain, - objective_values, - label="Forrester Function", - color=cols[0], - linestyle="--", - linewidth=2, - ) - ax.axvline(x=0.757, linestyle=":", color=cols[3], label="True Optimum") - ax.scatter(dataset.X, dataset.y, label="Observations", color=cols[2], zorder=2) - ax.scatter( - queried_x, - utility_function(queried_x), - label="Probability of Improvement Optimum", - marker="*", - color=cols[3], - zorder=3, - ) - ax.legend(loc="center left", bbox_to_anchor=(0.975, 0.5)) - plt.show() - - return queried_x - - -# plot_bayes_opt_using_pi(standardised_forrester, opt_posterior, D, key) - -# %% -bo_iters = 5 - -# Set up initial dataset -initial_x = tfp.mcmc.sample_halton_sequence( - dim=1, num_results=initial_sample_num, seed=key, dtype=jnp.float64 -).reshape(-1, 1) -initial_y = standardised_forrester(initial_x) -D = gpx.Dataset(X=initial_x, y=initial_y) - -for i in range(bo_iters): - key, subkey = jr.split(key) - - # Generate optimised posterior using previously observed data - mean = gpx.mean_functions.Zero() - kernel = gpx.kernels.Matern52() - prior = gpx.gps.Prior(mean_function=mean, kernel=kernel) - opt_posterior = return_optimised_posterior(D, prior, subkey) - - queried_x = run_one_bo_loop_in_1D(standardised_forrester, opt_posterior, D, key) - - # Evaluate the black-box function at the best point observed so far, and add it to the dataset - y_star = standardised_forrester(queried_x) - print(f"Queried Point: {queried_x}, Black-Box Function Value: {y_star}") - D = D + gpx.Dataset(X=queried_x, y=y_star) - - -# %% -def standardised_six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]: - mean = 1.12767 - std = 1.17500 - x1 = x[..., :1] - x2 = x[..., 1:] - term1 = (4 - 2.1 * x1**2 + x1**4 / 3) * x1**2 - term2 = x1 * x2 - term3 = (-4 + 4 * x2**2) * x2**2 - return (term1 + term2 + term3 - mean) / std - - -# %% -x1 = jnp.linspace(-2, 2, 100) -x2 = jnp.linspace(-1, 1, 100) -x1, x2 = jnp.meshgrid(x1, x2) -x = jnp.stack([x1.flatten(), x2.flatten()], axis=1) -y = standardised_six_hump_camel(x) - -fig, ax = plt.subplots(subplot_kw={"projection": "3d"}) -surf = ax.plot_surface( - x1, - x2, - y.reshape(x1.shape[0], x2.shape[0]), - linewidth=0, - cmap=cm.coolwarm, - antialiased=False, -) -ax.set_xlabel("x1") -ax.set_ylabel("x2") -plt.show() - - -# %% -x_star_one = jnp.array([[0.0898, -0.7126]]) -x_star_two = jnp.array([[-0.0898, 0.7126]]) -fig, ax = plt.subplots() -contour_plot = ax.contourf( - x1, x2, y.reshape(x1.shape[0], x2.shape[0]), cmap=cm.coolwarm, levels=40 -) -ax.scatter( - x_star_one[0][0], x_star_one[0][1], marker="*", color=cols[2], label="Global Minima" -) -ax.scatter(x_star_two[0][0], x_star_two[0][1], marker="*", color=cols[2]) -ax.set_xlabel("x1") -ax.set_ylabel("x2") -fig.colorbar(contour_plot) -ax.legend() -plt.show() - -# %% -lower_bound = jnp.array([-2.0, -1.0]) -upper_bound = jnp.array([2.0, 1.0]) -initial_sample_num = 5 -bo_iters = 20 -num_experiments = 5 -bo_experiment_results = [] - -for experiment in range(num_experiments): - print(f"Starting Experiment: {experiment + 1}") - # Set up initial dataset - initial_x = tfp.mcmc.sample_halton_sequence( - dim=2, num_results=initial_sample_num, seed=key, dtype=jnp.float64 - ) - initial_x = jnp.array(lower_bound + (upper_bound - lower_bound) * initial_x) - initial_y = standardised_six_hump_camel(initial_x) - D = gpx.Dataset(X=initial_x, y=initial_y) - - for i in range(bo_iters): - key, subkey = jr.split(key) - - # Generate optimised posterior - mean = gpx.mean_functions.Zero() - kernel = gpx.kernels.Matern52( - active_dims=[0, 1], lengthscale=jnp.array([1.0, 1.0]), variance=2.0 - ) - prior = gpx.gps.Prior(mean_function=mean, kernel=kernel) - opt_posterior = return_optimised_posterior(D, prior, subkey) - - # Constructing the acq. function - utility_function = construct_acquisition_function(opt_posterior, D, key) - - # Draw a sample from the posterior, and find the minimiser of it - queried_x = propose_next_candidate( - utility_function, key, lower_bound, upper_bound - ) - - # Evaluate the black-box function at the best point observed so far, and add it to the dataset - y_star = standardised_six_hump_camel(queried_x) - print( - f"BO Iteration: {i + 1}, Queried Point: {queried_x}, Black-Box Function Value:" - f" {y_star}" - ) - D = D + gpx.Dataset(X=queried_x, y=y_star) - bo_experiment_results.append(D) - -# %% -random_experiment_results = [] -for i in range(num_experiments): - key, subkey = jr.split(key) - initial_x = bo_experiment_results[i].X[:5] - initial_y = bo_experiment_results[i].y[:5] - final_x = jr.uniform( - key, - shape=(bo_iters, 2), - dtype=jnp.float64, - minval=lower_bound, - maxval=upper_bound, - ) - final_y = standardised_six_hump_camel(final_x) - random_x = jnp.concatenate([initial_x, final_x], axis=0) - random_y = jnp.concatenate([initial_y, final_y], axis=0) - random_experiment_results.append(gpx.Dataset(X=random_x, y=random_y)) - - -# %% -def obtain_log_regret_statistics( - experiment_results: List[gpx.Dataset], - global_minimum: ScalarFloat, -) -> Tuple[Float[Array, "N 1"], Float[Array, "N 1"]]: - log_regret_results = [] - for exp_result in experiment_results: - observations = exp_result.y - cumulative_best_observations = jax.lax.associative_scan( - jax.numpy.minimum, observations - ) - regret = cumulative_best_observations - global_minimum - log_regret = jnp.log(regret) - log_regret_results.append(log_regret) - - log_regret_results = jnp.array(log_regret_results) - log_regret_mean = jnp.mean(log_regret_results, axis=0) - log_regret_std = jnp.std(log_regret_results, axis=0) - return log_regret_mean, log_regret_std - - -bo_log_regret_mean, bo_log_regret_std = obtain_log_regret_statistics( - bo_experiment_results, -1.8377 -) -( - random_log_regret_mean, - random_log_regret_std, -) = obtain_log_regret_statistics(random_experiment_results, -1.8377) - -# %% [markdown] -# Now, when we plot the mean and standard deviation of the log regret at each iteration, -# we can see that BO outperforms random sampling! - -# %% -fig, ax = plt.subplots() -fn_evaluations = jnp.arange(1, bo_iters + initial_sample_num + 1) -ax.plot(fn_evaluations, bo_log_regret_mean, label="Bayesian Optimisation") -ax.fill_between( - fn_evaluations, - bo_log_regret_mean[:, 0] - bo_log_regret_std[:, 0], - bo_log_regret_mean[:, 0] + bo_log_regret_std[:, 0], - alpha=0.2, -) -ax.plot(fn_evaluations, random_log_regret_mean, label="Random Search") -ax.fill_between( - fn_evaluations, - random_log_regret_mean[:, 0] - random_log_regret_std[:, 0], - random_log_regret_mean[:, 0] + random_log_regret_std[:, 0], - alpha=0.2, -) -ax.axvline(x=initial_sample_num, linestyle=":") -ax.set_xlabel("Number of Black-Box Function Evaluations") -ax.set_ylabel("Log Regret") -ax.legend() -plt.show() - -# %% [markdown] -# It can also be useful to plot the queried points over the course of a single BO run, in -# order to gain some insight into how the algorithm queries the search space. Below -# we do this for one of the BO experiments, and can see that the algorithm initially -# performs some exploration of the search space whilst it is uncertain about the black-box -# function, but it then hones in one one of the global minima of the function, as we would hope! - -# %% -fig, ax = plt.subplots() -contour_plot = ax.contourf( - x1, x2, y.reshape(x1.shape[0], x2.shape[0]), cmap=cm.coolwarm, levels=40 -) -ax.scatter( - x_star_one[0][0], - x_star_one[0][1], - marker="*", - color=cols[2], - label="Global Minimum", - zorder=2, -) -ax.scatter(x_star_two[0][0], x_star_two[0][1], marker="*", color=cols[2], zorder=2) -ax.scatter( - bo_experiment_results[1].X[:, 0], - bo_experiment_results[1].X[:, 1], - marker="x", - color=cols[1], - label="Bayesian Optimisation Queries", -) -ax.set_xlabel("x1") -ax.set_ylabel("x2") -fig.colorbar(contour_plot) -ax.legend() -plt.show() - -# %% diff --git a/docs/examples/decision_making.py b/docs/examples/decision_making.py index 1460fc25a..ee9df229f 100644 --- a/docs/examples/decision_making.py +++ b/docs/examples/decision_making.py @@ -237,6 +237,12 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: num_initial_samples=100, num_restarts=1 ) +# %% [markdown] + +# It is worth noting that `ThompsonSampling` is not the only utility function we could use, +# since our module also provides e.g. `ProbabilityOfImprovement`, +# which was briefly discussed in [our previous introduction to Bayesian optimisation](https://docs.jaxgaussianprocesses.com/examples/bayesian_optimisation/). + # %% [markdown] # ## Putting it All Together with the Decision Maker diff --git a/gpjax/decision_making/utility_functions/__init__.py b/gpjax/decision_making/utility_functions/__init__.py index 18868949d..eedb2feda 100644 --- a/gpjax/decision_making/utility_functions/__init__.py +++ b/gpjax/decision_making/utility_functions/__init__.py @@ -19,6 +19,9 @@ UtilityFunction, ) from gpjax.decision_making.utility_functions.thompson_sampling import ThompsonSampling +from gpjax.decision_making.utility_functions.probability_of_improvement import ( + ProbabilityOfImprovement, +) __all__ = [ "UtilityFunction", @@ -26,4 +29,5 @@ "AbstractSinglePointUtilityFunctionBuilder", "SinglePointUtilityFunction", "ThompsonSampling", + "ProbabilityOfImprovement", ] From b455c9a3be03071790f341b5ce79597b4cc8b8c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Gonz=C3=A1lez=20Duque?= Date: Mon, 1 Jul 2024 21:11:29 +0200 Subject: [PATCH 09/14] Makes a better test for PI by manually computing the CDF --- .../test_probability_of_improvement.py | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py b/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py index 5f18f3ca0..f66acd16b 100644 --- a/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py +++ b/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py @@ -23,7 +23,7 @@ from gpjax.decision_making.utility_functions.probability_of_improvement import ( ProbabilityOfImprovement, ) -from gpjax.decision_making.utils import OBJECTIVE +from gpjax.decision_making.utils import OBJECTIVE, gaussian_cdf from tests.test_decision_making.utils import generate_dummy_conjugate_posterior @@ -43,19 +43,13 @@ def test_probability_of_improvement_gives_correct_value_for_a_seed(): test_X = forrester.generate_test_points(num_points=10, key=key) utility_values = pi_utility(test_X) - expected_utility_values = jnp.array( - [ - 7.30230451e-05, - 5.00322831e-05, - 1.06219741e-03, - 2.19520435e-03, - 3.49279363e-05, - 1.66031943e-04, - 2.78478912e-04, - 3.35871920e-04, - 1.38265233e-04, - 3.63297977e-05, - ] + # Computing the expected utility values + predictive_dist = posterior.predict(test_X, train_data=dataset) + predictive_mean = predictive_dist.mean() + predictive_std = predictive_dist.stddev() + + expected_utility_values = gaussian_cdf( + (dataset.y.min() - predictive_mean) / predictive_std ).reshape(-1, 1) assert utility_values.shape == (10, 1) From 1f67bfb97023c25a4e7ec0b33e349351bb322164 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Gonz=C3=A1lez=20Duque?= Date: Mon, 1 Jul 2024 21:12:28 +0200 Subject: [PATCH 10/14] Lints according to Ruff --- gpjax/decision_making/utility_functions/__init__.py | 2 +- .../test_probability_of_improvement.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/gpjax/decision_making/utility_functions/__init__.py b/gpjax/decision_making/utility_functions/__init__.py index eedb2feda..5186b1075 100644 --- a/gpjax/decision_making/utility_functions/__init__.py +++ b/gpjax/decision_making/utility_functions/__init__.py @@ -18,10 +18,10 @@ SinglePointUtilityFunction, UtilityFunction, ) -from gpjax.decision_making.utility_functions.thompson_sampling import ThompsonSampling from gpjax.decision_making.utility_functions.probability_of_improvement import ( ProbabilityOfImprovement, ) +from gpjax.decision_making.utility_functions.thompson_sampling import ThompsonSampling __all__ = [ "UtilityFunction", diff --git a/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py b/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py index f66acd16b..65a7bbad8 100644 --- a/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py +++ b/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py @@ -16,14 +16,17 @@ config.update("jax_enable_x64", True) -import jax.random as jr import jax.numpy as jnp +import jax.random as jr from gpjax.decision_making.test_functions.continuous_functions import Forrester from gpjax.decision_making.utility_functions.probability_of_improvement import ( ProbabilityOfImprovement, ) -from gpjax.decision_making.utils import OBJECTIVE, gaussian_cdf +from gpjax.decision_making.utils import ( + OBJECTIVE, + gaussian_cdf, +) from tests.test_decision_making.utils import generate_dummy_conjugate_posterior From 64aeb2ced73677e01f359deecc36e49420c9691c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Gonz=C3=A1lez=20Duque?= Date: Wed, 3 Jul 2024 10:46:09 +0200 Subject: [PATCH 11/14] Removes the manual CDF and replaces it with tfp, improves error message on PI --- .../probability_of_improvement.py | 36 +++++++++++++------ gpjax/decision_making/utils.py | 8 ----- .../test_probability_of_improvement.py | 21 +++++++---- 3 files changed, 41 insertions(+), 24 deletions(-) diff --git a/gpjax/decision_making/utility_functions/probability_of_improvement.py b/gpjax/decision_making/utility_functions/probability_of_improvement.py index a044110b1..e04f1eaca 100644 --- a/gpjax/decision_making/utility_functions/probability_of_improvement.py +++ b/gpjax/decision_making/utility_functions/probability_of_improvement.py @@ -16,16 +16,14 @@ from beartype.typing import Mapping from jaxtyping import Num +import tensorflow_probability.substrates.jax as tfp from gpjax.dataset import Dataset from gpjax.decision_making.utility_functions.base import ( AbstractSinglePointUtilityFunctionBuilder, SinglePointUtilityFunction, ) -from gpjax.decision_making.utils import ( - OBJECTIVE, - gaussian_cdf, -) +from gpjax.decision_making.utils import OBJECTIVE from gpjax.gps import ConjugatePosterior from gpjax.typing import ( Array, @@ -92,19 +90,37 @@ def build_utility_function( objective_posterior = posteriors[OBJECTIVE] if not isinstance(objective_posterior, ConjugatePosterior): raise ValueError( - "Objective posterior must be a ConjugatePosterior to draw an approximate sample." + "Objective posterior must be a ConjugatePosterior to compute the Probability of Improvement using a Gaussian CDF." ) objective_dataset = datasets[OBJECTIVE] + if ( + objective_dataset.X is None + or objective_dataset.n == 0 + or objective_dataset.y is None + ): + raise ValueError( + "Objective dataset must be non-empty to compute the " + "Probability of Improvement (since we need a " + "`best_y` value)." + ) def probability_of_improvement(x_test: Num[Array, "N D"]): + # Computing the posterior mean for the training dataset + # for computing the best_y value (as the minimum + # posterior mean of the objective function) + predictive_dist_for_training = objective_posterior.predict( + objective_dataset.X, objective_dataset + ) + best_y = predictive_dist_for_training.mean().min() + predictive_dist = objective_posterior.predict(x_test, objective_dataset) - # Assuming that the goal is to minimize the objective function - best_y = objective_dataset.y.min() + normal_dist = tfp.distributions.Normal( + loc=predictive_dist.mean(), + scale=predictive_dist.stddev(), + ) - return gaussian_cdf( - (best_y - predictive_dist.mean()) / predictive_dist.stddev() - ).reshape(-1, 1) + return normal_dist.cdf(best_y).reshape(-1, 1) return probability_of_improvement diff --git a/gpjax/decision_making/utils.py b/gpjax/decision_making/utils.py index 873bf21de..dd81b0cfd 100644 --- a/gpjax/decision_making/utils.py +++ b/gpjax/decision_making/utils.py @@ -50,11 +50,3 @@ def build_function_evaluator( dictionary of datasets storing the evaluated points. """ return lambda x: {tag: Dataset(x, f(x)) for tag, f in functions.items()} - - -def gaussian_cdf(x: Float[Array, " N"]) -> Float[Array, " N"]: - """ - Compute the cumulative distribution function of the standard normal distribution at - the points `x`. - """ - return 0.5 * (1 + jax.scipy.special.erf(x / jnp.sqrt(2))) diff --git a/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py b/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py index 65a7bbad8..5b6be0698 100644 --- a/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py +++ b/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py @@ -16,6 +16,7 @@ config.update("jax_enable_x64", True) +import jax import jax.numpy as jnp import jax.random as jr @@ -23,10 +24,7 @@ from gpjax.decision_making.utility_functions.probability_of_improvement import ( ProbabilityOfImprovement, ) -from gpjax.decision_making.utils import ( - OBJECTIVE, - gaussian_cdf, -) +from gpjax.decision_making.utils import OBJECTIVE from tests.test_decision_making.utils import generate_dummy_conjugate_posterior @@ -51,9 +49,20 @@ def test_probability_of_improvement_gives_correct_value_for_a_seed(): predictive_mean = predictive_dist.mean() predictive_std = predictive_dist.stddev() - expected_utility_values = gaussian_cdf( - (dataset.y.min() - predictive_mean) / predictive_std + # Computing best_y as the min. of the posterior predictive mean + # over the training set. + predictive_dist_for_training_data = posterior.predict(dataset.X, train_data=dataset) + best_y = predictive_dist_for_training_data.mean().min() + + # Gaussian CDF computed "by hand" + x_ = (best_y - predictive_mean) / predictive_std + expected_utility_values = 0.5 * ( + 1 + jax.scipy.special.erf(x_ / jnp.sqrt(2)) ).reshape(-1, 1) assert utility_values.shape == (10, 1) assert jnp.isclose(utility_values, expected_utility_values).all() + + +if __name__ == "__main__": + test_probability_of_improvement_gives_correct_value_for_a_seed() From b651e90b444ff51b31c53368c2f4c6318f314452 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Gonz=C3=A1lez=20Duque?= Date: Wed, 3 Jul 2024 10:46:38 +0200 Subject: [PATCH 12/14] Removes unused imports in utils --- gpjax/decision_making/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/gpjax/decision_making/utils.py b/gpjax/decision_making/utils.py index dd81b0cfd..9af19de32 100644 --- a/gpjax/decision_making/utils.py +++ b/gpjax/decision_making/utils.py @@ -17,8 +17,6 @@ Dict, Final, ) -import jax -import jax.numpy as jnp from gpjax.dataset import Dataset from gpjax.typing import ( From ceef6301252ef41fd3411a33448a85ac17316c53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Gonz=C3=A1lez=20Duque?= Date: Wed, 3 Jul 2024 10:50:25 +0200 Subject: [PATCH 13/14] Removes test of manual CDF --- .../test_utility_functions/test_utility_functions.py | 6 +++--- tests/test_decision_making/test_utils.py | 7 ------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/tests/test_decision_making/test_utility_functions/test_utility_functions.py b/tests/test_decision_making/test_utility_functions/test_utility_functions.py index f8e888e82..658d5b88a 100644 --- a/tests/test_decision_making/test_utility_functions/test_utility_functions.py +++ b/tests/test_decision_making/test_utility_functions/test_utility_functions.py @@ -148,11 +148,11 @@ def test_utility_functions_have_correct_shapes( posterior = generate_dummy_conjugate_posterior(dataset) posteriors = {OBJECTIVE: posterior} datasets = {OBJECTIVE: dataset} - ts_utility_builder = utility_function_builder(**utility_function_kwargs) - ts_utility_function = ts_utility_builder.build_utility_function( + utility_builder = utility_function_builder(**utility_function_kwargs) + utility_function = utility_builder.build_utility_function( posteriors=posteriors, datasets=datasets, key=key ) test_key, _ = jr.split(key) test_X = test_target_function.generate_test_points(num_test_points, test_key) - ts_utility_function_values = ts_utility_function(test_X) + ts_utility_function_values = utility_function(test_X) assert ts_utility_function_values.shape == (num_test_points, 1) diff --git a/tests/test_decision_making/test_utils.py b/tests/test_decision_making/test_utils.py index d097ee23a..f77bc6cab 100644 --- a/tests/test_decision_making/test_utils.py +++ b/tests/test_decision_making/test_utils.py @@ -21,7 +21,6 @@ from gpjax.decision_making.utils import ( OBJECTIVE, build_function_evaluator, - gaussian_cdf, ) from gpjax.typing import ( Array, @@ -45,9 +44,3 @@ def _cube(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: assert jnp.equal(datasets[OBJECTIVE].y, _square(x)).all() assert jnp.equal(datasets["CONSTRAINT"].X, x).all() assert jnp.equal(datasets["CONSTRAINT"].y, _cube(x)).all() - - -def test_gaussian_cdf(): - x = jnp.array([0.0, 1.0, 2.0]) - cdf = jnp.array([0.5, 0.84134475, 0.97724987]) - assert jnp.allclose(cdf, gaussian_cdf(x)) From 58d7bd2b651fd24c87f721678ef2e2e0cb2613cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Gonz=C3=A1lez=20Duque?= Date: Wed, 3 Jul 2024 22:39:14 +0200 Subject: [PATCH 14/14] Updates the documentation to reflect changes, removes if main --- .../utility_functions/probability_of_improvement.py | 3 ++- .../test_utility_functions/test_probability_of_improvement.py | 4 ---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/gpjax/decision_making/utility_functions/probability_of_improvement.py b/gpjax/decision_making/utility_functions/probability_of_improvement.py index e04f1eaca..30aa3986f 100644 --- a/gpjax/decision_making/utility_functions/probability_of_improvement.py +++ b/gpjax/decision_making/utility_functions/probability_of_improvement.py @@ -40,7 +40,8 @@ class ProbabilityOfImprovement(AbstractSinglePointUtilityFunctionBuilder): More precisely, given a predictive posterior distribution of the objective function $`f`$, the probability of improvement at a test point $`x`$ is defined as: $$`\text{PI}(x) = \text{Prob}[f(x) < f(x_{\text{best}})]`$$ - where $`x_{\text{best}}`$ is the minimizer of $`f`$ in the dataset. + where $`x_{\text{best}}`$ is the minimiser of the posterior mean + at previously observed values (to handle noisy observations). The probability of improvement can be easily computed using the cumulative distribution function of the standard normal distribution $`\Phi`$: diff --git a/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py b/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py index 5b6be0698..417f5b2f6 100644 --- a/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py +++ b/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py @@ -62,7 +62,3 @@ def test_probability_of_improvement_gives_correct_value_for_a_seed(): assert utility_values.shape == (10, 1) assert jnp.isclose(utility_values, expected_utility_values).all() - - -if __name__ == "__main__": - test_probability_of_improvement_gives_correct_value_for_a_seed()