Skip to content

Commit

Permalink
fix docs, test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
ceci3 committed Aug 17, 2020
1 parent 0763302 commit 6e46c44
Showing 1 changed file with 35 additions and 75 deletions.
110 changes: 35 additions & 75 deletions python/paddle/fluid/dygraph/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3198,7 +3198,7 @@ class SyncBatchNorm(layers.Layer):
Internal Covariate Shift <https://arxiv.org/pdf/1502.03167.pdf>`_
for more details.
When model in train mode, the :math:`\\mu_{\\beta}`
When model in training mode, the :math:`\\mu_{\\beta}`
and :math:`\\sigma_{\\beta}^{2}` are the statistics of whole mini-batch data in all gpus.
Calculated as follows:
Expand All @@ -3209,12 +3209,12 @@ class SyncBatchNorm(layers.Layer):
\\sigma_{\\beta}^{2} &\\gets \\frac{1}{m} \\sum_{i=1}^{m}(x_i - \\
\\mu_{\\beta})^2 \\qquad &//\ mini-batch\ variance \\\\
- :math:`x` : mini-batch data
- :math:`m` : the size of the mini-batch data
- :math:`x` : whole mini-batch data in all gpus
- :math:`m` : the size of the whole mini-batch data
When model in eval mode, the :math:`\\mu_{\\beta}`
and :math:`\\sigma_{\\beta}^{2}` are global or running statistics (moving_mean and moving_variance).
It usually got from the pre-trained model. Calculated as follows:
When model in evaluation mode, the :math:`\\mu_{\\beta}`
and :math:`\\sigma_{\\beta}^{2}` are global statistics (moving_mean and moving_variance,
which usually got from the pre-trained model). Global statistics calculated as follows:
.. math::
moving\_mean = moving\_mean * momentum + \mu_{\beta} * (1. - momentum) \quad &// global mean \\
Expand All @@ -3229,25 +3229,23 @@ class SyncBatchNorm(layers.Layer):
y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift
- :math:`\\eps` : add a smaller value to the variance to prevent division by zero
- :math:`\\gamma` : trainable proportional parameter
- :math:`\\beta` : trainable deviation parameter
**Note**:
moving mean and moving variance will be calculated whether `track_running_stats` is set to `True`
or `False`, we will fix it in the next version.
- :math:`\\gamma` : trainable scale parameter vector
- :math:`\\beta` : trainable shift parameter vector
Parameters:
num_features(int): Indicate the number of channels of the input ``Tensor``.
epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5.
momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9.
weight_attr(ParamAttr, optional): The parameter attribute for Parameter `scale`
of batch_norm. If it is set to None or one attribute of ParamAttr, batch_norm
weight_attr(ParamAttr|bool, optional): The parameter attribute for Parameter `scale`
of this layer. If it is set to None or one attribute of ParamAttr, this layerr
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr(ParamAttr, optional): The parameter attribute for the bias of batch_norm.
If it is set to None or one attribute of ParamAttr, batch_norm
is not set, the parameter is initialized with Xavier. If it is set to False,
this layer will not have trainable scale parameter. Default: None.
bias_attr(ParamAttr|bool, optional): The parameter attribute for the bias of this layer.
If it is set to None or one attribute of ParamAttr, this layer
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
is not set, the bias is initialized zero. If it is set to False, this layer will not
have trainable bias parameter. Default: None.
track_running_stats(bool, optional): Whether to compute global stats, which including running mean and
running variance. Default: True.
Expand Down Expand Up @@ -3289,31 +3287,35 @@ def __init__(self,
self._epsilon = epsilon
self._track_running_stats = track_running_stats

if self._track_running_stats == False:
logging.warn(
"moving mean and moving variance will be calculated whether `track_running_stats` is set to `True` or `False`, we will fix it in the next version."
)

param_shape = [self._num_features]

### TODO(lvmengsi): remove create param when weight_attr=False in python when BatchNorm kernel support
# create parameter
if weight_attr == False:
self.weight = self.create_parameter(
attr=self._weight_attr,
shape=param_shape,
default_initializer=Constant(1.0))
self.weight.stop_gradient = True
else:
self.weight = self.create_parameter(
attr=self._weight_attr,
shape=param_shape,
default_initializer=Constant(1.0))
self.weight.stop_gradient = self._weight_attr != None and self._weight_attr.learning_rate == 0.

if bias_attr == False:
self.bias = self.create_parameter(
attr=self._bias_attr,
shape=param_shape,
default_initializer=Constant(0.0),
is_bias=True)
self.bias.stop_gradient = True

else:
# create parameter
self.weight = self.create_parameter(
attr=self._weight_attr,
shape=param_shape,
default_initializer=Constant(1.0))
self.weight.stop_gradient = self._weight_attr != None and self._weight_attr.learning_rate == 0.

self.bias = self.create_parameter(
attr=self._bias_attr, shape=param_shape, is_bias=True)
self.bias.stop_gradient = self._weight_attr != None and self._weight_attr.learning_rate == 0.
Expand All @@ -3338,7 +3340,7 @@ def __init__(self,
dtype=self._dtype)
self._variance.stop_gradient = True

def forward(self, input):
def forward(self, x):
# create output
# mean and mean_out share the same memory
mean_out = self._mean
Expand All @@ -3353,13 +3355,13 @@ def forward(self, input):
False, "use_global_stats", not self.training,
'trainable_statistics', False)
sync_batch_norm_out, _, _, _, _, _ = core.ops.sync_batch_norm(
input, self.weight, self.bias, self._mean, self._variance,
mean_out, variance_out, *attrs)
x, self.weight, self.bias, self._mean, self._variance, mean_out,
variance_out, *attrs)

return sync_batch_norm_out

check_variable_and_dtype(input, 'input',
['float16', 'float32', 'float64'], 'BatchNorm')
check_variable_and_dtype(x, 'input', ['float16', 'float32', 'float64'],
'BatchNorm')

attrs = {
"momentum": self._momentum,
Expand All @@ -3373,7 +3375,7 @@ def forward(self, input):
}

inputs = {
"X": [input],
"X": [x],
"Scale": [self.weight],
"Bias": [self.bias],
"Mean": [self._mean],
Expand All @@ -3399,48 +3401,6 @@ def forward(self, input):
type="sync_batch_norm", inputs=inputs, outputs=outputs, attrs=attrs)
return sync_batch_norm_out

### TODO: remove comment after BatchNorm merged.
#@classmethod
#def convert_sync_batchnorm(cls, layer):
# """
# Helper function to convert :class: `paddle.nn.BatchNorm` in the model to :class: `paddle.nn.SyncBatchNorm` layers.

# Parameters:
# layer(paddle.fluid.dygraph.Layer): layer containing one or more `BatchNorm` layers.

# Returns:
# A new SyncBatchNorm layer object if origin layer is BatchNorm Layer.

# Examples:

# .. code-block:: python
# import paddle
# import paddle.nn as nn

# paddle.disable_static()
# model = nn.Sequential(nn.Conv2D(3, 5, 3), nn.BatchNorm(5))
# sync_model = nn.SyncBatchNorm.convert(model)

# """
# layer_output = layer
# if isinstance(layer, BatchNorm):
# layer_output = SyncBatchNorm(layer._num_features,
# layer._epsilon, layer._momentum,
# layer._weight_attr, layer._bias_attr
# layer._data_layout)

# if layer._weight_attr != False and layer._bias_attr != False:
# with no_grad():
# layer_output.weight = layer.weight
# layer_output.bias = layer.bias
# layer_output._mean = layer._mean
# layer_output._variance = layer._variance

# for name, sublayer in layer.named_sublayer():
# layer_output.add_sublayer(name, cls.convert_sync_batchnorm(sublayer))
# del layer
# return layer_output


class Flatten(layers.Layer):
"""
Expand Down

0 comments on commit 6e46c44

Please sign in to comment.