From 85a1a748a1a72e154ab3a16d97ef87063dff4d66 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Thu, 18 Jul 2024 21:00:57 +0200 Subject: [PATCH 01/17] Add backend doc --- docs/examples/backend.py | 159 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 159 insertions(+) create mode 100644 docs/examples/backend.py diff --git a/docs/examples/backend.py b/docs/examples/backend.py new file mode 100644 index 000000000..35aacbd67 --- /dev/null +++ b/docs/examples/backend.py @@ -0,0 +1,159 @@ +# -*- coding: utf-8 -*- +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# custom_cell_magics: kql +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.11.2 +# kernelspec: +# display_name: gpjax +# language: python +# name: python3 +# --- + +# %% [markdown] +# # Backend Module Design +# +# Since v0.9, GPJax is built upon Flax's [NNX](https://flax.readthedocs.io/en/latest/nnx/index.html) module. This transition allows for more efficient parameter handling, improved integration with Flax and Flax-based libraries, and enhanced flexibility in model design. This notebook provides a high-level overview of the backend module design in GPJax. For an introduction to NNX, please refer to the [official documentation](https://flax.readthedocs.io/en/latest/nnx/index.html). +# + +# %% +import jax.numpy as jnp +import matplotlib as mpl +import matplotlib.pyplot as plt +import gpjax as gpx + +plt.style.use( + "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" +) +cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] + +# %% [markdown] +# ## Parameters +# +# The biggest change bought about by the transition to an NNX backend is the increased support we now provide for handling parameters. As discussed in our [Sharp Bits - Bijectors Doc](https://docs.jaxgaussianprocesses.com/sharp_bits/#bijectors), GPJax uses bijectors to transform constrained parameters to unconstrained parameters during optimisation. You may now register the support of a parameter using our `Parameter` class. To see this, consider the constant mean function who contains a single constant parameter whose value ordinarily exists on the real line. We can register this parameter as follows: + +# %% +from gpjax.mean_functions import Constant +from gpjax.parameters import Real + +constant_param = Real(value=1.0) +meanf = Constant(constant_param) +meanf + +# %% [markdown] +# However, suppose you wish your mean function's constant parameter to be strictly positive. This is easy to achieve by using the correct Parameter type. + +# %% +from gpjax.parameters import PositiveReal + +constant_param = PositiveReal(value=1.0) +meanf = Constant(constant_param) +meanf + +# %% [markdown] +# Were we to try and instantiate the `PositiveReal` class with a negative value, then an explicit error would be raised. + +# %% +try: + PositiveReal(value=-1.0) +except ValueError as e: + print(e) + +# %% [markdown] +# ### Parameter Transforms +# +# With a parameter instantiated, you likely wish to transform the parameter's value from its constrained support onto the entire real line. To do this, you can apply the `transform` function to the parameter. To control the bijector used to transform the parameter, you may pass a set of bijectors into the transform function. Under-the-hood, the `transform` function is looking up the bijector of a parameter using it's `_tag` field in the bijector dictionary, and then applying the bijector to the parameter's value using a tree map operation. + +# %% +print(constant_param._tag) + +# %% [markdown] +# For most users, you will not need to worry about this as we provide a set of default bijectors that are defined for all the parameter types we support. However, see our [Kernel Guide Notebook](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/) to see how you can define your own bijectors and parameter types. + +# %% +from gpjax.parameters import DEFAULT_BIJECTION, transform + +print(DEFAULT_BIJECTION[constant_param._tag]) + +# %% [markdown] +# We see here that the Softplus bijector is specified as the default for strictly positive parameters. To apply this, we may invoke the following + +# %% +transform(constant_param, DEFAULT_BIJECTION) + +# %% [markdown] +# ### Transforming Multiple Parameters +# +# In the above, we transformed a single parameter. However, in practice your parameters may be nested within several functions e.g., a kernel function within a GP model. Fortunately, transforming several parameters is a simple operation that we here demonstrate for a regular GP poster + +# %% +kernel = gpx.kernels.Matern32() +meanf = gpx.mean_functions.Constant() + +prior = gpx.gps.Prior(meanf, kernel) + + +likelihood = gpx.likelihoods.Gaussian(100) +posterior = likelihood * prior + +# %% [markdown] +# # Backend Module Design +# +# Since v0.9, GPJax is built upon Flax's [NNX](https://flax.readthedocs.io/en/latest/nnx/index.html) module. This transition allows for more efficient parameter handling, improved integration with Flax and Flax-based libraries, and enhanced flexibility in model design. This notebook provides a high-level overview of the backend module design in GPJax. For an introduction to NNX, please refer to the [official documentation](https://flax.readthedocs.io/en/latest/nnx/index.html). +# + +# %% [markdown] +# ## NNX Modules +# +# To demonstrate the ease of use and flexibility of NNX modules, we will implement a linear mean function using the existing abstractions in GPJax. For inputs $x_n \in \mathbb{R}^d$, the linear mean function $m(x): \mathbb{R}^d \to \mathbb{R}$ is defined as: +# $$ +# m(x) = \alpha + \sum_{i=1}^d \beta_i x_i +# $$ +# where $\alpha \in \mathbb{R}$ and $\beta_i \in \mathbb{R}$ are the parameters of the mean function. Let's now implement that using the new NNX backend. + +# %% +import typing as tp + +from jaxtyping import Float, Num + +from gpjax.mean_functions import AbstractMeanFunction +from gpjax.parameters import Real, Parameter +from gpjax.typing import ScalarFloat, Array + + +class LinearMeanFunction(AbstractMeanFunction): + def __init__( + self, + intercept: tp.Union[ScalarFloat, Float[Array, " O"], Parameter] = 0.0, + slope: tp.Union[ScalarFloat, Float[Array, " D O"], Parameter] = 0.0, + ): + if isinstance(intercept, Parameter): + self.intercept = intercept + else: + self.intercept = Real(jnp.array(intercept)) + + if isinstance(slope, Parameter): + self.slope = slope + else: + self.slope = Real(jnp.array(slope)) + + def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N O"]: + return self.intercept.value + jnp.dot(x, self.slope.value) + + +# %% [markdown] +# As we can see, the implementation is straightforward and concise. The `AbstractMeanFunction` module is a subclass of `nnx.Module`. From here, we inform the module about the parameters +# + +# %% +X = jnp.linspace(-5.0, 5.0, 100)[:, None] + +meanf = LinearMeanFunction(intercept=1.0, slope=2.0) +plt.plot(X, meanf(X)) + +# %% From abfeec6718e4a281d5d6019be86626f312080dbc Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Fri, 19 Jul 2024 08:15:14 +0200 Subject: [PATCH 02/17] Add backend doc --- docs/examples/backend.py | 122 +++++++++++++++++++++++++++++++++++---- mkdocs.yml | 1 + 2 files changed, 111 insertions(+), 12 deletions(-) diff --git a/docs/examples/backend.py b/docs/examples/backend.py index 35aacbd67..7102d2e2b 100644 --- a/docs/examples/backend.py +++ b/docs/examples/backend.py @@ -22,6 +22,11 @@ # # %% +# Enable Float64 for more stable matrix inversions. +from jax import config, grad + +config.update("jax_enable_x64", True) + import jax.numpy as jnp import matplotlib as mpl import matplotlib.pyplot as plt @@ -84,33 +89,88 @@ # We see here that the Softplus bijector is specified as the default for strictly positive parameters. To apply this, we may invoke the following # %% -transform(constant_param, DEFAULT_BIJECTION) +transform(constant_param, DEFAULT_BIJECTION, inverse=True) + +# %% [markdown] +# The parameter's value was changed here from 1. to 0.54132485. This is the result of applying the Softplus bijector to the parameter's value and projecting its value onto the real line. Were the parameter's value to be closer to 0, then the transformation would be more pronounced. + +# %% +transform(PositiveReal(value=1e-6), DEFAULT_BIJECTION, inverse=True) # %% [markdown] -# ### Transforming Multiple Parameters +# ### Transforming Multiple Parameters # -# In the above, we transformed a single parameter. However, in practice your parameters may be nested within several functions e.g., a kernel function within a GP model. Fortunately, transforming several parameters is a simple operation that we here demonstrate for a regular GP poster +# In the above, we transformed a single parameter. However, in practice your parameters may be nested within several functions e.g., a kernel function within a GP model. Fortunately, transforming several parameters is a simple operation that we here demonstrate for a conjugate GP posterior (see our [Regression Notebook](https://docs.jaxgaussianprocesses.com/examples/regression/) for detailed explanation of this model.). # %% kernel = gpx.kernels.Matern32() meanf = gpx.mean_functions.Constant() -prior = gpx.gps.Prior(meanf, kernel) - +prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) likelihood = gpx.likelihoods.Gaussian(100) posterior = likelihood * prior +posterior # %% [markdown] -# # Backend Module Design -# -# Since v0.9, GPJax is built upon Flax's [NNX](https://flax.readthedocs.io/en/latest/nnx/index.html) module. This transition allows for more efficient parameter handling, improved integration with Flax and Flax-based libraries, and enhanced flexibility in model design. This notebook provides a high-level overview of the backend module design in GPJax. For an introduction to NNX, please refer to the [official documentation](https://flax.readthedocs.io/en/latest/nnx/index.html). +# Now contained within the posterior PyGraph here there are four parameters: the kernel's lengthscale and variance, the noise variance of the likelihood, and the constant of the mean function. Using NNX, we may realise these parameters through the `nnx.split` function. The `split` function deomposes a PyGraph into a `GraphDef` and `State` object. As the name suggests, `State` contains information on the parameters' state, whilst `GraphDef` contains the information required to reconstruct a PyGraph from a give `State`. + +# %% +from flax import nnx + +graphdef, state = nnx.split(posterior) +state + +# %% [markdown] +# The `State` object behaves just like a PyTree and, consequently, we may use JAX's `tree_map` function to alter the values of the `State`. The updated `State` can then be used to reconstruct our posterior. In the below, we simply increment each parameter's value by 1. + +# %% +import jax.tree_util as jtu + +updated_state = jtu.tree_map(lambda x: x + 1, state) +updated_state + +# %% [markdown] +# Let us now use NNX's `merge` function to reconstruct the posterior distribution using the updated state. + +# %% +updated_posterior = nnx.merge(graphdef, updated_state) +updated_posterior + +# %% [markdown] +# However, we begun this point of conversation with bijectors in mind, so let us now see how bijectors may be applied to a collection of parameters in GPJax. Fortunately, this is very straightforward, and we may simply use the `trasnform` function as before. + +# %% +transformed_state = transform(state, DEFAULT_BIJECTION, inverse=True) +transformed_state + +# %% [markdown] +# We may also (re-)constrain the parameters' values by setting the `inverse` argument of `transform` to False. + +# %% +retransformed_state = transform(transformed_state, DEFAULT_BIJECTION, inverse=False) +retransformed_state == transformed_state + +# %% [markdown] +# ### Fine-Scale Control # +# One of the advantages of being able to split and re-merge the PyGraph is that we are able to gain fine-scale control over the parameters' whose state we wish to realise. This is by virtue of the fact that each of our parameters now inherit from `gpjax.parameters.Parameter`. In the former, we were simply extracting and `Parameter` from the posterior. However, suppose we only wish to extract those parameters whose support is the positive real line. This is easily achieved by altering the way in which we invoke `nnx.split`. + +# %% +from gpjax.parameters import PositiveReal + +graphdef, positive_reals, other_params = nnx.split(posterior, PositiveReal, ...) +print(positive_reals) + +# %% [markdown] +# Now we see that we have two state objects: one containing the positive real parameters and the other containing the remaining parameters. This functionality is exceptionally useful as it allows us to efficiently operate on a subset of the parameters whilst leaving the others untouched. Looking forward, we hope to use this functionality in our [Variational Inference Approximations](https://docs.jaxgaussianprocesses.com/examples/uncollapsed_vi/) to perform more efficient updates of the variational parameters and then the model's hyperparameters. # %% [markdown] # ## NNX Modules # -# To demonstrate the ease of use and flexibility of NNX modules, we will implement a linear mean function using the existing abstractions in GPJax. For inputs $x_n \in \mathbb{R}^d$, the linear mean function $m(x): \mathbb{R}^d \to \mathbb{R}$ is defined as: +# To conclude this notebook, we will now demonstrate the ease of use and flexibility offered by NNX modules. To do this, we will implement a linear mean function using the existing abstractions in GPJax. +# +# For inputs $x_n \in \mathbb{R}^d$, the linear mean function $m(x): \mathbb{R}^d \to \mathbb{R}$ is defined as: # $$ # m(x) = \alpha + \sum_{i=1}^d \beta_i x_i # $$ @@ -122,7 +182,7 @@ from jaxtyping import Float, Num from gpjax.mean_functions import AbstractMeanFunction -from gpjax.parameters import Real, Parameter +from gpjax.parameters import Parameter, Real from gpjax.typing import ScalarFloat, Array @@ -147,13 +207,51 @@ def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N O"]: # %% [markdown] -# As we can see, the implementation is straightforward and concise. The `AbstractMeanFunction` module is a subclass of `nnx.Module`. From here, we inform the module about the parameters +# As we can see, the implementation is straightforward and concise. The `AbstractMeanFunction` module is a subclass of `nnx.Module` and may, therefore, be used in any `split` or `merge` call. Further, we have registered the intercept and slope parameters as `Real` parameter types. This registers their value in the PyGraph and means that they will be part of any operation applied to the PyGraph e.g., transforming and differentiation. # +# To check our implementation worked, let's now plot the value of our mean function for a linearly spaced set of inputs. # %% -X = jnp.linspace(-5.0, 5.0, 100)[:, None] +N = 100 +X = jnp.linspace(-5.0, 5.0, N)[:, None] meanf = LinearMeanFunction(intercept=1.0, slope=2.0) plt.plot(X, meanf(X)) +# %% [markdown] +# Looks good! To conclude this section, let's now parameterise a GP with our new mean function and see how gradients may be computed. + +# %% +y = jnp.sin(X) +D = gpx.Dataset(X, y) + +prior = gpx.gps.Prior(mean_function=meanf, kernel=gpx.kernels.Matern32()) +likelihood = gpx.likelihoods.Gaussian(D.n) +posterior = likelihood * prior + +# %% [markdown] +# We'll compute derivatives of the conjugate marginal log-likelihood, with respect to the unconstrained state of the kernel, mean function, and likelihood parameters. + +# %% +graphdef, params, others = nnx.split(posterior, Parameter, ...) +params = transform(params, DEFAULT_BIJECTION) + + +def loss_fn(params: nnx.State, data: gpx.Dataset) -> ScalarFloat: + params = transform(params, DEFAULT_BIJECTION) + model = nnx.merge(graphdef, params, *others) + return -gpx.objectives.conjugate_mll(model, data) + + +grad(loss_fn)(params, D) + +# %% [markdown] +# ## Conclusions +# +# In this notebook we have explored how GPJax's Flax-based backend may be easily manipulated and extended. For a more applied look at this, see how we construct a kernel on polar coordinated in our [Kernel Guide](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/#custom-kernel) notebook. +# +# ## System configuration + # %% +# %reload_ext watermark +# %watermark -n -u -v -iv -w -a 'Thomas Pinder' diff --git a/mkdocs.yml b/mkdocs.yml index f9c1a95f9..03492cb39 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -33,6 +33,7 @@ nav: - 📖 Guides for customisation: - Kernels: examples/constructing_new_kernels.md - Likelihoods: examples/likelihoods_guide.md + - Model Guide: examples/backend.md - UCI regression: examples/yacht.md - 💻 Raw tutorial code: give_me_the_code.md - Community: From 0d9e41ef75d93f95202e2a0724f901d553b4f999 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Fri, 19 Jul 2024 08:17:31 +0200 Subject: [PATCH 03/17] Add backend doc --- docs/examples/backend.py | 124 +++++++++++++++++++++++++++++++-------- 1 file changed, 100 insertions(+), 24 deletions(-) diff --git a/docs/examples/backend.py b/docs/examples/backend.py index 7102d2e2b..6190feb72 100644 --- a/docs/examples/backend.py +++ b/docs/examples/backend.py @@ -18,7 +18,13 @@ # %% [markdown] # # Backend Module Design # -# Since v0.9, GPJax is built upon Flax's [NNX](https://flax.readthedocs.io/en/latest/nnx/index.html) module. This transition allows for more efficient parameter handling, improved integration with Flax and Flax-based libraries, and enhanced flexibility in model design. This notebook provides a high-level overview of the backend module design in GPJax. For an introduction to NNX, please refer to the [official documentation](https://flax.readthedocs.io/en/latest/nnx/index.html). +# Since v0.9, GPJax is built upon Flax's +# [NNX](https://flax.readthedocs.io/en/latest/nnx/index.html) module. This transition +# allows for more efficient parameter handling, improved integration with Flax and +# Flax-based libraries, and enhanced flexibility in model design. This notebook provides +# a high-level overview of the backend module design in GPJax. For an introduction to +# NNX, please refer to the [official +# documentation](https://flax.readthedocs.io/en/latest/nnx/index.html). # # %% @@ -40,7 +46,14 @@ # %% [markdown] # ## Parameters # -# The biggest change bought about by the transition to an NNX backend is the increased support we now provide for handling parameters. As discussed in our [Sharp Bits - Bijectors Doc](https://docs.jaxgaussianprocesses.com/sharp_bits/#bijectors), GPJax uses bijectors to transform constrained parameters to unconstrained parameters during optimisation. You may now register the support of a parameter using our `Parameter` class. To see this, consider the constant mean function who contains a single constant parameter whose value ordinarily exists on the real line. We can register this parameter as follows: +# The biggest change bought about by the transition to an NNX backend is the increased +# support we now provide for handling parameters. As discussed in our [Sharp Bits - +# Bijectors Doc](https://docs.jaxgaussianprocesses.com/sharp_bits/#bijectors), GPJax +# uses bijectors to transform constrained parameters to unconstrained parameters during +# optimisation. You may now register the support of a parameter using our `Parameter` +# class. To see this, consider the constant mean function who contains a single constant +# parameter whose value ordinarily exists on the real line. We can register this +# parameter as follows: # %% from gpjax.mean_functions import Constant @@ -51,7 +64,8 @@ meanf # %% [markdown] -# However, suppose you wish your mean function's constant parameter to be strictly positive. This is easy to achieve by using the correct Parameter type. +# However, suppose you wish your mean function's constant parameter to be strictly +# positive. This is easy to achieve by using the correct Parameter type. # %% from gpjax.parameters import PositiveReal @@ -61,7 +75,8 @@ meanf # %% [markdown] -# Were we to try and instantiate the `PositiveReal` class with a negative value, then an explicit error would be raised. +# Were we to try and instantiate the `PositiveReal` class with a negative value, then an +# explicit error would be raised. # %% try: @@ -72,13 +87,23 @@ # %% [markdown] # ### Parameter Transforms # -# With a parameter instantiated, you likely wish to transform the parameter's value from its constrained support onto the entire real line. To do this, you can apply the `transform` function to the parameter. To control the bijector used to transform the parameter, you may pass a set of bijectors into the transform function. Under-the-hood, the `transform` function is looking up the bijector of a parameter using it's `_tag` field in the bijector dictionary, and then applying the bijector to the parameter's value using a tree map operation. +# With a parameter instantiated, you likely wish to transform the parameter's value from +# its constrained support onto the entire real line. To do this, you can apply the +# `transform` function to the parameter. To control the bijector used to transform the +# parameter, you may pass a set of bijectors into the transform function. +# Under-the-hood, the `transform` function is looking up the bijector of a parameter +# using it's `_tag` field in the bijector dictionary, and then applying the bijector to +# the parameter's value using a tree map operation. # %% print(constant_param._tag) # %% [markdown] -# For most users, you will not need to worry about this as we provide a set of default bijectors that are defined for all the parameter types we support. However, see our [Kernel Guide Notebook](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/) to see how you can define your own bijectors and parameter types. +# For most users, you will not need to worry about this as we provide a set of default +# bijectors that are defined for all the parameter types we support. However, see our +# [Kernel Guide +# Notebook](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/) to +# see how you can define your own bijectors and parameter types. # %% from gpjax.parameters import DEFAULT_BIJECTION, transform @@ -86,13 +111,17 @@ print(DEFAULT_BIJECTION[constant_param._tag]) # %% [markdown] -# We see here that the Softplus bijector is specified as the default for strictly positive parameters. To apply this, we may invoke the following +# We see here that the Softplus bijector is specified as the default for strictly +# positive parameters. To apply this, we may invoke the following # %% transform(constant_param, DEFAULT_BIJECTION, inverse=True) # %% [markdown] -# The parameter's value was changed here from 1. to 0.54132485. This is the result of applying the Softplus bijector to the parameter's value and projecting its value onto the real line. Were the parameter's value to be closer to 0, then the transformation would be more pronounced. +# The parameter's value was changed here from 1. to 0.54132485. This is the result of +# applying the Softplus bijector to the parameter's value and projecting its value onto +# the real line. Were the parameter's value to be closer to 0, then the transformation +# would be more pronounced. # %% transform(PositiveReal(value=1e-6), DEFAULT_BIJECTION, inverse=True) @@ -100,7 +129,12 @@ # %% [markdown] # ### Transforming Multiple Parameters # -# In the above, we transformed a single parameter. However, in practice your parameters may be nested within several functions e.g., a kernel function within a GP model. Fortunately, transforming several parameters is a simple operation that we here demonstrate for a conjugate GP posterior (see our [Regression Notebook](https://docs.jaxgaussianprocesses.com/examples/regression/) for detailed explanation of this model.). +# In the above, we transformed a single parameter. However, in practice your parameters +# may be nested within several functions e.g., a kernel function within a GP model. +# Fortunately, transforming several parameters is a simple operation that we here +# demonstrate for a conjugate GP posterior (see our [Regression +# Notebook](https://docs.jaxgaussianprocesses.com/examples/regression/) for detailed +# explanation of this model.). # %% kernel = gpx.kernels.Matern32() @@ -113,7 +147,13 @@ posterior # %% [markdown] -# Now contained within the posterior PyGraph here there are four parameters: the kernel's lengthscale and variance, the noise variance of the likelihood, and the constant of the mean function. Using NNX, we may realise these parameters through the `nnx.split` function. The `split` function deomposes a PyGraph into a `GraphDef` and `State` object. As the name suggests, `State` contains information on the parameters' state, whilst `GraphDef` contains the information required to reconstruct a PyGraph from a give `State`. +# Now contained within the posterior PyGraph here there are four parameters: the +# kernel's lengthscale and variance, the noise variance of the likelihood, and the +# constant of the mean function. Using NNX, we may realise these parameters through the +# `nnx.split` function. The `split` function deomposes a PyGraph into a `GraphDef` and +# `State` object. As the name suggests, `State` contains information on the parameters' +# state, whilst `GraphDef` contains the information required to reconstruct a PyGraph +# from a give `State`. # %% from flax import nnx @@ -122,7 +162,10 @@ state # %% [markdown] -# The `State` object behaves just like a PyTree and, consequently, we may use JAX's `tree_map` function to alter the values of the `State`. The updated `State` can then be used to reconstruct our posterior. In the below, we simply increment each parameter's value by 1. +# The `State` object behaves just like a PyTree and, consequently, we may use JAX's +# `tree_map` function to alter the values of the `State`. The updated `State` can then +# be used to reconstruct our posterior. In the below, we simply increment each +# parameter's value by 1. # %% import jax.tree_util as jtu @@ -131,21 +174,25 @@ updated_state # %% [markdown] -# Let us now use NNX's `merge` function to reconstruct the posterior distribution using the updated state. +# Let us now use NNX's `merge` function to reconstruct the posterior distribution using +# the updated state. # %% updated_posterior = nnx.merge(graphdef, updated_state) updated_posterior # %% [markdown] -# However, we begun this point of conversation with bijectors in mind, so let us now see how bijectors may be applied to a collection of parameters in GPJax. Fortunately, this is very straightforward, and we may simply use the `trasnform` function as before. +# However, we begun this point of conversation with bijectors in mind, so let us now see +# how bijectors may be applied to a collection of parameters in GPJax. Fortunately, this +# is very straightforward, and we may simply use the `trasnform` function as before. # %% transformed_state = transform(state, DEFAULT_BIJECTION, inverse=True) transformed_state # %% [markdown] -# We may also (re-)constrain the parameters' values by setting the `inverse` argument of `transform` to False. +# We may also (re-)constrain the parameters' values by setting the `inverse` argument of +# `transform` to False. # %% retransformed_state = transform(transformed_state, DEFAULT_BIJECTION, inverse=False) @@ -154,7 +201,13 @@ # %% [markdown] # ### Fine-Scale Control # -# One of the advantages of being able to split and re-merge the PyGraph is that we are able to gain fine-scale control over the parameters' whose state we wish to realise. This is by virtue of the fact that each of our parameters now inherit from `gpjax.parameters.Parameter`. In the former, we were simply extracting and `Parameter` from the posterior. However, suppose we only wish to extract those parameters whose support is the positive real line. This is easily achieved by altering the way in which we invoke `nnx.split`. +# One of the advantages of being able to split and re-merge the PyGraph is that we are +# able to gain fine-scale control over the parameters' whose state we wish to realise. +# This is by virtue of the fact that each of our parameters now inherit from +# `gpjax.parameters.Parameter`. In the former, we were simply extracting and `Parameter` +# from the posterior. However, suppose we only wish to extract those parameters whose +# support is the positive real line. This is easily achieved by altering the way in +# which we invoke `nnx.split`. # %% from gpjax.parameters import PositiveReal @@ -163,18 +216,29 @@ print(positive_reals) # %% [markdown] -# Now we see that we have two state objects: one containing the positive real parameters and the other containing the remaining parameters. This functionality is exceptionally useful as it allows us to efficiently operate on a subset of the parameters whilst leaving the others untouched. Looking forward, we hope to use this functionality in our [Variational Inference Approximations](https://docs.jaxgaussianprocesses.com/examples/uncollapsed_vi/) to perform more efficient updates of the variational parameters and then the model's hyperparameters. +# Now we see that we have two state objects: one containing the positive real parameters +# and the other containing the remaining parameters. This functionality is exceptionally +# useful as it allows us to efficiently operate on a subset of the parameters whilst +# leaving the others untouched. Looking forward, we hope to use this functionality in +# our [Variational Inference +# Approximations](https://docs.jaxgaussianprocesses.com/examples/uncollapsed_vi/) to +# perform more efficient updates of the variational parameters and then the model's +# hyperparameters. # %% [markdown] # ## NNX Modules # -# To conclude this notebook, we will now demonstrate the ease of use and flexibility offered by NNX modules. To do this, we will implement a linear mean function using the existing abstractions in GPJax. +# To conclude this notebook, we will now demonstrate the ease of use and flexibility +# offered by NNX modules. To do this, we will implement a linear mean function using the +# existing abstractions in GPJax. # -# For inputs $x_n \in \mathbb{R}^d$, the linear mean function $m(x): \mathbb{R}^d \to \mathbb{R}$ is defined as: +# For inputs $x_n \in \mathbb{R}^d$, the linear mean function $m(x): \mathbb{R}^d \to +# \mathbb{R}$ is defined as: # $$ # m(x) = \alpha + \sum_{i=1}^d \beta_i x_i # $$ -# where $\alpha \in \mathbb{R}$ and $\beta_i \in \mathbb{R}$ are the parameters of the mean function. Let's now implement that using the new NNX backend. +# where $\alpha \in \mathbb{R}$ and $\beta_i \in \mathbb{R}$ are the parameters of the +# mean function. Let's now implement that using the new NNX backend. # %% import typing as tp @@ -207,9 +271,15 @@ def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N O"]: # %% [markdown] -# As we can see, the implementation is straightforward and concise. The `AbstractMeanFunction` module is a subclass of `nnx.Module` and may, therefore, be used in any `split` or `merge` call. Further, we have registered the intercept and slope parameters as `Real` parameter types. This registers their value in the PyGraph and means that they will be part of any operation applied to the PyGraph e.g., transforming and differentiation. +# As we can see, the implementation is straightforward and concise. The +# `AbstractMeanFunction` module is a subclass of `nnx.Module` and may, therefore, be +# used in any `split` or `merge` call. Further, we have registered the intercept and +# slope parameters as `Real` parameter types. This registers their value in the PyGraph +# and means that they will be part of any operation applied to the PyGraph e.g., +# transforming and differentiation. # -# To check our implementation worked, let's now plot the value of our mean function for a linearly spaced set of inputs. +# To check our implementation worked, let's now plot the value of our mean function for +# a linearly spaced set of inputs. # %% N = 100 @@ -219,7 +289,8 @@ def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N O"]: plt.plot(X, meanf(X)) # %% [markdown] -# Looks good! To conclude this section, let's now parameterise a GP with our new mean function and see how gradients may be computed. +# Looks good! To conclude this section, let's now parameterise a GP with our new mean +# function and see how gradients may be computed. # %% y = jnp.sin(X) @@ -230,7 +301,8 @@ def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N O"]: posterior = likelihood * prior # %% [markdown] -# We'll compute derivatives of the conjugate marginal log-likelihood, with respect to the unconstrained state of the kernel, mean function, and likelihood parameters. +# We'll compute derivatives of the conjugate marginal log-likelihood, with respect to +# the unconstrained state of the kernel, mean function, and likelihood parameters. # %% graphdef, params, others = nnx.split(posterior, Parameter, ...) @@ -248,7 +320,11 @@ def loss_fn(params: nnx.State, data: gpx.Dataset) -> ScalarFloat: # %% [markdown] # ## Conclusions # -# In this notebook we have explored how GPJax's Flax-based backend may be easily manipulated and extended. For a more applied look at this, see how we construct a kernel on polar coordinated in our [Kernel Guide](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/#custom-kernel) notebook. +# In this notebook we have explored how GPJax's Flax-based backend may be easily +# manipulated and extended. For a more applied look at this, see how we construct a +# kernel on polar coordinated in our [Kernel +# Guide](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/#custom-kernel) +# notebook. # # ## System configuration From 031b0a6cee0e2449b1016cbbcd5fa2c07359175a Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Thu, 15 Aug 2024 22:03:36 +0200 Subject: [PATCH 04/17] Respond to comments --- docs/examples/backend.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/docs/examples/backend.py b/docs/examples/backend.py index 6190feb72..231a09a54 100644 --- a/docs/examples/backend.py +++ b/docs/examples/backend.py @@ -61,7 +61,7 @@ constant_param = Real(value=1.0) meanf = Constant(constant_param) -meanf +print(meanf) # %% [markdown] # However, suppose you wish your mean function's constant parameter to be strictly @@ -72,7 +72,7 @@ constant_param = PositiveReal(value=1.0) meanf = Constant(constant_param) -meanf +print(meanf) # %% [markdown] # Were we to try and instantiate the `PositiveReal` class with a negative value, then an @@ -144,7 +144,7 @@ likelihood = gpx.likelihoods.Gaussian(100) posterior = likelihood * prior -posterior +print(posterior) # %% [markdown] # Now contained within the posterior PyGraph here there are four parameters: the @@ -159,7 +159,7 @@ from flax import nnx graphdef, state = nnx.split(posterior) -state +print(state) # %% [markdown] # The `State` object behaves just like a PyTree and, consequently, we may use JAX's @@ -171,7 +171,7 @@ import jax.tree_util as jtu updated_state = jtu.tree_map(lambda x: x + 1, state) -updated_state +print(updated_state) # %% [markdown] # Let us now use NNX's `merge` function to reconstruct the posterior distribution using @@ -179,16 +179,16 @@ # %% updated_posterior = nnx.merge(graphdef, updated_state) -updated_posterior +print(updated_posterior) # %% [markdown] # However, we begun this point of conversation with bijectors in mind, so let us now see # how bijectors may be applied to a collection of parameters in GPJax. Fortunately, this -# is very straightforward, and we may simply use the `trasnform` function as before. +# is very straightforward, and we may simply use the `transform` function as before. # %% transformed_state = transform(state, DEFAULT_BIJECTION, inverse=True) -transformed_state +print(transformed_state) # %% [markdown] # We may also (re-)constrain the parameters' values by setting the `inverse` argument of @@ -196,7 +196,7 @@ # %% retransformed_state = transform(transformed_state, DEFAULT_BIJECTION, inverse=False) -retransformed_state == transformed_state +print(retransformed_state == transformed_state) # %% [markdown] # ### Fine-Scale Control @@ -204,10 +204,10 @@ # One of the advantages of being able to split and re-merge the PyGraph is that we are # able to gain fine-scale control over the parameters' whose state we wish to realise. # This is by virtue of the fact that each of our parameters now inherit from -# `gpjax.parameters.Parameter`. In the former, we were simply extracting and `Parameter` -# from the posterior. However, suppose we only wish to extract those parameters whose -# support is the positive real line. This is easily achieved by altering the way in -# which we invoke `nnx.split`. +# `gpjax.parameters.Parameter`. In the former, we were simply extracting any +# `Parameter`subclass from the posterior. However, suppose we only wish to extract those +# parameters whose support is the positive real line. This is easily achieved by +# altering the way in which we invoke `nnx.split`. # %% from gpjax.parameters import PositiveReal @@ -306,7 +306,7 @@ def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N O"]: # %% graphdef, params, others = nnx.split(posterior, Parameter, ...) -params = transform(params, DEFAULT_BIJECTION) +params = transform(params, DEFAULT_BIJECTION, inverse=True) def loss_fn(params: nnx.State, data: gpx.Dataset) -> ScalarFloat: @@ -322,7 +322,7 @@ def loss_fn(params: nnx.State, data: gpx.Dataset) -> ScalarFloat: # # In this notebook we have explored how GPJax's Flax-based backend may be easily # manipulated and extended. For a more applied look at this, see how we construct a -# kernel on polar coordinated in our [Kernel +# kernel on polar coordinates in our [Kernel # Guide](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/#custom-kernel) # notebook. # From 0c073ac2c9ea365c89c74b078a0e448d3efe28e5 Mon Sep 17 00:00:00 2001 From: Francesco Zanetta Date: Fri, 16 Aug 2024 16:46:36 +0200 Subject: [PATCH 05/17] add scikit-learn dependency for docs --- poetry.lock | 72 ++++++++++++++++++++++++++++++++++++++++++++++++-- pyproject.toml | 1 + 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/poetry.lock b/poetry.lock index 1141e1264..3477e7fd4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "absl-py" @@ -1158,6 +1158,7 @@ description = "Python AST that abstracts the underlying Python version" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" files = [ + {file = "gast-0.6.0-py3-none-any.whl", hash = "sha256:52b182313f7330389f72b069ba00f174cfe2a06411099547288839c6cbafbd54"}, {file = "gast-0.6.0.tar.gz", hash = "sha256:88fc5300d32c7ac6ca7b515310862f71e6fdf2c029bbec7c66c0f5dd47b6b1fb"}, ] @@ -1515,6 +1516,17 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "joblib" +version = "1.4.2" +description = "Lightweight pipelining with Python functions" +optional = false +python-versions = ">=3.8" +files = [ + {file = "joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6"}, + {file = "joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"}, +] + [[package]] name = "json5" version = "0.9.25" @@ -3723,6 +3735,51 @@ files = [ {file = "ruff-0.6.0.tar.gz", hash = "sha256:272a81830f68f9bd19d49eaf7fa01a5545c5a2e86f32a9935bb0e4bb9a1db5b8"}, ] +[[package]] +name = "scikit-learn" +version = "1.5.1" +description = "A set of python modules for machine learning and data mining" +optional = false +python-versions = ">=3.9" +files = [ + {file = "scikit_learn-1.5.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:781586c414f8cc58e71da4f3d7af311e0505a683e112f2f62919e3019abd3745"}, + {file = "scikit_learn-1.5.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:f5b213bc29cc30a89a3130393b0e39c847a15d769d6e59539cd86b75d276b1a7"}, + {file = "scikit_learn-1.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1ff4ba34c2abff5ec59c803ed1d97d61b036f659a17f55be102679e88f926fac"}, + {file = "scikit_learn-1.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:161808750c267b77b4a9603cf9c93579c7a74ba8486b1336034c2f1579546d21"}, + {file = "scikit_learn-1.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:10e49170691514a94bb2e03787aa921b82dbc507a4ea1f20fd95557862c98dc1"}, + {file = "scikit_learn-1.5.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:154297ee43c0b83af12464adeab378dee2d0a700ccd03979e2b821e7dd7cc1c2"}, + {file = "scikit_learn-1.5.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:b5e865e9bd59396220de49cb4a57b17016256637c61b4c5cc81aaf16bc123bbe"}, + {file = "scikit_learn-1.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:909144d50f367a513cee6090873ae582dba019cb3fca063b38054fa42704c3a4"}, + {file = "scikit_learn-1.5.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:689b6f74b2c880276e365fe84fe4f1befd6a774f016339c65655eaff12e10cbf"}, + {file = "scikit_learn-1.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:9a07f90846313a7639af6a019d849ff72baadfa4c74c778821ae0fad07b7275b"}, + {file = "scikit_learn-1.5.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5944ce1faada31c55fb2ba20a5346b88e36811aab504ccafb9f0339e9f780395"}, + {file = "scikit_learn-1.5.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:0828673c5b520e879f2af6a9e99eee0eefea69a2188be1ca68a6121b809055c1"}, + {file = "scikit_learn-1.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:508907e5f81390e16d754e8815f7497e52139162fd69c4fdbd2dfa5d6cc88915"}, + {file = "scikit_learn-1.5.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97625f217c5c0c5d0505fa2af28ae424bd37949bb2f16ace3ff5f2f81fb4498b"}, + {file = "scikit_learn-1.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:da3f404e9e284d2b0a157e1b56b6566a34eb2798205cba35a211df3296ab7a74"}, + {file = "scikit_learn-1.5.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:88e0672c7ac21eb149d409c74cc29f1d611d5158175846e7a9c2427bd12b3956"}, + {file = "scikit_learn-1.5.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:7b073a27797a283187a4ef4ee149959defc350b46cbf63a84d8514fe16b69855"}, + {file = "scikit_learn-1.5.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b59e3e62d2be870e5c74af4e793293753565c7383ae82943b83383fdcf5cc5c1"}, + {file = "scikit_learn-1.5.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1bd8d3a19d4bd6dc5a7d4f358c8c3a60934dc058f363c34c0ac1e9e12a31421d"}, + {file = "scikit_learn-1.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:5f57428de0c900a98389c4a433d4a3cf89de979b3aa24d1c1d251802aa15e44d"}, + {file = "scikit_learn-1.5.1.tar.gz", hash = "sha256:0ea5d40c0e3951df445721927448755d3fe1d80833b0b7308ebff5d2a45e6414"}, +] + +[package.dependencies] +joblib = ">=1.2.0" +numpy = ">=1.19.5" +scipy = ">=1.6.0" +threadpoolctl = ">=3.1.0" + +[package.extras] +benchmark = ["matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "pandas (>=1.1.5)"] +build = ["cython (>=3.0.10)", "meson-python (>=0.16.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)"] +docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "polars (>=0.20.23)", "pooch (>=1.6.0)", "pydata-sphinx-theme (>=0.15.3)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)", "sphinx (>=7.3.7)", "sphinx-copybutton (>=0.5.2)", "sphinx-design (>=0.5.0)", "sphinx-gallery (>=0.16.0)", "sphinx-prompt (>=1.4.0)", "sphinx-remove-toctrees (>=1.0.0.post1)", "sphinxcontrib-sass (>=0.3.4)", "sphinxext-opengraph (>=0.9.1)"] +examples = ["matplotlib (>=3.3.4)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)"] +install = ["joblib (>=1.2.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)", "threadpoolctl (>=3.1.0)"] +maintenance = ["conda-lock (==2.5.6)"] +tests = ["black (>=24.3.0)", "matplotlib (>=3.3.4)", "mypy (>=1.9)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "polars (>=0.20.23)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pyarrow (>=12.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.2.1)", "scikit-image (>=0.17.2)"] + [[package]] name = "scipy" version = "1.14.0" @@ -3915,6 +3972,17 @@ files = [ ml-dtypes = ">=0.3.1" numpy = ">=1.22.0" +[[package]] +name = "threadpoolctl" +version = "3.5.0" +description = "threadpoolctl" +optional = false +python-versions = ">=3.8" +files = [ + {file = "threadpoolctl-3.5.0-py3-none-any.whl", hash = "sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467"}, + {file = "threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107"}, +] + [[package]] name = "tinycss2" version = "1.3.0" @@ -4213,4 +4281,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.12" -content-hash = "4b2ae7bad45029e7becc027913e0adbf40d21a2632971f4eb50c3a4096f20766" +content-hash = "99d22602c5c323f3ea78b4a80ca493069b946cea47b23d0a6e932c2900c385a4" diff --git a/pyproject.toml b/pyproject.toml index 82a757eda..cc1b64125 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,7 @@ pandas = "^1.5.3" pymdown-extensions = "^10.7.1" nbconvert = "^7.16.2" markdown-katex = "^202406.1035" +scikit-learn = "^1.5.1" [build-system] requires = ["poetry-core"] From 91022b987afd8e334ac5d9f0d0760eb810270d16 Mon Sep 17 00:00:00 2001 From: Francesco Zanetta Date: Fri, 16 Aug 2024 16:48:49 +0200 Subject: [PATCH 06/17] bugfix: change directory before running jupytext --- docs/scripts/gen_examples.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/scripts/gen_examples.py b/docs/scripts/gen_examples.py index 8fbe71b46..6f4bcd513 100644 --- a/docs/scripts/gen_examples.py +++ b/docs/scripts/gen_examples.py @@ -28,7 +28,7 @@ def process_file(file: Path, out_file: Path | None = None, execute: bool = False f"| jupyter nbconvert --to markdown --execute --stdin --output {out_file}" ) else: - command = f"jupytext --to markdown {file} --output {out_file}" + command += f"jupytext --to markdown {file} --output {out_file}" subprocess.run(command, shell=True, check=False) From 7d73c274477a08ec5e8c023a5c88fbb1a1583dd2 Mon Sep 17 00:00:00 2001 From: Francesco Zanetta Date: Fri, 16 Aug 2024 16:50:47 +0200 Subject: [PATCH 07/17] use local mpl style file --- examples/backend.py | 8 +++++--- examples/barycentres.py | 5 ++--- examples/bayesian_optimisation.py | 9 ++++++--- examples/classification.py | 15 +++++++-------- examples/collapsed_vi.py | 11 +++++++---- examples/constructing_new_kernels.py | 12 ++++++++---- examples/decision_making.py | 11 ++++++++--- examples/deep_kernels.py | 12 ++++++++---- examples/graph_kernels.py | 12 ++++++++---- examples/intro_to_gps.py | 11 +++++++---- examples/intro_to_kernels.py | 8 +++++--- examples/likelihoods_guide.py | 10 ++++++---- examples/oceanmodelling.py | 12 +++++++----- examples/poisson.py | 11 +++++++---- examples/regression.py | 6 +++--- examples/uncollapsed_vi.py | 12 ++++++++---- examples/utils.py | 5 +++++ examples/yacht.py | 11 ++++++----- 18 files changed, 113 insertions(+), 68 deletions(-) diff --git a/examples/backend.py b/examples/backend.py index 231a09a54..48a8dd8af 100644 --- a/examples/backend.py +++ b/examples/backend.py @@ -38,9 +38,11 @@ import matplotlib.pyplot as plt import gpjax as gpx -plt.style.use( - "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) +from examples.utils import use_mpl_style + +# set the default style for plotting +use_mpl_style() + cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] # %% [markdown] diff --git a/examples/barycentres.py b/examples/barycentres.py index 83e664de0..62e06753b 100644 --- a/examples/barycentres.py +++ b/examples/barycentres.py @@ -48,13 +48,12 @@ with install_import_hook("gpjax", "beartype.beartype"): import gpjax as gpx +from examples.utils import use_mpl_style key = jr.key(123) # set the default style for plotting -plt.style.use( - "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) +use_mpl_style() cols = plt.rcParams["axes.prop_cycle"].by_key()["color"] diff --git a/examples/bayesian_optimisation.py b/examples/bayesian_optimisation.py index fb4e6a475..ac660a693 100644 --- a/examples/bayesian_optimisation.py +++ b/examples/bayesian_optimisation.py @@ -44,10 +44,13 @@ from gpjax.typing import Array, FunctionalSample, ScalarFloat from jaxopt import ScipyBoundedMinimize +from examples.utils import use_mpl_style + +# set the default style for plotting +use_mpl_style() + key = jr.key(42) -plt.style.use( - "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) + cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] diff --git a/examples/classification.py b/examples/classification.py index 837c37237..7a97528cb 100644 --- a/examples/classification.py +++ b/examples/classification.py @@ -15,10 +15,6 @@ # name: python3 # --- -# %% -# %load_ext autoreload -# %autoreload 2 - # %% [markdown] # # Classification # @@ -54,12 +50,15 @@ with install_import_hook("gpjax", "beartype.beartype"): import gpjax as gpx +from examples.utils import use_mpl_style + tfd = tfp.distributions identity_matrix = jnp.eye -key = jr.key(123) -plt.style.use( - "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) + +# set the default style for plotting +use_mpl_style() + +key = jr.key(42) cols = plt.rcParams["axes.prop_cycle"].by_key()["color"] # %% [markdown] diff --git a/examples/collapsed_vi.py b/examples/collapsed_vi.py index b397e4c21..959515920 100644 --- a/examples/collapsed_vi.py +++ b/examples/collapsed_vi.py @@ -42,10 +42,13 @@ with install_import_hook("gpjax", "beartype.beartype"): import gpjax as gpx -key = jr.key(123) -plt.style.use( - "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) +from examples.utils import use_mpl_style + +# set the default style for plotting +use_mpl_style() + +key = jr.key(42) + cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] # %% [markdown] diff --git a/examples/constructing_new_kernels.py b/examples/constructing_new_kernels.py index 7936d7774..96d085259 100644 --- a/examples/constructing_new_kernels.py +++ b/examples/constructing_new_kernels.py @@ -40,11 +40,15 @@ with install_import_hook("gpjax", "beartype.beartype"): import gpjax as gpx -key = jr.key(123) +from examples.utils import use_mpl_style + tfb = tfp.bijectors -plt.style.use( - "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) + +# set the default style for plotting +use_mpl_style() + +key = jr.key(42) + cols = plt.rcParams["axes.prop_cycle"].by_key()["color"] # %% [markdown] diff --git a/examples/decision_making.py b/examples/decision_making.py index 7cc97f13e..e281e55d4 100644 --- a/examples/decision_making.py +++ b/examples/decision_making.py @@ -65,10 +65,15 @@ Float, ) + +from examples.utils import use_mpl_style + +# set the default style for plotting +use_mpl_style() + key = jr.key(42) -plt.style.use( - "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) + + cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] diff --git a/examples/deep_kernels.py b/examples/deep_kernels.py index 3370e9d16..b41538c39 100644 --- a/examples/deep_kernels.py +++ b/examples/deep_kernels.py @@ -58,12 +58,16 @@ import gpjax as gpx from gpjax.kernels.base import AbstractKernel -key = jr.key(123) -plt.style.use( - "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) + +from examples.utils import use_mpl_style + +# set the default style for plotting +use_mpl_style() cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] +key = jr.key(42) + + # %% [markdown] # ## Dataset # diff --git a/examples/graph_kernels.py b/examples/graph_kernels.py index 4cd397795..e9a28b9a7 100644 --- a/examples/graph_kernels.py +++ b/examples/graph_kernels.py @@ -42,10 +42,14 @@ with install_import_hook("gpjax", "beartype.beartype"): import gpjax as gpx -key = jr.key(123) -plt.style.use( - "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) + +from examples.utils import use_mpl_style + +# set the default style for plotting +use_mpl_style() + +key = jr.key(42) + cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] # %% [markdown] diff --git a/examples/intro_to_gps.py b/examples/intro_to_gps.py index e04ec9fff..9fc0eeae9 100644 --- a/examples/intro_to_gps.py +++ b/examples/intro_to_gps.py @@ -121,11 +121,14 @@ import pandas as pd import seaborn as sns import tensorflow_probability.substrates.jax as tfp -from docs.examples.utils import confidence_ellipse +from examples.utils import confidence_ellipse, use_mpl_style + +# set the default style for plotting +use_mpl_style() + +key = jr.key(42) + -plt.style.use( - "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] tfd = tfp.distributions diff --git a/examples/intro_to_kernels.py b/examples/intro_to_kernels.py index 8387b1c8b..396aa9361 100644 --- a/examples/intro_to_kernels.py +++ b/examples/intro_to_kernels.py @@ -40,10 +40,12 @@ from gpjax.typing import Array from sklearn.preprocessing import StandardScaler +from examples.utils import use_mpl_style + key = jr.key(42) -plt.style.use( - "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) + +# set the default style for plotting +use_mpl_style() cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] # %% [markdown] diff --git a/examples/likelihoods_guide.py b/examples/likelihoods_guide.py index 2bff2fdfe..10597a265 100644 --- a/examples/likelihoods_guide.py +++ b/examples/likelihoods_guide.py @@ -78,13 +78,15 @@ import matplotlib.pyplot as plt import tensorflow_probability.substrates.jax as tfp +from examples.utils import use_mpl_style + tfd = tfp.distributions -plt.style.use( - "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) + +# set the default style for plotting +use_mpl_style() cols = plt.rcParams["axes.prop_cycle"].by_key()["color"] -key = jr.key(123) +key = jr.key(42) n = 50 x = jnp.sort(jr.uniform(key=key, shape=(n, 1), minval=-3.0, maxval=3.0), axis=0) diff --git a/examples/oceanmodelling.py b/examples/oceanmodelling.py index 7422a8f0a..d46e89af8 100644 --- a/examples/oceanmodelling.py +++ b/examples/oceanmodelling.py @@ -45,11 +45,13 @@ with install_import_hook("gpjax", "beartype.beartype"): import gpjax as gpx -# Enable Float64 for more stable matrix inversions. -key = jr.key(123) -plt.style.use( - "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) +from examples.utils import use_mpl_style + +# set the default style for plotting +use_mpl_style() + +key = jr.key(42) + colors = rcParams["axes.prop_cycle"].by_key()["color"] diff --git a/examples/poisson.py b/examples/poisson.py index 1c59b0fef..284cb54c8 100644 --- a/examples/poisson.py +++ b/examples/poisson.py @@ -38,15 +38,18 @@ with install_import_hook("gpjax", "beartype.beartype"): import gpjax as gpx +from examples.utils import use_mpl_style + # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) tfd = tfp.distributions -key = jr.key(123) -plt.style.use( - "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) + +# set the default style for plotting +use_mpl_style() cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] +key = jr.key(42) + # %% [markdown] # ## Dataset # diff --git a/examples/regression.py b/examples/regression.py index bf777b1e4..c5ef2d50e 100644 --- a/examples/regression.py +++ b/examples/regression.py @@ -35,12 +35,12 @@ with install_import_hook("gpjax", "beartype.beartype"): import gpjax as gpx +from examples.utils import use_mpl_style + key = jr.key(123) # set the default style for plotting -plt.style.use( - "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) +use_mpl_style() cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] diff --git a/examples/uncollapsed_vi.py b/examples/uncollapsed_vi.py index 21d51f4c2..073819a6d 100644 --- a/examples/uncollapsed_vi.py +++ b/examples/uncollapsed_vi.py @@ -48,13 +48,17 @@ import gpjax as gpx import gpjax.kernels as jk -key = jr.key(123) +from examples.utils import use_mpl_style + tfb = tfp.bijectors -plt.style.use( - "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) + +key = jr.key(123) + +# set the default style for plotting +use_mpl_style() cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] + # %% [markdown] # ## Dataset # diff --git a/examples/utils.py b/examples/utils.py index b9ccaa65f..b8d5a81f8 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -73,3 +73,8 @@ def clean_legend(ax): by_label = dict(zip(labels, handles)) ax.legend(by_label.values(), by_label.keys()) return ax + + +def use_mpl_style(): + style_file = Path(__file__).parent / "gpjax.mplstyle" + plt.style.use(style_file) diff --git a/examples/yacht.py b/examples/yacht.py index 5e9ef0e72..940dff153 100644 --- a/examples/yacht.py +++ b/examples/yacht.py @@ -46,13 +46,14 @@ with install_import_hook("gpjax", "beartype.beartype"): import gpjax as gpx -# Enable Float64 for more stable matrix inversions. -key = jr.key(123) -plt.style.use( - "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) +from examples.utils import use_mpl_style + +# set the default style for plotting +use_mpl_style() cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] +key = jr.key(42) + # %% [markdown] # ## Data Loading # From c07ead4f5bb7342670d673ae43cee6a11ce29fae Mon Sep 17 00:00:00 2001 From: Francesco Zanetta Date: Fri, 16 Aug 2024 16:54:38 +0200 Subject: [PATCH 08/17] do not use MCMC for classification (it is *very* slow) --- README.md | 3 +- examples/classification.py | 184 +------------------------------------ 2 files changed, 2 insertions(+), 185 deletions(-) diff --git a/README.md b/README.md index 8ebf160da..721d4d747 100644 --- a/README.md +++ b/README.md @@ -72,10 +72,9 @@ helped to shape GPJax into the package it is today. ## Notebook examples > - [**Conjugate Inference**](https://docs.jaxgaussianprocesses.com/examples/regression/) -> - [**Classification with MCMC**](https://docs.jaxgaussianprocesses.com/examples/classification/) +> - [**Classification**](https://docs.jaxgaussianprocesses.com/examples/classification/) > - [**Sparse Variational Inference**](https://docs.jaxgaussianprocesses.com/examples/collapsed_vi/) > - [**Stochastic Variational Inference**](https://docs.jaxgaussianprocesses.com/examples/uncollapsed_vi/) -> - [**BlackJax Integration**](https://docs.jaxgaussianprocesses.com/examples/classification/#mcmc-inference) > - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/examples/classification/#laplace-approximation) > - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/#custom-kernel) > - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/examples/graph_kernels/) diff --git a/examples/classification.py b/examples/classification.py index 7a97528cb..c53cf4354 100644 --- a/examples/classification.py +++ b/examples/classification.py @@ -19,9 +19,7 @@ # # Classification # # In this notebook we demonstrate how to perform inference for Gaussian process models -# with non-Gaussian likelihoods via maximum a posteriori (MAP) and Markov chain Monte -# Carlo (MCMC). We focus on a classification task here and use -# [BlackJax](https://github.com/blackjax-devs/blackjax/) for sampling. +# with non-Gaussian likelihoods via maximum a posteriori (MAP). We focus on a classification task here. # %% # Enable Float64 for more stable matrix inversions. @@ -319,186 +317,6 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNorma ) ax.legend() -# %% [markdown] -# However, the Laplace approximation is still limited by considering information about -# the posterior at a single location. On the other hand, through approximate sampling, -# MCMC methods allow us to learn all information about the posterior distribution. - -# %% [markdown] -# ## MCMC inference -# -# An MCMC sampler works by starting at an initial position and -# drawing a sample from a cheap-to-simulate distribution known as the _proposal_. The -# next step is to determine whether this sample could be considered a draw from the -# posterior. We accomplish this using an _acceptance probability_ determined via the -# sampler's _transition kernel_ which depends on the current position and the -# unnormalised target posterior distribution. If the new sample is more _likely_, we -# accept it; otherwise, we reject it and stay in our current position. Repeating these -# steps results in a Markov chain (a random sequence that depends only on the last -# state) whose stationary distribution (the long-run empirical distribution of the -# states visited) is the posterior. For a gentle introduction, see the first chapter -# of [A Handbook of Markov Chain Monte Carlo](https://www.mcmchandbook.net/HandbookChapter1.pdf). -# -# ### MCMC through BlackJax -# -# Rather than implementing a suite of MCMC samplers, GPJax relies on MCMC-specific -# libraries for sampling functionality. We focus on -# [BlackJax](https://github.com/blackjax-devs/blackjax/) in this notebook, which we -# recommend adopting for general applications. -# -# We'll use the No U-Turn Sampler (NUTS) implementation given in BlackJax for sampling. -# For the interested reader, NUTS is a Hamiltonian Monte Carlo sampling scheme where -# the number of leapfrog integration steps is computed at each step of the change -# according to the NUTS algorithm. In general, samplers constructed under this -# framework are very efficient. -# -# We begin by generating _sensible_ initial positions for our sampler before defining -# an inference loop and sampling 500 values from our Markov chain. In practice, -# drawing more samples will be necessary. - -# %% -num_adapt = 600 -num_samples = 600 - -graphdef, params, *static_state = nnx.split(posterior, gpx.parameters.Parameter, ...) -params_bijection = gpx.parameters.DEFAULT_BIJECTION - -# Transform the parameters to the unconstrained space -params = gpx.parameters.transform(params, params_bijection, inverse=True) - - -def logprob_fn(params): - params = gpx.parameters.transform(params, params_bijection) - model = nnx.merge(graphdef, params, *static_state) - return gpx.objectives.log_posterior_density(model, D) - - -# jit compile -logprob_fn = jax.jit(logprob_fn) -_ = logprob_fn(params) - -adapt = blackjax.window_adaptation( - blackjax.nuts, logprob_fn, num_adapt, target_acceptance_rate=0.65, progress_bar=True -) - -# Initialise the chain -start = time() -last_state, kernel, _ = adapt.run(key, params) -print(f"Adaption time taken: {time() - start: .1f} seconds") - - -def inference_loop(rng_key, kernel, initial_state, num_samples): - def one_step(state, rng_key): - state, info = kernel(rng_key, state) - return state, (state, info) - - keys = jax.random.split(rng_key, num_samples) - _, (states, infos) = jax.lax.scan(one_step, initial_state, keys, unroll=10) - - return states, infos - - -# Sample from the posterior distribution -start = time() -states, infos = inference_loop(key, kernel, last_state, num_samples) -print(f"Sampling time taken: {time() - start: .1f} seconds") - -# %% [markdown] -# ### Sampler efficiency -# -# BlackJax gives us easy access to our sampler's efficiency through metrics such as the -# sampler's _acceptance probability_ (the number of times that our chain accepted a -# proposed sample, divided by the total number of steps run by the chain). For NUTS and -# Hamiltonian Monte Carlo sampling, we typically seek an acceptance rate of 60-70% to -# strike the right balance between having a chain which is _stuck_ and rarely moves -# versus a chain that is too jumpy with frequent small steps. - -# %% -acceptance_rate = jnp.mean(infos.acceptance_probability) -print(f"Acceptance rate: {acceptance_rate:.2f}") - -# %% [markdown] -# Our acceptance rate is slightly too large, prompting an examination of the chain's -# trace plots. A well-mixing chain will have very few (if any) flat spots in its trace -# plot whilst also not having too many steps in the same direction. In addition to -# the model's hyperparameters, there will be 500 samples for each of the 100 latent -# function values in the `states.position` dictionary. We depict the chains that -# correspond to the model hyperparameters and the first value of the latent function -# for brevity. - -# %% -fig, (ax0, ax1, ax2) = plt.subplots(ncols=3, figsize=(10, 3)) -ax0.plot(states.position.prior.kernel.lengthscale.value) -ax1.plot(states.position.prior.kernel.variance.value) -ax2.plot(states.position.latent.value[:, 1, :]) -ax0.set_title("Kernel Lengthscale") -ax1.set_title("Kernel Variance") -ax2.set_title("Latent Function (index = 1)") - -# %% [markdown] -# ## Prediction -# -# Having obtained samples from the posterior, we draw ten instances from our model's -# predictive distribution per MCMC sample. Using these draws, we will be able to -# compute credible values and expected values under our posterior distribution. -# -# An ideal Markov chain would have samples completely uncorrelated with their -# neighbours after a single lag. However, in practice, correlations often exist -# within our chain's sample set. A commonly used technique to try and reduce this -# correlation is _thinning_ whereby we select every $n$th sample where $n$ is the -# minimum lag length at which we believe the samples are uncorrelated. Although further -# analysis of the chain's autocorrelation is required to find appropriate thinning -# factors, we employ a thin factor of 10 for demonstration purposes. - -# %% -thin_factor = 20 -posterior_samples = [] - -for i in trange(0, num_samples, thin_factor, desc="Drawing posterior samples"): - sample_params = jtu.tree_map(lambda samples, i=i: samples[i], states.position) - sample_params = gpx.parameters.transform(sample_params, params_bijection) - model = nnx.merge(graphdef, sample_params, *static_state) - latent_dist = model.predict(xtest, train_data=D) - predictive_dist = model.likelihood(latent_dist) - posterior_samples.append(predictive_dist.sample(seed=key, sample_shape=(10,))) - -posterior_samples = jnp.vstack(posterior_samples) -lower_ci, upper_ci = jnp.percentile(posterior_samples, jnp.array([2.5, 97.5]), axis=0) -expected_val = jnp.mean(posterior_samples, axis=0) - -# %% [markdown] -# -# Finally, we end this tutorial by plotting the predictions obtained from our model -# against the observed data. - -# %% -fig, ax = plt.subplots() -ax.scatter(x, y, color=cols[0], label="Observations", zorder=2, alpha=0.7) -ax.plot(xtest, expected_val, color=cols[1], label="Predicted mean", zorder=1) -ax.fill_between( - xtest.flatten(), - lower_ci.flatten(), - upper_ci.flatten(), - alpha=0.2, - color=cols[1], - label="95\\% CI", -) -ax.plot( - xtest, - lower_ci.flatten(), - color=cols[1], - linestyle="--", - linewidth=1, -) -ax.plot( - xtest, - upper_ci.flatten(), - color=cols[1], - linestyle="--", - linewidth=1, -) -ax.legend() - # %% [markdown] # ## System configuration From 4ebc673debe7085d01a5dede88ed66f64ee02b0b Mon Sep 17 00:00:00 2001 From: Francesco Zanetta Date: Fri, 16 Aug 2024 17:00:08 +0200 Subject: [PATCH 09/17] [skip-ci] update github workflows for docs --- .github/workflows/build_docs.yml | 7 +------ .github/workflows/test_docs.yml | 16 +--------------- 2 files changed, 2 insertions(+), 21 deletions(-) diff --git a/.github/workflows/build_docs.yml b/.github/workflows/build_docs.yml index 433d4e345..265cc536c 100644 --- a/.github/workflows/build_docs.yml +++ b/.github/workflows/build_docs.yml @@ -52,16 +52,11 @@ jobs: virtualenvs-in-project: false installer-parallel: true - - name: Install LaTex - run: | - sudo apt-get update - sudo apt-get install texlive-fonts-recommended texlive-fonts-extra texlive-latex-extra dvipng cm-super - - name: Build the documentation with MKDocs run: | poetry install --all-extras --with docs conda install pandoc - poetry run mkdocs build + poetry run python docs/scripts/gen_examples.py --execute && poetry run mkdocs build - name: Deploy Page 🚀 uses: JamesIves/github-pages-deploy-action@v4.4.1 diff --git a/.github/workflows/test_docs.yml b/.github/workflows/test_docs.yml index 9dc9c55d5..129de0e39 100644 --- a/.github/workflows/test_docs.yml +++ b/.github/workflows/test_docs.yml @@ -33,20 +33,6 @@ jobs: auto-update-conda: true python-version: ${{ matrix.python-version }} - # Install katex for math support - - name: Install NPM - uses: actions/setup-node@v3 - with: - node-version: 16 - - name: Install KaTeX - run: | - npm install katex - - - name: Install LaTex - run: | - sudo apt-get update - sudo apt-get install texlive-fonts-recommended texlive-fonts-extra texlive-latex-extra dvipng cm-super - # Install Poetry and build the documentation - name: Install and configure Poetry uses: snok/install-poetry@v1 @@ -60,4 +46,4 @@ jobs: run: | poetry install --all-extras --with docs conda install pandoc - poetry run python docs/scripts/gen_examples.py && poetry run mkdocs build + poetry run python docs/scripts/gen_examples.py --execute && poetry run mkdocs build From 6450603852d3cb8ed10bfc32534eb3a700eaee8f Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Fri, 16 Aug 2024 17:42:39 +0200 Subject: [PATCH 10/17] Fix split --- gpjax/parameters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpjax/parameters.py b/gpjax/parameters.py index 0aa746148..93a356b6c 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -58,7 +58,7 @@ def _inner(param): param = param.replace(transformed_value) return param - gp_params, *other_params = params.split(Parameter, ...) + gp_params, *other_params = nnx.split(Parameter, ...) transformed_gp_params: nnx.State = jtu.tree_map( lambda x: _inner(x), gp_params, From 4ee615ecd65a98bf229da3630c38d35570e6bd34 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Fri, 16 Aug 2024 17:48:34 +0200 Subject: [PATCH 11/17] Fix split --- gpjax/parameters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpjax/parameters.py b/gpjax/parameters.py index 93a356b6c..c57dbee39 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -58,7 +58,7 @@ def _inner(param): param = param.replace(transformed_value) return param - gp_params, *other_params = nnx.split(Parameter, ...) + gp_params, *other_params = nnx.split(params, Parameter, ...) transformed_gp_params: nnx.State = jtu.tree_map( lambda x: _inner(x), gp_params, From 2d323b2c00ced99c56363600754ccb472264fe60 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Fri, 16 Aug 2024 17:52:33 +0200 Subject: [PATCH 12/17] Fix split --- gpjax/parameters.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gpjax/parameters.py b/gpjax/parameters.py index c57dbee39..3f7144498 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -58,13 +58,13 @@ def _inner(param): param = param.replace(transformed_value) return param - gp_params, *other_params = nnx.split(params, Parameter, ...) + graphdef, gp_params, *other_params = nnx.split(params, Parameter, ...) transformed_gp_params: nnx.State = jtu.tree_map( lambda x: _inner(x), gp_params, is_leaf=lambda x: isinstance(x, nnx.VariableState), ) - return nnx.State.merge(transformed_gp_params, *other_params) + return nnx.merge(graphdef, transformed_gp_params, *other_params) class Parameter(nnx.Variable[T]): From 5b9f0581f7a107e1a4e4846bf6644a6830d7993c Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Fri, 16 Aug 2024 18:09:07 +0200 Subject: [PATCH 13/17] Fix xdoctest --- gpjax/parameters.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/gpjax/parameters.py b/gpjax/parameters.py index 3f7144498..0bb55f4b7 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -31,11 +31,8 @@ def transform( >>> ) >>> params_bijection = {'positive': tfb.Softplus()} >>> transformed_params = transform(params, params_bijection) - >>> transformed_params["a"] - PositiveReal( - value=Array([1.3132617], dtype=float32), - _tag='positive' - ) + >>> print(transformed_params["a"].value) + [1.3132617] ``` From 6fcb83d6fcf1e87b0555f18aea01393d1ad25c92 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Fri, 16 Aug 2024 18:27:01 +0200 Subject: [PATCH 14/17] Fix doc --- examples/backend.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/backend.py b/examples/backend.py index 48a8dd8af..118098d28 100644 --- a/examples/backend.py +++ b/examples/backend.py @@ -117,7 +117,7 @@ # positive parameters. To apply this, we may invoke the following # %% -transform(constant_param, DEFAULT_BIJECTION, inverse=True) +transform(meanf, DEFAULT_BIJECTION, inverse=True) # %% [markdown] # The parameter's value was changed here from 1. to 0.54132485. This is the result of @@ -126,7 +126,7 @@ # would be more pronounced. # %% -transform(PositiveReal(value=1e-6), DEFAULT_BIJECTION, inverse=True) +transform(Constant(PositiveReal(value=1e-6)), DEFAULT_BIJECTION, inverse=True) # %% [markdown] # ### Transforming Multiple Parameters @@ -198,7 +198,6 @@ # %% retransformed_state = transform(transformed_state, DEFAULT_BIJECTION, inverse=False) -print(retransformed_state == transformed_state) # %% [markdown] # ### Fine-Scale Control From 02bfd011b1cb1aaceab1508d3120d762634a2edc Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Fri, 16 Aug 2024 19:05:43 +0200 Subject: [PATCH 15/17] Add serial build --- docs/scripts/gen_examples.py | 32 +++++++++++++++++++------------- mkdocs.yml | 2 +- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/docs/scripts/gen_examples.py b/docs/scripts/gen_examples.py index 6f4bcd513..632fad7fe 100644 --- a/docs/scripts/gen_examples.py +++ b/docs/scripts/gen_examples.py @@ -64,21 +64,26 @@ def main(args): print(files) # process files in parallel - with ThreadPoolExecutor(max_workers=args.max_workers) as executor: - futures = [] - for file in files: - out_file = out_dir / f"{file.stem}.md" - futures.append( - executor.submit( - process_file, file, out_file=out_file, execute=args.execute + if args.parallel: + with ThreadPoolExecutor(max_workers=args.max_workers) as executor: + futures = [] + for file in files: + out_file = out_dir / f"{file.stem}.md" + futures.append( + executor.submit( + process_file, file, out_file=out_file, execute=args.execute + ) ) - ) - for future in as_completed(futures): - try: - future.result() - except Exception as e: - print(f"Error processing file: {e}") + for future in as_completed(futures): + try: + future.result() + except Exception as e: + print(f"Error processing file: {e}") + else: + for file in files: + out_file = out_dir / f"{file.stem}.md" + process_file(file, out_file=out_file, execute=args.execute) if __name__ == "__main__": @@ -91,6 +96,7 @@ def main(args): parser.add_argument( "--outdir", type=Path, default=project_root / "docs" / "_examples" ) + parser.add_argument("--parallel", type=bool, default=False) args = parser.parse_args() main(args) diff --git a/mkdocs.yml b/mkdocs.yml index c7a91fd5d..a1b4728b1 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -32,7 +32,7 @@ nav: - 📖 Guides for customisation: - Kernels: _examples/constructing_new_kernels.md - Likelihoods: _examples/likelihoods_guide.md - - Model Guide: examples/backend.md + - Model Guide: _examples/backend.md - UCI regression: _examples/yacht.md # - 💻 Raw tutorial code: give_me_the_code.md - Community: From 7e628f004c2d29caba764c3bff7fd5509b56ba08 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Fri, 16 Aug 2024 21:09:30 +0200 Subject: [PATCH 16/17] Update parameters transform and backend doc --- examples/backend.py | 69 ++++++++++++++++++++++++++++++++++-------- gpjax/parameters.py | 5 +-- tests/test_markdown.py | 4 +-- 3 files changed, 61 insertions(+), 17 deletions(-) diff --git a/examples/backend.py b/examples/backend.py index 118098d28..fb40f59b3 100644 --- a/examples/backend.py +++ b/examples/backend.py @@ -29,14 +29,31 @@ # %% # Enable Float64 for more stable matrix inversions. -from jax import config, grad +from jax import ( + config, + grad, +) config.update("jax_enable_x64", True) import jax.numpy as jnp +from jaxtyping import ( + Float, + install_import_hook, +) import matplotlib as mpl import matplotlib.pyplot as plt -import gpjax as gpx + +from gpjax.mean_functions import Constant +from gpjax.parameters import ( + Parameter, + Real, +) + +with install_import_hook("gpjax", "beartype.beartype"): + import gpjax as gpx + +from flax import nnx from examples.utils import use_mpl_style @@ -58,20 +75,23 @@ # parameter as follows: # %% -from gpjax.mean_functions import Constant -from gpjax.parameters import Real - -constant_param = Real(value=1.0) +constant_param = Parameter(value=1.0, tag=None) meanf = Constant(constant_param) print(meanf) # %% [markdown] # However, suppose you wish your mean function's constant parameter to be strictly -# positive. This is easy to achieve by using the correct Parameter type. +# positive. This is easy to achieve by using the correct Parameter type which, in this case, will be the `PositiveReal`. However, any Parameter that subclasses from `Parameter` will be transformed by GPJax. # %% from gpjax.parameters import PositiveReal +issubclass(PositiveReal, Parameter) + +# %% [markdown] +# Injecting this newly constrained parameter into our mean function is then identical to before. + +# %% constant_param = PositiveReal(value=1.0) meanf = Constant(constant_param) print(meanf) @@ -114,10 +134,12 @@ # %% [markdown] # We see here that the Softplus bijector is specified as the default for strictly -# positive parameters. To apply this, we may invoke the following +# positive parameters. To apply this, we must first realise the _state_ of our model. This is achieved using the `split` function provided by `nnx`. # %% -transform(meanf, DEFAULT_BIJECTION, inverse=True) +_, _params = nnx.split(meanf, Parameter) + +tranformed_params = transform(_params, DEFAULT_BIJECTION, inverse=True) # %% [markdown] # The parameter's value was changed here from 1. to 0.54132485. This is the result of @@ -126,7 +148,9 @@ # would be more pronounced. # %% -transform(Constant(PositiveReal(value=1e-6)), DEFAULT_BIJECTION, inverse=True) +_, _close_to_zero_state = nnx.split(Constant(PositiveReal(value=1e-6)), Parameter) + +transform(_close_to_zero_state, DEFAULT_BIJECTION, inverse=True) # %% [markdown] # ### Transforming Multiple Parameters @@ -158,8 +182,6 @@ # from a give `State`. # %% -from flax import nnx - graphdef, state = nnx.split(posterior) print(state) @@ -316,7 +338,28 @@ def loss_fn(params: nnx.State, data: gpx.Dataset) -> ScalarFloat: return -gpx.objectives.conjugate_mll(model, data) -grad(loss_fn)(params, D) +param_grads = grad(loss_fn)(params, D) + +# %% [markdown] +# In practice, you would wish to perform multiple iterations of gradient descent to learn the optimal parameter values. However, for the purposes of illustration, we use another `tree_map` in the below to update the parameters' state using their previously computed gradients. As you can see, the really beauty in having access to the model's state is that we have full control over the operations that we perform to the state. + +# %% +LEARNING_RATE = 0.01 +optimised_params = jtu.tree_map( + lambda _params, _grads: _params + LEARNING_RATE * _grads, params, param_grads +) + +# %% [markdown] +# Now we will plot the updated mean function alongside its initial form. To achieve this, we first merge the state back into the model using `merge`, and we then simply invoke the model as normal. + +# %% +optimised_posterior = nnx.merge(graphdef, optimised_params, *others) + +fig, ax = plt.subplots() +ax.plot(X, optimised_posterior.prior.mean_function(X), label="Updated mean function") +ax.plot(X, meanf(X), label="Initial mean function") +ax.legend() +ax.set(xlabel="x", ylabel="m(x)") # %% [markdown] # ## Conclusions diff --git a/gpjax/parameters.py b/gpjax/parameters.py index 0bb55f4b7..c54676dd1 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -55,13 +55,14 @@ def _inner(param): param = param.replace(transformed_value) return param - graphdef, gp_params, *other_params = nnx.split(params, Parameter, ...) + gp_params, *other_params = params.split(Parameter, ...) + transformed_gp_params: nnx.State = jtu.tree_map( lambda x: _inner(x), gp_params, is_leaf=lambda x: isinstance(x, nnx.VariableState), ) - return nnx.merge(graphdef, transformed_gp_params, *other_params) + return nnx.State.merge(transformed_gp_params, *other_params) class Parameter(nnx.Variable[T]): diff --git a/tests/test_markdown.py b/tests/test_markdown.py index d28e90783..e543db8e5 100644 --- a/tests/test_markdown.py +++ b/tests/test_markdown.py @@ -5,12 +5,12 @@ # Ensure that code chunks within any markdown files execute without error -@pytest.mark.parametrize("fpath", pathlib.Path("gpjax/").glob("**/*.md"), ids=str) +@pytest.mark.parametrize("fpath", pathlib.Path("gpjax/").glob("*.md"), ids=str) def test_source_good(fpath): check_md_file(fpath=fpath, memory=True) -@pytest.mark.parametrize("fpath", pathlib.Path("docs").glob("**/*.md"), ids=str) +@pytest.mark.parametrize("fpath", pathlib.Path("docs").glob("*.md"), ids=str) def test_docs_good(fpath): check_md_file(fpath=fpath, memory=True) From 172e61e2d0bb2cdb53cb7e511d7de75d81fb5789 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Fri, 16 Aug 2024 21:09:55 +0200 Subject: [PATCH 17/17] Update parameters transform and backend doc --- examples/backend.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/examples/backend.py b/examples/backend.py index fb40f59b3..8e8bbe3a9 100644 --- a/examples/backend.py +++ b/examples/backend.py @@ -81,7 +81,9 @@ # %% [markdown] # However, suppose you wish your mean function's constant parameter to be strictly -# positive. This is easy to achieve by using the correct Parameter type which, in this case, will be the `PositiveReal`. However, any Parameter that subclasses from `Parameter` will be transformed by GPJax. +# positive. This is easy to achieve by using the correct Parameter type which, in this +# case, will be the `PositiveReal`. However, any Parameter that subclasses from +# `Parameter` will be transformed by GPJax. # %% from gpjax.parameters import PositiveReal @@ -134,7 +136,8 @@ # %% [markdown] # We see here that the Softplus bijector is specified as the default for strictly -# positive parameters. To apply this, we must first realise the _state_ of our model. This is achieved using the `split` function provided by `nnx`. +# positive parameters. To apply this, we must first realise the _state_ of our model. +# This is achieved using the `split` function provided by `nnx`. # %% _, _params = nnx.split(meanf, Parameter) @@ -341,7 +344,11 @@ def loss_fn(params: nnx.State, data: gpx.Dataset) -> ScalarFloat: param_grads = grad(loss_fn)(params, D) # %% [markdown] -# In practice, you would wish to perform multiple iterations of gradient descent to learn the optimal parameter values. However, for the purposes of illustration, we use another `tree_map` in the below to update the parameters' state using their previously computed gradients. As you can see, the really beauty in having access to the model's state is that we have full control over the operations that we perform to the state. +# In practice, you would wish to perform multiple iterations of gradient descent to +# learn the optimal parameter values. However, for the purposes of illustration, we use +# another `tree_map` in the below to update the parameters' state using their previously +# computed gradients. As you can see, the really beauty in having access to the model's +# state is that we have full control over the operations that we perform to the state. # %% LEARNING_RATE = 0.01 @@ -350,7 +357,9 @@ def loss_fn(params: nnx.State, data: gpx.Dataset) -> ScalarFloat: ) # %% [markdown] -# Now we will plot the updated mean function alongside its initial form. To achieve this, we first merge the state back into the model using `merge`, and we then simply invoke the model as normal. +# Now we will plot the updated mean function alongside its initial form. To achieve +# this, we first merge the state back into the model using `merge`, and we then simply +# invoke the model as normal. # %% optimised_posterior = nnx.merge(graphdef, optimised_params, *others)