diff --git a/Project.toml b/Project.toml index 9767e53..0a4bd75 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "LazyArrays" uuid = "5078a376-72f3-5289-bfd5-ec5146d43c02" -version = "2.1.9" +version = "2.2" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/ext/LazyArraysBandedMatricesExt.jl b/ext/LazyArraysBandedMatricesExt.jl index dc1ff8b..e596b2c 100644 --- a/ext/LazyArraysBandedMatricesExt.jl +++ b/ext/LazyArraysBandedMatricesExt.jl @@ -567,8 +567,8 @@ copy(M::Mul{<:DiagonalLayout, <:BandedLazyLayouts}) = simplify(M) copy(M::Mul{<:Union{ZerosLayout,DualLayout{ZerosLayout}}, <:BandedLazyLayouts}) = copy(mulreduce(M)) copy(M::Mul{<:BandedLazyLayouts, <:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = copy(mulreduce(M)) -simplifiable(::Mul{<:BandedLazyLayouts, <:DiagonalLayout{<:OnesLayout}}) = Val(true) -simplifiable(::Mul{<:DiagonalLayout{<:OnesLayout}, <:BandedLazyLayouts}) = Val(true) +simplifiable(::Mul{<:BandedLayouts, <:DiagonalLayout{<:OnesLayout}}) = Val(true) +simplifiable(::Mul{<:DiagonalLayout{<:OnesLayout}, <:BandedLayouts}) = Val(true) copy(M::Mul{<:BandedLazyLayouts, <:DiagonalLayout{<:OnesLayout}}) = _copy_oftype(M.A, eltype(M)) copy(M::Mul{<:DiagonalLayout{<:OnesLayout}, <:BandedLazyLayouts}) = _copy_oftype(M.B, eltype(M)) @@ -644,8 +644,8 @@ copy(M::Mul{<:Union{AbstractStridedLayout,DualLayout{<:AbstractStridedLayout}}, simplifiable(M::Mul{<:AbstractInvLayout{<:BandedLayouts}, <:Union{PaddedColumns,PaddedLayout,AbstractStridedLayout}}) = Val(true) simplifiable(M::Mul{<:Union{PaddedRows,PaddedLayout,DualLayout{<:PaddedRows}}, <:BandedLayouts}) = Val(true) simplifiable(M::Mul{<:BandedLayouts, <:Union{PaddedColumns,PaddedLayout}}) = Val(true) -simplifiable(M::Mul{<:Union{AbstractStridedLayout,DualLayout{<:AbstractStridedLayout}}, <:BandedLazyLayouts}) = Val(true) -simplifiable(M::Mul{<:BandedLazyLayouts, <:AbstractStridedLayout}) = Val(true) +simplifiable(M::Mul{<:Union{AbstractStridedLayout,DualLayout{<:AbstractStridedLayout}}, <:BandedLayouts}) = Val(true) +simplifiable(M::Mul{<:BandedLayouts, <:AbstractStridedLayout}) = Val(true) copy(L::Ldiv{ApplyBandedLayout{typeof(*)}, Lay}) where Lay = copy(Ldiv{ApplyLayout{typeof(*)},Lay}(L.A, L.B)) copy(L::Ldiv{ApplyBandedLayout{typeof(*)}, Lay}) where {Lay<:AbstractLazyLayout} = copy(Ldiv{ApplyLayout{typeof(*)},Lay}(L.A, L.B)) diff --git a/src/LazyArrays.jl b/src/LazyArrays.jl index fb6fe84..008a627 100644 --- a/src/LazyArrays.jl +++ b/src/LazyArrays.jl @@ -38,7 +38,7 @@ import ArrayLayouts: AbstractQLayout, Dot, Dotu, Ldiv, Lmul, MatMulMatAdd, MatMu hermitianlayout, layout_getindex, layout_replace_in_print_matrix, ldivaxes, materialize, materialize!, mulreduce, reshapedlayout, rowsupport, scalarone, scalarzero, sub_materialize, sublayout, symmetriclayout, symtridiagonallayout, transposelayout, triangulardata, - triangularlayout, tridiagonallayout, zero!, transtype + triangularlayout, tridiagonallayout, zero!, transtype, OnesLayout import FillArrays: AbstractFill, getindex_value diff --git a/src/linalg/mul.jl b/src/linalg/mul.jl index 1cb944b..39afb86 100644 --- a/src/linalg/mul.jl +++ b/src/linalg/mul.jl @@ -371,18 +371,23 @@ applylayout(::Type{typeof(*)}, ::DualLayout{Lay}, args...) where Lay = DualLayou transtype(A::MulMatrix) = transtype(first(A.args)) #TODO: Why not all DiagonalLayout? -@inline simplifiable(M::Mul{<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true) -@inline simplifiable(M::Mul{<:DiagonalLayout{<:AbstractFillLayout},<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true) -@inline simplifiable(M::Mul{<:Any,<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true) +@inline simplifiable(::Mul{<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true) +@inline simplifiable(::Mul{<:DiagonalLayout{<:AbstractFillLayout},<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true) +@inline simplifiable(::Mul{<:DiagonalLayout{<:OnesLayout},<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true) +@inline simplifiable(::Mul{<:DiagonalLayout{<:AbstractFillLayout},<:DiagonalLayout{<:OnesLayout}}) = Val(true) +@inline simplifiable(::Mul{<:DiagonalLayout{<:OnesLayout}, <:DiagonalLayout{<:OnesLayout}}) = Val(true) # ambiguity +@inline simplifiable(::Mul{<:Any,<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true) +@inline simplifiable(::Mul{<:Union{AbstractStridedLayout,DualLayout{<:AbstractStridedLayout}}, <:DiagonalLayout{<:AbstractFillLayout}}) = Val(true) +@inline simplifiable(::Mul{<:DiagonalLayout{<:AbstractFillLayout}, <:AbstractStridedLayout}) = Val(true) @inline copy(M::Mul{<:DiagonalLayout{<:AbstractFillLayout},<:LazyLayouts}) = copy(mulreduce(M)) @inline copy(M::Mul{<:LazyLayouts,<:DiagonalLayout{<:AbstractFillLayout}}) = copy(mulreduce(M)) @inline copy(M::Mul{BroadcastLayout{typeof(*)},<:DiagonalLayout{<:AbstractFillLayout}}) = copy(mulreduce(M)) -@inline simplifiable(M::Mul{<:Union{ZerosLayout,DualLayout{ZerosLayout}},<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true) -@inline simplifiable(M::Mul{<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true) -@inline simplifiable(M::Mul{<:Any,<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true) -@inline simplifiable(M::Mul{<:Union{ZerosLayout,DualLayout{ZerosLayout}},<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true) -@inline simplifiable(M::Mul{<:DiagonalLayout{<:AbstractFillLayout},<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true) +@inline simplifiable(::Mul{<:Union{ZerosLayout,DualLayout{ZerosLayout}},<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true) +@inline simplifiable(::Mul{<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true) +@inline simplifiable(::Mul{<:Any,<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true) +@inline simplifiable(::Mul{<:Union{ZerosLayout,DualLayout{ZerosLayout}},<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true) +@inline simplifiable(::Mul{<:DiagonalLayout{<:AbstractFillLayout},<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true) # inv diff --git a/test/multests.jl b/test/multests.jl index cddfc3f..1735ada 100644 --- a/test/multests.jl +++ b/test/multests.jl @@ -1,4 +1,6 @@ -using Test, LinearAlgebra, LazyArrays, StaticArrays, FillArrays, Base64, Random, BandedMatrices +using Test, LinearAlgebra, LazyArrays, FillArrays, Base64, Random +using BandedMatrices +using StaticArrays import LazyArrays: MulAdd, MemoryLayout, DenseColumnMajor, DiagonalLayout, SymTridiagonalLayout, Add, AddArray, MulStyle, MulAddStyle, Applied, ApplyStyle, Lmul, ApplyArrayBroadcastStyle, DefaultArrayApplyStyle, Rmul, ApplyLayout, arguments, colsupport, rowsupport, lazymaterialize @@ -1176,8 +1178,23 @@ end @testset "simplifiable tests" begin A = randn(5,5) + b = randn(5) + D = Diagonal(Fill(2,5)) + E = Eye(5) @test LazyArrays.simplifiable(*, A) == Val(false) @test LazyArrays.simplify(Applied(*, A, A)) == A*A + @test LazyArrays.simplifiable(*, A, D) == Val(true) + @test LazyArrays.simplifiable(*, D, A) == Val(true) + @test LazyArrays.simplifiable(*, b', D) == Val(true) + @test LazyArrays.simplifiable(*, D, b) == Val(true) + @test LazyArrays.simplifiable(*, A, E) == Val(true) + @test LazyArrays.simplifiable(*, E, A) == Val(true) + @test LazyArrays.simplifiable(*, b', E) == Val(true) + @test LazyArrays.simplifiable(*, E, b) == Val(true) + @test LazyArrays.simplifiable(*, E, D) == Val(true) + @test LazyArrays.simplifiable(*, D, E) == Val(true) + @test LazyArrays.simplifiable(*, D, D) == Val(true) + @test LazyArrays.simplifiable(*, E, E) == Val(true) end end