-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add native serialization support #148
Conversation
By the way, I didn't really implement nor test Nevertheless, @edaxberger feel free to implement serialization on |
Though I must say that a quick test on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for tackling this, will be very useful!
I think currently the assumption is that model
, likelihood
, subset_of_weights
, and hessian_structure
of la
and la2
(in the README example) match. It might make sense to explicitly check for this in the load_state_dict
method and throw descriptive errors in case of a mismatch. The necessary information for this has to be also stored (see my comment below).
Additionally, it would be nice to allow for the creation of a new Laplace class instance just based on the saved state dict. A classmethod
would make sense for this, but Laplace
is a function and not a class, making this option impossible (without other changes). So alternatively, we could add a function like LaplaceFromStateDict
that takes the state dict as an argument and returns a Laplace class instance which is (hopefully completely) equivalent to the previously saved one. I could add this functionality in a follow up PR though.
@runame updated. Please check and merge if everything's good. I decided to not include |
I also ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, will review the update soon!
I should have elaborated on this, but there is a very concrete use case where the backend
and backend_kwargs
are needed, i.e. whenever you want to set fit(train_loader, override=False)
. For example, this is used to implement continual learning, see our experiment for the Redux paper. (I guess some mismatch between backends is fine, since only the shape (or both Kron
) of self.H
matters.)
I'm fine with not including model.state_dict()
for now, this will only be important for the LaplaceFromStateDict
functionality that I might add in a follow up PR.
I think the current checks are sufficient to handle continual learning: This will check whether the Laplace/laplace/baselaplace.py Lines 792 to 796 in 30e28f2
This will check the network (torch model) is correct: Laplace/laplace/baselaplace.py Lines 797 to 802 in 30e28f2
As for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok makes sense, only have one comment left.
Added a test to check |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, last questions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for incorporating all feedback!
Addressing #45. Very useful for large models like LLMs where even doing forward passes over training data (for
fit()
) is expensive.The API basically follows PyTorch.