Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
nalimilan committed May 8, 2019
1 parent d627393 commit 7f364f4
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 5 deletions.
13 changes: 10 additions & 3 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -724,9 +724,12 @@ sum!(f::Function, r::AbstractArray, A::AbstractArray;
_sum!(f, r, A, weights; init=init)
_sum!(f, r::AbstractArray, A::AbstractArray, ::Nothing; init::Bool=true) =
mapreducedim!(f, add_sum, initarray!(r, add_sum, init, A), A)
_sum(f, A, dims, ::Nothing) = mapreduce(f, add_sum, A, dims=dims)
_sum(A::AbstractArray, dims, w::AbstractArray) =
_sum(A::AbstractArray, dims, weights) = _sum(identity, A, dims, weights)
_sum(f, A::AbstractArray, dims, ::Nothing) = mapreduce(f, add_sum, A, dims=dims)
_sum(::typeof(identity), A::AbstractArray, dims, w::AbstractArray) =
_sum!(identity, reducedim_init(t -> t*zero(eltype(w)), add_sum, A, dims), A, w)
_sum(f, A::AbstractArray, dims, w::AbstractArray) =
throw(ArgumentError("Passing a function is not supported with `weights`"))


# Weighted sum
Expand Down Expand Up @@ -832,8 +835,11 @@ _wsum!(R::AbstractArray, A::AbstractArray, w::AbstractVector,
dim::Int, init::Bool) =
_wsum_general!(R, A, w, dim, init)

function _sum!(::typeof(identity), R::AbstractArray, A::AbstractArray{T,N}, w::AbstractVector;
function _sum!(f, R::AbstractArray, A::AbstractArray{T,N}, w::AbstractArray;
init::Bool=true) where {T,N}
f === identity || throw(ArgumentError("Passing a function is not supported with `weights`"))
w isa AbstractVector || throw(ArgumentError("Only vector `weights` are supported"))

check_reducedims(R,A)
reddims = size(R) .!= size(A)
dim = something(findfirst(reddims), ndims(R)+1)
Expand All @@ -849,6 +855,7 @@ function _sum!(::typeof(identity), R::AbstractArray, A::AbstractArray{T,N}, w::A
_wsum!(R, A, w, dim, init)
end


##### findmin & findmax #####
# The initial values of Rval are not used if the corresponding indices in Rind are 0.
#
Expand Down
4 changes: 2 additions & 2 deletions stdlib/Statistics/src/Statistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,14 @@ function _mean(r::AbstractRange{<:Real}, dims::Colon, weights::Nothing)
end

_mean(A::AbstractArray, dims, weights::Nothing) =
_mean!(Base.reducedim_init(t -> t/2, +, A, dims), A, nothing)
_mean!(Base.reducedim_init(t -> t/2, Base.add_sum, A, dims), A, nothing)
_mean(A::AbstractArray, dims::Colon, weights::Nothing) = sum(A) / length(A)

_mean(A::AbstractArray, dims::Colon, w::AbstractArray) =
sum(A, weights=w) / sum(w)

_mean(A::AbstractArray, dims, w::AbstractArray) =
_mean!(Base.reducedim_init(t -> (t*zero(eltype(w)))/2, +, A, dims), A, w)
_mean!(Base.reducedim_init(t -> (t*zero(eltype(w)))/2, Base.add_sum, A, dims), A, w)

##### variances #####

Expand Down
4 changes: 4 additions & 0 deletions test/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -563,4 +563,8 @@ x = [j+7 for j in i]
@test typeof(res) == typeof(expected)
end
end

@test_throws ArgumentError sum(exp, [1], weights=[1])
@test_throws ArgumentError sum!(exp, [0 0], [1 2], weights=[1, 10])
@test_throws ArgumentError sum!([0 0], [1 2], weights=[1 10])
end
3 changes: 3 additions & 0 deletions test/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -481,4 +481,7 @@ end
@test_throws DimensionMismatch sum(a, weights=w, dims=4)
end
end

@test_throws ArgumentError sum(exp, [1 2], weights=[1, 10], dims=1)
@test_throws ArgumentError sum([1 2], weights=[1 10], dims=1)
end

0 comments on commit 7f364f4

Please sign in to comment.