Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add native serialization support #148

Merged
merged 11 commits into from
Mar 15, 2024
Merged

Add native serialization support #148

merged 11 commits into from
Mar 15, 2024

Conversation

wiseodd
Copy link
Collaborator

@wiseodd wiseodd commented Mar 9, 2024

Addressing #45. Very useful for large models like LLMs where even doing forward passes over training data (for fit()) is expensive.

The API basically follows PyTorch.

la = Laplace(model, 'regression', subset_of_weights='all', hessian_structure='full')
la.fit(train_loader)
la.optimize_prior_precision()  # Or via marglik optimizing sigma_noise also

# Serialization for fitted quantities
state_dict = la.state_dict()
torch.save(state_dict, 'state_dict.bin')

la2 = Laplace(model, 'regression', subset_of_weights='all', hessian_structure='full')
# Load serialized, fitted quantities
la2.load_state_dict(torch.load('state_dict.bin'))

@wiseodd wiseodd added the enhancement New feature or request label Mar 9, 2024
@wiseodd wiseodd requested review from aleximmer and runame March 9, 2024 17:40
@wiseodd
Copy link
Collaborator Author

wiseodd commented Mar 9, 2024

By the way, I didn't really implement nor test SubnetLaplace. Not sure if it's really popular nowadays; the same effect can be done by setting requires_grad = False on select parameters anyway (from #144).

Nevertheless, @edaxberger feel free to implement serialization on SubnetLaplace. It should be very straightforward for you.

@wiseodd
Copy link
Collaborator Author

wiseodd commented Mar 9, 2024

By the way, I didn't really implement nor test SubnetLaplace. Not sure if it's really popular nowadays; the same effect can be done by setting requires_grad = False on select parameters anyway (from #144).

Nevertheless, @edaxberger feel free to implement serialization on SubnetLaplace. It should be very straightforward for you.

Though I must say that a quick test on SubnetLaplace works no problem. But I don't know about edge cases like different subnetwork_indices but the same len(subnetwork_indices), etc.

Copy link
Collaborator

@runame runame left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for tackling this, will be very useful!

I think currently the assumption is that model, likelihood, subset_of_weights, and hessian_structure of la and la2 (in the README example) match. It might make sense to explicitly check for this in the load_state_dict method and throw descriptive errors in case of a mismatch. The necessary information for this has to be also stored (see my comment below).

Additionally, it would be nice to allow for the creation of a new Laplace class instance just based on the saved state dict. A classmethod would make sense for this, but Laplace is a function and not a class, making this option impossible (without other changes). So alternatively, we could add a function like LaplaceFromStateDict that takes the state dict as an argument and returns a Laplace class instance which is (hopefully completely) equivalent to the previously saved one. I could add this functionality in a follow up PR though.

laplace/baselaplace.py Outdated Show resolved Hide resolved
laplace/baselaplace.py Outdated Show resolved Hide resolved
@wiseodd
Copy link
Collaborator Author

wiseodd commented Mar 14, 2024

@runame updated. Please check and merge if everything's good.

I decided to not include backend and backend_kwargs since Laplace doesn't care about them once H is obtained. I.e. H is just a tensor or Kron and you can use whatever backend you want for glm_predictive and to do another fit.

@wiseodd
Copy link
Collaborator Author

wiseodd commented Mar 14, 2024

I also ignore model.state_dict since the user would already have it anyway. Even marglik training also outputted the resulting torch model. Including it in the serialized Laplace is redundant.

Copy link
Collaborator

@runame runame left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, will review the update soon!

I should have elaborated on this, but there is a very concrete use case where the backend and backend_kwargs are needed, i.e. whenever you want to set fit(train_loader, override=False). For example, this is used to implement continual learning, see our experiment for the Redux paper. (I guess some mismatch between backends is fine, since only the shape (or both Kron) of self.H matters.)

I'm fine with not including model.state_dict() for now, this will only be important for the LaplaceFromStateDict functionality that I might add in a follow up PR.

@wiseodd
Copy link
Collaborator Author

wiseodd commented Mar 14, 2024

I think the current checks are sufficient to handle continual learning:

This will check whether the hessian_structure and subset_weights are compatible with the loaded instance:

if self.__class__.__name__ != state_dict['cls_name']:
raise ValueError(
'Loading a wrong Laplace type. Make sure `subset_of_weights` and'
+ ' `hessian_structure` are correct!'
)

This will check the network (torch model) is correct:

if self.n_params is not None and len(state_dict['mean']) != self.n_params:
raise ValueError(
'Attempting to load Laplace with different number of parameters than the model.'
+ ' Make sure that you use the same `subset_of_weights` value and the same `.requires_grad`'
+ ' switch on `model.parameters()`.'
)

As for model.state_dict we could add an option in Laplace.state_dict(), say save_model: bool = False. But this should be part of your future PR.

Copy link
Collaborator

@runame runame left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok makes sense, only have one comment left.

laplace/baselaplace.py Show resolved Hide resolved
@wiseodd
Copy link
Collaborator Author

wiseodd commented Mar 14, 2024

Added a test to check fit(override=False)!

Copy link
Collaborator

@runame runame left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, last questions.

tests/test_serialization.py Outdated Show resolved Hide resolved
tests/test_serialization.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@runame runame left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for incorporating all feedback!

@runame runame merged commit 2543274 into main Mar 15, 2024
@runame runame deleted the serialization branch March 15, 2024 00:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants