Skip to content

Commit

Permalink
Adapt eigen and eigvals rules from JuliaDiff#321
Browse files Browse the repository at this point in the history
  • Loading branch information
sethaxen committed Dec 6, 2020
1 parent 76bfac3 commit 539e9d4
Showing 1 changed file with 94 additions and 0 deletions.
94 changes: 94 additions & 0 deletions src/rulesets/LinearAlgebra/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,97 @@ function _hermitian_back(ΔΩ::LinearAlgebra.AbstractTriangular, uplo)
return Matrix(uplo == 'U' ? ∂UL' : ∂UL)
end
end

#####
##### `eigen`
#####

function frule((_, ΔA), ::typeof(eigen), A::LinearAlgebra.RealHermSymComplexHerm; sortby::Union{Function,Nothing}=nothing)
F = eigen(A; sortby=sortby)
ΔA isa AbstractZero && return F, ΔA
λ, U = F.values, F.vectors
tmp = U' * ΔA
∂K = tmp * U
∂Kdiag = @view ∂K[diagind(∂K)]
∂λ = real.(∂Kdiag)
∂K ./= λ' .- λ
fill!(∂Kdiag, 0)
∂U = mul!(tmp, U, ∂K)
_eigen_norm_phase_fwd!(∂U, A, U)
∂F = Composite{typeof(F)}(values = ∂λ, vectors = ∂U)
return F, ∂F
end

function rrule(::typeof(eigen), A::LinearAlgebra.RealHermSymComplexHerm; sortby::Union{Function,Nothing}=nothing)
F = eigen(A; sortby=sortby)
function eigen_pullback(ΔF::Composite{<:Eigen})
λ, U = F.values, F.vectors
Δλ, ΔU = ΔF.values, ΔF.vectors
if ΔU isa AbstractZero
Δλ isa AbstractZero && return (NO_FIELDS, Δλ + ΔU)
∂K = Diagonal(Δλ)
∂A = U * ∂K * U'
else
∂U = copyto!(similar(ΔU), ΔU)
_eigen_norm_phase_rev!(∂U, A, U)
∂K = U' * ∂U
∂K ./= λ' .- λ
∂K[diagind(∂K)] = Δλ
∂A = mul!(∂K, U * ∂K, U')
end
return NO_FIELDS, ∂A
end
eigen_pullback(ΔF::AbstractZero) = (NO_FIELDS, ΔF)
return F, eigen_pullback
end

_eigen_norm_phase_fwd!(∂V, ::LinearAlgebra.RealHermSym, V) = ∂V
function _eigen_norm_phase_fwd!(∂V, A::Hermitian, V)
k = A.uplo === 'U' ? size(A, 1) : 1
@inbounds for i in axes(V, 2)
vᵢ = @view V[:, i]
vₖᵢ, ∂vₖᵢ = real(vᵢ[k]), ∂V[k, i]
∂vᵢ .-= vᵢ .* (imag(∂vₖᵢ) / ifelse(iszero(vₖᵢ), one(vₖᵢ), vₖᵢ))
end
return ∂V
end

_eigen_norm_phase_rev!(∂V, ::LinearAlgebra.RealHermSym, V) = ∂V
function _eigen_norm_phase_rev!(∂V, A::Hermitian, V)
k = A.uplo === 'U' ? size(A, 1) : 1
@inbounds for i in axes(V, 2)
vᵢ, ∂vᵢ = @views V[:, i], ∂V[:, i]
vₖᵢ = real(vᵢ[k])
∂cᵢ = dot(vᵢ, ∂vᵢ)
∂vᵢ[k] -= im * (imag(∂cᵢ) / ifelse(iszero(vₖᵢ), one(vₖᵢ), vₖᵢ))
end
return ∂V
end

#####
##### `eigvals`
#####

function frule((_, ΔA), ::typeof(eigvals), A::LinearAlgebra.RealHermSymComplexHerm)
ΔA isa AbstractZero && return eigvals(A), ΔA
F = eigen(A)
λ, U = F.values, F.vectors
tmp = ΔA * U
∂λ = similar(λ)
@inbounds for i in eachindex(λ)
∂λ[i] = real(dot(U[:, i], tmp[:, i]))
end
return λ, ∂λ
end

function rrule(::typeof(eigvals), A::LinearAlgebra.RealHermSymComplexHerm)
F = eigen(A)
λ = F.values
function eigvals_pullback(Δλ)
U = F.vectors
∂A = U * Diagonal(Δλ) * U'
return NO_FIELDS, ∂A
end
eigvals_pullback(Δλ::AbstractZero) = (NO_FIELDS, Δλ)
return λ, eigvals_pullback
end

0 comments on commit 539e9d4

Please sign in to comment.