-
Notifications
You must be signed in to change notification settings - Fork 72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Keep initialization of H for all-weights and last-layer separate #72
Conversation
Looks great, and much simpler than the other solution indeed! I also tested it with WILDS and it works well. Is it a problem that for last-layer with no |
Thanks for testing it with WILDS! I don't think the two exceptions are a problem:
|
Yes, good points! |
We could additionally prohibit trying to do CL with these classes by adjusting the |
I think that's also ok. I don't really have any use case in mind where one might want to use |
Now a descriptive error gets raised when |
Great, I agree that a descriptive error is more useful/clear than just not offering the option at all (and we might still add the feature at some point if we think it's useful at all). Happy to merge this in. |
I think @aleximmer wanted to take a closer look today. After that we can merge it. |
lgtm |
This is an alternative to #71: it addresses a bug resulting from #62, first reported here (note: links to PR discussion in private repository). For the last-layer LA flavors the Hessian approximation
H
was first initialized for all-weights, which leads to out-of-memory errors for larger models.The advantage of the previous fix is that it keeps the classes for all-weights and last-layer flavors more strictly separated, which might make it harder to introduce similar bugs in the future. However, many more classes are necessary. @aleximmer and me agreed that this additional complexity is probably not worth it.
Changes:
H
initialization bug. Theposterior_precision
falls back to the prior before callingfit()
for the first time for most cases. Exceptions: a last-layer flavor which doesn't get thelast_layer_name
passed as an argument and low-rank Laplace. For these two cases,H
will beNone
. When trying to callposterior_precision
in these cases, a descriptive error will be raised.