Skip to content

Commit

Permalink
Make StructArrayStyle track inputs dimension
Browse files Browse the repository at this point in the history
fix #185
  • Loading branch information
N5N3 committed Feb 9, 2022
1 parent 8e67e4e commit 9b9d8b2
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
10 changes: 7 additions & 3 deletions src/structarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,11 @@ end
# broadcast
import Base.Broadcast: BroadcastStyle, ArrayStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle

struct StructArrayStyle{Style} <: AbstractArrayStyle{Any} end
struct StructArrayStyle{S,N} <: AbstractArrayStyle{N} end
# If `S` also track input's dimensionality, we'd better also update it.
StructArrayStyle{S,M}(::Val{N}) where {M,S<:AbstractArrayStyle{M},N} =
StructArrayStyle{typeof(S(Val(N))),N}()
StructArrayStyle{S,M}(::Val{N}) where {M,S,N} = StructArrayStyle{S,N}()

@inline combine_style_types(::Type{A}, args...) where A<:AbstractArray =
combine_style_types(BroadcastStyle(A), args...)
Expand All @@ -455,9 +459,9 @@ combine_style_types(s::BroadcastStyle) = s

Base.@pure cst(::Type{SA}) where SA = combine_style_types(array_types(SA).parameters...)

BroadcastStyle(::Type{SA}) where SA<:StructArray = StructArrayStyle{typeof(cst(SA))}()
BroadcastStyle(::Type{SA}) where SA<:StructArray = StructArrayStyle{typeof(cst(SA)),ndims(SA)}()

Base.similar(bc::Broadcasted{StructArrayStyle{S}}, ::Type{ElType}) where {S<:DefaultArrayStyle,N,ElType} =
Base.similar(bc::Broadcasted{<:StructArrayStyle{S}}, ::Type{ElType}) where {S<:DefaultArrayStyle,N,ElType} =
isstructtype(ElType) ? similar(StructArray{ElType}, axes(bc)) : similar(Array{ElType}, axes(bc))

# for aliasing analysis during broadcast
Expand Down
18 changes: 17 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -926,8 +926,24 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El
# used inside of broadcast but we also test it here explicitly
@test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N})

s = StructArray{ComplexF64}((MyArray(rand(2,2)), MyArray(rand(2,2))))
s = StructArray{ComplexF64}((MyArray(rand(2)), MyArray(rand(2))))
@test_throws MethodError s .+ s

# test for dimensionality track
@test Base.broadcasted(+, s, s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}}
@test Base.broadcasted(+, s, [1,2]) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}}
@test Base.broadcasted(+, s, [1;;2]) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{2}}
@test Base.broadcasted(+, [1;;;2], s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{3}}

a = StructArray([1;2+im])
b = StructArray([1;;2+im])
@test a .+ b == a .+ collect(b) == collect(a) .+ b == collect(a) .+ collect(b)

# issue #185
A = StructArray(randn(ComplexF64, 3, 3))
B = randn(ComplexF64, 3, 3)
c = StructArray(randn(ComplexF64, 3))
@test (A .= B .* c) === A
end

@testset "staticarrays" begin
Expand Down

0 comments on commit 9b9d8b2

Please sign in to comment.