Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subnetwork Laplace #58

Merged
merged 54 commits into from
Jan 12, 2022
Merged

Subnetwork Laplace #58

merged 54 commits into from
Jan 12, 2022

Conversation

edaxberger
Copy link
Collaborator

This PR implements subnetwork Laplace, addressing issue #16.

As you can see from the code, this is fairly straightforward. We basically just need to ask the user to specify/pass a definition of a subnetwork. We then internally store the subnetwork as a vector of indices of the flattened/vectorized model parameters that form the subnetwork (for convenience to the user, we also allow passing the subnetwork as a binary mask/vector of the size of the parameter vector, where 1s indicate the subnetwork parameters). All we then need to do is index the Jacobians / gradients to extract the part corresponding to the subnetwork, both during inference time (i.e. for constructing the GGN/Fisher) and during prediction time (i.e. for constructing the predictive covariance matrix).

Some comments/questions:

  1. I'm aware that I still need to write tests, which I'll do (I ran some tests locally to ensure that the code works generally).
  2. The code also doesn't support the NN+MC predictive yet, which I'll add; it shouldn't be hard to implement, as we just need to sample the subnetwork weights and then combine them with the remaining MAP weights to form a parameter vector.
  3. Should we allow the user the change the subnetwork after instantiation, or should the subnetwork remain fixed? In either case I'd need to adapt the code accordingly. For us it would be easier to not allow changing the subnetwork, as that'd require re-fitting the Hessian etc.
  4. We'll have to be especially careful with the logic of the subnetwork_mask setter of SubnetLaplace, which needs to correctly check that the passed subnetwork_mask is in the right format (we currently support the two formats mentioned above), as the code would break otherwise. I'll double-check the logic and also add sufficient test cases for this.
  5. Should we update all the internal documentation to reflect the subnetwork Laplace option? As the subnetwork Laplace implementation is integrated into the full Laplace code (and just adds lines to index the Jacobians/gradients), we'd have to change everything where we refer to the full parameter vector (incl. when we define tensor shapes, e.g. where parameters would become parameters_subnet). It's not hard to change, but might make the documentation more cluttered / harder to read, as we'd have to include conditional clauses I guess.
  6. Is there are reason why jacobians() is static? Just wondering, as it makes the code and documentation a bit uglier/longer with the subnetwork option, as we always have to pass the subnetwork_indices when calling it. Compare this with gradients(), which is not static and can therefore simply access self.subnetwork_indices.
  7. The code currently only supports using subnetwork Laplace with a full Hessian (and throws an error otherwise). In principle it'd also be possible to also use it with diagonal or KFAC Hessians, but it makes less sense conceptually, so I'd rather not support it. But then again, we also support options like diagonal last-layer, which are also not too sensible arguably?
  8. Finally, the current implementation requires the user to specify / pass the subnetwork. We could consider also adding functionality for automatically selecting the subnetworks. Simple choices like random subnetworks or subnetworks of the largest magnitude weights would be easy to implement. But e.g. the approach we ended up using in the subnetwork inference paper (i.e. selecting weights with the largest marginal variances, where we used SWAG for variance estimation) is a bit more involved and might make the codebase too bloated. But happy to discuss this.

Thanks a lot for your efforts in looking at this (there's no rush of course)!

@edaxberger edaxberger added the enhancement New feature or request label Nov 30, 2021
@edaxberger edaxberger added this to the NeurIPS Prerelease milestone Nov 30, 2021
@edaxberger edaxberger linked an issue Nov 30, 2021 that may be closed by this pull request
laplace/__init__.py Outdated Show resolved Hide resolved
laplace/subnetlaplace.py Show resolved Hide resolved
@aleximmer
Copy link
Owner

1.-7. Minor comments

Looks quite nice already. I made two comments on the code directly. Here the answer/discussion to the more minor points:

  1. Suggestions for tests: with all parameters retained, all should be equivalent to FullLaplace. With last layer retained, all should be equal to that as well.
  2. The subnetwork should be static as you suggest because changing it really requires to rerun the entire .fit() method which is the only setup function that has a non negligible runtime.
  3. It could make sense to reimplement the code separately for SubnetLaplace to make it more clear and potentially changeable in the future. Otherwise, one could add @property values for all class parameters that need to be masked and then override these in the Subnet class and mask correspondingly.
  4. Yes, probably we could make the change and have Jacobians be an instance-method instead of static. I don't remember why it is like this right now. In fact, we have a similar change right now in the FunctionalLaplace PR: Functional laplace #55. If you can make the change to Jacobians in the backends on this branch, that would also work. I don't see any problem with this right now but there might be some..
  5. I agree that diagonal or KFAC subnet don't make sense because KFAC would require a certain subnet structure to be sensible and the full diagonal is usually used to select the parameters of interest. I think diagonal last layer makes a little more sense but is probably also not something that should be used much..

8. Choice of subnetwork

One question regarding the choice of subnetwork in your paper: would it be possible to use the diagonal Laplace-GGN to choose the subnetwork? In fact, I thought that's what you did. Is there any reason that you (need to) use SWAG instead of Laplace to choose the subnetwork?

Because I thought that the subnetwork would be chosen using the diagonal GGN, I was thinking to just override the .fit() method with a separate run over the data_loader to select and set the subnetwork indices and then run the true .fit() computing the masked full GGN. If the method does not work with a diagonal GGN as selection criterion and we anyway also want to include different policies for selecting subnetworks, it might make sense to design a class that can be passed to the SubnetLaplace. One could implement it as follows

class SubnetMask(object):
    def __init__(self, model, likelihood, ..):
        ...

    def __call__(self, train_loader, ...):
        # select mask based on train loader, model and other hyperparams similar to a subclass of CurvatureInterface
        return mask

Then, different policies can subclass this mask, and one could potentially separately implement a SWAG-Mask, random-mask, last-layer mask (for testing?). Also, the method using diagonal GGN would be easy to implement. What do you think about this?

@edaxberger
Copy link
Collaborator Author

I now addressed most of the remaining issues discussed. Most notably, I followed your suggestion of adding a SubnetMask class where subclasses implement different subnetwork selection strategies. For now, random, largest magnitude and last-layer subnetworks are supported. I am considering to also add the diag-Laplace subnet selection strategy you mentioned, which should indeed also be easy to implement. I am currently finalising the tests and will push those soon as well.

@edaxberger
Copy link
Collaborator Author

I also finished implementing subnet selection based on the largest marginal variances (using diagonal Laplace for variance estimation).

Copy link
Owner

@aleximmer aleximmer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, I just left some minor comments. I can help with them.

laplace/curvature/asdl.py Outdated Show resolved Hide resolved
laplace/curvature/curvature.py Outdated Show resolved Hide resolved
laplace/subnetlaplace.py Outdated Show resolved Hide resolved
laplace/subnetmask.py Outdated Show resolved Hide resolved
@edaxberger
Copy link
Collaborator Author

I have now implemented the remaining things we had discussed, namely:

  • I moved all utility files (feature_extractor.py, matrix.py, subnetmask.py, swag.py, utils.py) into a new laplace/utils folder
  • I changed the logic of SubnetLaplace to take subnetwork_indices instead of a subclass of SubnetMask; this includes careful checks for the validity of the passed subnetwork_indices
  • I updated the tests accordingly (i.e. by adapting existing tests to the new structure and adding new tests for custom subnetwork_indices not derived from any SubnetMask subclass)

Let me know what you think and if there is anything else that I should take care of!

@edaxberger
Copy link
Collaborator Author

I also just added an example for using SubnetLaplace (with different ways to choose the subnetwork_indices) to the README.

Copy link
Owner

@aleximmer aleximmer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, looks good to me now. I just left a few comments/ questions. Other than that

  • It would be great if you could make sure that the linewidth does not exceed (90? 100?). Maybe we didn't agree on the linewidth, yet. I'd suggest to go with 100 for now.
  • Did you recompile the documentation as indicated in the readme? Otherwise I can do that if you want.

laplace/baselaplace.py Outdated Show resolved Hide resolved
laplace/subnetlaplace.py Show resolved Hide resolved
laplace/subnetlaplace.py Outdated Show resolved Hide resolved
laplace/utils/subnetmask.py Outdated Show resolved Hide resolved
laplace/utils/subnetmask.py Outdated Show resolved Hide resolved
return parameters_to_vector(model.parameters()).detach()


def fit_diagonal_swag(model, train_loader, criterion, n_snapshots_total=40, snapshot_freq=1, lr=0.01, momentum=0.9, weight_decay=3e-4, min_var=1e-30):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just out of interest: the mean is just discarded and not useful for the final LA or could it help to get a better mean following the idea of standard SWA? The function name indicates it does the full swag (mean + var) currently. If you just need the variance, maybe that can be indicated in the function name like fit_diagonal_swag_var. But it's not necessary, just a suggestion.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! iirc I have tried using the SWA mean with Laplace, which indeed improved performance. But I haven't done extensive experiments with it as I thought it's not as principled (although perhaps one could view the SWA mean as some kind of approximation to the MAP?), but it would be interesting to revisit the idea (and perhaps even add this as a feature if it works well). The only downside I see is that it requires us to update the BatchNorm parameters of the model, which is one of the main reasons why SWA(G) is so slow.

I've changed the name to fit_diagonal_swag_var for clarity, thanks!

Copy link
Collaborator

@runame runame left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! See my last comment for my only concern.

Minor notes:

  • Regarding the line length, I also think 100 is fine for now. I think we have always used brackets ([{ for line breaks, so it would be nice to be consistent with that.
  • We could consider also moving marglik_training.py to the new utils directory (or just leave it where it is or in the future add something like a methods directory).

laplace/subnetlaplace.py Outdated Show resolved Hide resolved
laplace/utils/subnetmask.py Outdated Show resolved Hide resolved
laplace/lllaplace.py Show resolved Hide resolved
laplace/subnetlaplace.py Show resolved Hide resolved
@aleximmer
Copy link
Owner

  • Regarding the line length, I also think 100 is fine for now. I think we have always used brackets ([{ for line breaks, so it would be nice to be consistent with that.
  • We could consider also moving marglik_training.py to the new utils directory (or just leave it where it is or in the future add something like a methods directory).
  • Agree with the first comment. I think in the PR are plenty of linebreaks backslash.
  • I wouldn't put it into utils. I think utils should only contain utilities that are required for the package and not utilities for users. The latter we can call methods later as discussed.

@edaxberger
Copy link
Collaborator Author

Thanks a lot for your comments, Alex and Runa, which I have now addressed! I also recompiled the documentation. Let me know if there is anything else that needs to be resolved.

@edaxberger
Copy link
Collaborator Author

Just realised that for some reason Travis has not been running the tests here anymore for the latest commits (I ran tests locally, where everything worked). Any idea why it's not doing this check anymore?

@runame
Copy link
Collaborator

runame commented Jan 12, 2022

Just realised that for some reason Travis has not been running the tests here anymore for the latest commits (I ran tests locally, where everything worked). Any idea why it's not doing this check anymore?

We have limited compute available on Travis, so we only run the tests when merging PRs from now on. If you ran the tests locally it should be fine.

@edaxberger
Copy link
Collaborator Author

We have limited compute available on Travis, so we only run the tests when merging PRs from now on. If you ran the tests locally it should be fine.

Ah I see, makes sense!

I just merged main into subnetlaplace and re-ran the tests locally, and everything still works. So feel free to merge this into main!

@runame runame merged commit b0ddf5f into main Jan 12, 2022
@runame runame deleted the subnetlaplace branch January 12, 2022 12:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Subnetwork inference
4 participants