-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Revamp Cholesky implementation #311
Changes from 20 commits
a4b35da
4f15002
eed8ae2
86fd8dd
c6a7237
ce92c26
9f318ae
14ac3fd
6149066
f2fa1ce
026e01d
20a6fc7
e619576
83e44cf
eb6e2a8
d20e7e2
405ebb4
a3a7bdc
b396e50
fbeade4
32a42dd
e98519b
69b4c5c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -70,20 +70,73 @@ end | |||||||||
##### `cholesky` | ||||||||||
##### | ||||||||||
|
||||||||||
function rrule(::typeof(cholesky), X::AbstractMatrix{<:Real}) | ||||||||||
F = cholesky(X) | ||||||||||
function cholesky_pullback(Ȳ::Composite) | ||||||||||
∂X = if F.uplo === 'U' | ||||||||||
chol_blocked_rev(Ȳ.U, F.U, 25, true) | ||||||||||
else | ||||||||||
chol_blocked_rev(Ȳ.L, F.L, 25, false) | ||||||||||
end | ||||||||||
return (NO_FIELDS, ∂X) | ||||||||||
function rrule(::typeof(cholesky), A::Real, uplo::Symbol=:U) | ||||||||||
C = cholesky(A, uplo) | ||||||||||
function cholesky_pullback(ΔC::Composite) | ||||||||||
return NO_FIELDS, ΔC.factors[1, 1] / (2 * C.U[1, 1]), DoesNotExist() | ||||||||||
end | ||||||||||
return F, cholesky_pullback | ||||||||||
return C, cholesky_pullback | ||||||||||
end | ||||||||||
|
||||||||||
function rrule( | ||||||||||
::typeof(cholesky), A::Diagonal{<:Real}, ::Val{false}=Val(false); check::Bool=true, | ||||||||||
willtebbutt marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
) | ||||||||||
willtebbutt marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
C = cholesky(A, Val(false); check=check) | ||||||||||
function cholesky_pullback(ΔC::Composite) | ||||||||||
check && !issuccess(C) && throw(PosDefException(C.info)) | ||||||||||
Ā = Diagonal(diag(ΔC.factors) .* inv.(2 .* C.factors.diag)) | ||||||||||
return NO_FIELDS, Ā, DoesNotExist() | ||||||||||
end | ||||||||||
return C, cholesky_pullback | ||||||||||
end | ||||||||||
|
||||||||||
# The appropriate cotangent is different depending upon whether A is Symmetric / Hermitian, | ||||||||||
# or just a StridedMatrix. | ||||||||||
# Implementation due to Seeger, Matthias, et al. "Auto-differentiating linear algebra." | ||||||||||
function rrule( | ||||||||||
::typeof(cholesky), | ||||||||||
A::LinearAlgebra.HermOrSym{<:LinearAlgebra.BlasReal, <:StridedMatrix}, | ||||||||||
::Val{false}; | ||||||||||
check::Bool=true, | ||||||||||
) | ||||||||||
C = cholesky(A, Val(false); check=check) | ||||||||||
function cholesky_pullback(ΔC::Composite) | ||||||||||
Ā, U = _cholesky_pullback_shared_code(C, ΔC) | ||||||||||
Ā = BLAS.trsm!('R', 'U', 'C', 'N', one(eltype(Ā)) / 2, U.data, Ā) | ||||||||||
return NO_FIELDS, _symhermtype(A)(Ā), DoesNotExist() | ||||||||||
end | ||||||||||
return C, cholesky_pullback | ||||||||||
end | ||||||||||
|
||||||||||
function rrule( | ||||||||||
::typeof(cholesky), | ||||||||||
A::StridedMatrix{<:LinearAlgebra.BlasReal}, | ||||||||||
::Val{false}=Val(false); | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
check::Bool=true, | ||||||||||
) | ||||||||||
C = cholesky(A, Val(false); check=check) | ||||||||||
function cholesky_pullback(ΔC::Composite) | ||||||||||
Ā, U = _cholesky_pullback_shared_code(C, ΔC) | ||||||||||
Ā = BLAS.trsm!('R', 'U', 'C', 'N', one(eltype(Ā)), U.data, Ā) | ||||||||||
idx = diagind(Ā) | ||||||||||
@views Ā[idx] .= real.(Ā[idx]) ./ 2 | ||||||||||
return (NO_FIELDS, UpperTriangular(Ā), DoesNotExist()) | ||||||||||
end | ||||||||||
return C, cholesky_pullback | ||||||||||
end | ||||||||||
|
||||||||||
function _cholesky_pullback_shared_code(C, ΔC) | ||||||||||
issuccess(C) || throw(PosDefException(C.info)) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As with diagonal, could you move this out of the shared function and then only throw this error if the user did not specify |
||||||||||
U = C.U | ||||||||||
Ū = ΔC.U | ||||||||||
Ā = similar(U.data) | ||||||||||
Ā = mul!(Ā, Ū, U') | ||||||||||
Ā = LinearAlgebra.copytri!(Ā, 'U', true) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know complex matrices aren't officially supported in this PR, but I tested locally that this last fix makes them work for me for complex Hermitian matrices, for when the type constraints are relaxed.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Completely optional though, because some of the other |
||||||||||
Ā = ldiv!(U, Ā) | ||||||||||
return Ā, U | ||||||||||
end | ||||||||||
|
||||||||||
function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: Cholesky | ||||||||||
function rrule(::typeof(getproperty), F::T, x::Symbol) where {T <: Cholesky} | ||||||||||
function getproperty_cholesky_pullback(Ȳ) | ||||||||||
C = Composite{T} | ||||||||||
∂F = if x === :U | ||||||||||
|
@@ -103,161 +156,3 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: Cholesky | |||||||||
end | ||||||||||
return getproperty(F, x), getproperty_cholesky_pullback | ||||||||||
end | ||||||||||
|
||||||||||
# See "Differentiation of the Cholesky decomposition" (Murray 2016), pages 5-9 in particular, | ||||||||||
# for derivations. Here we're implementing the algorithms and their transposes. | ||||||||||
|
||||||||||
""" | ||||||||||
level2partition(A::AbstractMatrix, j::Integer, upper::Bool) | ||||||||||
|
||||||||||
Returns views to various bits of the lower triangle of `A` according to the | ||||||||||
`level2partition` procedure defined in [1] if `upper` is `false`. If `upper` is `true` then | ||||||||||
the transposed views are returned from the upper triangle of `A`. | ||||||||||
|
||||||||||
[1]: "Differentiation of the Cholesky decomposition", Murray 2016 | ||||||||||
""" | ||||||||||
function level2partition(A::AbstractMatrix, j::Integer, upper::Bool) | ||||||||||
n = checksquare(A) | ||||||||||
@boundscheck checkbounds(1:n, j) | ||||||||||
if upper | ||||||||||
r = view(A, 1:j-1, j) | ||||||||||
d = view(A, j, j) | ||||||||||
B = view(A, 1:j-1, j+1:n) | ||||||||||
c = view(A, j, j+1:n) | ||||||||||
else | ||||||||||
r = view(A, j, 1:j-1) | ||||||||||
d = view(A, j, j) | ||||||||||
B = view(A, j+1:n, 1:j-1) | ||||||||||
c = view(A, j+1:n, j) | ||||||||||
end | ||||||||||
return r, d, B, c | ||||||||||
end | ||||||||||
|
||||||||||
""" | ||||||||||
level3partition(A::AbstractMatrix, j::Integer, k::Integer, upper::Bool) | ||||||||||
|
||||||||||
Returns views to various bits of the lower triangle of `A` according to the | ||||||||||
`level3partition` procedure defined in [1] if `upper` is `false`. If `upper` is `true` then | ||||||||||
the transposed views are returned from the upper triangle of `A`. | ||||||||||
|
||||||||||
[1]: "Differentiation of the Cholesky decomposition", Murray 2016 | ||||||||||
""" | ||||||||||
function level3partition(A::AbstractMatrix, j::Integer, k::Integer, upper::Bool) | ||||||||||
n = checksquare(A) | ||||||||||
@boundscheck checkbounds(1:n, j) | ||||||||||
if upper | ||||||||||
R = view(A, 1:j-1, j:k) | ||||||||||
D = view(A, j:k, j:k) | ||||||||||
B = view(A, 1:j-1, k+1:n) | ||||||||||
C = view(A, j:k, k+1:n) | ||||||||||
else | ||||||||||
R = view(A, j:k, 1:j-1) | ||||||||||
D = view(A, j:k, j:k) | ||||||||||
B = view(A, k+1:n, 1:j-1) | ||||||||||
C = view(A, k+1:n, j:k) | ||||||||||
end | ||||||||||
return R, D, B, C | ||||||||||
end | ||||||||||
|
||||||||||
""" | ||||||||||
chol_unblocked_rev!(Ā::AbstractMatrix, L::AbstractMatrix, upper::Bool) | ||||||||||
|
||||||||||
Compute the reverse-mode sensitivities of the Cholesky factorization in an unblocked manner. | ||||||||||
If `upper` is `false`, then the sensitivites are computed from and stored in the lower triangle | ||||||||||
of `Ā` and `L` respectively. If `upper` is `true` then they are computed and stored in the | ||||||||||
upper triangles. If at input `upper` is `false` and `tril(Ā) = L̄`, at output | ||||||||||
`tril(Ā) = tril(Σ̄)`, where `Σ = LLᵀ`. Analogously, if at input `upper` is `true` and | ||||||||||
`triu(Ā) = triu(Ū)`, at output `triu(Ā) = triu(Σ̄)` where `Σ = UᵀU`. | ||||||||||
""" | ||||||||||
function chol_unblocked_rev!(Σ̄::AbstractMatrix{T}, L::AbstractMatrix{T}, upper::Bool) where T<:Real | ||||||||||
n = checksquare(Σ̄) | ||||||||||
j = n | ||||||||||
@inbounds for _ in 1:n | ||||||||||
r, d, B, c = level2partition(L, j, upper) | ||||||||||
r̄, d̄, B̄, c̄ = level2partition(Σ̄, j, upper) | ||||||||||
|
||||||||||
# d̄ <- d̄ - c'c̄ / d. | ||||||||||
d̄[1] -= dot(c, c̄) / d[1] | ||||||||||
|
||||||||||
# [d̄ c̄'] <- [d̄ c̄'] / d. | ||||||||||
d̄ ./= d | ||||||||||
c̄ ./= d | ||||||||||
|
||||||||||
# r̄ <- r̄ - [d̄ c̄'] [r' B']'. | ||||||||||
r̄ = axpy!(-Σ̄[j,j], r, r̄) | ||||||||||
r̄ = gemv!(upper ? 'n' : 'T', -one(T), B, c̄, one(T), r̄) | ||||||||||
|
||||||||||
# B̄ <- B̄ - c̄ r. | ||||||||||
B̄ = upper ? ger!(-one(T), r, c̄, B̄) : ger!(-one(T), c̄, r, B̄) | ||||||||||
d̄ ./= 2 | ||||||||||
j -= 1 | ||||||||||
end | ||||||||||
return (upper ? triu! : tril!)(Σ̄) | ||||||||||
end | ||||||||||
|
||||||||||
function chol_unblocked_rev(Σ̄::AbstractMatrix, L::AbstractMatrix, upper::Bool) | ||||||||||
return chol_unblocked_rev!(copy(Σ̄), L, upper) | ||||||||||
end | ||||||||||
|
||||||||||
""" | ||||||||||
chol_blocked_rev!(Σ̄::StridedMatrix, L::StridedMatrix, nb::Integer, upper::Bool) | ||||||||||
|
||||||||||
Compute the sensitivities of the Cholesky factorization using a blocked, cache-friendly | ||||||||||
procedure. `Σ̄` are the sensitivities of `L`, and will be transformed into the sensitivities | ||||||||||
of `Σ`, where `Σ = LLᵀ`. `nb` is the block size to use. If the upper triangle has been used | ||||||||||
to represent the factorization, that is `Σ = UᵀU` where `U := Lᵀ`, then this should be | ||||||||||
indicated by passing `upper = true`. | ||||||||||
""" | ||||||||||
function chol_blocked_rev!(Σ̄::StridedMatrix{T}, L::StridedMatrix{T}, nb::Integer, upper::Bool) where T<:Real | ||||||||||
n = checksquare(Σ̄) | ||||||||||
tmp = Matrix{T}(undef, nb, nb) | ||||||||||
k = n | ||||||||||
if upper | ||||||||||
@inbounds for _ in 1:nb:n | ||||||||||
j = max(1, k - nb + 1) | ||||||||||
R, D, B, C = level3partition(L, j, k, true) | ||||||||||
R̄, D̄, B̄, C̄ = level3partition(Σ̄, j, k, true) | ||||||||||
|
||||||||||
C̄ = trsm!('L', 'U', 'N', 'N', one(T), D, C̄) | ||||||||||
gemm!('N', 'N', -one(T), R, C̄, one(T), B̄) | ||||||||||
gemm!('N', 'T', -one(T), C, C̄, one(T), D̄) | ||||||||||
chol_unblocked_rev!(D̄, D, true) | ||||||||||
gemm!('N', 'T', -one(T), B, C̄, one(T), R̄) | ||||||||||
if size(D̄, 1) == nb | ||||||||||
tmp = axpy!(one(T), D̄, transpose!(tmp, D̄)) | ||||||||||
gemm!('N', 'N', -one(T), R, tmp, one(T), R̄) | ||||||||||
else | ||||||||||
gemm!('N', 'N', -one(T), R, D̄ + D̄', one(T), R̄) | ||||||||||
end | ||||||||||
|
||||||||||
k -= nb | ||||||||||
end | ||||||||||
return triu!(Σ̄) | ||||||||||
else | ||||||||||
@inbounds for _ in 1:nb:n | ||||||||||
j = max(1, k - nb + 1) | ||||||||||
R, D, B, C = level3partition(L, j, k, false) | ||||||||||
R̄, D̄, B̄, C̄ = level3partition(Σ̄, j, k, false) | ||||||||||
|
||||||||||
C̄ = trsm!('R', 'L', 'N', 'N', one(T), D, C̄) | ||||||||||
gemm!('N', 'N', -one(T), C̄, R, one(T), B̄) | ||||||||||
gemm!('T', 'N', -one(T), C̄, C, one(T), D̄) | ||||||||||
chol_unblocked_rev!(D̄, D, false) | ||||||||||
gemm!('T', 'N', -one(T), C̄, B, one(T), R̄) | ||||||||||
if size(D̄, 1) == nb | ||||||||||
tmp = axpy!(one(T), D̄, transpose!(tmp, D̄)) | ||||||||||
gemm!('N', 'N', -one(T), tmp, R, one(T), R̄) | ||||||||||
else | ||||||||||
gemm!('N', 'N', -one(T), D̄ + D̄', R, one(T), R̄) | ||||||||||
end | ||||||||||
|
||||||||||
k -= nb | ||||||||||
end | ||||||||||
return tril!(Σ̄) | ||||||||||
end | ||||||||||
end | ||||||||||
|
||||||||||
function chol_blocked_rev(Σ̄::AbstractMatrix, L::AbstractMatrix, nb::Integer, upper::Bool) | ||||||||||
# Convert to `Matrix`s because blas functions require StridedMatrix input. | ||||||||||
return chol_blocked_rev!(Matrix(Σ̄), Matrix(L), nb, upper) | ||||||||||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be safe.