Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flax/nnx backend #440

Merged
merged 212 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
212 commits
Select commit Hold shift + click to select a range
f7ff6b1
add flax v0.8.0 to deps, temporarily from github main branch
frazane Jan 6, 2024
04ce459
main gps objects as nnx modules
frazane Jan 6, 2024
2dc9479
integrators as nnx dataclasses and some static typing refactoring
frazane Jan 6, 2024
7efbd20
likelihoods as nnx dataclasses modules and some static typing refacto…
frazane Jan 6, 2024
ae034e2
small refactoring
frazane Jan 6, 2024
c530481
mean functions as nnx dataclasses modules and some refactoring
frazane Jan 6, 2024
7d14eaf
bugfix
frazane Jan 6, 2024
40891f7
objectives as nnx dataclasses modules
frazane Jan 6, 2024
e749bc6
variational families with nnx
frazane Jan 6, 2024
3568d9c
kernels base with nnx
frazane Jan 6, 2024
28c3b9e
wip stationary kernels
frazane Jan 6, 2024
49cc6e0
wip nonstationary kernels
frazane Jan 6, 2024
9d2567e
wip non euclidean kernels
frazane Jan 6, 2024
f18701c
computations with nnx
frazane Jan 7, 2024
61bf22b
rff with nnx
frazane Jan 7, 2024
f0736f7
bugfix
frazane Jan 7, 2024
64f4f3b
Merge branch 'main' into flax-backend
frazane Feb 25, 2024
6038458
stationary kernels as normal classes
frazane Feb 25, 2024
447dad1
nonstationary kernels as normal classes
frazane Feb 25, 2024
2d343c2
noneuclidean kernels as normal classes
frazane Feb 25, 2024
af33559
rff as standard class + stationary kernel abstract class for static t…
frazane Feb 25, 2024
cfabc4c
started work on parameters
frazane Feb 25, 2024
645c111
more objects as normal classes
frazane Feb 25, 2024
c536b1a
gps as normal classes
frazane Feb 25, 2024
cac712e
integrators as normal classes
frazane Feb 25, 2024
f50ffe7
dataset is not a pytree
frazane Feb 25, 2024
7e1a23e
removed superfluous inits
frazane Feb 25, 2024
5d21926
register dataset as pytree
frazane Feb 26, 2024
8133850
use parameters here and there
frazane Feb 26, 2024
4333f0b
set active_dims default to 1
frazane Feb 26, 2024
acd72a5
start working on tests
frazane Feb 26, 2024
e748711
active_dims defaults to None
frazane Mar 1, 2024
003f380
rewrite objectives as functions
frazane Mar 1, 2024
1a132ac
black + isort
frazane Mar 1, 2024
0167d74
remove objective from cite
frazane Mar 2, 2024
135e130
fix dataset repr
frazane Mar 2, 2024
4f26690
pass tests for variational families
frazane Mar 2, 2024
95ab9cd
active_dims defaults to None
frazane Mar 4, 2024
174e9a6
use generic Objective type
frazane Mar 4, 2024
8820991
small fixes
frazane Mar 4, 2024
6aadcd6
make 'active_dims' required parameter, fix static typing and beartype…
frazane Mar 5, 2024
eb5c0de
pass tests/test_kernels/test_computation.py
frazane Mar 5, 2024
5a70fc4
rewrite tests for nonstationary kernels + pass tests
frazane Mar 5, 2024
a3ca7f4
adapt to nnx's explicit variables + miscellaneous fixes
frazane Mar 5, 2024
7bbfcfb
rewrite of objectives as simple functions, [WIP] started rewriting tests
frazane Mar 5, 2024
22077b7
rewrite and pass tests for objectives
frazane Mar 5, 2024
d8011b9
rewrite fit function
frazane Mar 6, 2024
36bc8fd
remove gpjax.base module
frazane Mar 6, 2024
77cfd88
remove base module tests
frazane Mar 6, 2024
c8b70bd
rewrite and pass tests for fit
frazane Mar 6, 2024
e6ff0e1
finish kernels and pass all tests
frazane Mar 6, 2024
7748ba6
pass all tests except decision making
frazane Mar 6, 2024
fc9a1a7
pass all tests 🚀
frazane Mar 6, 2024
130698c
update and run classification notebook (python cells)
frazane Mar 6, 2024
4d0f7f8
pass doctests
frazane Mar 8, 2024
f2dfabe
pass integration tests, more checks to parameters
frazane Mar 8, 2024
e90869c
linting and formatting
frazane Mar 8, 2024
ddfffab
update barycentres and classification examples
frazane Mar 8, 2024
62844d9
update project files
frazane Mar 8, 2024
8782915
update ruff and make it happy
frazane Mar 8, 2024
d52530e
lint + format all doc examples
frazane Mar 8, 2024
15132dd
[skip ci] change how dimensions are specified for kernels, update ker…
frazane Mar 12, 2024
466e748
[skip ci] api reference looks pretty now, implemented template patter…
frazane Mar 13, 2024
700a406
[skip ci] wip - fixing math rendering in documentation - almost there
frazane Mar 13, 2024
30ed414
Update notebooks. (#447)
daniel-dodd Mar 22, 2024
a08d690
Update likelihoods.py (#446)
daniel-dodd Mar 22, 2024
7c2a450
Adding tagged parameters and updated notebooks
thomaspinder Jun 5, 2024
d6b86ef
Update likelihoods.py (#446)
daniel-dodd Mar 22, 2024
890d054
Merge
thomaspinder Jun 14, 2024
1b4d7f6
Update notebooks
thomaspinder Jun 14, 2024
e2e1fd6
Merge branch 'tagged-params' of github.com:JaxGaussianProcesses/GPJax…
thomaspinder Jun 14, 2024
32488a2
Fix linting
thomaspinder Jun 14, 2024
15c84db
Fix missing dep.
thomaspinder Jun 14, 2024
fdbf378
Fix integration test
thomaspinder Jun 14, 2024
9cd1ac5
Readd docs deps
thomaspinder Jun 14, 2024
63d2c32
Fix docstrings
thomaspinder Jun 14, 2024
3a7ee7f
Update lockfile
thomaspinder Jun 14, 2024
b503ead
Update parameter refs
thomaspinder Jun 14, 2024
037ec0a
Fix broken tests
thomaspinder Jun 14, 2024
740510e
Remove PyTrees doc
thomaspinder Jun 14, 2024
306190b
Failing split order
thomaspinder Jun 25, 2024
c359936
Merge pull request #452 from JaxGaussianProcesses/tagged-params
thomaspinder Jun 26, 2024
65bd81a
NNX update
thomaspinder Jun 27, 2024
9495947
Merge pull request #457 from JaxGaussianProcesses/nnx_update
thomaspinder Jul 1, 2024
adfc3f0
add flax v0.8.0 to deps, temporarily from github main branch
frazane Jan 6, 2024
01c2fd7
main gps objects as nnx modules
frazane Jan 6, 2024
fe2c214
integrators as nnx dataclasses and some static typing refactoring
frazane Jan 6, 2024
ba7f175
likelihoods as nnx dataclasses modules and some static typing refacto…
frazane Jan 6, 2024
3502bb7
small refactoring
frazane Jan 6, 2024
e399709
mean functions as nnx dataclasses modules and some refactoring
frazane Jan 6, 2024
9b1a2a4
bugfix
frazane Jan 6, 2024
6406809
objectives as nnx dataclasses modules
frazane Jan 6, 2024
4ad04d1
variational families with nnx
frazane Jan 6, 2024
63ca736
kernels base with nnx
frazane Jan 6, 2024
f45c9c9
wip stationary kernels
frazane Jan 6, 2024
bfbcf6d
wip nonstationary kernels
frazane Jan 6, 2024
e3a0c5e
wip non euclidean kernels
frazane Jan 6, 2024
6faf852
computations with nnx
frazane Jan 7, 2024
1e0a5a9
rff with nnx
frazane Jan 7, 2024
1c25f23
bugfix
frazane Jan 7, 2024
a4a9410
stationary kernels as normal classes
frazane Feb 25, 2024
cce9646
nonstationary kernels as normal classes
frazane Feb 25, 2024
e9d4d45
noneuclidean kernels as normal classes
frazane Feb 25, 2024
975f2df
rff as standard class + stationary kernel abstract class for static t…
frazane Feb 25, 2024
7e82af6
started work on parameters
frazane Feb 25, 2024
af09dc7
more objects as normal classes
frazane Feb 25, 2024
76fc0bf
gps as normal classes
frazane Feb 25, 2024
ca9fc24
integrators as normal classes
frazane Feb 25, 2024
e69b84c
dataset is not a pytree
frazane Feb 25, 2024
58e036c
removed superfluous inits
frazane Feb 25, 2024
79df530
register dataset as pytree
frazane Feb 26, 2024
bd8e5ae
use parameters here and there
frazane Feb 26, 2024
24e1aa6
set active_dims default to 1
frazane Feb 26, 2024
c889cbe
start working on tests
frazane Feb 26, 2024
86a25ea
active_dims defaults to None
frazane Mar 1, 2024
39aaace
rewrite objectives as functions
frazane Mar 1, 2024
651f070
black + isort
frazane Mar 1, 2024
49758b3
remove objective from cite
frazane Mar 2, 2024
49c2c6d
fix dataset repr
frazane Mar 2, 2024
91e063d
pass tests for variational families
frazane Mar 2, 2024
bedee48
active_dims defaults to None
frazane Mar 4, 2024
87e2302
use generic Objective type
frazane Mar 4, 2024
8a04aa0
small fixes
frazane Mar 4, 2024
e72ba7b
make 'active_dims' required parameter, fix static typing and beartype…
frazane Mar 5, 2024
7a3be54
pass tests/test_kernels/test_computation.py
frazane Mar 5, 2024
6237548
rewrite tests for nonstationary kernels + pass tests
frazane Mar 5, 2024
fb0acdb
adapt to nnx's explicit variables + miscellaneous fixes
frazane Mar 5, 2024
6d0ead9
rewrite of objectives as simple functions, [WIP] started rewriting tests
frazane Mar 5, 2024
192dbc4
rewrite and pass tests for objectives
frazane Mar 5, 2024
63dc0a2
rewrite fit function
frazane Mar 6, 2024
44fce75
remove gpjax.base module
frazane Mar 6, 2024
4f32d19
remove base module tests
frazane Mar 6, 2024
5e1b54c
rewrite and pass tests for fit
frazane Mar 6, 2024
e720333
finish kernels and pass all tests
frazane Mar 6, 2024
b60a5e4
pass all tests except decision making
frazane Mar 6, 2024
840947e
pass all tests 🚀
frazane Mar 6, 2024
dd31205
update and run classification notebook (python cells)
frazane Mar 6, 2024
99969d5
pass doctests
frazane Mar 8, 2024
87af55c
pass integration tests, more checks to parameters
frazane Mar 8, 2024
a08baac
linting and formatting
frazane Mar 8, 2024
5275608
update barycentres and classification examples
frazane Mar 8, 2024
ed7a513
update project files
frazane Mar 8, 2024
1cf1e40
update ruff and make it happy
frazane Mar 8, 2024
c333e51
lint + format all doc examples
frazane Mar 8, 2024
39948a8
[skip ci] change how dimensions are specified for kernels, update ker…
frazane Mar 12, 2024
c16e7ff
[skip ci] api reference looks pretty now, implemented template patter…
frazane Mar 13, 2024
2d3951a
[skip ci] wip - fixing math rendering in documentation - almost there
frazane Mar 13, 2024
1783be0
Update notebooks. (#447)
daniel-dodd Mar 22, 2024
6b21121
Update likelihoods.py (#446)
daniel-dodd Mar 22, 2024
f360071
Update notebooks
thomaspinder Jun 14, 2024
62f9edf
Adding tagged parameters and updated notebooks
thomaspinder Jun 5, 2024
8ec9c6c
Fix linting
thomaspinder Jun 14, 2024
b00f4d7
Fix missing dep.
thomaspinder Jun 14, 2024
ad6d031
Fix integration test
thomaspinder Jun 14, 2024
be93d76
Readd docs deps
thomaspinder Jun 14, 2024
6549a51
Fix docstrings
thomaspinder Jun 14, 2024
4c190ad
Update lockfile
thomaspinder Jun 14, 2024
aa31eb6
Update parameter refs
thomaspinder Jun 14, 2024
7056e37
Fix broken tests
thomaspinder Jun 14, 2024
41b17f3
Remove PyTrees doc
thomaspinder Jun 14, 2024
52fc221
Failing split order
thomaspinder Jun 25, 2024
69a2475
NNX update
thomaspinder Jun 27, 2024
5b63449
Merge branch 'flax-backend-normalclasses' of github.com:JaxGaussianPr…
thomaspinder Jul 9, 2024
225c395
rename static dir
frazane Jul 15, 2024
512eefd
move examples dir in top level
frazane Jul 15, 2024
36791a0
add _examples generated dir to gitignore
frazane Jul 15, 2024
0883abf
update pyproject deps
frazane Jul 15, 2024
165e1b8
update mkdocs config
frazane Jul 15, 2024
ad6f12a
add examples generation script
frazane Jul 15, 2024
d25538f
adapt relative paths in md files
frazane Jul 15, 2024
33c806a
Merge branch 'main' into fix-doc-build
frazane Jul 16, 2024
741f7b4
Update Ruff and incorporate changes
Thomas-Christie Jul 15, 2024
d50f015
update github workflow for building doc, without executing notebookf …
frazane Jul 18, 2024
85a1a74
Add backend doc
thomaspinder Jul 18, 2024
abfeec6
Add backend doc
thomaspinder Jul 19, 2024
0d9e41e
Add backend doc
thomaspinder Jul 19, 2024
1cc2962
Add replace to transform
thomaspinder Aug 9, 2024
684cd8e
Merge pull request #464 from JaxGaussianProcesses/fix-transform
thomaspinder Aug 14, 2024
be8fad5
Merge with main
thomaspinder Aug 15, 2024
e6b83f9
Merge with main
thomaspinder Aug 15, 2024
dac4553
Update parameters docstring
thomaspinder Aug 15, 2024
031b0a6
Respond to comments
thomaspinder Aug 15, 2024
e00aeda
Merge branch 'flax-backend-normalclasses' into backend_doc
thomaspinder Aug 15, 2024
5430b07
Merge with flax backend
thomaspinder Aug 15, 2024
357ffb1
Fix e2e tests
thomaspinder Aug 15, 2024
7a0ad3b
Fix mplstyle refs
thomaspinder Aug 15, 2024
939f903
bump deps
thomaspinder Aug 15, 2024
8ddee1f
Update poetry
thomaspinder Aug 15, 2024
767cef2
Update poetry
thomaspinder Aug 15, 2024
46be074
Merge branch 'fix-doc-build' into backend_doc
thomaspinder Aug 15, 2024
81a6b98
Fix shutil
thomaspinder Aug 15, 2024
e38d5bb
Drop flax base
thomaspinder Aug 15, 2024
0d21a31
Merge pull request #461 from JaxGaussianProcesses/fix-doc-build
thomaspinder Aug 15, 2024
382098c
Rebase
thomaspinder Aug 15, 2024
a5bca61
Merge branch 'backend_doc' of github.com:JaxGaussianProcesses/GPJax i…
thomaspinder Aug 15, 2024
4383d1d
Merge pull request #462 from JaxGaussianProcesses/backend_doc
thomaspinder Aug 15, 2024
0c073ac
add scikit-learn dependency for docs
frazane Aug 16, 2024
91022b9
bugfix: change directory before running jupytext
frazane Aug 16, 2024
7d73c27
use local mpl style file
frazane Aug 16, 2024
c07ead4
do not use MCMC for classification (it is *very* slow)
frazane Aug 16, 2024
4ebc673
[skip-ci] update github workflows for docs
frazane Aug 16, 2024
6450603
Fix split
thomaspinder Aug 16, 2024
4ee615e
Fix split
thomaspinder Aug 16, 2024
2d323b2
Fix split
thomaspinder Aug 16, 2024
5b9f058
Fix xdoctest
thomaspinder Aug 16, 2024
6fcb83d
Fix doc
thomaspinder Aug 16, 2024
02bfd01
Add serial build
thomaspinder Aug 16, 2024
7e628f0
Update parameters transform and backend doc
thomaspinder Aug 16, 2024
172e61e
Update parameters transform and backend doc
thomaspinder Aug 16, 2024
ab86ac5
Merge pull request #465 from JaxGaussianProcesses/fix-doc-build
thomaspinder Aug 16, 2024
cc62521
Bump Python
thomaspinder Aug 16, 2024
20953c6
Merge branch 'flax-backend-normalclasses' of github.com:JaxGaussianPr…
thomaspinder Aug 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions .github/workflows/build_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,16 @@ jobs:
- name: Install and configure Poetry
uses: snok/install-poetry@v1
with:
version: 1.2.2
version: 1.5.1
virtualenvs-create: false
virtualenvs-in-project: false
installer-parallel: true

- name: Install LaTex
run: |
sudo apt-get update
sudo apt-get install texlive-fonts-recommended texlive-fonts-extra texlive-latex-extra dvipng cm-super

- name: Build the documentation with MKDocs
run: |
cp docs/examples/gpjax.mplstyle .
poetry install --all-extras --with docs
conda install pandoc
poetry run mkdocs build
poetry run python docs/scripts/gen_examples.py --execute && poetry run mkdocs build

- name: Deploy Page 🚀
uses: JamesIves/github-pages-deploy-action@v4.4.1
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
- name: Install Poetry
uses: snok/install-poetry@v1.3.3
with:
version: 1.4.0
version: 1.5.1

# Configure Poetry to use the virtual environment in the project
- name: Setup Poetry
Expand All @@ -39,7 +39,7 @@ jobs:
# Install the dependencies
- name: Install Package
run: |
poetry install --all-extras --with docs
poetry install --with docs

# Run the unit tests and build the coverage report
- name: Run Integration Tests
Expand Down
19 changes: 2 additions & 17 deletions .github/workflows/test_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,32 +33,17 @@ jobs:
auto-update-conda: true
python-version: ${{ matrix.python-version }}

# Install katex for math support
- name: Install NPM
uses: actions/setup-node@v3
with:
node-version: 16
- name: Install KaTeX
run: |
npm install katex

- name: Install LaTex
run: |
sudo apt-get update
sudo apt-get install texlive-fonts-recommended texlive-fonts-extra texlive-latex-extra dvipng cm-super

# Install Poetry and build the documentation
- name: Install and configure Poetry
uses: snok/install-poetry@v1
with:
version: 1.2.2
version: 1.5.1
virtualenvs-create: false
virtualenvs-in-project: false
installer-parallel: true

- name: Build the documentation with MKDocs
run: |
cp docs/examples/gpjax.mplstyle .
poetry install --all-extras --with docs
conda install pandoc
poetry run mkdocs build
poetry run python docs/scripts/gen_examples.py --execute && poetry run mkdocs build
11 changes: 7 additions & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,13 @@ jobs:
python-version: ${{ matrix.python-version }}

# Install Poetry
- name: Install Poetry
uses: snok/install-poetry@v1.3.3
- name: Install and configure Poetry
uses: snok/install-poetry@v1
with:
version: 1.4.0
version: 1.5.1
virtualenvs-create: false
virtualenvs-in-project: false
installer-parallel: true

# Configure Poetry to use the virtual environment in the project
- name: Setup Poetry
Expand All @@ -39,7 +42,7 @@ jobs:
# Install the dependencies
- name: Install Package
run: |
poetry install --with tests
poetry install --with dev

- name: Check docstrings
run: |
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,4 @@ package-lock.json
node_modules/

docs/api
docs/_examples
22 changes: 11 additions & 11 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ repos:
language: system
types: [python]
exclude: examples/
- repo: https://github.com/econchick/interrogate
rev: 1.5.0
hooks:
- id: interrogate
args:
[
"gpjax",
"--config",
"pyproject.toml",
]
pass_filenames: false
# - repo: https://github.com/econchick/interrogate
# rev: 1.5.0
# hooks:
# - id: interrogate
# args:
# [
# "gpjax",
# "--config",
# "pyproject.toml",
# ]
# pass_filenames: false
8 changes: 2 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,9 @@ helped to shape GPJax into the package it is today.
## Notebook examples

> - [**Conjugate Inference**](https://docs.jaxgaussianprocesses.com/examples/regression/)
> - [**Classification with MCMC**](https://docs.jaxgaussianprocesses.com/examples/classification/)
> - [**Classification**](https://docs.jaxgaussianprocesses.com/examples/classification/)
> - [**Sparse Variational Inference**](https://docs.jaxgaussianprocesses.com/examples/collapsed_vi/)
> - [**Stochastic Variational Inference**](https://docs.jaxgaussianprocesses.com/examples/uncollapsed_vi/)
> - [**BlackJax Integration**](https://docs.jaxgaussianprocesses.com/examples/classification/#mcmc-inference)
> - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/examples/classification/#laplace-approximation)
> - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/#custom-kernel)
> - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/examples/graph_kernels/)
Expand Down Expand Up @@ -146,13 +145,10 @@ posterior = prior * likelihood
# Define an optimiser
optimiser = ox.adam(learning_rate=1e-2)

# Define the marginal log-likelihood
negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True))

# Obtain Type 2 MLEs of the hyperparameters
opt_posterior, history = gpx.fit(
model=posterior,
objective=negative_mll,
objective=gpx.objectives.conjugate_mll,
train_data=D,
optim=optimiser,
num_iters=500,
Expand Down
Empty file removed benchmarks/__init__.py
Empty file.
25 changes: 0 additions & 25 deletions benchmarks/asv.conf.json

This file was deleted.

99 changes: 0 additions & 99 deletions benchmarks/kernels.py

This file was deleted.

87 changes: 0 additions & 87 deletions benchmarks/objectives.py

This file was deleted.

Loading
Loading