Skip to content

Commit

Permalink
Improve performance of Kronecker products (#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch authored Jan 11, 2021
1 parent 51f5997 commit ef64037
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 14 deletions.
29 changes: 19 additions & 10 deletions src/kronecker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ Base.kron(A::KroneckerMap, B::LinearMap) =
KroneckerMap{promote_type(eltype(A), eltype(B))}(tuple(A.maps..., B))
Base.kron(A::KroneckerMap, B::KroneckerMap) =
KroneckerMap{promote_type(eltype(A), eltype(B))}(tuple(A.maps..., B.maps...))
Base.kron(A::ScaledMap, B::LinearMap) = A.λ * kron(A.lmap, B)
Base.kron(A::LinearMap{<:RealOrComplex}, B::ScaledMap) = B.λ * kron(A, B.lmap)
Base.kron(A::ScaledMap, B::ScaledMap) = (A.λ * B.λ) * kron(A.lmap, B.lmap)
Base.kron(A::LinearMap, B::LinearMap, C::LinearMap, Ds::LinearMap...) =
kron(kron(A, B), C, Ds...)
Base.kron(A::AbstractMatrix, B::LinearMap) = kron(LinearMap(A), B)
Expand Down Expand Up @@ -104,9 +107,10 @@ Base.:(==)(A::KroneckerMap, B::KroneckerMap) = (eltype(A) == eltype(B) && A.maps
# multiplication helper functions
#################

@inline function _kronmul!(y, B, X, At, T)
@inline function _kronmul!(y, B, x, At, T)
na, ma = size(At)
mb, nb = size(B)
X = reshape(x, (nb, na))
v = zeros(T, ma)
temp1 = similar(y, na)
temp2 = similar(y, nb)
Expand All @@ -119,14 +123,23 @@ Base.:(==)(A::KroneckerMap, B::KroneckerMap) = (eltype(A) == eltype(B) && A.maps
end
return y
end
@inline function _kronmul!(y, B, X, At::Union{MatrixMap, UniformScalingMap}, T)
@inline function _kronmul!(y, B, x, At::UniformScalingMap, _)
na, ma = size(At)
mb, nb = size(B)
X = reshape(x, (nb, na))
Y = reshape(y, (mb, ma))
_unsafe_mul!(Y, B, X, At.λ, false)
return y
end
@inline function _kronmul!(y, B, x, At::MatrixMap, _)
na, ma = size(At)
mb, nb = size(B)
X = reshape(x, (nb, na))
Y = reshape(y, (mb, ma))
if nb*ma < mb*na
_unsafe_mul!(Y, B, Matrix(X*At))
_unsafe_mul!(Y, B, X * At.lmap)
else
_unsafe_mul!(Y, Matrix(B*X), _parent(At))
_unsafe_mul!(Y, Matrix(B*X), At.lmap)
end
return y
end
Expand All @@ -140,18 +153,14 @@ const KroneckerMap2{T} = KroneckerMap{T, <:Tuple{LinearMap, LinearMap}}
function _unsafe_mul!(y::AbstractVecOrMat, L::KroneckerMap2, x::AbstractVector)
require_one_based_indexing(y)
A, B = L.maps
X = LinearMap(reshape(x, (size(B, 2), size(A, 2)));
issymmetric = false, ishermitian = false, isposdef = false)
_kronmul!(y, B, X, transpose(A), eltype(L))
_kronmul!(y, B, x, transpose(A), eltype(L))
return y
end
function _unsafe_mul!(y::AbstractVecOrMat, L::KroneckerMap, x::AbstractVector)
require_one_based_indexing(y)
A = first(L.maps)
B = kron(Base.tail(L.maps)...)
X = LinearMap(reshape(x, (size(B, 2), size(A, 2)));
issymmetric = false, ishermitian = false, isposdef = false)
_kronmul!(y, B, X, transpose(A), eltype(L))
_kronmul!(y, B, x, transpose(A), eltype(L))
return y
end
# mixed-product rule, prefer the right if possible
Expand Down
1 change: 0 additions & 1 deletion src/uniformscalingmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ MulStyle(::UniformScalingMap) = FiveArg()

# properties
Base.size(A::UniformScalingMap) = (A.M, A.M)
_parent(A::UniformScalingMap) = A.λ
Base.isreal(A::UniformScalingMap) = isreal(A.λ)
LinearAlgebra.issymmetric(::UniformScalingMap) = true
LinearAlgebra.ishermitian(A::UniformScalingMap) = isreal(A)
Expand Down
1 change: 0 additions & 1 deletion src/wrappedmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ Base.:(==)(A::MatrixMap, B::MatrixMap) =

# properties
Base.size(A::WrappedMap) = size(A.lmap)
_parent(A::WrappedMap) = A.lmap
LinearAlgebra.issymmetric(A::WrappedMap) = A._issymmetric
LinearAlgebra.ishermitian(A::WrappedMap) = A._ishermitian
LinearAlgebra.isposdef(A::WrappedMap) = A._isposdef
Expand Down
4 changes: 4 additions & 0 deletions test/kronecker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ using Test, LinearMaps, LinearAlgebra, SparseArrays
LA = LinearMap(A)
LB = LinearMap(B)
LK = @inferred kron(LA, LB)
@test kron(LA, 2LB) isa LinearMaps.ScaledMap
@test kron(3LA, LB) isa LinearMaps.ScaledMap
@test kron(3LA, 2LB) isa LinearMaps.ScaledMap
@test kron(3LA, 2LB).λ == 6
@test_throws ErrorException LinearMaps.KroneckerMap{Float64}((LA, LB))
@test occursin("6×6 LinearMaps.KroneckerMap{$(eltype(LK))}", sprint((t, s) -> show(t, "text/plain", s), LK))
@test @inferred size(LK) == size(K)
Expand Down
1 change: 0 additions & 1 deletion test/uniformscalingmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ using Test, LinearMaps, LinearAlgebra, BenchmarkTools
w = similar(v)
Id = @inferred LinearMap(I, 10)
@test occursin("10×10 LinearMaps.UniformScalingMap{Bool}", sprint((t, s) -> show(t, "text/plain", s), Id))
@test LinearMaps._parent(Id) == true
@test_throws ErrorException LinearMaps.UniformScalingMap(1, 10, 20)
@test_throws ErrorException LinearMaps.UniformScalingMap(1, (10, 20))
@test size(Id) == (10, 10)
Expand Down
1 change: 0 additions & 1 deletion test/wrappedmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ using Test, LinearMaps, LinearAlgebra
SB = B'B + I
L = @inferred LinearMap{Float64}(A)
@test occursin("10×20 LinearMaps.WrappedMap{Float64}", sprint((t, s) -> show(t, "text/plain", s), L))
@test LinearMaps._parent(L) === A
MA = @inferred LinearMap(SA)
MB = @inferred LinearMap(SB)
@test eltype(Matrix{Complex{Float32}}(LinearMap(A))) <: Complex
Expand Down

0 comments on commit ef64037

Please sign in to comment.