From 6e5ea12828393c94b518036d377ba23788f03a59 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 23 Dec 2024 11:57:59 +0530 Subject: [PATCH] `sqrt`, `cbrt` and `log` for dense diagonal matrices (#1156) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR improves performance by only applying the functions to the diagonal elements: ```julia julia> A = diagm(0=>ones(100)); julia> @btime log($A); 364.163 μs (22 allocations: 401.62 KiB) # master 13.528 μs (7 allocations: 80.02 KiB) # this PR ``` Similar improvements for `sqrt` and `cbrt` as well. --- src/dense.jl | 12 +++++++++--- test/dense.jl | 16 ++++++++++++++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/dense.jl b/src/dense.jl index 7939e18..e68f7e9 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -890,7 +890,9 @@ julia> log(A) """ function log(A::AbstractMatrix) # If possible, use diagonalization - if ishermitian(A) + if isdiag(A) + return applydiagonal(log, A) + elseif ishermitian(A) logHermA = log(Hermitian(A)) return ishermitian(logHermA) ? copytri!(parent(logHermA), 'U', true) : parent(logHermA) elseif istriu(A) @@ -969,7 +971,9 @@ sqrt(::AbstractMatrix) function sqrt(A::AbstractMatrix{T}) where {T<:Union{Real,Complex}} if checksquare(A) == 0 - return copy(A) + return copy(float(A)) + elseif isdiag(A) + return applydiagonal(sqrt, A) elseif ishermitian(A) sqrtHermA = sqrt(Hermitian(A)) return ishermitian(sqrtHermA) ? copytri!(parent(sqrtHermA), 'U', true) : parent(sqrtHermA) @@ -1035,7 +1039,9 @@ true """ function cbrt(A::AbstractMatrix{<:Real}) if checksquare(A) == 0 - return copy(A) + return copy(float(A)) + elseif isdiag(A) + return applydiagonal(cbrt, A) elseif issymmetric(A) return cbrt(Symmetric(A, :U)) else diff --git a/test/dense.jl b/test/dense.jl index 41f14f0..236587d 100644 --- a/test/dense.jl +++ b/test/dense.jl @@ -817,6 +817,7 @@ end A13 = convert(Matrix{elty}, [2 0; 0 2]) @test typeof(log(A13)) == Array{elty, 2} + @test exp(log(A13)) ≈ log(exp(A13)) ≈ A13 T = elty == Float64 ? Symmetric : Hermitian @test typeof(log(T(A13))) == T{elty, Array{elty, 2}} @@ -968,6 +969,10 @@ end @test typeof(sqrt(A8)) == Matrix{elty} end end +@testset "sqrt for diagonal" begin + A = diagm(0 => [1, 2, 3]) + @test sqrt(A)^2 ≈ A +end @testset "issue #40141" begin x = [-1 -eps() 0 0; eps() -1 0 0; 0 0 -1 -eps(); 0 0 eps() -1] @@ -1280,6 +1285,7 @@ end T = cbrt(Symmetric(S,:U)) @test T*T*T ≈ S @test eltype(S) == eltype(T) + @test cbrt(Array(Symmetric(S,:U))) == T # Real valued symmetric S = (A -> (A+A')/2)(randn(N,N)) T = cbrt(Symmetric(S,:L)) @@ -1300,6 +1306,16 @@ end T = cbrt(A) @test T*T*T ≈ A @test eltype(A) == eltype(T) + @testset "diagonal" begin + A = diagm(0 => [1, 2, 3]) + @test cbrt(A)^3 ≈ A + end + @testset "empty" begin + A = Matrix{Float64}(undef, 0, 0) + @test cbrt(A) == A + A = Matrix{Int}(undef, 0, 0) + @test cbrt(A) isa Matrix{Float64} + end end @testset "tr" begin