From 26b153dd5dd4631630e55c171b62349b96bc8ee6 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Thu, 10 Feb 2022 22:47:19 +0800 Subject: [PATCH] Make sure `StructArrayStyle{<:StaticArrayStyle}` lose to `DefaultArrayStyle` --- src/structarray.jl | 3 ++- test/runtests.jl | 41 ++++++++++++++++++++++------------------- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/src/structarray.jl b/src/structarray.jl index c1055331..d665cca3 100644 --- a/src/structarray.jl +++ b/src/structarray.jl @@ -443,7 +443,7 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T end # broadcast -import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted +import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown struct StructArrayStyle{S,N} <: AbstractArrayStyle{N} end @@ -468,6 +468,7 @@ function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{S,N}) wher end StructArrayStyle{typeof(S′),_dimmax(N,M)}() end +BroadcastStyle(::StructArrayStyle, ::DefaultArrayStyle) = Unknown() @inline combine_style_types(::Type{A}, args...) where A<:AbstractArray = combine_style_types(BroadcastStyle(A), args...) diff --git a/test/runtests.jl b/test/runtests.jl index 802ca059..3f2da716 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -917,6 +917,7 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray1}}, ::Type{E Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray2}}, ::Type{ElType}) where ElType = MyArray2{ElType}(undef, size(bc)) Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyArray3}) = Broadcast.ArrayStyle{MyArray1}() +Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayStyle) = S @testset "broadcast" begin s = StructArray{ComplexF64}((rand(2,2), rand(2,2))) @@ -935,27 +936,24 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyA @test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N}) # Make sure we can handle style with similar defined + # And we can handle most conflict # s1 and s2 has similar defined, but s3 not - # s2 are conflict with s1 and s3. + # s2 are conflict with s1 and s3. (And it's weaker than DefaultArrayStyle) s1 = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2)))) s2 = StructArray{ComplexF64}((MyArray2(rand(2)), MyArray2(rand(2)))) s3 = StructArray{ComplexF64}((MyArray3(rand(2)), MyArray3(rand(2)))) + s4 = StructArray{ComplexF64}((rand(2), rand(2))) - function _test_similar(a, b) - flag = false + function _test_similar(a, b, c) try - c = StructArray{ComplexF64}((a.re .+ b.re, a.im .+ b.im)) - flag = true + d = StructArray{ComplexF64}((a.re .+ b.re .- c.re, a.im .+ b.im .- c.im)) + @test typeof(a .+ b .- c) == typeof(d) catch - end - if flag - @test typeof(@inferred(a .+ b)) == typeof(c) - else - @test_throws MethodError a .+ b + @test_throws MethodError a .+ b .- c end end - for s in (s1,s2,s3), s′ in (s1,s2,s3) - _test_similar(s, s′) + for s in (s1,s2,s3,s4), s′ in (s1,s2,s3,s4), s″ in (s1,s2,s3,s4) + _test_similar(s, s′, s″) end # test for dimensionality track @@ -973,16 +971,21 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyA @test (A .= B .* c) === A # ambiguity check (can we do this better?) - function _test(a, b) - if a isa StructArray || b isa StructArray - d = @inferred a .+ b - @test d == collect(a) .+ collect(b) + function _test(a, b, c) + if a isa StructArray || b isa StructArray || c isa StructArray + d = @inferred a .+ b .- c + @test d == collect(a) .+ collect(b) .- collect(c) @test d isa StructArray end end - testset = StructArray([1;2+im]), StructArray([1 2+im]), 1:2, (1,2), (@SArray [1 2]) - for aa in testset, bb in testset - _test(aa, bb) + testset = Any[StructArray([1;2+im]), + StructArray([1 2+im]), + 1:2, + (1,2), + (@SArray [1 2]), + StructArray(@SArray [1 1+2im]),] + for aa in testset, bb in testset, cc in testset + _test(aa, bb, cc) end a = StructArray([1;2+im]) b = StructArray([1 2+im])