Skip to content

Commit

Permalink
Address part of #28866, prevent array operations from dropping 0d con…
Browse files Browse the repository at this point in the history
…tainers (#32122)

This is a simple workaround for the handful of elementwise operations that are defined on arrays _without_ the need for explicit broadcast but use broadcasting (with an extra shape check) in their implementation. These were the only affected cases I could find.
  • Loading branch information
mbauman authored and StefanKarpinski committed Jul 3, 2019
1 parent b74ef05 commit 38b38cf
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 7 deletions.
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ Standard library changes
* `Cmd` interpolation (`` `$(x::Cmd) a b c` `` where) now propagates `x`'s process flags
(environment, flags, working directory, etc) if `x` is the first interpolant and errors
otherwise ([#24353]).
* Zero-dimensional arrays are now consistently preserved in the return values of mathematical
functions that operate on the array(s) as a whole (and are not explicitly broadcasted across their elements).
Previously, the functions `+`, `-`, `*`, `/`, `conj`, `real` and `imag` returned the unwrapped element
when operating over zero-dimensional arrays ([#32122]).
* `IPAddr` subtypes now behave like scalars when used in broadcasting ([#32133]).
* `clamp` can now handle missing values ([#31066]).

Expand Down
10 changes: 5 additions & 5 deletions base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ julia> A
conj!(A::AbstractArray{<:Number}) = (@inbounds broadcast!(conj, A, A); A)

for f in (:-, :conj, :real, :imag)
@eval ($f)(A::AbstractArray) = broadcast($f, A)
@eval ($f)(A::AbstractArray) = broadcast_preserving_zero_d($f, A)
end


Expand All @@ -36,23 +36,23 @@ end
for f in (:+, :-)
@eval function ($f)(A::AbstractArray, B::AbstractArray)
promote_shape(A, B) # check size compatibility
broadcast($f, A, B)
broadcast_preserving_zero_d($f, A, B)
end
end

function +(A::Array, Bs::Array...)
for B in Bs
promote_shape(A, B) # check size compatibility
end
broadcast(+, A, Bs...)
broadcast_preserving_zero_d(+, A, Bs...)
end

for f in (:/, :\, :*)
if f != :/
@eval ($f)(A::Number, B::AbstractArray) = broadcast($f, A, B)
@eval ($f)(A::Number, B::AbstractArray) = broadcast_preserving_zero_d($f, A, B)
end
if f != :\
@eval ($f)(A::AbstractArray, B::Number) = broadcast($f, A, B)
@eval ($f)(A::AbstractArray, B::Number) = broadcast_preserving_zero_d($f, A, B)
end
end

Expand Down
18 changes: 17 additions & 1 deletion base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using .Base.Cartesian
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin,
_msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, unalias
import .Base: copy, copyto!, axes
export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__
export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__, broadcast_preserving_zero_d

## Computing the result's axes: deprecated name
const broadcast_axes = axes
Expand Down Expand Up @@ -790,6 +790,22 @@ julia> A
"""
broadcast!(f::Tf, dest, As::Vararg{Any,N}) where {Tf,N} = (materialize!(dest, broadcasted(f, As...)); dest)

"""
broadcast_preserving_zero_d(f, As...)
Like [`broadcast`](@ref), except in the case of a 0-dimensional result where it returns a 0-dimensional container
Broadcast automatically unwraps zero-dimensional results to be just the element itself,
but in some cases it is necessary to always return a container — even in the 0-dimensional case.
"""
function broadcast_preserving_zero_d(f, As...)
bc = broadcasted(f, As...)
r = materialize(bc)
return length(axes(bc)) == 0 ? fill!(similar(bc, typeof(r)), r) : r
end
broadcast_preserving_zero_d(f) = fill(f())
broadcast_preserving_zero_d(f, as::Number...) = fill(f(as...))

"""
Broadcast.materialize(bc)
Expand Down
5 changes: 4 additions & 1 deletion stdlib/LinearAlgebra/src/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,10 @@ quasiparentt(x) = parent(x); quasiparentt(x::Number) = x # to handle numbers in
quasiparenta(x) = parent(x); quasiparenta(x::Number) = conj(x) # to handle numbers in the defs below
broadcast(f, avs::Union{Number,AdjointAbsVec}...) = adjoint(broadcast((xs...) -> adjoint(f(adjoint.(xs)...)), quasiparenta.(avs)...))
broadcast(f, tvs::Union{Number,TransposeAbsVec}...) = transpose(broadcast((xs...) -> transpose(f(transpose.(xs)...)), quasiparentt.(tvs)...))
# TODO unify and allow mixed combinations
# Hack to preserve behavior after #32122; this needs to be done with a broadcast style instead to support dotted fusion
Broadcast.broadcast_preserving_zero_d(f, avs::Union{Number,AdjointAbsVec}...) = adjoint(broadcast((xs...) -> adjoint(f(adjoint.(xs)...)), quasiparenta.(avs)...))
Broadcast.broadcast_preserving_zero_d(f, tvs::Union{Number,TransposeAbsVec}...) = transpose(broadcast((xs...) -> transpose(f(transpose.(xs)...)), quasiparentt.(tvs)...))
# TODO unify and allow mixed combinations with a broadcast style

### linear algebra

Expand Down
17 changes: 17 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2628,6 +2628,23 @@ Base.view(::T25958, args...) = args
@test t[end,end,end] == @view(t[end,end,end]) == @views t[end,end,end]
end

@testset "0-dimensional container operations" begin
for op in (-, conj, real, imag)
@test op(fill(2)) == fill(op(2))
@test op(fill(1+2im)) == fill(op(1+2im))
end
for op in (+, -)
@test op(fill(1), fill(2)) == fill(op(1, 2))
@test op(fill(1), fill(2)) isa AbstractArray{Int, 0}
end
@test fill(1) + fill(2) + fill(3) == fill(1+2+3)
@test fill(1) / 2 == fill(1/2)
@test 2 \ fill(1) == fill(1/2)
@test 2*fill(1) == fill(2)
@test fill(1)*2 == fill(2)
end


# Fix oneunit bug for unitful arrays
@test oneunit([Second(1) Second(2); Second(3) Second(4)]) == [Second(1) Second(0); Second(0) Second(1)]

Expand Down

0 comments on commit 38b38cf

Please sign in to comment.