-
Notifications
You must be signed in to change notification settings - Fork 19
/
utils.py
317 lines (258 loc) · 11.5 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
#!/usr/bin/env python
# coding=utf-8
from __future__ import (print_function, division, absolute_import, unicode_literals)
import os
import numpy as np
import tensorflow as tf
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.colors import hsv_to_rgb
from sklearn.metrics import adjusted_mutual_info_score
def save_image(filename, image_array):
import scipy.misc
if image_array.shape[2] == 1:
if np.min(image_array) >= 0.:
scipy.misc.toimage(image_array[:, :, 0], cmin=0.0, cmax=1.0).save(filename)
else:
scipy.misc.toimage(image_array[:, :, 0], cmin=-1.0, cmax=1.0).save(filename)
else:
scipy.misc.toimage(255*image_array).save(filename)
def delete_files(folder, recursive=False):
for the_file in os.listdir(folder):
file_path = os.path.join(folder, the_file)
try:
if os.path.isfile(file_path):
os.unlink(file_path)
elif recursive and os.path.isdir(file_path):
delete_files(file_path, recursive)
os.unlink(file_path)
except Exception as e:
print(e)
def create_directory(directory):
if not os.path.exists(directory):
os.makedirs(directory)
def print_vars(_vars):
total_n_vars = 0
for var in _vars:
sh = var.get_shape().as_list()
total_n_vars += np.prod(sh)
print(var.name, sh)
print(total_n_vars, 'total variables')
return total_n_vars
ACTIVATION_FUNCTIONS = {
'sigmoid': tf.nn.sigmoid,
'tanh': tf.nn.tanh,
'relu': tf.nn.relu,
'elu': tf.nn.elu,
'linear': lambda x: x,
'exp': lambda x: tf.exp(x),
'softplus': tf.nn.softplus,
'clip': lambda x: tf.clip_by_value(x, -1., 1.),
'clip_low': lambda x: tf.clip_by_value(x, -1., 1e6)
}
def parse_activation_function(name_list):
return [ACTIVATION_FUNCTIONS[name] for name in name_list]
def evaluate_groups_seq(true_groups, predicted, weights):
""" Compute the weighted AMI score and corresponding mean confidence for given gammas.
:param true_groups: (T, B, 1, W, H, 1)
:param predicted: (T, B, K, W, H, 1)
:param weights: (T)
:return: scores, confidences (B,)
"""
w_scores, w_confidences = 0., 0.
assert true_groups.ndim == predicted.ndim == 6, true_groups.shape
for t in range(true_groups.shape[0]):
scores, confidences = evaluate_groups(true_groups[t], predicted[t])
w_scores += weights[t] * np.array(scores)
w_confidences += weights[t] * np.array(confidences)
norm = np.sum(weights)
return w_scores/norm, w_confidences/norm
def evaluate_groups(true_groups, predicted):
""" Compute the AMI score and corresponding mean confidence for given gammas.
:param true_groups: (B, 1, W, H, 1)
:param predicted: (B, K, W, H, 1)
:return: scores, confidences (B,)
"""
scores, confidences = [], []
assert true_groups.ndim == predicted.ndim == 5, true_groups.shape
batch_size, K = predicted.shape[:2]
true_groups = true_groups.reshape(batch_size, -1)
predicted = predicted.reshape(batch_size, K, -1)
predicted_groups = predicted.argmax(1)
predicted_conf = predicted.max(1)
for i in range(batch_size):
true_group = true_groups[i]
idxs = np.where(true_group != 0.0)[0]
scores.append(adjusted_mutual_info_score(true_group[idxs], predicted_groups[i, idxs]))
confidences.append(np.mean(predicted_conf[i, idxs]))
return scores, confidences
def tf_adjusted_rand_index(groups, gammas, iter_weights):
"""
Inputs:
groups: shape=(T, B, 1, W, H, 1)
These are the masks as stored in the hdf5 files
gammas: shape=(T, B, K, W, H, 1)
These are the gammas as predicted by the network
"""
with tf.name_scope('ARI'):
# ignore first iteration
groups = groups[1:]
gammas = gammas[1:]
# reshape gammas and convert to one-hot
yshape = tf.shape(gammas)
gammas = tf.reshape(gammas, shape=tf.stack([yshape[0] * yshape[1], yshape[2], yshape[3] * yshape[4] * yshape[5]]))
Y = tf.one_hot(tf.argmax(gammas, axis=1), depth=yshape[2], axis=1)
# reshape masks
gshape = tf.shape(groups)
groups = tf.reshape(groups, shape=tf.stack([gshape[0] * gshape[1], 1, gshape[3] * gshape[4] * gshape[5]]))
G = tf.one_hot(tf.cast(groups[:, 0], tf.int32), depth=tf.cast(tf.reduce_max(groups) + 1, tf.int32), axis=1)
# now Y and G both have dim (B*T, K, N) where N=W*H*C
# mask entries with group 0
M = tf.cast(tf.greater(groups, 0.5), tf.float32)
n = tf.cast(tf.reduce_sum(M, axis=[1, 2]), tf.float32)
DM = G * M
YM = Y * M
# contingency table for overlap between G and Y
nij = tf.einsum('bij,bkj->bki', YM, DM)
a = tf.reduce_sum(nij, axis=1)
b = tf.reduce_sum(nij, axis=2)
# rand index
rindex = tf.cast(tf.reduce_sum(nij * (nij-1), axis=[1, 2]), tf.float32)
aindex = tf.cast(tf.reduce_sum(a * (a-1), axis=1), tf.float32)
bindex = tf.cast(tf.reduce_sum(b * (b-1), axis=1), tf.float32)
expected_rindex = aindex * bindex / (n*(n-1) + 1e-6)
max_rindex = (aindex + bindex) / 2
ARI = (rindex - expected_rindex)/tf.clip_by_value(max_rindex - expected_rindex, 1e-6, 1e6)
ARI = tf.reshape(ARI, shape=(yshape[0], yshape[1]))
iter_weigths= tf.constant(np.array(iter_weights)[:, None], dtype=tf.float32)
sum_iter_weights = tf.constant(np.sum(iter_weights), dtype=tf.float32)
seq_ARI = tf.reduce_mean(tf.reduce_sum(ARI * iter_weigths, axis=0) / sum_iter_weights)
last_ARI = tf.reduce_mean(ARI[-1])
confidences = tf.reduce_sum(tf.reduce_max(gammas, axis=1, keep_dims=True) * M, axis=[1, 2]) / n
confidences = tf.reshape(confidences, shape=(yshape[0], yshape[1]))
seq_conf = tf.reduce_mean(tf.reduce_sum(confidences * iter_weigths, axis=0) / sum_iter_weights)
last_conf = tf.reduce_mean(confidences[-1])
return seq_ARI, last_ARI, seq_conf, last_conf
def color_spines(ax, color, lw=2):
for sn in ['top', 'bottom', 'left', 'right']:
ax.spines[sn].set_linewidth(lw)
ax.spines[sn].set_color(color)
ax.spines[sn].set_visible(True)
def color_half_spines(ax, color1, color2, lw=2):
for sn in ['top', 'left']:
ax.spines[sn].set_linewidth(lw)
ax.spines[sn].set_color(color1)
ax.spines[sn].set_visible(True)
for sn in ['bottom', 'right']:
ax.spines[sn].set_linewidth(lw)
ax.spines[sn].set_color(color2)
ax.spines[sn].set_visible(True)
def get_gamma_colors(nr_colors):
hsv_colors = np.ones((nr_colors, 3))
hsv_colors[:, 0] = (np.linspace(0, 1, nr_colors, endpoint=False) + 2/3) % 1.0
color_conv = hsv_to_rgb(hsv_colors)
return color_conv
def overview_plot(i, gammas, preds, inputs, corrupted=None, **kwargs):
attentions = np.array(kwargs['attentions']) if 'attentions' in kwargs else None
T, B, K, W, H, C = gammas.shape
T -= 1 # the initialization doesn't count as iteration
corrupted = corrupted if corrupted is not None else inputs
gamma_colors = get_gamma_colors(K)
# restrict to sample i and get rid of useless dims
inputs = inputs[:, i, 0]
gammas = gammas[:, i, :, :, :, 0]
if preds.shape[1] != B:
preds = preds[:, 0]
preds = preds[:, i]
corrupted = corrupted[:, i, 0]
inputs = np.clip(inputs, 0., 1.)
preds = np.clip(preds, 0., 1.)
corrupted = np.clip(corrupted, 0., 1.)
def plot_img(ax, data, cmap='Greys_r', xlabel=None, ylabel=None, border_color=None):
if data.shape[-1] == 1:
ax.matshow(data[:, :, 0], cmap=cmap, vmin=0., vmax=1., interpolation='nearest')
else:
ax.imshow(data, interpolation='nearest')
ax.set_xticks([]); ax.set_yticks([])
ax.set_xlabel(xlabel, color=border_color or 'k') if xlabel else None
ax.set_ylabel(ylabel, color=border_color or 'k') if ylabel else None
if border_color:
color_spines(ax, color=border_color)
def plot_attention_summary_img(ax, attention, k_excluded, preds, cmap='Greys_r', xlabel=None, ylabel=None, border_color=None):
# copy so we don't mutate
preds = np.copy(preds)
attention = np.copy(attention)
# get focus object as rgb version of black-white
focus_pred = np.tile(np.copy(preds[k_excluded]), [1, 1, 3])
# we are safe to do what we want to do to the k_excluded row
preds[k_excluded] = 0
attention = np.insert(attention, k_excluded, 0) # zero out the focus object
# mask preds by attention
preds = np.transpose(preds[:, :, :, 0], [1, 2, 0]) # (28, 28, K)
preds *= attention
# color the preds
preds = preds.reshape(-1, preds.shape[-1]).dot(gamma_colors).reshape(preds.shape[:-1] + (3,))
# add in the focus object
preds += focus_pred
# plot
ax.imshow(preds, interpolation='nearest')
ax.set_xticks([]); ax.set_yticks([])
ax.set_xlabel(xlabel, color=border_color or 'k') if xlabel else None
ax.set_ylabel(ylabel, color=border_color or 'k') if ylabel else None
if border_color:
color_spines(ax, color=border_color)
def plot_gamma(ax, gamma, xlabel=None, ylabel=None):
gamma = np.transpose(gamma, [1, 2, 0])
gamma = gamma.reshape(-1, gamma.shape[-1]).dot(gamma_colors).reshape(gamma.shape[:-1] + (3,))
ax.imshow(gamma, interpolation='nearest')
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel(xlabel) if xlabel else None
ax.set_ylabel(ylabel) if ylabel else None
nrows, ncols = (K + 4, T + 1)
fig, axes = plt.subplots(nrows=nrows, ncols=ncols,
figsize=(2 * ncols, 2 * nrows))
axes[0, 0].set_visible(False)
axes[1, 0].set_visible(False)
plot_gamma(axes[2, 0], gammas[0], ylabel='Gammas')
for k in range(K + 1):
axes[k + 3, 0].set_visible(False)
for t in range(1, T + 1):
g = gammas[t]
p = preds[t]
reconst = np.sum(g[:, :, :, None] * p, axis=0)
plot_img(axes[0, t], inputs[t])
plot_img(axes[1, t], reconst)
plot_gamma(axes[2, t], g)
for k in range(K):
if attentions is not None:
plot_attention_summary_img(axes[k + 3, t], attentions[t-1, i, k], k, p,
border_color=tuple(gamma_colors[k]),
ylabel=('contexts {} for {}'.format(k-1, k) if t == 1 else None))
else:
plot_img(axes[k + 3, t], p[k], border_color=tuple(gamma_colors[k]), ylabel=('mu_{}'.format(k) if t == 1 else None))
plot_img(axes[K + 3, t], corrupted[t - 1])
plt.subplots_adjust(hspace=0.1, wspace=0.1)
return fig
def curve_plot(values_dict, coarse_range, fine_range):
if fine_range is not None:
fig, ax = plt.subplots(1, 2, figsize=(40, 10))
else:
fig, ax = plt.subplots(1, 1, figsize=(20, 10))
ax = [ax]
for key, values in values_dict.items():
# coarse
ax[0].plot(values, label=key)
ax[0].set_xlabel('epochs')
ax[0].axis([0, len(values), coarse_range[0], coarse_range[1]])
ax[0].set_title("coarse range")
ax[0].legend()
# fine
if fine_range is not None:
ax[1].plot(values, label=key)
ax[1].set_xlabel('epochs')
ax[1].axis([0, len(values), fine_range[0], fine_range[1]])
ax[1].set_title("fine range")
ax[1].legend()
return fig