-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: Add weights argument to sum #33310
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -353,9 +353,16 @@ reduce(op, A::AbstractArray; kw...) = mapreduce(identity, op, A; kw...) | |
|
||
##### Specific reduction functions ##### | ||
""" | ||
sum(A::AbstractArray; dims) | ||
sum(A::AbstractArray; [dims], [weights::AbstractArray]) | ||
|
||
Sum elements of an array over the given dimensions. | ||
Compute the sum of array `A`. | ||
If `dims` is provided, return an array of sums over these dimensions. | ||
If `weights` is provided, return the weighted sum(s). `weights` must be | ||
either an array of the same size as `A` if `dims` is omitted, | ||
or a vector with the same length as `size(A, dims)` if `dims` is provided. | ||
|
||
!!! compat "Julia 1.4" | ||
The `weights` keyword argument requires at least Julia 1.4. | ||
|
||
# Examples | ||
```jldoctest | ||
|
@@ -372,14 +379,26 @@ julia> sum(A, dims=2) | |
2×1 Array{Int64,2}: | ||
3 | ||
7 | ||
|
||
julia> sum(A, weights=[2 1; 2 1]) | ||
14 | ||
|
||
julia> sum(A, weights=[2, 1], dims=1) | ||
1×2 Array{Int64,2}: | ||
5 8 | ||
``` | ||
""" | ||
sum(A::AbstractArray; dims) | ||
|
||
""" | ||
sum!(r, A) | ||
sum!(r, A; [weights::AbstractVector]) | ||
|
||
Sum elements of `A` over the singleton dimensions of `r`, and write results to `r`. | ||
If `r` has only one singleton dimension `i`, `weights` can be a vector of length | ||
`size(v, i)` to compute the weighted mean. | ||
|
||
!!! compat "Julia 1.4" | ||
The `weights` keyword argument requires at least Julia 1.4. | ||
|
||
# Examples | ||
```jldoctest | ||
|
@@ -645,7 +664,7 @@ julia> any!([1 1], A) | |
""" | ||
any!(r, A) | ||
|
||
for (fname, _fname, op) in [(:sum, :_sum, :add_sum), (:prod, :_prod, :mul_prod), | ||
for (fname, _fname, op) in [(:prod, :_prod, :mul_prod), | ||
(:maximum, :_maximum, :max), (:minimum, :_minimum, :min)] | ||
@eval begin | ||
# User-facing methods with keyword arguments | ||
|
@@ -658,28 +677,188 @@ for (fname, _fname, op) in [(:sum, :_sum, :add_sum), (:prod, :_prod, | |
end | ||
end | ||
|
||
# Sum is the only reduction which supports weights | ||
sum(a::AbstractArray; dims=:, weights::Union{AbstractArray,Nothing}=nothing) = | ||
_sum(a, dims, weights) | ||
sum(f, a::AbstractArray; dims=:, weights::Union{AbstractArray,Nothing}=nothing) = | ||
_sum(f, a, dims, weights) | ||
sum(a, ::Colon, weights) = _sum(identity, a, :, weights) | ||
sum(f, a, ::Colon, ::Nothing) = mapreduce(f, add_sum, a) | ||
|
||
any(a::AbstractArray; dims=:) = _any(a, dims) | ||
any(f::Function, a::AbstractArray; dims=:) = _any(f, a, dims) | ||
_any(a, ::Colon) = _any(identity, a, :) | ||
all(a::AbstractArray; dims=:) = _all(a, dims) | ||
all(f::Function, a::AbstractArray; dims=:) = _all(f, a, dims) | ||
_all(a, ::Colon) = _all(identity, a, :) | ||
|
||
for (fname, op) in [(:sum, :add_sum), (:prod, :mul_prod), | ||
for (fname, op) in [(:prod, :mul_prod), | ||
(:maximum, :max), (:minimum, :min), | ||
(:all, :&), (:any, :|)] | ||
fname! = Symbol(fname, '!') | ||
_fname! = Symbol('_', fname, '!') | ||
_fname = Symbol('_', fname) | ||
@eval begin | ||
$(fname!)(r::AbstractArray, A::AbstractArray; init::Bool=true) = | ||
$(fname!)(identity, r, A; init=init) | ||
$(fname!)(f::Function, r::AbstractArray, A::AbstractArray; init::Bool=true) = | ||
mapreducedim!(f, $(op), initarray!(r, $(op), init, A), A) | ||
$(fname!)(r::AbstractArray, A::AbstractArray; init::Bool=true) = $(fname!)(identity, r, A; init=init) | ||
$(_fname!)(f, r, A; init=init) | ||
|
||
$(_fname)(A, dims) = $(_fname)(identity, A, dims) | ||
# Underlying implementations using dispatch | ||
$(_fname!)(f, r::AbstractArray, A::AbstractArray; init::Bool=true) = | ||
mapreducedim!(f, $(op), initarray!(r, $(op), init, A), A) | ||
$(_fname)(A, dims) = $(_fname)(identity, A, dims) | ||
$(_fname)(f, A, dims) = mapreduce(f, $(op), A, dims=dims) | ||
end | ||
end | ||
|
||
# Sum is the only reduction which supports weights | ||
sum!(r::AbstractArray, A::AbstractArray; | ||
init::Bool=true, weights::Union{AbstractArray,Nothing}=nothing) = | ||
sum!(identity, r, A; init=init, weights=weights) | ||
sum!(f::Function, r::AbstractArray, A::AbstractArray; | ||
init::Bool=true, weights::Union{AbstractArray,Nothing}=nothing) = | ||
_sum!(f, r, A, weights; init=init) | ||
_sum!(f, r::AbstractArray, A::AbstractArray, ::Nothing; init::Bool=true) = | ||
mapreducedim!(f, add_sum, initarray!(r, add_sum, init, A), A) | ||
_sum(A::AbstractArray, dims, weights) = _sum(identity, A, dims, weights) | ||
_sum(f, A::AbstractArray, dims, ::Nothing) = mapreduce(f, add_sum, A, dims=dims) | ||
_sum(::typeof(identity), A::AbstractArray, dims, w::AbstractArray) = | ||
_sum!(identity, reducedim_init(t -> t*zero(eltype(w)), add_sum, A, dims), A, w) | ||
_sum(f, A::AbstractArray, dims, w::AbstractArray) = | ||
throw(ArgumentError("Passing a function is not supported with `weights`")) | ||
|
||
|
||
# Weighted sum | ||
function _sum(A::AbstractArray, dims::Colon, w::AbstractArray{<:Real}) | ||
sw = size(w) | ||
sA = size(A) | ||
if sw != sA | ||
throw(DimensionMismatch("weights must have the same dimension as data (got $sw and $sA).")) | ||
end | ||
s0 = zero(eltype(A)) * zero(eltype(w)) | ||
s = add_sum(s0, s0) | ||
@inbounds @simd for i in eachindex(A, w) | ||
s += A[i] * w[i] | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will be much less accurate than the non-weighted algorithm based on pairwise summation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, probably (I haven't considered this, that's just a port of the StatsBase code with some cleanup). We could improve this, but note that for If we want to preserve accuracy, that would be an argument in favor of this PR, since a naive |
||
s | ||
end | ||
|
||
# Weighted sum over dimensions | ||
# | ||
# Brief explanation of the algorithm: | ||
# ------------------------------------ | ||
# | ||
# 1. _wsum! provides the core implementation, which assumes that | ||
# the dimensions of all input arguments are consistent, and no | ||
# dimension checking is performed therein. | ||
# | ||
# wsum and wsum! perform argument checking and call _wsum! | ||
# internally. | ||
# | ||
# 2. _wsum! adopt a Cartesian based implementation for general | ||
# sub types of AbstractArray. Particularly, a faster routine | ||
# that keeps a local accumulator will be used when dim = 1. | ||
# | ||
# The internal function that implements this is _wsum_general! | ||
# | ||
# 3. _wsum! is specialized for following cases: | ||
# (a) A is a vector: we invoke the vector version wsum above. | ||
# The internal function that implements this is _wsum1! | ||
# | ||
# (b) A is a dense matrix with eltype <: BlasReal: we call gemv! | ||
# The internal function that implements this is _wsum2_blas! | ||
# (in LinearAlgebra/src/wsum.jl) | ||
# | ||
# (c) A is a contiguous array with eltype <: BlasReal: | ||
# dim == 1: treat A like a matrix of size (d1, d2 x ... x dN) | ||
# dim == N: treat A like a matrix of size (d1 x ... x d(N-1), dN) | ||
# otherwise: decompose A into multiple pages, and apply _wsum2_blas! | ||
# for each | ||
# The internal function that implements this is _wsumN! | ||
# (in LinearAlgebra/src/wsum.jl) | ||
# | ||
# (d) A is a general dense array with eltype <: BlasReal: | ||
# dim <= 2: delegate to (a) and (b) | ||
# otherwise, decompose A into multiple pages | ||
# The internal function that implements this is _wsumN! | ||
# (in LinearAlgebra/src/wsum.jl) | ||
|
||
function _wsum1!(R::AbstractArray, A::AbstractVector, w::AbstractVector, init::Bool) | ||
r = _sum(A, :, w) | ||
if init | ||
R[1] = r | ||
else | ||
R[1] += r | ||
end | ||
return R | ||
end | ||
|
||
function _wsum_general!(R::AbstractArray{S}, A::AbstractArray, w::AbstractVector, | ||
dim::Int, init::Bool) where {S} | ||
# following the implementation of _mapreducedim! | ||
lsiz = check_reducedims(R,A) | ||
!isempty(R) && init && fill!(R, zero(S)) | ||
isempty(A) && return R | ||
|
||
indsAt, indsRt = safe_tail(axes(A)), safe_tail(axes(R)) # handle d=1 manually | ||
keep, Idefault = Broadcast.shapeindexer(indsRt) | ||
if reducedim1(R, A) | ||
i1 = first(axes1(R)) | ||
for IA in CartesianIndices(indsAt) | ||
IR = Broadcast.newindex(IA, keep, Idefault) | ||
r = R[i1,IR] | ||
@inbounds @simd for i in axes(A, 1) | ||
r += A[i,IA] * w[dim > 1 ? IA[dim-1] : i] | ||
end | ||
R[i1,IR] = r | ||
end | ||
else | ||
for IA in CartesianIndices(indsAt) | ||
IR = Broadcast.newindex(IA, keep, Idefault) | ||
@inbounds @simd for i in axes(A, 1) | ||
R[i,IR] += A[i,IA] * w[dim > 1 ? IA[dim-1] : i] | ||
end | ||
end | ||
end | ||
return R | ||
end | ||
|
||
_wsum!(R::AbstractArray, A::AbstractVector, w::AbstractVector, | ||
dim::Int, init::Bool) = | ||
_wsum1!(R, A, w, init) | ||
|
||
_wsum!(R::AbstractArray, A::AbstractArray, w::AbstractVector, | ||
dim::Int, init::Bool) = | ||
_wsum_general!(R, A, w, dim, init) | ||
|
||
function _sum!(f, R::AbstractArray, A::AbstractArray{T,N}, w::AbstractArray; | ||
init::Bool=true) where {T,N} | ||
f === identity || throw(ArgumentError("Passing a function is not supported with `weights`")) | ||
w isa AbstractVector || throw(ArgumentError("Only vector `weights` are supported")) | ||
|
||
check_reducedims(R,A) | ||
reddims = size(R) .!= size(A) | ||
dim = something(findfirst(reddims), ndims(R)+1) | ||
if dim > N | ||
dim1 = findfirst(==(1), size(A)) | ||
if dim1 !== nothing | ||
dim = dim1 | ||
end | ||
end | ||
if findnext(reddims, dim+1) !== nothing | ||
throw(ArgumentError("reducing over more than one dimension is not supported with weights")) | ||
end | ||
lw = length(w) | ||
ldim = size(A, dim) | ||
if lw != ldim | ||
throw(DimensionMismatch("weights must have the same length as the dimension " * | ||
"over which reduction is performed (got $lw and $ldim).")) | ||
end | ||
_wsum!(R, A, w, dim, init) | ||
end | ||
|
||
|
||
##### findmin & findmax ##### | ||
# The initial values of Rval are not used if the corresponding indices in Rind are 0. | ||
# | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# Optimized method for weighted sum with BlasReal | ||
# dot cannot be used for other types as it uses + rather than add_sum for accumulation, | ||
# and therefore does not return the correct type | ||
Base._sum(A::AbstractArray{T}, dims::Colon, w::AbstractArray{T}) where {T<:BlasReal} = | ||
dot(vec(A), vec(w)) | ||
|
||
# Optimized methods for weighted sum over dimensions with BlasReal | ||
# (generic method is defined in base/reducedim.jl) | ||
# | ||
# _wsum! is specialized for following cases: | ||
# (a) A is a dense matrix with eltype <: BlasReal: we call gemv! | ||
# The internal function that implements this is _wsum2_blas! | ||
# | ||
# (b) A is a contiguous array with eltype <: BlasReal: | ||
# dim == 1: treat A like a matrix of size (d1, d2 x ... x dN) | ||
# dim == N: treat A like a matrix of size (d1 x ... x d(N-1), dN) | ||
# otherwise: decompose A into multiple pages, and apply _wsum2_blas! | ||
# for each | ||
# The internal function that implements this is _wsumN! | ||
# | ||
# (c) A is a general dense array with eltype <: BlasReal: | ||
# dim <= 2: delegate to (a) and (b) | ||
# otherwise, decompose A into multiple pages | ||
# The internal function that implements this is _wsumN! | ||
|
||
function _wsum2_blas!(R::StridedVector{T}, A::StridedMatrix{T}, w::StridedVector{T}, | ||
dim::Int, init::Bool) where T<:BlasReal | ||
beta = ifelse(init, zero(T), one(T)) | ||
trans = dim == 1 ? 'T' : 'N' | ||
BLAS.gemv!(trans, one(T), A, w, beta, R) | ||
return R | ||
end | ||
|
||
function _wsumN!(R::StridedArray{T}, A::StridedArray{T,N}, w::StridedVector{T}, | ||
dim::Int, init::Bool) where {T<:BlasReal,N} | ||
if dim == 1 | ||
m = size(A, 1) | ||
n = div(length(A), m) | ||
_wsum2_blas!(view(R,:), reshape(A, (m, n)), w, 1, init) | ||
elseif dim == N | ||
n = size(A, N) | ||
m = div(length(A), n) | ||
_wsum2_blas!(view(R,:), reshape(A, (m, n)), w, 2, init) | ||
else # 1 < dim < N | ||
m = 1 | ||
for i = 1:dim-1 | ||
m *= size(A, i) | ||
end | ||
n = size(A, dim) | ||
k = 1 | ||
for i = dim+1:N | ||
k *= size(A, i) | ||
end | ||
Av = reshape(A, (m, n, k)) | ||
Rv = reshape(R, (m, k)) | ||
for i = 1:k | ||
_wsum2_blas!(view(Rv,:,i), view(Av,:,:,i), w, 2, init) | ||
end | ||
end | ||
return R | ||
end | ||
|
||
function _wsumN!(R::StridedArray{T}, A::DenseArray{T,N}, w::StridedVector{T}, | ||
dim::Int, init::Bool) where {T<:BlasReal,N} | ||
@assert N >= 3 | ||
if dim <= 2 | ||
m = size(A, 1) | ||
n = size(A, 2) | ||
npages = 1 | ||
for i = 3:N | ||
npages *= size(A, i) | ||
end | ||
rlen = ifelse(dim == 1, n, m) | ||
Rv = reshape(R, (rlen, npages)) | ||
for i = 1:npages | ||
_wsum2_blas!(view(Rv,:,i), view(A,:,:,i), w, dim, init) | ||
end | ||
else | ||
Base._wsum_general!(R, A, w, dim, init) | ||
end | ||
return R | ||
end | ||
|
||
Base._wsum!(R::StridedArray{T}, A::DenseMatrix{T}, w::StridedVector{T}, | ||
dim::Int, init::Bool) where {T<:BlasReal} = | ||
_wsum2_blas!(view(R,:), A, w, dim, init) | ||
|
||
Base._wsum!(R::StridedArray{T}, A::DenseArray{T}, w::StridedVector{T}, | ||
dim::Int, init::Bool) where {T<:BlasReal} = | ||
_wsumN!(R, A, w, dim, init) | ||
|
||
Base._wsum!(R::StridedVector{T}, A::DenseArray{T}, w::StridedVector{T}, | ||
dim::Int, init::Bool) where {T<:BlasReal} = | ||
Base._wsum1!(R, A, w, init) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why should
weights
have to be an array, as opposed to just an iterable?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's how things work currently in StatsBase. I guess we could make this more general at least for the simple case where
dims=:
. But that would require adding a separate method as@simd
only works when using indices.