Skip to content

Commit

Permalink
Fix regression on StaticArray
Browse files Browse the repository at this point in the history
Make sure `StructArrayStyle{<:StaticArrayStyle}` lose to `DefaultArrayStyle`
  • Loading branch information
N5N3 committed Feb 10, 2022
1 parent c7a4dc8 commit 3c4b7f2
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 25 deletions.
9 changes: 7 additions & 2 deletions src/staticarrays_support.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import StaticArrays: StaticArray, FieldArray, tuple_prod
import StaticArrays: StaticArray, FieldArray, tuple_prod, StaticArrayStyle

"""
StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
Expand Down Expand Up @@ -26,4 +26,9 @@ StructArrays.component(s::StaticArray, i) = getindex(s, i)
invoke(StructArrays.staticschema, Tuple{Type{<:Any}}, T)
end
StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i)
StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(createinstance, Tuple{Type{<:Any}, Vararg}, T, args...)
StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(createinstance, Tuple{Type{<:Any}, Vararg}, T, args...)

function Base.copy(bc::Broadcasted{StructArrayStyle{StaticArrayStyle{N},N}}) where {N}
B = convert(Broadcasted{StructArrayStyle{Broadcast.DefaultArrayStyle{N},N}}, bc)
copy(B)
end
3 changes: 2 additions & 1 deletion src/structarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T
end

# broadcast
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown

struct StructArrayStyle{S,N} <: AbstractArrayStyle{N} end

Expand All @@ -468,6 +468,7 @@ function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{S,N}) wher
end
StructArrayStyle{typeof(S′),_dimmax(N,M)}()
end
BroadcastStyle(::StructArrayStyle, ::DefaultArrayStyle) = Unknown()

@inline combine_style_types(::Type{A}, args...) where A<:AbstractArray =
combine_style_types(BroadcastStyle(A), args...)
Expand Down
48 changes: 26 additions & 22 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,7 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray1}}, ::Type{E
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray2}}, ::Type{ElType}) where ElType =
MyArray2{ElType}(undef, size(bc))
Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyArray3}) = Broadcast.ArrayStyle{MyArray1}()
Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayStyle) = S

@testset "broadcast" begin
s = StructArray{ComplexF64}((rand(2,2), rand(2,2)))
Expand All @@ -935,27 +936,24 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyA
@test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N})

# Make sure we can handle style with similar defined
# And we can handle most conflict
# s1 and s2 has similar defined, but s3 not
# s2 are conflict with s1 and s3.
# s2 are conflict with s1 and s3. (And it's weaker than DefaultArrayStyle)
s1 = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2))))
s2 = StructArray{ComplexF64}((MyArray2(rand(2)), MyArray2(rand(2))))
s3 = StructArray{ComplexF64}((MyArray3(rand(2)), MyArray3(rand(2))))
s4 = StructArray{ComplexF64}((rand(2), rand(2)))

function _test_similar(a, b)
flag = false
function _test_similar(a, b, c)
try
c = StructArray{ComplexF64}((a.re .+ b.re, a.im .+ b.im))
flag = true
d = StructArray{ComplexF64}((a.re .+ b.re .- c.re, a.im .+ b.im .- c.im))
@test typeof(a .+ b .- c) == typeof(d)
catch
end
if flag
@test typeof(@inferred(a .+ b)) == typeof(c)
else
@test_throws MethodError a .+ b
@test_throws MethodError a .+ b .- c
end
end
for s in (s1,s2,s3), s′ in (s1,s2,s3)
_test_similar(s, s′)
for s in (s1,s2,s3,s4), s′ in (s1,s2,s3,s4), s″ in (s1,s2,s3,s4)
_test_similar(s, s′, s″)
end

# test for dimensionality track
Expand All @@ -973,20 +971,26 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyA
@test (A .= B .* c) === A

# ambiguity check (can we do this better?)
function _test(a, b)
if a isa StructArray || b isa StructArray
d = @inferred a .+ b
@test d == collect(a) .+ collect(b)
function _test(a, b, c)
if a isa StructArray || b isa StructArray || c isa StructArray
d = @inferred a .+ b .- c
@test d == collect(a) .+ collect(b) .- collect(c)
@test d isa StructArray
end
end
testset = StructArray([1;2+im]), StructArray([1 2+im]), 1:2, (1,2), (@SArray [1 2])
for aa in testset, bb in testset
_test(aa, bb)
testset = (StructArray([1;2+im]),
StructArray([1 2+im]),
1:2,
(1,2),
(@SArray [1 2]),
StructArray(@SArray [1 1+2im]))
for aa in testset, bb in testset, cc in testset
_test(aa, bb, cc)
end
a = StructArray([1;2+im])
b = StructArray([1 2+im])
@test @inferred(a .+ b .+ a .* a' .+ (1,2) .+ (1:2) .- b') isa StructArray

a = @SArray randn(3,3);
b = StructArray{ComplexF64}((a,a))
@test a[:,1] .+ b isa StructArray && (a[:,1] .+ b).re isa SizedMatrix
end

@testset "staticarrays" begin
Expand Down

0 comments on commit 3c4b7f2

Please sign in to comment.