diff --git a/.gitignore b/.gitignore index 86ff35b4..c68c3851 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ deps/deps.jl .DS_Store Manifest.toml +statprof/* diff --git a/Project.toml b/Project.toml index c1f92bf3..af9ae6a3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "LazyArrays" uuid = "5078a376-72f3-5289-bfd5-ec5146d43c02" -version = "1.4.0" +version = "1.4.1" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/linalg/mul.jl b/src/linalg/mul.jl index b696f3b8..27214214 100644 --- a/src/linalg/mul.jl +++ b/src/linalg/mul.jl @@ -92,7 +92,18 @@ combine_mul_styles(a, b, c...) = combine_mul_styles(combine_mul_styles(a, b), c. # We need to combine all branches to determine whether it can be simplified ApplyStyle(::typeof(*), a) = DefaultApplyStyle() ApplyStyle(::typeof(*), a::AbstractArray) = DefaultArrayApplyStyle() -_mul_ApplyStyle(a, b...) = combine_mul_styles(ApplyStyle(*, a, Base.front(b)...), ApplyStyle(*, b...)) +# naive recursion is more comprehensive but is slower than the implemented algorithm as of Julia 1.9.2. +# @generated _mul_ApplyStyle(a...) = combine_mul_styles(_mul_ApplyStyle(Base.front(a)...), _mul_ApplyStyle(Base.tail(a)...)) +@generated function _mul_ApplyStyle(a...) + list = ApplyStyle[_mul_ApplyStyle(x) for x in a] + for countdown in length(list)-1:-1:1 + for k in 1:countdown + list[k] = combine_mul_styles(list[k], list[k+1]) + end + end + list[1] +end +_mul_ApplyStyle(a) = MulStyle() ApplyStyle(::typeof(*), a, b...) = _mul_ApplyStyle(a, b...) if !(AbstractQ <: AbstractMatrix) # VERSION >= v"1.10-" ApplyStyle(::typeof(*), a::Type{<:Union{AbstractArray,AbstractQ}}, b::Type{<:Union{AbstractArray,AbstractQ}}...) = _mul_ApplyStyle(a, b...) diff --git a/test/multests.jl b/test/multests.jl index ce9ec373..da62abea 100644 --- a/test/multests.jl +++ b/test/multests.jl @@ -662,7 +662,7 @@ end @test applied(*, UpperTriangular(A), x) isa Applied{MulStyle} @test similar(applied(*, UpperTriangular(A), x), Float64) isa Vector{Float64} - @test ApplyStyle(*, typeof(UpperTriangular(A)), typeof(x)) isa MulStyle + @test @inferred(ApplyStyle(*, typeof(UpperTriangular(A)), typeof(x))) isa MulStyle @test all((y = copy(x); y .= applied(*, UpperTriangular(A),y) ) .=== (similar(x) .= applied(*, UpperTriangular(A),x)) .=== @@ -883,13 +883,13 @@ end A = randn(5,5) B = Diagonal(randn(5)) @test MemoryLayout(typeof(B)) == DiagonalLayout{DenseColumnMajor}() - @test ApplyStyle(*, typeof(A), typeof(B)) == MulStyle() + @test @inferred(ApplyStyle(*, typeof(A), typeof(B))) == MulStyle() @test apply(*,A,B) == A*B == materialize!(Rmul(copy(A),B)) - @test ApplyStyle(*, typeof(B), typeof(A)) == MulStyle() + @test @inferred(ApplyStyle(*, typeof(B), typeof(A))) == MulStyle() @test apply(*,B,A) == B*A - @test ApplyStyle(*, typeof(B), typeof(B)) == MulStyle() + @test @inferred(ApplyStyle(*, typeof(B), typeof(B))) == MulStyle() @test apply(*,B,B) == B*B @test apply(*,B,B) isa Diagonal @@ -961,7 +961,7 @@ end @testset "ApplyArray MulTest" begin A = ApplyArray(*,randn(2,2), randn(2,2)) - @test ApplyStyle(*,typeof(A),typeof(randn(2,2))) isa MulStyle + @test @inferred(ApplyStyle(*,typeof(A),typeof(randn(2,2)))) isa MulStyle @test ApplyArray(*,Diagonal(Fill(2,10)), Fill(3,10,10))*Fill(3,10) ≡ Fill(180,10) @test ApplyArray(*,Diagonal(Fill(2,10)), Fill(3,10,10))*ApplyArray(*,Diagonal(Fill(2,10)), Fill(3,10,10)) == Fill(360,10,10) @test A' isa ApplyArray