-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Revamp Cholesky implementation #311
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This rule looks more sane than the previous one. Had a few suggestions, the most substantial being tweaks to also support complex Hermitian
matrices.
This PR does two different things: 1) implements an improved cholesky rule for a subset of matrices (also adding support for keyword arguments!) and 2) removes a generic rule that would be applied to all matrices. I wonder if it is better to split those two things into separate PRs. The 2nd one is non-breaking, but the 1st one could potentially be a breaking change if users decorated some custom matrix type with Symmetric.
It seems like it would be straightforward to add the frule
s for cholesky
too. Do you think that's worth adding to this PR, or do you think ADing directly through cholesky
with forward mode would be more performant?
@nmheim it would be great if you could look at this too, since this solves some of the issues raised in TuringLang/DistributionsAD.jl#95, and it's also removing the internal function that DistributionsAD.jl relies on (but shouldn't).
function _cholesky_pullback_shared_code(C, Δ) | ||
issuccess(C) || throw(PosDefException(C.info)) | ||
U = C.U | ||
Ū = Δ.U |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about the case where the user accessed L
instead of U
, or both of them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be handled by getproperty
I believe -- it flips stuff appropriately for the uplo
property of the cholesky
. Since the user-facing interface no longer provides an option to set uplo
, I think we might be fine. Does this seem reasonable to you?
Co-authored-by: Seth Axen <seth.axen@gmail.com>
Co-authored-by: Seth Axen <seth.axen@gmail.com>
Co-authored-by: Seth Axen <seth.axen@gmail.com>
Co-authored-by: Seth Axen <seth.axen@gmail.com>
@sethaxen thanks for the careful review. I've adopted most of your recommendations, but I'm reluctant to extend the implementation to Complex-valued things in this PR, mainly because testing it will be a bit awkward and I don't want to get it wrong. Also I was planning to limit the scope of this PR to maintaining / correcting existing functionality present in
This would be a good thing to do, but my preference would again be to leave it for later.
Hmmm good point. I wonder whether it would have worked for custom matrices before though? For example, I think that the previous implementation probably implicitly restricted users to edit: no longer depends on the FiniteDifferences PR! |
Co-authored-by: Seth Axen <seth.axen@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few more minor suggestions.
I'm reluctant to extend the implementation to Complex-valued things in this PR, mainly because testing it will be a bit awkward and I don't want to get it wrong. Also I was planning to limit the scope of this PR to maintaining / correcting existing functionality present in
ChainRules
/Zygote
, rather than extending it.
Technically Zygote's rules catch complex StridedMatrix
es. Looks like Zygote's rule was never tested for complex numbers, and I can guarantee they are not right, so I agree it is fine to only support real in this PR.
Hmmm good point. I wonder whether it would have worked for custom matrices before though? For example, I think that the previous implementation probably implicitly restricted users to
StridedMatrix
s and theirHermOrSym
counterparts anyway because of the BLAS calls. Does this address your concerns, or do you think there's a possibility that I'm missing?
It looks like the old rules supported generic matrices by converting them to Matrix
es before passing to the function with the BLAS calls. So it looks like it would have worked for a custom matrix before, so long as the number type was a BlasFloat
.
In the past, when restricting types of rules, have we marked those as breaking changes? If so, we should do the same here.
chol_blocked_rev(Ȳ.L, F.L, 25, false) | ||
end | ||
return (NO_FIELDS, ∂X) | ||
function rrule(::typeof(cholesky), A::Real, uplo::Symbol=:U) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be safe.
function rrule(::typeof(cholesky), A::Real, uplo::Symbol=:U) | |
function rrule(::typeof(cholesky), A::Real, uplo::Symbol) |
function rrule( | ||
::typeof(cholesky), | ||
A::StridedMatrix{<:LinearAlgebra.BlasReal}, | ||
::Val{false}=Val(false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
::Val{false}=Val(false); | |
::Val{false}; |
end | ||
|
||
function _cholesky_pullback_shared_code(C, ΔC) | ||
issuccess(C) || throw(PosDefException(C.info)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As with diagonal, could you move this out of the shared function and then only throw this error if the user did not specify check = false
?
Ū = ΔC.U | ||
Ā = similar(U.data) | ||
Ā = mul!(Ā, Ū, U') | ||
Ā = LinearAlgebra.copytri!(Ā, 'U', true) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know complex matrices aren't officially supported in this PR, but I tested locally that this last fix makes them work for me for complex Hermitian matrices, for when the type constraints are relaxed.
Ā = LinearAlgebra.copytri!(Ā, 'U', true) | |
Ā = LinearAlgebra.copytri!(Ā, 'U', true) | |
idx = diagind(Ā) | |
@views Ā[idx] .= real(Ā[idx]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Completely optional though, because some of the other rrule
s wouldn't support complex either.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to merge whenever you're happy with version number and such. Thanks!
Thanks for the feedback @sethaxen . I'm going to stop development on this branch until CI is back up and running, then will process your feedback. |
Codecov Report
@@ Coverage Diff @@
## master #311 +/- ##
==========================================
- Coverage 97.46% 97.44% -0.03%
==========================================
Files 18 18
Lines 988 939 -49
==========================================
- Hits 963 915 -48
+ Misses 25 24 -1
Continue to review full report at Codecov.
|
The current Cholesky implementation isn't fantastic, in that we're choosing fairly arbitrary blocksizes based on code that I hacked together almost 4 years ago.
This is probably a better approach (due to Seeger). It basically copies what Zygote currently does, but with the important improvement that the type of the cotangent w.r.t. the matrix
A
whose Cholesky is taken is either aSymmetric / Hermitian
orStridedMatrix
, depending on whetherA
is Symmetric / Hermitian / a StridedMatrix. As shown in the additional tests, this is important to ensure that compositioncholesky(Symmetric(X::StridedMatrix))
gives the same answer ascholesky(X::StridedMatrix)
.I'll remove Zygote's implementation once this is in.