Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Add weights argument to sum #33310

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 187 additions & 8 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,16 @@ reduce(op, A::AbstractArray; kw...) = mapreduce(identity, op, A; kw...)

##### Specific reduction functions #####
"""
sum(A::AbstractArray; dims)
sum(A::AbstractArray; [dims], [weights::AbstractArray])

Sum elements of an array over the given dimensions.
Compute the sum of array `A`.
If `dims` is provided, return an array of sums over these dimensions.
If `weights` is provided, return the weighted sum(s). `weights` must be
either an array of the same size as `A` if `dims` is omitted,
or a vector with the same length as `size(A, dims)` if `dims` is provided.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should weights have to be an array, as opposed to just an iterable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's how things work currently in StatsBase. I guess we could make this more general at least for the simple case where dims=:. But that would require adding a separate method as @simd only works when using indices.


!!! compat "Julia 1.4"
The `weights` keyword argument requires at least Julia 1.4.

# Examples
```jldoctest
Expand All @@ -372,14 +379,26 @@ julia> sum(A, dims=2)
2×1 Array{Int64,2}:
3
7

julia> sum(A, weights=[2 1; 2 1])
14

julia> sum(A, weights=[2, 1], dims=1)
1×2 Array{Int64,2}:
5 8
```
"""
sum(A::AbstractArray; dims)

"""
sum!(r, A)
sum!(r, A; [weights::AbstractVector])

Sum elements of `A` over the singleton dimensions of `r`, and write results to `r`.
If `r` has only one singleton dimension `i`, `weights` can be a vector of length
`size(v, i)` to compute the weighted mean.

!!! compat "Julia 1.4"
The `weights` keyword argument requires at least Julia 1.4.

# Examples
```jldoctest
Expand Down Expand Up @@ -645,7 +664,7 @@ julia> any!([1 1], A)
"""
any!(r, A)

for (fname, _fname, op) in [(:sum, :_sum, :add_sum), (:prod, :_prod, :mul_prod),
for (fname, _fname, op) in [(:prod, :_prod, :mul_prod),
(:maximum, :_maximum, :max), (:minimum, :_minimum, :min)]
@eval begin
# User-facing methods with keyword arguments
Expand All @@ -658,28 +677,188 @@ for (fname, _fname, op) in [(:sum, :_sum, :add_sum), (:prod, :_prod,
end
end

# Sum is the only reduction which supports weights
sum(a::AbstractArray; dims=:, weights::Union{AbstractArray,Nothing}=nothing) =
_sum(a, dims, weights)
sum(f, a::AbstractArray; dims=:, weights::Union{AbstractArray,Nothing}=nothing) =
_sum(f, a, dims, weights)
sum(a, ::Colon, weights) = _sum(identity, a, :, weights)
sum(f, a, ::Colon, ::Nothing) = mapreduce(f, add_sum, a)

any(a::AbstractArray; dims=:) = _any(a, dims)
any(f::Function, a::AbstractArray; dims=:) = _any(f, a, dims)
_any(a, ::Colon) = _any(identity, a, :)
all(a::AbstractArray; dims=:) = _all(a, dims)
all(f::Function, a::AbstractArray; dims=:) = _all(f, a, dims)
_all(a, ::Colon) = _all(identity, a, :)

for (fname, op) in [(:sum, :add_sum), (:prod, :mul_prod),
for (fname, op) in [(:prod, :mul_prod),
(:maximum, :max), (:minimum, :min),
(:all, :&), (:any, :|)]
fname! = Symbol(fname, '!')
_fname! = Symbol('_', fname, '!')
_fname = Symbol('_', fname)
@eval begin
$(fname!)(r::AbstractArray, A::AbstractArray; init::Bool=true) =
$(fname!)(identity, r, A; init=init)
$(fname!)(f::Function, r::AbstractArray, A::AbstractArray; init::Bool=true) =
mapreducedim!(f, $(op), initarray!(r, $(op), init, A), A)
$(fname!)(r::AbstractArray, A::AbstractArray; init::Bool=true) = $(fname!)(identity, r, A; init=init)
$(_fname!)(f, r, A; init=init)

$(_fname)(A, dims) = $(_fname)(identity, A, dims)
# Underlying implementations using dispatch
$(_fname!)(f, r::AbstractArray, A::AbstractArray; init::Bool=true) =
mapreducedim!(f, $(op), initarray!(r, $(op), init, A), A)
$(_fname)(A, dims) = $(_fname)(identity, A, dims)
$(_fname)(f, A, dims) = mapreduce(f, $(op), A, dims=dims)
end
end

# Sum is the only reduction which supports weights
sum!(r::AbstractArray, A::AbstractArray;
init::Bool=true, weights::Union{AbstractArray,Nothing}=nothing) =
sum!(identity, r, A; init=init, weights=weights)
sum!(f::Function, r::AbstractArray, A::AbstractArray;
init::Bool=true, weights::Union{AbstractArray,Nothing}=nothing) =
_sum!(f, r, A, weights; init=init)
_sum!(f, r::AbstractArray, A::AbstractArray, ::Nothing; init::Bool=true) =
mapreducedim!(f, add_sum, initarray!(r, add_sum, init, A), A)
_sum(A::AbstractArray, dims, weights) = _sum(identity, A, dims, weights)
_sum(f, A::AbstractArray, dims, ::Nothing) = mapreduce(f, add_sum, A, dims=dims)
_sum(::typeof(identity), A::AbstractArray, dims, w::AbstractArray) =
_sum!(identity, reducedim_init(t -> t*zero(eltype(w)), add_sum, A, dims), A, w)
_sum(f, A::AbstractArray, dims, w::AbstractArray) =
throw(ArgumentError("Passing a function is not supported with `weights`"))


# Weighted sum
function _sum(A::AbstractArray, dims::Colon, w::AbstractArray{<:Real})
sw = size(w)
sA = size(A)
if sw != sA
throw(DimensionMismatch("weights must have the same dimension as data (got $sw and $sA)."))
end
s0 = zero(eltype(A)) * zero(eltype(w))
s = add_sum(s0, s0)
@inbounds @simd for i in eachindex(A, w)
s += A[i] * w[i]
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be much less accurate than the non-weighted algorithm based on pairwise summation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, probably (I haven't considered this, that's just a port of the StatsBase code with some cleanup). We could improve this, but note that for BlasReal the BLAS-optimized dot is called instead, which has different properties anyway (not sure whether it's going to be more accurate or less).

If we want to preserve accuracy, that would be an argument in favor of this PR, since a naive broadcast call won't use pairwise summation.

s
end

# Weighted sum over dimensions
#
# Brief explanation of the algorithm:
# ------------------------------------
#
# 1. _wsum! provides the core implementation, which assumes that
# the dimensions of all input arguments are consistent, and no
# dimension checking is performed therein.
#
# wsum and wsum! perform argument checking and call _wsum!
# internally.
#
# 2. _wsum! adopt a Cartesian based implementation for general
# sub types of AbstractArray. Particularly, a faster routine
# that keeps a local accumulator will be used when dim = 1.
#
# The internal function that implements this is _wsum_general!
#
# 3. _wsum! is specialized for following cases:
# (a) A is a vector: we invoke the vector version wsum above.
# The internal function that implements this is _wsum1!
#
# (b) A is a dense matrix with eltype <: BlasReal: we call gemv!
# The internal function that implements this is _wsum2_blas!
# (in LinearAlgebra/src/wsum.jl)
#
# (c) A is a contiguous array with eltype <: BlasReal:
# dim == 1: treat A like a matrix of size (d1, d2 x ... x dN)
# dim == N: treat A like a matrix of size (d1 x ... x d(N-1), dN)
# otherwise: decompose A into multiple pages, and apply _wsum2_blas!
# for each
# The internal function that implements this is _wsumN!
# (in LinearAlgebra/src/wsum.jl)
#
# (d) A is a general dense array with eltype <: BlasReal:
# dim <= 2: delegate to (a) and (b)
# otherwise, decompose A into multiple pages
# The internal function that implements this is _wsumN!
# (in LinearAlgebra/src/wsum.jl)

function _wsum1!(R::AbstractArray, A::AbstractVector, w::AbstractVector, init::Bool)
r = _sum(A, :, w)
if init
R[1] = r
else
R[1] += r
end
return R
end

function _wsum_general!(R::AbstractArray{S}, A::AbstractArray, w::AbstractVector,
dim::Int, init::Bool) where {S}
# following the implementation of _mapreducedim!
lsiz = check_reducedims(R,A)
!isempty(R) && init && fill!(R, zero(S))
isempty(A) && return R

indsAt, indsRt = safe_tail(axes(A)), safe_tail(axes(R)) # handle d=1 manually
keep, Idefault = Broadcast.shapeindexer(indsRt)
if reducedim1(R, A)
i1 = first(axes1(R))
for IA in CartesianIndices(indsAt)
IR = Broadcast.newindex(IA, keep, Idefault)
r = R[i1,IR]
@inbounds @simd for i in axes(A, 1)
r += A[i,IA] * w[dim > 1 ? IA[dim-1] : i]
end
R[i1,IR] = r
end
else
for IA in CartesianIndices(indsAt)
IR = Broadcast.newindex(IA, keep, Idefault)
@inbounds @simd for i in axes(A, 1)
R[i,IR] += A[i,IA] * w[dim > 1 ? IA[dim-1] : i]
end
end
end
return R
end

_wsum!(R::AbstractArray, A::AbstractVector, w::AbstractVector,
dim::Int, init::Bool) =
_wsum1!(R, A, w, init)

_wsum!(R::AbstractArray, A::AbstractArray, w::AbstractVector,
dim::Int, init::Bool) =
_wsum_general!(R, A, w, dim, init)

function _sum!(f, R::AbstractArray, A::AbstractArray{T,N}, w::AbstractArray;
init::Bool=true) where {T,N}
f === identity || throw(ArgumentError("Passing a function is not supported with `weights`"))
w isa AbstractVector || throw(ArgumentError("Only vector `weights` are supported"))

check_reducedims(R,A)
reddims = size(R) .!= size(A)
dim = something(findfirst(reddims), ndims(R)+1)
if dim > N
dim1 = findfirst(==(1), size(A))
if dim1 !== nothing
dim = dim1
end
end
if findnext(reddims, dim+1) !== nothing
throw(ArgumentError("reducing over more than one dimension is not supported with weights"))
end
lw = length(w)
ldim = size(A, dim)
if lw != ldim
throw(DimensionMismatch("weights must have the same length as the dimension " *
"over which reduction is performed (got $lw and $ldim)."))
end
_wsum!(R, A, w, dim, init)
end


##### findmin & findmax #####
# The initial values of Rval are not used if the corresponding indices in Rind are 0.
#
Expand Down
1 change: 1 addition & 0 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ include("bitarray.jl")
include("ldlt.jl")
include("schur.jl")
include("structuredbroadcast.jl")
include("wsum.jl")
include("deprecated.jl")

const ⋅ = dot
Expand Down
94 changes: 94 additions & 0 deletions stdlib/LinearAlgebra/src/wsum.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Optimized method for weighted sum with BlasReal
# dot cannot be used for other types as it uses + rather than add_sum for accumulation,
# and therefore does not return the correct type
Base._sum(A::AbstractArray{T}, dims::Colon, w::AbstractArray{T}) where {T<:BlasReal} =
dot(vec(A), vec(w))

# Optimized methods for weighted sum over dimensions with BlasReal
# (generic method is defined in base/reducedim.jl)
#
# _wsum! is specialized for following cases:
# (a) A is a dense matrix with eltype <: BlasReal: we call gemv!
# The internal function that implements this is _wsum2_blas!
#
# (b) A is a contiguous array with eltype <: BlasReal:
# dim == 1: treat A like a matrix of size (d1, d2 x ... x dN)
# dim == N: treat A like a matrix of size (d1 x ... x d(N-1), dN)
# otherwise: decompose A into multiple pages, and apply _wsum2_blas!
# for each
# The internal function that implements this is _wsumN!
#
# (c) A is a general dense array with eltype <: BlasReal:
# dim <= 2: delegate to (a) and (b)
# otherwise, decompose A into multiple pages
# The internal function that implements this is _wsumN!

function _wsum2_blas!(R::StridedVector{T}, A::StridedMatrix{T}, w::StridedVector{T},
dim::Int, init::Bool) where T<:BlasReal
beta = ifelse(init, zero(T), one(T))
trans = dim == 1 ? 'T' : 'N'
BLAS.gemv!(trans, one(T), A, w, beta, R)
return R
end

function _wsumN!(R::StridedArray{T}, A::StridedArray{T,N}, w::StridedVector{T},
dim::Int, init::Bool) where {T<:BlasReal,N}
if dim == 1
m = size(A, 1)
n = div(length(A), m)
_wsum2_blas!(view(R,:), reshape(A, (m, n)), w, 1, init)
elseif dim == N
n = size(A, N)
m = div(length(A), n)
_wsum2_blas!(view(R,:), reshape(A, (m, n)), w, 2, init)
else # 1 < dim < N
m = 1
for i = 1:dim-1
m *= size(A, i)
end
n = size(A, dim)
k = 1
for i = dim+1:N
k *= size(A, i)
end
Av = reshape(A, (m, n, k))
Rv = reshape(R, (m, k))
for i = 1:k
_wsum2_blas!(view(Rv,:,i), view(Av,:,:,i), w, 2, init)
end
end
return R
end

function _wsumN!(R::StridedArray{T}, A::DenseArray{T,N}, w::StridedVector{T},
dim::Int, init::Bool) where {T<:BlasReal,N}
@assert N >= 3
if dim <= 2
m = size(A, 1)
n = size(A, 2)
npages = 1
for i = 3:N
npages *= size(A, i)
end
rlen = ifelse(dim == 1, n, m)
Rv = reshape(R, (rlen, npages))
for i = 1:npages
_wsum2_blas!(view(Rv,:,i), view(A,:,:,i), w, dim, init)
end
else
Base._wsum_general!(R, A, w, dim, init)
end
return R
end

Base._wsum!(R::StridedArray{T}, A::DenseMatrix{T}, w::StridedVector{T},
dim::Int, init::Bool) where {T<:BlasReal} =
_wsum2_blas!(view(R,:), A, w, dim, init)

Base._wsum!(R::StridedArray{T}, A::DenseArray{T}, w::StridedVector{T},
dim::Int, init::Bool) where {T<:BlasReal} =
_wsumN!(R, A, w, dim, init)

Base._wsum!(R::StridedVector{T}, A::DenseArray{T}, w::StridedVector{T},
dim::Int, init::Bool) where {T<:BlasReal} =
Base._wsum1!(R, A, w, init)
35 changes: 35 additions & 0 deletions test/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -533,3 +533,38 @@ x = [j^2 for j in i]
i = Base.Slice(0:0)
x = [j+7 for j in i]
@test sum(x) == 7

@testset "weighted sum" begin
wts = ([1.4, 2.5, 10.1], [1.4f0, 2.5f0, 10.1f0], [0.0, 2.3, 5.6],
[NaN, 2.3, 5.6], [Inf, 2.3, 5.6],
[2, 1, 3], Int8[1, 2, 3], [1, 1, 1])
for a in (rand(3), rand(Int, 3), rand(Int8, 3))
for w in wts
res = @inferred sum(a, weights=w)
expected = sum(a.*w)
if isfinite(res)
@test res ≈ expected
else
@test isequal(res, expected)
end
@test typeof(res) == typeof(expected)
end
end
for a in (rand(3, 5), rand(Float32, 3, 5), rand(Int, 3, 5), rand(Int8, 3, 5))
for w in wts
wr = repeat(w, outer=(1, 5))
res = @inferred sum(a, weights=wr)
expected = sum(a.*wr)
if isfinite(res)
@test res ≈ expected
else
@test isequal(res, expected)
end
@test typeof(res) == typeof(expected)
end
end

@test_throws ArgumentError sum(exp, [1], weights=[1])
@test_throws ArgumentError sum!(exp, [0 0], [1 2], weights=[1, 10])
@test_throws ArgumentError sum!([0 0], [1 2], weights=[1 10])
end
Loading