Skip to content

Commit

Permalink
Fix bug in pinv (#45009)
Browse files Browse the repository at this point in the history
(cherry picked from commit b4eb88a)
  • Loading branch information
hyrodium authored and KristofferC committed May 16, 2022
1 parent f5e292a commit 6d4b8d0
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 37 deletions.
7 changes: 4 additions & 3 deletions stdlib/LinearAlgebra/src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1449,12 +1449,13 @@ function pinv(A::AbstractMatrix{T}; atol::Real = 0.0, rtol::Real = (eps(real(flo
return similar(A, Tout, (n, m))
end
if isdiag(A)
ind = diagind(A)
dA = view(A, ind)
indA = diagind(A)
dA = view(A, indA)
maxabsA = maximum(abs, dA)
tol = max(rtol * maxabsA, atol)
B = fill!(similar(A, Tout, (n, m)), 0)
B[ind] .= (x -> abs(x) > tol ? pinv(x) : zero(x)).(dA)
indB = diagind(B)
B[indB] .= (x -> abs(x) > tol ? pinv(x) : zero(x)).(dA)
return B
end
SVD = svd(A)
Expand Down
61 changes: 27 additions & 34 deletions stdlib/LinearAlgebra/test/pinv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,39 +63,23 @@ function tridiag(T::Type, m::Integer, n::Integer)
end
tridiag(m::Integer, n::Integer) = tridiag(Float64, m::Integer, n::Integer)

function randn_float64(m::Integer, n::Integer)
a=randn(m,n)
b = Matrix{Float64}(undef, m, n)
for i=1:n
for j=1:m
b[j,i]=convert(Float64,a[j,i])
end
end
return b
end

function randn_float32(m::Integer, n::Integer)
a=randn(m,n)
b = Matrix{Float32}(undef, m, n)
for i=1:n
for j=1:m
b[j,i]=convert(Float32,a[j,i])
end
end
return b
end

function test_pinv(a,tol1,tol2)
m,n = size(a)

function test_pinv(a,m,n,tol1,tol2,tol3)
apinv = @inferred pinv(a)

@test size(apinv) == (n,m)
@test norm(a*apinv*a-a)/norm(a) 0 atol=tol1
x0 = randn(n); b = a*x0; x = apinv*b
@test norm(apinv*a*apinv-apinv)/norm(apinv) 0 atol=tol1
b = a*randn(n)
x = apinv*b
@test norm(a*x-b)/norm(b) 0 atol=tol1
apinv = pinv(a,sqrt(eps(real(one(eltype(a))))))

apinv = @inferred pinv(a,sqrt(eps(real(one(eltype(a))))))
@test size(apinv) == (n,m)
@test norm(a*apinv*a-a)/norm(a) 0 atol=tol2
x0 = randn(n); b = a*x0; x = apinv*b
@test norm(apinv*a*apinv-apinv)/norm(apinv) 0 atol=tol2
b = a*randn(n)
x = apinv*b
@test norm(a*x-b)/norm(b) 0 atol=tol2
end

Expand All @@ -104,28 +88,25 @@ end
default_tol = (real(one(eltya))) * max(m,n) * 10
tol1 = 1e-2
tol2 = 1e-5
tol3 = 1e-5
if real(eltya) == Float32
tol1 = 1e0
tol2 = 1e-2
tol3 = 1e-2
end
@testset "dense/ill-conditioned matrix" begin
### a = randn_float64(m,n) * hilb(eltya,n)
a = hilb(eltya, m, n)
test_pinv(a, m, n, tol1, tol2, tol3)
test_pinv(a, tol1, tol2)
end
@testset "dense/diagonal matrix" begin
a = onediag(eltya, m, n)
test_pinv(a, m, n, default_tol, default_tol, default_tol)
test_pinv(a, default_tol, default_tol)
end
@testset "dense/tri-diagonal matrix" begin
a = tridiag(eltya, m, n)
test_pinv(a, m, n, default_tol, tol2, default_tol)
test_pinv(a, default_tol, tol2)
end
@testset "Diagonal matrix" begin
a = onediag_sparse(eltya, m)
test_pinv(a, m, m, default_tol, default_tol, default_tol)
test_pinv(a, default_tol, default_tol)
end
@testset "Vector" begin
a = rand(eltya, m)
Expand Down Expand Up @@ -164,6 +145,18 @@ end
@test C ones(2,2)
end

@testset "non-square diagonal matrices" begin
A = eltya[1 0 ; 0 1 ; 0 0]
B = pinv(A)
@test A*B*A A
@test B*A*B B

A = eltya[1 0 0 ; 0 1 0]
B = pinv(A)
@test A*B*A A
@test B*A*B B
end

if eltya <: LinearAlgebra.BlasReal
@testset "sub-normal numbers/vectors/matrices" begin
a = pinv(floatmin(eltya)/100)
Expand Down

0 comments on commit 6d4b8d0

Please sign in to comment.