Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance of weighted sum #778

Merged
merged 2 commits into from
Mar 31, 2022
Merged

Improve performance of weighted sum #778

merged 2 commits into from
Mar 31, 2022

Conversation

nalimilan
Copy link
Member

The current code is calling the AbstractArray matrix multiplication fallback, which is slower than BLAS.

Fixes #775.

julia> using BenchmarkTools

julia> using StatsBase

julia> x = rand(1000); w = weights(rand(1000));

# master
julia> @btime wsum(x, w);
  1.029 μs (1 allocation: 16 bytes)

# This PR
julia> @btime wsum(x, w);
  97.572 ns (1 allocation: 16 bytes)

Tests appear to cover all weights types and AbstractVector already.

The current code is calling the `AbstractArray` matrix multiplication fallback,
which is slower than BLAS.
@rofinn
Copy link
Member

rofinn commented Mar 29, 2022

Seems like a reasonable extension in this case. Looks like there isn't a test for the UnitWeights dispatch.

@nalimilan
Copy link
Member Author

AFAICT unit weights are tested here, right?

@test sum([1.0, 2.0, 3.0], wt) 6.0

@mschauer
Copy link
Member

The CI error seems to be related to sampling?

* wsampling.jl ...
[172](https://github.com/JuliaStats/StatsBase.jl/runs/5703799652?check_suite_focus=true#step:6:172)
┌ Warning: Assignment to `a` in soft scope is ambiguous because a global variable by the same name exists: `a` will be treated as a new local. Disambiguate by using `local a` to suppress this warning or `global a` to assign to the existing global variable.
[173](https://github.com/JuliaStats/StatsBase.jl/runs/5703799652?check_suite_focus=true#step:6:173)
└ @ ~/work/StatsBase.jl/StatsBase.jl/test/wsampling.jl:47
[174](https://github.com/JuliaStats/StatsBase.jl/runs/5703799652?check_suite_focus=true#step:6:174)
ERROR: LoadError: MethodError: no method matching direct_sample!(::UnitRange{Int64}, ::Vector{Float64}, ::Matrix{Int64})
[175](https://github.com/JuliaStats/StatsBase.jl/runs/5703799652?check_suite_focus=true#step:6:175)
Closest candidates are:
[176](https://github.com/JuliaStats/StatsBase.jl/runs/5703799652?check_suite_focus=true#step:6:176)
  direct_sample!(::UnitRange, ::AbstractArray) at ~/work/StatsBase.jl/StatsBase.jl/src/sampling.jl:26
[177](https://github.com/JuliaStats/StatsBase.jl/runs/5703799652?check_suite_focus=true#step:6:177)
  direct_sample!(::AbstractRNG, ::AbstractArray, ::AbstractArray) at ~/work/StatsBase.jl/StatsBase.jl/src/sampling.jl:36
[178](https://github.com/JuliaStats/StatsBase.jl/runs/5703799652?check_suite_focus=true#step:6:178)
  direct_sample!(::AbstractArray, ::AbstractWeights, ::AbstractArray) at ~/work/StatsBase.jl/StatsBase.jl/src/sampling.jl:584
[179](https://github.com/JuliaStats/StatsBase.jl/runs/5703799652?check_suite_focus=true#step:6:179)
  ...
[180](https://github.com/JuliaStats/StatsBase.jl/runs/5703799652?check_suite_focus=true#step:6:180)
Stacktrace:
[181](https://github.com/JuliaStats/StatsBase.jl/runs/5703799652?check_suite_focus=true#step:6:181)
 [1] top-level scope
[182](https://github.com/JuliaStats/StatsBase.jl/runs/5703799652?check_suite_focus=true#step:6:182)
   @ ~/work/StatsBase.jl/StatsBase.jl/test/wsampling.jl:47
[183](https://github.com/JuliaStats/StatsBase.jl/runs/5703799652?check_suite_focus=true#step:6:183)
 [2] include(fname::String)
[184](https://github.com/JuliaStats/StatsBase.jl/runs/5703799652?check_suite_focus=true#step:6:184)
   @ Base.MainInclude ./client.jl:476
[185](https://github.com/JuliaStats/StatsBase.jl/runs/5703799652?check_suite_focus=true#step:6:185)
 [3] top-level scope
[186](https://github.com/JuliaStats/StatsBase.jl/runs/5703799652?check_suite_focus=true#step:6:186)
   @ ~/work/StatsBase.jl/StatsBase.jl/test/runtests.jl:34
[187](https://github.com/JuliaStats/StatsBase.jl/runs/5703799652?check_suite_focus=true#step:6:187)
 [4] include(fname::String)
[188](https://github.com/JuliaStats/StatsBase.jl/runs/5703799652?check_suite_focus=true#step:6:188)
   @ Base.MainInclude ./client.jl:476
[189](https://github.com/JuliaStats/StatsBase.jl/runs/5703799652?check_suite_focus=true#step:6:189)
 [5] top-level scope
[190](https://github.com/JuliaStats/StatsBase.jl/runs/5703799652?check_suite_focus=true#step:6:190)
   @ none:6
[191](https://github.com/JuliaStats/StatsBase.jl/runs/5703799652?check_suite_focus=true#step:6:191)
in expression starting at /home/runner/work/StatsBase.jl/StatsBase.jl/test/wsampling.jl:40
[192](https://github.com/JuliaStats/StatsBase.jl/runs/5703799652?check_suite_focus=true#step:6:192)
in expression starting at /home/runner/work/StatsBase.jl/StatsBase.jl/test/runtests.jl:31
[193](https://github.com/JuliaStats/StatsBase.jl/runs/5703799652?check_suite_focus=true#step:6:193)
ERROR: Package StatsBase errored during testing

@nalimilan
Copy link
Member Author

Yeah these also happen on master: https://github.com/JuliaStats/StatsBase.jl/actions/runs/2044798742

@rofinn
Copy link
Member

rofinn commented Mar 30, 2022

AFAICT unit weights are tested here, right?

Hmm, codecov says your new dispatch for wsum at line 391 isn't getting hit though :/

@rofinn
Copy link
Member

rofinn commented Mar 30, 2022

Alright, I figured out why that method wasn't being hit. The specific line you posted calls a sum dispatch which never calls the wsum call.

https://github.com/JuliaStats/StatsBase.jl/blob/nl/wsum/src/weights.jl#L616

The wsum tests in that block all pass an integer and not Colon, so it hits a different dispatch.

https://github.com/JuliaStats/StatsBase.jl/blob/nl/wsum/src/weights.jl#L588

Adding a wsum call with Colon should solve the immediate coverage problem, though I wonder if the number of method overrides for sum and wsum suggestion an organizational/design issue?

@nalimilan
Copy link
Member Author

Indeed. I've added a commit which hopefully simplifies the dispatch logic without breaking anything.

@nalimilan
Copy link
Member Author

Good to go?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Slow weighted sum/dot etc.
4 participants