diff --git a/base/reducedim.jl b/base/reducedim.jl index 996729be8bc4c..9d0bcc592790c 100644 --- a/base/reducedim.jl +++ b/base/reducedim.jl @@ -353,9 +353,16 @@ reduce(op, A::AbstractArray; kw...) = mapreduce(identity, op, A; kw...) ##### Specific reduction functions ##### """ - sum(A::AbstractArray; dims) + sum(A::AbstractArray; [dims], [weights::AbstractArray]) -Sum elements of an array over the given dimensions. +Compute the sum of array `A`. +If `dims` is provided, return an array of sums over these dimensions. +If `weights` is provided, return the weighted sum(s). `weights` must be +either an array of the same size as `A` if `dims` is omitted, +or a vector with the same length as `size(A, dims)` if `dims` is provided. + +!!! compat "Julia 1.4" + The `weights` keyword argument requires at least Julia 1.4. # Examples ```jldoctest @@ -372,14 +379,26 @@ julia> sum(A, dims=2) 2×1 Array{Int64,2}: 3 7 + +julia> sum(A, weights=[2 1; 2 1]) +14 + +julia> sum(A, weights=[2, 1], dims=1) +1×2 Array{Int64,2}: + 5 8 ``` """ sum(A::AbstractArray; dims) """ - sum!(r, A) + sum!(r, A; [weights::AbstractVector]) Sum elements of `A` over the singleton dimensions of `r`, and write results to `r`. +If `r` has only one singleton dimension `i`, `weights` can be a vector of length +`size(v, i)` to compute the weighted mean. + +!!! compat "Julia 1.4" + The `weights` keyword argument requires at least Julia 1.4. # Examples ```jldoctest @@ -645,7 +664,7 @@ julia> any!([1 1], A) """ any!(r, A) -for (fname, _fname, op) in [(:sum, :_sum, :add_sum), (:prod, :_prod, :mul_prod), +for (fname, _fname, op) in [(:prod, :_prod, :mul_prod), (:maximum, :_maximum, :max), (:minimum, :_minimum, :min)] @eval begin # User-facing methods with keyword arguments @@ -658,6 +677,14 @@ for (fname, _fname, op) in [(:sum, :_sum, :add_sum), (:prod, :_prod, end end +# Sum is the only reduction which supports weights +sum(a::AbstractArray; dims=:, weights::Union{AbstractArray,Nothing}=nothing) = + _sum(a, dims, weights) +sum(f, a::AbstractArray; dims=:, weights::Union{AbstractArray,Nothing}=nothing) = + _sum(f, a, dims, weights) +sum(a, ::Colon, weights) = _sum(identity, a, :, weights) +sum(f, a, ::Colon, ::Nothing) = mapreduce(f, add_sum, a) + any(a::AbstractArray; dims=:) = _any(a, dims) any(f::Function, a::AbstractArray; dims=:) = _any(f, a, dims) _any(a, ::Colon) = _any(identity, a, :) @@ -665,21 +692,173 @@ all(a::AbstractArray; dims=:) = _all(a, dims) all(f::Function, a::AbstractArray; dims=:) = _all(f, a, dims) _all(a, ::Colon) = _all(identity, a, :) -for (fname, op) in [(:sum, :add_sum), (:prod, :mul_prod), +for (fname, op) in [(:prod, :mul_prod), (:maximum, :max), (:minimum, :min), (:all, :&), (:any, :|)] fname! = Symbol(fname, '!') + _fname! = Symbol('_', fname, '!') _fname = Symbol('_', fname) @eval begin + $(fname!)(r::AbstractArray, A::AbstractArray; init::Bool=true) = + $(fname!)(identity, r, A; init=init) $(fname!)(f::Function, r::AbstractArray, A::AbstractArray; init::Bool=true) = - mapreducedim!(f, $(op), initarray!(r, $(op), init, A), A) - $(fname!)(r::AbstractArray, A::AbstractArray; init::Bool=true) = $(fname!)(identity, r, A; init=init) + $(_fname!)(f, r, A; init=init) - $(_fname)(A, dims) = $(_fname)(identity, A, dims) + # Underlying implementations using dispatch + $(_fname!)(f, r::AbstractArray, A::AbstractArray; init::Bool=true) = + mapreducedim!(f, $(op), initarray!(r, $(op), init, A), A) + $(_fname)(A, dims) = $(_fname)(identity, A, dims) $(_fname)(f, A, dims) = mapreduce(f, $(op), A, dims=dims) end end +# Sum is the only reduction which supports weights +sum!(r::AbstractArray, A::AbstractArray; + init::Bool=true, weights::Union{AbstractArray,Nothing}=nothing) = + sum!(identity, r, A; init=init, weights=weights) +sum!(f::Function, r::AbstractArray, A::AbstractArray; + init::Bool=true, weights::Union{AbstractArray,Nothing}=nothing) = + _sum!(f, r, A, weights; init=init) +_sum!(f, r::AbstractArray, A::AbstractArray, ::Nothing; init::Bool=true) = + mapreducedim!(f, add_sum, initarray!(r, add_sum, init, A), A) +_sum(A::AbstractArray, dims, weights) = _sum(identity, A, dims, weights) +_sum(f, A::AbstractArray, dims, ::Nothing) = mapreduce(f, add_sum, A, dims=dims) +_sum(::typeof(identity), A::AbstractArray, dims, w::AbstractArray) = + _sum!(identity, reducedim_init(t -> t*zero(eltype(w)), add_sum, A, dims), A, w) +_sum(f, A::AbstractArray, dims, w::AbstractArray) = + throw(ArgumentError("Passing a function is not supported with `weights`")) + + +# Weighted sum +function _sum(A::AbstractArray, dims::Colon, w::AbstractArray{<:Real}) + sw = size(w) + sA = size(A) + if sw != sA + throw(DimensionMismatch("weights must have the same dimension as data (got $sw and $sA).")) + end + s0 = zero(eltype(A)) * zero(eltype(w)) + s = add_sum(s0, s0) + @inbounds @simd for i in eachindex(A, w) + s += A[i] * w[i] + end + s +end + +# Weighted sum over dimensions +# +# Brief explanation of the algorithm: +# ------------------------------------ +# +# 1. _wsum! provides the core implementation, which assumes that +# the dimensions of all input arguments are consistent, and no +# dimension checking is performed therein. +# +# wsum and wsum! perform argument checking and call _wsum! +# internally. +# +# 2. _wsum! adopt a Cartesian based implementation for general +# sub types of AbstractArray. Particularly, a faster routine +# that keeps a local accumulator will be used when dim = 1. +# +# The internal function that implements this is _wsum_general! +# +# 3. _wsum! is specialized for following cases: +# (a) A is a vector: we invoke the vector version wsum above. +# The internal function that implements this is _wsum1! +# +# (b) A is a dense matrix with eltype <: BlasReal: we call gemv! +# The internal function that implements this is _wsum2_blas! +# (in LinearAlgebra/src/wsum.jl) +# +# (c) A is a contiguous array with eltype <: BlasReal: +# dim == 1: treat A like a matrix of size (d1, d2 x ... x dN) +# dim == N: treat A like a matrix of size (d1 x ... x d(N-1), dN) +# otherwise: decompose A into multiple pages, and apply _wsum2_blas! +# for each +# The internal function that implements this is _wsumN! +# (in LinearAlgebra/src/wsum.jl) +# +# (d) A is a general dense array with eltype <: BlasReal: +# dim <= 2: delegate to (a) and (b) +# otherwise, decompose A into multiple pages +# The internal function that implements this is _wsumN! +# (in LinearAlgebra/src/wsum.jl) + +function _wsum1!(R::AbstractArray, A::AbstractVector, w::AbstractVector, init::Bool) + r = _sum(A, :, w) + if init + R[1] = r + else + R[1] += r + end + return R +end + +function _wsum_general!(R::AbstractArray{S}, A::AbstractArray, w::AbstractVector, + dim::Int, init::Bool) where {S} + # following the implementation of _mapreducedim! + lsiz = check_reducedims(R,A) + !isempty(R) && init && fill!(R, zero(S)) + isempty(A) && return R + + indsAt, indsRt = safe_tail(axes(A)), safe_tail(axes(R)) # handle d=1 manually + keep, Idefault = Broadcast.shapeindexer(indsRt) + if reducedim1(R, A) + i1 = first(axes1(R)) + for IA in CartesianIndices(indsAt) + IR = Broadcast.newindex(IA, keep, Idefault) + r = R[i1,IR] + @inbounds @simd for i in axes(A, 1) + r += A[i,IA] * w[dim > 1 ? IA[dim-1] : i] + end + R[i1,IR] = r + end + else + for IA in CartesianIndices(indsAt) + IR = Broadcast.newindex(IA, keep, Idefault) + @inbounds @simd for i in axes(A, 1) + R[i,IR] += A[i,IA] * w[dim > 1 ? IA[dim-1] : i] + end + end + end + return R +end + +_wsum!(R::AbstractArray, A::AbstractVector, w::AbstractVector, + dim::Int, init::Bool) = + _wsum1!(R, A, w, init) + +_wsum!(R::AbstractArray, A::AbstractArray, w::AbstractVector, + dim::Int, init::Bool) = + _wsum_general!(R, A, w, dim, init) + +function _sum!(f, R::AbstractArray, A::AbstractArray{T,N}, w::AbstractArray; + init::Bool=true) where {T,N} + f === identity || throw(ArgumentError("Passing a function is not supported with `weights`")) + w isa AbstractVector || throw(ArgumentError("Only vector `weights` are supported")) + + check_reducedims(R,A) + reddims = size(R) .!= size(A) + dim = something(findfirst(reddims), ndims(R)+1) + if dim > N + dim1 = findfirst(==(1), size(A)) + if dim1 !== nothing + dim = dim1 + end + end + if findnext(reddims, dim+1) !== nothing + throw(ArgumentError("reducing over more than one dimension is not supported with weights")) + end + lw = length(w) + ldim = size(A, dim) + if lw != ldim + throw(DimensionMismatch("weights must have the same length as the dimension " * + "over which reduction is performed (got $lw and $ldim).")) + end + _wsum!(R, A, w, dim, init) +end + + ##### findmin & findmax ##### # The initial values of Rval are not used if the corresponding indices in Rind are 0. # diff --git a/stdlib/LinearAlgebra/src/LinearAlgebra.jl b/stdlib/LinearAlgebra/src/LinearAlgebra.jl index dee9bd21ee0d6..83e19efe365b4 100644 --- a/stdlib/LinearAlgebra/src/LinearAlgebra.jl +++ b/stdlib/LinearAlgebra/src/LinearAlgebra.jl @@ -380,6 +380,7 @@ include("bitarray.jl") include("ldlt.jl") include("schur.jl") include("structuredbroadcast.jl") +include("wsum.jl") include("deprecated.jl") const ⋅ = dot diff --git a/stdlib/LinearAlgebra/src/wsum.jl b/stdlib/LinearAlgebra/src/wsum.jl new file mode 100644 index 0000000000000..8b0ecc1478cf7 --- /dev/null +++ b/stdlib/LinearAlgebra/src/wsum.jl @@ -0,0 +1,94 @@ +# Optimized method for weighted sum with BlasReal +# dot cannot be used for other types as it uses + rather than add_sum for accumulation, +# and therefore does not return the correct type +Base._sum(A::AbstractArray{T}, dims::Colon, w::AbstractArray{T}) where {T<:BlasReal} = + dot(vec(A), vec(w)) + +# Optimized methods for weighted sum over dimensions with BlasReal +# (generic method is defined in base/reducedim.jl) +# +# _wsum! is specialized for following cases: +# (a) A is a dense matrix with eltype <: BlasReal: we call gemv! +# The internal function that implements this is _wsum2_blas! +# +# (b) A is a contiguous array with eltype <: BlasReal: +# dim == 1: treat A like a matrix of size (d1, d2 x ... x dN) +# dim == N: treat A like a matrix of size (d1 x ... x d(N-1), dN) +# otherwise: decompose A into multiple pages, and apply _wsum2_blas! +# for each +# The internal function that implements this is _wsumN! +# +# (c) A is a general dense array with eltype <: BlasReal: +# dim <= 2: delegate to (a) and (b) +# otherwise, decompose A into multiple pages +# The internal function that implements this is _wsumN! + +function _wsum2_blas!(R::StridedVector{T}, A::StridedMatrix{T}, w::StridedVector{T}, + dim::Int, init::Bool) where T<:BlasReal + beta = ifelse(init, zero(T), one(T)) + trans = dim == 1 ? 'T' : 'N' + BLAS.gemv!(trans, one(T), A, w, beta, R) + return R +end + +function _wsumN!(R::StridedArray{T}, A::StridedArray{T,N}, w::StridedVector{T}, + dim::Int, init::Bool) where {T<:BlasReal,N} + if dim == 1 + m = size(A, 1) + n = div(length(A), m) + _wsum2_blas!(view(R,:), reshape(A, (m, n)), w, 1, init) + elseif dim == N + n = size(A, N) + m = div(length(A), n) + _wsum2_blas!(view(R,:), reshape(A, (m, n)), w, 2, init) + else # 1 < dim < N + m = 1 + for i = 1:dim-1 + m *= size(A, i) + end + n = size(A, dim) + k = 1 + for i = dim+1:N + k *= size(A, i) + end + Av = reshape(A, (m, n, k)) + Rv = reshape(R, (m, k)) + for i = 1:k + _wsum2_blas!(view(Rv,:,i), view(Av,:,:,i), w, 2, init) + end + end + return R +end + +function _wsumN!(R::StridedArray{T}, A::DenseArray{T,N}, w::StridedVector{T}, + dim::Int, init::Bool) where {T<:BlasReal,N} + @assert N >= 3 + if dim <= 2 + m = size(A, 1) + n = size(A, 2) + npages = 1 + for i = 3:N + npages *= size(A, i) + end + rlen = ifelse(dim == 1, n, m) + Rv = reshape(R, (rlen, npages)) + for i = 1:npages + _wsum2_blas!(view(Rv,:,i), view(A,:,:,i), w, dim, init) + end + else + Base._wsum_general!(R, A, w, dim, init) + end + return R +end + +Base._wsum!(R::StridedArray{T}, A::DenseMatrix{T}, w::StridedVector{T}, + dim::Int, init::Bool) where {T<:BlasReal} = + _wsum2_blas!(view(R,:), A, w, dim, init) + +Base._wsum!(R::StridedArray{T}, A::DenseArray{T}, w::StridedVector{T}, + dim::Int, init::Bool) where {T<:BlasReal} = + _wsumN!(R, A, w, dim, init) + +Base._wsum!(R::StridedVector{T}, A::DenseArray{T}, w::StridedVector{T}, + dim::Int, init::Bool) where {T<:BlasReal} = + Base._wsum1!(R, A, w, init) \ No newline at end of file diff --git a/test/reduce.jl b/test/reduce.jl index eb585e8a630f1..9b89e8f6741b6 100644 --- a/test/reduce.jl +++ b/test/reduce.jl @@ -533,3 +533,38 @@ x = [j^2 for j in i] i = Base.Slice(0:0) x = [j+7 for j in i] @test sum(x) == 7 + +@testset "weighted sum" begin + wts = ([1.4, 2.5, 10.1], [1.4f0, 2.5f0, 10.1f0], [0.0, 2.3, 5.6], + [NaN, 2.3, 5.6], [Inf, 2.3, 5.6], + [2, 1, 3], Int8[1, 2, 3], [1, 1, 1]) + for a in (rand(3), rand(Int, 3), rand(Int8, 3)) + for w in wts + res = @inferred sum(a, weights=w) + expected = sum(a.*w) + if isfinite(res) + @test res ≈ expected + else + @test isequal(res, expected) + end + @test typeof(res) == typeof(expected) + end + end + for a in (rand(3, 5), rand(Float32, 3, 5), rand(Int, 3, 5), rand(Int8, 3, 5)) + for w in wts + wr = repeat(w, outer=(1, 5)) + res = @inferred sum(a, weights=wr) + expected = sum(a.*wr) + if isfinite(res) + @test res ≈ expected + else + @test isequal(res, expected) + end + @test typeof(res) == typeof(expected) + end + end + + @test_throws ArgumentError sum(exp, [1], weights=[1]) + @test_throws ArgumentError sum!(exp, [0 0], [1 2], weights=[1, 10]) + @test_throws ArgumentError sum!([0 0], [1 2], weights=[1 10]) +end \ No newline at end of file diff --git a/test/reducedim.jl b/test/reducedim.jl index 2a39b706d2a73..c8c030d3704ce 100644 --- a/test/reducedim.jl +++ b/test/reducedim.jl @@ -398,3 +398,93 @@ end @test_throws DimensionMismatch maximum!(fill(0, 1, 1, 2, 1), B) @test_throws DimensionMismatch minimum!(fill(0, 1, 1, 2, 1), B) end + +@testset "weighted sum over dimensions" begin + wts = ([1.4, 2.5, 10.1], [1.4f0, 2.5f0, 10.1f0], [0.0, 2.3, 5.6], + [NaN, 2.3, 5.6], [Inf, 2.3, 5.6], + [2, 1, 3], Int8[1, 2, 3], [1, 1, 1]) + + ainf = rand(3) + ainf[1] = Inf + anan = rand(3) + anan[1] = NaN + for a in (rand(3), rand(Float32, 3), ainf, anan, + rand(Int, 3), rand(Int8, 3), + view(rand(5), 2:4)) + for w in wts + if all(isfinite, a) && all(isfinite, w) + expected = sum(a.*w, dims=1) + res = @inferred sum(a, weights=w, dims=1) + @test res ≈ expected + @test typeof(res) == typeof(expected) + x = rand!(similar(expected)) + y = copy(x) + @inferred sum!(y, a, weights=w) + @test y ≈ expected + y = copy(x) + @inferred sum!(y, a, weights=w, init=false) + @test y ≈ x + expected + else + expected = sum(a.*w, dims=1) + res = @inferred sum(a, weights=w, dims=1) + @test isfinite.(res) == isfinite.(expected) + @test typeof(res) == typeof(expected) + x = rand!(similar(expected)) + y = copy(x) + @inferred sum!(y, a, weights=w) + @test isfinite.(y) == isfinite.(expected) + y = copy(x) + @inferred sum!(y, a, weights=w, init=false) + @test isfinite.(y) == isfinite.(expected) + end + end + end + + ainf = rand(3, 3, 3) + ainf[1] = Inf + anan = rand(3, 3, 3) + anan[1] = NaN + for a in (rand(3, 3, 3), rand(Float32, 3, 3, 3), ainf, anan, + rand(Int, 3, 3, 3), rand(Int8, 3, 3, 3), + view(rand(3, 3, 5), :, :, 2:4)) + for w in wts + for (d, rw) in ((1, reshape(w, :, 1, 1)), + (2, reshape(w, 1, :, 1)), + (3, reshape(w, 1, 1, :))) + if all(isfinite, a) && all(isfinite, w) + expected = sum(a.*rw, dims=d) + res = @inferred sum(a, weights=w, dims=d) + @test res ≈ expected + @test typeof(res) == typeof(expected) + x = rand!(similar(expected)) + y = copy(x) + @inferred sum!(y, a, weights=w) + @test y ≈ expected + y = copy(x) + @inferred sum!(y, a, weights=w, init=false) + @test y ≈ x + expected + else + expected = sum(a.*rw, dims=d) + res = @inferred sum(a, weights=w, dims=d) + @test isfinite.(res) == isfinite.(expected) + @test typeof(res) == typeof(expected) + x = rand!(similar(expected)) + y = copy(x) + @inferred sum!(y, a, weights=w) + @test isfinite.(y) == isfinite.(expected) + y = copy(x) + @inferred sum!(y, a, weights=w, init=false) + @test isfinite.(y) == isfinite.(expected) + end + end + + @test_throws DimensionMismatch sum(a, weights=w, dims=4) + end + end + + # Corner case with a single row + @test sum([1 2], weights=[2], dims=1) == [2 4] + + @test_throws ArgumentError sum(exp, [1 2], weights=[1, 10], dims=1) + @test_throws ArgumentError sum([1 2], weights=[1 10], dims=1) +end \ No newline at end of file