-
Notifications
You must be signed in to change notification settings - Fork 12
/
model_share_sd_resnet_up.py
75 lines (67 loc) · 2.11 KB
/
model_share_sd_resnet_up.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import tensorflow as tf
import util
upsample = True
def build_model(x, scale, training, reuse):
hidden_size = 128
bottleneck_size = 32
survival_rate = 0.5
survival_rate = tf.constant(survival_rate, name='survival_rate')
x = tf.layers.conv2d(
x, hidden_size, 1, activation=None, name='in', reuse=reuse)
for i in range(5):
x = util.crop_by_pixel(
x, 1) + conv(x, hidden_size, bottleneck_size, training, 'lr_conv' +
str(i), reuse)
x = tf.nn.relu(x)
x = tf.layers.conv2d_transpose(
x,
hidden_size,
scale,
strides=scale,
activation=None,
name='up',
reuse=reuse)
print x.get_shape().as_list()
for i in range(5):
shortcut = util.crop_by_pixel(x, 1)
#print shortcut.get_shape().as_list()
resblock = conv(
x,
hidden_size,
bottleneck_size,
training,
'hr_conv_share',
reuse=None if i == 0 else True)
if training:
survival_roll = tf.random_uniform(
shape=[], minval=0.0, maxval=1.0, name='suvival' + str(i))
survive = tf.less(survival_roll, survival_rate)
dummy_zero = tf.zeros_like(resblock)
x = tf.cond(survive, lambda: tf.add(shortcut, resblock),
lambda: tf.add(dummy_zero, shortcut))
else:
x = tf.add(tf.mul(resblock, survival_rate), shortcut)
x = tf.nn.relu(x)
x = tf.layers.conv2d(x, 3, 1, activation=None, name='out', reuse=reuse)
return x
def conv(x, hidden_size, bottleneck_size, training, name, reuse):
x = tf.nn.relu(x)
x = tf.layers.conv2d(
x,
bottleneck_size,
1,
activation=None,
name=name + '_proj',
reuse=reuse)
x = tf.nn.relu(x)
x = tf.layers.conv2d(
x,
bottleneck_size,
3,
activation=None,
name=name + '_filt',
reuse=reuse)
x = tf.nn.relu(x)
x = tf.layers.conv2d(
x, hidden_size, 1, activation=None, name=name + '_recv', reuse=reuse)
return x