diff --git a/stdlib/LinearAlgebra/src/symmetric.jl b/stdlib/LinearAlgebra/src/symmetric.jl index 0c19e26e3dd4a..789ad40a9de94 100644 --- a/stdlib/LinearAlgebra/src/symmetric.jl +++ b/stdlib/LinearAlgebra/src/symmetric.jl @@ -453,7 +453,7 @@ function triu(A::Symmetric, k::Integer=0) end end -for (T, trans, real) in [(:Symmetric, :transpose, :identity), (:Hermitian, :adjoint, :real)] +for (T, trans, real) in [(:Symmetric, :transpose, :identity), (:(Hermitian{<:Union{Real,Complex}}), :adjoint, :real)] @eval begin function dot(A::$T, B::$T) n = size(A, 2) diff --git a/stdlib/LinearAlgebra/test/symmetric.jl b/stdlib/LinearAlgebra/test/symmetric.jl index 224b7b31a50df..82236c2a677eb 100644 --- a/stdlib/LinearAlgebra/test/symmetric.jl +++ b/stdlib/LinearAlgebra/test/symmetric.jl @@ -4,6 +4,11 @@ module TestSymmetric using Test, LinearAlgebra, Random +const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test") + +isdefined(Main, :Quaternions) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "Quaternions.jl")) +using .Main.Quaternions + Random.seed!(1010) @testset "Pauli σ-matrices: $σ" for σ in map(Hermitian, @@ -462,6 +467,17 @@ end end end +# bug identified in PR #52318: dot products of quaternionic Hermitian matrices, +# or any number type where conj(a)*conj(b) ≠ conj(a*b): +@testset "dot Hermitian quaternion #52318" begin + A, B = [Quaternion.(randn(3,3), randn(3, 3), randn(3, 3), randn(3,3)) |> t -> t + t' for i in 1:2] + @test A == Hermitian(A) && B == Hermitian(B) + @test dot(A, B) ≈ dot(Hermitian(A), Hermitian(B)) + A, B = [Quaternion.(randn(3,3), randn(3, 3), randn(3, 3), randn(3,3)) |> t -> t + transpose(t) for i in 1:2] + @test A == Symmetric(A) && B == Symmetric(B) + @test dot(A, B) ≈ dot(Symmetric(A), Symmetric(B)) +end + #Issue #7647: test xsyevr, xheevr, xstevr drivers. @testset "Eigenvalues in interval for $(typeof(Mi7647))" for Mi7647 in (Symmetric(diagm(0 => 1.0:3.0)), diff --git a/test/testhelpers/Quaternions.jl b/test/testhelpers/Quaternions.jl index 1eddad322ec40..81b7a0c2d0121 100644 --- a/test/testhelpers/Quaternions.jl +++ b/test/testhelpers/Quaternions.jl @@ -20,6 +20,7 @@ Base.abs2(q::Quaternion) = q.s*q.s + q.v1*q.v1 + q.v2*q.v2 + q.v3*q.v3 Base.float(z::Quaternion{T}) where T = Quaternion(float(z.s), float(z.v1), float(z.v2), float(z.v3)) Base.abs(q::Quaternion) = sqrt(abs2(q)) Base.real(::Type{Quaternion{T}}) where {T} = T +Base.real(q::Quaternion) = q.s Base.conj(q::Quaternion) = Quaternion(q.s, -q.v1, -q.v2, -q.v3) Base.isfinite(q::Quaternion) = isfinite(q.s) & isfinite(q.v1) & isfinite(q.v2) & isfinite(q.v3) Base.zero(::Type{Quaternion{T}}) where T = Quaternion{T}(zero(T), zero(T), zero(T), zero(T)) @@ -33,7 +34,9 @@ Base.:(*)(q::Quaternion, w::Quaternion) = Quaternion(q.s*w.s - q.v1*w.v1 - q.v2* q.s*w.v2 - q.v1*w.v3 + q.v2*w.s + q.v3*w.v1, q.s*w.v3 + q.v1*w.v2 - q.v2*w.v1 + q.v3*w.s) Base.:(*)(q::Quaternion, r::Real) = Quaternion(q.s*r, q.v1*r, q.v2*r, q.v3*r) -Base.:(*)(q::Quaternion, b::Bool) = b * q # remove method ambiguity +Base.:(*)(q::Quaternion, r::Bool) = Quaternion(q.s*r, q.v1*r, q.v2*r, q.v3*r) # remove method ambiguity +Base.:(*)(r::Real, q::Quaternion) = q * r +Base.:(*)(r::Bool, q::Quaternion) = q * r # remove method ambiguity Base.:(/)(q::Quaternion, w::Quaternion) = q * conj(w) * (1.0 / abs2(w)) Base.:(\)(q::Quaternion, w::Quaternion) = conj(q) * w * (1.0 / abs2(q))