Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revamp Cholesky implementation #311

Merged
merged 23 commits into from
Dec 9, 2020
Merged

Conversation

willtebbutt
Copy link
Member

The current Cholesky implementation isn't fantastic, in that we're choosing fairly arbitrary blocksizes based on code that I hacked together almost 4 years ago.

This is probably a better approach (due to Seeger). It basically copies what Zygote currently does, but with the important improvement that the type of the cotangent w.r.t. the matrix A whose Cholesky is taken is either a Symmetric / Hermitian or StridedMatrix, depending on whether A is Symmetric / Hermitian / a StridedMatrix. As shown in the additional tests, this is important to ensure that composition cholesky(Symmetric(X::StridedMatrix)) gives the same answer as cholesky(X::StridedMatrix).

I'll remove Zygote's implementation once this is in.

@willtebbutt willtebbutt requested a review from sethaxen November 19, 2020 23:03
Copy link
Member

@sethaxen sethaxen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rule looks more sane than the previous one. Had a few suggestions, the most substantial being tweaks to also support complex Hermitian matrices.

This PR does two different things: 1) implements an improved cholesky rule for a subset of matrices (also adding support for keyword arguments!) and 2) removes a generic rule that would be applied to all matrices. I wonder if it is better to split those two things into separate PRs. The 2nd one is non-breaking, but the 1st one could potentially be a breaking change if users decorated some custom matrix type with Symmetric.

It seems like it would be straightforward to add the frules for cholesky too. Do you think that's worth adding to this PR, or do you think ADing directly through cholesky with forward mode would be more performant?

@nmheim it would be great if you could look at this too, since this solves some of the issues raised in TuringLang/DistributionsAD.jl#95, and it's also removing the internal function that DistributionsAD.jl relies on (but shouldn't).

src/rulesets/LinearAlgebra/factorization.jl Outdated Show resolved Hide resolved
src/rulesets/LinearAlgebra/factorization.jl Outdated Show resolved Hide resolved
src/rulesets/LinearAlgebra/factorization.jl Outdated Show resolved Hide resolved
src/rulesets/LinearAlgebra/factorization.jl Outdated Show resolved Hide resolved
src/rulesets/LinearAlgebra/factorization.jl Outdated Show resolved Hide resolved
src/rulesets/LinearAlgebra/factorization.jl Outdated Show resolved Hide resolved
function _cholesky_pullback_shared_code(C, Δ)
issuccess(C) || throw(PosDefException(C.info))
U = C.U
Ū = Δ.U
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about the case where the user accessed L instead of U, or both of them?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be handled by getproperty I believe -- it flips stuff appropriately for the uplo property of the cholesky. Since the user-facing interface no longer provides an option to set uplo, I think we might be fine. Does this seem reasonable to you?

src/rulesets/LinearAlgebra/factorization.jl Outdated Show resolved Hide resolved
src/rulesets/LinearAlgebra/factorization.jl Outdated Show resolved Hide resolved
src/rulesets/LinearAlgebra/factorization.jl Outdated Show resolved Hide resolved
@willtebbutt
Copy link
Member Author

willtebbutt commented Nov 20, 2020

This PR now depends on this FiniteDifferences PR.

@sethaxen thanks for the careful review. I've adopted most of your recommendations, but I'm reluctant to extend the implementation to Complex-valued things in this PR, mainly because testing it will be a bit awkward and I don't want to get it wrong. Also I was planning to limit the scope of this PR to maintaining / correcting existing functionality present in ChainRules / Zygote, rather than extending it.

It seems like it would be straightforward to add the frules for cholesky too. Do you think that's worth adding to this PR, or do you think ADing directly through cholesky with forward mode would be more performant?

This would be a good thing to do, but my preference would again be to leave it for later.

This PR does two different things: 1) implements an improved cholesky rule for a subset of matrices (also adding support for keyword arguments!) and 2) removes a generic rule that would be applied to all matrices. I wonder if it is better to split those two things into separate PRs. The 2nd one is non-breaking, but the 1st one could potentially be a breaking change if users decorated some custom matrix type with Symmetric.

Hmmm good point. I wonder whether it would have worked for custom matrices before though? For example, I think that the previous implementation probably implicitly restricted users to StridedMatrixs and their HermOrSym counterparts anyway because of the BLAS calls. Does this address your concerns, or do you think there's a possibility that I'm missing?

edit: no longer depends on the FiniteDifferences PR!

Copy link
Member

@sethaxen sethaxen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few more minor suggestions.

I'm reluctant to extend the implementation to Complex-valued things in this PR, mainly because testing it will be a bit awkward and I don't want to get it wrong. Also I was planning to limit the scope of this PR to maintaining / correcting existing functionality present in ChainRules / Zygote, rather than extending it.

Technically Zygote's rules catch complex StridedMatrixes. Looks like Zygote's rule was never tested for complex numbers, and I can guarantee they are not right, so I agree it is fine to only support real in this PR.

Hmmm good point. I wonder whether it would have worked for custom matrices before though? For example, I think that the previous implementation probably implicitly restricted users to StridedMatrixs and their HermOrSym counterparts anyway because of the BLAS calls. Does this address your concerns, or do you think there's a possibility that I'm missing?

It looks like the old rules supported generic matrices by converting them to Matrixes before passing to the function with the BLAS calls. So it looks like it would have worked for a custom matrix before, so long as the number type was a BlasFloat.

In the past, when restricting types of rules, have we marked those as breaking changes? If so, we should do the same here.

chol_blocked_rev(Ȳ.L, F.L, 25, false)
end
return (NO_FIELDS, ∂X)
function rrule(::typeof(cholesky), A::Real, uplo::Symbol=:U)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be safe.

Suggested change
function rrule(::typeof(cholesky), A::Real, uplo::Symbol=:U)
function rrule(::typeof(cholesky), A::Real, uplo::Symbol)

src/rulesets/LinearAlgebra/factorization.jl Outdated Show resolved Hide resolved
function rrule(
::typeof(cholesky),
A::StridedMatrix{<:LinearAlgebra.BlasReal},
::Val{false}=Val(false);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
::Val{false}=Val(false);
::Val{false};

end

function _cholesky_pullback_shared_code(C, ΔC)
issuccess(C) || throw(PosDefException(C.info))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As with diagonal, could you move this out of the shared function and then only throw this error if the user did not specify check = false?

Ū = ΔC.U
Ā = similar(U.data)
Ā = mul!(Ā, Ū, U')
Ā = LinearAlgebra.copytri!(Ā, 'U', true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know complex matrices aren't officially supported in this PR, but I tested locally that this last fix makes them work for me for complex Hermitian matrices, for when the type constraints are relaxed.

Suggested change
= LinearAlgebra.copytri!(Ā, 'U', true)
= LinearAlgebra.copytri!(Ā, 'U', true)
idx = diagind(Ā)
@views Ā[idx] .= real(Ā[idx])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completely optional though, because some of the other rrules wouldn't support complex either.

Copy link
Member

@sethaxen sethaxen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to merge whenever you're happy with version number and such. Thanks!

@willtebbutt
Copy link
Member Author

Thanks for the feedback @sethaxen . I'm going to stop development on this branch until CI is back up and running, then will process your feedback.

@codecov-io
Copy link

codecov-io commented Dec 9, 2020

Codecov Report

Merging #311 (e98519b) into master (2dab3ab) will decrease coverage by 0.02%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #311      +/-   ##
==========================================
- Coverage   97.46%   97.44%   -0.03%     
==========================================
  Files          18       18              
  Lines         988      939      -49     
==========================================
- Hits          963      915      -48     
+ Misses         25       24       -1     
Impacted Files Coverage Δ
src/rulesets/LinearAlgebra/factorization.jl 98.73% <100.00%> (+0.29%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 2dab3ab...69b4c5c. Read the comment docs.

@willtebbutt willtebbutt merged commit 80443b1 into master Dec 9, 2020
@willtebbutt willtebbutt deleted the wct/cholesky-improvements branch December 9, 2020 21:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants