From 4b19843b63730bec50a143f45006a82eff0f8497 Mon Sep 17 00:00:00 2001 From: Thomas-Christie Date: Fri, 25 Aug 2023 14:02:49 +0100 Subject: [PATCH 1/2] Add logic for running a decision making loop Added an `AbstractDecisionMaker` abstract base class which can be implemented in order to define a decision making loop. A concrete implementation, in the form of a `DecisionMaker` class, has been added. At its heart it has two core methods: 1. `ask` which is used to get a point to be queried next. 2. `tell` which is used to tell the `DecisionMaker` about new observations. In a typical decision making setup this will result in the datsets and posteriors being updated. In addition to this, in the `DecisionMaker` a `run` method is provided, which will automatically run the decision making loop for n steps. After the `ask` step, the functions in the `post_ask` list will be executed, taking as arguments the decision maker and the point chosen to be queried next. Similarly, after the `tell` step, the functions in the `post_tell` list are executed, taking the decision maker as the sole argument. --- gpjax/decision_making/__init__.py | 8 + gpjax/decision_making/decision_maker.py | 254 ++++++++++++++ gpjax/decision_making/utils.py | 35 +- pyproject.toml | 2 - .../test_decision_maker.py | 330 ++++++++++++++++++ tests/test_decision_making/test_utils.py | 46 +++ tests/test_decision_making/utils.py | 45 +++ 7 files changed, 717 insertions(+), 3 deletions(-) create mode 100644 gpjax/decision_making/decision_maker.py create mode 100644 tests/test_decision_making/test_decision_maker.py create mode 100644 tests/test_decision_making/test_utils.py create mode 100644 tests/test_decision_making/utils.py diff --git a/gpjax/decision_making/__init__.py b/gpjax/decision_making/__init__.py index c3ff92262..50c7d551b 100644 --- a/gpjax/decision_making/__init__.py +++ b/gpjax/decision_making/__init__.py @@ -21,6 +21,10 @@ AbstractAcquisitionMaximizer, ContinuousAcquisitionMaximizer, ) +from gpjax.decision_making.decision_maker import ( + AbstractDecisionMaker, + DecisionMaker, +) from gpjax.decision_making.posterior_handler import PosteriorHandler from gpjax.decision_making.search_space import ( AbstractSearchSpace, @@ -32,14 +36,18 @@ LogarithmicGoldsteinPrice, Quadratic, ) +from gpjax.decision_making.utils import build_function_evaluator __all__ = [ "AbstractAcquisitionFunctionBuilder", "AbstractAcquisitionMaximizer", + "AbstractDecisionMaker", "AbstractSearchSpace", "AcquisitionFunction", + "build_function_evaluator", "ContinuousAcquisitionMaximizer", "ContinuousSearchSpace", + "DecisionMaker", "AbstractContinuousTestFunction", "Forrester", "LogarithmicGoldsteinPrice", diff --git a/gpjax/decision_making/decision_maker.py b/gpjax/decision_making/decision_maker.py new file mode 100644 index 000000000..ddfff18b0 --- /dev/null +++ b/gpjax/decision_making/decision_maker.py @@ -0,0 +1,254 @@ +# Copyright 2023 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from abc import ( + ABC, + abstractmethod, +) +from dataclasses import ( + dataclass, + field, +) + +from beartype.typing import ( + Callable, + Dict, + List, + Mapping, +) +import jax.random as jr + +from gpjax.dataset import Dataset +from gpjax.decision_making.acquisition_functions import ( + AbstractAcquisitionFunctionBuilder, +) +from gpjax.decision_making.acquisition_maximizer import AbstractAcquisitionMaximizer +from gpjax.decision_making.posterior_handler import PosteriorHandler +from gpjax.decision_making.search_space import AbstractSearchSpace +from gpjax.decision_making.utils import FunctionEvaluator +from gpjax.gps import AbstractPosterior +from gpjax.typing import ( + Array, + Float, + KeyArray, +) + + +@dataclass +class AbstractDecisionMaker(ABC): + """ + AbstractDecisionMaker abstract base class which handles the core decision loop. The + decision making loop is split into two key steps, `ask` and `tell`. The `ask` + step is typically used to decide which point to query next. The `tell` step is + typically used to update models and datasets with newly queried points. + + Attributes: + search_space (AbstractSearchSpace): Search space which is being queried + posterior_handlers (Dict[str, PosteriorHandler]): Dictionary of posterior + handlers, which are used to update posteriors throughout the decision making + loop. Tags are used to distinguish between posteriors. In a typical Bayesian + optimisation setup one of the tags will be `OBJECTIVE`, defined in + decision_making.utils. + datasets (Dict[str, Dataset]): Dictionary of datasets, which are augmented with + observations throughout the decision making loop. In a typical setup they are + also used to fit the posteriors, using the `posterior_handlers`. Tags are used + to distinguish datasets, and correspond to tags in `posterior_handlers`. + acquisition_function_builder (AbstractAcquisitionFunctionBuilder): Object which + builds acquisition functions from posteriors and datasets, to decide where + to query next. In a typical Bayesian optimisation setup the point chosen to + be queried next is the point which maximizes the acquisition function. + acquisition_maximizer (AbstractAcquisitionMaximizer): Object which maximizes + acquisition functions over the search space. + key (KeyArray): JAX random key, used to generate random numbers. + post_ask (List[Callable]): List of functions to be executed after each ask step. + post_tell (List[Callable]): List of functions to be executed after each tell + step. + """ + + search_space: AbstractSearchSpace + posterior_handlers: Dict[str, PosteriorHandler] + datasets: Dict[str, Dataset] + acquisition_function_builder: AbstractAcquisitionFunctionBuilder + acquisition_maximizer: AbstractAcquisitionMaximizer + key: KeyArray + post_ask: List[Callable] = field( + default_factory=list + ) # Specific type is List[Callable[[DecisionMaker, Float[Array, ["1 D"]]], None]] but causes Beartype issues + post_tell: List[Callable] = field( + default_factory=list + ) # Specific type is List[Callable[[DecisionMaker], None]] but causes Beartype issues + + @abstractmethod + def ask(self, key: KeyArray) -> Float[Array, "1 D"]: + """ + In a typical decision making setup this will use the + `acquisition_function_builder` to form an acquisition function and then return + the point which maximizes the acquisition function using the + `acquisition_maximizer` as the point to be queried next. + + Args: + key (KeyArray): JAX PRNG key for controlling random state. + + Returns: + Float[Array, "1 D"]: Point to be queried next + """ + raise NotImplementedError + + @abstractmethod + def tell(self, observation_datasets: Mapping[str, Dataset], key: KeyArray): + """ + Tell decision maker about new observations. In a typical decision making setup + we will update the datasets and posteriors with the new observations. + + Args: + observation_datasets (Mapping[str, Dataset]): Dictionary of datasets + containing new observations. Tags are used to distinguish datasets, and + correspond to tags in `posterior_handlers` in a typical setup. + key (KeyArray): JAX PRNG key for controlling random state. + """ + raise NotImplementedError + + +@dataclass +class DecisionMaker(AbstractDecisionMaker): + """ + DecisionMaker class which handles the core decision making loop in a typical setup. The + decision making loop is split into two key steps, `ask` and `tell`. The `ask` + step forms an `AcquisitionFunction` from the current `posteriors` and `datasets` and + returns the point which maximises it. It also stores the formed acquisition function + under the attribute `self.current_acquisition_function` so that it can be called, + for instance for plotting, after the `ask` function has been called. The `tell` step + adds a newly queried point to the `datasets` and updates the `posteriors`. + + This can be run as a typical ask-tell loop, or the `run` method can be used to run + the decision making loop for a fixed number of steps. Moreover, the `run` method executes + the functions in `post_ask` and `post_tell` after each ask and tell step + respectively. This enables the user to add custom functionality, such as the ability + to plot values of interest during the optimization process. + """ + + def __post_init__(self): + """ + At initialisation we check that the posterior handlers and datasets are + consistent (i.e. have the same tags), and then initialise the posteriors, optimizing them using the + corresponding datasets. + """ + # Check that posterior handlers and datasets are consistent + if self.posterior_handlers.keys() != self.datasets.keys(): + raise ValueError( + "Posterior handlers and datasets must have the same keys. " + f"Got posterior handlers keys {self.posterior_handlers.keys()} and " + f"datasets keys {self.datasets.keys()}." + ) + + # Initialize posteriors + self.posteriors: Dict[str, AbstractPosterior] = {} + for tag, posterior_handler in self.posterior_handlers.items(): + self.posteriors[tag] = posterior_handler.get_posterior( + self.datasets[tag], optimize=True, key=self.key + ) + + def ask(self, key: KeyArray) -> Float[Array, "1 D"]: + """ + Get updated acquisition function and return the point which maximises it. This + method also stores the acquisition function in + `self.current_acquisition_function` so that it can be accessed after the ask + function has been called. This is useful for non-deterministic acquisition + functions, which will differ between calls to `ask` due to the splitting of + `self.key`. + + Args: + key (KeyArray): JAX PRNG key for controlling random state. + + Returns: + Float[Array, "1 D"]: Point to be queried next. + """ + self.current_acquisition_function = ( + self.acquisition_function_builder.build_acquisition_function( + self.posteriors, self.datasets, key + ) + ) + + key, _ = jr.split(key) + return self.acquisition_maximizer.maximize( + self.current_acquisition_function, self.search_space, key + ) + + def tell(self, observation_datasets: Mapping[str, Dataset], key: KeyArray): + """ + Add newly observed data to datasets and update the corresponding posteriors. + + Args: + observation_datasets (Mapping[str, Dataset]): Dictionary of datasets + containing new observations. Tags are used to distinguish datasets, and + correspond to tags in `posterior_handlers` and `self.datasets`. + key (KeyArray): JAX PRNG key for controlling random state. + """ + if observation_datasets.keys() != self.datasets.keys(): + raise ValueError( + "Observation datasets and existing datasets must have the same keys. " + f"Got observation datasets keys {observation_datasets.keys()} and " + f"existing datasets keys {self.datasets.keys()}." + ) + + for tag, observation_dataset in observation_datasets.items(): + self.datasets[tag] += observation_dataset + + for tag, posterior_handler in self.posterior_handlers.items(): + key, _ = jr.split(key) + self.posteriors[tag] = posterior_handler.update_posterior( + self.datasets[tag], self.posteriors[tag], optimize=True, key=key + ) + + def run( + self, n_steps: int, black_box_function_evaluator: FunctionEvaluator + ) -> Mapping[str, Dataset]: + """ + Run the decision making loop continuously for for `n_steps`. This is broken down + into three main steps: + 1. Call the `ask` method to get the point to be queried next. + 2. Call the `black_box_function_evaluator` to evaluate the black box functions + of interest at the point chosen to be queried. + 3. Call the `tell` method to update the datasets and posteriors with the newly + observed data. + + In addition to this, after the `ask` step, the functions in the `post_ask` list + are executed, taking as arguments the decision maker and the point chosen to be + queried next. Similarly, after the `tell` step, the functions in the `post_tell` + list are executed, taking the decision maker as the sole argument. + + Args: + n_steps (int): Number of steps to run the decision making loop for. + black_box_function_evaluator (FunctionEvaluator): Function evaluator which + evaluates the black box functions of interest at supplied points. + + Returns: + Mapping[str, Dataset]: Dictionary of datasets containing the observations + made throughout the decision making loop, as well as the initial data + supplied when initialising the `DecisionMaker`. + """ + for _ in range(n_steps): + query_point = self.ask(self.key) + + for post_ask_method in self.post_ask: + post_ask_method(self, query_point) + + self.key, _ = jr.split(self.key) + observation_datasets = black_box_function_evaluator(query_point) + self.tell(observation_datasets, self.key) + + for post_tell_method in self.post_tell: + post_tell_method(self) + + return self.datasets diff --git a/gpjax/decision_making/utils.py b/gpjax/decision_making/utils.py index 91fa9a924..55087c172 100644 --- a/gpjax/decision_making/utils.py +++ b/gpjax/decision_making/utils.py @@ -12,6 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from beartype.typing import Final +from beartype.typing import ( + Callable, + Dict, + Final, +) + +from gpjax.dataset import Dataset +from gpjax.typing import ( + Array, + Float, +) OBJECTIVE: Final[str] = "OBJECTIVE" +""" +Tag for the objective dataset/function in standard acquisition functions. +""" + + +FunctionEvaluator = Callable[[Float[Array, "N D"]], Dict[str, Dataset]] +""" +Type alias for function evaluators, which take an array of points of shape $`[N, D]`$ +and evaluate a set of functions at each point, returning a mapping from function tags +to datasets of the evaluated points. This is the same as the `Observer` in Trieste: +https://github.com/secondmind-labs/trieste/blob/develop/trieste/observer.py +""" + + +def build_function_evaluator( + functions: Dict[str, Callable[[Float[Array, "N D"]], Float[Array, "N 1"]]] +) -> FunctionEvaluator: + """ + Takes a dictionary of functions and returns a `FunctionEvaluator` which can be + used to evaluate each of the functions at a supplied set of points and return a + dictionary of datasets storing the evaluated points. + """ + return lambda x: {tag: Dataset(x, f(x)) for tag, f in functions.items()} diff --git a/pyproject.toml b/pyproject.toml index a20ecb254..b673e198f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -133,8 +133,6 @@ select = [ "TID", # implicit string concatenation "ISC", - # type-checking imports - "TCH", ] ignore = [ # space before : (needed for how black formats slicing) diff --git a/tests/test_decision_making/test_decision_maker.py b/tests/test_decision_making/test_decision_maker.py new file mode 100644 index 000000000..235b5bc35 --- /dev/null +++ b/tests/test_decision_making/test_decision_maker.py @@ -0,0 +1,330 @@ +# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from jax import config + +config.update("jax_enable_x64", True) + +import jax.numpy as jnp +import jax.random as jr +import optax as ox +import pytest + +import gpjax as gpx +from gpjax.dataset import Dataset +from gpjax.decision_making.acquisition_functions import ( + AbstractAcquisitionFunctionBuilder, +) +from gpjax.decision_making.acquisition_maximizer import ( + AbstractAcquisitionMaximizer, + ContinuousAcquisitionMaximizer, +) +from gpjax.decision_making.decision_maker import ( + AbstractDecisionMaker, + DecisionMaker, +) +from gpjax.decision_making.posterior_handler import PosteriorHandler +from gpjax.decision_making.search_space import ( + AbstractSearchSpace, + ContinuousSearchSpace, +) +from gpjax.decision_making.test_functions import Quadratic +from gpjax.decision_making.utils import ( + OBJECTIVE, + build_function_evaluator, +) +from gpjax.typing import KeyArray +from tests.test_decision_making.utils import QuadraticAcquisitionFunctionBuilder + +CONSTRAINT = "CONSTRAINT" + + +@pytest.fixture +def search_space() -> ContinuousSearchSpace: + return ContinuousSearchSpace( + lower_bounds=jnp.array([0.0], dtype=jnp.float64), + upper_bounds=jnp.array([1.0], dtype=jnp.float64), + ) + + +@pytest.fixture +def posterior_handler() -> PosteriorHandler: + mean = gpx.Zero() + kernel = gpx.Matern52(lengthscale=jnp.array(1.0), variance=jnp.array(1.0)) + prior = gpx.Prior(mean_function=mean, kernel=kernel) + likelihood_builder = lambda x: gpx.Gaussian( + num_datapoints=x, obs_noise=jnp.array(1e-6) + ) + posterior_handler = PosteriorHandler( + prior=prior, + likelihood_builder=likelihood_builder, + optimization_objective=gpx.ConjugateMLL(negative=True), + optimizer=ox.adam(learning_rate=0.01), + num_optimization_iters=100, + ) + return posterior_handler + + +@pytest.fixture +def acquisition_function_builder() -> AbstractAcquisitionFunctionBuilder: + return QuadraticAcquisitionFunctionBuilder() + + +@pytest.fixture +def acquisition_maximizer() -> AbstractAcquisitionMaximizer: + return ContinuousAcquisitionMaximizer(num_initial_samples=1000, num_restarts=1) + + +def get_dataset(num_points: int, key: KeyArray) -> Dataset: + test_function = Quadratic() + dataset = test_function.generate_dataset(num_points=num_points, key=key) + return dataset + + +def test_abstract_decision_maker_raises_error(): + with pytest.raises(TypeError): + AbstractDecisionMaker() + + +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +def test_invalid_tags_raises_error( + search_space: AbstractSearchSpace, + posterior_handler: PosteriorHandler, + acquisition_function_builder: AbstractAcquisitionFunctionBuilder, + acquisition_maximizer: AbstractAcquisitionMaximizer, +): + key = jr.PRNGKey(42) + posterior_handlers = {OBJECTIVE: posterior_handler} + dataset = get_dataset(num_points=5, key=jr.PRNGKey(42)) + datasets = {"CONSTRAINT": dataset} # Dataset tag doesn't match posterior tag + with pytest.raises(ValueError): + DecisionMaker( + search_space=search_space, + posterior_handlers=posterior_handlers, + datasets=datasets, + acquisition_function_builder=acquisition_function_builder, + acquisition_maximizer=acquisition_maximizer, + key=key, + ) + + +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +def test_initialisation_optimizes_posterior_hyperparameters( + search_space: AbstractSearchSpace, + posterior_handler: PosteriorHandler, + acquisition_function_builder: AbstractAcquisitionFunctionBuilder, + acquisition_maximizer: AbstractAcquisitionMaximizer, +): + key = jr.PRNGKey(42) + posterior_handlers = {OBJECTIVE: posterior_handler, CONSTRAINT: posterior_handler} + objective_dataset = get_dataset(num_points=5, key=jr.PRNGKey(42)) + constraint_dataset = get_dataset(num_points=5, key=jr.PRNGKey(10)) + datasets = {"OBJECTIVE": objective_dataset, CONSTRAINT: constraint_dataset} + decision_maker = DecisionMaker( + search_space=search_space, + posterior_handlers=posterior_handlers, + datasets=datasets, + acquisition_function_builder=acquisition_function_builder, + acquisition_maximizer=acquisition_maximizer, + key=key, + ) + # Assert kernel hyperparameters get changed from their initial values + assert decision_maker.posteriors[OBJECTIVE].prior.kernel.lengthscale != jnp.array( + 1.0 + ) + assert decision_maker.posteriors[OBJECTIVE].prior.kernel.variance != jnp.array(1.0) + assert decision_maker.posteriors[CONSTRAINT].prior.kernel.lengthscale != jnp.array( + 1.0 + ) + assert decision_maker.posteriors[CONSTRAINT].prior.kernel.variance != jnp.array(1.0) + assert ( + decision_maker.posteriors[CONSTRAINT].prior.kernel.lengthscale + != decision_maker.posteriors[OBJECTIVE].prior.kernel.lengthscale + ) + assert ( + decision_maker.posteriors[CONSTRAINT].prior.kernel.variance + != decision_maker.posteriors[OBJECTIVE].prior.kernel.variance + ) + + +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +def test_decision_maker_ask( + search_space: AbstractSearchSpace, + posterior_handler: PosteriorHandler, + acquisition_function_builder: AbstractAcquisitionFunctionBuilder, + acquisition_maximizer: AbstractAcquisitionMaximizer, +): + key = jr.PRNGKey(42) + posterior_handlers = {OBJECTIVE: posterior_handler} + objective_dataset = get_dataset(num_points=5, key=jr.PRNGKey(42)) + datasets = {"OBJECTIVE": objective_dataset} + decision_maker = DecisionMaker( + search_space=search_space, + posterior_handlers=posterior_handlers, + datasets=datasets, + acquisition_function_builder=acquisition_function_builder, + acquisition_maximizer=acquisition_maximizer, + key=key, + ) + initial_decision_maker_key = decision_maker.key + query_point = decision_maker.ask(key=key) + assert query_point.shape == (1, 1) + assert jnp.allclose(query_point, jnp.array([[0.5]]), atol=1e-5) + assert decision_maker.current_acquisition_function is not None + assert ( + decision_maker.key == initial_decision_maker_key + ).all() # Ensure decision maker key is unchanged + + +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +def test_decision_maker_tell_with_inconsistent_observations_raises_error( + search_space: AbstractSearchSpace, + posterior_handler: PosteriorHandler, + acquisition_function_builder: AbstractAcquisitionFunctionBuilder, + acquisition_maximizer: AbstractAcquisitionMaximizer, +): + key = jr.PRNGKey(42) + posterior_handlers = {OBJECTIVE: posterior_handler, CONSTRAINT: posterior_handler} + initial_objective_dataset = get_dataset(num_points=5, key=jr.PRNGKey(42)) + initial_constraint_dataset = get_dataset(num_points=5, key=jr.PRNGKey(10)) + datasets = { + "OBJECTIVE": initial_objective_dataset, + CONSTRAINT: initial_constraint_dataset, + } + decision_maker = DecisionMaker( + search_space=search_space, + posterior_handlers=posterior_handlers, + datasets=datasets, + acquisition_function_builder=acquisition_function_builder, + acquisition_maximizer=acquisition_maximizer, + key=key, + ) + mock_objective_observation = get_dataset(num_points=1, key=jr.PRNGKey(1)) + mock_constraint_observation = get_dataset(num_points=1, key=jr.PRNGKey(2)) + observations = { + OBJECTIVE: mock_objective_observation, + "CONSTRAINT_ONE": mock_constraint_observation, # Deliberately incorrect tag + } + with pytest.raises(ValueError): + decision_maker.tell(observation_datasets=observations, key=key) + + +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +def test_decision_maker_tell_updates_datasets_and_models( + search_space: AbstractSearchSpace, + posterior_handler: PosteriorHandler, + acquisition_function_builder: AbstractAcquisitionFunctionBuilder, + acquisition_maximizer: AbstractAcquisitionMaximizer, +): + key = jr.PRNGKey(42) + posterior_handlers = {OBJECTIVE: posterior_handler, CONSTRAINT: posterior_handler} + initial_objective_dataset = get_dataset(num_points=5, key=jr.PRNGKey(42)) + initial_constraint_dataset = get_dataset(num_points=5, key=jr.PRNGKey(10)) + datasets = { + "OBJECTIVE": initial_objective_dataset, + CONSTRAINT: initial_constraint_dataset, + } + decision_maker = DecisionMaker( + search_space=search_space, + posterior_handlers=posterior_handlers, + datasets=datasets, + acquisition_function_builder=acquisition_function_builder, + acquisition_maximizer=acquisition_maximizer, + key=key, + ) + initial_decision_maker_key = decision_maker.key + initial_objective_posterior = decision_maker.posteriors[OBJECTIVE] + initial_constraint_posterior = decision_maker.posteriors[CONSTRAINT] + mock_objective_observation = get_dataset(num_points=1, key=jr.PRNGKey(1)) + mock_constraint_observation = get_dataset(num_points=1, key=jr.PRNGKey(2)) + observations = { + OBJECTIVE: mock_objective_observation, + CONSTRAINT: mock_constraint_observation, + } + decision_maker.tell(observation_datasets=observations, key=key) + assert decision_maker.datasets[OBJECTIVE].n == 6 + assert decision_maker.datasets[CONSTRAINT].n == 6 + assert decision_maker.datasets[OBJECTIVE].X[-1] == mock_objective_observation.X[0] + assert decision_maker.datasets[CONSTRAINT].X[-1] == mock_constraint_observation.X[0] + assert ( + decision_maker.posteriors[OBJECTIVE].prior.kernel.lengthscale + != initial_objective_posterior.prior.kernel.lengthscale + ) + assert ( + decision_maker.posteriors[OBJECTIVE].prior.kernel.variance + != initial_objective_posterior.prior.kernel.variance + ) + assert ( + decision_maker.posteriors[CONSTRAINT].prior.kernel.lengthscale + != initial_constraint_posterior.prior.kernel.lengthscale + ) + assert ( + decision_maker.posteriors[CONSTRAINT].prior.kernel.variance + != initial_constraint_posterior.prior.kernel.variance + ) + assert ( + decision_maker.key == initial_decision_maker_key + ).all() # Ensure decision maker key has not been updated + + +@pytest.mark.parametrize("n_steps", [1, 3]) +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +def test_decision_maker_run( + search_space: AbstractSearchSpace, + posterior_handler: PosteriorHandler, + acquisition_function_builder: AbstractAcquisitionFunctionBuilder, + acquisition_maximizer: AbstractAcquisitionMaximizer, + n_steps: int, +): + key = jr.PRNGKey(42) + posterior_handlers = {OBJECTIVE: posterior_handler} + initial_objective_dataset = get_dataset(num_points=5, key=jr.PRNGKey(42)) + datasets = { + "OBJECTIVE": initial_objective_dataset, + } + decision_maker = DecisionMaker( + search_space=search_space, + posterior_handlers=posterior_handlers, + datasets=datasets, + acquisition_function_builder=acquisition_function_builder, + acquisition_maximizer=acquisition_maximizer, + key=key, + ) + initial_decision_maker_key = decision_maker.key + black_box_fn = Quadratic() + black_box_function_evaluator = build_function_evaluator( + {OBJECTIVE: black_box_fn.evaluate} + ) + query_datasets = decision_maker.run( + n_steps=n_steps, black_box_function_evaluator=black_box_function_evaluator + ) + assert query_datasets[OBJECTIVE].n == 5 + n_steps + assert ( + jnp.abs(query_datasets[OBJECTIVE].X[-n_steps:] - jnp.array([[0.5]])) < 1e-5 + ).all() # Ensure we're querying the correct point in our dummy acquisition function at each step + assert ( + decision_maker.key != initial_decision_maker_key + ).all() # Ensure decision maker key gets updated diff --git a/tests/test_decision_making/test_utils.py b/tests/test_decision_making/test_utils.py new file mode 100644 index 000000000..f77bc6cab --- /dev/null +++ b/tests/test_decision_making/test_utils.py @@ -0,0 +1,46 @@ +# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from jax import config + +config.update("jax_enable_x64", True) + +import jax.numpy as jnp + +from gpjax.decision_making.utils import ( + OBJECTIVE, + build_function_evaluator, +) +from gpjax.typing import ( + Array, + Float, +) + + +def test_build_function_evaluator(): + def _square(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: + return x**2 + + def _cube(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: + return x**3 + + functions = {OBJECTIVE: _square, "CONSTRAINT": _cube} + fn_evaluator = build_function_evaluator(functions) + x = jnp.array([[2.0, 3.0]]) + datasets = fn_evaluator(x) + assert datasets.keys() == functions.keys() + assert jnp.equal(datasets[OBJECTIVE].X, x).all() + assert jnp.equal(datasets[OBJECTIVE].y, _square(x)).all() + assert jnp.equal(datasets["CONSTRAINT"].X, x).all() + assert jnp.equal(datasets["CONSTRAINT"].y, _cube(x)).all() diff --git a/tests/test_decision_making/utils.py b/tests/test_decision_making/utils.py new file mode 100644 index 000000000..f92b793d8 --- /dev/null +++ b/tests/test_decision_making/utils.py @@ -0,0 +1,45 @@ +# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from beartype.typing import Mapping + +from gpjax.dataset import Dataset +from gpjax.decision_making.acquisition_functions import ( + AbstractAcquisitionFunctionBuilder, + AcquisitionFunction, +) +from gpjax.decision_making.test_functions import Quadratic +from gpjax.gps import ConjugatePosterior +from gpjax.typing import KeyArray + + +class QuadraticAcquisitionFunctionBuilder(AbstractAcquisitionFunctionBuilder): + """ + Dummy acquisition function builder for testing purposes, which returns the negative + of the value of a quadratic test function at the input points. This is because + acquisition functions are *maximised*, and we wish to *minimise* the quadratic test + function. + """ + + def build_acquisition_function( + self, + posteriors: Mapping[str, ConjugatePosterior], + datasets: Mapping[str, Dataset], + key: KeyArray, + ) -> AcquisitionFunction: + test_function = Quadratic() + return lambda x: -1.0 * test_function.evaluate( + x + ) # Acquisition functions are *maximised* From f6aa6d4e586447b56dcb8841c3513cc439eff43d Mon Sep 17 00:00:00 2001 From: Thomas-Christie Date: Fri, 1 Sep 2023 13:20:51 +0100 Subject: [PATCH 2/2] Implement CR feedback and change acquisition function naming Implemented minor CR feedback and renamed acquisition functions to utility functions to reflect the fact that the package will be used for general Bayesian decision making and not just Bayesian optimisation. --- gpjax/citation.py | 2 +- gpjax/decision_making/__init__.py | 30 +-- gpjax/decision_making/decision_maker.py | 185 +++++++++--------- .../__init__.py | 14 +- .../base.py | 26 +-- .../thompson_sampling.py | 30 ++- ...tion_maximizer.py => utility_maximizer.py} | 73 ++++--- gpjax/decision_making/utils.py | 2 +- .../test_decision_maker.py | 100 +++++----- .../test_decision_making/test_search_space.py | 2 +- .../test_continuous_functions.py | 2 +- .../test_non_conjugate_functions.py | 2 +- .../__init__.py | 0 .../test_base.py | 10 +- .../test_thompson_sampling.py | 74 ++++--- ...maximizer.py => test_utility_maximizer.py} | 64 +++--- tests/test_decision_making/utils.py | 20 +- 17 files changed, 311 insertions(+), 325 deletions(-) rename gpjax/decision_making/{acquisition_functions => utility_functions}/__init__.py (71%) rename gpjax/decision_making/{acquisition_functions => utility_functions}/base.py (75%) rename gpjax/decision_making/{acquisition_functions => utility_functions}/thompson_sampling.py (76%) rename gpjax/decision_making/{acquisition_maximizer.py => utility_maximizer.py} (60%) rename tests/test_decision_making/{test_acquisition_functions => test_utility_functions}/__init__.py (100%) rename tests/test_decision_making/{test_acquisition_functions => test_utility_functions}/test_base.py (77%) rename tests/test_decision_making/{test_acquisition_functions => test_utility_functions}/test_thompson_sampling.py (74%) rename tests/test_decision_making/{test_acquisition_maximizer.py => test_utility_maximizer.py} (73%) diff --git a/gpjax/citation.py b/gpjax/citation.py index 56202d2ab..2c29a6175 100644 --- a/gpjax/citation.py +++ b/gpjax/citation.py @@ -10,11 +10,11 @@ from jaxlib.xla_extension import PjitFunction from plum import dispatch -from gpjax.decision_making.acquisition_functions import ThompsonSampling from gpjax.decision_making.test_functions import ( Forrester, LogarithmicGoldsteinPrice, ) +from gpjax.decision_making.utility_functions import ThompsonSampling from gpjax.kernels import ( RFF, ArcCosine, diff --git a/gpjax/decision_making/__init__.py b/gpjax/decision_making/__init__.py index 50c7d551b..5c782a3a4 100644 --- a/gpjax/decision_making/__init__.py +++ b/gpjax/decision_making/__init__.py @@ -12,18 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from gpjax.decision_making.acquisition_functions import ( - AbstractAcquisitionFunctionBuilder, - AcquisitionFunction, - ThompsonSampling, -) -from gpjax.decision_making.acquisition_maximizer import ( - AbstractAcquisitionMaximizer, - ContinuousAcquisitionMaximizer, -) from gpjax.decision_making.decision_maker import ( AbstractDecisionMaker, - DecisionMaker, + UtilityDrivenDecisionMaker, ) from gpjax.decision_making.posterior_handler import PosteriorHandler from gpjax.decision_making.search_space import ( @@ -36,18 +27,27 @@ LogarithmicGoldsteinPrice, Quadratic, ) +from gpjax.decision_making.utility_functions import ( + AbstractUtilityFunctionBuilder, + ThompsonSampling, + UtilityFunction, +) +from gpjax.decision_making.utility_maximizer import ( + AbstractUtilityMaximizer, + ContinuousUtilityMaximizer, +) from gpjax.decision_making.utils import build_function_evaluator __all__ = [ - "AbstractAcquisitionFunctionBuilder", - "AbstractAcquisitionMaximizer", + "AbstractUtilityFunctionBuilder", + "AbstractUtilityMaximizer", "AbstractDecisionMaker", "AbstractSearchSpace", - "AcquisitionFunction", + "UtilityFunction", "build_function_evaluator", - "ContinuousAcquisitionMaximizer", + "ContinuousUtilityMaximizer", "ContinuousSearchSpace", - "DecisionMaker", + "UtilityDrivenDecisionMaker", "AbstractContinuousTestFunction", "Forrester", "LogarithmicGoldsteinPrice", diff --git a/gpjax/decision_making/decision_maker.py b/gpjax/decision_making/decision_maker.py index ddfff18b0..255a6b23a 100644 --- a/gpjax/decision_making/decision_maker.py +++ b/gpjax/decision_making/decision_maker.py @@ -16,10 +16,7 @@ ABC, abstractmethod, ) -from dataclasses import ( - dataclass, - field, -) +from dataclasses import dataclass from beartype.typing import ( Callable, @@ -30,12 +27,10 @@ import jax.random as jr from gpjax.dataset import Dataset -from gpjax.decision_making.acquisition_functions import ( - AbstractAcquisitionFunctionBuilder, -) -from gpjax.decision_making.acquisition_maximizer import AbstractAcquisitionMaximizer from gpjax.decision_making.posterior_handler import PosteriorHandler from gpjax.decision_making.search_space import AbstractSearchSpace +from gpjax.decision_making.utility_functions import AbstractUtilityFunctionBuilder +from gpjax.decision_making.utility_maximizer import AbstractUtilityMaximizer from gpjax.decision_making.utils import FunctionEvaluator from gpjax.gps import AbstractPosterior from gpjax.typing import ( @@ -48,28 +43,29 @@ @dataclass class AbstractDecisionMaker(ABC): """ - AbstractDecisionMaker abstract base class which handles the core decision loop. The - decision making loop is split into two key steps, `ask` and `tell`. The `ask` + AbstractDecisionMaker abstract base class which handles the core decision making + loop, where we sequentially decide on points to query our function of interest at. + The decision making loop is split into two key steps, `ask` and `tell`. The `ask` step is typically used to decide which point to query next. The `tell` step is - typically used to update models and datasets with newly queried points. + typically used to update models and datasets with newly queried points. These steps + can be combined in a 'run' loop which alternates between asking which point to query + next and telling the decision maker about the newly queried point having evaluated + the black-box function of interest at this point. Attributes: - search_space (AbstractSearchSpace): Search space which is being queried + search_space (AbstractSearchSpace): Search space over which we can evaluate the + function(s) of interest. posterior_handlers (Dict[str, PosteriorHandler]): Dictionary of posterior handlers, which are used to update posteriors throughout the decision making - loop. Tags are used to distinguish between posteriors. In a typical Bayesian - optimisation setup one of the tags will be `OBJECTIVE`, defined in + loop. Note that the word `posteriors` is used for consistency with GPJax, but these + objects are typically referred to as `models` in the model-based decision + making literature. Tags are used to distinguish between posteriors. In a typical + Bayesian optimisation setup one of the tags will be `OBJECTIVE`, defined in decision_making.utils. datasets (Dict[str, Dataset]): Dictionary of datasets, which are augmented with observations throughout the decision making loop. In a typical setup they are - also used to fit the posteriors, using the `posterior_handlers`. Tags are used + also used to update the posteriors, using the `posterior_handlers`. Tags are used to distinguish datasets, and correspond to tags in `posterior_handlers`. - acquisition_function_builder (AbstractAcquisitionFunctionBuilder): Object which - builds acquisition functions from posteriors and datasets, to decide where - to query next. In a typical Bayesian optimisation setup the point chosen to - be queried next is the point which maximizes the acquisition function. - acquisition_maximizer (AbstractAcquisitionMaximizer): Object which maximizes - acquisition functions over the search space. key (KeyArray): JAX random key, used to generate random numbers. post_ask (List[Callable]): List of functions to be executed after each ask step. post_tell (List[Callable]): List of functions to be executed after each tell @@ -79,64 +75,13 @@ class AbstractDecisionMaker(ABC): search_space: AbstractSearchSpace posterior_handlers: Dict[str, PosteriorHandler] datasets: Dict[str, Dataset] - acquisition_function_builder: AbstractAcquisitionFunctionBuilder - acquisition_maximizer: AbstractAcquisitionMaximizer key: KeyArray - post_ask: List[Callable] = field( - default_factory=list - ) # Specific type is List[Callable[[DecisionMaker, Float[Array, ["1 D"]]], None]] but causes Beartype issues - post_tell: List[Callable] = field( - default_factory=list - ) # Specific type is List[Callable[[DecisionMaker], None]] but causes Beartype issues - - @abstractmethod - def ask(self, key: KeyArray) -> Float[Array, "1 D"]: - """ - In a typical decision making setup this will use the - `acquisition_function_builder` to form an acquisition function and then return - the point which maximizes the acquisition function using the - `acquisition_maximizer` as the point to be queried next. - - Args: - key (KeyArray): JAX PRNG key for controlling random state. - - Returns: - Float[Array, "1 D"]: Point to be queried next - """ - raise NotImplementedError - - @abstractmethod - def tell(self, observation_datasets: Mapping[str, Dataset], key: KeyArray): - """ - Tell decision maker about new observations. In a typical decision making setup - we will update the datasets and posteriors with the new observations. - - Args: - observation_datasets (Mapping[str, Dataset]): Dictionary of datasets - containing new observations. Tags are used to distinguish datasets, and - correspond to tags in `posterior_handlers` in a typical setup. - key (KeyArray): JAX PRNG key for controlling random state. - """ - raise NotImplementedError - - -@dataclass -class DecisionMaker(AbstractDecisionMaker): - """ - DecisionMaker class which handles the core decision making loop in a typical setup. The - decision making loop is split into two key steps, `ask` and `tell`. The `ask` - step forms an `AcquisitionFunction` from the current `posteriors` and `datasets` and - returns the point which maximises it. It also stores the formed acquisition function - under the attribute `self.current_acquisition_function` so that it can be called, - for instance for plotting, after the `ask` function has been called. The `tell` step - adds a newly queried point to the `datasets` and updates the `posteriors`. - - This can be run as a typical ask-tell loop, or the `run` method can be used to run - the decision making loop for a fixed number of steps. Moreover, the `run` method executes - the functions in `post_ask` and `post_tell` after each ask and tell step - respectively. This enables the user to add custom functionality, such as the ability - to plot values of interest during the optimization process. - """ + post_ask: List[ + Callable + ] # Specific type is List[Callable[[AbstractDecisionMaker, Float[Array, ["1 D"]]], None]] but causes Beartype issues + post_tell: List[ + Callable + ] # Specific type is List[Callable[[AbstractDecisionMaker], None]] but causes Beartype issues def __post_init__(self): """ @@ -159,31 +104,18 @@ def __post_init__(self): self.datasets[tag], optimize=True, key=self.key ) + @abstractmethod def ask(self, key: KeyArray) -> Float[Array, "1 D"]: """ - Get updated acquisition function and return the point which maximises it. This - method also stores the acquisition function in - `self.current_acquisition_function` so that it can be accessed after the ask - function has been called. This is useful for non-deterministic acquisition - functions, which will differ between calls to `ask` due to the splitting of - `self.key`. + Get the point to be queried next. Args: key (KeyArray): JAX PRNG key for controlling random state. Returns: - Float[Array, "1 D"]: Point to be queried next. + Float[Array, "1 D"]: Point to be queried next """ - self.current_acquisition_function = ( - self.acquisition_function_builder.build_acquisition_function( - self.posteriors, self.datasets, key - ) - ) - - key, _ = jr.split(key) - return self.acquisition_maximizer.maximize( - self.current_acquisition_function, self.search_space, key - ) + raise NotImplementedError def tell(self, observation_datasets: Mapping[str, Dataset], key: KeyArray): """ @@ -252,3 +184,66 @@ def run( post_tell_method(self) return self.datasets + + +@dataclass +class UtilityDrivenDecisionMaker(AbstractDecisionMaker): + """ + UtilityDrivenDecisionMaker class which handles the core decision making loop in a + typical model-based decision making setup. In this setup we use surrogate model(s) + for the function(s) of interest, and define a utility function (often called the + 'acquisition function' in the context of Bayesian optimisation) which characterises + how useful it would be to query a given point within the search space given the data + we have observed so far. This can then be used to decide which point(s) to query + next. + + The decision making loop is split into two key steps, `ask` and `tell`. The `ask` + step forms a `UtilityFunction` from the current `posteriors` and `datasets` and + returns the point which maximises it. It also stores the formed utility function + under the attribute `self.current_utility_function` so that it can be called, + for instance for plotting, after the `ask` function has been called. The `tell` step + adds a newly queried point to the `datasets` and updates the `posteriors`. + + This can be run as a typical ask-tell loop, or the `run` method can be used to run + the decision making loop for a fixed number of steps. Moreover, the `run` method executes + the functions in `post_ask` and `post_tell` after each ask and tell step + respectively. This enables the user to add custom functionality, such as the ability + to plot values of interest during the optimization process. + + Attributes: + utility_function_builder (AbstractUtilityFunctionBuilder): Object which + builds utility functions from posteriors and datasets, to decide where + to query next. In a typical Bayesian optimisation setup the point chosen to + be queried next is the point which maximizes the utility function. + utility_maximizer (AbstractUtilityMaximizer): Object which maximizes + utility functions over the search space. + """ + + utility_function_builder: AbstractUtilityFunctionBuilder + utility_maximizer: AbstractUtilityMaximizer + + def ask(self, key: KeyArray) -> Float[Array, "1 D"]: + """ + Get updated utility function and return the point which maximises it. This + method also stores the utility function in + `self.current_utility_function` so that it can be accessed after the ask + function has been called. This is useful for non-deterministic utility + functions, which will differ between calls to `ask` due to the splitting of + `self.key`. + + Args: + key (KeyArray): JAX PRNG key for controlling random state. + + Returns: + Float[Array, "1 D"]: Point to be queried next. + """ + self.current_utility_function = ( + self.utility_function_builder.build_utility_function( + self.posteriors, self.datasets, key + ) + ) + + key, _ = jr.split(key) + return self.utility_maximizer.maximize( + self.current_utility_function, self.search_space, key + ) diff --git a/gpjax/decision_making/acquisition_functions/__init__.py b/gpjax/decision_making/utility_functions/__init__.py similarity index 71% rename from gpjax/decision_making/acquisition_functions/__init__.py rename to gpjax/decision_making/utility_functions/__init__.py index 9c986fff6..dbf8a9577 100644 --- a/gpjax/decision_making/acquisition_functions/__init__.py +++ b/gpjax/decision_making/utility_functions/__init__.py @@ -12,16 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from gpjax.decision_making.acquisition_functions.base import ( - AbstractAcquisitionFunctionBuilder, - AcquisitionFunction, -) -from gpjax.decision_making.acquisition_functions.thompson_sampling import ( - ThompsonSampling, +from gpjax.decision_making.utility_functions.base import ( + AbstractUtilityFunctionBuilder, + UtilityFunction, ) +from gpjax.decision_making.utility_functions.thompson_sampling import ThompsonSampling __all__ = [ - "AcquisitionFunction", - "AbstractAcquisitionFunctionBuilder", + "UtilityFunction", + "AbstractUtilityFunctionBuilder", "ThompsonSampling", ] diff --git a/gpjax/decision_making/acquisition_functions/base.py b/gpjax/decision_making/utility_functions/base.py similarity index 75% rename from gpjax/decision_making/acquisition_functions/base.py rename to gpjax/decision_making/utility_functions/base.py index 59cfcbf8e..c06ffc9db 100644 --- a/gpjax/decision_making/acquisition_functions/base.py +++ b/gpjax/decision_making/utility_functions/base.py @@ -32,17 +32,17 @@ KeyArray, ) -AcquisitionFunction = Callable[[Float[Array, "N D"]], Float[Array, "N 1"]] +UtilityFunction = Callable[[Float[Array, "N D"]], Float[Array, "N 1"]] """ -Type alias for acquisition functions, which take an array of points of shape $`[N, D]`$ -and return the value of the acquisition function at each point in an array of shape $`[N, 1]`$. +Type alias for utility functions, which take an array of points of shape $`[N, D]`$ +and return the value of the utility function at each point in an array of shape $`[N, 1]`$. """ @dataclass -class AbstractAcquisitionFunctionBuilder(ABC): +class AbstractUtilityFunctionBuilder(ABC): """ - Abstract class for building acquisition functions. + Abstract class for building utility functions. """ def check_objective_present( @@ -56,9 +56,9 @@ def check_objective_present( Args: posteriors (Mapping[str, AbstractPosterior]): Dictionary of posteriors to be - used to form the acquisition function. + used to form the utility function. datasets (Mapping[str, Dataset]): Dictionary of datasets which may be used - to form the acquisition function. + to form the utility function. Raises: ValueError: If the objective posterior or dataset are not present in the @@ -70,24 +70,24 @@ def check_objective_present( raise ValueError("Objective dataset not found in datasets") @abstractmethod - def build_acquisition_function( + def build_utility_function( self, posteriors: Mapping[str, AbstractPosterior], datasets: Mapping[str, Dataset], key: KeyArray, - ) -> AcquisitionFunction: + ) -> UtilityFunction: """ - Build an `AcquisitionFunction` from a set of posteriors and datasets. + Build a `UtilityFunction` from a set of posteriors and datasets. Args: posteriors (Mapping[str, AbstractPosterior]): Dictionary of posteriors to be - used to form the acquisition function. + used to form the utility function. datasets (Mapping[str, Dataset]): Dictionary of datasets which may be used - to form the acquisition function. + to form the utility function. key (KeyArray): JAX PRNG key used for random number generation. Returns: - AcquisitionFunction: Acquisition function to be *maximised* in order to + UtilityFunction: Utility function to be *maximised* in order to decide which point to query next. """ raise NotImplementedError diff --git a/gpjax/decision_making/acquisition_functions/thompson_sampling.py b/gpjax/decision_making/utility_functions/thompson_sampling.py similarity index 76% rename from gpjax/decision_making/acquisition_functions/thompson_sampling.py rename to gpjax/decision_making/utility_functions/thompson_sampling.py index 129cb4e97..6ef765f5f 100644 --- a/gpjax/decision_making/acquisition_functions/thompson_sampling.py +++ b/gpjax/decision_making/utility_functions/thompson_sampling.py @@ -17,9 +17,9 @@ from beartype.typing import Mapping from gpjax.dataset import Dataset -from gpjax.decision_making.acquisition_functions.base import ( - AbstractAcquisitionFunctionBuilder, - AcquisitionFunction, +from gpjax.decision_making.utility_functions.base import ( + AbstractUtilityFunctionBuilder, + UtilityFunction, ) from gpjax.decision_making.utils import OBJECTIVE from gpjax.gps import ConjugatePosterior @@ -27,12 +27,12 @@ @dataclass -class ThompsonSampling(AbstractAcquisitionFunctionBuilder): +class ThompsonSampling(AbstractUtilityFunctionBuilder): """ - Form an acquisition function by drawing an approximate sample from the posterior, + Form a utility function by drawing an approximate sample from the posterior, using decoupled sampling as introduced in [Wilson et. al. (2020)](https://arxiv.org/abs/2002.09309). Note that we return the *negative* of the - sample as the acquisition function, as acquisition functions are *maximised*. + sample as the utility function, as utility functions are *maximised*. Attributes: num_features (int): The number of random Fourier features to use when drawing @@ -47,30 +47,30 @@ def __post_init__(self): "The number of random Fourier features must be a positive integer." ) - def build_acquisition_function( + def build_utility_function( self, posteriors: Mapping[str, ConjugatePosterior], datasets: Mapping[str, Dataset], key: KeyArray, - ) -> AcquisitionFunction: + ) -> UtilityFunction: """ Draw an approximate sample from the posterior of the objective model and return - the *negative* of this sample as an acquisition function, as acquisition functions + the *negative* of this sample as a utility function, as utility functions are *maximised*. Args: posteriors (Mapping[str, AbstractPosterior]): Dictionary of posteriors to be - used to form the acquisition function. One of the posteriors must correspond + used to form the utility function. One of the posteriors must correspond to the `OBJECTIVE` key, as we sample from the objective posterior to form - the acquisition function. + the utility function. datasets (Mapping[str, Dataset]): Dictionary of datasets which may be used - to form the acquisition function. Keys in `datasets` should correspond to + to form the utility function. Keys in `datasets` should correspond to keys in `posteriors`. One of the datasets must correspond to the `OBJECTIVE` key. key (KeyArray): JAX PRNG key used for random number generation. Returns: - AcquisitionFunction: An appproximate sample from the objective model + UtilityFunction: An appproximate sample from the objective model posterior to to be *maximised* in order to decide which point to query next. """ @@ -90,6 +90,4 @@ def build_acquisition_function( num_features=self.num_features, ) - return lambda x: -1.0 * thompson_sample( - x - ) # Acquisition functions are *maximised* + return lambda x: -1.0 * thompson_sample(x) # Utility functions are *maximised* diff --git a/gpjax/decision_making/acquisition_maximizer.py b/gpjax/decision_making/utility_maximizer.py similarity index 60% rename from gpjax/decision_making/acquisition_maximizer.py rename to gpjax/decision_making/utility_maximizer.py index 1f75f5ced..ff0871f5d 100644 --- a/gpjax/decision_making/acquisition_maximizer.py +++ b/gpjax/decision_making/utility_maximizer.py @@ -22,11 +22,11 @@ import jax.random as jr from jaxopt import ScipyBoundedMinimize -from gpjax.decision_making.acquisition_functions import AcquisitionFunction from gpjax.decision_making.search_space import ( AbstractSearchSpace, ContinuousSearchSpace, ) +from gpjax.decision_making.utility_functions import UtilityFunction from gpjax.typing import ( Array, Float, @@ -36,60 +36,58 @@ def _get_discrete_maximizer( - query_points: Float[Array, "N D"], acquisition_function: AcquisitionFunction + query_points: Float[Array, "N D"], utility_function: UtilityFunction ) -> Float[Array, "1 D"]: - """Get the point which maximises the acquisition function evaluated at a given set of points. + """Get the point which maximises the utility function evaluated at a given set of points. Args: query_points (Float[Array, "N D"]): Set of points at which to evaluate the - acquisition function. - acquisition_function (AcquisitionFunction): Acquisition function - to evaluate at `query_points`. + utility function. + utility_function (UtilityFunction): Utility function to evaluate at `query_points`. Returns: - Float[Array, "1 D"]: Point in `query_points` which maximises the acquisition - function. + Float[Array, "1 D"]: Point in `query_points` which maximises the utility function. """ - acquisition_function_values = acquisition_function(query_points) - max_acquisition_function_value_idx = jnp.argmax( - acquisition_function_values, axis=0, keepdims=True + utility_function_values = utility_function(query_points) + max_utility_function_value_idx = jnp.argmax( + utility_function_values, axis=0, keepdims=True ) best_sample_point = jnp.take_along_axis( - query_points, max_acquisition_function_value_idx, axis=0 + query_points, max_utility_function_value_idx, axis=0 ) return best_sample_point @dataclass -class AbstractAcquisitionMaximizer(ABC): - """Abstract base class for acquisition function maximizers.""" +class AbstractUtilityMaximizer(ABC): + """Abstract base class for utility function maximizers.""" @abstractmethod def maximize( self, - acquisition_function: AcquisitionFunction, + utility_function: UtilityFunction, search_space: AbstractSearchSpace, key: KeyArray, ) -> Float[Array, "1 D"]: - """Maximize the given acquisition function over the search space provided. + """Maximize the given utility function over the search space provided. Args: - acquisition_function (AcquisitionFunction): Acquisition function to be + utility_function (UtilityFunction): Utility function to be maximized. search_space (AbstractSearchSpace): Search space over which to maximize - the acquisition function. + the utility function. key (KeyArray): JAX PRNG key. Returns: - Float[Array, "1 D"]: Point at which the acquisition function is maximized. + Float[Array, "1 D"]: Point at which the utility function is maximized. """ raise NotImplementedError @dataclass -class ContinuousAcquisitionMaximizer(AbstractAcquisitionMaximizer): - """The `ContinuousAcquisitionMaximizer` class is used to maximize acquisition - functions over the continuous domain with L-BFGS-B. First we sample the acquisition +class ContinuousUtilityMaximizer(AbstractUtilityMaximizer): + """The `ContinuousUtilityMaximizer` class is used to maximize utility + functions over the continuous domain with L-BFGS-B. First we sample the utility function at `num_initial_samples` points from the search space, and then we run L-BFGS-B from the best of these initial points. We run this process `num_restarts` number of times, each time sampling a different random set of @@ -111,11 +109,11 @@ def __post_init__(self): def maximize( self, - acquisition_function: AcquisitionFunction, + utility_function: UtilityFunction, search_space: ContinuousSearchSpace, key: KeyArray, ) -> Float[Array, "1 D"]: - max_observed_acquisition_function_value = None + max_observed_utility_function_value = None maximizer = None for _ in range(self.num_restarts): @@ -124,36 +122,31 @@ def maximize( self.num_initial_samples, key=key ) best_initial_sample_point = _get_discrete_maximizer( - initial_sample_points, acquisition_function + initial_sample_points, utility_function ) - def _scalar_acquisition_function(x: Float[Array, "1 D"]) -> ScalarFloat: + def _scalar_utility_function(x: Float[Array, "1 D"]) -> ScalarFloat: """ The Jaxopt minimizer requires a function which returns a scalar. It calls the - acquisition function with one point at a time, so the acquisition function + utility function with one point at a time, so the utility function returns an array of shape [1, 1], so we index to return a scalar. Note that - we also return the negative of the acquisition function - this is because - acquisition functions should be *maximimized* but the Jaxopt minimizer + we also return the negative of the utility function - this is because + utility functions should be *maximimized* but the Jaxopt minimizer minimizes functions. """ - return -acquisition_function(x)[0][0] + return -utility_function(x)[0][0] lbfgsb = ScipyBoundedMinimize( - fun=_scalar_acquisition_function, method="l-bfgs-b" + fun=_scalar_utility_function, method="l-bfgs-b" ) bounds = (search_space.lower_bounds, search_space.upper_bounds) optimized_point = lbfgsb.run( best_initial_sample_point, bounds=bounds ).params - optimized_acquisition_function_value = _scalar_acquisition_function( - optimized_point - ) - if (max_observed_acquisition_function_value is None) or ( - optimized_acquisition_function_value - > max_observed_acquisition_function_value + optimized_utility_function_value = _scalar_utility_function(optimized_point) + if (max_observed_utility_function_value is None) or ( + optimized_utility_function_value > max_observed_utility_function_value ): - max_observed_acquisition_function_value = ( - optimized_acquisition_function_value - ) + max_observed_utility_function_value = optimized_utility_function_value maximizer = optimized_point return maximizer diff --git a/gpjax/decision_making/utils.py b/gpjax/decision_making/utils.py index 55087c172..9af19de32 100644 --- a/gpjax/decision_making/utils.py +++ b/gpjax/decision_making/utils.py @@ -26,7 +26,7 @@ OBJECTIVE: Final[str] = "OBJECTIVE" """ -Tag for the objective dataset/function in standard acquisition functions. +Tag for the objective dataset/function in standard utility functions. """ diff --git a/tests/test_decision_making/test_decision_maker.py b/tests/test_decision_making/test_decision_maker.py index 235b5bc35..dad2340e1 100644 --- a/tests/test_decision_making/test_decision_maker.py +++ b/tests/test_decision_making/test_decision_maker.py @@ -23,16 +23,9 @@ import gpjax as gpx from gpjax.dataset import Dataset -from gpjax.decision_making.acquisition_functions import ( - AbstractAcquisitionFunctionBuilder, -) -from gpjax.decision_making.acquisition_maximizer import ( - AbstractAcquisitionMaximizer, - ContinuousAcquisitionMaximizer, -) from gpjax.decision_making.decision_maker import ( AbstractDecisionMaker, - DecisionMaker, + UtilityDrivenDecisionMaker, ) from gpjax.decision_making.posterior_handler import PosteriorHandler from gpjax.decision_making.search_space import ( @@ -40,12 +33,17 @@ ContinuousSearchSpace, ) from gpjax.decision_making.test_functions import Quadratic +from gpjax.decision_making.utility_functions import AbstractUtilityFunctionBuilder +from gpjax.decision_making.utility_maximizer import ( + AbstractUtilityMaximizer, + ContinuousUtilityMaximizer, +) from gpjax.decision_making.utils import ( OBJECTIVE, build_function_evaluator, ) from gpjax.typing import KeyArray -from tests.test_decision_making.utils import QuadraticAcquisitionFunctionBuilder +from tests.test_decision_making.utils import QuadraticUtilityFunctionBuilder CONSTRAINT = "CONSTRAINT" @@ -77,13 +75,13 @@ def posterior_handler() -> PosteriorHandler: @pytest.fixture -def acquisition_function_builder() -> AbstractAcquisitionFunctionBuilder: - return QuadraticAcquisitionFunctionBuilder() +def utility_function_builder() -> AbstractUtilityFunctionBuilder: + return QuadraticUtilityFunctionBuilder() @pytest.fixture -def acquisition_maximizer() -> AbstractAcquisitionMaximizer: - return ContinuousAcquisitionMaximizer(num_initial_samples=1000, num_restarts=1) +def utility_maximizer() -> AbstractUtilityMaximizer: + return ContinuousUtilityMaximizer(num_initial_samples=1000, num_restarts=1) def get_dataset(num_points: int, key: KeyArray) -> Dataset: @@ -103,21 +101,23 @@ def test_abstract_decision_maker_raises_error(): def test_invalid_tags_raises_error( search_space: AbstractSearchSpace, posterior_handler: PosteriorHandler, - acquisition_function_builder: AbstractAcquisitionFunctionBuilder, - acquisition_maximizer: AbstractAcquisitionMaximizer, + utility_function_builder: AbstractUtilityFunctionBuilder, + utility_maximizer: AbstractUtilityMaximizer, ): key = jr.PRNGKey(42) posterior_handlers = {OBJECTIVE: posterior_handler} dataset = get_dataset(num_points=5, key=jr.PRNGKey(42)) datasets = {"CONSTRAINT": dataset} # Dataset tag doesn't match posterior tag with pytest.raises(ValueError): - DecisionMaker( + UtilityDrivenDecisionMaker( search_space=search_space, posterior_handlers=posterior_handlers, datasets=datasets, - acquisition_function_builder=acquisition_function_builder, - acquisition_maximizer=acquisition_maximizer, + utility_function_builder=utility_function_builder, + utility_maximizer=utility_maximizer, key=key, + post_ask=[], + post_tell=[], ) @@ -127,21 +127,23 @@ def test_invalid_tags_raises_error( def test_initialisation_optimizes_posterior_hyperparameters( search_space: AbstractSearchSpace, posterior_handler: PosteriorHandler, - acquisition_function_builder: AbstractAcquisitionFunctionBuilder, - acquisition_maximizer: AbstractAcquisitionMaximizer, + utility_function_builder: AbstractUtilityFunctionBuilder, + utility_maximizer: AbstractUtilityMaximizer, ): key = jr.PRNGKey(42) posterior_handlers = {OBJECTIVE: posterior_handler, CONSTRAINT: posterior_handler} objective_dataset = get_dataset(num_points=5, key=jr.PRNGKey(42)) constraint_dataset = get_dataset(num_points=5, key=jr.PRNGKey(10)) datasets = {"OBJECTIVE": objective_dataset, CONSTRAINT: constraint_dataset} - decision_maker = DecisionMaker( + decision_maker = UtilityDrivenDecisionMaker( search_space=search_space, posterior_handlers=posterior_handlers, datasets=datasets, - acquisition_function_builder=acquisition_function_builder, - acquisition_maximizer=acquisition_maximizer, + utility_function_builder=utility_function_builder, + utility_maximizer=utility_maximizer, key=key, + post_ask=[], + post_tell=[], ) # Assert kernel hyperparameters get changed from their initial values assert decision_maker.posteriors[OBJECTIVE].prior.kernel.lengthscale != jnp.array( @@ -168,26 +170,28 @@ def test_initialisation_optimizes_posterior_hyperparameters( def test_decision_maker_ask( search_space: AbstractSearchSpace, posterior_handler: PosteriorHandler, - acquisition_function_builder: AbstractAcquisitionFunctionBuilder, - acquisition_maximizer: AbstractAcquisitionMaximizer, + utility_function_builder: AbstractUtilityFunctionBuilder, + utility_maximizer: AbstractUtilityMaximizer, ): key = jr.PRNGKey(42) posterior_handlers = {OBJECTIVE: posterior_handler} objective_dataset = get_dataset(num_points=5, key=jr.PRNGKey(42)) datasets = {"OBJECTIVE": objective_dataset} - decision_maker = DecisionMaker( + decision_maker = UtilityDrivenDecisionMaker( search_space=search_space, posterior_handlers=posterior_handlers, datasets=datasets, - acquisition_function_builder=acquisition_function_builder, - acquisition_maximizer=acquisition_maximizer, + utility_function_builder=utility_function_builder, + utility_maximizer=utility_maximizer, key=key, + post_ask=[], + post_tell=[], ) initial_decision_maker_key = decision_maker.key query_point = decision_maker.ask(key=key) assert query_point.shape == (1, 1) assert jnp.allclose(query_point, jnp.array([[0.5]]), atol=1e-5) - assert decision_maker.current_acquisition_function is not None + assert decision_maker.current_utility_function is not None assert ( decision_maker.key == initial_decision_maker_key ).all() # Ensure decision maker key is unchanged @@ -199,8 +203,8 @@ def test_decision_maker_ask( def test_decision_maker_tell_with_inconsistent_observations_raises_error( search_space: AbstractSearchSpace, posterior_handler: PosteriorHandler, - acquisition_function_builder: AbstractAcquisitionFunctionBuilder, - acquisition_maximizer: AbstractAcquisitionMaximizer, + utility_function_builder: AbstractUtilityFunctionBuilder, + utility_maximizer: AbstractUtilityMaximizer, ): key = jr.PRNGKey(42) posterior_handlers = {OBJECTIVE: posterior_handler, CONSTRAINT: posterior_handler} @@ -210,13 +214,15 @@ def test_decision_maker_tell_with_inconsistent_observations_raises_error( "OBJECTIVE": initial_objective_dataset, CONSTRAINT: initial_constraint_dataset, } - decision_maker = DecisionMaker( + decision_maker = UtilityDrivenDecisionMaker( search_space=search_space, posterior_handlers=posterior_handlers, datasets=datasets, - acquisition_function_builder=acquisition_function_builder, - acquisition_maximizer=acquisition_maximizer, + utility_function_builder=utility_function_builder, + utility_maximizer=utility_maximizer, key=key, + post_ask=[], + post_tell=[], ) mock_objective_observation = get_dataset(num_points=1, key=jr.PRNGKey(1)) mock_constraint_observation = get_dataset(num_points=1, key=jr.PRNGKey(2)) @@ -234,8 +240,8 @@ def test_decision_maker_tell_with_inconsistent_observations_raises_error( def test_decision_maker_tell_updates_datasets_and_models( search_space: AbstractSearchSpace, posterior_handler: PosteriorHandler, - acquisition_function_builder: AbstractAcquisitionFunctionBuilder, - acquisition_maximizer: AbstractAcquisitionMaximizer, + utility_function_builder: AbstractUtilityFunctionBuilder, + utility_maximizer: AbstractUtilityMaximizer, ): key = jr.PRNGKey(42) posterior_handlers = {OBJECTIVE: posterior_handler, CONSTRAINT: posterior_handler} @@ -245,13 +251,15 @@ def test_decision_maker_tell_updates_datasets_and_models( "OBJECTIVE": initial_objective_dataset, CONSTRAINT: initial_constraint_dataset, } - decision_maker = DecisionMaker( + decision_maker = UtilityDrivenDecisionMaker( search_space=search_space, posterior_handlers=posterior_handlers, datasets=datasets, - acquisition_function_builder=acquisition_function_builder, - acquisition_maximizer=acquisition_maximizer, + utility_function_builder=utility_function_builder, + utility_maximizer=utility_maximizer, key=key, + post_ask=[], + post_tell=[], ) initial_decision_maker_key = decision_maker.key initial_objective_posterior = decision_maker.posteriors[OBJECTIVE] @@ -295,8 +303,8 @@ def test_decision_maker_tell_updates_datasets_and_models( def test_decision_maker_run( search_space: AbstractSearchSpace, posterior_handler: PosteriorHandler, - acquisition_function_builder: AbstractAcquisitionFunctionBuilder, - acquisition_maximizer: AbstractAcquisitionMaximizer, + utility_function_builder: AbstractUtilityFunctionBuilder, + utility_maximizer: AbstractUtilityMaximizer, n_steps: int, ): key = jr.PRNGKey(42) @@ -305,13 +313,15 @@ def test_decision_maker_run( datasets = { "OBJECTIVE": initial_objective_dataset, } - decision_maker = DecisionMaker( + decision_maker = UtilityDrivenDecisionMaker( search_space=search_space, posterior_handlers=posterior_handlers, datasets=datasets, - acquisition_function_builder=acquisition_function_builder, - acquisition_maximizer=acquisition_maximizer, + utility_function_builder=utility_function_builder, + utility_maximizer=utility_maximizer, key=key, + post_ask=[], + post_tell=[], ) initial_decision_maker_key = decision_maker.key black_box_fn = Quadratic() @@ -324,7 +334,7 @@ def test_decision_maker_run( assert query_datasets[OBJECTIVE].n == 5 + n_steps assert ( jnp.abs(query_datasets[OBJECTIVE].X[-n_steps:] - jnp.array([[0.5]])) < 1e-5 - ).all() # Ensure we're querying the correct point in our dummy acquisition function at each step + ).all() # Ensure we're querying the correct point in our dummy utility function at each step assert ( decision_maker.key != initial_decision_maker_key ).all() # Ensure decision maker key gets updated diff --git a/tests/test_decision_making/test_search_space.py b/tests/test_decision_making/test_search_space.py index 49bac87ab..95c47f8bb 100644 --- a/tests/test_decision_making/test_search_space.py +++ b/tests/test_decision_making/test_search_space.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== from beartype.roar import BeartypeCallHintParamViolation -from jax.config import config +from jax import config import jax.numpy as jnp import jax.random as jr from jaxtyping import ( diff --git a/tests/test_decision_making/test_test_functions/test_continuous_functions.py b/tests/test_decision_making/test_test_functions/test_continuous_functions.py index 7b410f815..6272a6694 100644 --- a/tests/test_decision_making/test_test_functions/test_continuous_functions.py +++ b/tests/test_decision_making/test_test_functions/test_continuous_functions.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from jax.config import config +from jax import config config.update("jax_enable_x64", True) diff --git a/tests/test_decision_making/test_test_functions/test_non_conjugate_functions.py b/tests/test_decision_making/test_test_functions/test_non_conjugate_functions.py index 837abdbfd..f6d999b5c 100644 --- a/tests/test_decision_making/test_test_functions/test_non_conjugate_functions.py +++ b/tests/test_decision_making/test_test_functions/test_non_conjugate_functions.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from jax.config import config +from jax import config config.update("jax_enable_x64", True) diff --git a/tests/test_decision_making/test_acquisition_functions/__init__.py b/tests/test_decision_making/test_utility_functions/__init__.py similarity index 100% rename from tests/test_decision_making/test_acquisition_functions/__init__.py rename to tests/test_decision_making/test_utility_functions/__init__.py diff --git a/tests/test_decision_making/test_acquisition_functions/test_base.py b/tests/test_decision_making/test_utility_functions/test_base.py similarity index 77% rename from tests/test_decision_making/test_acquisition_functions/test_base.py rename to tests/test_decision_making/test_utility_functions/test_base.py index b71831ef3..c38a458f0 100644 --- a/tests/test_decision_making/test_acquisition_functions/test_base.py +++ b/tests/test_decision_making/test_utility_functions/test_base.py @@ -13,16 +13,14 @@ # limitations under the License. # ============================================================================== -from jax.config import config +from jax import config import pytest -from gpjax.decision_making.acquisition_functions.base import ( - AbstractAcquisitionFunctionBuilder, -) +from gpjax.decision_making.utility_functions.base import AbstractUtilityFunctionBuilder config.update("jax_enable_x64", True) -def test_abstract_acquisition_function_builder(): +def test_abstract_utility_function_builder(): with pytest.raises(TypeError): - AbstractAcquisitionFunctionBuilder() + AbstractUtilityFunctionBuilder() diff --git a/tests/test_decision_making/test_acquisition_functions/test_thompson_sampling.py b/tests/test_decision_making/test_utility_functions/test_thompson_sampling.py similarity index 74% rename from tests/test_decision_making/test_acquisition_functions/test_thompson_sampling.py rename to tests/test_decision_making/test_utility_functions/test_thompson_sampling.py index 0d1980fe5..2ffd41ffe 100644 --- a/tests/test_decision_making/test_acquisition_functions/test_thompson_sampling.py +++ b/tests/test_decision_making/test_utility_functions/test_thompson_sampling.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from jax.config import config +from jax import config config.update("jax_enable_x64", True) @@ -22,14 +22,12 @@ import pytest from gpjax.dataset import Dataset -from gpjax.decision_making.acquisition_functions.thompson_sampling import ( - ThompsonSampling, -) from gpjax.decision_making.test_functions.continuous_functions import ( AbstractContinuousTestFunction, Forrester, LogarithmicGoldsteinPrice, ) +from gpjax.decision_making.utility_functions.thompson_sampling import ThompsonSampling from gpjax.decision_making.utils import OBJECTIVE from gpjax.gps import ( ConjugatePosterior, @@ -74,8 +72,8 @@ def test_thompson_sampling_no_objective_posterior_raises_error(): posteriors = {"CONSTRAINT": posterior} datasets = {OBJECTIVE: dataset} with pytest.raises(ValueError): - ts_acquisition_builder = ThompsonSampling(num_features=100) - ts_acquisition_builder.build_acquisition_function( + ts_utility_builder = ThompsonSampling(num_features=100) + ts_utility_builder.build_utility_function( posteriors=posteriors, datasets=datasets, key=key ) @@ -91,8 +89,8 @@ def test_thompson_sampling_no_objective_dataset_raises_error(): posteriors = {OBJECTIVE: posterior} datasets = {"CONSTRAINT": dataset} with pytest.raises(ValueError): - ts_acquisition_builder = ThompsonSampling(num_features=100) - ts_acquisition_builder.build_acquisition_function( + ts_utility_builder = ThompsonSampling(num_features=100) + ts_utility_builder.build_utility_function( posteriors=posteriors, datasets=datasets, key=key ) @@ -108,8 +106,8 @@ def test_thompson_sampling_non_conjugate_posterior_raises_error(): posteriors = {OBJECTIVE: posterior} datasets = {OBJECTIVE: dataset} with pytest.raises(ValueError): - ts_acquisition_builder = ThompsonSampling(num_features=100) - ts_acquisition_builder.build_acquisition_function( + ts_utility_builder = ThompsonSampling(num_features=100) + ts_utility_builder.build_utility_function( posteriors=posteriors, datasets=datasets, key=key ) @@ -126,8 +124,8 @@ def test_thompson_sampling_invalid_rff_num_raises_error(num_rff_features: int): posteriors = {OBJECTIVE: posterior} datasets = {OBJECTIVE: dataset} with pytest.raises(ValueError): - ts_acquisition_builder = ThompsonSampling(num_features=num_rff_features) - ts_acquisition_builder.build_acquisition_function( + ts_utility_builder = ThompsonSampling(num_features=num_rff_features) + ts_utility_builder.build_utility_function( posteriors=posteriors, datasets=datasets, key=key ) @@ -141,7 +139,7 @@ def test_thompson_sampling_invalid_rff_num_raises_error(num_rff_features: int): @pytest.mark.filterwarnings( "ignore::UserWarning" ) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort -def test_thompson_sampling_acquisition_function_correct_shapes( +def test_thompson_sampling_utility_function_correct_shapes( test_target_function: AbstractContinuousTestFunction, num_test_points: int, key: KeyArray, @@ -150,14 +148,14 @@ def test_thompson_sampling_acquisition_function_correct_shapes( posterior = generate_dummy_conjugate_posterior(dataset) posteriors = {OBJECTIVE: posterior} datasets = {OBJECTIVE: dataset} - ts_acquisition_builder = ThompsonSampling(num_features=100) - ts_acquisition_function = ts_acquisition_builder.build_acquisition_function( + ts_utility_builder = ThompsonSampling(num_features=100) + ts_utility_function = ts_utility_builder.build_utility_function( posteriors=posteriors, datasets=datasets, key=key ) test_key, _ = jr.split(key) test_X = test_target_function.generate_test_points(num_test_points, test_key) - ts_acquisition_function_values = ts_acquisition_function(test_X) - assert ts_acquisition_function_values.shape == (num_test_points, 1) + ts_utility_function_values = ts_utility_function(test_X) + assert ts_utility_function_values.shape == (num_test_points, 1) @pytest.mark.parametrize( @@ -169,7 +167,7 @@ def test_thompson_sampling_acquisition_function_correct_shapes( @pytest.mark.filterwarnings( "ignore::UserWarning" ) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort -def test_thompson_sampling_acquisition_function_same_key_same_function( +def test_thompson_sampling_utility_function_same_key_same_function( test_target_function: AbstractContinuousTestFunction, num_test_points: int, key: KeyArray, @@ -178,23 +176,21 @@ def test_thompson_sampling_acquisition_function_same_key_same_function( posterior = generate_dummy_conjugate_posterior(dataset) posteriors = {OBJECTIVE: posterior} datasets = {OBJECTIVE: dataset} - ts_acquisition_builder_one = ThompsonSampling(num_features=100) - ts_acquisition_builder_two = ThompsonSampling(num_features=100) - ts_acquisition_function_one = ts_acquisition_builder_one.build_acquisition_function( + ts_utility_builder_one = ThompsonSampling(num_features=100) + ts_utility_builder_two = ThompsonSampling(num_features=100) + ts_utility_function_one = ts_utility_builder_one.build_utility_function( posteriors=posteriors, datasets=datasets, key=key ) - ts_acquisition_function_two = ts_acquisition_builder_two.build_acquisition_function( + ts_utility_function_two = ts_utility_builder_two.build_utility_function( posteriors=posteriors, datasets=datasets, key=key ) test_key, _ = jr.split(key) test_X = test_target_function.generate_test_points(num_test_points, test_key) - ts_acquisition_function_one_values = ts_acquisition_function_one(test_X) - ts_acquisition_function_two_values = ts_acquisition_function_two(test_X) - assert isinstance(ts_acquisition_function_one, Callable) - assert isinstance(ts_acquisition_function_two, Callable) - assert ( - ts_acquisition_function_one_values == ts_acquisition_function_two_values - ).all() + ts_utility_function_one_values = ts_utility_function_one(test_X) + ts_utility_function_two_values = ts_utility_function_two(test_X) + assert isinstance(ts_utility_function_one, Callable) + assert isinstance(ts_utility_function_two, Callable) + assert (ts_utility_function_one_values == ts_utility_function_two_values).all() @pytest.mark.parametrize( @@ -206,7 +202,7 @@ def test_thompson_sampling_acquisition_function_same_key_same_function( @pytest.mark.filterwarnings( "ignore::UserWarning" ) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort -def test_thompson_sampling_acquisition_function_different_key_different_function( +def test_thompson_sampling_utility_function_different_key_different_function( test_target_function: AbstractContinuousTestFunction, num_test_points: int, key: KeyArray, @@ -217,19 +213,17 @@ def test_thompson_sampling_acquisition_function_different_key_different_function datasets = {OBJECTIVE: dataset} sample_one_key = key sample_two_key, _ = jr.split(key) - ts_acquisition_builder = ThompsonSampling(num_features=100) - ts_acquisition_function_one = ts_acquisition_builder.build_acquisition_function( + ts_utility_builder = ThompsonSampling(num_features=100) + ts_utility_function_one = ts_utility_builder.build_utility_function( posteriors=posteriors, datasets=datasets, key=sample_one_key ) - ts_acquisition_function_two = ts_acquisition_builder.build_acquisition_function( + ts_utility_function_two = ts_utility_builder.build_utility_function( posteriors=posteriors, datasets=datasets, key=sample_two_key ) test_key, _ = jr.split(sample_two_key) test_X = test_target_function.generate_test_points(num_test_points, test_key) - ts_acquisition_function_one_values = ts_acquisition_function_one(test_X) - ts_acquisition_function_two_values = ts_acquisition_function_two(test_X) - assert isinstance(ts_acquisition_function_one, Callable) - assert isinstance(ts_acquisition_function_two, Callable) - assert not ( - ts_acquisition_function_one_values == ts_acquisition_function_two_values - ).all() + ts_utility_function_one_values = ts_utility_function_one(test_X) + ts_utility_function_two_values = ts_utility_function_two(test_X) + assert isinstance(ts_utility_function_one, Callable) + assert isinstance(ts_utility_function_two, Callable) + assert not (ts_utility_function_one_values == ts_utility_function_two_values).all() diff --git a/tests/test_decision_making/test_acquisition_maximizer.py b/tests/test_decision_making/test_utility_maximizer.py similarity index 73% rename from tests/test_decision_making/test_acquisition_maximizer.py rename to tests/test_decision_making/test_utility_maximizer.py index 939a6f80e..980add289 100644 --- a/tests/test_decision_making/test_acquisition_maximizer.py +++ b/tests/test_decision_making/test_utility_maximizer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from jax.config import config +from jax import config config.update("jax_enable_x64", True) @@ -20,23 +20,23 @@ import jax.random as jr import pytest -from gpjax.decision_making.acquisition_maximizer import ( - AbstractAcquisitionMaximizer, - ContinuousAcquisitionMaximizer, - _get_discrete_maximizer, -) from gpjax.decision_making.test_functions.continuous_functions import ( AbstractContinuousTestFunction, Forrester, LogarithmicGoldsteinPrice, Quadratic, ) +from gpjax.decision_making.utility_maximizer import ( + AbstractUtilityMaximizer, + ContinuousUtilityMaximizer, + _get_discrete_maximizer, +) from gpjax.typing import KeyArray -def test_abstract_acquisition_maximizer(): +def test_abstract_utility_maximizer(): with pytest.raises(TypeError): - AbstractAcquisitionMaximizer() + AbstractUtilityMaximizer() @pytest.mark.parametrize( @@ -53,13 +53,13 @@ def test_discrete_maximizer_returns_correct_point( key: KeyArray, ): query_points = test_function.generate_test_points(1000, key=key) - acquisition_function = lambda x: -1.0 * test_function.evaluate(x) - acquisition_vals = acquisition_function(query_points) - true_max_acquisition_val = jnp.max(acquisition_vals) - discrete_maximizer = _get_discrete_maximizer(query_points, acquisition_function) + utility_function = lambda x: -1.0 * test_function.evaluate(x) + utility_vals = utility_function(query_points) + true_max_utility_val = jnp.max(utility_vals) + discrete_maximizer = _get_discrete_maximizer(query_points, utility_function) assert discrete_maximizer.shape == (1, dimensionality) assert discrete_maximizer.dtype == jnp.float64 - assert acquisition_function(discrete_maximizer)[0][0] == true_max_acquisition_val + assert utility_function(discrete_maximizer)[0][0] == true_max_utility_val @pytest.mark.parametrize("num_initial_samples", [0, -1, -10]) @@ -67,7 +67,7 @@ def test_continuous_maximizer_raises_error_with_erroneous_num_initial_samples( num_initial_samples: int, ): with pytest.raises(ValueError): - ContinuousAcquisitionMaximizer( + ContinuousUtilityMaximizer( num_initial_samples=num_initial_samples, num_restarts=1 ) @@ -77,7 +77,7 @@ def test_continuous_maximizer_raises_error_with_erroneous_num_restarts( num_restarts: int, ): with pytest.raises(ValueError): - ContinuousAcquisitionMaximizer(num_initial_samples=1, num_restarts=num_restarts) + ContinuousUtilityMaximizer(num_initial_samples=1, num_restarts=num_restarts) @pytest.mark.parametrize( @@ -95,20 +95,20 @@ def test_continous_maximizer_returns_same_point_with_same_key( key: KeyArray, num_restarts: int, ): - continuous_maximizer_one = ContinuousAcquisitionMaximizer( + continuous_maximizer_one = ContinuousUtilityMaximizer( num_initial_samples=1000, num_restarts=num_restarts ) - continuous_maximizer_two = ContinuousAcquisitionMaximizer( + continuous_maximizer_two = ContinuousUtilityMaximizer( num_initial_samples=1000, num_restarts=num_restarts ) - acquisition_function = lambda x: -1.0 * test_function.evaluate(x) + utility_function = lambda x: -1.0 * test_function.evaluate(x) maximizer_one = continuous_maximizer_one.maximize( - acquisition_function=acquisition_function, + utility_function=utility_function, search_space=test_function.search_space, key=key, ) maximizer_two = continuous_maximizer_two.maximize( - acquisition_function=acquisition_function, + utility_function=utility_function, search_space=test_function.search_space, key=key, ) @@ -137,19 +137,19 @@ def test_continuous_maximizer_finds_correct_point( key: KeyArray, num_restarts: int, ): - continuous_acquisition_maximizer = ContinuousAcquisitionMaximizer( + continuous_utility_maximizer = ContinuousUtilityMaximizer( num_initial_samples=1000, num_restarts=num_restarts ) - acquisition_function = lambda x: -1.0 * test_function.evaluate(x) - true_acquisition_maximizer = test_function.minimizer - maximizer = continuous_acquisition_maximizer.maximize( - acquisition_function=acquisition_function, + utility_function = lambda x: -1.0 * test_function.evaluate(x) + true_utility_maximizer = test_function.minimizer + maximizer = continuous_utility_maximizer.maximize( + utility_function=utility_function, search_space=test_function.search_space, key=key, ) assert maximizer.shape == (1, dimensionality) assert maximizer.dtype == jnp.float64 - assert jnp.allclose(maximizer, true_acquisition_maximizer, atol=1e-6).all() + assert jnp.allclose(maximizer, true_utility_maximizer, atol=1e-6).all() @pytest.mark.parametrize("key", [jr.PRNGKey(42), jr.PRNGKey(10), jr.PRNGKey(1)]) @@ -159,17 +159,17 @@ def test_continuous_maximizer_finds_correct_point( ) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort def test_continuous_maximizer_jaxopt_component(key: KeyArray, num_restarts: int): quadratic = Quadratic() - continuous_acquisition_maximizer = ContinuousAcquisitionMaximizer( + continuous_utility_maximizer = ContinuousUtilityMaximizer( num_initial_samples=1, # Force JaxOpt L-GFBS-B to do the heavy lifting num_restarts=num_restarts, ) - acquisition_function = lambda x: -1.0 * quadratic.evaluate(x) - true_acquisition_maximizer = quadratic.minimizer - maximizer = continuous_acquisition_maximizer.maximize( - acquisition_function=acquisition_function, + utility_function = lambda x: -1.0 * quadratic.evaluate(x) + true_utility_maximizer = quadratic.minimizer + maximizer = continuous_utility_maximizer.maximize( + utility_function=utility_function, search_space=quadratic.search_space, key=key, ) assert maximizer.shape == (1, 1) assert maximizer.dtype == jnp.float64 - assert jnp.allclose(maximizer, true_acquisition_maximizer, atol=1e-6).all() + assert jnp.allclose(maximizer, true_utility_maximizer, atol=1e-6).all() diff --git a/tests/test_decision_making/utils.py b/tests/test_decision_making/utils.py index f92b793d8..90ad2d75a 100644 --- a/tests/test_decision_making/utils.py +++ b/tests/test_decision_making/utils.py @@ -16,30 +16,30 @@ from beartype.typing import Mapping from gpjax.dataset import Dataset -from gpjax.decision_making.acquisition_functions import ( - AbstractAcquisitionFunctionBuilder, - AcquisitionFunction, -) from gpjax.decision_making.test_functions import Quadratic +from gpjax.decision_making.utility_functions import ( + AbstractUtilityFunctionBuilder, + UtilityFunction, +) from gpjax.gps import ConjugatePosterior from gpjax.typing import KeyArray -class QuadraticAcquisitionFunctionBuilder(AbstractAcquisitionFunctionBuilder): +class QuadraticUtilityFunctionBuilder(AbstractUtilityFunctionBuilder): """ - Dummy acquisition function builder for testing purposes, which returns the negative + Dummy utility function builder for testing purposes, which returns the negative of the value of a quadratic test function at the input points. This is because - acquisition functions are *maximised*, and we wish to *minimise* the quadratic test + utility functions are *maximised*, and we wish to *minimise* the quadratic test function. """ - def build_acquisition_function( + def build_utility_function( self, posteriors: Mapping[str, ConjugatePosterior], datasets: Mapping[str, Dataset], key: KeyArray, - ) -> AcquisitionFunction: + ) -> UtilityFunction: test_function = Quadratic() return lambda x: -1.0 * test_function.evaluate( x - ) # Acquisition functions are *maximised* + ) # Utility functions are *maximised*