Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Banded * Strided is always simplifiable #343

Merged
merged 8 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "LazyArrays"
uuid = "5078a376-72f3-5289-bfd5-ec5146d43c02"
version = "2.1.9"
version = "2.2"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
8 changes: 4 additions & 4 deletions ext/LazyArraysBandedMatricesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -567,8 +567,8 @@
copy(M::Mul{<:Union{ZerosLayout,DualLayout{ZerosLayout}}, <:BandedLazyLayouts}) = copy(mulreduce(M))
copy(M::Mul{<:BandedLazyLayouts, <:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = copy(mulreduce(M))

simplifiable(::Mul{<:BandedLazyLayouts, <:DiagonalLayout{<:OnesLayout}}) = Val(true)
simplifiable(::Mul{<:DiagonalLayout{<:OnesLayout}, <:BandedLazyLayouts}) = Val(true)
simplifiable(::Mul{<:BandedLayouts, <:DiagonalLayout{<:OnesLayout}}) = Val(true)
simplifiable(::Mul{<:DiagonalLayout{<:OnesLayout}, <:BandedLayouts}) = Val(true)

Check warning on line 571 in ext/LazyArraysBandedMatricesExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LazyArraysBandedMatricesExt.jl#L570-L571

Added lines #L570 - L571 were not covered by tests
copy(M::Mul{<:BandedLazyLayouts, <:DiagonalLayout{<:OnesLayout}}) = _copy_oftype(M.A, eltype(M))
copy(M::Mul{<:DiagonalLayout{<:OnesLayout}, <:BandedLazyLayouts}) = _copy_oftype(M.B, eltype(M))

Expand Down Expand Up @@ -644,8 +644,8 @@
simplifiable(M::Mul{<:AbstractInvLayout{<:BandedLayouts}, <:Union{PaddedColumns,PaddedLayout,AbstractStridedLayout}}) = Val(true)
simplifiable(M::Mul{<:Union{PaddedRows,PaddedLayout,DualLayout{<:PaddedRows}}, <:BandedLayouts}) = Val(true)
simplifiable(M::Mul{<:BandedLayouts, <:Union{PaddedColumns,PaddedLayout}}) = Val(true)
simplifiable(M::Mul{<:Union{AbstractStridedLayout,DualLayout{<:AbstractStridedLayout}}, <:BandedLazyLayouts}) = Val(true)
simplifiable(M::Mul{<:BandedLazyLayouts, <:AbstractStridedLayout}) = Val(true)
simplifiable(M::Mul{<:Union{AbstractStridedLayout,DualLayout{<:AbstractStridedLayout}}, <:BandedLayouts}) = Val(true)
simplifiable(M::Mul{<:BandedLayouts, <:AbstractStridedLayout}) = Val(true)

Check warning on line 648 in ext/LazyArraysBandedMatricesExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LazyArraysBandedMatricesExt.jl#L647-L648

Added lines #L647 - L648 were not covered by tests

copy(L::Ldiv{ApplyBandedLayout{typeof(*)}, Lay}) where Lay = copy(Ldiv{ApplyLayout{typeof(*)},Lay}(L.A, L.B))
copy(L::Ldiv{ApplyBandedLayout{typeof(*)}, Lay}) where {Lay<:AbstractLazyLayout} = copy(Ldiv{ApplyLayout{typeof(*)},Lay}(L.A, L.B))
Expand Down
2 changes: 1 addition & 1 deletion src/LazyArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import ArrayLayouts: AbstractQLayout, Dot, Dotu, Ldiv, Lmul, MatMulMatAdd, MatMu
hermitianlayout, layout_getindex, layout_replace_in_print_matrix, ldivaxes, materialize,
materialize!, mulreduce, reshapedlayout, rowsupport, scalarone, scalarzero, sub_materialize,
sublayout, symmetriclayout, symtridiagonallayout, transposelayout, triangulardata,
triangularlayout, tridiagonallayout, zero!, transtype
triangularlayout, tridiagonallayout, zero!, transtype, OnesLayout

import FillArrays: AbstractFill, getindex_value

Expand Down
21 changes: 13 additions & 8 deletions src/linalg/mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -371,18 +371,23 @@
transtype(A::MulMatrix) = transtype(first(A.args))

#TODO: Why not all DiagonalLayout?
@inline simplifiable(M::Mul{<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
@inline simplifiable(M::Mul{<:DiagonalLayout{<:AbstractFillLayout},<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
@inline simplifiable(M::Mul{<:Any,<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
@inline simplifiable(::Mul{<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
@inline simplifiable(::Mul{<:DiagonalLayout{<:AbstractFillLayout},<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
@inline simplifiable(::Mul{<:DiagonalLayout{<:OnesLayout},<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
@inline simplifiable(::Mul{<:DiagonalLayout{<:AbstractFillLayout},<:DiagonalLayout{<:OnesLayout}}) = Val(true)
@inline simplifiable(::Mul{<:DiagonalLayout{<:OnesLayout}, <:DiagonalLayout{<:OnesLayout}}) = Val(true) # ambiguity
@inline simplifiable(::Mul{<:Any,<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
@inline simplifiable(::Mul{<:Union{AbstractStridedLayout,DualLayout{<:AbstractStridedLayout}}, <:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
@inline simplifiable(::Mul{<:DiagonalLayout{<:AbstractFillLayout}, <:AbstractStridedLayout}) = Val(true)

Check warning on line 381 in src/linalg/mul.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg/mul.jl#L374-L381

Added lines #L374 - L381 were not covered by tests
@inline copy(M::Mul{<:DiagonalLayout{<:AbstractFillLayout},<:LazyLayouts}) = copy(mulreduce(M))
@inline copy(M::Mul{<:LazyLayouts,<:DiagonalLayout{<:AbstractFillLayout}}) = copy(mulreduce(M))
@inline copy(M::Mul{BroadcastLayout{typeof(*)},<:DiagonalLayout{<:AbstractFillLayout}}) = copy(mulreduce(M))

@inline simplifiable(M::Mul{<:Union{ZerosLayout,DualLayout{ZerosLayout}},<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true)
@inline simplifiable(M::Mul{<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true)
@inline simplifiable(M::Mul{<:Any,<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true)
@inline simplifiable(M::Mul{<:Union{ZerosLayout,DualLayout{ZerosLayout}},<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
@inline simplifiable(M::Mul{<:DiagonalLayout{<:AbstractFillLayout},<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true)
@inline simplifiable(::Mul{<:Union{ZerosLayout,DualLayout{ZerosLayout}},<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true)
@inline simplifiable(::Mul{<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true)
@inline simplifiable(::Mul{<:Any,<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true)
@inline simplifiable(::Mul{<:Union{ZerosLayout,DualLayout{ZerosLayout}},<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
@inline simplifiable(::Mul{<:DiagonalLayout{<:AbstractFillLayout},<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true)

Check warning on line 390 in src/linalg/mul.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg/mul.jl#L386-L390

Added lines #L386 - L390 were not covered by tests


# inv
Expand Down
19 changes: 18 additions & 1 deletion test/multests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using Test, LinearAlgebra, LazyArrays, StaticArrays, FillArrays, Base64, Random, BandedMatrices
using Test, LinearAlgebra, LazyArrays, FillArrays, Base64, Random
using BandedMatrices
using StaticArrays
import LazyArrays: MulAdd, MemoryLayout, DenseColumnMajor, DiagonalLayout, SymTridiagonalLayout, Add, AddArray,
MulStyle, MulAddStyle, Applied, ApplyStyle, Lmul, ApplyArrayBroadcastStyle, DefaultArrayApplyStyle,
Rmul, ApplyLayout, arguments, colsupport, rowsupport, lazymaterialize
Expand Down Expand Up @@ -1176,8 +1178,23 @@ end

@testset "simplifiable tests" begin
A = randn(5,5)
b = randn(5)
D = Diagonal(Fill(2,5))
E = Eye(5)
@test LazyArrays.simplifiable(*, A) == Val(false)
@test LazyArrays.simplify(Applied(*, A, A)) == A*A
@test LazyArrays.simplifiable(*, A, D) == Val(true)
@test LazyArrays.simplifiable(*, D, A) == Val(true)
@test LazyArrays.simplifiable(*, b', D) == Val(true)
@test LazyArrays.simplifiable(*, D, b) == Val(true)
@test LazyArrays.simplifiable(*, A, E) == Val(true)
@test LazyArrays.simplifiable(*, E, A) == Val(true)
@test LazyArrays.simplifiable(*, b', E) == Val(true)
@test LazyArrays.simplifiable(*, E, b) == Val(true)
@test LazyArrays.simplifiable(*, E, D) == Val(true)
@test LazyArrays.simplifiable(*, D, E) == Val(true)
@test LazyArrays.simplifiable(*, D, D) == Val(true)
@test LazyArrays.simplifiable(*, E, E) == Val(true)
end
end

Expand Down
Loading