Skip to content

Commit

Permalink
Add unwrapping mechanism for triangular matrices
Browse files Browse the repository at this point in the history
(cherry picked from commit e67ddaa)
  • Loading branch information
dkarrasch authored and KristofferC committed Jul 17, 2023
1 parent d3276e1 commit 6aa1ab3
Show file tree
Hide file tree
Showing 6 changed files with 550 additions and 504 deletions.
8 changes: 7 additions & 1 deletion stdlib/LinearAlgebra/src/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ end
Adjoint(A) = Adjoint{Base.promote_op(adjoint,eltype(A)),typeof(A)}(A)
Transpose(A) = Transpose{Base.promote_op(transpose,eltype(A)),typeof(A)}(A)

# TODO: remove, is already replaced by wrapperop
"""
adj_or_trans(::AbstractArray) -> adjoint|transpose|identity
adj_or_trans(::Type{<:AbstractArray}) -> adjoint|transpose|identity
Return [`adjoint`](@ref) from an `Adjoint` type or object and
[`transpose`](@ref) from a `Transpose` type or object. Otherwise,
return [`identity`](@ref). Note that `Adjoint` and `Transpose` have
Expand All @@ -94,9 +94,15 @@ inplace_adj_or_trans(::Type{<:AbstractArray}) = copyto!
inplace_adj_or_trans(::Type{<:Adjoint}) = adjoint!
inplace_adj_or_trans(::Type{<:Transpose}) = transpose!

# unwraps Adjoint, Transpose, Symmetric, Hermitian
_unwrap(A::Adjoint) = parent(A)
_unwrap(A::Transpose) = parent(A)

# unwraps Adjoint and Transpose only
_unwrap_at(A) = A
_unwrap_at(A::Adjoint) = parent(A)
_unwrap_at(A::Transpose) = parent(A)

Base.dataids(A::Union{Adjoint, Transpose}) = Base.dataids(A.parent)
Base.unaliascopy(A::Union{Adjoint,Transpose}) = typeof(A)(Base.unaliascopy(A.parent))

Expand Down
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,7 @@ function ldiv!(c::AbstractVecOrMat, A::Bidiagonal, b::AbstractVecOrMat)
end
ldiv!(A::AdjOrTrans{<:Any,<:Bidiagonal}, b::AbstractVecOrMat) = @inline ldiv!(b, A, b)
ldiv!(c::AbstractVecOrMat, A::AdjOrTrans{<:Any,<:Bidiagonal}, b::AbstractVecOrMat) =
(t = adj_or_trans(A); _rdiv!(t(c), t(b), t(A)); return c)
(t = wrapperop(A); _rdiv!(t(c), t(b), t(A)); return c)

### Generic promotion methods and fallbacks
\(A::Bidiagonal, B::AbstractVecOrMat) = ldiv!(_initarray(\, eltype(A), eltype(B), B), A, B)
Expand Down Expand Up @@ -833,7 +833,7 @@ end
rdiv!(A::AbstractMatrix, B::Bidiagonal) = @inline _rdiv!(A, A, B)
rdiv!(A::AbstractMatrix, B::AdjOrTrans{<:Any,<:Bidiagonal}) = @inline _rdiv!(A, A, B)
_rdiv!(C::AbstractMatrix, A::AbstractMatrix, B::AdjOrTrans{<:Any,<:Bidiagonal}) =
(t = adj_or_trans(B); ldiv!(t(C), t(B), t(A)); return C)
(t = wrapperop(B); ldiv!(t(C), t(B), t(A)); return C)

/(A::AbstractMatrix, B::Bidiagonal) = _rdiv!(_initarray(/, eltype(A), eltype(B), A), A, B)

Expand Down
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/hessenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,11 @@ for T = (:Number, :UniformScaling, :Diagonal)
end

function *(H::UpperHessenberg, U::UpperOrUnitUpperTriangular)
HH = _mulmattri!(_initarray(*, eltype(H), eltype(U), H), H, U)
HH = mul!(_initarray(*, eltype(H), eltype(U), H), H, U)
UpperHessenberg(HH)
end
function *(U::UpperOrUnitUpperTriangular, H::UpperHessenberg)
HH = _multrimat!(_initarray(*, eltype(U), eltype(H), H), U, H)
HH = mul!(_initarray(*, eltype(U), eltype(H), H), U, H)
UpperHessenberg(HH)
end

Expand Down
16 changes: 6 additions & 10 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@ AdjOrTransStridedMat{T} = Union{Adjoint{<:Any, <:StridedMatrix{T}}, Transpose{<:
StridedMaybeAdjOrTransMat{T} = Union{StridedMatrix{T}, Adjoint{<:Any, <:StridedMatrix{T}}, Transpose{<:Any, <:StridedMatrix{T}}}
StridedMaybeAdjOrTransVecOrMat{T} = Union{StridedVecOrMat{T}, AdjOrTrans{<:Any, <:StridedVecOrMat{T}}}

_parent(A) = A
_parent(A::Adjoint) = parent(A)
_parent(A::Transpose) = parent(A)

matprod(x, y) = x*y + x*y

# dot products
Expand Down Expand Up @@ -115,14 +111,14 @@ end
function (*)(A::StridedMaybeAdjOrTransMat{<:BlasReal}, B::StridedMaybeAdjOrTransMat{<:BlasReal})
TS = promote_type(eltype(A), eltype(B))
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
wrapperop(A)(convert(AbstractArray{TS}, _parent(A))),
wrapperop(B)(convert(AbstractArray{TS}, _parent(B))))
wrapperop(A)(convert(AbstractArray{TS}, _unwrap(A))),
wrapperop(B)(convert(AbstractArray{TS}, _unwrap(B))))
end
function (*)(A::StridedMaybeAdjOrTransMat{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:BlasComplex})
TS = promote_type(eltype(A), eltype(B))
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
wrapperop(A)(convert(AbstractArray{TS}, _parent(A))),
wrapperop(B)(convert(AbstractArray{TS}, _parent(B))))
wrapperop(A)(convert(AbstractArray{TS}, _unwrap(A))),
wrapperop(B)(convert(AbstractArray{TS}, _unwrap(B))))
end

# Complex Matrix times real matrix: We use that it is generally faster to reinterpret the
Expand All @@ -131,13 +127,13 @@ function (*)(A::StridedMatrix{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:Bla
TS = promote_type(eltype(A), eltype(B))
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
convert(AbstractArray{TS}, A),
wrapperop(B)(convert(AbstractArray{real(TS)}, _parent(B))))
wrapperop(B)(convert(AbstractArray{real(TS)}, _unwrap(B))))
end
function (*)(A::AdjOrTransStridedMat{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:BlasReal})
TS = promote_type(eltype(A), eltype(B))
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
copymutable_oftype(A, TS), # remove AdjOrTrans to use reinterpret trick below
wrapperop(B)(convert(AbstractArray{real(TS)}, _parent(B))))
wrapperop(B)(convert(AbstractArray{real(TS)}, _unwrap(B))))
end
# the following case doesn't seem to benefit from the translation A*B = (B' * A')'
function (*)(A::StridedMatrix{<:BlasReal}, B::StridedMatrix{<:BlasComplex})
Expand Down
Loading

0 comments on commit 6aa1ab3

Please sign in to comment.