Skip to content

Commit

Permalink
Fix lu! requiring real matrices to be square. (#1513)
Browse files Browse the repository at this point in the history
  • Loading branch information
mjrodgers authored Aug 18, 2023
1 parent 407b0e0 commit 1d474ab
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 94 deletions.
25 changes: 1 addition & 24 deletions src/arb/ComplexMat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -568,30 +568,7 @@ function lu!(P::Generic.Perm, x::ComplexMat)
r == 0 && error("Could not find $(nrows(x)) invertible pivot elements")
P.d .+= 1
inv!(P)
return nrows(x)
end

function lu(P::Generic.Perm, x::ComplexMat)
ncols(x) != nrows(x) && error("Matrix must be square")
parent(P).n != nrows(x) && error("Permutation does not match matrix")
R = base_ring(x)
L = similar(x)
U = deepcopy(x)
n = ncols(x)
lu!(P, U)
for i = 1:n
for j = 1:n
if i > j
L[i, j] = U[i, j]
U[i, j] = R()
elseif i == j
L[i, j] = one(R)
else
L[i, j] = R()
end
end
end
return L, U
return min(nrows(x), ncols(x))
end

function solve!(z::ComplexMat, x::ComplexMat, y::ComplexMat)
Expand Down
25 changes: 1 addition & 24 deletions src/arb/RealMat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,6 @@ end
###############################################################################

function lu!(P::Generic.Perm, x::RealMat)
ncols(x) != nrows(x) && error("Matrix must be square")
parent(P).n != nrows(x) && error("Permutation does not match matrix")
P.d .-= 1
r = ccall((:arb_mat_lu, libarb), Cint,
Expand All @@ -511,29 +510,7 @@ function lu!(P::Generic.Perm, x::RealMat)
r == 0 && error("Could not find $(nrows(x)) invertible pivot elements")
P.d .+= 1
inv!(P)
return nrows(x)
end

function lu(x::RealMat, P = SymmetricGroup(nrows(x)))
p = one(P)
R = base_ring(x)
L = similar(x)
U = deepcopy(x)
n = ncols(x)
r = lu!(p, U)
for i = 1:n
for j = 1:n
if i > j
L[i, j] = U[i, j]
U[i, j] = R()
elseif i == j
L[i, j] = one(R)
else
L[i, j] = R()
end
end
end
return r, p, L, U
return min(nrows(x), ncols(x))
end

function solve!(z::RealMat, x::RealMat, y::RealMat)
Expand Down
23 changes: 0 additions & 23 deletions src/arb/acb_mat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -574,29 +574,6 @@ function lu!(P::Generic.Perm, x::acb_mat)
return nrows(x)
end

function lu(P::Generic.Perm, x::acb_mat)
ncols(x) != nrows(x) && error("Matrix must be square")
parent(P).n != nrows(x) && error("Permutation does not match matrix")
R = base_ring(x)
L = similar(x)
U = deepcopy(x)
n = ncols(x)
lu!(P, U)
for i = 1:n
for j = 1:n
if i > j
L[i, j] = U[i, j]
U[i, j] = R()
elseif i == j
L[i, j] = one(R)
else
L[i, j] = R()
end
end
end
return L, U
end

function solve!(z::acb_mat, x::acb_mat, y::acb_mat)
r = ccall((:acb_mat_solve, libarb), Cint,
(Ref{acb_mat}, Ref{acb_mat}, Ref{acb_mat}, Int),
Expand Down
23 changes: 0 additions & 23 deletions src/arb/arb_mat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,6 @@ end
###############################################################################

function lu!(P::Generic.Perm, x::arb_mat)
ncols(x) != nrows(x) && error("Matrix must be square")
parent(P).n != nrows(x) && error("Permutation does not match matrix")
P.d .-= 1
r = ccall((:arb_mat_lu, libarb), Cint,
Expand All @@ -518,28 +517,6 @@ function lu!(P::Generic.Perm, x::arb_mat)
return nrows(x)
end

function lu(x::arb_mat, P = SymmetricGroup(nrows(x)))
p = one(P)
R = base_ring(x)
L = similar(x)
U = deepcopy(x)
n = ncols(x)
r = lu!(p, U)
for i = 1:n
for j = 1:n
if i > j
L[i, j] = U[i, j]
U[i, j] = R()
elseif i == j
L[i, j] = one(R)
else
L[i, j] = R()
end
end
end
return r, p, L, U
end

function solve!(z::arb_mat, x::arb_mat, y::arb_mat)
r = ccall((:arb_mat_solve, libarb), Cint,
(Ref{arb_mat}, Ref{arb_mat}, Ref{arb_mat}, Int),
Expand Down
12 changes: 12 additions & 0 deletions test/arb/ComplexMat-test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,18 @@ end
@test overlaps(B, C)
end

@testset "ComplexMat.lu_nonsquare" begin
S = matrix_space(CC, 2, 3)

A = S(["1.0 +/- 0.01" "1.0 +/- 0.01" "1.0 +/- 0.01";
"1.0 +/- 0.01" "0.0 +/- 0.01" "-1.0 +/- 0.01"])

r, p, L, U = lu(A)

@test overlaps(L*U, p*A)
@test r == 2
end

@testset "ComplexMat.linear_solving" begin
S = matrix_space(CC, 3, 3)
T = matrix_space(ZZ, 3, 3)
Expand Down
12 changes: 12 additions & 0 deletions test/arb/RealMat-test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,18 @@ end
@test overlaps(B, C)
end

@testset "RealMat.lu_nonsquare" begin
S = matrix_space(RR, 2, 3)

A = S(["1.0 +/- 0.01" "1.0 +/- 0.01" "1.0 +/- 0.01";
"1.0 +/- 0.01" "0.0 +/- 0.01" "-1.0 +/- 0.01"])

r, p, L, U = lu(A)

@test overlaps(L*U, p*A)
@test r == 2
end

@testset "RealMat.linear_solving" begin
S = matrix_space(RR, 3, 3)
T = matrix_space(ZZ, 3, 3)
Expand Down
12 changes: 12 additions & 0 deletions test/arb/acb_mat-test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,18 @@ end
@test overlaps(B, C)
end

@testset "acb_mat.lu_nonsquare" begin
S = matrix_space(CC, 2, 3)

A = S(["1.0 +/- 0.01" "1.0 +/- 0.01" "1.0 +/- 0.01";
"1.0 +/- 0.01" "0.0 +/- 0.01" "-1.0 +/- 0.01"])

r, p, L, U = lu(A)

@test overlaps(L*U, p*A)
@test r == 2
end

@testset "acb_mat.linear_solving" begin
S = matrix_space(CC, 3, 3)
T = matrix_space(ZZ, 3, 3)
Expand Down
12 changes: 12 additions & 0 deletions test/arb/arb_mat-test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,18 @@ end
@test overlaps(B, C)
end

@testset "arb_mat.lu_nonsquare" begin
S = matrix_space(RR, 2, 3)

A = S(["1.0 +/- 0.01" "1.0 +/- 0.01" "1.0 +/- 0.01";
"1.0 +/- 0.01" "0.0 +/- 0.01" "-1.0 +/- 0.01"])

r, p, L, U = lu(A)

@test overlaps(L*U, p*A)
@test r == 2
end

@testset "arb_mat.linear_solving" begin
S = matrix_space(RR, 3, 3)
T = matrix_space(ZZ, 3, 3)
Expand Down

0 comments on commit 1d474ab

Please sign in to comment.