Skip to content

Commit

Permalink
make SparseArrays a weak dependency (#134)
Browse files Browse the repository at this point in the history
  • Loading branch information
IanButterworth authored Sep 1, 2023
1 parent 04e5d89 commit 81a90af
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 90 deletions.
9 changes: 8 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,16 @@ julia = "1.9"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[weakdeps]
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[extensions]
SparseArraysExt = ["SparseArrays"]

[extras]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Random", "Test"]
test = ["Random", "SparseArrays", "Test"]
101 changes: 101 additions & 0 deletions ext/SparseArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
module SparseArraysExt

##### SparseArrays optimizations #####

using Base: require_one_based_indexing
using LinearAlgebra
using SparseArrays
using Statistics
using Statistics: centralize_sumabs2, unscaled_covzm

# extended functions
import Statistics: cov, centralize_sumabs2!

function cov(X::SparseMatrixCSC; dims::Int=1, corrected::Bool=true)
vardim = dims
a, b = size(X)
n, p = vardim == 1 ? (a, b) : (b, a)

# The covariance can be decomposed into two terms
# 1/(n - 1) ∑ (x_i - x̄)*(x_i - x̄)' = 1/(n - 1) (∑ x_i*x_i' - n*x̄*x̄')
# which can be evaluated via a sparse matrix-matrix product

# Compute ∑ x_i*x_i' = X'X using sparse matrix-matrix product
out = Matrix(unscaled_covzm(X, vardim))

# Compute x̄
x̄ᵀ = mean(X, dims=vardim)

# Subtract n*x̄*x̄' from X'X
@inbounds for j in 1:p, i in 1:p
out[i,j] -= x̄ᵀ[i] * x̄ᵀ[j]' * n
end

# scale with the sample size n or the corrected sample size n - 1
return rmul!(out, inv(n - corrected))
end

# This is the function that does the reduction underlying var/std
function centralize_sumabs2!(R::AbstractArray{S}, A::SparseMatrixCSC{Tv,Ti}, means::AbstractArray) where {S,Tv,Ti}
require_one_based_indexing(R, A, means)
lsiz = Base.check_reducedims(R,A)
for i in 1:max(ndims(R), ndims(means))
if axes(means, i) != axes(R, i)
throw(DimensionMismatch("dimension $i of `mean` should have indices $(axes(R, i)), but got $(axes(means, i))"))
end
end
isempty(R) || fill!(R, zero(S))
isempty(A) && return R

rowval = rowvals(A)
nzval = nonzeros(A)
m = size(A, 1)
n = size(A, 2)

if size(R, 1) == size(R, 2) == 1
# Reduction along both columns and rows
R[1, 1] = centralize_sumabs2(A, means[1])
elseif size(R, 1) == 1
# Reduction along rows
@inbounds for col = 1:n
mu = means[col]
r = convert(S, (m - length(nzrange(A, col)))*abs2(mu))
@simd for j = nzrange(A, col)
r += abs2(nzval[j] - mu)
end
R[1, col] = r
end
elseif size(R, 2) == 1
# Reduction along columns
rownz = fill(convert(Ti, n), m)
@inbounds for col = 1:n
@simd for j = nzrange(A, col)
row = rowval[j]
R[row, 1] += abs2(nzval[j] - means[row])
rownz[row] -= 1
end
end
for i = 1:m
R[i, 1] += rownz[i]*abs2(means[i])
end
else
# Reduction along a dimension > 2
@inbounds for col = 1:n
lastrow = 0
@simd for j = nzrange(A, col)
row = rowval[j]
for i = lastrow+1:row-1
R[i, col] = abs2(means[i, col])
end
R[row, col] = abs2(nzval[j] - means[row, col])
lastrow = row
end
for i = lastrow+1:m
R[i, col] = abs2(means[i, col])
end
end
end
return R
end

end # module
93 changes: 4 additions & 89 deletions src/Statistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Standard library module for basic statistics functionality.
"""
module Statistics

using LinearAlgebra, SparseArrays
using LinearAlgebra

using Base: has_offset_axes, require_one_based_indexing

Expand Down Expand Up @@ -1095,94 +1095,9 @@ quantile(itr, p; sorted::Bool=false, alpha::Real=1.0, beta::Real=alpha) =
quantile(v::AbstractVector, p; sorted::Bool=false, alpha::Real=1.0, beta::Real=alpha) =
quantile!(sorted ? v : Base.copymutable(v), p; sorted=sorted, alpha=alpha, beta=beta)


##### SparseArrays optimizations #####

function cov(X::SparseMatrixCSC; dims::Int=1, corrected::Bool=true)
vardim = dims
a, b = size(X)
n, p = vardim == 1 ? (a, b) : (b, a)

# The covariance can be decomposed into two terms
# 1/(n - 1) ∑ (x_i - x̄)*(x_i - x̄)' = 1/(n - 1) (∑ x_i*x_i' - n*x̄*x̄')
# which can be evaluated via a sparse matrix-matrix product

# Compute ∑ x_i*x_i' = X'X using sparse matrix-matrix product
out = Matrix(unscaled_covzm(X, vardim))

# Compute x̄
x̄ᵀ = mean(X, dims=vardim)

# Subtract n*x̄*x̄' from X'X
@inbounds for j in 1:p, i in 1:p
out[i,j] -= x̄ᵀ[i] * x̄ᵀ[j]' * n
end

# scale with the sample size n or the corrected sample size n - 1
return rmul!(out, inv(n - corrected))
end

# This is the function that does the reduction underlying var/std
function centralize_sumabs2!(R::AbstractArray{S}, A::SparseMatrixCSC{Tv,Ti}, means::AbstractArray) where {S,Tv,Ti}
require_one_based_indexing(R, A, means)
lsiz = Base.check_reducedims(R,A)
for i in 1:max(ndims(R), ndims(means))
if axes(means, i) != axes(R, i)
throw(DimensionMismatch("dimension $i of `mean` should have indices $(axes(R, i)), but got $(axes(means, i))"))
end
end
isempty(R) || fill!(R, zero(S))
isempty(A) && return R

rowval = rowvals(A)
nzval = nonzeros(A)
m = size(A, 1)
n = size(A, 2)

if size(R, 1) == size(R, 2) == 1
# Reduction along both columns and rows
R[1, 1] = centralize_sumabs2(A, means[1])
elseif size(R, 1) == 1
# Reduction along rows
@inbounds for col = 1:n
mu = means[col]
r = convert(S, (m - length(nzrange(A, col)))*abs2(mu))
@simd for j = nzrange(A, col)
r += abs2(nzval[j] - mu)
end
R[1, col] = r
end
elseif size(R, 2) == 1
# Reduction along columns
rownz = fill(convert(Ti, n), m)
@inbounds for col = 1:n
@simd for j = nzrange(A, col)
row = rowval[j]
R[row, 1] += abs2(nzval[j] - means[row])
rownz[row] -= 1
end
end
for i = 1:m
R[i, 1] += rownz[i]*abs2(means[i])
end
else
# Reduction along a dimension > 2
@inbounds for col = 1:n
lastrow = 0
@simd for j = nzrange(A, col)
row = rowval[j]
for i = lastrow+1:row-1
R[i, col] = abs2(means[i, col])
end
R[row, col] = abs2(nzval[j] - means[row, col])
lastrow = row
end
for i = lastrow+1:m
R[i, col] = abs2(means[i, col])
end
end
end
return R
# If package extensions are not supported in this Julia version
if !isdefined(Base, :get_extension)
include("../ext/SparseArraysExt.jl")
end

end # module

0 comments on commit 81a90af

Please sign in to comment.