Skip to content

Commit

Permalink
fix #29392: HermOrSym should preserve structure when scaled with Numb…
Browse files Browse the repository at this point in the history
…ers.
  • Loading branch information
fredrikekre committed Oct 18, 2018
1 parent 0bab957 commit 4837cf9
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 6 deletions.
13 changes: 7 additions & 6 deletions stdlib/LinearAlgebra/src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -440,12 +440,13 @@ mul!(C::StridedMatrix{T}, A::StridedMatrix{T}, B::Hermitian{T,<:StridedMatrix})
*(adjA::Adjoint{<:Any,<:RealHermSymComplexHerm}, B::AbstractTriangular) = adjA.parent * B
*(A::AbstractTriangular, adjB::Adjoint{<:Any,<:RealHermSymComplexHerm}) = A * adjB.parent

for T in (:Symmetric, :Hermitian), op in (:*, :/)
# Deal with an ambiguous case
@eval ($op)(A::$T, x::Bool) = ($T)(($op)(A.data, x), sym_uplo(A.uplo))
S = T == :Hermitian ? :Real : :Number
@eval ($op)(A::$T, x::$S) = ($T)(($op)(A.data, x), sym_uplo(A.uplo))
end
# Scaling with Number
*(A::Symmetric, x::Number) = Symmetric(A.data*x, sym_uplo(A.uplo))
*(x::Number, A::Symmetric) = Symmetric(x*A.data, sym_uplo(A.uplo))
*(A::Hermitian, x::Real) = Hermitian(A.data*x, sym_uplo(A.uplo))
*(x::Real, A::Hermitian) = Hermitian(x*A.data, sym_uplo(A.uplo))
/(A::Symmetric, x::Number) = Symmetric(A.data/x, sym_uplo(A.uplo))
/(A::Hermitian, x::Real) = Hermitian(A.data/x, sym_uplo(A.uplo))

function factorize(A::HermOrSym{T}) where T
TT = typeof(sqrt(oneunit(T)))
Expand Down
44 changes: 44 additions & 0 deletions stdlib/LinearAlgebra/test/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -517,4 +517,48 @@ end
@test Hermitian(A, :U)[1,1] == Hermitian(A, :L)[1,1] == real(A[1,1])
end

@testset "issue #29392: SymOrHerm scaled with Number" begin
R = rand(Float64, 2, 2); C = rand(ComplexF64, 2, 2)
# Symmetric * Real, Real * Symmetric
A = Symmetric(R); x = 2.0
@test (A * x)::Symmetric == (x * A)::Symmetric
A = Symmetric(C); x = 2.0
@test (A * x)::Symmetric == (x * A)::Symmetric
# Symmetric * Complex, Complex * Symmetrics
A = Symmetric(R); x = 2.0im
@test (A * x)::Symmetric == (x * A)::Symmetric
A = Symmetric(C); x = 2.0im
@test (A * x)::Symmetric == (x * A)::Symmetric
# Hermitian * Real, Real * Hermitian
A = Hermitian(R); x = 2.0
@test (A * x)::Hermitian == (x * A)::Hermitian
A = Hermitian(C); x = 2.0
@test (A * x)::Hermitian == (x * A)::Hermitian
# Hermitian * Complex, Complex * Hermitian
A = Hermitian(R); x = 2.0im
@test (A * x)::Matrix == (x * A)::Matrix
A = Hermitian(C); x = 2.0im
@test (A * x)::Matrix == (x * A)::Matrix
# Symmetric / Real
A = Symmetric(R); x = 2.0
@test (A / x)::Symmetric == Matrix(A) / x
A = Symmetric(C); x = 2.0
@test (A / x)::Symmetric == Matrix(A) / x
# Symmetric / Complex
A = Symmetric(R); x = 2.0im
@test (A / x)::Symmetric == Matrix(A) / x
A = Symmetric(C); x = 2.0im
@test (A / x)::Symmetric == Matrix(A) / x
# Hermitian / Real
A = Hermitian(R); x = 2.0
@test (A / x)::Hermitian == Matrix(A) / x
A = Hermitian(C); x = 2.0
@test (A / x)::Hermitian == Matrix(A) / x
# Hermitian / Complex
A = Hermitian(R); x = 2.0im
@test (A / x)::Matrix == Matrix(A) / x
A = Hermitian(C); x = 2.0im
@test (A / x)::Matrix == Matrix(A) / x
end

end # module TestSymmetric

0 comments on commit 4837cf9

Please sign in to comment.