Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Last Layer Laplace predictions could be computed much faster and becomes problematic for large label classification. #138

Closed
charlio23 opened this issue Nov 13, 2023 · 3 comments · Fixed by #145
Labels
enhancement New feature or request
Milestone

Comments

@charlio23
Copy link

When experimenting with Last Layer Laplace approximation on a classification task with large labels (1K, e.g. ImageNet), the memory rapidly jumps to 10~20Gb and inference becomes slower.

For context, for a last layer 128 -> 3100 (where 3100 is the number of labels), each inference step takes ~10 seconds and 15Gb using only one sample (batch_size=1) using Diagonal Laplace. Kronecker Laplace does not fit in GPU with this setup.

When inspecting the code, I noticed the Last Layer Laplace computes the Jacobian as $\phi(x)^T \otimes I$ explicitly, which is very costly when considering >1K labels as the resulting matrix has shape (num_labels, num_features*num_labels).

Note: $\phi(x)$ are the features of the L-1 layer given input x, and $I$ is an identity matrix of size (num_labels, num_labels)

For the case of Last Layer Laplace, I believe this could be implemented more efficiently considering the factorisation of the jacobian. For example, posterior variance for the kronecker Laplace could be implemented as follows:

  • Assume the inverse posterior precision is such that $H^{-1} = V \otimes U$, the following identity is well-established in Last Layer Laplace literature,
    $$\Sigma = \Big(\phi(x)^T V \phi(x)\Big)U $$

Below you can find some code I used to speed up the Last Layer Laplace by modifying the code from the repository. It would be nice if you could consider making this modification as I am sure there are people interested in your package that might benefit from this.

Diagonal Laplace

    def _glm_predictive_distribution(self, X):
        f_mu, phi = self.model.forward_with_features(X)
        emb_size = phi.shape[-1]
        f_var torch.diag_embed(torch.matmul(self.posterior_variance.reshape(-1, emb_size),(phi*phi).transpose(0,1)).transpose(0,1))
        return f_mu.detach(), f_var.detach()

Kronecker Laplace

    def _glm_predictive_distribution(self, X):
        f_mu, phi = self.model.forward_with_features(X)

        eig_U, eig_V = self.posterior_precision.eigenvalues[0]
        vec_U, vec_V = self.posterior_precision.eigenvectors[0]
        delta = self.posterior_precision.deltas.sqrt()
        inv_U_eig, inv_V_eig = torch.pow(eig_U + delta, -1), torch.pow(eig_V + delta, -1)

        phiT_Q = torch.matmul(Js[:,None,:], vec_V[None,:,:])
        phiTVphi = torch.matmul(torch.matmul(phiT_Q, torch.diag(inv_V_eig)[None,:,:]), phiT_Q.transpose(1,2))

        f_var = phiTVphi*((vec_U @ torch.diag(inv_U_eig) @ vec_U.T)[None,:,:])

        return f_mu.detach(), f_var.detach()

Note1: The speedup is considerable. In my case, it went from 3 hours on the test set to 5 seconds using the Diagonal laplace.
Note2: The Kronecker Laplace gives a slightly different result from the initial case, but the math is correctly implemented as far as I can tell. Still, empirically the results are reasonable.

Best,

Carles

@wiseodd
Copy link
Collaborator

wiseodd commented Nov 13, 2023

Hi Carles, they are indeed good suggestions! I wrote the formula for KronLLLaplace in App. B.1 of my paper; not sure why I haven't managed to implement that yet here!

Would you like to submit a pull request for this? No problem if you can't---I can also do this rather quickly.

@charlio23
Copy link
Author

charlio23 commented Nov 13, 2023

Thank you very much for your fast reply.

I am not sure if my pull request would meet the coding standards of the repo, so if you already have experience, I would appreciate if you could do it instead. The only thing to note is that I noticed a slight difference of the resulting covariance matrix when using the new formula, and I am not sure if the error comes from numerical precision.

PS: In fact the formula I used is from your paper 😄.

@wiseodd
Copy link
Collaborator

wiseodd commented Feb 24, 2024

@aleximmer what's the best way to go about this? For KronLLLaplace seems like this should be implemented in KronDecomposed in matrix.py.

This change would be very useful in relation to #144, e.g. for LLMs where num. of classes is in the order of $10^4$.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants