Skip to content

Commit

Permalink
warning about default relu
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Oct 17, 2016
1 parent dc59ad5 commit 6eb0beb
Show file tree
Hide file tree
Showing 11 changed files with 48 additions and 39 deletions.
4 changes: 2 additions & 2 deletions examples/DisturbLabel/mnist-disturb.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ def _build_graph(self, input_vars):
image, label = input_vars
image = tf.expand_dims(image, 3)

with argscope(Conv2D, kernel_shape=5):
with argscope(Conv2D, kernel_shape=5, nl=tf.nn.relu):
logits = (LinearWrap(image) # the starting brace is only for line-breaking
.Conv2D('conv0', out_channel=32, padding='VALID')
.MaxPooling('pool0', 2)
.Conv2D('conv1', out_channel=64, padding='VALID')
.MaxPooling('pool1', 2)
.FullyConnected('fc0', 512)
.FullyConnected('fc0', 512, nl=tf.nn.relu)
.FullyConnected('fc1', out_dim=10, nl=tf.identity)())
prob = tf.nn.softmax(logits, name='prob')

Expand Down
2 changes: 1 addition & 1 deletion examples/HED/hed.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def branch(name, l, up):
up = up / 2
return l

with argscope(Conv2D, kernel_shape=3):
with argscope(Conv2D, kernel_shape=3, nl=tf.nn.relu):
l = Conv2D('conv1_1', image, 64)
l = Conv2D('conv1_2', l, 64)
b1 = branch('branch1', l, 1)
Expand Down
4 changes: 2 additions & 2 deletions examples/Inception/inception-bn.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def inception(name, x, nr1x1, nr3x3r, nr3x3, nr233r, nr233, nrpool, pooltype):
l = inception('incep3c', l, 0, 128, 160, 64, 96, 0, 'max')

br1 = Conv2D('loss1conv', l, 128, 1)
br1 = FullyConnected('loss1fc', br1, 1024)
br1 = FullyConnected('loss1fc', br1, 1024, nl=tf.nn.relu)
br1 = FullyConnected('loss1logit', br1, 1000, nl=tf.identity)
loss1 = tf.nn.sparse_softmax_cross_entropy_with_logits(br1, label)
loss1 = tf.reduce_mean(loss1, name='loss1')
Expand All @@ -84,7 +84,7 @@ def inception(name, x, nr1x1, nr3x3r, nr3x3, nr233r, nr233, nrpool, pooltype):
l = inception('incep4e', l, 0, 128, 192, 192, 256, 0, 'max')

br2 = Conv2D('loss2conv', l, 128, 1)
br2 = FullyConnected('loss2fc', br2, 1024)
br2 = FullyConnected('loss2fc', br2, 1024, nl=tf.nn.relu)
br2 = FullyConnected('loss2logit', br2, 1000, nl=tf.identity)
loss2 = tf.nn.sparse_softmax_cross_entropy_with_logits(br2, label)
loss2 = tf.reduce_mean(loss2, name='loss2')
Expand Down
4 changes: 2 additions & 2 deletions examples/cifar-convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ def _build_graph(self, input_vars):
.MaxPooling('pool2', 3, stride=2, padding='SAME') \
.Conv2D('conv3.1', out_channel=128, padding='VALID') \
.Conv2D('conv3.2', out_channel=128, padding='VALID') \
.FullyConnected('fc0', 1024 + 512,
.FullyConnected('fc0', 1024 + 512, nl=tf.nn.relu,
b_init=tf.constant_initializer(0.1)) \
.tf.nn.dropout(keep_prob) \
.FullyConnected('fc1', 512,
.FullyConnected('fc1', 512, nl=tf.nn.relu,
b_init=tf.constant_initializer(0.1)) \
.FullyConnected('linear', out_dim=self.cifar_classnum, nl=tf.identity)()

Expand Down
25 changes: 13 additions & 12 deletions examples/load-alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,22 @@ def _build_graph(self, inputs):

image, label = inputs

l = Conv2D('conv1', image, out_channel=96, kernel_shape=11, stride=4, padding='VALID')
l = tf.nn.lrn(l, 2, bias=1.0, alpha=2e-5, beta=0.75, name='norm1')
l = MaxPooling('pool1', l, 3, stride=2, padding='VALID')
with argscope([Conv2D, FullyConnected], nl=tf.nn.relu):
l = Conv2D('conv1', image, out_channel=96, kernel_shape=11, stride=4, padding='VALID')
l = tf.nn.lrn(l, 2, bias=1.0, alpha=2e-5, beta=0.75, name='norm1')
l = MaxPooling('pool1', l, 3, stride=2, padding='VALID')

l = Conv2D('conv2', l, out_channel=256, kernel_shape=5, split=2)
l = tf.nn.lrn(l, 2, bias=1.0, alpha=2e-5, beta=0.75, name='norm2')
l = MaxPooling('pool2', l, 3, stride=2, padding='VALID')
l = Conv2D('conv2', l, out_channel=256, kernel_shape=5, split=2)
l = tf.nn.lrn(l, 2, bias=1.0, alpha=2e-5, beta=0.75, name='norm2')
l = MaxPooling('pool2', l, 3, stride=2, padding='VALID')

l = Conv2D('conv3', l, out_channel=384, kernel_shape=3)
l = Conv2D('conv4', l, out_channel=384, kernel_shape=3, split=2)
l = Conv2D('conv5', l, out_channel=256, kernel_shape=3, split=2)
l = MaxPooling('pool3', l, 3, stride=2, padding='VALID')
l = Conv2D('conv3', l, out_channel=384, kernel_shape=3)
l = Conv2D('conv4', l, out_channel=384, kernel_shape=3, split=2)
l = Conv2D('conv5', l, out_channel=256, kernel_shape=3, split=2)
l = MaxPooling('pool3', l, 3, stride=2, padding='VALID')

l = FullyConnected('fc6', l, 4096)
l = FullyConnected('fc7', l, out_dim=4096)
l = FullyConnected('fc6', l, 4096)
l = FullyConnected('fc7', l, out_dim=4096)
# fc will have activation summary by default. disable this for the output layer
logits = FullyConnected('fc8', l, out_dim=1000, nl=tf.identity)
prob = tf.nn.softmax(logits, name='output')
Expand Down
7 changes: 3 additions & 4 deletions examples/load-vgg16.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _build_graph(self, inputs, is_training):

image, label = inputs

with argscope(Conv2D, kernel_shape=3):
with argscope(Conv2D, kernel_shape=3, nl=tf.nn.relu):
# 224
logits = (LinearWrap(image)
.Conv2D('conv1_1', 64)
Expand All @@ -62,10 +62,9 @@ def _build_graph(self, inputs, is_training):
.Conv2D('conv5_3', 512)
.MaxPooling('pool5', 2)
# 7
.FullyConnected('fc6', 4096)
.FullyConnected('fc6', 4096, nl=tf.nn.relu)
.Dropout('drop0', 0.5)
.print_tensor()
.FullyConnected('fc7', 4096)
.FullyConnected('fc7', 4096, nl=tf.nn.relu)
.Dropout('drop1', 0.5)
.FullyConnected('fc8', out_dim=1000, nl=tf.identity)())
prob = tf.nn.softmax(logits, name='output')
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist-convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _build_graph(self, input_vars):
.Conv2D('conv2')
.MaxPooling('pool1', 2)
.Conv2D('conv3')
.FullyConnected('fc0', 512)
.FullyConnected('fc0', 512, nl=tf.nn.relu)
.Dropout('dropout', 0.5)
.FullyConnected('fc1', out_dim=10, nl=tf.identity)())
prob = tf.nn.softmax(logits, name='prob') # a Bx10 with probabilities
Expand Down
23 changes: 12 additions & 11 deletions examples/svhn-digit-convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,18 @@ def _build_graph(self, input_vars):

image = image / 128.0 - 1

logits = (LinearWrap(image)
.Conv2D('conv1', 24, 5, padding='VALID')
.MaxPooling('pool1', 2, padding='SAME')
.Conv2D('conv2', 32, 3, padding='VALID')
.Conv2D('conv3', 32, 3, padding='VALID')
.MaxPooling('pool2', 2, padding='SAME')
.Conv2D('conv4', 64, 3, padding='VALID')
.Dropout('drop', 0.5)
.FullyConnected('fc0', 512,
b_init=tf.constant_initializer(0.1))
.FullyConnected('linear', out_dim=10, nl=tf.identity)())
with argscope(Conv2D, nl=tf.nn.relu):
logits = (LinearWrap(image)
.Conv2D('conv1', 24, 5, padding='VALID')
.MaxPooling('pool1', 2, padding='SAME')
.Conv2D('conv2', 32, 3, padding='VALID')
.Conv2D('conv3', 32, 3, padding='VALID')
.MaxPooling('pool2', 2, padding='SAME')
.Conv2D('conv4', 64, 3, padding='VALID')
.Dropout('drop', 0.5)
.FullyConnected('fc0', 512,
b_init=tf.constant_initializer(0.1), nl=tf.nn.relu)
.FullyConnected('linear', out_dim=10, nl=tf.identity)())
prob = tf.nn.softmax(logits, name='output')

# compute the number of failed samples, for ClassificationError to use at test time
Expand Down
8 changes: 6 additions & 2 deletions tensorpack/models/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
import tensorflow as tf
import math
from ._common import *
from ..utils import map_arg
from ..utils import map_arg, logger

__all__ = ['Conv2D']

@layer_register()
def Conv2D(x, out_channel, kernel_shape,
padding='SAME', stride=1,
W_init=None, b_init=None,
nl=tf.nn.relu, split=1, use_bias=True):
nl=None, split=1, use_bias=True):
"""
2D convolution on 4D inputs.
Expand Down Expand Up @@ -59,5 +59,9 @@ def Conv2D(x, out_channel, kernel_shape,
outputs = [tf.nn.conv2d(i, k, stride, padding)
for i, k in zip(inputs, kernels)]
conv = tf.concat(3, outputs)
if nl is None:
logger.warn("[DEPRECATED] Default nonlinearity for Conv2D and FullyConnected will be deprecated.")
logger.warn("[DEPRECATED] Please use argscope instead.")
nl = tf.nn.relu
return nl(tf.nn.bias_add(conv, b) if use_bias else conv, name='output')

6 changes: 5 additions & 1 deletion tensorpack/models/fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
@layer_register()
def FullyConnected(x, out_dim,
W_init=None, b_init=None,
nl=tf.nn.relu, use_bias=True):
nl=None, use_bias=True):
"""
Fully-Connected layer.
Expand All @@ -39,4 +39,8 @@ def FullyConnected(x, out_dim,
if use_bias:
b = tf.get_variable('b', [out_dim], initializer=b_init)
prod = tf.nn.xw_plus_b(x, W, b) if use_bias else tf.matmul(x, W)
if nl is None:
logger.warn("[DEPRECATED] Default nonlinearity for Conv2D and FullyConnected will be deprecated.")
logger.warn("[DEPRECATED] Please use argscope instead.")
nl = tf.nn.relu
return nl(prod, name='output')
2 changes: 1 addition & 1 deletion tensorpack/models/model_desc.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def build_graph(self, model_inputs):
:returns: the cost to minimize. a scalar variable
"""
if len(inspect.getargspec(self._build_graph).args) == 3:
logger.warn("_build_graph(self, input_vars, is_training) is deprecated! \
logger.warn("[DEPRECATED] _build_graph(self, input_vars, is_training) is deprecated! \
Use _build_graph(self, input_vars) and get_current_tower_context().is_training instead.")
self._build_graph(model_inputs, get_current_tower_context().is_training)
else:
Expand Down

0 comments on commit 6eb0beb

Please sign in to comment.