diff --git a/src/kronecker.jl b/src/kronecker.jl index f6beefe5..57e1079b 100644 --- a/src/kronecker.jl +++ b/src/kronecker.jl @@ -52,6 +52,9 @@ Base.kron(A::KroneckerMap, B::LinearMap) = KroneckerMap{promote_type(eltype(A), eltype(B))}(tuple(A.maps..., B)) Base.kron(A::KroneckerMap, B::KroneckerMap) = KroneckerMap{promote_type(eltype(A), eltype(B))}(tuple(A.maps..., B.maps...)) +Base.kron(A::ScaledMap, B::LinearMap) = A.λ * kron(A.lmap, B) +Base.kron(A::LinearMap{<:RealOrComplex}, B::ScaledMap) = B.λ * kron(A, B.lmap) +Base.kron(A::ScaledMap, B::ScaledMap) = (A.λ * B.λ) * kron(A.lmap, B.lmap) Base.kron(A::LinearMap, B::LinearMap, C::LinearMap, Ds::LinearMap...) = kron(kron(A, B), C, Ds...) Base.kron(A::AbstractMatrix, B::LinearMap) = kron(LinearMap(A), B) @@ -104,9 +107,10 @@ Base.:(==)(A::KroneckerMap, B::KroneckerMap) = (eltype(A) == eltype(B) && A.maps # multiplication helper functions ################# -@inline function _kronmul!(y, B, X, At, T) +@inline function _kronmul!(y, B, x, At, T) na, ma = size(At) mb, nb = size(B) + X = reshape(x, (nb, na)) v = zeros(T, ma) temp1 = similar(y, na) temp2 = similar(y, nb) @@ -119,14 +123,23 @@ Base.:(==)(A::KroneckerMap, B::KroneckerMap) = (eltype(A) == eltype(B) && A.maps end return y end -@inline function _kronmul!(y, B, X, At::Union{MatrixMap, UniformScalingMap}, T) +@inline function _kronmul!(y, B, x, At::UniformScalingMap, _) + na, ma = size(At) + mb, nb = size(B) + X = reshape(x, (nb, na)) + Y = reshape(y, (mb, ma)) + _unsafe_mul!(Y, B, X, At.λ, false) + return y +end +@inline function _kronmul!(y, B, x, At::MatrixMap, _) na, ma = size(At) mb, nb = size(B) + X = reshape(x, (nb, na)) Y = reshape(y, (mb, ma)) if nb*ma < mb*na - _unsafe_mul!(Y, B, Matrix(X*At)) + _unsafe_mul!(Y, B, X * At.lmap) else - _unsafe_mul!(Y, Matrix(B*X), _parent(At)) + _unsafe_mul!(Y, Matrix(B*X), At.lmap) end return y end @@ -140,18 +153,14 @@ const KroneckerMap2{T} = KroneckerMap{T, <:Tuple{LinearMap, LinearMap}} function _unsafe_mul!(y::AbstractVecOrMat, L::KroneckerMap2, x::AbstractVector) require_one_based_indexing(y) A, B = L.maps - X = LinearMap(reshape(x, (size(B, 2), size(A, 2))); - issymmetric = false, ishermitian = false, isposdef = false) - _kronmul!(y, B, X, transpose(A), eltype(L)) + _kronmul!(y, B, x, transpose(A), eltype(L)) return y end function _unsafe_mul!(y::AbstractVecOrMat, L::KroneckerMap, x::AbstractVector) require_one_based_indexing(y) A = first(L.maps) B = kron(Base.tail(L.maps)...) - X = LinearMap(reshape(x, (size(B, 2), size(A, 2))); - issymmetric = false, ishermitian = false, isposdef = false) - _kronmul!(y, B, X, transpose(A), eltype(L)) + _kronmul!(y, B, x, transpose(A), eltype(L)) return y end # mixed-product rule, prefer the right if possible diff --git a/src/uniformscalingmap.jl b/src/uniformscalingmap.jl index 350812bf..175c1be5 100644 --- a/src/uniformscalingmap.jl +++ b/src/uniformscalingmap.jl @@ -17,7 +17,6 @@ MulStyle(::UniformScalingMap) = FiveArg() # properties Base.size(A::UniformScalingMap) = (A.M, A.M) -_parent(A::UniformScalingMap) = A.λ Base.isreal(A::UniformScalingMap) = isreal(A.λ) LinearAlgebra.issymmetric(::UniformScalingMap) = true LinearAlgebra.ishermitian(A::UniformScalingMap) = isreal(A) diff --git a/src/wrappedmap.jl b/src/wrappedmap.jl index 4631c246..48528c13 100644 --- a/src/wrappedmap.jl +++ b/src/wrappedmap.jl @@ -38,7 +38,6 @@ Base.:(==)(A::MatrixMap, B::MatrixMap) = # properties Base.size(A::WrappedMap) = size(A.lmap) -_parent(A::WrappedMap) = A.lmap LinearAlgebra.issymmetric(A::WrappedMap) = A._issymmetric LinearAlgebra.ishermitian(A::WrappedMap) = A._ishermitian LinearAlgebra.isposdef(A::WrappedMap) = A._isposdef diff --git a/test/kronecker.jl b/test/kronecker.jl index 7acd6e78..1c511a6f 100644 --- a/test/kronecker.jl +++ b/test/kronecker.jl @@ -8,6 +8,10 @@ using Test, LinearMaps, LinearAlgebra, SparseArrays LA = LinearMap(A) LB = LinearMap(B) LK = @inferred kron(LA, LB) + @test kron(LA, 2LB) isa LinearMaps.ScaledMap + @test kron(3LA, LB) isa LinearMaps.ScaledMap + @test kron(3LA, 2LB) isa LinearMaps.ScaledMap + @test kron(3LA, 2LB).λ == 6 @test_throws ErrorException LinearMaps.KroneckerMap{Float64}((LA, LB)) @test occursin("6×6 LinearMaps.KroneckerMap{$(eltype(LK))}", sprint((t, s) -> show(t, "text/plain", s), LK)) @test @inferred size(LK) == size(K) diff --git a/test/uniformscalingmap.jl b/test/uniformscalingmap.jl index 8b158701..0a958cdb 100644 --- a/test/uniformscalingmap.jl +++ b/test/uniformscalingmap.jl @@ -11,7 +11,6 @@ using Test, LinearMaps, LinearAlgebra, BenchmarkTools w = similar(v) Id = @inferred LinearMap(I, 10) @test occursin("10×10 LinearMaps.UniformScalingMap{Bool}", sprint((t, s) -> show(t, "text/plain", s), Id)) - @test LinearMaps._parent(Id) == true @test_throws ErrorException LinearMaps.UniformScalingMap(1, 10, 20) @test_throws ErrorException LinearMaps.UniformScalingMap(1, (10, 20)) @test size(Id) == (10, 10) diff --git a/test/wrappedmap.jl b/test/wrappedmap.jl index 485ea99b..db36fb1b 100644 --- a/test/wrappedmap.jl +++ b/test/wrappedmap.jl @@ -7,7 +7,6 @@ using Test, LinearMaps, LinearAlgebra SB = B'B + I L = @inferred LinearMap{Float64}(A) @test occursin("10×20 LinearMaps.WrappedMap{Float64}", sprint((t, s) -> show(t, "text/plain", s), L)) - @test LinearMaps._parent(L) === A MA = @inferred LinearMap(SA) MB = @inferred LinearMap(SB) @test eltype(Matrix{Complex{Float32}}(LinearMap(A))) <: Complex