Skip to content

Commit

Permalink
Accelerate _mul_ApplyStyle (#261)
Browse files Browse the repository at this point in the history
* Update mul.jl

* Update mul.jl

* generated

* Update mul.jl

* Update multests.jl

* Add parantheses

* Update Project.toml

---------

Co-authored-by: Sheehan Olver <solver@mac.com>
  • Loading branch information
putianyi889 and dlfivefifty authored Jul 20, 2023
1 parent 6f90647 commit 7009bd1
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 7 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
deps/deps.jl
.DS_Store
Manifest.toml
statprof/*
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "LazyArrays"
uuid = "5078a376-72f3-5289-bfd5-ec5146d43c02"
version = "1.4.0"
version = "1.4.1"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
13 changes: 12 additions & 1 deletion src/linalg/mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,18 @@ combine_mul_styles(a, b, c...) = combine_mul_styles(combine_mul_styles(a, b), c.
# We need to combine all branches to determine whether it can be simplified
ApplyStyle(::typeof(*), a) = DefaultApplyStyle()
ApplyStyle(::typeof(*), a::AbstractArray) = DefaultArrayApplyStyle()
_mul_ApplyStyle(a, b...) = combine_mul_styles(ApplyStyle(*, a, Base.front(b)...), ApplyStyle(*, b...))
# naive recursion is more comprehensive but is slower than the implemented algorithm as of Julia 1.9.2.
# @generated _mul_ApplyStyle(a...) = combine_mul_styles(_mul_ApplyStyle(Base.front(a)...), _mul_ApplyStyle(Base.tail(a)...))
@generated function _mul_ApplyStyle(a...)
list = ApplyStyle[_mul_ApplyStyle(x) for x in a]
for countdown in length(list)-1:-1:1
for k in 1:countdown
list[k] = combine_mul_styles(list[k], list[k+1])
end
end
list[1]
end
_mul_ApplyStyle(a) = MulStyle()
ApplyStyle(::typeof(*), a, b...) = _mul_ApplyStyle(a, b...)
if !(AbstractQ <: AbstractMatrix) # VERSION >= v"1.10-"
ApplyStyle(::typeof(*), a::Type{<:Union{AbstractArray,AbstractQ}}, b::Type{<:Union{AbstractArray,AbstractQ}}...) = _mul_ApplyStyle(a, b...)
Expand Down
10 changes: 5 additions & 5 deletions test/multests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ end
@test applied(*, UpperTriangular(A), x) isa Applied{MulStyle}
@test similar(applied(*, UpperTriangular(A), x), Float64) isa Vector{Float64}

@test ApplyStyle(*, typeof(UpperTriangular(A)), typeof(x)) isa MulStyle
@test @inferred(ApplyStyle(*, typeof(UpperTriangular(A)), typeof(x))) isa MulStyle

@test all((y = copy(x); y .= applied(*, UpperTriangular(A),y) ) .===
(similar(x) .= applied(*, UpperTriangular(A),x)) .===
Expand Down Expand Up @@ -883,13 +883,13 @@ end
A = randn(5,5)
B = Diagonal(randn(5))
@test MemoryLayout(typeof(B)) == DiagonalLayout{DenseColumnMajor}()
@test ApplyStyle(*, typeof(A), typeof(B)) == MulStyle()
@test @inferred(ApplyStyle(*, typeof(A), typeof(B))) == MulStyle()
@test apply(*,A,B) == A*B == materialize!(Rmul(copy(A),B))

@test ApplyStyle(*, typeof(B), typeof(A)) == MulStyle()
@test @inferred(ApplyStyle(*, typeof(B), typeof(A))) == MulStyle()
@test apply(*,B,A) == B*A

@test ApplyStyle(*, typeof(B), typeof(B)) == MulStyle()
@test @inferred(ApplyStyle(*, typeof(B), typeof(B))) == MulStyle()
@test apply(*,B,B) == B*B
@test apply(*,B,B) isa Diagonal

Expand Down Expand Up @@ -961,7 +961,7 @@ end

@testset "ApplyArray MulTest" begin
A = ApplyArray(*,randn(2,2), randn(2,2))
@test ApplyStyle(*,typeof(A),typeof(randn(2,2))) isa MulStyle
@test @inferred(ApplyStyle(*,typeof(A),typeof(randn(2,2)))) isa MulStyle
@test ApplyArray(*,Diagonal(Fill(2,10)), Fill(3,10,10))*Fill(3,10) Fill(180,10)
@test ApplyArray(*,Diagonal(Fill(2,10)), Fill(3,10,10))*ApplyArray(*,Diagonal(Fill(2,10)), Fill(3,10,10)) == Fill(360,10,10)
@test A' isa ApplyArray
Expand Down

2 comments on commit 7009bd1

@dlfivefifty
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/87925

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.4.1 -m "<description of version>" 7009bd174b4367b961c4a632a35ab0cb9dfc8055
git push origin v1.4.1

Please sign in to comment.