From 3c4b7f2497d024366b6fd202767760b54ec47182 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Thu, 10 Feb 2022 13:46:50 +0800 Subject: [PATCH] Fix regression on `StaticArray` Make sure `StructArrayStyle{<:StaticArrayStyle}` lose to `DefaultArrayStyle` --- src/staticarrays_support.jl | 9 +++++-- src/structarray.jl | 3 ++- test/runtests.jl | 48 ++++++++++++++++++++----------------- 3 files changed, 35 insertions(+), 25 deletions(-) diff --git a/src/staticarrays_support.jl b/src/staticarrays_support.jl index ce1875b0..d03be3b0 100644 --- a/src/staticarrays_support.jl +++ b/src/staticarrays_support.jl @@ -1,4 +1,4 @@ -import StaticArrays: StaticArray, FieldArray, tuple_prod +import StaticArrays: StaticArray, FieldArray, tuple_prod, StaticArrayStyle """ StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T} @@ -26,4 +26,9 @@ StructArrays.component(s::StaticArray, i) = getindex(s, i) invoke(StructArrays.staticschema, Tuple{Type{<:Any}}, T) end StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i) -StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(createinstance, Tuple{Type{<:Any}, Vararg}, T, args...) \ No newline at end of file +StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(createinstance, Tuple{Type{<:Any}, Vararg}, T, args...) + +function Base.copy(bc::Broadcasted{StructArrayStyle{StaticArrayStyle{N},N}}) where {N} + B = convert(Broadcasted{StructArrayStyle{Broadcast.DefaultArrayStyle{N},N}}, bc) + copy(B) +end diff --git a/src/structarray.jl b/src/structarray.jl index c1055331..d665cca3 100644 --- a/src/structarray.jl +++ b/src/structarray.jl @@ -443,7 +443,7 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T end # broadcast -import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted +import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown struct StructArrayStyle{S,N} <: AbstractArrayStyle{N} end @@ -468,6 +468,7 @@ function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{S,N}) wher end StructArrayStyle{typeof(S′),_dimmax(N,M)}() end +BroadcastStyle(::StructArrayStyle, ::DefaultArrayStyle) = Unknown() @inline combine_style_types(::Type{A}, args...) where A<:AbstractArray = combine_style_types(BroadcastStyle(A), args...) diff --git a/test/runtests.jl b/test/runtests.jl index a4f19ce4..e5cb2afb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -917,6 +917,7 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray1}}, ::Type{E Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray2}}, ::Type{ElType}) where ElType = MyArray2{ElType}(undef, size(bc)) Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyArray3}) = Broadcast.ArrayStyle{MyArray1}() +Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayStyle) = S @testset "broadcast" begin s = StructArray{ComplexF64}((rand(2,2), rand(2,2))) @@ -935,27 +936,24 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyA @test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N}) # Make sure we can handle style with similar defined + # And we can handle most conflict # s1 and s2 has similar defined, but s3 not - # s2 are conflict with s1 and s3. + # s2 are conflict with s1 and s3. (And it's weaker than DefaultArrayStyle) s1 = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2)))) s2 = StructArray{ComplexF64}((MyArray2(rand(2)), MyArray2(rand(2)))) s3 = StructArray{ComplexF64}((MyArray3(rand(2)), MyArray3(rand(2)))) + s4 = StructArray{ComplexF64}((rand(2), rand(2))) - function _test_similar(a, b) - flag = false + function _test_similar(a, b, c) try - c = StructArray{ComplexF64}((a.re .+ b.re, a.im .+ b.im)) - flag = true + d = StructArray{ComplexF64}((a.re .+ b.re .- c.re, a.im .+ b.im .- c.im)) + @test typeof(a .+ b .- c) == typeof(d) catch - end - if flag - @test typeof(@inferred(a .+ b)) == typeof(c) - else - @test_throws MethodError a .+ b + @test_throws MethodError a .+ b .- c end end - for s in (s1,s2,s3), s′ in (s1,s2,s3) - _test_similar(s, s′) + for s in (s1,s2,s3,s4), s′ in (s1,s2,s3,s4), s″ in (s1,s2,s3,s4) + _test_similar(s, s′, s″) end # test for dimensionality track @@ -973,20 +971,26 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyA @test (A .= B .* c) === A # ambiguity check (can we do this better?) - function _test(a, b) - if a isa StructArray || b isa StructArray - d = @inferred a .+ b - @test d == collect(a) .+ collect(b) + function _test(a, b, c) + if a isa StructArray || b isa StructArray || c isa StructArray + d = @inferred a .+ b .- c + @test d == collect(a) .+ collect(b) .- collect(c) @test d isa StructArray end end - testset = StructArray([1;2+im]), StructArray([1 2+im]), 1:2, (1,2), (@SArray [1 2]) - for aa in testset, bb in testset - _test(aa, bb) + testset = (StructArray([1;2+im]), + StructArray([1 2+im]), + 1:2, + (1,2), + (@SArray [1 2]), + StructArray(@SArray [1 1+2im])) + for aa in testset, bb in testset, cc in testset + _test(aa, bb, cc) end - a = StructArray([1;2+im]) - b = StructArray([1 2+im]) - @test @inferred(a .+ b .+ a .* a' .+ (1,2) .+ (1:2) .- b') isa StructArray + + a = @SArray randn(3,3); + b = StructArray{ComplexF64}((a,a)) + @test a[:,1] .+ b isa StructArray && (a[:,1] .+ b).re isa SizedMatrix end @testset "staticarrays" begin