Skip to content

Commit

Permalink
Add rules for Symmetric/Hermitian eigen and svd (#323)
Browse files Browse the repository at this point in the history
* Move symmetric rules to own file

* Move symmetric tests to own file

* Adapt eigen and eigvals rules from #321

* Don't allocate

* Implement mutating forms for frule

* Add sortby keyword

* Use fewer indices

* Correctly reference BlasReal

* Move eigen pullback to external function

* Add hermitian svd rrule

* Hermitrize in pullback

* Separate out eigvals sign code

* Add svdvals rrule

* Rearrange functions

* Realify eigenvalue cotangents

* Restrict to StridedMatrixes

* Reduce allocations and unnecessary ops

* Avoid unnecessary allocation in svd

* Explicitly create eigvals pullback inputs

* Simplify _hermitrize!

* Add svdvals section

* Don't use convenience type not in v1.0

* Fix multiplication order

* Remove ambiguity in signature

* Define missing variable

* Make pure imaginary

* Add tests for eigendecomposition rules

* Test from nonsymmetric matrix

* Add newlines

* Remove unnecessary unthunks

* Test type-stability

* Test mixtures of Zeros

* Use more informative testset names

* Fix svd pullback bugs

* Add svd pullback tests

* Return correct argument

* Remove unused (co)tangents

* Add svdvals tests

* Fix typo

* Restrict SVD test to greater than v1.3.0

* Only check type-stability on 1.6

* Avoid specifying sortby keyword

This is not defined on earlier Julia versions

* Remove obsolete comment

* Abandon ship when derivatives explode

* Handle Hermitian special-case for general eigen

* Fix comment

* Resolve type-instability

* Handle when just ΔV is Zero

* Test eigen for hermitian Matrix

* Make hermitian Matrix

* Call eigen pullback from eigvals

* Only pass sortby to Hermitian eigen

* Support Julia 1.0's return for _symherm_back

* Call Hermitian eigvals! frule

* Call eigen rrule in eigval rrule

* Test eigvals for hermitian Matrix-es

* Do less expensive eltype check first

* Correctly handle sortby default

* Increment version number

* Add references and notes

* More clearly name tests

* Add comment explaining test set
  • Loading branch information
sethaxen authored Jan 7, 2021
1 parent ba7bbd0 commit 053a6c0
Show file tree
Hide file tree
Showing 4 changed files with 610 additions and 24 deletions.
62 changes: 46 additions & 16 deletions src/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,16 @@ end
# - support degenerate matrices (see #144)

function frule((_, ΔA), ::typeof(eigen!), A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat}
ΔA isa AbstractZero && return (eigen!(A; kwargs...), ΔA)
if ishermitian(A)
sortby = get(kwargs, :sortby, VERSION v"1.2.0" ? LinearAlgebra.eigsortby : nothing)
return if sortby === nothing
frule((Zero(), Hermitian(ΔA)), eigen!, Hermitian(A))
else
frule((Zero(), Hermitian(ΔA)), eigen!, Hermitian(A); sortby=sortby)
end
end
F = eigen!(A; kwargs...)
ΔA isa AbstractZero && return F, ΔA
λ, V = F.values, F.vectors
tmp = V \ ΔA
∂K = tmp * V
Expand All @@ -96,8 +104,14 @@ function rrule(::typeof(eigen), A::StridedMatrix{T}; kwargs...) where {T<:Union{
function eigen_pullback(ΔF::Composite{<:Eigen})
λ, V = F.values, F.vectors
Δλ, ΔV = ΔF.values, ΔF.vectors
if ΔV isa AbstractZero
Δλ isa AbstractZero && return (NO_FIELDS, Δλ + ΔV)
ΔV isa AbstractZero && Δλ isa AbstractZero && return (NO_FIELDS, Δλ + ΔV)
if eltype(λ) <: Real && ishermitian(A)
hermA = Hermitian(A)
∂V = ΔV isa AbstractZero ? ΔV : copyto!(similar(ΔV), ΔV)
∂hermA = eigen_rev!(hermA, λ, V, Δλ, ∂V)
∂Atriu = _symherm_back(typeof(hermA), ∂hermA, hermA.uplo)
∂A = ∂Atriu isa AbstractTriangular ? triu!(∂Atriu.data) : ∂Atriu
elseif ΔV isa AbstractZero
∂K = Diagonal(Δλ)
∂A = V' \ ∂K * V'
else
Expand Down Expand Up @@ -173,31 +187,47 @@ end

function frule((_, ΔA), ::typeof(eigvals!), A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat}
ΔA isa AbstractZero && return eigvals!(A; kwargs...), ΔA
F = eigen!(A; kwargs...)
λ, V = F.values, F.vectors
tmp = V \ ΔA
∂λ = similar(λ)
# diag(tmp * V) without computing full matrix product
if eltype(∂λ) <: Real
broadcast!((a, b) -> sum(real prod, zip(a, b)), ∂λ, eachrow(tmp), eachcol(V))
if ishermitian(A)
λ, ∂λ = frule((Zero(), Hermitian(ΔA)), eigvals!, Hermitian(A))
sortby = get(kwargs, :sortby, VERSION v"1.2.0" ? LinearAlgebra.eigsortby : nothing)
_sorteig!_fwd(∂λ, λ, sortby)
else
broadcast!((a, b) -> sum(prod, zip(a, b)), ∂λ, eachrow(tmp), eachcol(V))
F = eigen!(A; kwargs...)
λ, V = F.values, F.vectors
tmp = V \ ΔA
∂λ = similar(λ)
# diag(tmp * V) without computing full matrix product
if eltype(∂λ) <: Real
broadcast!((a, b) -> sum(real prod, zip(a, b)), ∂λ, eachrow(tmp), eachcol(V))
else
broadcast!((a, b) -> sum(prod, zip(a, b)), ∂λ, eachrow(tmp), eachcol(V))
end
end
return λ, ∂λ
end

function rrule(::typeof(eigvals), A::StridedMatrix{T}; kwargs...) where {T<:Union{Real,Complex}}
F = eigen(A; kwargs...)
F, eigen_back = rrule(eigen, A; kwargs...)
λ = F.values
function eigvals_pullback(Δλ)
V = F.vectors
∂A = V' \ Diagonal(Δλ) * V'
return NO_FIELDS, T <: Real ? real(∂A) : ∂A
∂F = Composite{typeof(F)}(values = Δλ)
_, ∂A = eigen_back(∂F)
return NO_FIELDS, ∂A
end
eigvals_pullback(Δλ::AbstractZero) = (NO_FIELDS, Δλ)
return λ, eigvals_pullback
end

# adapted from LinearAlgebra.sorteig!
function _sorteig!_fwd(Δλ, λ, sortby)
Δλ isa AbstractZero && return (sort!(λ; by=sortby), Δλ)
if sortby !== nothing
p = sortperm(λ; alg=QuickSort, by=sortby)
permute!(λ, p)
permute!(Δλ, p)
end
return (λ, Δλ)
end

#####
##### `cholesky`
#####
Expand Down
233 changes: 228 additions & 5 deletions src/rulesets/LinearAlgebra/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,6 @@ function rrule(TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym)
end
rrule(::Type{Array}, A::LinearAlgebra.HermOrSym) = rrule(Matrix, A)

# Get type (Symmetric or Hermitian) from type or matrix
_symhermtype(::Type{<:Symmetric}) = Symmetric
_symhermtype(::Type{<:Hermitian}) = Hermitian
_symhermtype(A) = _symhermtype(typeof(A))

# for Ω = Matrix(A::HermOrSym), push forward ΔA to get ∂Ω
function _symherm_forward(A, ΔA)
TA = _symhermtype(A)
Expand Down Expand Up @@ -77,3 +72,231 @@ function _hermitian_back(ΔΩ::LinearAlgebra.AbstractTriangular, uplo)
return Matrix(uplo == 'U' ? ∂UL' : ∂UL)
end
end

#####
##### `eigen!`/`eigen`
#####

# rule is old but the usual references are
# real rules:
# Giles M. B., An extended collection of matrix derivative results for forward and reverse
# mode algorithmic differentiation.
# https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf.
# complex rules:
# Boeddeker C., Hanebrink P., et al, On the Computation of Complex-valued Gradients with
# Application to Statistically Optimum Beamforming. arXiv:1701.00392v2 [cs.NA]
#
# accounting for normalization convention appears in Boeddeker && Hanebrink.
# account for phase convention is unpublished.
function frule(
(_, ΔA),
::typeof(eigen!),
A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.BlasReal,<:StridedMatrix};
kwargs...,
)
F = eigen!(A; kwargs...)
ΔA isa AbstractZero && return F, ΔA
λ, U = F.values, F.vectors
tmp = U' * ΔA
∂K = mul!(ΔA.data, tmp, U)
∂Kdiag = @view ∂K[diagind(∂K)]
∂λ = real.(∂Kdiag)
∂K ./= λ' .- λ
fill!(∂Kdiag, 0)
∂U = mul!(tmp, U, ∂K)
_eigen_norm_phase_fwd!(∂U, A, U)
∂F = Composite{typeof(F)}(values = ∂λ, vectors = ∂U)
return F, ∂F
end

function rrule(
::typeof(eigen),
A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.BlasReal,<:StridedMatrix};
kwargs...,
)
F = eigen(A; kwargs...)
function eigen_pullback(ΔF::Composite{<:Eigen})
λ, U = F.values, F.vectors
Δλ, ΔU = ΔF.values, ΔF.vectors
ΔU = ΔU isa AbstractZero ? ΔU : copy(ΔU)
∂A = eigen_rev!(A, λ, U, Δλ, ΔU)
return NO_FIELDS, ∂A
end
eigen_pullback(ΔF::AbstractZero) = (NO_FIELDS, ΔF)
return F, eigen_pullback
end

# ∂U is overwritten if not an `AbstractZero`
function eigen_rev!(A::LinearAlgebra.RealHermSymComplexHerm, λ, U, ∂λ, ∂U)
∂λ isa AbstractZero && ∂U isa AbstractZero && return ∂λ + ∂U
∂A = similar(A, eltype(U))
tmp = ∂U
if ∂U isa AbstractZero
mul!(∂A.data, U, real.(∂λ) .* U')
else
_eigen_norm_phase_rev!(∂U, A, U)
∂K = mul!(∂A.data, U', ∂U)
∂K ./= λ' .- λ
∂K[diagind(∂K)] .= real.(∂λ)
mul!(tmp, ∂K, U')
mul!(∂A.data, U, tmp)
@inbounds _hermitrize!(∂A.data)
end
return ∂A
end

# NOTE: for small vₖ, the derivative of sign(vₖ) explodes, causing the tangents to become
# unstable even for phase-invariant programs. So for small vₖ we don't account for the phase
# in the gradient. Then derivatives are accurate for phase-invariant programs but inaccurate
# for phase-dependent programs that have low vₖ.

_eigen_norm_phase_fwd!(∂V, ::Union{Symmetric{T,S},Hermitian{T,S}}, V) where {T<:Real,S} = ∂V
function _eigen_norm_phase_fwd!(∂V, A::Hermitian{<:Complex}, V)
k = A.uplo === 'U' ? size(A, 1) : 1
ϵ = sqrt(eps(real(eltype(V))))
@inbounds for i in axes(V, 2)
v = @view V[:, i]
vₖ = real(v[k])
if abs(vₖ) > ϵ
∂v = @view ∂V[:, i]
∂v .-= v .* (im * (imag(∂v[k]) / vₖ))
end
end
return ∂V
end

_eigen_norm_phase_rev!(∂V, ::Union{Symmetric{T,S},Hermitian{T,S}}, V) where {T<:Real,S} = ∂V
function _eigen_norm_phase_rev!(∂V, A::Hermitian{<:Complex}, V)
k = A.uplo === 'U' ? size(A, 1) : 1
ϵ = sqrt(eps(real(eltype(V))))
@inbounds for i in axes(V, 2)
v = @view V[:, i]
vₖ = real(v[k])
if abs(vₖ) > ϵ
∂v = @view ∂V[:, i]
∂c = dot(v, ∂v)
∂v[k] -= im * (imag(∂c) / vₖ)
end
end
return ∂V
end

#####
##### `eigvals!`/`eigvals`
#####

function frule(
(_, ΔA),
::typeof(eigvals!),
A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.BlasReal,<:StridedMatrix};
kwargs...,
)
ΔA isa AbstractZero && return eigvals!(A; kwargs...), ΔA
F = eigen!(A; kwargs...)
λ, U = F.values, F.vectors
tmp = ΔA * U
# diag(U' * tmp) without computing matrix product
∂λ = similar(λ)
@inbounds for i in eachindex(λ)
∂λ[i] = @views real(dot(U[:, i], tmp[:, i]))
end
return λ, ∂λ
end

function rrule(
::typeof(eigvals),
A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.BlasReal,<:StridedMatrix};
kwargs...,
)
F, eigen_back = rrule(eigen, A; kwargs...)
λ = F.values
function eigvals_pullback(Δλ)
∂F = Composite{typeof(F)}(values = Δλ)
_, ∂A = eigen_back(∂F)
return NO_FIELDS, ∂A
end
return λ, eigvals_pullback
end

#####
##### `svd`
#####

# NOTE: rrule defined because the `svd` primal mutates after calling `eigen`.
# otherwise, this rule just applies the chain rule and can be removed when mutation
# is supported by reverse-mode AD packages
function rrule(::typeof(svd), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.BlasReal,<:StridedMatrix})
F = svd(A)
function svd_pullback(ΔF::Composite{<:SVD})
U, V = F.U, F.V
c = _svd_eigvals_sign!(similar(F.S), U, V)
λ = F.S .* c
∂λ = ΔF.S isa AbstractZero ? ΔF.S : ΔF.S .* c
if all(x -> x isa AbstractZero, (ΔF.U, ΔF.V, ΔF.Vt))
∂U = ΔF.U + ΔF.V + ΔF.Vt
else
∂U = ΔF.U .+ (ΔF.V .+ ΔF.Vt') .* c'
end
∂A = eigen_rev!(A, λ, U, ∂λ, ∂U)
return NO_FIELDS, ∂A
end
svd_pullback(ΔF::AbstractZero) = (NO_FIELDS, ΔF)
return F, svd_pullback
end

# given singular vectors, compute sign of eigenvalues corresponding to singular values
function _svd_eigvals_sign!(c, U, V)
n = size(U, 1)
@inbounds broadcast!(c, eachindex(c)) do i
u = @views U[:, i]
# find element not close to zero
# at least one element satisfies abs2(x) ≥ 1/n > 1/(n + 1)
k = findfirst(x -> (n + 1) * abs2(x) 1, u)
return sign(real(u[k]) * real(V[k, i]))
end
return c
end

#####
##### `svdvals`
#####

# NOTE: rrule defined because `svdvals` calls mutating `svdvals!` internally.
# can be removed when mutation is supported by reverse-mode AD packages
function rrule(::typeof(svdvals), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.BlasReal,<:StridedMatrix})
λ, back = rrule(eigvals, A)
S = abs.(λ)
p = sortperm(S; rev=true)
permute!(S, p)
function svdvals_pullback(ΔS)
∂λ = real.(ΔS)
invpermute!(∂λ, p)
∂λ .*= sign.(λ)
_, ∂A = back(∂λ)
return NO_FIELDS, unthunk(∂A)
end
svdvals_pullback(ΔS::AbstractZero) = (NO_FIELDS, ΔS)
return S, svdvals_pullback
end

#####
##### utilities
#####

# Get type (Symmetric or Hermitian) from type or matrix
_symhermtype(::Type{<:Symmetric}) = Symmetric
_symhermtype(::Type{<:Hermitian}) = Hermitian
_symhermtype(A) = _symhermtype(typeof(A))

# in-place hermitrize matrix
function _hermitrize!(A)
n = size(A, 1)
for i in 1:n
for j in (i + 1):n
A[i, j] = (A[i, j] + conj(A[j, i])) / 2
A[j, i] = conj(A[i, j])
end
A[i, i] = real(A[i, i])
end
return A
end
Loading

2 comments on commit 053a6c0

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: "Tag with name v0.7.43 already exists and points to a different commit"

Please sign in to comment.