Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revamp Cholesky implementation #311

Merged
merged 23 commits into from
Dec 9, 2020
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.33"
version = "0.7.34"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
233 changes: 64 additions & 169 deletions src/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,73 @@ end
##### `cholesky`
#####

function rrule(::typeof(cholesky), X::AbstractMatrix{<:Real})
F = cholesky(X)
function cholesky_pullback(Ȳ::Composite)
∂X = if F.uplo === 'U'
chol_blocked_rev(Ȳ.U, F.U, 25, true)
else
chol_blocked_rev(Ȳ.L, F.L, 25, false)
end
return (NO_FIELDS, ∂X)
function rrule(::typeof(cholesky), A::Real, uplo::Symbol=:U)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be safe.

Suggested change
function rrule(::typeof(cholesky), A::Real, uplo::Symbol=:U)
function rrule(::typeof(cholesky), A::Real, uplo::Symbol)

C = cholesky(A, uplo)
function cholesky_pullback(ΔC::Composite)
return NO_FIELDS, ΔC.factors[1, 1] / (2 * C.U[1, 1]), DoesNotExist()
end
return F, cholesky_pullback
return C, cholesky_pullback
end

function rrule(
::typeof(cholesky), A::Diagonal{<:Real}, ::Val{false}=Val(false); check::Bool=true,
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
)
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
C = cholesky(A, Val(false); check=check)
function cholesky_pullback(ΔC::Composite)
check && !issuccess(C) && throw(PosDefException(C.info))
Ā = Diagonal(diag(ΔC.factors) .* inv.(2 .* C.factors.diag))
return NO_FIELDS, Ā, DoesNotExist()
end
return C, cholesky_pullback
end

# The appropriate cotangent is different depending upon whether A is Symmetric / Hermitian,
# or just a StridedMatrix.
# Implementation due to Seeger, Matthias, et al. "Auto-differentiating linear algebra."
function rrule(
::typeof(cholesky),
A::LinearAlgebra.HermOrSym{<:LinearAlgebra.BlasReal, <:StridedMatrix},
::Val{false};
check::Bool=true,
)
C = cholesky(A, Val(false); check=check)
function cholesky_pullback(ΔC::Composite)
Ā, U = _cholesky_pullback_shared_code(C, ΔC)
Ā = BLAS.trsm!('R', 'U', 'C', 'N', one(eltype(Ā)) / 2, U.data, Ā)
return NO_FIELDS, _symhermtype(A)(Ā), DoesNotExist()
end
return C, cholesky_pullback
end

function rrule(
::typeof(cholesky),
A::StridedMatrix{<:LinearAlgebra.BlasReal},
::Val{false}=Val(false);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
::Val{false}=Val(false);
::Val{false};

check::Bool=true,
)
C = cholesky(A, Val(false); check=check)
function cholesky_pullback(ΔC::Composite)
Ā, U = _cholesky_pullback_shared_code(C, ΔC)
Ā = BLAS.trsm!('R', 'U', 'C', 'N', one(eltype(Ā)), U.data, Ā)
idx = diagind(Ā)
@views Ā[idx] .= real.(Ā[idx]) ./ 2
return (NO_FIELDS, UpperTriangular(Ā), DoesNotExist())
end
return C, cholesky_pullback
end

function _cholesky_pullback_shared_code(C, ΔC)
issuccess(C) || throw(PosDefException(C.info))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As with diagonal, could you move this out of the shared function and then only throw this error if the user did not specify check = false?

U = C.U
Ū = ΔC.U
Ā = similar(U.data)
Ā = mul!(Ā, Ū, U')
Ā = LinearAlgebra.copytri!(Ā, 'U', true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know complex matrices aren't officially supported in this PR, but I tested locally that this last fix makes them work for me for complex Hermitian matrices, for when the type constraints are relaxed.

Suggested change
= LinearAlgebra.copytri!(Ā, 'U', true)
= LinearAlgebra.copytri!(Ā, 'U', true)
idx = diagind(Ā)
@views Ā[idx] .= real(Ā[idx])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completely optional though, because some of the other rrules wouldn't support complex either.

Ā = ldiv!(U, Ā)
return Ā, U
end

function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: Cholesky
function rrule(::typeof(getproperty), F::T, x::Symbol) where {T <: Cholesky}
function getproperty_cholesky_pullback(Ȳ)
C = Composite{T}
∂F = if x === :U
Expand All @@ -103,161 +156,3 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: Cholesky
end
return getproperty(F, x), getproperty_cholesky_pullback
end

# See "Differentiation of the Cholesky decomposition" (Murray 2016), pages 5-9 in particular,
# for derivations. Here we're implementing the algorithms and their transposes.

"""
level2partition(A::AbstractMatrix, j::Integer, upper::Bool)

Returns views to various bits of the lower triangle of `A` according to the
`level2partition` procedure defined in [1] if `upper` is `false`. If `upper` is `true` then
the transposed views are returned from the upper triangle of `A`.

[1]: "Differentiation of the Cholesky decomposition", Murray 2016
"""
function level2partition(A::AbstractMatrix, j::Integer, upper::Bool)
n = checksquare(A)
@boundscheck checkbounds(1:n, j)
if upper
r = view(A, 1:j-1, j)
d = view(A, j, j)
B = view(A, 1:j-1, j+1:n)
c = view(A, j, j+1:n)
else
r = view(A, j, 1:j-1)
d = view(A, j, j)
B = view(A, j+1:n, 1:j-1)
c = view(A, j+1:n, j)
end
return r, d, B, c
end

"""
level3partition(A::AbstractMatrix, j::Integer, k::Integer, upper::Bool)

Returns views to various bits of the lower triangle of `A` according to the
`level3partition` procedure defined in [1] if `upper` is `false`. If `upper` is `true` then
the transposed views are returned from the upper triangle of `A`.

[1]: "Differentiation of the Cholesky decomposition", Murray 2016
"""
function level3partition(A::AbstractMatrix, j::Integer, k::Integer, upper::Bool)
n = checksquare(A)
@boundscheck checkbounds(1:n, j)
if upper
R = view(A, 1:j-1, j:k)
D = view(A, j:k, j:k)
B = view(A, 1:j-1, k+1:n)
C = view(A, j:k, k+1:n)
else
R = view(A, j:k, 1:j-1)
D = view(A, j:k, j:k)
B = view(A, k+1:n, 1:j-1)
C = view(A, k+1:n, j:k)
end
return R, D, B, C
end

"""
chol_unblocked_rev!(Ā::AbstractMatrix, L::AbstractMatrix, upper::Bool)

Compute the reverse-mode sensitivities of the Cholesky factorization in an unblocked manner.
If `upper` is `false`, then the sensitivites are computed from and stored in the lower triangle
of `Ā` and `L` respectively. If `upper` is `true` then they are computed and stored in the
upper triangles. If at input `upper` is `false` and `tril(Ā) = L̄`, at output
`tril(Ā) = tril(Σ̄)`, where `Σ = LLᵀ`. Analogously, if at input `upper` is `true` and
`triu(Ā) = triu(Ū)`, at output `triu(Ā) = triu(Σ̄)` where `Σ = UᵀU`.
"""
function chol_unblocked_rev!(Σ̄::AbstractMatrix{T}, L::AbstractMatrix{T}, upper::Bool) where T<:Real
n = checksquare(Σ̄)
j = n
@inbounds for _ in 1:n
r, d, B, c = level2partition(L, j, upper)
r̄, d̄, B̄, c̄ = level2partition(Σ̄, j, upper)

# d̄ <- d̄ - c'c̄ / d.
d̄[1] -= dot(c, c̄) / d[1]

# [d̄ c̄'] <- [d̄ c̄'] / d.
d̄ ./= d
c̄ ./= d

# r̄ <- r̄ - [d̄ c̄'] [r' B']'.
r̄ = axpy!(-Σ̄[j,j], r, r̄)
r̄ = gemv!(upper ? 'n' : 'T', -one(T), B, c̄, one(T), r̄)

# B̄ <- B̄ - c̄ r.
B̄ = upper ? ger!(-one(T), r, c̄, B̄) : ger!(-one(T), c̄, r, B̄)
d̄ ./= 2
j -= 1
end
return (upper ? triu! : tril!)(Σ̄)
end

function chol_unblocked_rev(Σ̄::AbstractMatrix, L::AbstractMatrix, upper::Bool)
return chol_unblocked_rev!(copy(Σ̄), L, upper)
end

"""
chol_blocked_rev!(Σ̄::StridedMatrix, L::StridedMatrix, nb::Integer, upper::Bool)

Compute the sensitivities of the Cholesky factorization using a blocked, cache-friendly
procedure. `Σ̄` are the sensitivities of `L`, and will be transformed into the sensitivities
of `Σ`, where `Σ = LLᵀ`. `nb` is the block size to use. If the upper triangle has been used
to represent the factorization, that is `Σ = UᵀU` where `U := Lᵀ`, then this should be
indicated by passing `upper = true`.
"""
function chol_blocked_rev!(Σ̄::StridedMatrix{T}, L::StridedMatrix{T}, nb::Integer, upper::Bool) where T<:Real
n = checksquare(Σ̄)
tmp = Matrix{T}(undef, nb, nb)
k = n
if upper
@inbounds for _ in 1:nb:n
j = max(1, k - nb + 1)
R, D, B, C = level3partition(L, j, k, true)
R̄, D̄, B̄, C̄ = level3partition(Σ̄, j, k, true)

C̄ = trsm!('L', 'U', 'N', 'N', one(T), D, C̄)
gemm!('N', 'N', -one(T), R, C̄, one(T), B̄)
gemm!('N', 'T', -one(T), C, C̄, one(T), D̄)
chol_unblocked_rev!(D̄, D, true)
gemm!('N', 'T', -one(T), B, C̄, one(T), R̄)
if size(D̄, 1) == nb
tmp = axpy!(one(T), D̄, transpose!(tmp, D̄))
gemm!('N', 'N', -one(T), R, tmp, one(T), R̄)
else
gemm!('N', 'N', -one(T), R, D̄ + D̄', one(T), R̄)
end

k -= nb
end
return triu!(Σ̄)
else
@inbounds for _ in 1:nb:n
j = max(1, k - nb + 1)
R, D, B, C = level3partition(L, j, k, false)
R̄, D̄, B̄, C̄ = level3partition(Σ̄, j, k, false)

C̄ = trsm!('R', 'L', 'N', 'N', one(T), D, C̄)
gemm!('N', 'N', -one(T), C̄, R, one(T), B̄)
gemm!('T', 'N', -one(T), C̄, C, one(T), D̄)
chol_unblocked_rev!(D̄, D, false)
gemm!('T', 'N', -one(T), C̄, B, one(T), R̄)
if size(D̄, 1) == nb
tmp = axpy!(one(T), D̄, transpose!(tmp, D̄))
gemm!('N', 'N', -one(T), tmp, R, one(T), R̄)
else
gemm!('N', 'N', -one(T), D̄ + D̄', R, one(T), R̄)
end

k -= nb
end
return tril!(Σ̄)
end
end

function chol_blocked_rev(Σ̄::AbstractMatrix, L::AbstractMatrix, nb::Integer, upper::Bool)
# Convert to `Matrix`s because blas functions require StridedMatrix input.
return chol_blocked_rev!(Matrix(Σ̄), Matrix(L), nb, upper)
end
112 changes: 52 additions & 60 deletions test/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblocked_rev
function FiniteDifferences.to_vec(C::Cholesky)
C_vec, factors_from_vec = to_vec(C.factors)
function cholesky_from_vec(v)
return Cholesky(factors_from_vec(v), C.uplo, C.info)
end
return C_vec, cholesky_from_vec
end

@testset "Factorizations" begin
@testset "svd" begin
Expand Down Expand Up @@ -73,69 +79,55 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo
@test ChainRules._eyesubx!(copy(X)) ≈ I - X
end
end

# These tests are generally a bit tricky to write because FiniteDifferences doesn't
# have fantastic support for this stuff at the minute.
@testset "cholesky" begin
@testset "the thing" begin
X = generate_well_conditioned_matrix(10)
V = generate_well_conditioned_matrix(10)
F, dX_pullback = rrule(cholesky, X)
for p in [:U, :L]
Y, dF_pullback = rrule(getproperty, F, p)
Ȳ = (p === :U ? UpperTriangular : LowerTriangular)(randn(size(Y)))
(dself, dF, dp) = dF_pullback(Ȳ)
@test dself === NO_FIELDS
@test dp === DoesNotExist()
@testset "Real" begin
C = cholesky(rand() + 0.1)
ΔC = Composite{typeof(C)}((factors=rand_tangent(C.factors)))
rrule_test(cholesky, ΔC, (rand() + 0.1, randn()))
end
@testset "Diagonal{<:Real}" begin
D = Diagonal(rand(5) .+ 0.1)
C = cholesky(D)
ΔC = Composite{typeof(C)}((factors=Diagonal(randn(5))))
rrule_test(cholesky, ΔC, (D, Diagonal(randn(5))))
end


# NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp`
# machinery from FiniteDifferences because that isn't set up to respect
# necessary special properties of the input. In the case of the Cholesky
# factorization, we need the input to be Hermitian.
ΔF = unthunk(dF)
_, dX = dX_pullback(ΔF)
X̄_ad = dot(unthunk(dX), V)
X̄_fd = central_fdm(5, 1)(0.000_001) do ε
dot(Ȳ, getproperty(cholesky(X .+ ε .* V), p))
end
@test X̄_ad ≈ X̄_fd rtol=1e-4
X = generate_well_conditioned_matrix(10)
V = generate_well_conditioned_matrix(10)
F, dX_pullback = rrule(cholesky, X)
@testset "uplo=$p" for p in [:U, :L]
Y, dF_pullback = rrule(getproperty, F, p)
Ȳ = (p === :U ? UpperTriangular : LowerTriangular)(randn(size(Y)))
(dself, dF, dp) = dF_pullback(Ȳ)
@test dself === NO_FIELDS
@test dp === DoesNotExist()

# NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp`
# machinery from FiniteDifferences because that isn't set up to respect
# necessary special properties of the input. In the case of the Cholesky
# factorization, we need the input to be Hermitian.
ΔF = unthunk(dF)
_, dX = dX_pullback(ΔF)
X̄_ad = dot(unthunk(dX), V)
X̄_fd = central_fdm(5, 1)(0.000_001) do ε
dot(Ȳ, getproperty(cholesky(X .+ ε .* V), p))
end
@test X̄_ad ≈ X̄_fd rtol=1e-4
end
@testset "helper functions" begin
A = randn(5, 5)
r, d, B2, c = level2partition(A, 4, false)
R, D, B3, C = level3partition(A, 4, 4, false)
@test all(r .== R')
@test all(d .== D)
@test B2[1] == B3[1]
@test all(c .== C)

# Check that level 2 partition with `upper == true` is consistent with `false`
rᵀ, dᵀ, B2ᵀ, cᵀ = level2partition(transpose(A), 4, true)
@test r == rᵀ
@test d == dᵀ
@test B2' == B2ᵀ
@test c == cᵀ

# Check that level 3 partition with `upper == true` is consistent with `false`
R, D, B3, C = level3partition(A, 2, 4, false)
Rᵀ, Dᵀ, B3ᵀ, Cᵀ = level3partition(transpose(A), 2, 4, true)
@test transpose(R) == Rᵀ
@test transpose(D) == Dᵀ
@test transpose(B3) == B3ᵀ
@test transpose(C) == Cᵀ

A = Matrix(LowerTriangular(randn(10, 10)))
Ā = Matrix(LowerTriangular(randn(10, 10)))
# NOTE: BLAS gets angry if we don't materialize the Transpose objects first
B = Matrix(transpose(A))
B̄ = Matrix(transpose(Ā))
@test chol_unblocked_rev(Ā, A, false) ≈ chol_blocked_rev(Ā, A, 1, false)
@test chol_unblocked_rev(Ā, A, false) ≈ chol_blocked_rev(Ā, A, 3, false)
@test chol_unblocked_rev(Ā, A, false) ≈ chol_blocked_rev(Ā, A, 5, false)
@test chol_unblocked_rev(Ā, A, false) ≈ chol_blocked_rev(Ā, A, 10, false)
@test chol_unblocked_rev(Ā, A, false) ≈ transpose(chol_unblocked_rev(B̄, B, true))

@test chol_unblocked_rev(B̄, B, true) ≈ chol_blocked_rev(B̄, B, 1, true)
@test chol_unblocked_rev(B̄, B, true) ≈ chol_blocked_rev(B̄, B, 5, true)
@test chol_unblocked_rev(B̄, B, true) ≈ chol_blocked_rev(B̄, B, 10, true)

# Ensure that cotangents of cholesky(::StridedMatrix) and
# (cholesky ∘ Symmetric)(::StridedMatrix) are equal.
@testset "Symmetric" begin
X_symmetric, sym_back = rrule(Symmetric, X, :U)
C, chol_back_sym = rrule(cholesky, X_symmetric)

Δ = Composite{typeof(C)}((U=UpperTriangular(randn(size(X)))))
ΔX_symmetric = chol_back_sym(Δ)[2]
@test sym_back(ΔX_symmetric)[2] ≈ dX_pullback(Δ)[2]
end
end
end