Skip to content

Commit

Permalink
Replace SVD with randomized SVD (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
tsano430 authored Apr 13, 2021
1 parent 567d89c commit 79f8f0e
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 3 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NonNegLeastSquares = "b7351bd1-99d9-5c5d-8786-f205a815c4d7"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RandomizedLinAlg = "0448d7d9-159c-5637-8537-fd72090fea46"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
NonNegLeastSquares = "0.2.0"
NonNegLeastSquares = "0.2.0, 0.3.0"
RandomizedLinAlg = "0.1.0"
StatsBase = "0.25, 0.26, 0.27, 0.28, 0.29, 0.30, 0.31, 0.32, 0.33"
julia = "0.7, 1"

Expand Down
1 change: 1 addition & 0 deletions src/NMF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ module NMF
using LinearAlgebra
using NonNegLeastSquares
using Random
using RandomizedLinAlg

export nnmf

Expand Down
4 changes: 2 additions & 2 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ function _nndsvd!(X, W, Ht, inith::Bool, variant::Int)
k = size(W, 2)
T = eltype(W)

# compute SVD
(U, s, V) = svd(X, full=false)
# compute randomized SVD
(U, s, V) = rsvd(X, k)

# main loop
v0 = variant == 0 ? zero(T) :
Expand Down
2 changes: 2 additions & 0 deletions test/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ W, H = NMF.randinit(X, 5; normalize=true)

## nndsvd

Random.seed!(5678)
W, H = NMF.nndsvd(X, 5)
@test size(W) == (8, 5)
@test size(H) == (5, 12)
@test all(W .>= 0.0)
@test all(H .>= 0.0)

Random.seed!(5678)
W2, H2 = NMF.nndsvd(X, 5; zeroh=true)
@test size(W) == (8, 5)
@test size(H) == (5, 12)
Expand Down

0 comments on commit 79f8f0e

Please sign in to comment.