diff --git a/Project.toml b/Project.toml index 0714d9ceb..cb49e6add 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.36" +version = "0.7.37" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 6e7aa3677..10b0e24a0 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -70,20 +70,69 @@ end ##### `cholesky` ##### -function rrule(::typeof(cholesky), X::AbstractMatrix{<:Real}) - F = cholesky(X) - function cholesky_pullback(Ȳ::Composite) - ∂X = if F.uplo === 'U' - chol_blocked_rev(Ȳ.U, F.U, 25, true) - else - chol_blocked_rev(Ȳ.L, F.L, 25, false) - end - return (NO_FIELDS, ∂X) +function rrule(::typeof(cholesky), A::Real, uplo::Symbol=:U) + C = cholesky(A, uplo) + function cholesky_pullback(ΔC::Composite) + return NO_FIELDS, ΔC.factors[1, 1] / (2 * C.U[1, 1]), DoesNotExist() end - return F, cholesky_pullback + return C, cholesky_pullback +end + +function rrule(::typeof(cholesky), A::Diagonal{<:Real}, ::Val{false}; check::Bool=true) + C = cholesky(A, Val(false); check=check) + function cholesky_pullback(ΔC::Composite) + Ā = Diagonal(diag(ΔC.factors) .* inv.(2 .* C.factors.diag)) + return NO_FIELDS, Ā, DoesNotExist() + end + return C, cholesky_pullback +end + +# The appropriate cotangent is different depending upon whether A is Symmetric / Hermitian, +# or just a StridedMatrix. +# Implementation due to Seeger, Matthias, et al. "Auto-differentiating linear algebra." +function rrule( + ::typeof(cholesky), + A::LinearAlgebra.HermOrSym{<:LinearAlgebra.BlasReal, <:StridedMatrix}, + ::Val{false}; + check::Bool=true, +) + C = cholesky(A, Val(false); check=check) + function cholesky_pullback(ΔC::Composite) + Ā, U = _cholesky_pullback_shared_code(C, ΔC) + Ā = BLAS.trsm!('R', 'U', 'C', 'N', one(eltype(Ā)) / 2, U.data, Ā) + return NO_FIELDS, _symhermtype(A)(Ā), DoesNotExist() + end + return C, cholesky_pullback +end + +function rrule( + ::typeof(cholesky), + A::StridedMatrix{<:LinearAlgebra.BlasReal}, + ::Val{false}; + check::Bool=true, +) + C = cholesky(A, Val(false); check=check) + function cholesky_pullback(ΔC::Composite) + Ā, U = _cholesky_pullback_shared_code(C, ΔC) + Ā = BLAS.trsm!('R', 'U', 'C', 'N', one(eltype(Ā)), U.data, Ā) + idx = diagind(Ā) + @views Ā[idx] .= real.(Ā[idx]) ./ 2 + return (NO_FIELDS, UpperTriangular(Ā), DoesNotExist()) + end + return C, cholesky_pullback +end + +function _cholesky_pullback_shared_code(C, ΔC) + U = C.U + Ū = ΔC.U + Ā = similar(U.data) + Ā = mul!(Ā, Ū, U') + Ā = LinearAlgebra.copytri!(Ā, 'U', true) + Ā = ldiv!(U, Ā) + return Ā, U end -function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: Cholesky +function rrule(::typeof(getproperty), F::T, x::Symbol) where {T <: Cholesky} function getproperty_cholesky_pullback(Ȳ) C = Composite{T} ∂F = if x === :U @@ -103,161 +152,3 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: Cholesky end return getproperty(F, x), getproperty_cholesky_pullback end - -# See "Differentiation of the Cholesky decomposition" (Murray 2016), pages 5-9 in particular, -# for derivations. Here we're implementing the algorithms and their transposes. - -""" - level2partition(A::AbstractMatrix, j::Integer, upper::Bool) - -Returns views to various bits of the lower triangle of `A` according to the -`level2partition` procedure defined in [1] if `upper` is `false`. If `upper` is `true` then -the transposed views are returned from the upper triangle of `A`. - -[1]: "Differentiation of the Cholesky decomposition", Murray 2016 -""" -function level2partition(A::AbstractMatrix, j::Integer, upper::Bool) - n = checksquare(A) - @boundscheck checkbounds(1:n, j) - if upper - r = view(A, 1:j-1, j) - d = view(A, j, j) - B = view(A, 1:j-1, j+1:n) - c = view(A, j, j+1:n) - else - r = view(A, j, 1:j-1) - d = view(A, j, j) - B = view(A, j+1:n, 1:j-1) - c = view(A, j+1:n, j) - end - return r, d, B, c -end - -""" - level3partition(A::AbstractMatrix, j::Integer, k::Integer, upper::Bool) - -Returns views to various bits of the lower triangle of `A` according to the -`level3partition` procedure defined in [1] if `upper` is `false`. If `upper` is `true` then -the transposed views are returned from the upper triangle of `A`. - -[1]: "Differentiation of the Cholesky decomposition", Murray 2016 -""" -function level3partition(A::AbstractMatrix, j::Integer, k::Integer, upper::Bool) - n = checksquare(A) - @boundscheck checkbounds(1:n, j) - if upper - R = view(A, 1:j-1, j:k) - D = view(A, j:k, j:k) - B = view(A, 1:j-1, k+1:n) - C = view(A, j:k, k+1:n) - else - R = view(A, j:k, 1:j-1) - D = view(A, j:k, j:k) - B = view(A, k+1:n, 1:j-1) - C = view(A, k+1:n, j:k) - end - return R, D, B, C -end - -""" - chol_unblocked_rev!(Ā::AbstractMatrix, L::AbstractMatrix, upper::Bool) - -Compute the reverse-mode sensitivities of the Cholesky factorization in an unblocked manner. -If `upper` is `false`, then the sensitivites are computed from and stored in the lower triangle -of `Ā` and `L` respectively. If `upper` is `true` then they are computed and stored in the -upper triangles. If at input `upper` is `false` and `tril(Ā) = L̄`, at output -`tril(Ā) = tril(Σ̄)`, where `Σ = LLᵀ`. Analogously, if at input `upper` is `true` and -`triu(Ā) = triu(Ū)`, at output `triu(Ā) = triu(Σ̄)` where `Σ = UᵀU`. -""" -function chol_unblocked_rev!(Σ̄::AbstractMatrix{T}, L::AbstractMatrix{T}, upper::Bool) where T<:Real - n = checksquare(Σ̄) - j = n - @inbounds for _ in 1:n - r, d, B, c = level2partition(L, j, upper) - r̄, d̄, B̄, c̄ = level2partition(Σ̄, j, upper) - - # d̄ <- d̄ - c'c̄ / d. - d̄[1] -= dot(c, c̄) / d[1] - - # [d̄ c̄'] <- [d̄ c̄'] / d. - d̄ ./= d - c̄ ./= d - - # r̄ <- r̄ - [d̄ c̄'] [r' B']'. - r̄ = axpy!(-Σ̄[j,j], r, r̄) - r̄ = gemv!(upper ? 'n' : 'T', -one(T), B, c̄, one(T), r̄) - - # B̄ <- B̄ - c̄ r. - B̄ = upper ? ger!(-one(T), r, c̄, B̄) : ger!(-one(T), c̄, r, B̄) - d̄ ./= 2 - j -= 1 - end - return (upper ? triu! : tril!)(Σ̄) -end - -function chol_unblocked_rev(Σ̄::AbstractMatrix, L::AbstractMatrix, upper::Bool) - return chol_unblocked_rev!(copy(Σ̄), L, upper) -end - -""" - chol_blocked_rev!(Σ̄::StridedMatrix, L::StridedMatrix, nb::Integer, upper::Bool) - -Compute the sensitivities of the Cholesky factorization using a blocked, cache-friendly -procedure. `Σ̄` are the sensitivities of `L`, and will be transformed into the sensitivities -of `Σ`, where `Σ = LLᵀ`. `nb` is the block size to use. If the upper triangle has been used -to represent the factorization, that is `Σ = UᵀU` where `U := Lᵀ`, then this should be -indicated by passing `upper = true`. -""" -function chol_blocked_rev!(Σ̄::StridedMatrix{T}, L::StridedMatrix{T}, nb::Integer, upper::Bool) where T<:Real - n = checksquare(Σ̄) - tmp = Matrix{T}(undef, nb, nb) - k = n - if upper - @inbounds for _ in 1:nb:n - j = max(1, k - nb + 1) - R, D, B, C = level3partition(L, j, k, true) - R̄, D̄, B̄, C̄ = level3partition(Σ̄, j, k, true) - - C̄ = trsm!('L', 'U', 'N', 'N', one(T), D, C̄) - gemm!('N', 'N', -one(T), R, C̄, one(T), B̄) - gemm!('N', 'T', -one(T), C, C̄, one(T), D̄) - chol_unblocked_rev!(D̄, D, true) - gemm!('N', 'T', -one(T), B, C̄, one(T), R̄) - if size(D̄, 1) == nb - tmp = axpy!(one(T), D̄, transpose!(tmp, D̄)) - gemm!('N', 'N', -one(T), R, tmp, one(T), R̄) - else - gemm!('N', 'N', -one(T), R, D̄ + D̄', one(T), R̄) - end - - k -= nb - end - return triu!(Σ̄) - else - @inbounds for _ in 1:nb:n - j = max(1, k - nb + 1) - R, D, B, C = level3partition(L, j, k, false) - R̄, D̄, B̄, C̄ = level3partition(Σ̄, j, k, false) - - C̄ = trsm!('R', 'L', 'N', 'N', one(T), D, C̄) - gemm!('N', 'N', -one(T), C̄, R, one(T), B̄) - gemm!('T', 'N', -one(T), C̄, C, one(T), D̄) - chol_unblocked_rev!(D̄, D, false) - gemm!('T', 'N', -one(T), C̄, B, one(T), R̄) - if size(D̄, 1) == nb - tmp = axpy!(one(T), D̄, transpose!(tmp, D̄)) - gemm!('N', 'N', -one(T), tmp, R, one(T), R̄) - else - gemm!('N', 'N', -one(T), D̄ + D̄', R, one(T), R̄) - end - - k -= nb - end - return tril!(Σ̄) - end -end - -function chol_blocked_rev(Σ̄::AbstractMatrix, L::AbstractMatrix, nb::Integer, upper::Bool) - # Convert to `Matrix`s because blas functions require StridedMatrix input. - return chol_blocked_rev!(Matrix(Σ̄), Matrix(L), nb, upper) -end diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 60824a4f4..71e2236fe 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -1,4 +1,15 @@ -using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblocked_rev +function FiniteDifferences.to_vec(C::Cholesky) + C_vec, factors_from_vec = to_vec(C.factors) + function cholesky_from_vec(v) + return Cholesky(factors_from_vec(v), C.uplo, C.info) + end + return C_vec, cholesky_from_vec +end + +function FiniteDifferences.to_vec(x::Val) + Val_from_vec(v) = x + return Bool[], Val_from_vec +end @testset "Factorizations" begin @testset "svd" begin @@ -73,69 +84,54 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo @test ChainRules._eyesubx!(copy(X)) ≈ I - X end end + + # These tests are generally a bit tricky to write because FiniteDifferences doesn't + # have fantastic support for this stuff at the minute. @testset "cholesky" begin - @testset "the thing" begin - X = generate_well_conditioned_matrix(10) - V = generate_well_conditioned_matrix(10) - F, dX_pullback = rrule(cholesky, X) - for p in [:U, :L] - Y, dF_pullback = rrule(getproperty, F, p) - Ȳ = (p === :U ? UpperTriangular : LowerTriangular)(randn(size(Y))) - (dself, dF, dp) = dF_pullback(Ȳ) - @test dself === NO_FIELDS - @test dp === DoesNotExist() + @testset "Real" begin + C = cholesky(rand() + 0.1) + ΔC = Composite{typeof(C)}((factors=rand_tangent(C.factors))) + rrule_test(cholesky, ΔC, (rand() + 0.1, randn())) + end + @testset "Diagonal{<:Real}" begin + D = Diagonal(rand(5) .+ 0.1) + C = cholesky(D) + ΔC = Composite{typeof(C)}((factors=Diagonal(randn(5)))) + rrule_test(cholesky, ΔC, (D, Diagonal(randn(5))), (Val(false), nothing)) + end - # NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp` - # machinery from FiniteDifferences because that isn't set up to respect - # necessary special properties of the input. In the case of the Cholesky - # factorization, we need the input to be Hermitian. - ΔF = unthunk(dF) - _, dX = dX_pullback(ΔF) - X̄_ad = dot(unthunk(dX), V) - X̄_fd = _fdm(0.0) do ε - dot(Ȳ, getproperty(cholesky(X .+ ε .* V), p)) - end - @test X̄_ad ≈ X̄_fd rtol=1e-6 atol=1e-6 + X = generate_well_conditioned_matrix(10) + V = generate_well_conditioned_matrix(10) + F, dX_pullback = rrule(cholesky, X, Val(false)) + @testset "uplo=$p" for p in [:U, :L] + Y, dF_pullback = rrule(getproperty, F, p) + Ȳ = (p === :U ? UpperTriangular : LowerTriangular)(randn(size(Y))) + (dself, dF, dp) = dF_pullback(Ȳ) + @test dself === NO_FIELDS + @test dp === DoesNotExist() + + # NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp` + # machinery from FiniteDifferences because that isn't set up to respect + # necessary special properties of the input. In the case of the Cholesky + # factorization, we need the input to be Hermitian. + ΔF = unthunk(dF) + _, dX = dX_pullback(ΔF) + X̄_ad = dot(unthunk(dX), V) + X̄_fd = central_fdm(5, 1)(0.000_001) do ε + dot(Ȳ, getproperty(cholesky(X .+ ε .* V), p)) end + @test X̄_ad ≈ X̄_fd rtol=1e-4 end - @testset "helper functions" begin - A = randn(5, 5) - r, d, B2, c = level2partition(A, 4, false) - R, D, B3, C = level3partition(A, 4, 4, false) - @test all(r .== R') - @test all(d .== D) - @test B2[1] == B3[1] - @test all(c .== C) - - # Check that level 2 partition with `upper == true` is consistent with `false` - rᵀ, dᵀ, B2ᵀ, cᵀ = level2partition(transpose(A), 4, true) - @test r == rᵀ - @test d == dᵀ - @test B2' == B2ᵀ - @test c == cᵀ - - # Check that level 3 partition with `upper == true` is consistent with `false` - R, D, B3, C = level3partition(A, 2, 4, false) - Rᵀ, Dᵀ, B3ᵀ, Cᵀ = level3partition(transpose(A), 2, 4, true) - @test transpose(R) == Rᵀ - @test transpose(D) == Dᵀ - @test transpose(B3) == B3ᵀ - @test transpose(C) == Cᵀ - - A = Matrix(LowerTriangular(randn(10, 10))) - Ā = Matrix(LowerTriangular(randn(10, 10))) - # NOTE: BLAS gets angry if we don't materialize the Transpose objects first - B = Matrix(transpose(A)) - B̄ = Matrix(transpose(Ā)) - @test chol_unblocked_rev(Ā, A, false) ≈ chol_blocked_rev(Ā, A, 1, false) - @test chol_unblocked_rev(Ā, A, false) ≈ chol_blocked_rev(Ā, A, 3, false) - @test chol_unblocked_rev(Ā, A, false) ≈ chol_blocked_rev(Ā, A, 5, false) - @test chol_unblocked_rev(Ā, A, false) ≈ chol_blocked_rev(Ā, A, 10, false) - @test chol_unblocked_rev(Ā, A, false) ≈ transpose(chol_unblocked_rev(B̄, B, true)) - - @test chol_unblocked_rev(B̄, B, true) ≈ chol_blocked_rev(B̄, B, 1, true) - @test chol_unblocked_rev(B̄, B, true) ≈ chol_blocked_rev(B̄, B, 5, true) - @test chol_unblocked_rev(B̄, B, true) ≈ chol_blocked_rev(B̄, B, 10, true) + + # Ensure that cotangents of cholesky(::StridedMatrix) and + # (cholesky ∘ Symmetric)(::StridedMatrix) are equal. + @testset "Symmetric" begin + X_symmetric, sym_back = rrule(Symmetric, X, :U) + C, chol_back_sym = rrule(cholesky, X_symmetric, Val(false)) + + Δ = Composite{typeof(C)}((U=UpperTriangular(randn(size(X))))) + ΔX_symmetric = chol_back_sym(Δ)[2] + @test sym_back(ΔX_symmetric)[2] ≈ dX_pullback(Δ)[2] end end end