Skip to content

Commit

Permalink
Banded * Strided is always simplifiable (JuliaArrays#343)
Browse files Browse the repository at this point in the history
* Banded * Strided is always simplifiable

* fix tests

* Update mul.jl

* Update mul.jl

* Update multests.jl

* v2.2

* Update mul.jl

* Update multests.jl
  • Loading branch information
dlfivefifty authored Aug 7, 2024
1 parent 223d86f commit 7777232
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "LazyArrays"
uuid = "5078a376-72f3-5289-bfd5-ec5146d43c02"
version = "2.1.9"
version = "2.2"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
8 changes: 4 additions & 4 deletions ext/LazyArraysBandedMatricesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -567,8 +567,8 @@ copy(M::Mul{<:DiagonalLayout, <:BandedLazyLayouts}) = simplify(M)
copy(M::Mul{<:Union{ZerosLayout,DualLayout{ZerosLayout}}, <:BandedLazyLayouts}) = copy(mulreduce(M))
copy(M::Mul{<:BandedLazyLayouts, <:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = copy(mulreduce(M))

simplifiable(::Mul{<:BandedLazyLayouts, <:DiagonalLayout{<:OnesLayout}}) = Val(true)
simplifiable(::Mul{<:DiagonalLayout{<:OnesLayout}, <:BandedLazyLayouts}) = Val(true)
simplifiable(::Mul{<:BandedLayouts, <:DiagonalLayout{<:OnesLayout}}) = Val(true)
simplifiable(::Mul{<:DiagonalLayout{<:OnesLayout}, <:BandedLayouts}) = Val(true)
copy(M::Mul{<:BandedLazyLayouts, <:DiagonalLayout{<:OnesLayout}}) = _copy_oftype(M.A, eltype(M))
copy(M::Mul{<:DiagonalLayout{<:OnesLayout}, <:BandedLazyLayouts}) = _copy_oftype(M.B, eltype(M))

Expand Down Expand Up @@ -644,8 +644,8 @@ copy(M::Mul{<:Union{AbstractStridedLayout,DualLayout{<:AbstractStridedLayout}},
simplifiable(M::Mul{<:AbstractInvLayout{<:BandedLayouts}, <:Union{PaddedColumns,PaddedLayout,AbstractStridedLayout}}) = Val(true)
simplifiable(M::Mul{<:Union{PaddedRows,PaddedLayout,DualLayout{<:PaddedRows}}, <:BandedLayouts}) = Val(true)
simplifiable(M::Mul{<:BandedLayouts, <:Union{PaddedColumns,PaddedLayout}}) = Val(true)
simplifiable(M::Mul{<:Union{AbstractStridedLayout,DualLayout{<:AbstractStridedLayout}}, <:BandedLazyLayouts}) = Val(true)
simplifiable(M::Mul{<:BandedLazyLayouts, <:AbstractStridedLayout}) = Val(true)
simplifiable(M::Mul{<:Union{AbstractStridedLayout,DualLayout{<:AbstractStridedLayout}}, <:BandedLayouts}) = Val(true)
simplifiable(M::Mul{<:BandedLayouts, <:AbstractStridedLayout}) = Val(true)

copy(L::Ldiv{ApplyBandedLayout{typeof(*)}, Lay}) where Lay = copy(Ldiv{ApplyLayout{typeof(*)},Lay}(L.A, L.B))
copy(L::Ldiv{ApplyBandedLayout{typeof(*)}, Lay}) where {Lay<:AbstractLazyLayout} = copy(Ldiv{ApplyLayout{typeof(*)},Lay}(L.A, L.B))
Expand Down
2 changes: 1 addition & 1 deletion src/LazyArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import ArrayLayouts: AbstractQLayout, Dot, Dotu, Ldiv, Lmul, MatMulMatAdd, MatMu
hermitianlayout, layout_getindex, layout_replace_in_print_matrix, ldivaxes, materialize,
materialize!, mulreduce, reshapedlayout, rowsupport, scalarone, scalarzero, sub_materialize,
sublayout, symmetriclayout, symtridiagonallayout, transposelayout, triangulardata,
triangularlayout, tridiagonallayout, zero!, transtype
triangularlayout, tridiagonallayout, zero!, transtype, OnesLayout

import FillArrays: AbstractFill, getindex_value

Expand Down
21 changes: 13 additions & 8 deletions src/linalg/mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -371,18 +371,23 @@ applylayout(::Type{typeof(*)}, ::DualLayout{Lay}, args...) where Lay = DualLayou
transtype(A::MulMatrix) = transtype(first(A.args))

#TODO: Why not all DiagonalLayout?
@inline simplifiable(M::Mul{<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
@inline simplifiable(M::Mul{<:DiagonalLayout{<:AbstractFillLayout},<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
@inline simplifiable(M::Mul{<:Any,<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
@inline simplifiable(::Mul{<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
@inline simplifiable(::Mul{<:DiagonalLayout{<:AbstractFillLayout},<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
@inline simplifiable(::Mul{<:DiagonalLayout{<:OnesLayout},<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
@inline simplifiable(::Mul{<:DiagonalLayout{<:AbstractFillLayout},<:DiagonalLayout{<:OnesLayout}}) = Val(true)
@inline simplifiable(::Mul{<:DiagonalLayout{<:OnesLayout}, <:DiagonalLayout{<:OnesLayout}}) = Val(true) # ambiguity
@inline simplifiable(::Mul{<:Any,<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
@inline simplifiable(::Mul{<:Union{AbstractStridedLayout,DualLayout{<:AbstractStridedLayout}}, <:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
@inline simplifiable(::Mul{<:DiagonalLayout{<:AbstractFillLayout}, <:AbstractStridedLayout}) = Val(true)
@inline copy(M::Mul{<:DiagonalLayout{<:AbstractFillLayout},<:LazyLayouts}) = copy(mulreduce(M))
@inline copy(M::Mul{<:LazyLayouts,<:DiagonalLayout{<:AbstractFillLayout}}) = copy(mulreduce(M))
@inline copy(M::Mul{BroadcastLayout{typeof(*)},<:DiagonalLayout{<:AbstractFillLayout}}) = copy(mulreduce(M))

@inline simplifiable(M::Mul{<:Union{ZerosLayout,DualLayout{ZerosLayout}},<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true)
@inline simplifiable(M::Mul{<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true)
@inline simplifiable(M::Mul{<:Any,<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true)
@inline simplifiable(M::Mul{<:Union{ZerosLayout,DualLayout{ZerosLayout}},<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
@inline simplifiable(M::Mul{<:DiagonalLayout{<:AbstractFillLayout},<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true)
@inline simplifiable(::Mul{<:Union{ZerosLayout,DualLayout{ZerosLayout}},<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true)
@inline simplifiable(::Mul{<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true)
@inline simplifiable(::Mul{<:Any,<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true)
@inline simplifiable(::Mul{<:Union{ZerosLayout,DualLayout{ZerosLayout}},<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
@inline simplifiable(::Mul{<:DiagonalLayout{<:AbstractFillLayout},<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true)


# inv
Expand Down
19 changes: 18 additions & 1 deletion test/multests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using Test, LinearAlgebra, LazyArrays, StaticArrays, FillArrays, Base64, Random, BandedMatrices
using Test, LinearAlgebra, LazyArrays, FillArrays, Base64, Random
using BandedMatrices
using StaticArrays
import LazyArrays: MulAdd, MemoryLayout, DenseColumnMajor, DiagonalLayout, SymTridiagonalLayout, Add, AddArray,
MulStyle, MulAddStyle, Applied, ApplyStyle, Lmul, ApplyArrayBroadcastStyle, DefaultArrayApplyStyle,
Rmul, ApplyLayout, arguments, colsupport, rowsupport, lazymaterialize
Expand Down Expand Up @@ -1176,8 +1178,23 @@ end

@testset "simplifiable tests" begin
A = randn(5,5)
b = randn(5)
D = Diagonal(Fill(2,5))
E = Eye(5)
@test LazyArrays.simplifiable(*, A) == Val(false)
@test LazyArrays.simplify(Applied(*, A, A)) == A*A
@test LazyArrays.simplifiable(*, A, D) == Val(true)
@test LazyArrays.simplifiable(*, D, A) == Val(true)
@test LazyArrays.simplifiable(*, b', D) == Val(true)
@test LazyArrays.simplifiable(*, D, b) == Val(true)
@test LazyArrays.simplifiable(*, A, E) == Val(true)
@test LazyArrays.simplifiable(*, E, A) == Val(true)
@test LazyArrays.simplifiable(*, b', E) == Val(true)
@test LazyArrays.simplifiable(*, E, b) == Val(true)
@test LazyArrays.simplifiable(*, E, D) == Val(true)
@test LazyArrays.simplifiable(*, D, E) == Val(true)
@test LazyArrays.simplifiable(*, D, D) == Val(true)
@test LazyArrays.simplifiable(*, E, E) == Val(true)
end
end

Expand Down

0 comments on commit 7777232

Please sign in to comment.