Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Import weighted stats and moments from StatsBase to Statistics #31395

Closed
wants to merge 12 commits into from
198 changes: 190 additions & 8 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,19 @@ reduce(op, A::AbstractArray; kw...) = mapreduce(identity, op, A; kw...)

##### Specific reduction functions #####
"""
sum(A::AbstractArray; dims)
sum(A::AbstractArray; [dims], [weights::AbstractArray])

Sum elements of an array over the given dimensions.
Compute the sum of array `A`.
If `dims` is provided, return an array of sums over these dimensions.
If `weights` is provided, return the weighted sum(s). `weights` must be
either an array of the same size as `A` if `dims` is omitted,
or a vector with the same length as `size(A, dims)` if `dims` is provided.

!!! compat "Julia 1.1"
`mean` for empty arrays requires at least Julia 1.1.

!!! compat "Julia 1.3"
The `weights` keyword argument requires at least Julia 1.3.

# Examples
```jldoctest
Expand All @@ -372,14 +382,26 @@ julia> sum(A, dims=2)
2×1 Array{Int64,2}:
3
7

julia> sum(A, weights=[2 1; 2 1])
14

julia> sum(A, weights=[2, 1], dims=1)
1×2 Array{Int64,2}:
5 8
```
"""
sum(A::AbstractArray; dims)

"""
sum!(r, A)
sum!(r, A; [weights::AbstractVector])

Sum elements of `A` over the singleton dimensions of `r`, and write results to `r`.
If `r` has only one singleton dimension `i`, `weights` can be a vector of length
`size(v, i)` to compute the weighted mean.

!!! compat "Julia 1.3"
The `weights` keyword argument requires at least Julia 1.3.

# Examples
```jldoctest
Expand Down Expand Up @@ -645,7 +667,7 @@ julia> any!([1 1], A)
"""
any!(r, A)

for (fname, _fname, op) in [(:sum, :_sum, :add_sum), (:prod, :_prod, :mul_prod),
for (fname, _fname, op) in [(:prod, :_prod, :mul_prod),
(:maximum, :_maximum, :max), (:minimum, :_minimum, :min)]
@eval begin
# User-facing methods with keyword arguments
Expand All @@ -658,28 +680,188 @@ for (fname, _fname, op) in [(:sum, :_sum, :add_sum), (:prod, :_prod,
end
end

# Sum is the only reduction which supports weights
sum(a::AbstractArray; dims=:, weights::Union{AbstractArray,Nothing}=nothing) =
_sum(a, dims, weights)
sum(f, a::AbstractArray; dims=:, weights::Union{AbstractArray,Nothing}=nothing) =
_sum(f, a, dims, weights)
sum(a, ::Colon, weights) = _sum(identity, a, :, weights)
sum(f, a, ::Colon, ::Nothing) = mapreduce(f, add_sum, a)

any(a::AbstractArray; dims=:) = _any(a, dims)
any(f::Function, a::AbstractArray; dims=:) = _any(f, a, dims)
_any(a, ::Colon) = _any(identity, a, :)
all(a::AbstractArray; dims=:) = _all(a, dims)
all(f::Function, a::AbstractArray; dims=:) = _all(f, a, dims)
_all(a, ::Colon) = _all(identity, a, :)

for (fname, op) in [(:sum, :add_sum), (:prod, :mul_prod),
for (fname, op) in [(:prod, :mul_prod),
(:maximum, :max), (:minimum, :min),
(:all, :&), (:any, :|)]
fname! = Symbol(fname, '!')
_fname! = Symbol('_', fname, '!')
_fname = Symbol('_', fname)
@eval begin
$(fname!)(r::AbstractArray, A::AbstractArray; init::Bool=true) =
$(fname!)(identity, r, A; init=init)
$(fname!)(f::Function, r::AbstractArray, A::AbstractArray; init::Bool=true) =
mapreducedim!(f, $(op), initarray!(r, $(op), init, A), A)
$(fname!)(r::AbstractArray, A::AbstractArray; init::Bool=true) = $(fname!)(identity, r, A; init=init)
$(_fname!)(f, r, A; init=init)

$(_fname)(A, dims) = $(_fname)(identity, A, dims)
# Underlying implementations using dispatch
$(_fname!)(f, r::AbstractArray, A::AbstractArray; init::Bool=true) =
mapreducedim!(f, $(op), initarray!(r, $(op), init, A), A)
$(_fname)(A, dims) = $(_fname)(identity, A, dims)
$(_fname)(f, A, dims) = mapreduce(f, $(op), A, dims=dims)
end
end

# Sum is the only reduction which supports weights
sum!(r::AbstractArray, A::AbstractArray;
init::Bool=true, weights::Union{AbstractArray,Nothing}=nothing) =
sum!(identity, r, A; init=init, weights=weights)
sum!(f::Function, r::AbstractArray, A::AbstractArray;
init::Bool=true, weights::Union{AbstractArray,Nothing}=nothing) =
_sum!(f, r, A, weights; init=init)
_sum!(f, r::AbstractArray, A::AbstractArray, ::Nothing; init::Bool=true) =
mapreducedim!(f, add_sum, initarray!(r, add_sum, init, A), A)
_sum(A::AbstractArray, dims, weights) = _sum(identity, A, dims, weights)
_sum(f, A::AbstractArray, dims, ::Nothing) = mapreduce(f, add_sum, A, dims=dims)
_sum(::typeof(identity), A::AbstractArray, dims, w::AbstractArray) =
_sum!(identity, reducedim_init(t -> t*zero(eltype(w)), add_sum, A, dims), A, w)
_sum(f, A::AbstractArray, dims, w::AbstractArray) =
throw(ArgumentError("Passing a function is not supported with `weights`"))


# Weighted sum
function _sum(A::AbstractArray, dims::Colon, w::AbstractArray{<:Real})
sw = size(w)
sA = size(A)
if sw != sA
throw(DimensionMismatch("weights must have the same dimension as data (got $sw and $sA)."))
end
s0 = zero(eltype(A)) * zero(eltype(w))
s = add_sum(s0, s0)
@inbounds @simd for i in eachindex(A, w)
s += A[i] * w[i]
end
s
end

# Weighted sum over dimensions
#
# Brief explanation of the algorithm:
# ------------------------------------
#
# 1. _wsum! provides the core implementation, which assumes that
# the dimensions of all input arguments are consistent, and no
# dimension checking is performed therein.
#
# wsum and wsum! perform argument checking and call _wsum!
# internally.
#
# 2. _wsum! adopt a Cartesian based implementation for general
# sub types of AbstractArray. Particularly, a faster routine
# that keeps a local accumulator will be used when dim = 1.
#
# The internal function that implements this is _wsum_general!
#
# 3. _wsum! is specialized for following cases:
# (a) A is a vector: we invoke the vector version wsum above.
# The internal function that implements this is _wsum1!
#
# (b) A is a dense matrix with eltype <: BlasReal: we call gemv!
# The internal function that implements this is _wsum2_blas!
# (in LinearAlgebra/src/wsum.jl)
#
# (c) A is a contiguous array with eltype <: BlasReal:
# dim == 1: treat A like a matrix of size (d1, d2 x ... x dN)
# dim == N: treat A like a matrix of size (d1 x ... x d(N-1), dN)
# otherwise: decompose A into multiple pages, and apply _wsum2_blas!
# for each
# The internal function that implements this is _wsumN!
# (in LinearAlgebra/src/wsum.jl)
#
# (d) A is a general dense array with eltype <: BlasReal:
# dim <= 2: delegate to (a) and (b)
# otherwise, decompose A into multiple pages
# The internal function that implements this is _wsumN!
# (in LinearAlgebra/src/wsum.jl)

function _wsum1!(R::AbstractArray, A::AbstractVector, w::AbstractVector, init::Bool)
r = _sum(A, :, w)
if init
R[1] = r
else
R[1] += r
end
return R
end

function _wsum_general!(R::AbstractArray{S}, A::AbstractArray, w::AbstractVector,
dim::Int, init::Bool) where {S}
# following the implementation of _mapreducedim!
lsiz = check_reducedims(R,A)
!isempty(R) && init && fill!(R, zero(S))
isempty(A) && return R

indsAt, indsRt = safe_tail(axes(A)), safe_tail(axes(R)) # handle d=1 manually
keep, Idefault = Broadcast.shapeindexer(indsRt)
if reducedim1(R, A)
i1 = first(axes1(R))
for IA in CartesianIndices(indsAt)
IR = Broadcast.newindex(IA, keep, Idefault)
r = R[i1,IR]
@inbounds @simd for i in axes(A, 1)
r += A[i,IA] * w[dim > 1 ? IA[dim-1] : i]
end
R[i1,IR] = r
end
else
for IA in CartesianIndices(indsAt)
IR = Broadcast.newindex(IA, keep, Idefault)
@inbounds @simd for i in axes(A, 1)
R[i,IR] += A[i,IA] * w[dim > 1 ? IA[dim-1] : i]
end
end
end
return R
end

_wsum!(R::AbstractArray, A::AbstractVector, w::AbstractVector,
dim::Int, init::Bool) =
_wsum1!(R, A, w, init)

_wsum!(R::AbstractArray, A::AbstractArray, w::AbstractVector,
dim::Int, init::Bool) =
_wsum_general!(R, A, w, dim, init)

function _sum!(f, R::AbstractArray, A::AbstractArray{T,N}, w::AbstractArray;
init::Bool=true) where {T,N}
f === identity || throw(ArgumentError("Passing a function is not supported with `weights`"))
w isa AbstractVector || throw(ArgumentError("Only vector `weights` are supported"))

check_reducedims(R,A)
reddims = size(R) .!= size(A)
dim = something(findfirst(reddims), ndims(R)+1)
if dim > N
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is quite ugly but I'm not sure what's the best solution. For unweighted sum, reducing over dim > N is a no-op, so that's easy, but for the weighted sum it amounts to multiplying values by their corresponding weight. Maybe this should just be an error?

dim1 = findfirst(==(1), size(A))
if dim1 !== nothing
dim = dim1
end
end
if findnext(reddims, dim+1) !== nothing
throw(ArgumentError("reducing over more than one dimension is not supported with weights"))
end
lw = length(w)
ldim = size(A, dim)
if lw != ldim
throw(DimensionMismatch("weights must have the same length as the dimension " *
"over which reduction is performed (got $lw and $ldim)."))
end
_wsum!(R, A, w, dim, init)
end


##### findmin & findmax #####
# The initial values of Rval are not used if the corresponding indices in Rind are 0.
#
Expand Down
3 changes: 2 additions & 1 deletion stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, as
getproperty, imag, inv, isapprox, isone, iszero, IndexStyle, kron, length, log, map, ndims,
oneunit, parent, power_by_squaring, print_matrix, promote_rule, real, round, sec, sech,
setindex!, show, similar, sin, sincos, sinh, size, sqrt,
strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec
strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec, _wsum!
using Base: hvcat_fill, IndexLinear, promote_op, promote_typeof,
@propagate_inbounds, @pure, reduce, typed_vcat, require_one_based_indexing
using Base.Broadcast: Broadcasted
Expand Down Expand Up @@ -373,6 +373,7 @@ include("bitarray.jl")
include("ldlt.jl")
include("schur.jl")
include("structuredbroadcast.jl")
include("wsum.jl")
include("deprecated.jl")

const ⋅ = dot
Expand Down
94 changes: 94 additions & 0 deletions stdlib/LinearAlgebra/src/wsum.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Optimized method for weighted sum with BlasReal
# dot cannot be used for other types as it uses + rather than add_sum for accumulation,
# and therefore does not return the correct type
Base._sum(A::AbstractArray{T}, dims::Colon, w::AbstractArray{T}) where {T<:BlasReal} =
dot(vec(A), vec(w))

# Optimized methods for weighted sum over dimensions with BlasReal
# (generic method is defined in base/reducedim.jl)
#
# _wsum! is specialized for following cases:
# (a) A is a dense matrix with eltype <: BlasReal: we call gemv!
# The internal function that implements this is _wsum2_blas!
#
# (b) A is a contiguous array with eltype <: BlasReal:
# dim == 1: treat A like a matrix of size (d1, d2 x ... x dN)
# dim == N: treat A like a matrix of size (d1 x ... x d(N-1), dN)
# otherwise: decompose A into multiple pages, and apply _wsum2_blas!
# for each
# The internal function that implements this is _wsumN!
#
# (c) A is a general dense array with eltype <: BlasReal:
# dim <= 2: delegate to (a) and (b)
# otherwise, decompose A into multiple pages
# The internal function that implements this is _wsumN!

function _wsum2_blas!(R::StridedVector{T}, A::StridedMatrix{T}, w::StridedVector{T},
dim::Int, init::Bool) where T<:BlasReal
beta = ifelse(init, zero(T), one(T))
trans = dim == 1 ? 'T' : 'N'
BLAS.gemv!(trans, one(T), A, w, beta, R)
return R
end

function _wsumN!(R::StridedArray{T}, A::StridedArray{T,N}, w::StridedVector{T},
dim::Int, init::Bool) where {T<:BlasReal,N}
if dim == 1
m = size(A, 1)
n = div(length(A), m)
_wsum2_blas!(view(R,:), reshape(A, (m, n)), w, 1, init)
elseif dim == N
n = size(A, N)
m = div(length(A), n)
_wsum2_blas!(view(R,:), reshape(A, (m, n)), w, 2, init)
else # 1 < dim < N
m = 1
for i = 1:dim-1
m *= size(A, i)
end
n = size(A, dim)
k = 1
for i = dim+1:N
k *= size(A, i)
end
Av = reshape(A, (m, n, k))
Rv = reshape(R, (m, k))
for i = 1:k
_wsum2_blas!(view(Rv,:,i), view(Av,:,:,i), w, 2, init)
end
end
return R
end

function _wsumN!(R::StridedArray{T}, A::DenseArray{T,N}, w::StridedVector{T},
dim::Int, init::Bool) where {T<:BlasReal,N}
@assert N >= 3
if dim <= 2
m = size(A, 1)
n = size(A, 2)
npages = 1
for i = 3:N
npages *= size(A, i)
end
rlen = ifelse(dim == 1, n, m)
Rv = reshape(R, (rlen, npages))
for i = 1:npages
_wsum2_blas!(view(Rv,:,i), view(A,:,:,i), w, dim, init)
end
else
Base._wsum_general!(R, A, w, dim, init)
end
return R
end

Base._wsum!(R::StridedArray{T}, A::DenseMatrix{T}, w::StridedVector{T},
dim::Int, init::Bool) where {T<:BlasReal} =
_wsum2_blas!(view(R,:), A, w, dim, init)

Base._wsum!(R::StridedArray{T}, A::DenseArray{T}, w::StridedVector{T},
dim::Int, init::Bool) where {T<:BlasReal} =
_wsumN!(R, A, w, dim, init)

Base._wsum!(R::StridedVector{T}, A::DenseArray{T}, w::StridedVector{T},
dim::Int, init::Bool) where {T<:BlasReal} =
Base._wsum1!(R, A, w, init)
Loading