From fb2e0e070d24f7cf0c8687ca72ba84002f0f8f8e Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 28 Jun 2023 23:19:38 +0530 Subject: [PATCH] kron for AbstractFill --- Project.toml | 2 +- src/FillArrays.jl | 2 +- src/fillalgebra.jl | 19 +++++++++++++++++++ test/runtests.jl | 22 ++++++++++++++++++++++ 4 files changed, 43 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 44a5e5e8..c4182cfd 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FillArrays" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.2.1" +version = "1.3.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/FillArrays.jl b/src/FillArrays.jl index af3ee9af..3d90f3bb 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -10,7 +10,7 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert, import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!, dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AdjointAbsVec, TransposeAbsVec, - issymmetric, ishermitian, AdjOrTransAbsVec, checksquare, mul! + issymmetric, ishermitian, AdjOrTransAbsVec, checksquare, mul!, kron import Base.Broadcast: broadcasted, DefaultArrayStyle, broadcast_shape diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index 3dd222c3..b11fa5a2 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -434,3 +434,22 @@ fillzero(::Type{Zeros{T,N,AXIS}}, n, m) where {T,N,AXIS} = Zeros{T,N,AXIS}((n, m fillzero(::Type{F}, n, m) where F = throw(ArgumentError("Cannot create a zero array of type $F")) diagzero(D::Diagonal{F}, i, j) where F<:AbstractFill = fillzero(F, axes(D.diag[i], 1), axes(D.diag[j], 2)) + +# kron + +_kronsize(f::AbstractFillVector, g::AbstractFillVector) = (size(f,1)*size(g,1),) +_kronsize(f::AbstractFillVecOrMat, g::AbstractFillVecOrMat) = (size(f,1)*size(g,1), size(f,2)*size(g,2)) +function _kron(f::AbstractFill, g::AbstractFill, sz) + v = getindex_value(f)*getindex_value(g) + Fill(v, sz) +end +function _kron(f::Zeros, g::Zeros, sz) + Zeros{promote_type(eltype(f), eltype(g))}(sz) +end +function _kron(f::Ones, g::Ones, sz) + Ones{promote_type(eltype(f), eltype(g))}(sz) +end +function kron(f::AbstractFillVecOrMat, g::AbstractFillVecOrMat) + sz = _kronsize(f, g) + _kron(f, g, sz) +end diff --git a/test/runtests.jl b/test/runtests.jl index df8d25db..bd736745 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1482,6 +1482,28 @@ end end end +@testset "kron" begin + for T in (Fill, Zeros, Ones), sz in ((2,), (2,2)) + f = T{Int}((T == Fill ? (3,sz...) : sz)...) + g = Ones{Int}(2) + fc = collect(f) + gc = collect(g) + @test kron(f, f) == kron(fc, fc) + @test kron(f, f) isa T{Int,length(sz)} + @test kron(f, g) == kron(fc, gc) + @test kron(f, g) isa AbstractFill{Int,length(sz)} + @test kron(g, f) == kron(gc, fc) + @test kron(g, f) isa AbstractFill{Int,length(sz)} + @test kron(f, f .+ 0.5) == kron(fc, fc .+ 0.5) + @test kron(f, f .+ 0.5) isa AbstractFill{Float64,length(sz)} + @test kron(f, g .+ 0.5) isa AbstractFill{Float64,length(sz)} + end + + E = Eye(2) + @test kron(E, E) isa typeof(E) + @test kron(E, E) == kron(collect(E), collect(E)) +end + @testset "dot products" begin n = 15 o = Ones(1:n)