Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LinearAlgebra: copyto! between banded matrix types #54041

Merged
merged 2 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,22 @@ function Base.copy(tB::Transpose{<:Any,<:Bidiagonal})
return Bidiagonal(map(x -> copy.(transpose.(x)), (B.dv, B.ev))..., B.uplo == 'U' ? :L : :U)
end

# copyto! for matching axes
function _copyto_banded!(A::Bidiagonal, B::Bidiagonal)
A.dv .= B.dv
if A.uplo == B.uplo
A.ev .= B.ev
elseif iszero(B.ev) # diagonal source
A.ev .= zero.(A.ev)
else
zeroband = istriu(A) ? "lower" : "upper"
uplo = A.uplo
throw(ArgumentError(string("cannot set the ",
zeroband, " bidiagonal band to a nonzero value for uplo=:", uplo)))
end
return A
end

iszero(M::Bidiagonal) = iszero(M.dv) && iszero(M.ev)
isone(M::Bidiagonal) = all(isone, M.dv) && iszero(M.ev)
function istriu(M::Bidiagonal, k::Integer=0)
Expand Down Expand Up @@ -334,6 +350,8 @@ function istril(M::Bidiagonal, k::Integer=0)
end
end
isdiag(M::Bidiagonal) = iszero(M.ev)
issymmetric(M::Bidiagonal) = isdiag(M) && all(issymmetric, M.dv)
ishermitian(M::Bidiagonal) = isdiag(M) && all(ishermitian, M.dv)

function tril!(M::Bidiagonal{T}, k::Integer=0) where T
n = length(M.dv)
Expand Down
3 changes: 2 additions & 1 deletion stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ Diagonal{T}(::UndefInitializer, n::Integer) where T = Diagonal(Vector{T}(undef,
similar(D::Diagonal, ::Type{T}) where {T} = Diagonal(similar(D.diag, T))
similar(D::Diagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = similar(D.diag, T, dims)

copyto!(D1::Diagonal, D2::Diagonal) = (copyto!(D1.diag, D2.diag); D1)
# copyto! for matching axes
_copyto_banded!(D1::Diagonal, D2::Diagonal) = (copyto!(D1.diag, D2.diag); D1)

size(D::Diagonal) = (n = length(D.diag); (n,n))

Expand Down
86 changes: 86 additions & 0 deletions stdlib/LinearAlgebra/src/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,92 @@ isdiag(A::HermOrSym{<:Any,<:Diagonal}) = isdiag(parent(A))
dot(x::AbstractVector, A::RealHermSymComplexSym{<:Real,<:Diagonal}, y::AbstractVector) =
dot(x, A.data, y)

# O(N) implementations using the banded structure
function copyto!(dest::BandedMatrix, src::BandedMatrix)
if axes(dest) == axes(src)
_copyto_banded!(dest, src)
else
@invoke copyto!(dest::AbstractMatrix, src::AbstractMatrix)
end
return dest
end
function _copyto_banded!(T::Tridiagonal, D::Diagonal)
T.d .= D.diag
T.dl .= zero.(T.dl)
T.du .= zero.(T.du)
return T
end
function _copyto_banded!(SymT::SymTridiagonal, D::Diagonal)
issymmetric(D) || throw(ArgumentError("cannot copy a non-symmetric Diagonal matrix to a SymTridiagonal"))
SymT.dv .= D.diag
_ev = _evview(SymT)
_ev .= zero.(_ev)
return SymT
end
function _copyto_banded!(B::Bidiagonal, D::Diagonal)
B.dv .= D.diag
B.ev .= zero.(B.ev)
return B
end
function _copyto_banded!(D::Diagonal, B::Bidiagonal)
isdiag(B) ||
throw(ArgumentError("cannot copy a Bidiagonal with a non-zero off-diagonal band to a Diagonal"))
D.diag .= B.dv
return D
end
function _copyto_banded!(D::Diagonal, T::Tridiagonal)
isdiag(T) ||
throw(ArgumentError("cannot copy a Tridiagonal with a non-zero off-diagonal band to a Diagonal"))
D.diag .= T.d
return D
end
function _copyto_banded!(D::Diagonal, SymT::SymTridiagonal)
isdiag(SymT) ||
throw(ArgumentError("cannot copy a SymTridiagonal with a non-zero off-diagonal band to a Diagonal"))
# we broadcast identity for numbers using the fact that symmetric(x::Number) = x
# this potentially allows us to access faster copyto! paths
_symmetric = eltype(SymT) <: Number ? identity : symmetric
D.diag .= _symmetric.(SymT.dv)
return D
end
function _copyto_banded!(T::Tridiagonal, B::Bidiagonal)
T.d .= B.dv
if B.uplo == 'U'
T.du .= B.ev
T.dl .= zero.(T.dl)
else
T.dl .= B.ev
T.du .= zero.(T.du)
end
return T
end
function _copyto_banded!(SymT::SymTridiagonal, B::Bidiagonal)
issymmetric(B) || throw(ArgumentError("cannot copy a non-symmetric Bidiagonal matrix to a SymTridiagonal"))
SymT.dv .= B.dv
_ev = _evview(SymT)
_ev .= zero.(_ev)
return SymT
end
function _copyto_banded!(B::Bidiagonal, T::Tridiagonal)
if B.uplo == 'U' && !iszero(T.dl)
throw(ArgumentError("cannot copy a Tridiagonal with a non-zero subdiagonal to a Bidiagonal with uplo=:U"))
elseif B.uplo == 'L' && !iszero(T.du)
throw(ArgumentError("cannot copy a Tridiagonal with a non-zero superdiagonal to a Bidiagonal with uplo=:L"))
end
B.dv .= T.d
B.ev .= B.uplo == 'U' ? T.du : T.dl
return B
end
function _copyto_banded!(B::Bidiagonal, SymT::SymTridiagonal)
isdiag(SymT) ||
throw(ArgumentError("cannot copy a SymTridiagonal with a non-zero off-diagonal band to a Bidiagonal"))
# we broadcast identity for numbers using the fact that symmetric(x::Number) = x
# this potentially allows us to access faster copyto! paths
_symmetric = eltype(SymT) <: Number ? identity : symmetric
B.dv .= _symmetric.(SymT.dv)
return B
end

# equals and approx equals methods for structured matrices
# SymTridiagonal == Tridiagonal is already defined in tridiag.jl

Expand Down
29 changes: 27 additions & 2 deletions stdlib/LinearAlgebra/src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ axes(M::SymTridiagonal) = (ax = axes(M.dv, 1); (ax, ax))
similar(S::SymTridiagonal, ::Type{T}) where {T} = SymTridiagonal(similar(S.dv, T), similar(S.ev, T))
similar(S::SymTridiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = similar(S.dv, T, dims)

copyto!(dest::SymTridiagonal, src::SymTridiagonal) =
# copyto! for matching axes
_copyto_banded!(dest::SymTridiagonal, src::SymTridiagonal) =
(copyto!(dest.dv, src.dv); copyto!(dest.ev, _evview(src)); dest)

#Elementary operations
Expand Down Expand Up @@ -607,7 +608,13 @@ similar(M::Tridiagonal, ::Type{T}) where {T} = Tridiagonal(similar(M.dl, T), sim
similar(M::Tridiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = similar(M.d, T, dims)

# Operations on Tridiagonal matrices
copyto!(dest::Tridiagonal, src::Tridiagonal) = (copyto!(dest.dl, src.dl); copyto!(dest.d, src.d); copyto!(dest.du, src.du); dest)
# copyto! for matching axes
function _copyto_banded!(dest::Tridiagonal, src::Tridiagonal)
copyto!(dest.dl, src.dl)
copyto!(dest.d, src.d)
copyto!(dest.du, src.du)
dest
end

#Elementary operations
for func in (:conj, :copy, :real, :imag)
Expand Down Expand Up @@ -984,3 +991,21 @@ function ldiv!(A::Tridiagonal, B::AbstractVecOrMat)
end
return B
end

# combinations of Tridiagonal and Symtridiagonal
# copyto! for matching axes
function _copyto_banded!(A::Tridiagonal, B::SymTridiagonal)
Bev = _evview(B)
A.du .= Bev
# Broadcast identity for numbers to access the faster copyto! path
# This uses the fact that transpose(x::Number) = x and symmetric(x::Number) = x
A.dl .= (eltype(B) <: Number ? identity : transpose).(Bev)
A.d .= (eltype(B) <: Number ? identity : symmetric).(B.dv)
return A
end
function _copyto_banded!(A::SymTridiagonal, B::Tridiagonal)
issymmetric(B) || throw(ArgumentError("cannot copy a non-symmetric Tridiagonal matrix to a SymTridiagonal"))
A.dv .= B.d
_evview(A) .= B.du
return A
end
20 changes: 20 additions & 0 deletions stdlib/LinearAlgebra/test/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,26 @@ end
end
end

@testset "copyto!" begin
ev, dv = [1:4;], [1:5;]
B = Bidiagonal(dv, ev, :U)
B2 = copyto!(zero(B), B)
@test B2 == B
for (ul1, ul2) in ((:U, :L), (:L, :U))
B3 = Bidiagonal(dv, zero(ev), ul1)
B2 = Bidiagonal(zero(dv), zero(ev), ul2)
@test copyto!(B2, B3) == B3
end

@testset "mismatched sizes" begin
dv2 = [4; @view dv[2:end]]
@test copyto!(B, Bidiagonal([4], Int[], :U)) == Bidiagonal(dv2, ev, :U)
@test copyto!(B, Bidiagonal([4], Int[], :L)) == Bidiagonal(dv2, ev, :U)
@test copyto!(B, Bidiagonal(Int[], Int[], :U)) == Bidiagonal(dv, ev, :U)
@test copyto!(B, Bidiagonal(Int[], Int[], :L)) == Bidiagonal(dv, ev, :U)
end
end

@testset "copyto! with UniformScaling" begin
@testset "Fill" begin
for len in (4, InfiniteArrays.Infinity())
Expand Down
Loading