-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Subnetwork Laplace #58
Conversation
1.-7. Minor commentsLooks quite nice already. I made two comments on the code directly. Here the answer/discussion to the more minor points:
8. Choice of subnetworkOne question regarding the choice of subnetwork in your paper: would it be possible to use the diagonal Laplace-GGN to choose the subnetwork? In fact, I thought that's what you did. Is there any reason that you (need to) use SWAG instead of Laplace to choose the subnetwork? Because I thought that the subnetwork would be chosen using the diagonal GGN, I was thinking to just override the class SubnetMask(object):
def __init__(self, model, likelihood, ..):
...
def __call__(self, train_loader, ...):
# select mask based on train loader, model and other hyperparams similar to a subclass of CurvatureInterface
return mask Then, different policies can subclass this mask, and one could potentially separately implement a SWAG-Mask, random-mask, last-layer mask (for testing?). Also, the method using diagonal GGN would be easy to implement. What do you think about this? |
…ccordingly (incl. tests)
…, and last-layer subnet masks
I now addressed most of the remaining issues discussed. Most notably, I followed your suggestion of adding a |
I also finished implementing subnet selection based on the largest marginal variances (using diagonal Laplace for variance estimation). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, I just left some minor comments. I can help with them.
…into subnetlaplace
…del now (same for last_layer_jacobians)
I have now implemented the remaining things we had discussed, namely:
Let me know what you think and if there is anything else that I should take care of! |
I also just added an example for using |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work, looks good to me now. I just left a few comments/ questions. Other than that
- It would be great if you could make sure that the linewidth does not exceed (90? 100?). Maybe we didn't agree on the linewidth, yet. I'd suggest to go with 100 for now.
- Did you recompile the documentation as indicated in the readme? Otherwise I can do that if you want.
laplace/utils/swag.py
Outdated
return parameters_to_vector(model.parameters()).detach() | ||
|
||
|
||
def fit_diagonal_swag(model, train_loader, criterion, n_snapshots_total=40, snapshot_freq=1, lr=0.01, momentum=0.9, weight_decay=3e-4, min_var=1e-30): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just out of interest: the mean is just discarded and not useful for the final LA or could it help to get a better mean following the idea of standard SWA? The function name indicates it does the full swag
(mean + var) currently. If you just need the variance, maybe that can be indicated in the function name like fit_diagonal_swag_var
. But it's not necessary, just a suggestion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! iirc I have tried using the SWA mean with Laplace, which indeed improved performance. But I haven't done extensive experiments with it as I thought it's not as principled (although perhaps one could view the SWA mean as some kind of approximation to the MAP?), but it would be interesting to revisit the idea (and perhaps even add this as a feature if it works well). The only downside I see is that it requires us to update the BatchNorm parameters of the model, which is one of the main reasons why SWA(G) is so slow.
I've changed the name to fit_diagonal_swag_var
for clarity, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! See my last comment for my only concern.
Minor notes:
- Regarding the line length, I also think 100 is fine for now. I think we have always used brackets
([{
for line breaks, so it would be nice to be consistent with that. - We could consider also moving
marglik_training.py
to the newutils
directory (or just leave it where it is or in the future add something like amethods
directory).
|
Thanks a lot for your comments, Alex and Runa, which I have now addressed! I also recompiled the documentation. Let me know if there is anything else that needs to be resolved. |
Just realised that for some reason Travis has not been running the tests here anymore for the latest commits (I ran tests locally, where everything worked). Any idea why it's not doing this check anymore? |
We have limited compute available on Travis, so we only run the tests when merging PRs from now on. If you ran the tests locally it should be fine. |
Ah I see, makes sense! I just merged |
This PR implements subnetwork Laplace, addressing issue #16.
As you can see from the code, this is fairly straightforward. We basically just need to ask the user to specify/pass a definition of a subnetwork. We then internally store the subnetwork as a vector of indices of the flattened/vectorized model parameters that form the subnetwork (for convenience to the user, we also allow passing the subnetwork as a binary mask/vector of the size of the parameter vector, where 1s indicate the subnetwork parameters). All we then need to do is index the Jacobians / gradients to extract the part corresponding to the subnetwork, both during inference time (i.e. for constructing the GGN/Fisher) and during prediction time (i.e. for constructing the predictive covariance matrix).
Some comments/questions:
subnetwork_mask
setter ofSubnetLaplace
, which needs to correctly check that the passedsubnetwork_mask
is in the right format (we currently support the two formats mentioned above), as the code would break otherwise. I'll double-check the logic and also add sufficient test cases for this.parameters
would becomeparameters_subnet
). It's not hard to change, but might make the documentation more cluttered / harder to read, as we'd have to include conditional clauses I guess.jacobians()
is static? Just wondering, as it makes the code and documentation a bit uglier/longer with the subnetwork option, as we always have to pass thesubnetwork_indices
when calling it. Compare this withgradients()
, which is not static and can therefore simply accessself.subnetwork_indices
.Thanks a lot for your efforts in looking at this (there's no rush of course)!