From 7b3d131df845ae42d0f698c46a08ba63f2957b1f Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Fri, 11 Oct 2019 21:55:37 -0400 Subject: [PATCH 1/4] Make StructArrays broadcast aware Fixes #89 --- src/structarray.jl | 11 +++++++++-- test/runtests.jl | 5 +++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/structarray.jl b/src/structarray.jl index 3175f84d..26913dd8 100644 --- a/src/structarray.jl +++ b/src/structarray.jl @@ -22,14 +22,14 @@ index_type(::Type{NamedTuple{names, types}}) where {names, types} = index_type(t index_type(::Type{Tuple{}}) = Int function index_type(::Type{T}) where {T<:Tuple} S, U = tuple_type_head(T), tuple_type_tail(T) - IndexStyle(S) isa IndexCartesian ? CartesianIndex{ndims(S)} : index_type(U) + IndexStyle(S) isa IndexCartesian ? CartesianIndex{ndims(S)} : index_type(U) end index_type(::Type{StructArray{T, N, C, I}}) where {T, N, C, I} = I function StructArray{T}(c::C) where {T, C<:Tup} cols = strip_params(staticschema(T))(c) - N = isempty(cols) ? 1 : ndims(cols[1]) + N = isempty(cols) ? 1 : ndims(cols[1]) StructArray{T, N, typeof(cols)}(cols) end @@ -225,3 +225,10 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T showfields(io, Tuple(fieldarrays(s))) toplevel && print(io, " with eltype ", T) end + +# broadcast +import Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted + +BroadcastStyle(::Type{<:StructArray}) = ArrayStyle{StructArray}() +Base.similar(bc::Broadcasted{ArrayStyle{StructArray}}, ::Type{ElType}) where {N,ElType} = + similar(StructArray{ElType}, axes(bc)) diff --git a/test/runtests.jl b/test/runtests.jl index 756a63cc..77aae447 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -714,3 +714,8 @@ Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs) @test t.b.c isa Array @test t.b.d isa Array end + +@testset "broadcast" begin + s = StructArray{ComplexF64}((rand(2,2), rand(2,2))) + @test isa(s .+ s, StructArray) +end From 4124b20941faf65dd6006c2e56652c56f98e79bd Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Sat, 11 Jul 2020 04:43:31 -0500 Subject: [PATCH 2/4] Limit circumstances in which broadcasting returns a StructArray While allowing broadcasting to return a StructArray, this limits it to cases where: - no other arrays in the broadcast operation, including those wrapped by the StructArray, have non-default BroadcastStyle - the eltype returned from the function is a struct type It should be straightforward to define precedence rules to handle other cases, e.g., StructArrays of CuArrays. --- src/structarray.jl | 22 ++++++++++++++++++---- test/runtests.jl | 11 ++++++++++- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/src/structarray.jl b/src/structarray.jl index 26913dd8..d871c85e 100644 --- a/src/structarray.jl +++ b/src/structarray.jl @@ -27,6 +27,10 @@ end index_type(::Type{StructArray{T, N, C, I}}) where {T, N, C, I} = I +array_types(::Type{StructArray{T, N, C, I}}) where {T, N, C, I} = array_types(C) +array_types(::Type{NamedTuple{names, types}}) where {names, types} = types +array_types(::Type{TT}) where {TT<:Tuple} = TT + function StructArray{T}(c::C) where {T, C<:Tup} cols = strip_params(staticschema(T))(c) N = isempty(cols) ? 1 : ndims(cols[1]) @@ -227,8 +231,18 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T end # broadcast -import Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted +import Base.Broadcast: BroadcastStyle, ArrayStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle + +struct StructArrayStyle{Style} <: AbstractArrayStyle{Any} end + +@inline combine_style_types(::Type{A}, args...) where A<:AbstractArray = + combine_style_types(BroadcastStyle(A), args...) +@inline combine_style_types(s::BroadcastStyle, ::Type{A}, args...) where A<:AbstractArray = + combine_style_types(Broadcast.result_style(s, BroadcastStyle(A)), args...) +combine_style_types(s::BroadcastStyle) = s + +Base.@pure cst(::Type{SA}) where SA = combine_style_types(array_types(SA).parameters...) -BroadcastStyle(::Type{<:StructArray}) = ArrayStyle{StructArray}() -Base.similar(bc::Broadcasted{ArrayStyle{StructArray}}, ::Type{ElType}) where {N,ElType} = - similar(StructArray{ElType}, axes(bc)) +BroadcastStyle(::Type{SA}) where SA<:StructArray = StructArrayStyle{typeof(cst(SA))}() +Base.similar(bc::Broadcasted{StructArrayStyle{S}}, ::Type{ElType}) where {S<:DefaultArrayStyle,N,ElType} = + isstructtype(ElType) ? similar(StructArray{ElType}, axes(bc)) : similar(Array{ElType}, axes(bc)) diff --git a/test/runtests.jl b/test/runtests.jl index 77aae447..86801638 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -717,5 +717,14 @@ end @testset "broadcast" begin s = StructArray{ComplexF64}((rand(2,2), rand(2,2))) - @test isa(s .+ s, StructArray) + @test isa(@inferred(s .+ s), StructArray) + @test (s .+ s).re == 2*s.re + @test (s .+ s).im == 2*s.im + @test isa(@inferred(broadcast(t->1, s)), Array) + @test all(x->x==1, broadcast(t->1, s)) + @test isa(@inferred(s .+ 1), StructArray) + @test s .+ 1 == StructArray{ComplexF64}((s.re .+ 1, s.im)) + r = rand(2,2) + @test isa(@inferred(s .+ r), StructArray) + @test s .+ r == StructArray{ComplexF64}((s.re .+ r, s.im)) end From 9f217f12879cce53cda1663539e84d6970d9e564 Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Sun, 12 Jul 2020 07:51:37 -0500 Subject: [PATCH 3/4] Add a test for custom-broadcasting internal arrays --- src/structarray.jl | 3 +++ test/runtests.jl | 17 +++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/src/structarray.jl b/src/structarray.jl index d871c85e..49604131 100644 --- a/src/structarray.jl +++ b/src/structarray.jl @@ -244,5 +244,8 @@ combine_style_types(s::BroadcastStyle) = s Base.@pure cst(::Type{SA}) where SA = combine_style_types(array_types(SA).parameters...) BroadcastStyle(::Type{SA}) where SA<:StructArray = StructArrayStyle{typeof(cst(SA))}() + Base.similar(bc::Broadcasted{StructArrayStyle{S}}, ::Type{ElType}) where {S<:DefaultArrayStyle,N,ElType} = isstructtype(ElType) ? similar(StructArray{ElType}, axes(bc)) : similar(Array{ElType}, axes(bc)) +Base.similar(bc::Broadcasted{StructArrayStyle{S}}, ::Type{ElType}) where {S<:ArrayStyle{A},N,ElType} where A = + similar(A{ElType}, axes(bc)) diff --git a/test/runtests.jl b/test/runtests.jl index 86801638..d924517b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -715,6 +715,18 @@ Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs) @test t.b.d isa Array end +struct MyArray{T,N} <: AbstractArray{T,N} + A::Array{T,N} +end +MyArray{T}(::UndefInitializer, sz::Dims) where T = MyArray(Array{T}(undef, sz)) +Base.IndexStyle(::Type{<:MyArray}) = IndexLinear() +Base.getindex(A::MyArray, i::Int) = A.A[i] +Base.setindex!(A::MyArray, val, i::Int) = A.A[i] = val +Base.size(A::MyArray) = Base.size(A.A) +Base.BroadcastStyle(::Type{<:MyArray}) = Broadcast.ArrayStyle{MyArray}() +Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{ElType}) where ElType = + MyArray{ElType}(undef, size(bc)) + @testset "broadcast" begin s = StructArray{ComplexF64}((rand(2,2), rand(2,2))) @test isa(@inferred(s .+ s), StructArray) @@ -727,4 +739,9 @@ end r = rand(2,2) @test isa(@inferred(s .+ r), StructArray) @test s .+ r == StructArray{ComplexF64}((s.re .+ r, s.im)) + + s = StructArray{ComplexF64}((MyArray(rand(2,2)), MyArray(rand(2,2)))) + @test isa(@inferred(s .+ s), MyArray) + @test real.((s .+ s)) == 2*s.re + @test imag.((s .+ s)) == 2*s.im end From 6bc0f9498dd290581aecaeed65adef24555a17f2 Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Sun, 12 Jul 2020 10:36:31 -0500 Subject: [PATCH 4/4] Embrace the MethodError --- src/structarray.jl | 2 -- test/runtests.jl | 4 +--- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/src/structarray.jl b/src/structarray.jl index 49604131..34fe3bd1 100644 --- a/src/structarray.jl +++ b/src/structarray.jl @@ -247,5 +247,3 @@ BroadcastStyle(::Type{SA}) where SA<:StructArray = StructArrayStyle{typeof(cst(S Base.similar(bc::Broadcasted{StructArrayStyle{S}}, ::Type{ElType}) where {S<:DefaultArrayStyle,N,ElType} = isstructtype(ElType) ? similar(StructArray{ElType}, axes(bc)) : similar(Array{ElType}, axes(bc)) -Base.similar(bc::Broadcasted{StructArrayStyle{S}}, ::Type{ElType}) where {S<:ArrayStyle{A},N,ElType} where A = - similar(A{ElType}, axes(bc)) diff --git a/test/runtests.jl b/test/runtests.jl index d924517b..463e51b2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -741,7 +741,5 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El @test s .+ r == StructArray{ComplexF64}((s.re .+ r, s.im)) s = StructArray{ComplexF64}((MyArray(rand(2,2)), MyArray(rand(2,2)))) - @test isa(@inferred(s .+ s), MyArray) - @test real.((s .+ s)) == 2*s.re - @test imag.((s .+ s)) == 2*s.im + @test_throws MethodError s .+ s end