-
Notifications
You must be signed in to change notification settings - Fork 65
/
tvl1_flow_trainable.py
356 lines (269 loc) · 17.3 KB
/
tvl1_flow_trainable.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
import numpy as np
import tensorflow as tf
import spatial_transformer
class Tvl1Flow(object):
GRAD_IS_ZERO = 1e-12
def __init__(self):
pass
def grey_scale_image(self, x):
assert len(x.shape) == 4
assert x.shape[-1].value == 3, 'number of channels must be 3 (i.e. RGB)'
# grey scale = .299*R + .587*G + .114*B
ker_init = tf.constant_initializer([[0.114], [0.587], [0.299]])
grey_x = tf.layers.conv2d(x, 1, [1, 1], padding='same',
kernel_initializer=ker_init, use_bias=False, trainable=False)
return tf.floor(grey_x)
def normalize_images(self, x1, x2):
reduction_axes = [i for i in xrange(1, len(x1.shape))]
min_x1 = tf.reduce_min(x1, axis=reduction_axes)
max_x1 = tf.reduce_max(x1, axis=reduction_axes)
min_x2 = tf.reduce_min(x2, axis=reduction_axes)
max_x2 = tf.reduce_max(x2, axis=reduction_axes)
min_val = tf.minimum(min_x1, min_x2)
max_val = tf.maximum(max_x1, max_x2)
den = max_val - min_val
expand_dims = [-1 if i == 0 else 1 for i in xrange(len(x1.shape))]
min_val_ex = tf.reshape(min_val, expand_dims)
den_ex = tf.reshape(den, expand_dims)
x1_norm = tf.where(den > 0, 255. * (x1 - min_val_ex) / den_ex, x1)
x2_norm = tf.where(den > 0, 255. * (x2 - min_val_ex) / den_ex, x2)
return x1_norm, x2_norm
def gaussian_smooth(self, x):
assert len(x.shape) == 4
# http://dev.theomader.com/gaussian-kernel-calculator/
# sigma = 0.8, kernel size = 5
ker_init = tf.constant_initializer([[0.000874, 0.006976, 0.01386, 0.006976, 0.000874],
[0.006976, 0.0557, 0.110656, 0.0557, 0.006976],
[0.01386, 0.110656, 0.219833, 0.110656, 0.01386],
[0.006976, 0.0557, 0.110656, 0.0557, 0.006976],
[0.000874, 0.006976, 0.01386, 0.006976, 0.000874]])
smooth_x = tf.layers.conv2d(x, x.shape[-1].value, [5, 5], padding='same',
kernel_initializer=ker_init, use_bias=False, trainable=False)
return smooth_x
def warp_image(self, x, u, v):
assert len(x.shape) == 4
assert len(u.shape) == 3
assert len(v.shape) == 3
# rescale the unit to be pixel
u = u / x.shape[2].value * 2
v = v / x.shape[1].value * 2
delta = tf.concat(axis=1, values=[u, v])
return spatial_transformer.transformer(x, delta, (x.shape[-3].value, x.shape[-2].value))
def centered_gradient(self, x, name):
assert len(x.shape) == 4
with tf.variable_scope('centered_gradient'):
x_ker_init = tf.constant_initializer([[-0.5, 0, 0.5]])
diff_x = tf.layers.conv2d(x, x.shape[-1].value, [1, 3], padding='same',
kernel_initializer=x_ker_init, use_bias=False, name=name + '_diff_x',
trainable=False)
y_ker_init = tf.constant_initializer([[-0.5], [0], [0.5]])
diff_y = tf.layers.conv2d(x, x.shape[-1].value, [3, 1], padding='same',
kernel_initializer=y_ker_init, use_bias=False, name=name + '_diff_y',
trainable=False)
# refine the boundary
first_col = 0.5 * (tf.slice(x, [0, 0, 1, 0], [-1, x.shape[1].value, 1, x.shape[3].value]) -
tf.slice(x, [0, 0, 0, 0], [-1, x.shape[1].value, 1, x.shape[3].value]))
last_col = 0.5 * (
tf.slice(x, [0, 0, x.shape[2].value - 1, 0], [-1, x.shape[1].value, 1, x.shape[3].value]) -
tf.slice(x, [0, 0, x.shape[2].value - 2, 0], [-1, x.shape[1].value, 1, x.shape[3].value]))
diff_x_valid = tf.slice(diff_x, begin=[0, 0, 1, 0],
size=[-1, x.shape[1].value, x.shape[2].value - 2, x.shape[3].value])
diff_x = tf.concat(axis=2, values=[first_col, diff_x_valid, last_col])
first_row = 0.5 * (tf.slice(x, [0, 1, 0, 0], [-1, 1, x.shape[2].value, x.shape[3].value]) -
tf.slice(x, [0, 0, 0, 0], [-1, 1, x.shape[2].value, x.shape[3].value]))
last_row = 0.5 * (
tf.slice(x, [0, x.shape[1].value - 1, 0, 0], [-1, 1, x.shape[2].value, x.shape[3].value]) -
tf.slice(x, [0, x.shape[1].value - 2, 0, 0], [-1, 1, x.shape[2].value, x.shape[3].value]))
diff_y_valid = tf.slice(diff_y, begin=[0, 1, 0, 0],
size=[-1, x.shape[1].value - 2, x.shape[2].value, x.shape[3].value])
diff_y = tf.concat(axis=1, values=[first_row, diff_y_valid, last_row])
return diff_x, diff_y
def forward_gradient(self, x, name):
assert len(x.shape) == 4
with tf.variable_scope('forward_gradient'):
x_ker_init = tf.constant_initializer([[-1, 1]])
diff_x = tf.layers.conv2d(x, x.shape[-1].value, [1, 2], padding='same',
kernel_initializer=x_ker_init, use_bias=False, name=name + '_diff_x',
trainable=True)
y_ker_init = tf.constant_initializer([[-1], [1]])
diff_y = tf.layers.conv2d(x, x.shape[-1].value, [2, 1], padding='same',
kernel_initializer=y_ker_init, use_bias=False, name=name + '_diff_y',
trainable=True)
# refine the boundary
diff_x_valid = tf.slice(diff_x, begin=[0, 0, 0, 0],size=[-1, x.shape[1].value, x.shape[2].value-1, x.shape[3].value])
last_col = tf.zeros([tf.shape(x)[0], x.shape[1].value, 1, x.shape[3].value], dtype=tf.float32)
diff_x = tf.concat(axis=2, values=[diff_x_valid, last_col])
diff_y_valid = tf.slice(diff_y, begin=[0, 0, 0, 0],size=[-1, x.shape[1].value-1, x.shape[2].value, x.shape[3].value])
last_row = tf.zeros([tf.shape(x)[0], 1, x.shape[2].value, x.shape[3].value], dtype=tf.float32)
diff_y = tf.concat(axis=1, values=[diff_y_valid, last_row])
return diff_x, diff_y
def divergence(self, x, y, name):
assert len(x.shape) == 4
with tf.variable_scope('divergence'):
x_valid = tf.slice(x, begin=[0, 0, 0, 0],
size=[-1, x.shape[1].value, x.shape[2].value - 1, x.shape[3].value])
first_col = tf.zeros([tf.shape(x)[0], x.shape[1].value, 1, x.shape[3].value], dtype=tf.float32)
x_pad = tf.concat(axis=2, values=[first_col, x_valid])
y_valid = tf.slice(y, begin=[0, 0, 0, 0],
size=[-1, y.shape[1].value - 1, y.shape[2].value, y.shape[3].value])
first_row = tf.zeros([tf.shape(y)[0], 1, y.shape[2].value, y.shape[3].value], dtype=tf.float32)
y_pad = tf.concat(axis=1, values=[first_row, y_valid])
x_ker_init = tf.constant_initializer([[-1, 1]])
diff_x = tf.layers.conv2d(x_pad, x.shape[-1].value, [1, 2], padding='same',
kernel_initializer=x_ker_init, use_bias=False, name=name + '_diff_x',
trainable=True)
y_ker_init = tf.constant_initializer([[-1], [1]])
diff_y = tf.layers.conv2d(y_pad, y.shape[-1].value, [2, 1], padding='same',
kernel_initializer=y_ker_init, use_bias=False, name=name + '_diff_y',
trainable=True)
div = diff_x + diff_y
return div
def zoom_size(self, height, width, factor):
new_height = int(float(height) * factor + 0.5)
new_width = int(float(width) * factor + 0.5)
return new_height, new_width
def zoom_image(self, x, new_height, new_width):
assert len(x.shape) == 4
delta = tf.zeros((tf.shape(x)[0], 2, new_height * new_width))
zoomed_x = spatial_transformer.transformer(x, delta, (new_height, new_width))
return tf.reshape(zoomed_x, [tf.shape(x)[0], new_height, new_width, x.shape[-1].value])
def dual_tvl1_optic_flow(self, x1, x2, u1, u2,
tau=0.25, # time step
lbda=0.15, # weight parameter for the data term
theta=0.3, # weight parameter for (u - v)^2
warps=5, # number of warpings per scale
max_iterations=5 # maximum number of iterations for optimization
):
l_t = lbda * theta
taut = tau / theta
diff2_x, diff2_y = self.centered_gradient(x2, 'x2')
p11 = p12 = p21 = p22 = tf.zeros_like(x1)
for warpings in xrange(warps):
with tf.variable_scope('warping%d' % (warpings,)):
u1_flat = tf.reshape(u1, (tf.shape(x2)[0], 1, x2.shape[1].value*x2.shape[2].value))
u2_flat = tf.reshape(u2, (tf.shape(x2)[0], 1, x2.shape[1].value*x2.shape[2].value))
x2_warp = self.warp_image(x2, u1_flat, u2_flat)
x2_warp = tf.reshape(x2_warp, tf.shape(x2))
diff2_x_warp = self.warp_image(diff2_x, u1_flat, u2_flat)
diff2_x_warp = tf.reshape(diff2_x_warp, tf.shape(diff2_x))
diff2_y_warp = self.warp_image(diff2_y, u1_flat, u2_flat)
diff2_y_warp = tf.reshape(diff2_y_warp, tf.shape(diff2_y))
diff2_x_sq = tf.square(diff2_x_warp)
diff2_y_sq = tf.square(diff2_y_warp)
grad = diff2_x_sq + diff2_y_sq + self.GRAD_IS_ZERO
rho_c = x2_warp - diff2_x_warp * u1 - diff2_y_warp * u2 - x1
for ii in xrange(max_iterations):
with tf.variable_scope('iter%d' % (ii,)):
rho = rho_c + diff2_x_warp * u1 + diff2_y_warp * u2 + self.GRAD_IS_ZERO;
# fi = -rho / grad
masks1 = rho < -l_t * grad
d1_1 = tf.where(masks1, l_t * diff2_x_warp, tf.zeros_like(diff2_x_warp))
d2_1 = tf.where(masks1, l_t * diff2_y_warp, tf.zeros_like(diff2_y_warp))
masks2 = rho > l_t * grad
d1_2 = tf.where(masks2, -l_t * diff2_x_warp, tf.zeros_like(diff2_x_warp))
d2_2 = tf.where(masks2, -l_t * diff2_y_warp, tf.zeros_like(diff2_y_warp))
masks3 = (~masks1) & (~masks2) & (grad > self.GRAD_IS_ZERO)
d1_3 = tf.where(masks3, -rho / grad * diff2_x_warp, tf.zeros_like(diff2_x_warp))
d2_3 = tf.where(masks3, -rho / grad * diff2_y_warp, tf.zeros_like(diff2_y_warp))
v1 = d1_1 + d1_2 + d1_3 + u1
v2 = d2_1 + d2_2 + d2_3 + u2
u1 = v1 + theta * self.divergence(p11, p12, 'div_p1')
u2 = v2 + theta * self.divergence(p21, p22, 'div_p2')
u1x, u1y = self.forward_gradient(u1, 'u1')
u2x, u2y = self.forward_gradient(u2, 'u2')
p11 = (p11 + taut * u1x) / (
1.0 + taut * tf.sqrt(tf.square(u1x) + tf.square(u1y) + self.GRAD_IS_ZERO));
p12 = (p12 + taut * u1y) / (
1.0 + taut * tf.sqrt(tf.square(u1x) + tf.square(u1y) + self.GRAD_IS_ZERO));
p21 = (p21 + taut * u2x) / (
1.0 + taut * tf.sqrt(tf.square(u2x) + tf.square(u2y) + self.GRAD_IS_ZERO));
p22 = (p22 + taut * u2y) / (
1.0 + taut * tf.sqrt(tf.square(u2x) + tf.square(u2y) + self.GRAD_IS_ZERO));
return u1, u2, rho
def dual_tvl1_optic_flow_multiscale(self, x1, x2,
tau=0.25, # time step
lbda=0.15, # weight parameter for the data term
theta=0.3, # weight parameter for (u - v)^2
warps=5, # number of warpings per scale
zfactor=0.5, # factor for building the image piramid
max_scales=5, # maximum number of scales for image piramid
max_iterations=5 # maximum number of iterations for optimization
):
for i in xrange(len(x1.shape)):
assert x1.shape[i].value == x2.shape[i].value
zfactor = np.float32(zfactor)
height = x1.shape[-3].value
width = x1.shape[-2].value
n_scales = 1 + np.log(np.sqrt(height ** 2 + width ** 2) / 4.0) / np.log(1 / zfactor);
n_scales = min(n_scales, max_scales)
# n_scales = 1
with tf.variable_scope('tvl1_flow'):
grey_x1 = self.grey_scale_image(x1)
grey_x2 = self.grey_scale_image(x2)
norm_imgs = self.normalize_images(grey_x1, grey_x2)
smooth_x1 = self.gaussian_smooth(norm_imgs[0])
smooth_x2 = self.gaussian_smooth(norm_imgs[1])
# smooth_x1 = norm_imgs[0]
# smooth_x2 = norm_imgs[1]
for ss in xrange(n_scales - 1, -1, -1):
with tf.variable_scope('scale%d' % ss):
down_sample_factor = zfactor ** ss
down_height, down_width = self.zoom_size(height, width, down_sample_factor)
if ss == n_scales - 1:
# u1 = u2 = tf.zeros((tf.shape(smooth_x1)[0], down_height, down_width, 1))
u1 = tf.get_variable('u1', shape=[1, down_height, down_width, 1], dtype=tf.float32, initializer=tf.zeros_initializer)
u2 = tf.get_variable('u2', shape=[1, down_height, down_width, 1], dtype=tf.float32, initializer=tf.zeros_initializer)
u1 = tf.tile(u1, [tf.shape(smooth_x1)[0], 1, 1, 1])
u2 = tf.tile(u2, [tf.shape(smooth_x1)[0], 1, 1, 1])
down_x1 = self.zoom_image(smooth_x1, down_height, down_width)
down_x2 = self.zoom_image(smooth_x2, down_height, down_width)
u1, u2, rho = self.dual_tvl1_optic_flow(down_x1, down_x2, u1, u2,
tau=tau, lbda=lbda, theta=theta, warps=warps,
max_iterations=max_iterations)
if ss == 0:
return u1, u2, grey_x2
up_sample_factor = zfactor ** (ss - 1)
up_height, up_width = self.zoom_size(height, width, up_sample_factor)
u1 = self.zoom_image(u1, up_height, up_width) / zfactor
u2 = self.zoom_image(u2, up_height, up_width) / zfactor
def get_loss(self, x1, x2,
tau=0.25, # time step
lbda=0.15, # weight parameter for the data term
theta=0.3, # weight parameter for (u - v)^2
warps=5, # number of warpings per scale
zfactor=0.5, # factor for building the image piramid
max_scales=5, # maximum number of scales for image piramid
max_iterations=5 # maximum number of iterations for optimization
):
u1, u2, rho = self.dual_tvl1_optic_flow_multiscale(x1, x2,
tau=tau, lbda=lbda, theta=theta, warps=warps,
zfactor=zfactor, max_scales=max_scales,
max_iterations=max_iterations)
# computing loss
u1x, u1y = self.forward_gradient(u1, 'u1')
u2x, u2y = self.forward_gradient(u2, 'u2')
# loss = tf.reduce_mean(lbda * tf.abs(rho) + tf.abs(u1x) + tf.abs(u1y) + tf.abs(u2x) + tf.abs(u2y))
u1_flat = tf.reshape(u1, (tf.shape(x2)[0], 1, x2.shape[1].value*x2.shape[2].value))
u2_flat = tf.reshape(u2, (tf.shape(x2)[0], 1, x2.shape[1].value*x2.shape[2].value))
x2_warp = self.warp_image(x2, u1_flat, u2_flat)
x2_warp = tf.reshape(x2_warp, tf.shape(x2))
loss = lbda * tf.reduce_mean(tf.abs(x2_warp - x1)) + tf.reduce_mean(
tf.abs(u1x) + tf.abs(u1y) + tf.abs(u2x) + tf.abs(u2y))
return loss, u1, u2
def get_loss_from_flow(self, u1, u2, x1, x2):
lbda = 0.15
u1x, u1y = self.forward_gradient(u1, 'u1')
u2x, u2y = self.forward_gradient(u2, 'u2')
# loss = tf.reduce_mean(lbda * tf.abs(rho) + tf.abs(u1x) + tf.abs(u1y) + tf.abs(u2x) + tf.abs(u2y))
u1_flat = tf.reshape(u1, (tf.shape(x2)[0], 1, x2.shape[1].value*x2.shape[2].value))
u2_flat = tf.reshape(u2, (tf.shape(x2)[0], 1, x2.shape[1].value*x2.shape[2].value))
x2_warp = self.warp_image(x2, u1_flat, u2_flat)
x2_warp = tf.reshape(x2_warp, tf.shape(x2))
loss = lbda * tf.reduce_mean(tf.abs(x2_warp - x1)) + tf.reduce_mean(
tf.abs(u1x) + tf.abs(u1y) + tf.abs(u2x) + tf.abs(u2y))
return loss
def supervised_loss(self, u1, u2, u1_gt, u2_gt):
bias1 = u1 - u1_gt
bias2 = u2 - u2_gt
current_AEE = tf.reduce_mean(tf.sqrt(tf.pow(bias1, 2) + tf.pow(bias2, 2)))
return current_AEE