Last Layer Laplace predictions could be computed much faster and becomes problematic for large label classification. #138
Labels
enhancement
New feature or request
Milestone
When experimenting with Last Layer Laplace approximation on a classification task with large labels (1K, e.g. ImageNet), the memory rapidly jumps to 10~20Gb and inference becomes slower.
For context, for a last layer 128 -> 3100 (where 3100 is the number of labels), each inference step takes ~10 seconds and 15Gb using only one sample (batch_size=1) using Diagonal Laplace. Kronecker Laplace does not fit in GPU with this setup.
When inspecting the code, I noticed the Last Layer Laplace computes the Jacobian as$\phi(x)^T \otimes I$ explicitly, which is very costly when considering >1K labels as the resulting matrix has shape (num_labels, num_features*num_labels).
Note:$\phi(x)$ are the features of the L-1 layer given input x, and $I$ is an identity matrix of size (num_labels, num_labels)
For the case of Last Layer Laplace, I believe this could be implemented more efficiently considering the factorisation of the jacobian. For example, posterior variance for the kronecker Laplace could be implemented as follows:
Below you can find some code I used to speed up the Last Layer Laplace by modifying the code from the repository. It would be nice if you could consider making this modification as I am sure there are people interested in your package that might benefit from this.
Diagonal Laplace
Kronecker Laplace
Note1: The speedup is considerable. In my case, it went from 3 hours on the test set to 5 seconds using the Diagonal laplace.
Note2: The Kronecker Laplace gives a slightly different result from the initial case, but the math is correctly implemented as far as I can tell. Still, empirically the results are reasonable.
Best,
Carles
The text was updated successfully, but these errors were encountered: