-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for diagonal Kronecker factors in Kron
matrix class
#136
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just some suggestions and questions. The clone()
before was, if I remember correctly, important or necessary to maintain differentiability when it's needed. I am not sure whether torch.tensor(h)
guarantees the same.
else: | ||
# Diagonal Kronecker factor. | ||
l = Hi | ||
# This might be too memory intensive since len(Hi) can be large. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this comment that is from sorting previously?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No it is due to the torch.eye
, since the Kronecker factor is diagonal it might be too large to explicitly build a len(H_i) x len(H_i)
matrix.
Thanks! The initialisation is now not using tensors but just float scalars, so |
If it passes the tests, I guess we have to find out later. |
To support diagonal Kronecker factors in this library we have to adopt the methods of the
Kron
matrix class to be able to handle this correctly.