Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fast computation of functional_variance for DiagLLLaplace and KronLLLaplace #145

Merged
merged 10 commits into from
Jun 30, 2024

Conversation

wiseodd
Copy link
Collaborator

@wiseodd wiseodd commented Feb 24, 2024

PR for #138. Very useful for LLMs or diffusion models or any models with many outputs.

TODO: Implementation for KronLLLaplace. I'd like input from @aleximmer, who's the author of matrix.py.

@wiseodd wiseodd added the enhancement New feature or request label Feb 24, 2024
laplace/lllaplace.py Outdated Show resolved Hide resolved
@@ -201,6 +208,40 @@ def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1.,
def _init_H(self):
self.H = Kron.init_from_model(self.model.last_layer, self._device)

def _functional_variance_fast(self, X):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aleximmer here's the initial implementation for the KronLLLaplace. The test, comparing this f_var to the f_var = la.posterior_precision.inv_square_form(Js) fails tho...

Could you please check this? Feel free to propose a more elegant solution since you know more about the implementation of Kron.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the advantage you are trying to achieve with this? Is it faster because you do the damping formulation of the posterio update (eigenvalues + sqrt(delta))? I suppose that approximation makes the test fail potentially.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is to use this identity of the matrix-Normal distribution: (See https://arxiv.org/pdf/2002.10118.pdf, Appendix B.1)

Screenshot 2024-03-02 at 5 12 51 PM

Then, it's much faster than the naive functional_variance since we don't need to compute the Jacobian which is (batch_size, num_classes, num_params). We only need to multiply the inverse-Kronecker factors with the last layer features $\phi(x)$.

Let me know your thoughts on the best way to achieve this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PS. the rationale for the sqrt damping thing: I follow the KFAC-Laplace https://openreview.net/pdf?id=Skdvd2xAZ. Let me know what you think.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked a bit more into this now. It's a very important change to have this in since it's significantly faster. The way it should probably be implemented is by using the damping=True/False flag. The exact inversion with a prior would have to be done using an eigendecomposition, which we are doing right now, and can be avoided when we use the fast predictive by using the damping formulation instead. This also would avoid the recomputation of U and V from the eigendecomposition and we could add the method to matrix.py. I could look into how to do this best. One thing I am wondering: do you know if it is possible to do this as well for the joint posterior predictive?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, great. I'll leave it to you then to do this implicitly by using damping. The corresponding test case is in tests/test_lllaplace.py -k "test_functional_variance_fast[KronLLLaplace]".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think joint predictive can also benefit from this (code for naive). However, more thoughts need to be put here. So let's do this in a separate PR.

@runame runame changed the base branch from main to mc-subset2 March 1, 2024 14:58
@runame runame changed the base branch from mc-subset2 to main March 1, 2024 14:59
@wiseodd wiseodd marked this pull request as ready for review March 3, 2024 19:22
@wiseodd wiseodd merged commit e3ca2c6 into main Jun 30, 2024
3 checks passed
@wiseodd wiseodd deleted the speedup_llla branch June 30, 2024 15:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
3 participants