Skip to content

Commit

Permalink
Merge branch 'master' into nl/revert
Browse files Browse the repository at this point in the history
  • Loading branch information
nalimilan authored Sep 9, 2023
2 parents f91df38 + 81a90af commit 5c00836
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 96 deletions.
9 changes: 8 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,16 @@ julia = "1.9"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[weakdeps]
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[extensions]
SparseArraysExt = ["SparseArrays"]

[extras]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Random", "Test"]
test = ["Random", "SparseArrays", "Test"]
101 changes: 101 additions & 0 deletions ext/SparseArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
module SparseArraysExt

##### SparseArrays optimizations #####

using Base: require_one_based_indexing
using LinearAlgebra
using SparseArrays
using Statistics
using Statistics: centralize_sumabs2, unscaled_covzm

# extended functions
import Statistics: cov, centralize_sumabs2!

function cov(X::SparseMatrixCSC; dims::Int=1, corrected::Bool=true)
vardim = dims
a, b = size(X)
n, p = vardim == 1 ? (a, b) : (b, a)

# The covariance can be decomposed into two terms
# 1/(n - 1) ∑ (x_i - x̄)*(x_i - x̄)' = 1/(n - 1) (∑ x_i*x_i' - n*x̄*x̄')
# which can be evaluated via a sparse matrix-matrix product

# Compute ∑ x_i*x_i' = X'X using sparse matrix-matrix product
out = Matrix(unscaled_covzm(X, vardim))

# Compute x̄
x̄ᵀ = mean(X, dims=vardim)

# Subtract n*x̄*x̄' from X'X
@inbounds for j in 1:p, i in 1:p
out[i,j] -= x̄ᵀ[i] * x̄ᵀ[j]' * n
end

# scale with the sample size n or the corrected sample size n - 1
return rmul!(out, inv(n - corrected))
end

# This is the function that does the reduction underlying var/std
function centralize_sumabs2!(R::AbstractArray{S}, A::SparseMatrixCSC{Tv,Ti}, means::AbstractArray) where {S,Tv,Ti}
require_one_based_indexing(R, A, means)
lsiz = Base.check_reducedims(R,A)
for i in 1:max(ndims(R), ndims(means))
if axes(means, i) != axes(R, i)
throw(DimensionMismatch("dimension $i of `mean` should have indices $(axes(R, i)), but got $(axes(means, i))"))
end
end
isempty(R) || fill!(R, zero(S))
isempty(A) && return R

rowval = rowvals(A)
nzval = nonzeros(A)
m = size(A, 1)
n = size(A, 2)

if size(R, 1) == size(R, 2) == 1
# Reduction along both columns and rows
R[1, 1] = centralize_sumabs2(A, means[1])
elseif size(R, 1) == 1
# Reduction along rows
@inbounds for col = 1:n
mu = means[col]
r = convert(S, (m - length(nzrange(A, col)))*abs2(mu))
@simd for j = nzrange(A, col)
r += abs2(nzval[j] - mu)
end
R[1, col] = r
end
elseif size(R, 2) == 1
# Reduction along columns
rownz = fill(convert(Ti, n), m)
@inbounds for col = 1:n
@simd for j = nzrange(A, col)
row = rowval[j]
R[row, 1] += abs2(nzval[j] - means[row])
rownz[row] -= 1
end
end
for i = 1:m
R[i, 1] += rownz[i]*abs2(means[i])
end
else
# Reduction along a dimension > 2
@inbounds for col = 1:n
lastrow = 0
@simd for j = nzrange(A, col)
row = rowval[j]
for i = lastrow+1:row-1
R[i, col] = abs2(means[i, col])
end
R[row, col] = abs2(nzval[j] - means[row, col])
lastrow = row
end
for i = lastrow+1:m
R[i, col] = abs2(means[i, col])
end
end
end
return R
end

end # module
103 changes: 11 additions & 92 deletions src/Statistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Standard library module for basic statistics functionality.
"""
module Statistics

using LinearAlgebra, SparseArrays
using LinearAlgebra

using Base: has_offset_axes, require_one_based_indexing

Expand Down Expand Up @@ -106,9 +106,13 @@ mean(f, A::AbstractArray; dims=:) = _mean(f, A, dims)
function mean(f::Number, itr::Number)
f_value = try
f(itr)
catch MethodError
rethrow(ArgumentError("""mean(f, itr) requires a function and an iterable.
Perhaps you meant middle(x, y)?""",))
catch err
if err isa MethodError && err.f === f && err.args == (itr,)
rethrow(ArgumentError("""mean(f, itr) requires a function and an iterable.
Perhaps you meant mean((x, y))?"""))
else
rethrow(err)
end
end
Base.reduce_first(+, f_value)/1
end
Expand Down Expand Up @@ -1089,94 +1093,9 @@ quantile(itr, p; sorted::Bool=false, alpha::Real=1.0, beta::Real=alpha) =
quantile(v::AbstractVector, p; sorted::Bool=false, alpha::Real=1.0, beta::Real=alpha) =
quantile!(sorted ? v : Base.copymutable(v), p; sorted=sorted, alpha=alpha, beta=beta)


##### SparseArrays optimizations #####

function cov(X::SparseMatrixCSC; dims::Int=1, corrected::Bool=true)
vardim = dims
a, b = size(X)
n, p = vardim == 1 ? (a, b) : (b, a)

# The covariance can be decomposed into two terms
# 1/(n - 1) ∑ (x_i - x̄)*(x_i - x̄)' = 1/(n - 1) (∑ x_i*x_i' - n*x̄*x̄')
# which can be evaluated via a sparse matrix-matrix product

# Compute ∑ x_i*x_i' = X'X using sparse matrix-matrix product
out = Matrix(unscaled_covzm(X, vardim))

# Compute x̄
x̄ᵀ = mean(X, dims=vardim)

# Subtract n*x̄*x̄' from X'X
@inbounds for j in 1:p, i in 1:p
out[i,j] -= x̄ᵀ[i] * x̄ᵀ[j]' * n
end

# scale with the sample size n or the corrected sample size n - 1
return rmul!(out, inv(n - corrected))
end

# This is the function that does the reduction underlying var/std
function centralize_sumabs2!(R::AbstractArray{S}, A::SparseMatrixCSC{Tv,Ti}, means::AbstractArray) where {S,Tv,Ti}
require_one_based_indexing(R, A, means)
lsiz = Base.check_reducedims(R,A)
for i in 1:max(ndims(R), ndims(means))
if axes(means, i) != axes(R, i)
throw(DimensionMismatch("dimension $i of `mean` should have indices $(axes(R, i)), but got $(axes(means, i))"))
end
end
isempty(R) || fill!(R, zero(S))
isempty(A) && return R

rowval = rowvals(A)
nzval = nonzeros(A)
m = size(A, 1)
n = size(A, 2)

if size(R, 1) == size(R, 2) == 1
# Reduction along both columns and rows
R[1, 1] = centralize_sumabs2(A, means[1])
elseif size(R, 1) == 1
# Reduction along rows
@inbounds for col = 1:n
mu = means[col]
r = convert(S, (m - length(nzrange(A, col)))*abs2(mu))
@simd for j = nzrange(A, col)
r += abs2(nzval[j] - mu)
end
R[1, col] = r
end
elseif size(R, 2) == 1
# Reduction along columns
rownz = fill(convert(Ti, n), m)
@inbounds for col = 1:n
@simd for j = nzrange(A, col)
row = rowval[j]
R[row, 1] += abs2(nzval[j] - means[row])
rownz[row] -= 1
end
end
for i = 1:m
R[i, 1] += rownz[i]*abs2(means[i])
end
else
# Reduction along a dimension > 2
@inbounds for col = 1:n
lastrow = 0
@simd for j = nzrange(A, col)
row = rowval[j]
for i = lastrow+1:row-1
R[i, col] = abs2(means[i, col])
end
R[row, col] = abs2(nzval[j] - means[row, col])
lastrow = row
end
for i = lastrow+1:m
R[i, col] = abs2(means[i, col])
end
end
end
return R
# If package extensions are not supported in this Julia version
if !isdefined(Base, :get_extension)
include("../ext/SparseArraysExt.jl")
end

end # module
12 changes: 9 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,18 @@ end
@test isnan(@inferred mean(Iterators.filter(x -> true, Float64[])))

# using a number as a "function"
@test_throws "ArgumentError: mean(f, itr) requires a function and an iterable.\nPerhaps you meant middle(x, y)" mean(1, 2)
@test_throws "ArgumentError: mean(f, itr) requires a function and an iterable.\nPerhaps you meant mean((x, y))" mean(1, 2)
struct T <: Number
x::Int
end
(t::T)(y) = t.x * y
@test @inferred mean(T(2), 3) === 6.0
(t::T)(y) = t.x == 0 ? t(y, y + 1, y + 2) : t.x * y
@test mean(T(2), 3) === 6.0
@test_throws MethodError mean(T(0), 3)
struct U <: Number
x::Int
end
(t::U)(y) = t.x == 0 ? throw(MethodError(T)) : t.x * y
@test @inferred mean(U(2), 3) === 6.0
end

@testset "mean/median for ranges" begin
Expand Down

0 comments on commit 5c00836

Please sign in to comment.