diff --git a/src/FillArrays.jl b/src/FillArrays.jl index 60c22c33..3222a8f7 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -10,7 +10,7 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert, import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!, dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AdjointAbsVec, TransposeAbsVec, - issymmetric, ishermitian, AdjOrTransAbsVec, checksquare, mul! + issymmetric, ishermitian, AdjOrTransAbsVec, checksquare, mul!, kron import Base.Broadcast: broadcasted, DefaultArrayStyle, broadcast_shape @@ -411,6 +411,7 @@ Base.replace_in_print_matrix(A::RectDiagonal, i::Integer, j::Integer, s::Abstrac const RectOrDiagonal{T,V,Axes} = Union{RectDiagonal{T,V,Axes}, Diagonal{T,V}} +const RectDiagonalEye{T} = RectDiagonal{T,<:Ones{T,1}} const SquareEye{T,Axes} = Diagonal{T,Ones{T,1,Tuple{Axes}}} const Eye{T,Axes} = RectOrDiagonal{T,Ones{T,1,Tuple{Axes}}} diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index 3dd222c3..6d5c8394 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -434,3 +434,23 @@ fillzero(::Type{Zeros{T,N,AXIS}}, n, m) where {T,N,AXIS} = Zeros{T,N,AXIS}((n, m fillzero(::Type{F}, n, m) where F = throw(ArgumentError("Cannot create a zero array of type $F")) diagzero(D::Diagonal{F}, i, j) where F<:AbstractFill = fillzero(F, axes(D.diag[i], 1), axes(D.diag[j], 2)) + +# kron + +_kronsize(f::AbstractFillVector, g::AbstractFillVector) = (size(f,1)*size(g,1),) +_kronsize(f::AbstractFillVecOrMat, g::AbstractFillVecOrMat) = (size(f,1)*size(g,1), size(f,2)*size(g,2)) +function _kron(f::AbstractFill, g::AbstractFill, sz) + v = getindex_value(f)*getindex_value(g) + Fill(v, sz) +end +function _kron(f::Zeros, g::Zeros, sz) + Zeros{promote_type(eltype(f), eltype(g))}(sz) +end +function _kron(f::Ones, g::Ones, sz) + Ones{promote_type(eltype(f), eltype(g))}(sz) +end +function kron(f::AbstractFillVecOrMat, g::AbstractFillVecOrMat) + sz = _kronsize(f, g) + _kron(f, g, sz) +end +kron(E1::RectDiagonalEye, E2::RectDiagonalEye) = kron(sparse(E1), sparse(E2)) diff --git a/test/runtests.jl b/test/runtests.jl index f28f64bd..1ffd2d5a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1489,6 +1489,53 @@ end end end +@testset "kron" begin + for T in (Fill, Zeros, Ones), sz in ((2,), (2,2)) + f = T{Int}((T == Fill ? (3,sz...) : sz)...) + g = Ones{Int}(2) + z = Zeros{Int}(2) + fc = collect(f) + gc = collect(g) + zc = collect(z) + @test kron(f, f) == kron(fc, fc) + @test kron(f, f) isa T{Int,length(sz)} + @test kron(f, g) == kron(fc, gc) + @test kron(f, g) isa AbstractFill{Int,length(sz)} + @test kron(g, f) == kron(gc, fc) + @test kron(g, f) isa AbstractFill{Int,length(sz)} + @test kron(f, z) == kron(fc, zc) + @test kron(f, z) isa AbstractFill{Int,length(sz)} + @test kron(z, f) == kron(zc, fc) + @test kron(z, f) isa AbstractFill{Int,length(sz)} + @test kron(f, f .+ 0.5) == kron(fc, fc .+ 0.5) + @test kron(f, f .+ 0.5) isa AbstractFill{Float64,length(sz)} + @test kron(f, g .+ 0.5) isa AbstractFill{Float64,length(sz)} + end + + for m in (Fill(2,2,2), "a"), sz in ((2,2), (2,)) + f = Fill(m, sz) + g = fill(m, sz) + @test kron(f, f) == kron(g, g) + end + + @test_throws MethodError kron(Fill("a",2), Zeros(1)) # can't multiply String and Float64 + + E = Eye(2) + K = kron(E, E) + @test K isa Diagonal + if VERSION >= v"1.9" + @test K isa typeof(E) + end + C = collect(E) + @test K == kron(C, C) + + E = Eye(2,3) + K = kron(E, E) + C = collect(E) + @test K == kron(C, C) + @test issparse(kron(E,E)) +end + @testset "dot products" begin n = 15 o = Ones(1:n)