Skip to content

Commit

Permalink
Make StructArrays broadcast aware (#136)
Browse files Browse the repository at this point in the history
* Make StructArrays broadcast aware

Fixes #89

* Limit circumstances in which broadcasting returns a StructArray

While allowing broadcasting to return a StructArray, this limits it to
cases where:

- no other arrays in the broadcast operation, including those wrapped by the
  StructArray, have non-default BroadcastStyle
- the eltype returned from the function is a struct type

It should be straightforward to define precedence rules to handle other
cases, e.g., StructArrays of CuArrays.

* Add a test for custom-broadcasting internal arrays

* Embrace the MethodError

Co-authored-by: Keno Fischer <keno@alumni.harvard.edu>
  • Loading branch information
timholy and Keno authored Jul 14, 2020
1 parent 7b77672 commit 0c4adf6
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
26 changes: 24 additions & 2 deletions src/structarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,18 @@ index_type(::Type{NamedTuple{names, types}}) where {names, types} = index_type(t
index_type(::Type{Tuple{}}) = Int
function index_type(::Type{T}) where {T<:Tuple}
S, U = tuple_type_head(T), tuple_type_tail(T)
IndexStyle(S) isa IndexCartesian ? CartesianIndex{ndims(S)} : index_type(U)
IndexStyle(S) isa IndexCartesian ? CartesianIndex{ndims(S)} : index_type(U)
end

index_type(::Type{StructArray{T, N, C, I}}) where {T, N, C, I} = I

array_types(::Type{StructArray{T, N, C, I}}) where {T, N, C, I} = array_types(C)
array_types(::Type{NamedTuple{names, types}}) where {names, types} = types
array_types(::Type{TT}) where {TT<:Tuple} = TT

function StructArray{T}(c::C) where {T, C<:Tup}
cols = strip_params(staticschema(T))(c)
N = isempty(cols) ? 1 : ndims(cols[1])
N = isempty(cols) ? 1 : ndims(cols[1])
StructArray{T, N, typeof(cols)}(cols)
end

Expand Down Expand Up @@ -225,3 +229,21 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T
showfields(io, Tuple(fieldarrays(s)))
toplevel && print(io, " with eltype ", T)
end

# broadcast
import Base.Broadcast: BroadcastStyle, ArrayStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle

struct StructArrayStyle{Style} <: AbstractArrayStyle{Any} end

@inline combine_style_types(::Type{A}, args...) where A<:AbstractArray =
combine_style_types(BroadcastStyle(A), args...)
@inline combine_style_types(s::BroadcastStyle, ::Type{A}, args...) where A<:AbstractArray =
combine_style_types(Broadcast.result_style(s, BroadcastStyle(A)), args...)
combine_style_types(s::BroadcastStyle) = s

Base.@pure cst(::Type{SA}) where SA = combine_style_types(array_types(SA).parameters...)

BroadcastStyle(::Type{SA}) where SA<:StructArray = StructArrayStyle{typeof(cst(SA))}()

Base.similar(bc::Broadcasted{StructArrayStyle{S}}, ::Type{ElType}) where {S<:DefaultArrayStyle,N,ElType} =
isstructtype(ElType) ? similar(StructArray{ElType}, axes(bc)) : similar(Array{ElType}, axes(bc))
29 changes: 29 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -714,3 +714,32 @@ Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs)
@test t.b.c isa Array
@test t.b.d isa Array
end

struct MyArray{T,N} <: AbstractArray{T,N}
A::Array{T,N}
end
MyArray{T}(::UndefInitializer, sz::Dims) where T = MyArray(Array{T}(undef, sz))
Base.IndexStyle(::Type{<:MyArray}) = IndexLinear()
Base.getindex(A::MyArray, i::Int) = A.A[i]
Base.setindex!(A::MyArray, val, i::Int) = A.A[i] = val
Base.size(A::MyArray) = Base.size(A.A)
Base.BroadcastStyle(::Type{<:MyArray}) = Broadcast.ArrayStyle{MyArray}()
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{ElType}) where ElType =
MyArray{ElType}(undef, size(bc))

@testset "broadcast" begin
s = StructArray{ComplexF64}((rand(2,2), rand(2,2)))
@test isa(@inferred(s .+ s), StructArray)
@test (s .+ s).re == 2*s.re
@test (s .+ s).im == 2*s.im
@test isa(@inferred(broadcast(t->1, s)), Array)
@test all(x->x==1, broadcast(t->1, s))
@test isa(@inferred(s .+ 1), StructArray)
@test s .+ 1 == StructArray{ComplexF64}((s.re .+ 1, s.im))
r = rand(2,2)
@test isa(@inferred(s .+ r), StructArray)
@test s .+ r == StructArray{ComplexF64}((s.re .+ r, s.im))

s = StructArray{ComplexF64}((MyArray(rand(2,2)), MyArray(rand(2,2))))
@test_throws MethodError s .+ s
end

0 comments on commit 0c4adf6

Please sign in to comment.