Skip to content

Commit

Permalink
Make sure StructArrayStyle{<:StaticArrayStyle} lose to `DefaultArra…
Browse files Browse the repository at this point in the history
…yStyle`
  • Loading branch information
N5N3 committed Feb 10, 2022
1 parent 793d035 commit 26b153d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 20 deletions.
3 changes: 2 additions & 1 deletion src/structarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T
end

# broadcast
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown

struct StructArrayStyle{S,N} <: AbstractArrayStyle{N} end

Expand All @@ -468,6 +468,7 @@ function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{S,N}) wher
end
StructArrayStyle{typeof(S′),_dimmax(N,M)}()
end
BroadcastStyle(::StructArrayStyle, ::DefaultArrayStyle) = Unknown()

@inline combine_style_types(::Type{A}, args...) where A<:AbstractArray =
combine_style_types(BroadcastStyle(A), args...)
Expand Down
41 changes: 22 additions & 19 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,7 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray1}}, ::Type{E
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray2}}, ::Type{ElType}) where ElType =
MyArray2{ElType}(undef, size(bc))
Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyArray3}) = Broadcast.ArrayStyle{MyArray1}()
Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayStyle) = S

@testset "broadcast" begin
s = StructArray{ComplexF64}((rand(2,2), rand(2,2)))
Expand All @@ -935,27 +936,24 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyA
@test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N})

# Make sure we can handle style with similar defined
# And we can handle most conflict
# s1 and s2 has similar defined, but s3 not
# s2 are conflict with s1 and s3.
# s2 are conflict with s1 and s3. (And it's weaker than DefaultArrayStyle)
s1 = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2))))
s2 = StructArray{ComplexF64}((MyArray2(rand(2)), MyArray2(rand(2))))
s3 = StructArray{ComplexF64}((MyArray3(rand(2)), MyArray3(rand(2))))
s4 = StructArray{ComplexF64}((rand(2), rand(2)))

function _test_similar(a, b)
flag = false
function _test_similar(a, b, c)
try
c = StructArray{ComplexF64}((a.re .+ b.re, a.im .+ b.im))
flag = true
d = StructArray{ComplexF64}((a.re .+ b.re .- c.re, a.im .+ b.im .- c.im))
@test typeof(a .+ b .- c) == typeof(d)
catch
end
if flag
@test typeof(@inferred(a .+ b)) == typeof(c)
else
@test_throws MethodError a .+ b
@test_throws MethodError a .+ b .- c
end
end
for s in (s1,s2,s3), s′ in (s1,s2,s3)
_test_similar(s, s′)
for s in (s1,s2,s3,s4), s′ in (s1,s2,s3,s4), s″ in (s1,s2,s3,s4)
_test_similar(s, s′, s″)
end

# test for dimensionality track
Expand All @@ -973,16 +971,21 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyA
@test (A .= B .* c) === A

# ambiguity check (can we do this better?)
function _test(a, b)
if a isa StructArray || b isa StructArray
d = @inferred a .+ b
@test d == collect(a) .+ collect(b)
function _test(a, b, c)
if a isa StructArray || b isa StructArray || c isa StructArray
d = @inferred a .+ b .- c
@test d == collect(a) .+ collect(b) .- collect(c)
@test d isa StructArray
end
end
testset = StructArray([1;2+im]), StructArray([1 2+im]), 1:2, (1,2), (@SArray [1 2])
for aa in testset, bb in testset
_test(aa, bb)
testset = Any[StructArray([1;2+im]),
StructArray([1 2+im]),
1:2,
(1,2),
(@SArray [1 2]),
StructArray(@SArray [1 1+2im]),]
for aa in testset, bb in testset, cc in testset
_test(aa, bb, cc)
end
a = StructArray([1;2+im])
b = StructArray([1 2+im])
Expand Down

0 comments on commit 26b153d

Please sign in to comment.