Skip to content

Commit

Permalink
Fix JuliaLang#24914 (WIP).
Browse files Browse the repository at this point in the history
  • Loading branch information
tkoolen committed Dec 8, 2017
1 parent 6eec805 commit f819811
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 25 deletions.
20 changes: 12 additions & 8 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,6 @@ broadcast_indices
# special cases defined for performance
broadcast(f, x::Number...) = f(x...)
@inline broadcast(f, t::NTuple{N,Any}, ts::Vararg{NTuple{N,Any}}) where {N} = map(f, t, ts...)
@inline broadcast!(::typeof(identity), x::AbstractArray{T,N}, y::AbstractArray{S,N}) where {T,S,N} =
Base.indices(x) == Base.indices(y) ? copy!(x, y) : _broadcast!(identity, x, y)

# special cases for "X .= ..." (broadcast!) assignments
broadcast!(::typeof(identity), X::AbstractArray, x::Number) = fill!(X, x)
broadcast!(f, X::AbstractArray, x::Number...) = (@inbounds for I in eachindex(X); X[I] = f(x...); end; X)

## logic for deciding the BroadcastStyle
# Dimensionality: computing max(M,N) in the type domain so we preserve inferrability
Expand Down Expand Up @@ -448,8 +442,18 @@ Note that `dest` is only used to store the result, and does not supply
arguments to `f` unless it is also listed in the `As`,
as in `broadcast!(f, A, A, B)` to perform `A[:] = broadcast(f, A, B)`.
"""
@inline broadcast!(f, C::AbstractArray, A, Bs::Vararg{Any,N}) where {N} =
_broadcast!(f, C, A, Bs...)
broadcast!(f, dest, As...) = broadcast!(f, dest, combine_styles(As...), As...)
broadcast!(f, dest, ::BroadcastStyle, As...) = broadcast!(f, dest, nothing, As...)
@inline function broadcast!(f, C, ::Void, A, Bs::Vararg{Any,N}) where N
if isa(f, typeof(identity)) && N == 0
if isa(A, Number)
return fill!(C, A)
elseif isa(C, AbstractArray) && isa(A, AbstractArray) && Base.indices(C) == Base.indices(A)
return copy!(C, A)
end
end
return _broadcast!(f, C, A, Bs...)
end

# This indirection allows size-dependent implementations (e.g., see the copying `identity`
# specialization above)
Expand Down
34 changes: 17 additions & 17 deletions base/sparse/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ end
# (3) broadcast[!] entry points
broadcast(f::Tf, A::SparseVector) where {Tf} = _noshapecheck_map(f, A)
broadcast(f::Tf, A::SparseMatrixCSC) where {Tf} = _noshapecheck_map(f, A)
function broadcast!(f::Tf, C::SparseVecOrMat) where Tf

function broadcast!(f::Tf, C::SparseVecOrMat, ::Void) where Tf
isempty(C) && return _finishempty!(C)
fofnoargs = f()
if _iszero(fofnoargs) # f() is zero, so empty C
Expand All @@ -106,14 +107,13 @@ function broadcast!(f::Tf, C::SparseVecOrMat) where Tf
end
return C
end
function broadcast!(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N}
_aresameshape(C, A, Bs...) && return _noshapecheck_map!(f, C, A, Bs...)
Base.Broadcast.check_broadcast_indices(indices(C), A, Bs...)
fofzeros = f(_zeros_eltypes(A, Bs...)...)
fpreszeros = _iszero(fofzeros)
return fpreszeros ? _broadcast_zeropres!(f, C, A, Bs...) :
_broadcast_notzeropres!(f, fofzeros, C, A, Bs...)
function broadcast!(f, dest::SparseVecOrMat, ::Void, A, Bs::Vararg{Any,N}) where N
if isa(f, typeof(identity)) && N == 0 && isa(A, Number)
return fill!(dest, A)
end
return spbroadcast_args!(f, dest, Broadcast.combine_styles(A, Bs...), A, Bs...)
end

# the following three similar defs are necessary for type stability in the mixed vector/matrix case
broadcast(f::Tf, A::SparseVector, Bs::Vararg{SparseVector,N}) where {Tf,N} =
_aresameshape(A, Bs...) ? _noshapecheck_map(f, A, Bs...) : _diffshape_broadcast(f, A, Bs...)
Expand Down Expand Up @@ -1005,18 +1005,18 @@ Broadcast.BroadcastStyle(::SparseMatStyle, ::Broadcast.DefaultArrayStyle{N}) whe
broadcast(f, ::PromoteToSparse, ::Void, ::Void, As::Vararg{Any,N}) where {N} =
broadcast(f, map(_sparsifystructured, As)...)

# ambiguity resolution
broadcast!(::typeof(identity), dest::SparseVecOrMat, x::Number) =
fill!(dest, x)
broadcast!(f, dest::SparseVecOrMat, x::Number...) =
spbroadcast_args!(f, dest, SPVM, x...)

# For broadcast! with ::Any inputs, we need a layer of indirection to determine whether
# the inputs can be promoted to SparseVecOrMat. If it's just SparseVecOrMat and scalars,
# we can handle it here, otherwise see below for the promotion machinery.
broadcast!(f, dest::SparseVecOrMat, mixedsrcargs::Vararg{Any,N}) where N =
spbroadcast_args!(f, dest, Broadcast.combine_styles(mixedsrcargs...), mixedsrcargs...)
function spbroadcast_args!(f, dest, ::Type{SPVM}, mixedsrcargs::Vararg{Any,N}) where N
function spbroadcast_args!(f::Tf, C, ::SPVM, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N}
_aresameshape(C, A, Bs...) && return _noshapecheck_map!(f, C, A, Bs...)
Base.Broadcast.check_broadcast_indices(indices(C), A, Bs...)
fofzeros = f(_zeros_eltypes(A, Bs...)...)
fpreszeros = _iszero(fofzeros)
return fpreszeros ? _broadcast_zeropres!(f, C, A, Bs...) :
_broadcast_notzeropres!(f, fofzeros, C, A, Bs...)
end
function spbroadcast_args!(f, dest, ::SPVM, mixedsrcargs::Vararg{Any,N}) where N
# mixedsrcargs contains nothing but SparseVecOrMat and scalars
parevalf, passedsrcargstup = capturescalars(f, mixedsrcargs)
return broadcast!(parevalf, dest, passedsrcargstup...)
Expand Down

0 comments on commit f819811

Please sign in to comment.