-
Notifications
You must be signed in to change notification settings - Fork 34
/
common.jl
111 lines (88 loc) · 3.12 KB
/
common.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# common facilities
# tools to check size
function nmf_checksize(X, W::AbstractMatrix, H::AbstractMatrix)
p = size(X, 1)
n = size(X, 2)
k = size(W, 2)
if !(size(W,1) == p && size(H) == (k, n))
throw(DimensionMismatch("Dimensions of X, W, and H are inconsistent."))
end
return (p, n, k)
end
# the result type
struct Result{T}
W::Matrix{T}
H::Matrix{T}
niters::Int
converged::Bool
objvalue::T
function Result{T}(W::Matrix{T}, H::Matrix{T}, niters::Int, converged::Bool, objv) where T
if size(W, 2) != size(H, 1)
throw(DimensionMismatch("Inner dimensions of W and H mismatch."))
end
new{T}(W, H, niters, converged, objv)
end
end
Base.:(==)(A::Result, B::Result) = A.W == B.W && A.H == B.H && A.niters == B.niters && A.converged == B.converged && A.objvalue == B.objvalue
Base.hash(s::Result, h::UInt) = hash(s.objvalue, hash(s.converged, hash(s.niters, hash(s.H, hash(s.W, h + (0x09c9f08cfcba6de3 % UInt))))))
# common algorithmic skeleton for iterative updating methods
abstract type NMFUpdater{T} end
function nmf_skeleton!(updater::NMFUpdater{T},
X, W::Matrix{T}, H::Matrix{T},
maxiter::Int, verbose::Bool, tol) where T
objv = convert(T, NaN)
# init
state = prepare_state(updater, X, W, H)
preW = Matrix{T}(undef, size(W))
preH = Matrix{T}(undef, size(H))
if verbose
start = time()
objv = evaluate_objv(updater, state, X, W, H)
@printf("%-5s %-13s %-13s %-13s %-13s\n", "Iter", "Elapsed time", "objv", "objv.change", "(W & H).relchange")
@printf("%5d %13.6e %13.6e\n", 0, 0.0, objv)
end
# main loop
converged = false
t = 0
while !converged && t < maxiter
t += 1
copyto!(preW, W)
copyto!(preH, H)
# update H
update_wh!(updater, state, X, W, H)
# determine convergence
converged, dev = stop_condition(W, preW, H, preH, tol)
# display info
if verbose
elapsed = time() - start
preobjv = objv
objv = evaluate_objv(updater, state, X, W, H)
@printf("%5d %13.6e %13.6e %13.6e %13.6e\n",
t, elapsed, objv, objv - preobjv, dev)
end
end
if !verbose
objv = evaluate_objv(updater, state, X, W, H)
end
return Result{T}(W, H, t, converged, objv)
end
function stop_condition(W::AbstractArray{T}, preW::AbstractArray, H::AbstractArray, preH::AbstractArray, eps::AbstractFloat) where T
devmax = zero(T)
for j in axes(W,2)
dev_w = sum_w = zero(T)
for i in axes(W,1)
dev_w += (W[i,j] - preW[i,j])^2
sum_w += (W[i,j] + preW[i,j])^2
end
dev_h = sum_h = zero(T)
for i in axes(H,2)
dev_h += (H[j,i] - preH[j,i])^2
sum_h += (H[j,i] + preH[j,i])^2
end
devmax = max(devmax, sqrt(max(dev_w/sum_w, dev_h/sum_h)))
if sqrt(dev_w) > eps*sqrt(sum_w) || sqrt(dev_h) > eps*sqrt(sum_h)
return false, devmax
end
end
return true, devmax
end