From f311d3c1cfa986c6742a8e2e35759772d3a3df84 Mon Sep 17 00:00:00 2001 From: joejiong Date: Thu, 27 Aug 2020 19:51:25 +0800 Subject: [PATCH 1/8] Fix pow api type error with python side method, merge elementwise_pow and pow. (#26163) As the title --- .../elementwise/elementwise_pow_op.h | 17 +- python/paddle/__init__.py | 0 .../tests/unittests/test_activation_op.py | 0 .../paddle/fluid/tests/unittests/test_pow.py | 239 ++++++++++++++++++ python/paddle/tensor/__init__.py | 0 python/paddle/tensor/math.py | 123 +++++---- 6 files changed, 326 insertions(+), 53 deletions(-) mode change 100644 => 100755 paddle/fluid/operators/elementwise/elementwise_pow_op.h mode change 100644 => 100755 python/paddle/__init__.py mode change 100644 => 100755 python/paddle/fluid/tests/unittests/test_activation_op.py create mode 100755 python/paddle/fluid/tests/unittests/test_pow.py mode change 100644 => 100755 python/paddle/tensor/__init__.py diff --git a/paddle/fluid/operators/elementwise/elementwise_pow_op.h b/paddle/fluid/operators/elementwise/elementwise_pow_op.h old mode 100644 new mode 100755 index ff55d2f2040a1..a910c326196bc --- a/paddle/fluid/operators/elementwise/elementwise_pow_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_pow_op.h @@ -22,15 +22,20 @@ namespace operators { template struct PowFunctor { inline HOSTDEVICE T operator()(T a, T b) const { -#ifdef __CUDA_ARCH__ - // On CUDAPlace, std::pow(3, 1) calls pow(float, float), and - // it will return a float number like 2.99... , which floor to 2 - // when cast to int by default and it is wrong. - // Use llrint to cast it to the nearest integer, which is 3. + // TODO(wujionghao): A potential speed improvement is supporting different + // types in C++. + // #ifdef __CUDA_ARCH__ + // // On CUDAPlace, std::pow(3, 1) calls pow(float, float), and + // // it will return a float number like 2.99... , which floor to 2 + // // when cast to int by default and it is wrong. + // // Use llrint to cast it to the nearest integer, which is 3. + // if (std::is_integral::value) { + // return std::llrint(std::pow(a, b)); + // } + // #endif if (std::is_integral::value) { return std::llrint(std::pow(a, b)); } -#endif return std::pow(a, b); } }; diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py old mode 100644 new mode 100755 diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py old mode 100644 new mode 100755 diff --git a/python/paddle/fluid/tests/unittests/test_pow.py b/python/paddle/fluid/tests/unittests/test_pow.py new file mode 100755 index 0000000000000..0764cb580e40d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_pow.py @@ -0,0 +1,239 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import paddle +import paddle.tensor as tensor +import paddle.fluid as fluid +from paddle.static import Program, program_guard +import numpy as np +import unittest + +DYNAMIC = 1 +STATIC = 2 + + +def _run_power(mode, x, y): + # dynamic mode + if mode == DYNAMIC: + paddle.disable_static() + # y is scalar + if isinstance(y, (int, float)): + x_ = paddle.to_tensor(x) + y_ = y + res = paddle.pow(x_, y_) + return res.numpy() + # y is tensor + else: + x_ = paddle.to_tensor(x) + y_ = paddle.to_tensor(y) + res = paddle.pow(x_, y_) + return res.numpy() + # static mode + elif mode == STATIC: + paddle.enable_static() + # y is scalar + if isinstance(y, (int, float)): + with program_guard(Program(), Program()): + x_ = paddle.static.data(name="x", shape=x.shape, dtype=x.dtype) + y_ = y + res = paddle.pow(x_, y_) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + outs = exe.run(feed={'x': x}, fetch_list=[res]) + return outs[0] + # y is tensor + else: + with program_guard(Program(), Program()): + x_ = paddle.static.data(name="x", shape=x.shape, dtype=x.dtype) + y_ = paddle.static.data(name="y", shape=y.shape, dtype=y.dtype) + res = paddle.pow(x_, y_) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + outs = exe.run(feed={'x': x, 'y': y}, fetch_list=[res]) + return outs[0] + + +class TestPowerAPI(unittest.TestCase): + """TestPowerAPI.""" + + def test_power(self): + """test_power.""" + np.random.seed(7) + # test 1-d float tensor ** float scalar + dims = (np.random.randint(200, 300), ) + x = (np.random.rand(*dims) * 10).astype(np.float64) + y = np.random.rand() * 10 + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + # test 1-d float tensor ** int scalar + dims = (np.random.randint(200, 300), ) + x = (np.random.rand(*dims) * 10).astype(np.float64) + y = int(np.random.rand() * 10) + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + x = (np.random.rand(*dims) * 10).astype(np.int64) + y = int(np.random.rand() * 10) + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + # test 1-d float tensor ** 1-d float tensor + dims = (np.random.randint(200, 300), ) + x = (np.random.rand(*dims) * 10).astype(np.float64) + y = (np.random.rand(*dims) * 10).astype(np.float64) + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + # test 1-d float tensor ** 1-d int tensor + dims = (np.random.randint(200, 300), ) + x = (np.random.rand(*dims) * 10).astype(np.float64) + y = (np.random.rand(*dims) * 10).astype(np.int64) + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + # test 1-d int tensor ** 1-d float tensor + dims = (np.random.randint(200, 300), ) + x = (np.random.rand(*dims) * 10).astype(np.int64) + y = (np.random.rand(*dims) * 10).astype(np.float64) + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + # test 1-d int tensor ** 1-d int tensor + dims = (np.random.randint(200, 300), ) + x = (np.random.rand(*dims) * 10).astype(np.int64) + y = (np.random.rand(*dims) * 10).astype(np.int64) + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + # test 1-d int tensor ** 1-d int tensor + dims = (np.random.randint(200, 300), ) + x = (np.random.rand(*dims) * 10).astype(np.int32) + y = (np.random.rand(*dims) * 10).astype(np.int32) + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + # test 1-d int tensor ** 1-d int tensor + dims = (np.random.randint(200, 300), ) + x = (np.random.rand(*dims) * 10).astype(np.int64) + y = (np.random.rand(*dims) * 10).astype(np.int32) + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + # test 1-d int tensor ** 1-d int tensor + dims = (np.random.randint(200, 300), ) + x = (np.random.rand(*dims) * 10).astype(np.int32) + y = (np.random.rand(*dims) * 10).astype(np.int64) + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + # test 1-d int tensor ** 1-d int tensor + dims = (np.random.randint(200, 300), ) + x = (np.random.rand(*dims) * 10).astype(np.float32) + y = (np.random.rand(*dims) * 10).astype(np.float32) + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + # test 1-d int tensor ** 1-d int tensor + dims = (np.random.randint(200, 300), ) + x = (np.random.rand(*dims) * 10).astype(np.float64) + y = (np.random.rand(*dims) * 10).astype(np.float32) + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + # test 1-d int tensor ** 1-d int tensor + dims = (np.random.randint(200, 300), ) + x = (np.random.rand(*dims) * 10).astype(np.float64) + y = (np.random.rand(*dims) * 10).astype(np.int32) + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + # test 1-d int tensor ** 1-d int tensor + dims = (np.random.randint(200, 300), ) + x = (np.random.rand(*dims) * 10).astype(np.float32) + y = (np.random.rand(*dims) * 10).astype(np.int64) + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + # test broadcast + dims = (np.random.randint(1, 10), np.random.randint(5, 10), + np.random.randint(5, 10)) + x = (np.random.rand(*dims) * 10).astype(np.float64) + y = (np.random.rand(dims[-1]) * 10).astype(np.float64) + res = _run_power(DYNAMIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + res = _run_power(STATIC, x, y) + self.assertTrue(np.allclose(res, np.power(x, y))) + + +class TestPowerError(unittest.TestCase): + """TestPowerError.""" + + def test_errors(self): + """test_errors.""" + np.random.seed(7) + + # test dynamic computation graph: inputs must be broadcastable + dims = (np.random.randint(1, 10), np.random.randint(5, 10), + np.random.randint(5, 10)) + x = (np.random.rand(*dims) * 10).astype(np.float64) + y = (np.random.rand(dims[-1] + 1) * 10).astype(np.float64) + self.assertRaises(fluid.core.EnforceNotMet, _run_power, DYNAMIC, x, y) + self.assertRaises(fluid.core.EnforceNotMet, _run_power, STATIC, x, y) + + # test dynamic computation graph: inputs must be broadcastable + dims = (np.random.randint(1, 10), np.random.randint(5, 10), + np.random.randint(5, 10)) + x = (np.random.rand(*dims) * 10).astype(np.float64) + y = (np.random.rand(dims[-1] + 1) * 10).astype(np.int8) + self.assertRaises(TypeError, paddle.pow, x, y) + + # test 1-d float tensor ** int string + dims = (np.random.randint(200, 300), ) + x = (np.random.rand(*dims) * 10).astype(np.float64) + y = int(np.random.rand() * 10) + self.assertRaises(TypeError, paddle.pow, x, str(y)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py old mode 100644 new mode 100755 diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 9dfb31a5ac25b..e0317f4faceed 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -17,6 +17,8 @@ from __future__ import print_function from paddle.common_ops_import import * +from paddle.tensor import cast +import paddle from ..fluid import layers from ..fluid.framework import core, _varbase_creator, in_dygraph_mode, Variable from ..fluid.layer_helper import LayerHelper @@ -64,6 +66,7 @@ from ..fluid import layers import paddle + __all__ = [ 'abs', 'acos', @@ -86,8 +89,8 @@ 'logsumexp', 'mul', 'multiplex', - 'prod', 'pow', + 'prod', 'reciprocal', 'reduce_max', 'reduce_min', @@ -147,64 +150,87 @@ VarDesc.VarType.FP64, ] -@templatedoc() -def pow(input, exponent, name=None): +def pow(x, y, name=None): """ - :alias_main: paddle.pow - :alias: paddle.pow,paddle.tensor.pow,paddle.tensor.math.pow + Compute the power of tensor elements. The equation is: - This is Pow Activation Operator. + .. math:: + out = x^{y} - :math:`out = input^{exponent}` + **Note**: + ``paddle.pow`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting` . - Args: - input(Variable): A ``Tensor`` or ``LoDTensor`` . The data type is ``float32`` or ``float64``. - exponent(float32|Variable): A scalar with type ``float32`` or a ``Tensor`` with shape [1] and type ``float32``. - name(str, optional): The default value is None. Normally there is no need for user to set this property. - For more information, please refer to :ref:`api_guide_Name` . + Args: + x (Tensor): An N-D Tensor, the data type is float32, float64, int32 or int64. + y (Tensor): An N-D Tensor with type float32, float64, int32 or int64. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + Returns: - Variable: A ``Tensor`` or ``LoDTensor``. The data type is same as ``input``. + N-D Tensor. A location into which the result is stored. Its dimension equals with $x$. Examples: - .. code-block:: python + .. code-block:: python import paddle - import paddle.fluid as fluid - - x = fluid.data(name="x", shape=[32,32], dtype="float32") + import numpy as np - # example 1: argument exponent is float - y_1 = paddle.pow(x, 2.0) - # y_1 is x^{2.0} + paddle.disable_static() + + # example 1: y is a float + x_data = np.array([1, 2, 3]) + y = 2 + x = paddle.to_tensor(x_data) + res = paddle.pow(x, y) + print(res.numpy()) # [1 4 9] + + # example 2: y is a Tensor + y = paddle.fill_constant(shape=[1], value=2, dtype='float32') + res = paddle.pow(x, y) + print(res.numpy()) # [1 4 9] - # example 2: argument exponent is Variable - exponent_tensor = fluid.layers.fill_constant([1], "float32", 3.0) - y_2 = paddle.pow(x, exponent_tensor) - # y_2 is x^{3.0} """ + # in dynamic graph mode if in_dygraph_mode(): - return core.ops.pow(input, "exponent", exponent) - - helper = LayerHelper('pow', **locals()) - inputs = {'X': input} - attrs = {} - if isinstance(exponent, Variable): - exponent.stop_gradient = True - inputs['FactorTensor'] = exponent + if isinstance(y, (int, float)): + return core.ops.pow(x, 'factor', y) + elif isinstance(y, (paddle.Tensor, Variable)): + + if x.dtype != y.dtype: + y = cast(y, dtype='float64') + x = cast(x, dtype='float64') + out_dygraph = _elementwise_op_in_dygraph( + x, y, axis=-1, act=None, op_name='elementwise_pow') + return out_dygraph + + return _elementwise_op_in_dygraph( + x, y, axis=-1, act=None, op_name='elementwise_pow') + else: + raise TypeError('y must be scalar or tensor type, but received: %s '% (y.dtype)) + # in static graph mode else: - attrs['factor'] = exponent - - out = helper.create_variable_for_type_inference(dtype=input.dtype) - check_dtype( - out.dtype, out.name, - convert_dtype(input.dtype), 'pow', - '(The out data type in pow must be the same with input data type.)') + if isinstance(y, (int, float)): + helper = LayerHelper('pow', **locals()) + inputs = {'X': x} + attrs = {'factor': y} + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='pow', inputs=inputs, outputs={'Out': out}, attrs=attrs) + return out + elif isinstance(y, (paddle.Tensor, Variable)): + # TODO A potential speed improvement is supporting different types in C++ and removing the cast ops here + helper = LayerHelper('elementwise_pow', **locals()) + if x.dtype != y.dtype: + y = cast(y, dtype='float64') + x = cast(x, dtype='float64') + out = helper.create_variable_for_type_inference(dtype=x.dtype) + else: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + return _elementwise_op(LayerHelper('elementwise_pow', **locals())) + else: + raise TypeError('y must be scalar or tensor type, but received: %s '% (type(y))) - helper.append_op( - type='pow', inputs=inputs, outputs={'Out': out}, attrs=attrs) - return out @dygraph_only @@ -227,6 +253,8 @@ def _elementwise_op(helper): x = helper.kwargs.get('x', None) y = helper.kwargs.get('y', None) + out = helper.kwargs.get('out', None) + assert x is not None, 'x cannot be None in {}'.format(original_op_type) assert y is not None, 'y cannot be None in {}'.format(original_op_type) check_variable_and_dtype( @@ -239,11 +267,12 @@ def _elementwise_op(helper): axis = helper.kwargs.get('axis', -1) use_mkldnn = helper.kwargs.get('use_mkldnn', False) name = helper.kwargs.get('name', None) - if name is None: - out = helper.create_variable_for_type_inference(dtype=x.dtype) - else: - out = helper.create_variable( - name=name, dtype=x.dtype, persistable=False) + + if out is None: + if name is None: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + else: + out = helper.create_variable(name=name, dtype=x.dtype, persistable=False) helper.append_op( type=op_type, From f4083010a756f0a9554fdfef26fea262e8978f91 Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Thu, 27 Aug 2020 20:52:18 +0800 Subject: [PATCH 2/8] Add unified RNN APIs (#26588) * Add RNN related apis in paddl.nn test=develop * new rnn api, cell almost done * add new progresses in rnn APIs for 2.0 * refine rnn APIs and docstrings. * add unittets * disable gpu tests when paddle is not compiled with cuda support * remove unnecessary imports * fix docstring * add to no_sample wlist * backport to python2 to avoid yield from * add **kwargs, fix typos * update docstrings for birnn * rename argument for SimpleRNN and SimpleRNNCell, fix sample code * add default value for initial_states in fluid.layers.birnn Co-authored-by: guosheng --- python/paddle/fluid/layers/rnn.py | 288 +++- .../fluid/tests/unittests/CMakeLists.txt | 1 + .../fluid/tests/unittests/rnn/CMakeLists.txt | 6 + .../fluid/tests/unittests/rnn/__init__.py | 13 + .../fluid/tests/unittests/rnn/convert.py | 51 + .../fluid/tests/unittests/rnn/rnn_numpy.py | 516 +++++++ .../tests/unittests/rnn/test_rnn_cells.py | 166 ++ .../unittests/rnn/test_rnn_cells_static.py | 326 ++++ .../tests/unittests/rnn/test_rnn_nets.py | 269 ++++ .../unittests/rnn/test_rnn_nets_static.py | 470 ++++++ python/paddle/nn/__init__.py | 3 + python/paddle/nn/functional/__init__.py | 2 + python/paddle/nn/functional/rnn.py | 8 +- python/paddle/nn/layer/__init__.py | 2 + python/paddle/nn/layer/rnn.py | 1331 ++++++++++++++++- tools/wlist.json | 15 +- 16 files changed, 3391 insertions(+), 76 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/rnn/CMakeLists.txt create mode 100644 python/paddle/fluid/tests/unittests/rnn/__init__.py create mode 100644 python/paddle/fluid/tests/unittests/rnn/convert.py create mode 100644 python/paddle/fluid/tests/unittests/rnn/rnn_numpy.py create mode 100644 python/paddle/fluid/tests/unittests/rnn/test_rnn_cells.py create mode 100644 python/paddle/fluid/tests/unittests/rnn/test_rnn_cells_static.py create mode 100644 python/paddle/fluid/tests/unittests/rnn/test_rnn_nets.py create mode 100644 python/paddle/fluid/tests/unittests/rnn/test_rnn_nets_static.py diff --git a/python/paddle/fluid/layers/rnn.py b/python/paddle/fluid/layers/rnn.py index bc1368b562d7b..fe8ed83923e88 100644 --- a/python/paddle/fluid/layers/rnn.py +++ b/python/paddle/fluid/layers/rnn.py @@ -38,6 +38,7 @@ 'Decoder', 'BeamSearchDecoder', 'rnn', + 'birnn', 'dynamic_decode', 'DecodeHelper', 'TrainingHelper', @@ -438,61 +439,146 @@ def rnn(cell, is_reverse=False, **kwargs): """ - :api_attr: Static Graph - rnn creates a recurrent neural network specified by RNNCell `cell`, - which performs :code:`cell.call()` repeatedly until reaches to the maximum - length of `inputs`. - - Parameters: - cell(RNNCell): An instance of `RNNCell`. - inputs(Variable): A (possibly nested structure of) tensor variable[s]. - The shape of tensor should be `[batch_size, sequence_length, ...]` - for `time_major == False` or `[sequence_length, batch_size, ...]` - for `time_major == True`. It represents the inputs to be unrolled - in RNN. - initial_states(Variable, optional): A (possibly nested structure of) - tensor variable[s], representing the initial state for RNN. - If not provided, `cell.get_initial_states` would be used to produce - the initial state. Default None. - sequence_length(Variable, optional): A tensor with shape `[batch_size]`. - It stores real length of each instance, thus enables users to extract - the last valid state when past a batch element's sequence length for - correctness. If not provided, the paddings would be treated same as - non-padding inputs. Default None. - time_major(bool, optional): Indicate the data layout of Tensor included - in `input` and `output` tensors. If `False`, the data layout would - be batch major with shape `[batch_size, sequence_length, ...]`. If - `True`, the data layout would be time major with shape - `[sequence_length, batch_size, ...]`. Default: `False`. - is_reverse(bool, optional): Indicate whether to calculate in the reverse - order of input sequences. Default: `False`. - **kwargs: Additional keyword arguments. Arguments passed to `cell.call`. + which performs :code:`cell.call()` (for dygraph mode :code:`cell.forward`) + repeatedly until reaches to the maximum length of `inputs`. + + Arguments: + cell(RNNCellBase): An instance of `RNNCellBase`. + inputs(Tensor): the input sequences. + If time_major is True, the shape is + `[time_steps, batch_size, input_size]` + else the shape is `[batch_size, time_steps, input_size]`. + initial_states(Tensor|tuple|list, optional): the initial state of the + rnn cell. Tensor or a possibly nested structure of tensors. If not + provided, `cell.get_initial_states` would be called to produce + the initial state. Defaults to None. + sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64 + or int32. The valid lengths of input sequences. Defaults to None. + If `sequence_length` is not None, the inputs are treated as + padded sequences. In each input sequence, elements whose time step + index are not less than the valid length are treated as paddings. + time_major (bool): Whether the first dimension of the input means the + time steps. Defaults to False. + is_reverse (bool, optional): Indicate whether to calculate in the reverse + order of input sequences. Defaults to False. + **kwargs: Additional keyword arguments to pass to `forward` of the cell. Returns: - tuple: A tuple( :code:`(final_outputs, final_states)` ) including the final \ - outputs and states, both are Tensor or nested structure of Tensor. \ - `final_outputs` has the same structure and data types as \ - the returned `outputs` of :code:`cell.call` , and each Tenser in `final_outputs` \ - stacks all time steps' counterpart in `outputs` thus has shape `[batch_size, sequence_length, ...]` \ - for `time_major == False` or `[sequence_length, batch_size, ...]` for `time_major == True`. \ - `final_states` is the counterpart at last time step of initial states, \ - thus has the same structure with it and has tensors with same shapes \ - and data types. + (outputs, final_states) + outputs (Tensor|list|tuple): the output sequence. Tensor or nested + structure of Tensors. + If `time_major` is True, the shape of each tensor in outpus is + `[time_steps, batch_size, hidden_size]`, else + `[batch_size, time_steps, hidden_size]`. + final_states (Tensor|list|tuple): final states. A (possibly nested structure of) + tensor[s], representing the final state for RNN. It has the same + structure of intial state. Each tensor in final states has the same + shape and dtype as the corresponding tensor in initial states. Examples: .. code-block:: python - - import paddle.fluid as fluid - inputs = fluid.data(name="inputs", - shape=[-1, 32, 128], - dtype="float32") - cell = fluid.layers.GRUCell(hidden_size=128) - outputs = fluid.layers.rnn(cell=cell, inputs=inputs) + import paddle + paddle.disable_static() + + cell = paddle.nn.SimpleRNNCell(16, 32) + + inputs = paddle.rand((4, 23, 16)) + prev_h = paddle.randn((4, 32)) + outputs, final_states = paddle.nn.functional.rnn(cell, inputs, prev_h) + """ + if in_dygraph_mode(): + return _rnn_dynamic_graph(cell, inputs, initial_states, sequence_length, + time_major, is_reverse, **kwargs) + else: + return _rnn_static_graph(cell, inputs, initial_states, sequence_length, + time_major, is_reverse, **kwargs) + + +class ArrayWrapper(object): + def __init__(self, x): + self.array = [x] + + def append(self, x): + self.array.append(x) + return self + + +def _maybe_copy(state, new_state, step_mask): + """update rnn state or just pass the old state through""" + new_state = nn.elementwise_mul(new_state, step_mask, axis=0) \ + + nn.elementwise_mul(state, (1 - step_mask), axis=0) + return new_state + + +def _transpose_batch_time(x): + perm = [1, 0] + list(range(2, len(x.shape))) + return nn.transpose(x, perm) + + +def _rnn_dynamic_graph(cell, + inputs, + initial_states=None, + sequence_length=None, + time_major=False, + is_reverse=False, + **kwargs): + time_step_index = 0 if time_major else 1 + flat_inputs = flatten(inputs) + time_steps = flat_inputs[0].shape[time_step_index] + + if not time_major: + inputs = map_structure(_transpose_batch_time, inputs) + + if sequence_length is not None: + mask = sequence_lod.sequence_mask( + sequence_length, maxlen=time_steps, dtype=inputs.dtype) + mask = nn.transpose(mask, [1, 0]) + + if is_reverse: + inputs = map_structure(lambda x: tensor.reverse(x, axis=[0]), inputs) + mask = tensor.reverse(mask, axis=[0]) \ + if sequence_length is not None else None + + states = initial_states + outputs = [] + for i in range(time_steps): + step_inputs = map_structure(lambda x: x[i], inputs) + step_outputs, new_states = cell(step_inputs, states, **kwargs) + if sequence_length is not None: + new_states = map_structure( + partial( + _maybe_copy, step_mask=mask[i]), states, new_states) + states = new_states + outputs = map_structure(lambda x: ArrayWrapper(x), + step_outputs) if i == 0 else map_structure( + lambda x, x_array: x_array.append(x), + step_outputs, outputs) + + final_outputs = map_structure( + lambda x: nn.stack(x.array, axis=time_step_index), + outputs) + + if is_reverse: + final_outputs = map_structure( + lambda x: tensor.reverse(x, axis=time_step_index), + final_outputs) + + final_states = new_states + return final_outputs, final_states + + +def _rnn_static_graph(cell, + inputs, + initial_states=None, + sequence_length=None, + time_major=False, + is_reverse=False, + **kwargs): check_type(inputs, 'inputs', (Variable, list, tuple), 'rnn') if isinstance(inputs, (list, tuple)): for i, input_x in enumerate(inputs): @@ -500,30 +586,10 @@ def rnn(cell, ['float32', 'float64'], 'rnn') check_type(initial_states, 'initial_states', (Variable, list, tuple, type(None)), 'rnn') - if isinstance(initial_states, (list, tuple)): - states = map_structure(lambda x: x, initial_states)[0] - for i, state in enumerate(states): - if isinstance(state, (list, tuple)): - for j, state_j in enumerate(state): - check_variable_and_dtype(state_j, 'state_j[' + str(j) + ']', - ['float32', 'float64'], 'rnn') - else: - check_variable_and_dtype(state, 'states[' + str(i) + ']', - ['float32', 'float64'], 'rnn') check_type(sequence_length, 'sequence_length', (Variable, type(None)), 'rnn') - def _maybe_copy(state, new_state, step_mask): - # TODO: use where_op - new_state = nn.elementwise_mul( - new_state, step_mask, axis=0) - nn.elementwise_mul( - state, (step_mask - 1), axis=0) - return new_state - - def _transpose_batch_time(x): - return nn.transpose(x, [1, 0] + list(range(2, len(x.shape)))) - def _switch_grad(x, stop=False): x.stop_gradient = stop return x @@ -582,6 +648,98 @@ def _switch_grad(x, stop=False): return (final_outputs, final_states) +def birnn(cell_fw, + cell_bw, + inputs, + initial_states=None, + sequence_length=None, + time_major=False, + **kwargs): + """ + birnn creates a bidirectional recurrent neural network specified by + RNNCell `cell_fw` and `cell_bw`, which performs :code:`cell.call()` + (for dygraph mode :code:`cell.forward`) repeatedly until reaches to + the maximum length of `inputs` and then concat the ouputs for both RNNs + along the last axis. + + Arguments: + cell_fw(RNNCellBase): An instance of `RNNCellBase`. + cell_bw(RNNCellBase): An instance of `RNNCellBase`. + inputs(Tensor): the input sequences. + If time_major is True, the shape is + `[time_steps, batch_size, input_size]` + else the shape is `[batch_size, time_steps, input_size]`. + initial_states(tuple, optional): A tuple of initial states of + `cell_fw` and `cell_bw`. + If not provided, `cell.get_initial_states` would be called to + produce initial state for each cell. Defaults to None. + sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64 + or int32. The valid lengths of input sequences. Defaults to None. + If `sequence_length` is not None, the inputs are treated as + padded sequences. In each input sequence, elements whose time step + index are not less than the valid length are treated as paddings. + time_major (bool): Whether the first dimension of the input means the + time steps. Defaults to False. + **kwargs: Additional keyword arguments to pass to `forward` of each cell. + + Returns: + (outputs, final_states) + outputs (Tensor): the outputs of the bidirectional RNN. It is the + concatenation of the outputs from the forward RNN and backward + RNN along the last axis. + If time major is True, the shape is `[time_steps, batch_size, size]`, + else the shape is `[batch_size, time_steps, size]`, where size is + `cell_fw.hidden_size + cell_bw.hidden_size`. + final_states (tuple): A tuple of the final states of the forward + cell and backward cell. + + Examples: + + .. code-block:: python + + import paddle + paddle.disable_static() + + cell_fw = paddle.nn.LSTMCell(16, 32) + cell_bw = paddle.nn.LSTMCell(16, 32) + + inputs = paddle.rand((4, 23, 16)) + hf, cf = paddle.rand((4, 32)), paddle.rand((4, 32)) + hb, cb = paddle.rand((4, 32)), paddle.rand((4, 32)) + initial_states = ((hf, cf), (hb, cb)) + outputs, final_states = paddle.nn.functional.birnn( + cell_fw, cell_bw, inputs, initial_states) + + """ + if initial_states is None: + states_fw = cell_fw.get_initial_states( + batch_ref=inputs, batch_dim_idx=1 if time_major else 0) + states_bw = cell_fw.get_initial_states( + batch_ref=inputs, batch_dim_idx=1 if time_major else 0) + else: + states_fw, states_bw = initial_states + outputs_fw, states_fw = rnn(cell_fw, + inputs, + states_fw, + sequence_length, + time_major=time_major, + **kwargs) + + outputs_bw, states_bw = rnn(cell_bw, + inputs, + states_bw, + sequence_length, + time_major=time_major, + is_reverse=True, + **kwargs) + + outputs = map_structure(lambda x, y: tensor.concat([x, y], -1), outputs_fw, + outputs_bw) + + final_states = (states_fw, states_bw) + return outputs, final_states + + class Decoder(object): """ :api_attr: Static Graph diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 7e51ecf07e599..6220bf62c79c3 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -542,6 +542,7 @@ endif() add_subdirectory(sequence) add_subdirectory(dygraph_to_static) +add_subdirectory(rnn) if (WITH_MKLDNN) add_subdirectory(mkldnn) diff --git a/python/paddle/fluid/tests/unittests/rnn/CMakeLists.txt b/python/paddle/fluid/tests/unittests/rnn/CMakeLists.txt new file mode 100644 index 0000000000000..f71e04c09aa38 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/rnn/CMakeLists.txt @@ -0,0 +1,6 @@ +file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") +string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") + +foreach(TEST_OP ${TEST_OPS}) + py_test_modules(${TEST_OP} MODULES ${TEST_OP}) +endforeach(TEST_OP) diff --git a/python/paddle/fluid/tests/unittests/rnn/__init__.py b/python/paddle/fluid/tests/unittests/rnn/__init__.py new file mode 100644 index 0000000000000..abf198b97e6e8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/rnn/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/paddle/fluid/tests/unittests/rnn/convert.py b/python/paddle/fluid/tests/unittests/rnn/convert.py new file mode 100644 index 0000000000000..02f10694a4b47 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/rnn/convert.py @@ -0,0 +1,51 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np + + +def convert_params_for_cell(np_cell, paddle_cell): + state = np_cell.parameters + for k, v in paddle_cell.named_parameters(): + v.set_value(state[k]) + + +def convert_params_for_cell_static(np_cell, paddle_cell, place): + state = np_cell.parameters + for k, v in paddle_cell.named_parameters(): + scope = paddle.static.global_scope() + tensor = scope.find_var(v.name).get_tensor() + tensor.set(state[k], place) + + +def convert_params_for_net(np_net, paddle_net): + for np_layer, paddle_layer in zip(np_net, paddle_net): + if hasattr(np_layer, "cell"): + convert_params_for_cell(np_layer.cell, paddle_layer.cell) + else: + convert_params_for_cell(np_layer.cell_fw, paddle_layer.cell_fw) + convert_params_for_cell(np_layer.cell_bw, paddle_layer.cell_bw) + + +def convert_params_for_net_static(np_net, paddle_net, place): + for np_layer, paddle_layer in zip(np_net, paddle_net): + if hasattr(np_layer, "cell"): + convert_params_for_cell_static(np_layer.cell, paddle_layer.cell, + place) + else: + convert_params_for_cell_static(np_layer.cell_fw, + paddle_layer.cell_fw, place) + convert_params_for_cell_static(np_layer.cell_bw, + paddle_layer.cell_bw, place) diff --git a/python/paddle/fluid/tests/unittests/rnn/rnn_numpy.py b/python/paddle/fluid/tests/unittests/rnn/rnn_numpy.py new file mode 100644 index 0000000000000..7e0b8374b95cf --- /dev/null +++ b/python/paddle/fluid/tests/unittests/rnn/rnn_numpy.py @@ -0,0 +1,516 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import math + + +class LayerMixin(object): + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + +class LayerListMixin(LayerMixin): + def __init__(self, layers=None): + self._layers = list(layers) if layers else [] + + def append(self, layer): + self._layers.append(layer) + + def __iter__(self): + return iter(self._layers) + + +class SimpleRNNCell(LayerMixin): + def __init__(self, input_size, hidden_size, bias=True, nonlinearity="tanh"): + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + if nonlinearity == 'tanh': + self.nonlinearity = np.tanh + else: + self.nonlinearity = lambda x: np.maximum(x, 0.) + + self.parameters = dict() + std = 1.0 / math.sqrt(hidden_size) + self.weight_ih = np.random.uniform(-std, std, ( + hidden_size, input_size)).astype('float64') + self.weight_hh = np.random.uniform(-std, std, ( + hidden_size, hidden_size)).astype('float64') + self.parameters['weight_ih'] = self.weight_ih + self.parameters['weight_hh'] = self.weight_hh + if bias: + self.bias_ih = np.random.uniform(-std, std, + (hidden_size, )).astype('float64') + self.bias_hh = np.random.uniform(-std, std, + (hidden_size, )).astype('float64') + self.parameters['bias_ih'] = self.bias_ih + self.parameters['bias_hh'] = self.bias_hh + else: + self.bias_ih = None + self.bias_hh = None + + def init_state(self, inputs): + batch_size = inputs.shape[0] + return np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype) + + def forward(self, inputs, hx=None): + if hx is None: + hx = self.init_state(inputs) + pre_h = hx + i2h = np.matmul(inputs, self.weight_ih.T) + if self.bias_ih is not None: + i2h += self.bias_ih + h2h = np.matmul(pre_h, self.weight_hh.T) + if self.bias_hh is not None: + h2h += self.bias_hh + h = self.nonlinearity(i2h + h2h) + return h, h + + +class GRUCell(LayerMixin): + def __init__(self, input_size, hidden_size, bias=True): + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + self.parameters = dict() + std = 1.0 / math.sqrt(hidden_size) + self.weight_ih = np.random.uniform(-std, std, ( + 3 * hidden_size, input_size)).astype('float64') + self.weight_hh = np.random.uniform(-std, std, ( + 3 * hidden_size, hidden_size)).astype('float64') + self.parameters['weight_ih'] = self.weight_ih + self.parameters['weight_hh'] = self.weight_hh + if bias: + self.bias_ih = np.random.uniform(-std, std, ( + 3 * hidden_size)).astype('float64') + self.bias_hh = np.random.uniform(-std, std, ( + 3 * hidden_size)).astype('float64') + self.parameters['bias_ih'] = self.bias_ih + self.parameters['bias_hh'] = self.bias_hh + else: + self.bias_ih = None + self.bias_hh = None + + def init_state(self, inputs): + batch_size = inputs.shape[0] + return np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype) + + def forward(self, inputs, hx=None): + if hx is None: + hx = self.init_state(inputs) + pre_hidden = hx + x_gates = np.matmul(inputs, self.weight_ih.T) + if self.bias_ih is not None: + x_gates = x_gates + self.bias_ih + h_gates = np.matmul(pre_hidden, self.weight_hh.T) + if self.bias_hh is not None: + h_gates = h_gates + self.bias_hh + + x_r, x_z, x_c = np.split(x_gates, 3, 1) + h_r, h_z, h_c = np.split(h_gates, 3, 1) + + r = 1.0 / (1.0 + np.exp(-(x_r + h_r))) + z = 1.0 / (1.0 + np.exp(-(x_z + h_z))) + c = np.tanh(x_c + r * h_c) # apply reset gate after mm + h = (pre_hidden - c) * z + c + return h, h + + +class LSTMCell(LayerMixin): + def __init__(self, input_size, hidden_size, bias=True): + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + self.parameters = dict() + std = 1.0 / math.sqrt(hidden_size) + self.weight_ih = np.random.uniform(-std, std, ( + 4 * hidden_size, input_size)).astype('float64') + self.weight_hh = np.random.uniform(-std, std, ( + 4 * hidden_size, hidden_size)).astype('float64') + self.parameters['weight_ih'] = self.weight_ih + self.parameters['weight_hh'] = self.weight_hh + if bias: + self.bias_ih = np.random.uniform(-std, std, ( + 4 * hidden_size)).astype('float64') + self.bias_hh = np.random.uniform(-std, std, ( + 4 * hidden_size)).astype('float64') + self.parameters['bias_ih'] = self.bias_ih + self.parameters['bias_hh'] = self.bias_hh + else: + self.bias_ih = None + self.bias_hh = None + + def init_state(self, inputs): + batch_size = inputs.shape[0] + init_h = np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype) + init_c = np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype) + return init_h, init_c + + def forward(self, inputs, hx=None): + if hx is None: + hx = self.init_state(inputs) + pre_hidden, pre_cell = hx + gates = np.matmul(inputs, self.weight_ih.T) + if self.bias_ih is not None: + gates = gates + self.bias_ih + gates += np.matmul(pre_hidden, self.weight_hh.T) + if self.bias_hh is not None: + gates = gates + self.bias_hh + + chunked_gates = np.split(gates, 4, -1) + + i = 1.0 / (1.0 + np.exp(-chunked_gates[0])) + f = 1.0 / (1.0 + np.exp(-chunked_gates[1])) + o = 1.0 / (1.0 + np.exp(-chunked_gates[3])) + c = f * pre_cell + i * np.tanh(chunked_gates[2]) + h = o * np.tanh(c) + + return h, (h, c) + + +def sequence_mask(lengths, max_len=None): + if max_len is None: + max_len = np.max(lengths) + else: + assert max_len >= np.max(lengths) + return np.arange(max_len) < np.expand_dims(lengths, -1) + + +def update_state(mask, new, old): + if not isinstance(old, (tuple, list)): + return np.where(mask, new, old) + else: + return tuple(map(lambda x, y: np.where(mask, x, y), new, old)) + + +def rnn(cell, + inputs, + initial_states, + sequence_length=None, + time_major=False, + is_reverse=False): + if not time_major: + inputs = np.transpose(inputs, [1, 0, 2]) + if is_reverse: + inputs = np.flip(inputs, 0) + + if sequence_length is None: + mask = None + else: + mask = np.transpose(sequence_mask(sequence_length), [1, 0]) + mask = np.expand_dims(mask, -1) + if is_reverse: + mask = np.flip(mask, 0) + + time_steps = inputs.shape[0] + state = initial_states + outputs = [] + for t in range(time_steps): + x_t = inputs[t] + if mask is not None: + m_t = mask[t] + y, new_state = cell(x_t, state) + y = np.where(m_t, y, 0.) + outputs.append(y) + state = update_state(m_t, new_state, state) + else: + y, new_state = cell(x_t, state) + outputs.append(y) + state = new_state + + outputs = np.stack(outputs) + final_state = state + + if is_reverse: + outputs = np.flip(outputs, 0) + if not time_major: + outputs = np.transpose(outputs, [1, 0, 2]) + return outputs, final_state + + +def birnn(cell_fw, + cell_bw, + inputs, + initial_states, + sequence_length=None, + time_major=False): + states_fw, states_bw = initial_states + outputs_fw, states_fw = rnn(cell_fw, + inputs, + states_fw, + sequence_length, + time_major=time_major) + + outputs_bw, states_bw = rnn(cell_bw, + inputs, + states_bw, + sequence_length, + time_major=time_major, + is_reverse=True) + + outputs = np.concatenate((outputs_fw, outputs_bw), -1) + final_states = (states_fw, states_bw) + return outputs, final_states + + +def flatten(nested): + return list(_flatten(nested)) + + +def _flatten(nested): + for item in nested: + if isinstance(item, (list, tuple)): + for subitem in _flatten(item): + yield subitem + else: + yield item + + +def unstack(array, axis=0): + num = array.shape[axis] + sub_arrays = np.split(array, num, axis) + return [np.squeeze(sub_array, axis) for sub_array in sub_arrays] + + +def dropout(array, p=0.5): + if p == 0.0: + return array + + mask = (np.random.uniform(size=array.shape) < (1 - p)).astype(array.dtype) + return array * (mask / (1 - p)) + + +def split_states(states, bidirectional=False, state_components=1): + if state_components == 1: + states = unstack(states) + if not bidirectional: + return states + else: + return list(zip(states[::2], states[1::2])) + else: + assert len(states) == state_components + states = tuple([unstack(item) for item in states]) + if not bidirectional: + return list(zip(*states)) + else: + states = list(zip(*states)) + return list(zip(states[::2], states[1::2])) + + +def concat_states(states, bidirectional=False, state_components=1): + if state_components == 1: + return np.stack(flatten(states)) + else: + states = flatten(states) + componnets = [] + for i in range(state_components): + componnets.append(states[i::state_components]) + return [np.stack(item) for item in componnets] + + +class RNN(LayerMixin): + def __init__(self, cell, is_reverse=False, time_major=False): + super(RNN, self).__init__() + self.cell = cell + if not hasattr(self.cell, "call"): + # for non-dygraph mode, `rnn` api uses cell.call + self.cell.call = self.cell.forward + self.is_reverse = is_reverse + self.time_major = time_major + + def forward(self, inputs, initial_states=None, sequence_length=None): + final_outputs, final_states = rnn(self.cell, + inputs, + initial_states=initial_states, + sequence_length=sequence_length, + time_major=self.time_major, + is_reverse=self.is_reverse) + return final_outputs, final_states + + +class BiRNN(LayerMixin): + def __init__(self, cell_fw, cell_bw, time_major=False): + super(BiRNN, self).__init__() + self.cell_fw = cell_fw + self.cell_bw = cell_bw + self.time_major = time_major + + def forward(self, + inputs, + initial_states=None, + sequence_length=None, + **kwargs): + if isinstance(initial_states, (list, tuple)): + assert len(initial_states) == 2, \ + "length of initial_states should be 2 when it is a list/tuple" + else: + initial_states = [initial_states, initial_states] + + outputs, final_states = birnn(self.cell_fw, self.cell_bw, inputs, + initial_states, sequence_length, + self.time_major) + return outputs, final_states + + +class RNNMixin(LayerListMixin): + def forward(self, inputs, initial_states=None, sequence_length=None): + batch_index = 1 if self.time_major else 0 + batch_size = inputs.shape[batch_index] + dtype = inputs.dtype + if initial_states is None: + state_shape = (self.num_layers * self.num_directions, batch_size, + self.hidden_size) + if self.state_components == 1: + initial_states = np.zeros(state_shape, dtype) + else: + initial_states = tuple([ + np.zeros(state_shape, dtype) + for _ in range(self.state_components) + ]) + + states = split_states(initial_states, self.num_directions == 2, + self.state_components) + final_states = [] + + for i, rnn_layer in enumerate(self): + if i > 0: + inputs = dropout(inputs, self.dropout) + outputs, final_state = rnn_layer(inputs, states[i], sequence_length) + final_states.append(final_state) + inputs = outputs + + final_states = concat_states(final_states, self.num_directions == 2, + self.state_components) + return outputs, final_states + + +class SimpleRNN(RNNMixin): + def __init__(self, + input_size, + hidden_size, + num_layers=1, + nonlinearity="tanh", + direction="forward", + dropout=0., + time_major=False): + super(SimpleRNN, self).__init__() + + if direction in ["forward", "backward"]: + is_reverse = direction == "backward" + cell = SimpleRNNCell(input_size, hidden_size, nonlinearity) + self.append(RNN(cell, is_reverse, time_major)) + for i in range(1, num_layers): + cell = SimpleRNNCell(hidden_size, hidden_size, nonlinearity) + self.append(RNN(cell, is_reverse, time_major)) + elif direction == "bidirectional": + cell_fw = SimpleRNNCell(input_size, hidden_size, nonlinearity) + cell_bw = SimpleRNNCell(input_size, hidden_size, nonlinearity) + self.append(BiRNN(cell_fw, cell_bw, time_major)) + for i in range(1, num_layers): + cell_fw = SimpleRNNCell(2 * hidden_size, hidden_size, + nonlinearity) + cell_bw = SimpleRNNCell(2 * hidden_size, hidden_size, + nonlinearity) + self.append(BiRNN(cell_fw, cell_bw, time_major)) + else: + raise ValueError( + "direction should be forward, backward or bidirectional, " + "received direction = {}".format(direction)) + + self.input_size = input_size + self.hidden_size = hidden_size + self.dropout = dropout + self.num_directions = 2 if direction == "bidirectional" else 1 + self.time_major = time_major + self.num_layers = num_layers + self.state_components = 1 + + +class LSTM(RNNMixin): + def __init__(self, + input_size, + hidden_size, + num_layers=1, + direction="forward", + dropout=0., + time_major=False): + super(LSTM, self).__init__() + + if direction in ["forward", "backward"]: + is_reverse = direction == "backward" + cell = LSTMCell(input_size, hidden_size) + self.append(RNN(cell, is_reverse, time_major)) + for i in range(1, num_layers): + cell = LSTMCell(hidden_size, hidden_size) + self.append(RNN(cell, is_reverse, time_major)) + elif direction == "bidirectional": + cell_fw = LSTMCell(input_size, hidden_size) + cell_bw = LSTMCell(input_size, hidden_size) + self.append(BiRNN(cell_fw, cell_bw, time_major)) + for i in range(1, num_layers): + cell_fw = LSTMCell(2 * hidden_size, hidden_size) + cell_bw = LSTMCell(2 * hidden_size, hidden_size) + self.append(BiRNN(cell_fw, cell_bw, time_major)) + else: + raise ValueError( + "direction should be forward, backward or bidirectional, " + "received direction = {}".format(direction)) + + self.input_size = input_size + self.hidden_size = hidden_size + self.dropout = dropout + self.num_directions = 2 if direction == "bidirectional" else 1 + self.time_major = time_major + self.num_layers = num_layers + self.state_components = 2 + + +class GRU(RNNMixin): + def __init__(self, + input_size, + hidden_size, + num_layers=1, + direction="forward", + dropout=0., + time_major=False): + super(GRU, self).__init__() + + if direction in ["forward", "backward"]: + is_reverse = direction == "backward" + cell = GRUCell(input_size, hidden_size) + self.append(RNN(cell, is_reverse, time_major)) + for i in range(1, num_layers): + cell = GRUCell(hidden_size, hidden_size) + self.append(RNN(cell, is_reverse, time_major)) + elif direction == "bidirectional": + cell_fw = GRUCell(input_size, hidden_size) + cell_bw = GRUCell(input_size, hidden_size) + self.append(BiRNN(cell_fw, cell_bw, time_major)) + for i in range(1, num_layers): + cell_fw = GRUCell(2 * hidden_size, hidden_size) + cell_bw = GRUCell(2 * hidden_size, hidden_size) + self.append(BiRNN(cell_fw, cell_bw, time_major)) + else: + raise ValueError( + "direction should be forward, backward or bidirectional, " + "received direction = {}".format(direction)) + + self.input_size = input_size + self.hidden_size = hidden_size + self.dropout = dropout + self.num_directions = 2 if direction == "bidirectional" else 1 + self.time_major = time_major + self.num_layers = num_layers + self.state_components = 1 diff --git a/python/paddle/fluid/tests/unittests/rnn/test_rnn_cells.py b/python/paddle/fluid/tests/unittests/rnn/test_rnn_cells.py new file mode 100644 index 0000000000000..8d2677229a03f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/rnn/test_rnn_cells.py @@ -0,0 +1,166 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +paddle.framework.set_default_dtype("float64") + +import numpy as np +import unittest + +from rnn_numpy import SimpleRNNCell, LSTMCell, GRUCell +from convert import convert_params_for_cell + + +class TestSimpleRNNCell(unittest.TestCase): + def __init__(self, bias=True, place="cpu"): + super(TestSimpleRNNCell, self).__init__(methodName="runTest") + self.bias = bias + self.place = paddle.CPUPlace() if place == "cpu" \ + else paddle.CUDAPlace(0) + + def setUp(self): + paddle.disable_static(self.place) + rnn1 = SimpleRNNCell(16, 32, bias=self.bias) + rnn2 = paddle.nn.SimpleRNNCell( + 16, 32, bias_ih_attr=self.bias, bias_hh_attr=self.bias) + convert_params_for_cell(rnn1, rnn2) + + self.rnn1 = rnn1 + self.rnn2 = rnn2 + + def test_with_initial_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(4, 16) + prev_h = np.random.randn(4, 32) + + y1, h1 = rnn1(x, prev_h) + y2, h2 = rnn2(paddle.to_variable(x), paddle.to_variable(prev_h)) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + + def test_with_zero_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(4, 16) + + y1, h1 = rnn1(x) + y2, h2 = rnn2(paddle.to_variable(x)) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + + def runTest(self): + self.test_with_initial_state() + self.test_with_zero_state() + + +class TestGRUCell(unittest.TestCase): + def __init__(self, bias=True, place="cpu"): + super(TestGRUCell, self).__init__(methodName="runTest") + self.bias = bias + self.place = paddle.CPUPlace() if place == "cpu" \ + else paddle.CUDAPlace(0) + + def setUp(self): + paddle.disable_static(self.place) + rnn1 = GRUCell(16, 32, bias=self.bias) + rnn2 = paddle.nn.GRUCell( + 16, 32, bias_ih_attr=self.bias, bias_hh_attr=self.bias) + convert_params_for_cell(rnn1, rnn2) + + self.rnn1 = rnn1 + self.rnn2 = rnn2 + + def test_with_initial_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(4, 16) + prev_h = np.random.randn(4, 32) + + y1, h1 = rnn1(x, prev_h) + y2, h2 = rnn2(paddle.to_variable(x), paddle.to_variable(prev_h)) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + + def test_with_zero_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(4, 16) + + y1, h1 = rnn1(x) + y2, h2 = rnn2(paddle.to_variable(x)) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + + def runTest(self): + self.test_with_initial_state() + self.test_with_zero_state() + + +class TestLSTMCell(unittest.TestCase): + def __init__(self, bias=True, place="cpu"): + super(TestLSTMCell, self).__init__(methodName="runTest") + self.bias = bias + self.place = paddle.CPUPlace() if place == "cpu" \ + else paddle.CUDAPlace(0) + + def setUp(self): + rnn1 = LSTMCell(16, 32, bias=self.bias) + rnn2 = paddle.nn.LSTMCell( + 16, 32, bias_ih_attr=self.bias, bias_hh_attr=self.bias) + convert_params_for_cell(rnn1, rnn2) + + self.rnn1 = rnn1 + self.rnn2 = rnn2 + + def test_with_initial_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(4, 16) + prev_h = np.random.randn(4, 32) + prev_c = np.random.randn(4, 32) + + y1, (h1, c1) = rnn1(x, (prev_h, prev_c)) + y2, (h2, c2) = rnn2( + paddle.to_variable(x), + (paddle.to_variable(prev_h), paddle.to_variable(prev_c))) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5) + + def test_with_zero_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(4, 16) + + y1, (h1, c1) = rnn1(x) + y2, (h2, c2) = rnn2(paddle.to_variable(x)) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5) + + def runTest(self): + self.test_with_initial_state() + self.test_with_zero_state() + + +def load_tests(loader, tests, pattern): + suite = unittest.TestSuite() + devices = ["cpu", "gpu"] if paddle.fluid.is_compiled_with_cuda() \ + else ["cpu"] + for bias in [True, False]: + for device in devices: + for test_class in [TestSimpleRNNCell, TestGRUCell, TestLSTMCell]: + suite.addTest(test_class(bias, device)) + return suite diff --git a/python/paddle/fluid/tests/unittests/rnn/test_rnn_cells_static.py b/python/paddle/fluid/tests/unittests/rnn/test_rnn_cells_static.py new file mode 100644 index 0000000000000..948e47d5b9946 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/rnn/test_rnn_cells_static.py @@ -0,0 +1,326 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +paddle.framework.set_default_dtype("float64") + +import numpy as np +import unittest + +from convert import convert_params_for_cell_static +from rnn_numpy import SimpleRNNCell, LSTMCell, GRUCell + + +class TestSimpleRNNCell(unittest.TestCase): + def __init__(self, bias=True, place="cpu"): + super(TestSimpleRNNCell, self).__init__(methodName="runTest") + self.bias = bias + self.place = paddle.CPUPlace() if place == "cpu" \ + else paddle.CUDAPlace(0) + + def setUp(self): + rnn1 = SimpleRNNCell(16, 32, bias=self.bias) + + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + rnn2 = paddle.nn.SimpleRNNCell( + 16, 32, bias_ih_attr=self.bias, bias_hh_attr=self.bias) + + place = self.place + exe = paddle.static.Executor(place) + scope = paddle.fluid.Scope() + with paddle.static.scope_guard(scope): + exe.run(sp) + convert_params_for_cell_static(rnn1, rnn2, place) + + self.mp = mp + self.sp = sp + self.rnn1 = rnn1 + self.rnn2 = rnn2 + + self.executor = exe + self.scope = scope + + def test_with_initial_state(self): + mp = self.mp.clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(4, 16) + prev_h = np.random.randn(4, 32) + + y1, h1 = rnn1(x, prev_h) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, 16], + dtype=paddle.framework.get_default_dtype()) + init_h = paddle.data( + "init_h", [-1, 32], + dtype=paddle.framework.get_default_dtype()) + y, h = rnn2(x_data, init_h) + + feed_dict = {x_data.name: x, init_h.name: prev_h} + with paddle.static.scope_guard(scope): + y2, h2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h]) + + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + + def test_with_zero_state(self): + mp = self.mp.clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(4, 16) + + y1, h1 = rnn1(x) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, 16], + dtype=paddle.framework.get_default_dtype()) + y, h = rnn2(x_data) + + feed_dict = {x_data.name: x} + + with paddle.static.scope_guard(scope): + y2, h2 = exe.run(mp, + feed=feed_dict, + fetch_list=[y, h], + use_prune=True) + + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + + def runTest(self): + self.test_with_initial_state() + self.test_with_zero_state() + + +class TestGRUCell(unittest.TestCase): + def __init__(self, bias=True, place="cpu"): + super(TestGRUCell, self).__init__(methodName="runTest") + self.bias = bias + self.place = paddle.CPUPlace() if place == "cpu" \ + else paddle.CUDAPlace(0) + + def setUp(self): + rnn1 = GRUCell(16, 32, bias=self.bias) + + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + rnn2 = paddle.nn.GRUCell( + 16, 32, bias_ih_attr=self.bias, bias_hh_attr=self.bias) + + place = self.place + exe = paddle.static.Executor(place) + scope = paddle.fluid.Scope() + with paddle.static.scope_guard(scope): + exe.run(sp) + convert_params_for_cell_static(rnn1, rnn2, place) + + self.mp = mp + self.sp = sp + self.rnn1 = rnn1 + self.rnn2 = rnn2 + + self.place = place + self.executor = exe + self.scope = scope + + def test_with_initial_state(self): + mp = self.mp.clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(4, 16) + prev_h = np.random.randn(4, 32) + + y1, h1 = rnn1(x, prev_h) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, 16], + dtype=paddle.framework.get_default_dtype()) + init_h = paddle.data( + "init_h", [-1, 32], + dtype=paddle.framework.get_default_dtype()) + y, h = rnn2(x_data, init_h) + + feed_dict = {x_data.name: x, init_h.name: prev_h} + with paddle.static.scope_guard(scope): + y2, h2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h]) + + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + + def test_with_zero_state(self): + mp = self.mp.clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(4, 16) + + y1, h1 = rnn1(x) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, 16], + dtype=paddle.framework.get_default_dtype()) + y, h = rnn2(x_data) + + feed_dict = {x_data.name: x} + + with paddle.static.scope_guard(scope): + y2, h2 = exe.run(mp, + feed=feed_dict, + fetch_list=[y, h], + use_prune=True) + + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + + def runTest(self): + self.test_with_initial_state() + self.test_with_zero_state() + + +class TestLSTMCell(unittest.TestCase): + def __init__(self, bias=True, place="cpu"): + super(TestLSTMCell, self).__init__(methodName="runTest") + self.bias = bias + self.place = paddle.CPUPlace() if place == "cpu" \ + else paddle.CUDAPlace(0) + + def setUp(self): + rnn1 = LSTMCell(16, 32, bias=self.bias) + + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + rnn2 = paddle.nn.LSTMCell( + 16, 32, bias_ih_attr=self.bias, bias_hh_attr=self.bias) + + place = self.place + exe = paddle.static.Executor(place) + scope = paddle.fluid.Scope() + with paddle.static.scope_guard(scope): + exe.run(sp) + convert_params_for_cell_static(rnn1, rnn2, place) + + self.mp = mp + self.sp = sp + self.rnn1 = rnn1 + self.rnn2 = rnn2 + + self.place = place + self.executor = exe + self.scope = scope + + def test_with_initial_state(self): + mp = self.mp.clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(4, 16) + prev_h = np.random.randn(4, 32) + prev_c = np.random.randn(4, 32) + + y1, (h1, c1) = rnn1(x, (prev_h, prev_c)) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, 16], + dtype=paddle.framework.get_default_dtype()) + init_h = paddle.data( + "init_h", [-1, 32], + dtype=paddle.framework.get_default_dtype()) + init_c = paddle.data( + "init_c", [-1, 32], + dtype=paddle.framework.get_default_dtype()) + y, (h, c) = rnn2(x_data, (init_h, init_c)) + + feed_dict = {x_data.name: x, init_h.name: prev_h, init_c.name: prev_c} + with paddle.static.scope_guard(scope): + y2, h2, c2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h, c]) + + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(c1, c2, atol=1e-8, rtol=1e-5) + + def test_with_zero_state(self): + mp = self.mp.clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(4, 16) + + y1, (h1, c1) = rnn1(x) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, 16], + dtype=paddle.framework.get_default_dtype()) + y, (h, c) = rnn2(x_data) + + feed_dict = {x_data.name: x} + + with paddle.static.scope_guard(scope): + y2, h2, c2 = exe.run(mp, + feed=feed_dict, + fetch_list=[y, h, c], + use_prune=True) + + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(c1, c2, atol=1e-8, rtol=1e-5) + + def runTest(self): + self.test_with_initial_state() + self.test_with_zero_state() + + +def load_tests(loader, tests, pattern): + suite = unittest.TestSuite() + devices = ["cpu", "gpu"] if paddle.fluid.is_compiled_with_cuda() \ + else ["cpu"] + for bias in [True, False]: + for device in devices: + for test_class in [TestSimpleRNNCell, TestGRUCell, TestLSTMCell]: + suite.addTest(test_class(bias, device)) + return suite diff --git a/python/paddle/fluid/tests/unittests/rnn/test_rnn_nets.py b/python/paddle/fluid/tests/unittests/rnn/test_rnn_nets.py new file mode 100644 index 0000000000000..ef297b3bb6249 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/rnn/test_rnn_nets.py @@ -0,0 +1,269 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +paddle.set_default_dtype("float64") +from paddle.fluid.layers import sequence_mask + +import numpy as np +import unittest + +from convert import convert_params_for_net +from rnn_numpy import SimpleRNN, LSTM, GRU + + +class TestSimpleRNN(unittest.TestCase): + def __init__(self, time_major=True, direction="forward", place="cpu"): + super(TestSimpleRNN, self).__init__("runTest") + self.time_major = time_major + self.direction = direction + self.num_directions = 2 if direction == "bidirectional" else 1 + self.place = paddle.CPUPlace() if place == "cpu" \ + else paddle.CUDAPlace(0) + + def setUp(self): + paddle.disable_static(self.place) + rnn1 = SimpleRNN( + 16, 32, 2, time_major=self.time_major, direction=self.direction) + rnn2 = paddle.nn.SimpleRNN( + 16, 32, 2, time_major=self.time_major, direction=self.direction) + convert_params_for_net(rnn1, rnn2) + + self.rnn1 = rnn1 + self.rnn2 = rnn2 + + def test_with_initial_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + prev_h = np.random.randn(2 * self.num_directions, 4, 32) + + y1, h1 = rnn1(x, prev_h) + y2, h2 = rnn2(paddle.to_variable(x), paddle.to_variable(prev_h)) + np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + + def test_with_zero_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + + y1, h1 = rnn1(x) + y2, h2 = rnn2(paddle.to_variable(x)) + np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + + def test_with_input_lengths(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + sequence_length = np.array([12, 10, 9, 8], dtype=np.int64) + + y1, h1 = rnn1(x, sequence_length=sequence_length) + + seq_len = paddle.to_variable(sequence_length) + mask = sequence_mask(seq_len, dtype=paddle.get_default_dtype()) + if self.time_major: + mask = paddle.transpose(mask, [1, 0]) + y2, h2 = rnn2(paddle.to_variable(x), sequence_length=seq_len) + y2 = paddle.multiply(y2, mask, axis=0) + + np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + + def runTest(self): + self.test_with_initial_state() + self.test_with_zero_state() + self.test_with_input_lengths() + + +class TestGRU(unittest.TestCase): + def __init__(self, time_major=True, direction="forward", place="cpu"): + super(TestGRU, self).__init__("runTest") + self.time_major = time_major + self.direction = direction + self.num_directions = 2 if direction == "bidirectional" else 1 + self.place = paddle.CPUPlace() if place == "cpu" \ + else paddle.CUDAPlace(0) + + def setUp(self): + paddle.disable_static(self.place) + rnn1 = GRU(16, + 32, + 2, + time_major=self.time_major, + direction=self.direction) + rnn2 = paddle.nn.GRU(16, + 32, + 2, + time_major=self.time_major, + direction=self.direction) + convert_params_for_net(rnn1, rnn2) + + self.rnn1 = rnn1 + self.rnn2 = rnn2 + + def test_with_initial_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + prev_h = np.random.randn(2 * self.num_directions, 4, 32) + + y1, h1 = rnn1(x, prev_h) + y2, h2 = rnn2(paddle.to_variable(x), paddle.to_variable(prev_h)) + np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + + def test_with_zero_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + + y1, h1 = rnn1(x) + y2, h2 = rnn2(paddle.to_variable(x)) + np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + + def test_with_input_lengths(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + sequence_length = np.array([12, 10, 9, 8], dtype=np.int64) + + y1, h1 = rnn1(x, sequence_length=sequence_length) + + seq_len = paddle.to_variable(sequence_length) + mask = sequence_mask(seq_len, dtype=paddle.get_default_dtype()) + if self.time_major: + mask = paddle.transpose(mask, [1, 0]) + y2, h2 = rnn2(paddle.to_variable(x), sequence_length=seq_len) + y2 = paddle.multiply(y2, mask, axis=0) + + np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + + def runTest(self): + self.test_with_initial_state() + self.test_with_zero_state() + self.test_with_input_lengths() + + +class TestLSTM(unittest.TestCase): + def __init__(self, time_major=True, direction="forward", place="cpu"): + super(TestLSTM, self).__init__("runTest") + self.time_major = time_major + self.direction = direction + self.num_directions = 2 if direction == "bidirectional" else 1 + self.place = paddle.CPUPlace() if place == "cpu" \ + else paddle.CUDAPlace(0) + + def setUp(self): + paddle.disable_static(self.place) + rnn1 = LSTM( + 16, 32, 2, time_major=self.time_major, direction=self.direction) + rnn2 = paddle.nn.LSTM( + 16, 32, 2, time_major=self.time_major, direction=self.direction) + convert_params_for_net(rnn1, rnn2) + + self.rnn1 = rnn1 + self.rnn2 = rnn2 + + def test_with_initial_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + prev_h = np.random.randn(2 * self.num_directions, 4, 32) + prev_c = np.random.randn(2 * self.num_directions, 4, 32) + + y1, (h1, c1) = rnn1(x, (prev_h, prev_c)) + y2, (h2, c2) = rnn2( + paddle.to_variable(x), + (paddle.to_variable(prev_h), paddle.to_variable(prev_c))) + np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5) + + def test_with_zero_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + + y1, (h1, c1) = rnn1(x) + y2, (h2, c2) = rnn2(paddle.to_variable(x)) + np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5) + + def test_with_input_lengths(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + sequence_length = np.array([12, 10, 9, 8], dtype=np.int64) + + y1, (h1, c1) = rnn1(x, sequence_length=sequence_length) + + seq_len = paddle.to_variable(sequence_length) + mask = sequence_mask(seq_len, dtype=paddle.get_default_dtype()) + if self.time_major: + mask = paddle.transpose(mask, [1, 0]) + y2, (h2, c2) = rnn2(paddle.to_variable(x), sequence_length=seq_len) + y2 = paddle.multiply(y2, mask, axis=0) + + np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5) + + def runTest(self): + self.test_with_initial_state() + self.test_with_zero_state() + self.test_with_input_lengths() + + +def load_tests(loader, tests, pattern): + suite = unittest.TestSuite() + devices = ["cpu", "gpu"] if paddle.fluid.is_compiled_with_cuda() \ + else ["cpu"] + for direction in ["forward", "backward", "bidirectional"]: + for time_major in [True, False]: + for device in devices: + for test_class in [TestSimpleRNN, TestLSTM, TestGRU]: + suite.addTest(test_class(time_major, direction, device)) + return suite diff --git a/python/paddle/fluid/tests/unittests/rnn/test_rnn_nets_static.py b/python/paddle/fluid/tests/unittests/rnn/test_rnn_nets_static.py new file mode 100644 index 0000000000000..90ed6b8b4c907 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/rnn/test_rnn_nets_static.py @@ -0,0 +1,470 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +paddle.set_default_dtype("float64") +from paddle.fluid.layers import sequence_mask + +import numpy as np +import unittest + +from convert import convert_params_for_net_static +from rnn_numpy import SimpleRNN, LSTM, GRU + + +class TestSimpleRNN(unittest.TestCase): + def __init__(self, time_major=True, direction="forward", place="cpu"): + super(TestSimpleRNN, self).__init__("runTest") + self.time_major = time_major + self.direction = direction + self.num_directions = 2 if direction == "bidirectional" else 1 + self.place = paddle.CPUPlace() if place == "cpu" \ + else paddle.CUDAPlace(0) + + def setUp(self): + rnn1 = SimpleRNN( + 16, 32, 2, time_major=self.time_major, direction=self.direction) + + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + rnn2 = paddle.nn.SimpleRNN( + 16, + 32, + 2, + time_major=self.time_major, + direction=self.direction) + + place = self.place + exe = paddle.static.Executor(place) + scope = paddle.fluid.Scope() + with paddle.static.scope_guard(scope): + exe.run(sp) + convert_params_for_net_static(rnn1, rnn2, place) + + self.mp = mp + self.sp = sp + self.rnn1 = rnn1 + self.rnn2 = rnn2 + + self.place = place + self.executor = exe + self.scope = scope + + def test_with_initial_state(self): + mp = self.mp.clone().clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + prev_h = np.random.randn(2 * self.num_directions, 4, 32) + + y1, h1 = rnn1(x, prev_h) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, -1, 16], + dtype=paddle.framework.get_default_dtype()) + init_h = paddle.data( + "init_h", [2 * self.num_directions, -1, 32], + dtype=paddle.framework.get_default_dtype()) + y, h = rnn2(x_data, init_h) + + feed_dict = {x_data.name: x, init_h.name: prev_h} + with paddle.static.scope_guard(scope): + y2, h2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h]) + + np.testing.assert_allclose(y1, y2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + + def test_with_zero_state(self): + mp = self.mp.clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + + y1, h1 = rnn1(x) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, -1, 16], + dtype=paddle.framework.get_default_dtype()) + y, h = rnn2(x_data) + + feed_dict = {x_data.name: x} + + with paddle.static.scope_guard(scope): + y2, h2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h]) + + np.testing.assert_allclose(y1, y2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + + def test_with_input_lengths(self): + mp = self.mp.clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + sequence_length = np.array([12, 10, 9, 8], dtype=np.int64) + + y1, h1 = rnn1(x, sequence_length=sequence_length) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, -1, 16], + dtype=paddle.framework.get_default_dtype()) + seq_len = paddle.data("seq_len", [-1], dtype="int64") + mask = sequence_mask(seq_len, dtype=paddle.get_default_dtype()) + if self.time_major: + mask = paddle.transpose(mask, [1, 0]) + y, h = rnn2(x_data, sequence_length=seq_len) + y = paddle.multiply(y, mask, axis=0) + + feed_dict = {x_data.name: x, seq_len.name: sequence_length} + + with paddle.static.scope_guard(scope): + y2, h2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h]) + + np.testing.assert_allclose(y1, y2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + + def runTest(self): + self.test_with_initial_state() + self.test_with_zero_state() + self.test_with_input_lengths() + + +class TestGRU(unittest.TestCase): + def __init__(self, time_major=True, direction="forward", place="cpu"): + super(TestGRU, self).__init__("runTest") + self.time_major = time_major + self.direction = direction + self.num_directions = 2 if direction == "bidirectional" else 1 + self.place = paddle.CPUPlace() if place == "cpu" \ + else paddle.CUDAPlace(0) + + def setUp(self): + rnn1 = GRU(16, + 32, + 2, + time_major=self.time_major, + direction=self.direction) + + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + rnn2 = paddle.nn.GRU(16, + 32, + 2, + time_major=self.time_major, + direction=self.direction) + + place = self.place + exe = paddle.static.Executor(place) + scope = paddle.fluid.Scope() + with paddle.static.scope_guard(scope): + exe.run(sp) + convert_params_for_net_static(rnn1, rnn2, place) + + self.mp = mp + self.sp = sp + self.rnn1 = rnn1 + self.rnn2 = rnn2 + + self.place = place + self.executor = exe + self.scope = scope + + def test_with_initial_state(self): + mp = self.mp.clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + + prev_h = np.random.randn(2 * self.num_directions, 4, 32) + + y1, h1 = rnn1(x, prev_h) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, -1, 16], + dtype=paddle.framework.get_default_dtype()) + init_h = paddle.data( + "init_h", [2 * self.num_directions, -1, 32], + dtype=paddle.framework.get_default_dtype()) + y, h = rnn2(x_data, init_h) + + feed_dict = {x_data.name: x, init_h.name: prev_h} + with paddle.static.scope_guard(scope): + y2, h2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h]) + + np.testing.assert_allclose(y1, y2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + + def test_with_zero_state(self): + mp = self.mp.clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + + y1, h1 = rnn1(x) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, -1, 16], + dtype=paddle.framework.get_default_dtype()) + y, h = rnn2(x_data) + + feed_dict = {x_data.name: x} + + with paddle.static.scope_guard(scope): + y2, h2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h]) + + np.testing.assert_allclose(y1, y2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + + def test_with_input_lengths(self): + mp = self.mp.clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + sequence_length = np.array([12, 10, 9, 8], dtype=np.int64) + + y1, h1 = rnn1(x, sequence_length=sequence_length) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, -1, 16], + dtype=paddle.framework.get_default_dtype()) + seq_len = paddle.data("seq_len", [-1], dtype="int64") + mask = sequence_mask(seq_len, dtype=paddle.get_default_dtype()) + if self.time_major: + mask = paddle.transpose(mask, [1, 0]) + y, h = rnn2(x_data, sequence_length=seq_len) + y = paddle.multiply(y, mask, axis=0) + + feed_dict = {x_data.name: x, seq_len.name: sequence_length} + + with paddle.static.scope_guard(scope): + y2, h2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h]) + + np.testing.assert_allclose(y1, y2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + + def runTest(self): + self.test_with_initial_state() + self.test_with_zero_state() + + +class TestLSTM(unittest.TestCase): + def __init__(self, time_major=True, direction="forward", place="cpu"): + super(TestLSTM, self).__init__("runTest") + self.time_major = time_major + self.direction = direction + self.num_directions = 2 if direction == "bidirectional" else 1 + self.place = paddle.CPUPlace() if place == "cpu" \ + else paddle.CUDAPlace(0) + + def setUp(self): + rnn1 = LSTM( + 16, 32, 2, time_major=self.time_major, direction=self.direction) + + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + rnn2 = paddle.nn.LSTM( + 16, + 32, + 2, + time_major=self.time_major, + direction=self.direction) + + place = self.place + exe = paddle.static.Executor(place) + scope = paddle.fluid.Scope() + with paddle.static.scope_guard(scope): + exe.run(sp) + convert_params_for_net_static(rnn1, rnn2, place) + + self.mp = mp + self.sp = sp + self.rnn1 = rnn1 + self.rnn2 = rnn2 + + self.place = place + self.executor = exe + self.scope = scope + + def test_with_initial_state(self): + mp = self.mp.clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + prev_h = np.random.randn(2 * self.num_directions, 4, 32) + prev_c = np.random.randn(2 * self.num_directions, 4, 32) + + y1, (h1, c1) = rnn1(x, (prev_h, prev_c)) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, -1, 16], + dtype=paddle.framework.get_default_dtype()) + init_h = paddle.data( + "init_h", [2 * self.num_directions, -1, 32], + dtype=paddle.framework.get_default_dtype()) + init_c = paddle.data( + "init_c", [2 * self.num_directions, -1, 32], + dtype=paddle.framework.get_default_dtype()) + y, (h, c) = rnn2(x_data, (init_h, init_c)) + + feed_dict = {x_data.name: x, init_h.name: prev_h, init_c.name: prev_c} + with paddle.static.scope_guard(scope): + y2, h2, c2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h, c]) + + np.testing.assert_allclose(y1, y2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(c1, c2, atol=1e-8, rtol=1e-5) + + def test_with_zero_state(self): + mp = self.mp.clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + + y1, (h1, c1) = rnn1(x) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, -1, 16], + dtype=paddle.framework.get_default_dtype()) + y, (h, c) = rnn2(x_data) + + feed_dict = {x_data.name: x} + + with paddle.static.scope_guard(scope): + y2, h2, c2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h, c]) + + np.testing.assert_allclose(y1, y2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(c1, c2, atol=1e-8, rtol=1e-5) + + def test_with_input_lengths(self): + mp = self.mp.clone() + sp = self.sp + rnn1 = self.rnn1 + rnn2 = self.rnn2 + exe = self.executor + scope = self.scope + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + sequence_length = np.array([12, 10, 9, 8], dtype=np.int64) + + y1, (h1, c1) = rnn1(x, sequence_length=sequence_length) + + with paddle.fluid.unique_name.guard(): + with paddle.static.program_guard(mp, sp): + x_data = paddle.data( + "input", [-1, -1, 16], + dtype=paddle.framework.get_default_dtype()) + seq_len = paddle.data("seq_len", [-1], dtype="int64") + mask = sequence_mask(seq_len, dtype=paddle.get_default_dtype()) + if self.time_major: + mask = paddle.transpose(mask, [1, 0]) + y, (h, c) = rnn2(x_data, sequence_length=seq_len) + y = paddle.multiply(y, mask, axis=0) + + feed_dict = {x_data.name: x, seq_len.name: sequence_length} + + with paddle.static.scope_guard(scope): + y2, h2, c2 = exe.run(mp, feed=feed_dict, fetch_list=[y, h, c]) + + np.testing.assert_allclose(y1, y2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(c1, c2, atol=1e-8, rtol=1e-5) + + def runTest(self): + self.test_with_initial_state() + self.test_with_zero_state() + self.test_with_input_lengths() + + +def load_tests(loader, tests, pattern): + suite = unittest.TestSuite() + devices = ["cpu", "gpu"] if paddle.fluid.is_compiled_with_cuda() \ + else ["cpu"] + for direction in ["forward", "backward", "bidirectional"]: + for time_major in [True, False]: + for device in devices: + for test_class in [TestSimpleRNN, TestLSTM, TestGRU]: + suite.addTest(test_class(time_major, direction, device)) + return suite diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 62d389209baed..645b2115650a1 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -18,6 +18,7 @@ from .layer import norm from .functional import extension from .layer import common +from .layer import rnn from .utils import weight_norm_hook from . import initializer @@ -26,6 +27,7 @@ __all__ += norm.__all__ __all__ += extension.__all__ __all__ += common.__all__ +__all__ += rnn.__all__ __all__ += weight_norm_hook.__all__ # TODO: define alias in nn directory @@ -136,6 +138,7 @@ from .layer.norm import BatchNorm1d #DEFINE_ALIAS from .layer.norm import BatchNorm2d #DEFINE_ALIAS from .layer.norm import BatchNorm3d #DEFINE_ALIAS +from .layer.rnn import * # from .layer.rnn import RNNCell #DEFINE_ALIAS # from .layer.rnn import GRUCell #DEFINE_ALIAS # from .layer.rnn import LSTMCell #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 97a4d5432bdc2..75e2da4cf7e92 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -177,6 +177,8 @@ from .pooling import pool3d #DEFINE_ALIAS from .pooling import adaptive_pool2d #DEFINE_ALIAS from .pooling import adaptive_pool3d #DEFINE_ALIAS +from .rnn import rnn #DEFINE_ALIAS +from .rnn import birnn #DEFINE_ALIAS from .pooling import avg_pool2d #DEFINE_ALIAS from .pooling import max_pool2d #DEFINE_ALIAS from .pooling import avg_pool3d #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/rnn.py b/python/paddle/nn/functional/rnn.py index 520cf44360dc3..b7a97bc5aa303 100644 --- a/python/paddle/nn/functional/rnn.py +++ b/python/paddle/nn/functional/rnn.py @@ -12,10 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# TODO: define function of recurrent neural network +from paddle.fluid.layers.rnn import rnn, birnn -__all__ = [ - # 'gru_unit', - # 'lstm', - # 'lstm_unit' -] +__all__ = ['rnn', 'birnn'] diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 2eb9358f7f1a9..3399e4e34c9e3 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -20,6 +20,7 @@ from . import extension from . import activation from . import norm +from . import rnn from . import vision from . import distance from . import transformer @@ -30,6 +31,7 @@ from .extension import * from .activation import * from .norm import * +from .rnn import * from .vision import * from .transformer import * diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index 4717609503f7f..6f1c5f199ac99 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -12,10 +12,1333 @@ # See the License for the specific language governing permissions and # limitations under the License. -# TODO: define classes of recurrent neural network +import copy +import collections +import itertools +import six +import math +import sys +import warnings +from functools import partial, reduce + +import paddle +from paddle import framework +from paddle.nn import functional as F +from paddle.nn import initializer as I +from paddle.fluid.dygraph import Layer, LayerList +from paddle.fluid.layers import utils +from paddle.fluid.layers.utils import map_structure, flatten, pack_sequence_as +from paddle.fluid.data_feeder import convert_dtype __all__ = [ - # 'RNNCell', - # 'GRUCell', - # 'LSTMCell' + 'RNNCellBase', + 'SimpleRNNCell', + 'LSTMCell', + 'GRUCell', + 'RNN', + 'BiRNN', + 'SimpleRNN', + 'LSTM', + 'GRU', ] + + +def split_states(states, bidirectional=False, state_components=1): + r""" + Split states of RNN network into possibly nested list or tuple of + states of each RNN cells of the RNN network. + + Arguments: + states (Tensor|tuple|list): the concatenated states for RNN network. + When `state_components` is 1, states in a Tensor with shape + `(L*D, N, C)` where `L` is the number of layers of the RNN + network, `D` is the number of directions of the RNN network(1 + for unidirectional RNNs and 2 for bidirectional RNNs), `N` is + the batch size of the input to the RNN network, `C` is the + hidden size of the RNN network. + + When `state_components` is larger than 1, `states` is a tuple of + `state_components` Tensors that meet the requirements described + above. + + For SimpleRNNs and GRUs, `state_components` is 1, and for LSTMs, + `state_components` is 2. + bidirectional (bool): whether the state is of a bidirectional RNN + network. Defaults to False. + state_components (int): the number of the components of the states. see + `states` above. Defaults to 1. + + Returns: + A nested list or tuple of RNN cell states. + If `bidirectional` is True, it can be indexed twice to get an RNN + cell state. The first index indicates the layer, the second index + indicates the direction. + If `bidirectional` is False, it can be indexed once to get an RNN + cell state. The index indicates the layer. + Note that if `state_components` is larger than 1, an RNN cell state + can be indexed one more time to get a tensor of shape(N, C), where + `N` is the batch size of the input to the RNN cell, and `C` is the + hidden size of the RNN cell. + """ + if state_components == 1: + states = paddle.unstack(states) + if not bidirectional: + return states + else: + return list(zip(states[::2], states[1::2])) + else: + assert len(states) == state_components + states = tuple([paddle.unstack(item) for item in states]) + if not bidirectional: + return list(zip(*states)) + else: + states = list(zip(*states)) + return list(zip(states[::2], states[1::2])) + + +def concat_states(states, bidirectional=False, state_components=1): + r""" + Concatenate a possibly nested list or tuple of RNN cell states into a + compact form. + + Arguments: + states (list|tuple): a possibly nested list or tuple of RNN cell + states. + If `bidirectional` is True, it can be indexed twice to get an + RNN cell state. The first index indicates the layer, the second + index indicates the direction. + If `bidirectional` is False, it can be indexed once to get an RNN + cell state. The index indicates the layer. + Note that if `state_components` is larger than 1, an RNN cell + state can be indexed one more time to get a tensor of shape(N, C), + where `N` is the batch size of the input to the RNN cell, and + `C` is the hidden size of the RNN cell. + bidirectional (bool): whether the state is of a bidirectional RNN + network. Defaults to False. + state_components (int): the number of the components of the states. see + `states` above. Defaults to 1. + + Returns: + Concatenated states for RNN network. + When `state_components` is 1, states in a Tensor with shape + `(L\*D, N, C)` where `L` is the number of layers of the RNN + network, `D` is the number of directions of the RNN network(1 for + unidirectional RNNs and 2 for bidirectional RNNs), `N` is the batch + size of the input to the RNN network, `C` is the hidden size of the + RNN network. + + """ + if state_components == 1: + return paddle.stack(flatten(states)) + else: + states = flatten(states) + componnets = [] + for i in range(state_components): + componnets.append(states[i::state_components]) + return [paddle.stack(item) for item in componnets] + + +class RNNCellBase(Layer): + r""" + RNNCellBase is the base class for abstraction representing the calculations + mapping the input and state to the output and new state. It is suitable to + and mostly used in RNN. + """ + + def get_initial_states(self, + batch_ref, + shape=None, + dtype=None, + init_value=0., + batch_dim_idx=0): + r""" + Generate initialized states according to provided shape, data type and + value. + Arguments: + batch_ref (Tensor): A tensor, which shape would be used to + determine the batch size, which is used to generate initial + states. For `batch_ref`'s shape d, `d[batch_dim_idx]` is + treated as batch size. + shape (list|tuple, optional): A (possibly nested structure of) shape[s], + where a shape is a list/tuple of integer). `-1` (for batch size) + will be automatically prepended if a shape does not starts with + it. If None, property `state_shape` will be used. Defaults to + None. + dtype (str|list|tuple, optional): A (possibly nested structure of) + data type[s]. The structure must be same as that of `shape`, + except when all tensors' in states has the same data type, a + single data type can be used. If None and property `cell.state_shape` + is not available, current default floating type of paddle is + used. Defaults to None. + init_value (float, optional): A float value used to initialize states. + Defaults to 0. + batch_dim_idx (int, optional): An integer indicating which + dimension of the of `batch_ref` represents batch. Defaults to 0. + Returns: + init_states (Tensor|tuple|list): tensor of the provided shape and + dtype, or list of tensors that each satisfies the requirements, + packed in the same structure as `shape` and `type` does. + """ + # TODO: use inputs and batch_size + batch_ref = flatten(batch_ref)[0] + + def _is_shape_sequence(seq): + if sys.version_info < (3, ): + integer_types = ( + int, + long, ) + else: + integer_types = (int, ) + """For shape, list/tuple of integer is the finest-grained objection""" + if (isinstance(seq, list) or isinstance(seq, tuple)): + if reduce(lambda flag, x: isinstance(x, integer_types) and flag, + seq, True): + return False + # TODO: Add check for the illegal + if isinstance(seq, dict): + return True + return (isinstance(seq, collections.Sequence) and + not isinstance(seq, six.string_types)) + + class Shape(object): + def __init__(self, shape): + self.shape = shape if shape[0] == -1 else ([-1] + list(shape)) + + # nested structure of shapes + states_shapes = self.state_shape if shape is None else shape + is_sequence_ori = utils.is_sequence + utils.is_sequence = _is_shape_sequence + states_shapes = map_structure(lambda shape: Shape(shape), states_shapes) + utils.is_sequence = is_sequence_ori + + # nested structure of dtypes + try: + states_dtypes = self.state_dtype if dtype is None else dtype + except NotImplementedError: + states_dtypes = framework.get_default_dtype() + if len(flatten(states_dtypes)) == 1: + dtype = flatten(states_dtypes)[0] + states_dtypes = map_structure(lambda shape: dtype, states_shapes) + + init_states = map_structure( + lambda shape, dtype: paddle.fluid.layers.fill_constant_batch_size_like( + input=batch_ref, + shape=shape.shape, + dtype=dtype, + value=init_value, + input_dim_idx=batch_dim_idx), states_shapes, states_dtypes) + return init_states + + @property + def state_shape(self): + r""" + Abstract method (property). + Used to initialize states. + A (possiblely nested structure of) shape[s], where a shape is a + list/tuple of integers (-1 for batch size would be automatically + inserted into a shape if shape is not started with it). + Not necessary to be implemented if states are not initialized by + `get_initial_states` or the `shape` argument is provided when using + `get_initial_states`. + """ + raise NotImplementedError( + "Please add implementaion for `state_shape` in the used cell.") + + @property + def state_dtype(self): + r""" + Abstract method (property). + Used to initialize states. + A (possiblely nested structure of) data types[s]. The structure must be + same as that of `shape`, except when all tensors' in states has the same + data type, a signle data type can be used. + Not necessary to be implemented if states are not initialized + by `get_initial_states` or the `dtype` argument is provided when using + `get_initial_states`. + """ + raise NotImplementedError( + "Please add implementaion for `state_dtype` in the used cell.") + + +class SimpleRNNCell(RNNCellBase): + r""" + Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it + computes the outputs and updates states. + + The formula used is as follows: + + .. math:: + h_{t} & = \mathrm{tanh}(W_{ih}x_{t} + b_{ih} + W_{hh}h{t-1} + b_{hh}) + y_{t} & = h_{t} + + where :math:`\sigma` is the sigmoid fucntion, and \* is the elemetwise + multiplication operator. + + Please refer to `Finding Structure in Time + `_ for more details. + + Arguments: + input_size (int): The input size. + hidden_size (int): The hidden size. + activation (str, optional): The activation in the SimpleRNN cell. + It can be `tanh` or `relu`. Defaults to `tanh`. + weight_ih_attr (ParamAttr, optional): The parameter attribute for + `weight_ih`. Default: None. + weight_hh_attr(ParamAttr, optional): The parameter attribute for + `weight_hh`. Default: None. + bias_ih_attr (ParamAttr, optional): The parameter attribute for the + `bias_ih`. Default: None. + bias_ih_attr (ParamAttr, optional): The parameter attribute for the + `bias_hh`. Default: None. + name (str, optional): Name for the operation (optional, default is + None). For more information, please refer to :ref:`api_guide_Name`. + + Parameters: + weight_ih (Parameter): shape (hidden_size, input_size), input to hidden + weight, corresponding to :math:`W_{ih}` in the formula. + weight_hh (Parameter): shape (hidden_size, hidden_size), hidden to + hidden weight, corresponding to :math:`W_{hh}` in the formula. + bias_ih (Parameter): shape (hidden_size, ), input to hidden bias, + corresponding to :math:`b_{ih}` in the formula. + bias_hh (Parameter): shape (hidden_size, ), hidden to hidden bias, + corresponding to :math:`b_{hh}` in the formula. + + Inputs: + inputs (Tensor): shape `[batch_size, input_size]`, the input, + corresponding to :math:`x_t` in the formula. + states (Tensor, optional): shape `[batch_size, hidden_size]`, the + previous hidden state, corresponding to :math:`h_{t-1}` in the + formula. When states is None, zero state is used. Defaults to + None. + + Returns: + (outputs, new_states) + outputs (Tensor): shape `[batch_size, hidden_size]`, the output, + corresponding to :math:`h_{t}` in the formula. + states (Tensor): shape `[batch_size, hidden_size]`, the new hidden + state, corresponding to :math:`h_{t}` in the formula. + + Notes: + All the weights and bias are initialized with `Uniform(-std, std)` by + default. Where std = :math:`\frac{1}{\sqrt{hidden_size}}`. For more + information about parameter initialization, please refer to + :ref:`api_fluid_ParamAttr`. + + Examples: + + .. code-block:: python + + import paddle + paddle.disable_static() + + x = paddle.randn((4, 16)) + prev_h = paddle.randn((4, 32)) + + cell = paddle.nn.SimpleRNNCell(16, 32) + y, h = cell(x, prev_h) + + """ + + def __init__(self, + input_size, + hidden_size, + activation="tanh", + weight_ih_attr=None, + weight_hh_attr=None, + bias_ih_attr=None, + bias_hh_attr=None, + name=None): + super(SimpleRNNCell, self).__init__() + std = 1.0 / math.sqrt(hidden_size) + self.weight_ih = self.create_parameter( + (hidden_size, input_size), + weight_ih_attr, + default_initializer=I.Uniform(-std, std)) + self.weight_hh = self.create_parameter( + (hidden_size, hidden_size), + weight_hh_attr, + default_initializer=I.Uniform(-std, std)) + self.bias_ih = self.create_parameter( + (hidden_size, ), + bias_ih_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std)) + self.bias_hh = self.create_parameter( + (hidden_size, ), + bias_hh_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std)) + + self.input_size = input_size + self.hidden_size = hidden_size + if activation not in ["tanh", "relu"]: + raise ValueError( + "activation for SimpleRNNCell should be tanh or relu, " + "but get {}".format(activation)) + self.activation = activation + self._activation_fn = paddle.tanh \ + if activation == "tanh" \ + else F.relu + + def forward(self, inputs, states=None): + if states is None: + states = self.get_initial_states(inputs, self.state_shape) + pre_h = states + i2h = paddle.matmul(inputs, self.weight_ih, transpose_y=True) + if self.bias_ih is not None: + i2h += self.bias_ih + h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True) + if self.bias_hh is not None: + h2h += self.bias_hh + h = self._activation_fn(i2h + h2h) + return h, h + + @property + def state_shape(self): + return (self.hidden_size, ) + + +class LSTMCell(RNNCellBase): + r""" + Long-Short Term Memory(LSTM) RNN cell. Given the inputs and previous states, + it computes the outputs and updates states. + + The formula used is as follows: + + .. math:: + i_{t} & = \sigma(W_{ii}x_{t} + b_{ii} + W_{hi}h_{t-1} + b_{hi}) + f_{t} & = \sigma(W_{if}x_{t} + b_{if} + W_{hf}h_{t-1} + b_{hf}) + o_{t} & = \sigma(W_{io}x_{t} + b_{io} + W_{ho}h_{t-1} + b_{ho}) + \\widetilde{c}_{t} & = \\tanh (W_{ig}x_{t} + b_{ig} + W_{hg}h_{t-1} + b_{hg}) + c_{t} & = f_{t} \* c{t-1} + i{t} \* \\widetile{c}_{t} + h_{t} & = o_{t} \* \\tanh(c_{t}) + y_{t} & = h_{t} + + where :math:`\sigma` is the sigmoid fucntion, and \* is the elemetwise + multiplication operator. + + Please refer to `An Empirical Exploration of Recurrent Network Architectures + `_ for more details. + + Arguments: + input_size (int): The input size. + hidden_size (int): The hidden size. + weight_ih_attr(ParamAttr, optional): The parameter attribute for + `weight_ih`. Default: None. + weight_hh_attr(ParamAttr, optional): The parameter attribute for + `weight_hh`. Default: None. + bias_ih_attr (ParamAttr, optional): The parameter attribute for the + `bias_ih`. Default: None. + bias_ih_attr (ParamAttr, optional): The parameter attribute for the + `bias_hh`. Default: None. + name (str, optional): Name for the operation (optional, default is + None). For more information, please refer to :ref:`api_guide_Name`. + + Parameters: + weight_ih (Parameter): shape (4 * hidden_size, input_size), input to + hidden weight, which corresponds to the concatenation of + :math:`W_{ii}, W_{if}, W_{ig}, W_{io}` in the formula. + weight_hh (Parameter): shape (4 * hidden_size, hidden_size), hidden to + hidden weight, which corresponds to the concatenation of + :math:`W_{hi}, W_{hf}, W_{hg}, W_{ho}` in the formula. + bias_ih (Parameter): shape (4 * hidden_size, ), input to hidden bias, + which corresponds to the concatenation of + :math:`b_{ii}, b_{if}, b_{ig}, b_{io}` in the formula. + bias_hh (Parameter): shape (4 * hidden_size, ), hidden to hidden bias, + which corresponds to the concatenation of + :math:`b_{hi}, b_{hf}, b_{hg}, b_{ho}` in the formula. + + Inputs: + inputs (Tensor): shape `[batch_size, input_size]`, the input, + corresponding to :math:`x_t` in the formula. + states (tuple, optional): a tuple of two tensors, each of shape + `[batch_size, hidden_size]`, the previous hidden state, + corresponding to :math:`h_{t-1}, c_{t-1}` in the formula. + When states is None, zero state is used. Defaults to None. + + Returns: + (outputs, new_states) + outputs (Tensor): shape `[batch_size, hidden_size]`, the output, + corresponding to :math:`h_{t}` in the formula. + states (tuple): a tuple of two tensors, each of shape + `[batch_size, hidden_size]`, the new hidden states, + corresponding to :math:`h_{t}, c{t}` in the formula. + + Notes: + All the weights and bias are initialized with `Uniform(-std, std)` by + default. Where std = :math:`\frac{1}{\sqrt{hidden_size}}`. For more + information about parameter initialization, please refer to + :ref:`api_fluid_ParamAttr`. + + Examples: + + .. code-block:: python + + import paddle + paddle.disable_static() + + x = paddle.randn((4, 16)) + prev_h = paddle.randn((4, 32)) + prev_c = paddle.randn((4, 32)) + + cell = paddle.nn.LSTMCell(16, 32) + y, (h, c) = cell(x, (prev_h, prev_c)) + + """ + + def __init__(self, + input_size, + hidden_size, + weight_ih_attr=None, + weight_hh_attr=None, + bias_ih_attr=None, + bias_hh_attr=None, + name=None): + super(LSTMCell, self).__init__() + std = 1.0 / math.sqrt(hidden_size) + self.weight_ih = self.create_parameter( + (4 * hidden_size, input_size), + weight_ih_attr, + default_initializer=I.Uniform(-std, std)) + self.weight_hh = self.create_parameter( + (4 * hidden_size, hidden_size), + weight_hh_attr, + default_initializer=I.Uniform(-std, std)) + self.bias_ih = self.create_parameter( + (4 * hidden_size, ), + bias_ih_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std)) + self.bias_hh = self.create_parameter( + (4 * hidden_size, ), + bias_hh_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std)) + + self.hidden_size = hidden_size + self.input_size = input_size + self._gate_activation = F.sigmoid + self._activation = paddle.tanh + + def forward(self, inputs, states=None): + if states is None: + states = self.get_initial_states(inputs, self.state_shape) + pre_hidden, pre_cell = states + gates = paddle.matmul(inputs, self.weight_ih, transpose_y=True) + if self.bias_ih is not None: + gates = gates + self.bias_ih + gates += paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True) + if self.bias_hh is not None: + gates = gates + self.bias_hh + + chunked_gates = paddle.split(gates, num_or_sections=4, axis=-1) + + i = self._gate_activation(chunked_gates[0]) + f = self._gate_activation(chunked_gates[1]) + o = self._gate_activation(chunked_gates[3]) + c = f * pre_cell + i * self._activation(chunked_gates[2]) + h = o * self._activation(c) + + return h, (h, c) + + @property + def state_shape(self): + r""" + The `state_shape` of LSTMCell is a tuple with two shapes: + `((hidden_size, ), (hidden_size,))`. (-1 for batch size would be + automatically inserted into shape). These two shapes correspond + to :math:`h_{t-1}` and :math:`c_{t-1}` separately. + """ + return ((self.hidden_size, ), (self.hidden_size, )) + + +class GRUCell(RNNCellBase): + r""" + Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, + it computes the outputs and updates states. + + The formula for GRU used is as follows: + + .. math:: + + r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}x_{t} + b_{hr}) + z_{t} & = \sigma(W_{iz)x_{t} + b_{iz} + W_{hz}x_{t} + b_{hz}) + \\widetilde{h}_{t} & = \\tanh(W_{ic)x_{t} + b_{ic} + r_{t} \* (W_{hc}x_{t} + b{hc})) + h_{t} & = z_{t} \* h_{t-1} + (1 - z_{t}) \* \\widetilde{h}_{t} + y_{t} & = h_{t} + + where :math:`\sigma` is the sigmoid fucntion, and \* is the elemetwise + multiplication operator. + + Please refer to `An Empirical Exploration of Recurrent Network Architectures + `_ for more details. + + Parameters: + input_size (int): The input size.. + hidden_size (int): The hidden size. + weight_ih_attr(ParamAttr, optional): The parameter attribute for + `weight_ih`. Default: None. + weight_hh_attr(ParamAttr, optional): The parameter attribute for + `weight_hh`. Default: None. + bias_ih_attr (ParamAttr, optional): The parameter attribute for the + `bias_ih`. Default: None. + bias_ih_attr (ParamAttr, optional): The parameter attribute for the + `bias_hh`. Default: None. + name (str, optional): Name for the operation (optional, default is + None). For more information, please refer to :ref:`api_guide_Name`. + + Parameters: + weight_ih (Parameter): shape (3 * hidden_size, input_size), input to + hidden weight, which corresponds to the concatenation of + :math:`W_{ir}, W_{iz}, W_{ic}` in the formula. + weight_hh (Parameter): shape (3 * hidden_size, hidden_size), hidden to + hidden weight, which corresponds to the concatenation of + :math:`W_{hr}, W_{hz}, W_{hc}` in the formula. + bias_ih (Parameter): shape (3 * hidden_size, ), input to hidden bias, + which corresponds to the concatenation of + :math:`b_{ir}, b_{iz}, b_{ic}` in the formula. + bias_hh (Parameter): shape (3 * hidden_size, ), hidden to hidden bias, + which corresponds to the concatenation of + :math:`b_{hr}, b_{hz}, b_{hc}` in the formula. + + Inputs: + inputs (Tensor): A tensor with shape `[batch_size, input_size]`, + corresponding to :math:`x_t` in the formula. + states (Tensor): A tensor with shape `[batch_size, hidden_size]`. + corresponding to :math:`h_{t-1}` in the formula. + + Returns: + (outputs, new_states) + outputs (Tensor): shape `[batch_size, hidden_size]`, the output, + corresponding to :math:`h_{t}` in the formula. + states (Tensor): shape `[batch_size, hidden_size]`, the new hidden + state, corresponding to :math:`h_{t}` in the formula. + + Notes: + All the weights and bias are initialized with `Uniform(-std, std)` by + default. Where std = :math:`\frac{1}{\sqrt{hidden_size}}`. For more + information about parameter initialization, please refer to + :ref:`api_fluid_ParamAttr`. + + Examples: + + .. code-block:: python + + import paddle + paddle.disable_static() + + x = paddle.randn((4, 16)) + prev_h = paddle.randn((4, 32)) + + cell = paddle.nn.GRUCell(16, 32) + y, h = cell(x, prev_h) + + """ + + def __init__(self, + input_size, + hidden_size, + weight_ih_attr=None, + weight_hh_attr=None, + bias_ih_attr=None, + bias_hh_attr=None, + name=None): + super(GRUCell, self).__init__() + std = 1.0 / math.sqrt(hidden_size) + self.weight_ih = self.create_parameter( + (3 * hidden_size, input_size), + weight_ih_attr, + default_initializer=I.Uniform(-std, std)) + self.weight_hh = self.create_parameter( + (3 * hidden_size, hidden_size), + weight_hh_attr, + default_initializer=I.Uniform(-std, std)) + self.bias_ih = self.create_parameter( + (3 * hidden_size, ), + bias_ih_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std)) + self.bias_hh = self.create_parameter( + (3 * hidden_size, ), + bias_hh_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std)) + + self.hidden_size = hidden_size + self.input_size = input_size + self._gate_activation = F.sigmoid + self._activation = paddle.tanh + + def forward(self, inputs, states=None): + if states is None: + states = self.get_initial_states(inputs, self.state_shape) + + pre_hidden = states + x_gates = paddle.matmul(inputs, self.weight_ih, transpose_y=True) + if self.bias_ih is not None: + x_gates = x_gates + self.bias_ih + h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True) + if self.bias_hh is not None: + h_gates = h_gates + self.bias_hh + + x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1) + h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1) + + r = self._gate_activation(x_r + h_r) + z = self._gate_activation(x_z + h_z) + c = self._activation(x_c + r * h_c) # apply reset gate after mm + h = (pre_hidden - c) * z + c + + return h, h + + @property + def state_shape(self): + r""" + The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch + size would be automatically inserted into shape). The shape corresponds + to the shape of :math:`h_{t-1}`. + """ + return (self.hidden_size, ) + + +class RNN(Layer): + r""" + Wrapper for RNN, which creates a recurrent neural network with an RNN cell. + It performs :code:`cell.forward()` repeatedly until reaches to the maximum + length of `inputs`. + + Arguments: + cell(RNNCellBase): An instance of `RNNCellBase`. + is_reverse (bool, optional): Indicate whether to calculate in the reverse + order of input sequences. Defaults to False. + time_major (bool): Whether the first dimension of the input means the + time steps. Defaults to False. + + Inputs: + inputs (Tensor): A (possibly nested structure of) tensor[s]. The input + sequences. + If time major is True, the shape is `[batch_size, time_steps, input_size]` + If time major is False, the shape is [time_steps, batch_size, input_size]` + where `input_size` is the input size of the cell. + initial_states (Tensor|list|tuple, optional): Tensor of a possibly + nested structure of tensors, representing the initial state for + the rnn cell. If not provided, `cell.get_initial_states` would be + called to produce the initial states. Defaults to None. + sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64 + or int32. The valid lengths of input sequences. Defaults to None. + If `sequence_length` is not None, the inputs are treated as + padded sequences. In each input sequence, elements whose time step + index are not less than the valid length are treated as paddings. + **kwargs: Additional keyword arguments to pass to `forward` of the cell. + + Returns: + (outputs, final_states) + outputs (Tensor|list|tuple): the output sequences. + If `time_major` is True, the shape is + `[time_steps, batch_size, hidden_size]`, else + `[batch_size, time_steps, hidden_size]`. + final_states (Tensor|list|tuple): final states of the cell. Tensor or + a possibly nested structure of tensors which has the same structure + with intial state. Each tensor in final states has the same shape + and dtype as the corresponding tensor in initial states. + + Notes: + This class is a low level API for wrapping rnn cell into a RNN network. + Users should take care of the state of the cell. If `initial_states` is + passed to the `forward` method, make sure that it satisfies the + requirements of the cell. + + Examples: + + .. code-block:: python + + import paddle + paddle.disable_static() + + inputs = paddle.rand((4, 23, 16)) + prev_h = paddle.randn((4, 32)) + + cell = paddle.nn.SimpleRNNCell(16, 32) + rnn = paddle.nn.RNN(cell) + outputs, final_states = rnn(inputs, prev_h) + + """ + + def __init__(self, cell, is_reverse=False, time_major=False): + super(RNN, self).__init__() + self.cell = cell + if not hasattr(self.cell, "call"): + # for non-dygraph mode, `rnn` api uses cell.call + self.cell.call = self.cell.forward + self.is_reverse = is_reverse + self.time_major = time_major + + def forward(self, + inputs, + initial_states=None, + sequence_length=None, + **kwargs): + if initial_states is None: + initial_states = self.cell.get_initial_states( + batch_ref=inputs, + dtype=inputs.dtype, + batch_dim_idx=self.batch_index) + + final_outputs, final_states = F.rnn(self.cell, + inputs, + initial_states=initial_states, + sequence_length=sequence_length, + time_major=self.time_major, + is_reverse=self.is_reverse, + **kwargs) + return final_outputs, final_states + + +class BiRNN(Layer): + r""" + Wrapper for bidirectional RNN, which builds a bidiretional RNN given the + forward rnn cell and backward rnn cell. A BiRNN applies forward RNN and + backward RNN with coresponding cells separately and concats the outputs + along the last axis. + + Arguments: + cell_fw (RNNCellBase): A RNNCellBase instance used for forward RNN. + cell_bw (RNNCellBase): A RNNCellBase instance used for backward RNN. + time_major (bool): Whether the first dimension of the input means the + time steps. Defaults to False. + + Inputs: + inputs (Tensor): the input sequences of both RNN. + If time_major is True, the shape of is + `[time_steps, batch_size, input_size]`, else the shape is + `[batch_size, time_steps, input_size]`, where input_size is the + input size of both cells. + initial_states (list|tuple, optional): A tuple/list of the initial + states of the forward cell and backward cell. Defaults to None. + If not provided, `cell.get_initial_states` would be called to + produce the initial states for each cell. Defaults to None. + sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64 + or int32. The valid lengths of input sequences. Defaults to None. + If `sequence_length` is not None, the inputs are treated as + padded sequences. In each input sequence, elements whose time step + index are not less than the valid length are treated as paddings. + **kwargs: Additional keyword arguments. Arguments passed to `forward` + for each cell. + + Outputs: + (outputs, final_states) + outputs (Tensor): the outputs of the bidirectional RNN. It is the + concatenation of the outputs from the forward RNN and backward + RNN along the last axis. + If time major is True, the shape is `[time_steps, batch_size, size]`, + else the shape is `[batch_size, time_steps, size]`, where size is + `cell_fw.hidden_size + cell_bw.hidden_size`. + final_states (tuple): A tuple of the final states of the forward + cell and backward cell. + + Notes: + This class is a low level API for wrapping rnn cells into a BiRNN + network. Users should take care of the states of the cells. + If `initial_states` is passed to the `forward` method, make sure that + it satisfies the requirements of the cells. + + Examples: + + .. code-block:: python + + import paddle + paddle.disable_static() + + cell_fw = paddle.nn.LSTMCell(16, 32) + cell_bw = paddle.nn.LSTMCell(16, 32) + rnn = paddle.nn.BiRNN(cell_fw, cell_bw) + + inputs = paddle.rand((2, 23, 16)) + outputs, final_states = rnn(inputs) + + """ + + def __init__(self, cell_fw, cell_bw, time_major=False): + super(BiRNN, self).__init__() + self.cell_fw = cell_fw + self.cell_bw = cell_bw + if cell_fw.input_size != cell_bw.input_size: + raise ValueError("input size of forward cell({}) does not equals" + "that of backward cell({})".format( + cell_fw.input_size, cell_bw.input_size)) + for cell in [self.cell_fw, self.cell_bw]: + if not hasattr(cell, "call"): + # for non-dygraph mode, `rnn` api uses cell.call + cell.call = cell.forward + self.time_major = time_major + + def forward(self, + inputs, + initial_states=None, + sequence_length=None, + **kwargs): + if isinstance(initial_states, (list, tuple)): + assert len(initial_states) == 2, \ + "length of initial_states should be 2 when it is a list/tuple" + else: + initial_states = [initial_states, initial_states] + + outputs, final_states = F.birnn(self.cell_fw, self.cell_bw, inputs, + initial_states, sequence_length, + self.time_major, **kwargs) + return outputs, final_states + + +class RNNMixin(LayerList): + r""" + A Mixin class for RNN networks. It provides `forward` method for SimpleRNN, + LSTM and GRU. + """ + + def forward(self, inputs, initial_states=None, sequence_length=None): + batch_index = 1 if self.time_major else 0 + dtype = inputs.dtype + if initial_states is None: + state_shape = (self.num_layers * self.num_directions, -1, + self.hidden_size) + if self.state_components == 1: + initial_states = paddle.fluid.layers.fill_constant_batch_size_like( + inputs, state_shape, dtype, 0, batch_index, 1) + else: + initial_states = tuple([ + paddle.fluid.layers.fill_constant_batch_size_like( + inputs, state_shape, dtype, 0, batch_index, 1) + for _ in range(self.state_components) + ]) + + states = split_states(initial_states, self.num_directions == 2, + self.state_components) + final_states = [] + + for i, rnn_layer in enumerate(self): + if i > 0: + inputs = F.dropout( + inputs, + self.dropout, + training=self.training, + mode="upscale_in_train") + outputs, final_state = rnn_layer(inputs, states[i], sequence_length) + final_states.append(final_state) + inputs = outputs + + final_states = concat_states(final_states, self.num_directions == 2, + self.state_components) + return outputs, final_states + + +class SimpleRNN(RNNMixin): + r""" + Multilayer Elman network(SimpleRNN). It takes input sequences and initial + states as inputs, and returns the output sequences and the final states. + + Each layer inside the SimpleRNN maps the input sequences and initial states + to the output sequences and final states in the following manner: at each + step, it takes step inputs(:math:`x_{t}`) and previous + states(:math:`h_{t-1}`) as inputs, and returns step outputs(:math:`y_{t}`) + and new states(:math:`h_{t}`). + + .. math:: + + h_{t} & = \mathrm{tanh}(W_{ih}x_{t} + b_{ih} + W_{hh}h{t-1} + b_{hh}) + y_{t} & = h_{t} + + where :math:`\sigma` is the sigmoid fucntion, and \* is the elemetwise + multiplication operator. + + Arguments: + input_size (int): The input size for the first layer's cell. + hidden_size (int): The hidden size for each layer's cell. + num_layers (int, optional): Number of layers. Defaults to 1. + activation (str, optional): The activation in each SimpleRNN cell. It can be + `tanh` or `relu`. Defaults to `tanh`. + direction (str, optional): The direction of the network. It can be "forward", + "backward" and "bidirectional". Defaults to "forward". + dropout (float, optional): The droput probability. Dropout is applied to the + input of each layer except for the first layer. Defaults to 0. + time_major (bool, optional): Whether the first dimension of the input means the + time steps. Defaults to False. + weight_ih_attr (ParamAttr, optional): The parameter attribute for + `weight_ih` of each cell. Defaults to None. + weight_hh_attr (ParamAttr, optional): The parameter attribute for + `weight_hh` of each cell. Defaults to None. + bias_ih_attr (ParamAttr, optional): The parameter attribute for the + `bias_ih` of each cells. Defaults to None. + bias_ih_attr (ParamAttr, optional): The parameter attribute for the + `bias_hh` of each cells. Defaults to None. + name (str, optional): Name for the operation (optional, default is + None). For more information, please refer to :ref:`api_guide_Name`. + + Inputs: + inputs (Tensor): the input sequence. + If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`, + else, the shape is `[batch_size, time_steps, hidden_size]`. + initial_states (Tensor, optional): the initial state. The shape is + `[num_lauers * num_directions, batch_size, hidden_size]`. + If initial_state is not given, zero initial states are used. + sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64 + or int32. The valid lengths of input sequences. Defaults to None. + If `sequence_length` is not None, the inputs are treated as + padded sequences. In each input sequence, elements whose time step + index are not less than the valid length are treated as paddings. + + Returns: + (outputs, final_states) + outputs (Tensor): the output sequence. + If `time_major` is True, the shape is + `[time_steps, batch_size, num_directions * hidden_size]`, + else, the shape is + `[batch_size, time_steps, num_directions * hidden_size]`. + Note that `num_directions` is 2 if direction is "bidirectional" + else 1. + final_states (Tensor): final states. The shape is + `[num_lauers * num_directions, batch_size, hidden_size]`. + Note that `num_directions` is 2 if direction is "bidirectional" + else 1. + + Examples: + + .. code-block:: python + + import paddle + paddle.disable_static() + + rnn = paddle.nn.SimpleRNN(16, 32, 2) + + x = paddle.randn((4, 23, 16)) + prev_h = paddle.randn((2, 4, 32)) + y, h = rnn(x, prev_h) + + """ + + def __init__(self, + input_size, + hidden_size, + num_layers=1, + activation="tanh", + direction="forward", + dropout=0., + time_major=False, + weight_ih_attr=None, + weight_hh_attr=None, + bias_ih_attr=None, + bias_hh_attr=None, + name=None): + super(SimpleRNN, self).__init__() + + if direction in ["forward", "backward"]: + is_reverse = direction == "backward" + cell = SimpleRNNCell(input_size, hidden_size, activation, + weight_ih_attr, weight_hh_attr, bias_ih_attr, + bias_hh_attr) + self.append(RNN(cell, is_reverse, time_major)) + for i in range(1, num_layers): + cell = SimpleRNNCell(hidden_size, hidden_size, activation, + weight_ih_attr, weight_hh_attr, + bias_ih_attr, bias_hh_attr) + self.append(RNN(cell, is_reverse, time_major)) + elif direction == "bidirectional": + cell_fw = SimpleRNNCell(input_size, hidden_size, activation, + weight_ih_attr, weight_hh_attr, + bias_ih_attr, bias_hh_attr) + cell_bw = SimpleRNNCell(input_size, hidden_size, activation, + weight_ih_attr, weight_hh_attr, + bias_ih_attr, bias_hh_attr) + self.append(BiRNN(cell_fw, cell_bw, time_major)) + for i in range(1, num_layers): + cell_fw = SimpleRNNCell( + 2 * hidden_size, hidden_size, activation, weight_ih_attr, + weight_hh_attr, bias_ih_attr, bias_hh_attr) + cell_bw = SimpleRNNCell( + 2 * hidden_size, hidden_size, activation, weight_ih_attr, + weight_hh_attr, bias_ih_attr, bias_hh_attr) + self.append(BiRNN(cell_fw, cell_bw, time_major)) + else: + raise ValueError( + "direction should be forward, backward or bidirectional, " + "received direction = {}".format(direction)) + + self.input_size = input_size + self.hidden_size = hidden_size + self.dropout = dropout + self.num_directions = 2 if direction == "bidirectional" else 1 + self.time_major = time_major + self.num_layers = num_layers + self.state_components = 1 + + +class LSTM(RNNMixin): + r""" + Multilayer LSTM. It takes a sequence and an initial state as inputs, and + returns the output sequences and the final states. + + Each layer inside the LSTM maps the input sequences and initial states + to the output sequences and final states in the following manner: at each + step, it takes step inputs(:math:`x_{t}`) and previous + states(:math:`h_{t-1}, c_{t-1}`) as inputs, and returns step + outputs(:math:`y_{t}`) and new states(:math:`h_{t}, c_{t}`). + + .. math:: + + i_{t} & = \sigma(W_{ii}x_{t} + b_{ii} + W_{hi}h_{t-1} + b_{hi}) + f_{t} & = \sigma(W_{if}x_{t} + b_{if} + W_{hf}h_{t-1} + b_{hf}) + o_{t} & = \sigma(W_{io}x_{t} + b_{io} + W_{ho}h_{t-1} + b_{ho}) + \\widetilde{c}_{t} & = \\tanh (W_{ig}x_{t} + b_{ig} + W_{hg}h_{t-1} + b_{hg}) + c_{t} & = f_{t} \* c{t-1} + i{t} \* \\widetile{c}_{t} + h_{t} & = o_{t} \* \\tanh(c_{t}) + y_{t} & = h_{t} + + where :math:`\sigma` is the sigmoid fucntion, and \* is the elemetwise + multiplication operator. + + Arguments: + input_size (int): The input size for the first layer's cell. + hidden_size (int): The hidden size for each layer's cell. + num_layers (int, optional): Number of layers. Defaults to 1. + direction (str, optional): The direction of the network. It can be + "forward", "backward" and "bidirectional". Defaults to "forward". + dropout (float, optional): The droput probability. Dropout is applied + to the input of each layer except for the first layer. Defaults to 0. + time_major (bool, optional): Whether the first dimension of the input + means the time steps. Defaults to False. + weight_ih_attr (ParamAttr, optional): The parameter attribute for + `weight_ih` of each cell. Default: None. + weight_hh_attr (ParamAttr, optional): The parameter attribute for + `weight_hh` of each cell. Default: None. + bias_ih_attr (ParamAttr, optional): The parameter attribute for the + `bias_ih` of each cells. Default: None. + bias_ih_attr (ParamAttr, optional): The parameter attribute for the + `bias_hh` of each cells. Default: None. + name (str, optional): Name for the operation (optional, default is + None). For more information, please refer to :ref:`api_guide_Name`. + + Inputs: + inputs (Tensor): the input sequence. + If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`, + else, the shape is `[batch_size, time_steps, hidden_size]`. + initial_states (tuple, optional): the initial state, a tuple of (h, c), + the shape of each is `[num_lauers * num_directions, batch_size, hidden_size]`. + If initial_state is not given, zero initial states are used. + sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64 + or int32. The valid lengths of input sequences. Defaults to None. + If `sequence_length` is not None, the inputs are treated as + padded sequences. In each input sequence, elements whos time step + index are not less than the valid length are treated as paddings. + + Returns: + (outputs, final_states) + outputs (Tensor): the output sequence. + If `time_major` is True, the shape is + `[time_steps, batch_size, num_directions * hidden_size]`, + If `time_major` is False, the shape is + `[batch_size, time_steps, num_directions * hidden_size]`. + Note that `num_directions` is 2 if direction is "bidirectional" + else 1. + final_states (Tensor): the final state, a tuple of two tensors, h and c. + The shape of each is + `[num_lauers * num_directions, batch_size, hidden_size]`. + Note that `num_directions` is 2 if direction is "bidirectional" + else 1. + + Examples: + + .. code-block:: python + + import paddle + paddle.disable_static() + + rnn = paddle.nn.LSTM(16, 32, 2) + + x = paddle.randn((4, 23, 16)) + prev_h = paddle.randn((2, 4, 32)) + prev_c = paddle.randn((2, 4, 32)) + y, (h, c) = rnn(x, (prev_h, prev_c)) + + """ + + def __init__(self, + input_size, + hidden_size, + num_layers=1, + direction="forward", + dropout=0., + time_major=False, + weight_ih_attr=None, + weight_hh_attr=None, + bias_ih_attr=None, + bias_hh_attr=None, + name=None): + super(LSTM, self).__init__() + + if direction in ["forward", "backward"]: + is_reverse = direction == "backward" + cell = LSTMCell(input_size, hidden_size, weight_ih_attr, + weight_hh_attr, bias_ih_attr, bias_hh_attr) + self.append(RNN(cell, is_reverse, time_major)) + for i in range(1, num_layers): + cell = LSTMCell(hidden_size, hidden_size, weight_ih_attr, + weight_hh_attr, bias_ih_attr, bias_hh_attr) + self.append(RNN(cell, is_reverse, time_major)) + elif direction == "bidirectional": + cell_fw = LSTMCell(input_size, hidden_size, weight_ih_attr, + weight_hh_attr, bias_ih_attr, bias_hh_attr) + cell_bw = LSTMCell(input_size, hidden_size, weight_ih_attr, + weight_hh_attr, bias_ih_attr, bias_hh_attr) + self.append(BiRNN(cell_fw, cell_bw, time_major)) + for i in range(1, num_layers): + cell_fw = LSTMCell(2 * hidden_size, hidden_size, weight_ih_attr, + weight_hh_attr, bias_ih_attr, bias_hh_attr) + cell_bw = LSTMCell(2 * hidden_size, hidden_size, weight_ih_attr, + weight_hh_attr, bias_ih_attr, bias_hh_attr) + self.append(BiRNN(cell_fw, cell_bw, time_major)) + else: + raise ValueError( + "direction should be forward, backward or bidirectional, " + "received direction = {}".format(direction)) + + self.input_size = input_size + self.hidden_size = hidden_size + self.dropout = dropout + self.num_directions = 2 if direction == "bidirectional" else 1 + self.time_major = time_major + self.num_layers = num_layers + self.state_components = 2 + + +class GRU(RNNMixin): + r""" + Multilayer GRU. It takes input sequencse and initial states as inputs, and + returns the output sequences and the final states. + + Each layer inside the GRU maps the input sequences and initial states + to the output sequences and final states in the following manner: at each + step, it takes step inputs(:math:`x_{t}`) and previous + states(:math:`h_{t-1}`) as inputs, and returns step outputs(:math:`y_{t}`) + and new states(:math:`h_{t}`). + + .. math:: + + r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}x_{t} + b_{hr}) + z_{t} & = \sigma(W_{iz)x_{t} + b_{iz} + W_{hz}x_{t} + b_{hz}) + \\widetilde{h}_{t} & = \\tanh(W_{ic)x_{t} + b_{ic} + r_{t} \* (W_{hc}x_{t} + b{hc})) + h_{t} & = z_{t} \* h_{t-1} + (1 - z_{t}) \* \\widetilde{h}_{t} + y_{t} & = h_{t} + + where :math:`\sigma` is the sigmoid fucntion, and \* is the elemetwise + multiplication operator. + + Arguments: + input_size (int): The input size for the first layer's cell. + hidden_size (int): The hidden size for each layer's cell. + num_layers (int, optional): Number of layers. Defaults to 1. + direction (str, optional): The direction of the network. It can be + "forward", "backward" and "bidirectional". Defaults to "forward". + dropout (float, optional): The droput probability. Dropout is applied + to the input of each layer except for the first layer. Defaults to 0. + time_major (bool, optional): Whether the first dimension of the input + means the time steps. Defaults to False. + weight_ih_attr (ParamAttr, optional): The parameter attribute for + `weight_ih` of each cell. Default: None. + weight_hh_attr (ParamAttr, optional): The parameter attribute for + `weight_hh` of each cell. Default: None. + bias_ih_attr (ParamAttr, optional): The parameter attribute for the + `bias_ih` of each cells. Default: None. + bias_ih_attr (ParamAttr, optional): The parameter attribute for the + `bias_hh` of each cells. Default: None. + name (str, optional): Name for the operation (optional, default is + None). For more information, please refer to :ref:`api_guide_Name`. + + Inputs: + inputs (Tensor): the input sequence. + If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`, + else, the shape is `[batch_size, time_steps, hidden_size]`. + initial_states (Tensor, optional): the initial state. The shape is + `[num_lauers * num_directions, batch_size, hidden_size]`. + If initial_state is not given, zero initial states are used. + Defaults to None. + sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64 + or int32. The valid lengths of input sequences. Defaults to None. + If `sequence_length` is not None, the inputs are treated as + padded sequences. In each input sequence, elements whos time step + index are not less than the valid length are treated as paddings. + + Returns: + (outputs, final_states) + outputs (Tensor): the output sequence. + If `time_major` is True, the shape is + `[time_steps, batch_size, num_directions * hidden_size]`, + else, the shape is + `[batch_size, time_steps, num_directions * hidden_size]`. + Note that `num_directions` is 2 if direction is "bidirectional" + else 1. + final_states (Tensor): final states. The shape is + `[num_lauers * num_directions, batch_size, hidden_size]`. + Note that `num_directions` is 2 if direction is "bidirectional" + else 1. + + Examples: + + .. code-block:: python + + import paddle + paddle.disable_static() + + rnn = paddle.nn.GRU(16, 32, 2) + + x = paddle.randn((4, 23, 16)) + prev_h = paddle.randn((2, 4, 32)) + y, h = rnn(x, prev_h) + + """ + + def __init__(self, + input_size, + hidden_size, + num_layers=1, + direction="forward", + dropout=0., + time_major=False, + weight_ih_attr=None, + weight_hh_attr=None, + bias_ih_attr=None, + bias_hh_attr=None, + name=None): + super(GRU, self).__init__() + + if direction in ["forward", "backward"]: + is_reverse = direction == "backward" + cell = GRUCell(input_size, hidden_size, weight_ih_attr, + weight_hh_attr, bias_ih_attr, bias_hh_attr) + self.append(RNN(cell, is_reverse, time_major)) + for i in range(1, num_layers): + cell = GRUCell(hidden_size, hidden_size, weight_ih_attr, + weight_hh_attr, bias_ih_attr, bias_hh_attr) + self.append(RNN(cell, is_reverse, time_major)) + elif direction == "bidirectional": + cell_fw = GRUCell(input_size, hidden_size, weight_ih_attr, + weight_hh_attr, bias_ih_attr, bias_hh_attr) + cell_bw = GRUCell(input_size, hidden_size, weight_ih_attr, + weight_hh_attr, bias_ih_attr, bias_hh_attr) + self.append(BiRNN(cell_fw, cell_bw, time_major)) + for i in range(1, num_layers): + cell_fw = GRUCell(2 * hidden_size, hidden_size, weight_ih_attr, + weight_hh_attr, bias_ih_attr, bias_hh_attr) + cell_bw = GRUCell(2 * hidden_size, hidden_size, weight_ih_attr, + weight_hh_attr, bias_ih_attr, bias_hh_attr) + self.append(BiRNN(cell_fw, cell_bw, time_major)) + else: + raise ValueError( + "direction should be forward, backward or bidirectional, " + "received direction = {}".format(direction)) + + self.input_size = input_size + self.hidden_size = hidden_size + self.dropout = dropout + self.num_directions = 2 if direction == "bidirectional" else 1 + self.time_major = time_major + self.num_layers = num_layers + self.state_components = 1 diff --git a/tools/wlist.json b/tools/wlist.json index c6114918e5932..ce6f5fb176b5b 100644 --- a/tools/wlist.json +++ b/tools/wlist.json @@ -148,7 +148,20 @@ "Callback.on_eval_batch_end", "Callback.on_test_batch_begin", "Callback.on_test_batch_end", - "Model.prepare" + "Model.prepare", + "SimpleRNNCell", + "SimpleRNNCell.forward", + "LSTMCell", + "LSTMCell.forward", + "GRUCell", + "GRUCell.forward", + "SimpleRNN", + "GRU", + "LSTM", + "RNN", + "BiRNN", + "RNNCellBase", + "RNNCellBase.get_initial_states" ], "wlist_no_op_pass":[ "gelu", From 8071d2307336ebb9eaf9ec8f5e63ecff26a4ba20 Mon Sep 17 00:00:00 2001 From: Zhou Wei <52485244+zhouwei25@users.noreply.github.com> Date: Thu, 27 Aug 2020 21:14:59 +0800 Subject: [PATCH 3/8] fix bug that can't print int8_t (#26712) fix bug that can't print int8_t --- paddle/fluid/framework/tensor_util.cc | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 3b3271fc5b936..c3626c5c9e050 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -913,10 +913,20 @@ std::ostream& print_tensor(std::ostream& os, const framework::Tensor& tensor) { auto element_num = tensor.numel(); os << " - data: ["; - if (element_num > 0) { - os << inspect[0]; - for (int j = 1; j < element_num; ++j) { - os << " " << inspect[j]; + // Note: int8_t && uint8_t is typedf of char, ostream unable to print properly + if (typeid(int8_t) == typeid(T) || typeid(uint8_t) == typeid(T)) { + if (element_num > 0) { + os << signed(inspect[0]); + for (int j = 1; j < element_num; ++j) { + os << " " << signed(inspect[j]); + } + } + } else { + if (element_num > 0) { + os << inspect[0]; + for (int j = 1; j < element_num; ++j) { + os << " " << inspect[j]; + } } } os << "]"; From 1c898b66d6c668048ab77ee33b2457687b8b36be Mon Sep 17 00:00:00 2001 From: Wilber Date: Thu, 27 Aug 2020 22:56:33 +0800 Subject: [PATCH 4/8] add bug fix enum. (#26736) --- paddle/fluid/framework/op_version_registry.h | 14 +++++++++++++- paddle/fluid/framework/op_version_registry_test.cc | 4 ++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/op_version_registry.h b/paddle/fluid/framework/op_version_registry.h index 2a85c60305bd3..79b15fc87d0b0 100644 --- a/paddle/fluid/framework/op_version_registry.h +++ b/paddle/fluid/framework/op_version_registry.h @@ -34,7 +34,8 @@ struct OpUpdateRecord { kModifyAttr, kNewAttr, kNewInput, - kNewOutput + kNewOutput, + kBugfixWithBehaviorChanged, }; Type type_; std::string remark_; @@ -82,6 +83,11 @@ struct NewOutput : OpUpdateRecord { std::string name_; }; +struct BugfixWithBehaviorChanged : OpUpdateRecord { + explicit BugfixWithBehaviorChanged(const std::string& remark) + : OpUpdateRecord({Type::kBugfixWithBehaviorChanged, remark}) {} +}; + class OpVersionDesc { public: OpVersionDesc& ModifyAttr(const std::string& name, const std::string& remark, @@ -110,6 +116,12 @@ class OpVersionDesc { return *this; } + OpVersionDesc& BugfixWithBehaviorChanged(const std::string& remark) { + infos_.push_back(std::shared_ptr( + new compatible::BugfixWithBehaviorChanged(remark))); + return *this; + } + private: std::vector> infos_; }; diff --git a/paddle/fluid/framework/op_version_registry_test.cc b/paddle/fluid/framework/op_version_registry_test.cc index 052bf3a4b882b..80ad51ad07b5a 100644 --- a/paddle/fluid/framework/op_version_registry_test.cc +++ b/paddle/fluid/framework/op_version_registry_test.cc @@ -23,6 +23,10 @@ namespace compatible { TEST(test_operator_version, test_operator_version) { REGISTER_OP_VERSION(test__) + .AddCheckpoint( + R"ROC(Fix the bug of reshape op, support the case of axis < 0)ROC", + framework::compatible::OpVersionDesc().BugfixWithBehaviorChanged( + "Support the case of axis < 0")) .AddCheckpoint( R"ROC( Upgrade reshape, modified one attribute [axis] and add a new attribute [size]. From f9066e6a6fcfaea2c7bdf19762eb630c7a0a7985 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Fri, 28 Aug 2020 07:35:11 +0800 Subject: [PATCH 5/8] Update the demo code and the doc of varbase.backward. (#26506) * update the demo code and the doc of varbase.backward. * update the doc of the fake interface `paddle.fluid.Variable`. * remove BackwardStrategy. --- paddle/fluid/imperative/backward_strategy.h | 33 ---------- paddle/fluid/imperative/basic_engine.cc | 9 +-- paddle/fluid/imperative/basic_engine.h | 6 +- .../fluid/imperative/partial_grad_engine.cc | 22 +++---- paddle/fluid/imperative/partial_grad_engine.h | 4 +- paddle/fluid/imperative/tests/test_tracer.cc | 6 +- paddle/fluid/platform/flags.cc | 13 ++++ .../pybind/global_value_getter_setter.cc | 3 +- paddle/fluid/pybind/imperative.cc | 64 ++----------------- python/paddle/__init__.py | 1 - python/paddle/fluid/__init__.py | 1 + python/paddle/fluid/dygraph/__init__.py | 4 -- .../paddle/fluid/dygraph/backward_strategy.py | 19 ------ python/paddle/fluid/dygraph/base.py | 18 ++---- .../fluid/dygraph/varbase_patch_methods.py | 49 ++++++-------- python/paddle/fluid/framework.py | 43 ++++++------- .../unittests/test_directory_migration.py | 4 +- .../unittests/test_imperative_auto_prune.py | 8 +-- .../tests/unittests/test_imperative_basic.py | 20 +++--- .../tests/unittests/test_imperative_deepcf.py | 5 +- .../unittests/test_imperative_double_grad.py | 15 ++--- .../tests/unittests/test_imperative_gan.py | 7 +- .../test_imperative_hook_for_layer.py | 6 +- ..._imperative_lod_tensor_to_selected_rows.py | 7 +- .../test_imperative_mnist_sorted_gradient.py | 5 +- .../test_imperative_ocr_attention_model.py | 5 +- ...test_imperative_ptb_rnn_sorted_gradient.py | 5 +- .../test_imperative_resnet_sorted_gradient.py | 5 +- .../test_imperative_selected_rows.py | 14 ++-- ..._imperative_selected_rows_to_lod_tensor.py | 7 +- ...perative_star_gan_with_gradient_penalty.py | 7 +- .../test_imperative_static_runner_mnist.py | 5 +- .../test_imperative_static_runner_while.py | 6 +- ..._imperative_transformer_sorted_gradient.py | 5 +- .../test_paddle_imperative_double_grad.py | 5 +- python/paddle/framework/__init__.py | 6 +- 36 files changed, 148 insertions(+), 294 deletions(-) delete mode 100644 paddle/fluid/imperative/backward_strategy.h delete mode 100644 python/paddle/fluid/dygraph/backward_strategy.py diff --git a/paddle/fluid/imperative/backward_strategy.h b/paddle/fluid/imperative/backward_strategy.h deleted file mode 100644 index 0f04d6db8e63d..0000000000000 --- a/paddle/fluid/imperative/backward_strategy.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// -// Created by Jiabin on 2019-04-25. -// -#pragma once - -namespace paddle { -namespace imperative { -namespace detail { - -struct BackwardStrategy { - /* DyGraph now support two kinds of backward strategy, one is sorted sum - * gradient, another is sum gradient once they are created */ - // TODO(jiabin): add more Strategy when we support - bool sorted_sum_gradient_{false}; -}; - -} // namespace detail -} // namespace imperative -} // namespace paddle diff --git a/paddle/fluid/imperative/basic_engine.cc b/paddle/fluid/imperative/basic_engine.cc index de1246883f101..a91f14e56b719 100644 --- a/paddle/fluid/imperative/basic_engine.cc +++ b/paddle/fluid/imperative/basic_engine.cc @@ -30,12 +30,13 @@ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/profiler.h" +DECLARE_bool(sort_sum_gradient); + namespace paddle { namespace imperative { -void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy, - bool retain_graph) { - backward_strategy_ = strategy; +void BasicEngine::Init(VarBase* var, bool retain_graph) { + sorted_sum_gradient_ = FLAGS_sort_sum_gradient; retain_graph_ = retain_graph; init_node_ = var->GradVarBase()->GradNode(); var->GradVarBase()->ClearGradNode(); @@ -105,7 +106,7 @@ void BasicEngine::PrepareGradAccumulators(const OpBase& op) { auto& accumulator = accumulators_[var.get()]; if (!accumulator) { - if (backward_strategy_.sorted_sum_gradient_) { + if (sorted_sum_gradient_) { accumulator.reset(new SortedGradientAccumulator(var.get())); } else { accumulator.reset(new EagerGradientAccumulator(var.get())); diff --git a/paddle/fluid/imperative/basic_engine.h b/paddle/fluid/imperative/basic_engine.h index 4d25d81235098..d1aa69f16868d 100644 --- a/paddle/fluid/imperative/basic_engine.h +++ b/paddle/fluid/imperative/basic_engine.h @@ -18,7 +18,6 @@ #include #include #include -#include "paddle/fluid/imperative/backward_strategy.h" #include "paddle/fluid/imperative/engine.h" #include "paddle/fluid/imperative/gradient_accumulator.h" @@ -30,8 +29,7 @@ class OpBase; class BasicEngine : public Engine { public: - void Init(VarBase* var, const detail::BackwardStrategy& strategy, - bool retain_graph = false); + void Init(VarBase* var, bool retain_graph = false); void Execute() override; @@ -46,7 +44,7 @@ class BasicEngine : public Engine { private: std::shared_ptr init_node_; - detail::BackwardStrategy backward_strategy_; + bool sorted_sum_gradient_; std::unordered_map node_deps_; std::unordered_map> accumulators_; diff --git a/paddle/fluid/imperative/partial_grad_engine.cc b/paddle/fluid/imperative/partial_grad_engine.cc index 4f133bf80c790..3afe5af7f6348 100644 --- a/paddle/fluid/imperative/partial_grad_engine.cc +++ b/paddle/fluid/imperative/partial_grad_engine.cc @@ -33,6 +33,8 @@ #include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/string/string_helper.h" +DECLARE_bool(sort_sum_gradient); + namespace paddle { namespace imperative { @@ -529,8 +531,7 @@ class PartialGradTask { const std::vector> &output_targets, const std::vector> &output_grads, const std::vector> &no_grad_vars, - const platform::Place &place, - const detail::BackwardStrategy &strategy, bool create_graph, + const platform::Place &place, bool create_graph, bool retain_graph, bool allow_unused, bool only_inputs); std::vector> Run(); @@ -577,7 +578,7 @@ class PartialGradTask { bool retain_graph_; bool allow_unused_; bool only_inputs_; - detail::BackwardStrategy strategy_; + bool sorted_sum_gradient_{FLAGS_sort_sum_gradient}; }; PartialGradTask::PartialGradTask( @@ -585,15 +586,14 @@ PartialGradTask::PartialGradTask( const std::vector> &output_targets, const std::vector> &output_grads, const std::vector> &no_grad_vars, - const platform::Place &place, const detail::BackwardStrategy &strategy, - bool create_graph, bool retain_graph, bool allow_unused, bool only_inputs) { + const platform::Place &place, bool create_graph, bool retain_graph, + bool allow_unused, bool only_inputs) { input_targets_ = input_targets; place_ = place; create_graph_ = create_graph; retain_graph_ = retain_graph; allow_unused_ = allow_unused; only_inputs_ = only_inputs; - strategy_ = strategy; PADDLE_ENFORCE_EQ(only_inputs_, true, platform::errors::Unimplemented( @@ -981,7 +981,7 @@ void PartialGradTask::PrepareInitialGradientAccumulators(const OpBase *op) { if (!accumulator) { accumulator.reset(new GradientAccumulationInfo( - var, strategy_.sorted_sum_gradient_, create_graph_)); + var, sorted_sum_gradient_, create_graph_)); } accumulator->IncreaseTotalRefCnt(); @@ -1033,11 +1033,11 @@ PartialGradEngine::PartialGradEngine( const std::vector> &output_targets, const std::vector> &output_grads, const std::vector> &no_grad_vars, - const platform::Place &place, const detail::BackwardStrategy &strategy, - bool create_graph, bool retain_graph, bool allow_unused, bool only_inputs) + const platform::Place &place, bool create_graph, bool retain_graph, + bool allow_unused, bool only_inputs) : task_(new PartialGradTask(input_targets, output_targets, output_grads, - no_grad_vars, place, strategy, create_graph, - retain_graph, allow_unused, only_inputs)) {} + no_grad_vars, place, create_graph, retain_graph, + allow_unused, only_inputs)) {} PartialGradEngine::~PartialGradEngine() { Clear(); } diff --git a/paddle/fluid/imperative/partial_grad_engine.h b/paddle/fluid/imperative/partial_grad_engine.h index a7f28c49ec395..b5da39f8d4237 100644 --- a/paddle/fluid/imperative/partial_grad_engine.h +++ b/paddle/fluid/imperative/partial_grad_engine.h @@ -16,7 +16,6 @@ #include #include -#include "paddle/fluid/imperative/backward_strategy.h" #include "paddle/fluid/imperative/engine.h" #include "paddle/fluid/platform/place.h" @@ -33,8 +32,7 @@ class PartialGradEngine : public Engine { const std::vector> &output_targets, const std::vector> &output_grads, const std::vector> &no_grad_vars, - const platform::Place &place, - const detail::BackwardStrategy &strategy, bool create_graph, + const platform::Place &place, bool create_graph, bool retain_graph, bool allow_unused, bool only_inputs); ~PartialGradEngine(); diff --git a/paddle/fluid/imperative/tests/test_tracer.cc b/paddle/fluid/imperative/tests/test_tracer.cc index 3c3ec2e626339..892acffb712d9 100644 --- a/paddle/fluid/imperative/tests/test_tracer.cc +++ b/paddle/fluid/imperative/tests/test_tracer.cc @@ -240,9 +240,8 @@ TEST(test_tracer, test_trace_op_with_multi_device_inputs) { framework::AttributeMap reduce_attr_map; tracer.TraceOp("reduce_sum", reduce_in, reduce_out, reduce_attr_map, gpu_place, true); - detail::BackwardStrategy back_st; imperative::BasicEngine engine; - engine.Init(reduce_sum_out.get(), back_st); + engine.Init(reduce_sum_out.get()); engine.Execute(); framework::LoDTensor rlt; @@ -356,9 +355,8 @@ TEST(test_tracer, test_var_without_grad_var) { ASSERT_EQ(y_in->GradVarBase()->GradOpNum(), 0UL); ASSERT_EQ(vout->GradVarBase()->GradOpNum(), 1UL); - detail::BackwardStrategy back_st; imperative::BasicEngine engine; - engine.Init(vout.get(), back_st); + engine.Init(vout.get()); engine.Execute(); // check the grad diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index 8667375c6f272..af8798a4b7cf5 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -508,3 +508,16 @@ DEFINE_int32( "summary will be shown." "If FLAGS_call_stack_level == 2, the python stack, c++ stack, and " "error message summary will be shown."); + +/** + * Debug related FLAG + * Name: sort_sum_gradient + * Since Version: 2.0.0 + * Value Range: bool, default=false + * Example: + * Note: If True, gradients are summed by the reverse order of + * the forward execution sequence. + */ +DEFINE_bool(sort_sum_gradient, false, + "Sum gradients by the reverse order of " + "the forward execution sequence."); diff --git a/paddle/fluid/pybind/global_value_getter_setter.cc b/paddle/fluid/pybind/global_value_getter_setter.cc index deca9625e63d0..f1084018d9c79 100644 --- a/paddle/fluid/pybind/global_value_getter_setter.cc +++ b/paddle/fluid/pybind/global_value_getter_setter.cc @@ -38,6 +38,7 @@ DECLARE_bool(enable_rpc_profiler); DECLARE_int32(multiple_of_cupti_buffer_size); DECLARE_bool(reader_queue_speed_test_mode); DECLARE_int32(call_stack_level); +DECLARE_bool(sort_sum_gradient); // device management DECLARE_int32(paddle_num_threads); // executor @@ -340,7 +341,7 @@ static void RegisterGlobalVarGetterSetter() { REGISTER_PUBLIC_GLOBAL_VAR( FLAGS_eager_delete_tensor_gb, FLAGS_enable_parallel_graph, FLAGS_allocator_strategy, FLAGS_use_system_allocator, FLAGS_check_nan_inf, - FLAGS_call_stack_level, FLAGS_cpu_deterministic, + FLAGS_call_stack_level, FLAGS_sort_sum_gradient, FLAGS_cpu_deterministic, FLAGS_enable_rpc_profiler, FLAGS_multiple_of_cupti_buffer_size, FLAGS_reader_queue_speed_test_mode, FLAGS_pe_profile_fname, FLAGS_print_sub_graph_dir, FLAGS_fraction_of_cpu_memory_to_use, diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 021d10ca7facb..489dd19887620 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -30,7 +30,6 @@ limitations under the License. */ #include "paddle/fluid/imperative/all_reduce.h" #include "paddle/fluid/imperative/amp_auto_cast.h" -#include "paddle/fluid/imperative/backward_strategy.h" #include "paddle/fluid/imperative/basic_engine.h" #include "paddle/fluid/imperative/data_loader.h" #include "paddle/fluid/imperative/layer.h" @@ -507,50 +506,6 @@ void BindImperative(py::module *m_ptr) { []() { memory::allocation::MemoryMapFdSet::Instance().Clear(); }); #endif - py::class_ backward_strategy( - m, "BackwardStrategy", R"DOC( - - BackwardStrategy is a descriptor of how to run the backward process. - - **Note**: - **This API is only available in** `Dygraph <../../user_guides/howto/dygraph/DyGraph.html>`_ **Mode** - - Attribute: - **sort_sum_gradient**: - - If framework will sum the gradient by the reverse order of trace. eg. x_var ( :ref:`api_guide_Variable` ) will be the input of multiple OP such as :ref:`api_fluid_layers_scale` , this attr will decide if framework will sum gradient of `x_var` by the reverse order. - - By Default: False - - Examples: - .. code-block:: python - - import numpy as np - import paddle.fluid as fluid - - x = np.ones([2, 2], np.float32) - with fluid.dygraph.guard(): - x_var = fluid.dygraph.to_variable(x) - sums_inputs = [] - # x_var will be multi-scales' input here - for _ in range(10): - sums_inputs.append(fluid.layers.scale(x_var)) - ret2 = fluid.layers.sums(sums_inputs) - loss2 = fluid.layers.reduce_sum(ret2) - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = True - loss2.backward(backward_strategy) - )DOC"); - backward_strategy.def(py::init()) - .def_property("sort_sum_gradient", - [](const imperative::detail::BackwardStrategy &self) { - return self.sorted_sum_gradient_; - }, - [](imperative::detail::BackwardStrategy &self, - bool sorted_sum_gradient) { - self.sorted_sum_gradient_ = sorted_sum_gradient; - }); - m.def("start_imperative_gperf_profiler", []() { imperative::StartProfile(); }); @@ -745,21 +700,18 @@ void BindImperative(py::module *m_ptr) { inputs2.append(tmp) ret2 = fluid.layers.sums(inputs2) loss2 = fluid.layers.reduce_sum(ret2) - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = True - loss2.backward(backward_strategy) + loss2.backward() print(loss2.gradient()) loss2.clear_gradient() print("After clear {}".format(loss2.gradient())) )DOC") .def("_run_backward", - [](imperative::VarBase &self, - const imperative::detail::BackwardStrategy &bckst, - const imperative::Tracer &tracer, bool retain_graph) { + [](imperative::VarBase &self, const imperative::Tracer &tracer, + bool retain_graph) { // TODO(jiabin): when we impl more backward execution we can // select them auto *engine = tracer.GetEngine(); - engine->Init(&self, bckst, retain_graph); + engine->Init(&self, retain_graph); VLOG(3) << "Start backward"; engine->Execute(); VLOG(3) << "Finish backward"; @@ -1024,13 +976,11 @@ void BindImperative(py::module *m_ptr) { &output_targets, const std::vector> &output_grads, const std::vector> &no_grad_vars, - const platform::Place &place, - const imperative::detail::BackwardStrategy &strategy, - bool create_graph, bool retain_graph, bool allow_unused, - bool only_inputs) { + const platform::Place &place, bool create_graph, bool retain_graph, + bool allow_unused, bool only_inputs) { imperative::PartialGradEngine engine( input_targets, output_targets, output_grads, no_grad_vars, place, - strategy, create_graph, retain_graph, allow_unused, only_inputs); + create_graph, retain_graph, allow_unused, only_inputs); engine.Execute(); return engine.GetResult(); }, diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 4e1e04043ad7d..c22eee3df6f29 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -225,7 +225,6 @@ from .framework import CUDAPlace #DEFINE_ALIAS from .framework import CUDAPinnedPlace #DEFINE_ALIAS -from .framework import BackwardStrategy #DEFINE_ALIAS from .framework import to_variable #DEFINE_ALIAS from .framework import grad #DEFINE_ALIAS from .framework import no_grad #DEFINE_ALIAS diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 2ed8642c86d95..9f748b7956f9f 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -196,6 +196,7 @@ def __bootstrap__(): 'free_idle_chunk', 'free_when_no_cache_hit', 'call_stack_level', + 'sort_sum_gradient', ] if 'Darwin' not in sysstr: read_env_flags.append('use_pinned_memory') diff --git a/python/paddle/fluid/dygraph/__init__.py b/python/paddle/fluid/dygraph/__init__.py index fc14e9b390e6a..cf270ced3b704 100644 --- a/python/paddle/fluid/dygraph/__init__.py +++ b/python/paddle/fluid/dygraph/__init__.py @@ -38,9 +38,6 @@ from . import learning_rate_scheduler from .learning_rate_scheduler import * -from . import backward_strategy -from .backward_strategy import * - from . import jit from .jit import * @@ -69,7 +66,6 @@ __all__ += parallel.__all__ __all__ += checkpoint.__all__ __all__ += learning_rate_scheduler.__all__ -__all__ += backward_strategy.__all__ __all__ += jit.__all__ __all__ += io.__all__ __all__ += rnn.__all__ diff --git a/python/paddle/fluid/dygraph/backward_strategy.py b/python/paddle/fluid/dygraph/backward_strategy.py deleted file mode 100644 index bfcf66af31ce1..0000000000000 --- a/python/paddle/fluid/dygraph/backward_strategy.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from paddle.fluid import core - -__all__ = ["BackwardStrategy"] - -BackwardStrategy = core.BackwardStrategy diff --git a/python/paddle/fluid/dygraph/base.py b/python/paddle/fluid/dygraph/base.py index d4f1ca333945d..0c4a1964838c6 100644 --- a/python/paddle/fluid/dygraph/base.py +++ b/python/paddle/fluid/dygraph/base.py @@ -319,8 +319,7 @@ def grad(outputs, create_graph=False, only_inputs=True, allow_unused=False, - no_grad_vars=None, - backward_strategy=None): + no_grad_vars=None): ''' .. note:: **This API is ONLY available in Dygraph mode.** @@ -363,9 +362,6 @@ def grad(outputs, their gradients if allow_unused=True. Default False. no_grad_vars (Variable|list(Variable)|tuple(Variable)|set(Variable), optional): the Variables whose gradients are not needed to compute. Default None. - backward_strategy (BackwardStrategy, optional): The backward strategy to - compute gradients. See :ref:`api_fluid_dygraph_BackwardStrategy` for - details. Default None. Returns: tuple: a tuple of Variables, whose length is the same as the Variable number @@ -503,12 +499,6 @@ def check_in_out(in_out_list, name): raise AssertionError( "no_grad_vars must be None, Variable or list/tuple/set of Variables") - if backward_strategy is None: - backward_strategy = core.BackwardStrategy() - - assert isinstance(backward_strategy, core.BackwardStrategy), \ - "backward_strategy must be type paddle.fluid.dygraph.BackwardStrategy" - assert isinstance(create_graph, bool), "create_graph must be True or False" if retain_graph is None: @@ -524,9 +514,9 @@ def check_in_out(in_out_list, name): place = core.Place() place.set_place(framework._current_expected_place()) - return core.dygraph_partial_grad( - inputs, outputs, grad_outputs, no_grad_vars, place, backward_strategy, - create_graph, retain_graph, allow_unused, only_inputs) + return core.dygraph_partial_grad(inputs, outputs, grad_outputs, + no_grad_vars, place, create_graph, + retain_graph, allow_unused, only_inputs) @framework.dygraph_only diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index 9dbaab2580d21..7cb17843396a6 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -15,7 +15,6 @@ import inspect from .. import framework from .. import core -from . import BackwardStrategy from ..framework import Variable, Parameter, ParamBase from .base import switch_to_static_graph import numpy as np @@ -129,19 +128,18 @@ def set_value(self, value): framework._current_expected_place()) @framework.dygraph_only - def backward(self, backward_strategy=None, retain_graph=False): + def backward(self, retain_graph=False): """ **Notes**: **This API is ONLY available in Dygraph mode** - Run backward of current Graph which starts from current Variable + Run backward of current Graph which starts from current Tensor. Args: - backward_strategy( :ref:`api_fluid_dygraph_BackwardStrategy` ): The Backward Strategy to run backward retain_graph(bool, optional): If False, the graph used to compute grads will be freed. If you would - like to add more ops to the built graph after calling this method(`backward`), set the parameter - `retain_graph` to True, then the grads will be retained. Thus, seting it to False is much more memory-efficient. - Defaults to False. + like to add more ops to the built graph after calling this method( :code:`backward` ), set the parameter + :code:`retain_graph` to True, then the grads will be retained. Thus, seting it to False is much more memory-efficient. + Defaults to False. Returns: NoneType: None @@ -149,32 +147,25 @@ def backward(self, backward_strategy=None, retain_graph=False): Examples: .. code-block:: python - import paddle.fluid as fluid import numpy as np + import paddle + paddle.disable_static() x = np.ones([2, 2], np.float32) - with fluid.dygraph.guard(): - inputs2 = [] - for _ in range(10): - tmp = fluid.dygraph.base.to_variable(x) - # if we don't set tmp's stop_gradient as False then, all path to loss will has no gradient since - # there is no one need gradient on it. - tmp.stop_gradient=False - inputs2.append(tmp) - ret2 = fluid.layers.sums(inputs2) - loss2 = fluid.layers.reduce_sum(ret2) - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = True - loss2.backward(backward_strategy) + inputs = [] + for _ in range(10): + tmp = paddle.to_tensor(x) + # if we don't set tmp's stop_gradient as False then, all path to loss will has no gradient since + # there is no one need gradient on it. + tmp.stop_gradient=False + inputs.append(tmp) + ret = paddle.sums(inputs) + loss = paddle.reduce_sum(ret) + loss.backward() """ if framework.in_dygraph_mode(): - if backward_strategy is None: - backward_strategy = BackwardStrategy() - backward_strategy.sort_sum_gradient = False - - self._run_backward(backward_strategy, - framework._dygraph_tracer(), retain_graph) + self._run_backward(framework._dygraph_tracer(), retain_graph) else: raise ValueError( "Variable.backward() is only available in DyGraph mode") @@ -205,9 +196,7 @@ def gradient(self): inputs2.append(tmp) ret2 = fluid.layers.sums(inputs2) loss2 = fluid.layers.reduce_sum(ret2) - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = True - loss2.backward(backward_strategy) + loss2.backward() print(loss2.gradient()) """ diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index ef50294b8e762..fc4e91aad4fff 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1106,15 +1106,18 @@ def set_value(self, value): pass @fake_interface_only - def backward(self, backward_strategy=None): + def backward(self, retain_graph=False): """ **Notes**: **This API is ONLY available in Dygraph mode** - Run backward of current Graph which starts from current Variable + Run backward of current Graph which starts from current Tensor. Args: - backward_strategy( :ref:`api_fluid_dygraph_BackwardStrategy` ): The Backward Strategy to run backward + retain_graph(bool, optional): If False, the graph used to compute grads will be freed. If you would + like to add more ops to the built graph after calling this method( :code:`backward` ), set the parameter + :code:`retain_graph` to True, then the grads will be retained. Thus, seting it to False is much more memory-efficient. + Defaults to False. Returns: NoneType: None @@ -1122,23 +1125,21 @@ def backward(self, backward_strategy=None): Examples: .. code-block:: python - import paddle.fluid as fluid import numpy as np + import paddle + paddle.disable_static() x = np.ones([2, 2], np.float32) - with fluid.dygraph.guard(): - inputs2 = [] - for _ in range(10): - tmp = fluid.dygraph.base.to_variable(x) - # if we don't set tmp's stop_gradient as False then, all path to loss will has no gradient since - # there is no one need gradient on it. - tmp.stop_gradient=False - inputs2.append(tmp) - ret2 = fluid.layers.sums(inputs2) - loss2 = fluid.layers.reduce_sum(ret2) - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = True - loss2.backward(backward_strategy) + inputs = [] + for _ in range(10): + tmp = paddle.to_tensor(x) + # if we don't set tmp's stop_gradient as False then, all path to loss will has no gradient since + # there is no one need gradient on it. + tmp.stop_gradient=False + inputs.append(tmp) + ret = paddle.sums(inputs) + loss = paddle.reduce_sum(ret) + loss.backward() """ pass @@ -1170,9 +1171,7 @@ def gradient(self): inputs2.append(tmp) ret2 = fluid.layers.sums(inputs2) loss2 = fluid.layers.reduce_sum(ret2) - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = True - loss2.backward(backward_strategy) + loss2.backward() print(loss2.gradient()) # example2: return tuple of ndarray @@ -1218,9 +1217,7 @@ def clear_gradient(self): inputs2.append(tmp) ret2 = fluid.layers.sums(inputs2) loss2 = fluid.layers.reduce_sum(ret2) - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = True - loss2.backward(backward_strategy) + loss2.backward() print(loss2.gradient()) loss2.clear_gradient() print("After clear {}".format(loss2.gradient())) diff --git a/python/paddle/fluid/tests/unittests/test_directory_migration.py b/python/paddle/fluid/tests/unittests/test_directory_migration.py index bc85882805807..74cc87bd9dbd6 100644 --- a/python/paddle/fluid/tests/unittests/test_directory_migration.py +++ b/python/paddle/fluid/tests/unittests/test_directory_migration.py @@ -38,8 +38,7 @@ def test_new_directory(self): 'paddle.enable_static', 'paddle.disable_static', 'paddle.in_dynamic_mode', 'paddle.to_variable', 'paddle.grad', 'paddle.no_grad', 'paddle.save', 'paddle.load', - 'paddle.static.save', 'paddle.static.load', - 'paddle.BackwardStrategy', 'paddle.ParallelEnv', + 'paddle.static.save', 'paddle.static.load', 'paddle.ParallelEnv', 'paddle.prepare_context', 'paddle.DataParallel', 'paddle.jit', 'paddle.jit.TracedLayer', 'paddle.jit.to_static', 'paddle.jit.ProgramTranslator', 'paddle.jit.TranslatedLayer', @@ -98,7 +97,6 @@ def test_old_directory(self): 'paddle.imperative.enable', 'paddle.imperative.guard', 'paddle.imperative.grad', 'paddle.imperative.no_grad', 'paddle.imperative.save', 'paddle.imperative.load', - 'paddle.imperative.BackwardStrategy', 'paddle.imperative.ParallelEnv', 'paddle.imperative.prepare_context', 'paddle.imperative.DataParalell', 'paddle.imperative.jit', diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py index 2a25bf6f8abad..837e82882e9df 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py @@ -238,8 +238,7 @@ def test_auto_prune7(self): out2 = linear2(b) out1.stop_gradient = True out = fluid.layers.concat(input=[out1, out2, c], axis=1) - backward_strategy = fluid.dygraph.BackwardStrategy() - out.backward(backward_strategy) + out.backward() self.assertTrue(linear.weight.gradient() is None) self.assertTrue(out1.gradient() is None) @@ -311,9 +310,8 @@ def test_auto_prune10(self): out2 = linear2(b) out1.stop_gradient = True out = fluid.layers.concat(input=[out1, out2, c], axis=1) - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = True - out.backward(backward_strategy) + fluid.set_flags({'FLAGS_sort_sum_gradient': True}) + out.backward() self.assertTrue(linear.weight.gradient() is None) self.assertTrue(out1.gradient() is None) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_basic.py b/python/paddle/fluid/tests/unittests/test_imperative_basic.py index f83f8ef35215e..b74182d27ab8c 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_basic.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_basic.py @@ -314,9 +314,8 @@ def test_sum_op(self): inputs2.append(tmp) ret2 = fluid.layers.sums(inputs2) loss2 = fluid.layers.reduce_sum(ret2) - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = True - loss2.backward(backward_strategy) + fluid.set_flags({'FLAGS_sort_sum_gradient': True}) + loss2.backward() self.assertTrue(np.allclose(ret.numpy(), x * 10)) self.assertTrue(np.allclose(inputs[0].gradient(), x)) @@ -403,9 +402,8 @@ def test_layer_in_out(self): x2 = l2(var_inp2)[0] self.assertIsNotNone(x2) dy_out2 = x2.numpy() - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = True - x2.backward(backward_strategy) + fluid.set_flags({'FLAGS_sort_sum_gradient': True}) + x2.backward() dy_grad2 = l2._x_for_debug.gradient() with new_program_scope(): @@ -442,9 +440,8 @@ def test_mlp(self): mlp2 = MLP(input_size=2) out2 = mlp2(var_inp2) dy_out2 = out2.numpy() - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = True - out2.backward(backward_strategy) + fluid.set_flags({'FLAGS_sort_sum_gradient': True}) + out2.backward() dy_grad2 = mlp2._linear1.weight.gradient() with new_program_scope(): @@ -552,9 +549,8 @@ def test_rnn(self): simple_rnn2 = SimpleRNN() outs2, pre_hiddens2 = simple_rnn2.forward(var_inp2) dy_out2 = outs2[3].numpy() - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = True - outs2[3].backward(backward_strategy) + fluid.set_flags({'FLAGS_sort_sum_gradient': True}) + outs2[3].backward() dy_grad_h2o2 = simple_rnn2._cell._h2o_w.gradient() dy_grad_h2h2 = simple_rnn2._cell._h2h_w.gradient() dy_grad_i2h2 = simple_rnn2._cell._i2h_w.gradient() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_deepcf.py b/python/paddle/fluid/tests/unittests/test_imperative_deepcf.py index f76c3bd958081..af71d9d27b9a3 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_deepcf.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_deepcf.py @@ -275,8 +275,7 @@ def test_deefcf(self): deepcf2 = DeepCF(num_users, num_items, matrix) adam2 = fluid.optimizer.AdamOptimizer( 0.01, parameter_list=deepcf2.parameters()) - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = True + fluid.set_flags({'FLAGS_sort_sum_gradient': True}) for e in range(NUM_EPOCHES): sys.stderr.write('epoch %d\n' % e) for slice in range(0, BATCH_SIZE * NUM_BATCHES, BATCH_SIZE): @@ -289,7 +288,7 @@ def test_deefcf(self): fluid.layers.log_loss(prediction2, to_variable(labels_np[ slice:slice + BATCH_SIZE]))) - loss2.backward(backward_strategy) + loss2.backward() adam2.minimize(loss2) deepcf2.clear_gradients() dy_loss2 = loss2.numpy() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py b/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py index 429736803a192..227cd5d4acb29 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py @@ -52,8 +52,7 @@ def grad(self, retain_graph=None, create_graph=False, allow_unused=False): - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = self.sort_sum_gradient + fluid.set_flags({'FLAGS_sort_sum_gradient': self.sort_sum_gradient}) return fluid.dygraph.grad( outputs=outputs, inputs=inputs, @@ -61,8 +60,7 @@ def grad(self, no_grad_vars=no_grad_vars, retain_graph=retain_graph, create_graph=create_graph, - allow_unused=allow_unused, - backward_strategy=backward_strategy) + allow_unused=allow_unused) @dygraph_guard def test_exception(self): @@ -310,8 +308,8 @@ def model_f(input): out = out + linear(input) return out - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = True + fluid.set_flags({'FLAGS_sort_sum_gradient': True}) + with fluid.dygraph.guard(): paddle.manual_seed(123) a = fluid.dygraph.to_variable(value) @@ -324,8 +322,7 @@ def model_f(input): inputs=[a], create_graph=False, only_inputs=True, - allow_unused=False, - backward_strategy=backward_strategy) + allow_unused=False) grad_1 = dx[0].numpy() @@ -335,7 +332,7 @@ def model_f(input): a.stop_gradient = False out = model_f(a) - out.backward(backward_strategy) + out.backward() grad_2 = a.gradient() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_gan.py b/python/paddle/fluid/tests/unittests/test_imperative_gan.py index b7ebd23a0b742..80bdf2ea8a898 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_gan.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_gan.py @@ -179,9 +179,8 @@ def test_gan_float32(self): with fluid.dygraph.guard(): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed + fluid.set_flags({'FLAGS_sort_sum_gradient': True}) - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = True discriminator2 = Discriminator() generator2 = Generator() sgd2 = SGDOptimizer( @@ -201,7 +200,7 @@ def test_gan_float32(self): x=d_fake2, label=to_variable(np.zeros([2, 1], np.float32)))) d_loss2 = d_loss_real2 + d_loss_fake2 - d_loss2.backward(backward_strategy) + d_loss2.backward() sgd2.minimize(d_loss2) discriminator2.clear_gradients() generator2.clear_gradients() @@ -211,7 +210,7 @@ def test_gan_float32(self): g_loss2 = fluid.layers.reduce_mean( fluid.layers.sigmoid_cross_entropy_with_logits( x=d_fake2, label=to_variable(np.ones([2, 1], np.float32)))) - g_loss2.backward(backward_strategy) + g_loss2.backward() sgd2.minimize(g_loss2) for p in discriminator2.parameters(): dy_params2[p.name] = p.numpy() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_hook_for_layer.py b/python/paddle/fluid/tests/unittests/test_imperative_hook_for_layer.py index 4fe4d963ca5ee..317353684317f 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_hook_for_layer.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_hook_for_layer.py @@ -62,8 +62,7 @@ def test_forward_hook_return_value(self): with fluid.dygraph.guard(place): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = True + fluid.set_flags({'FLAGS_sort_sum_gradient': True}) input_word = np.array( [0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 6, 7, @@ -132,8 +131,7 @@ def test_forward_hook(self): with fluid.dygraph.guard(place): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = True + fluid.set_flags({'FLAGS_sort_sum_gradient': True}) global call_forward_hook global call_forward_pre_hook diff --git a/python/paddle/fluid/tests/unittests/test_imperative_lod_tensor_to_selected_rows.py b/python/paddle/fluid/tests/unittests/test_imperative_lod_tensor_to_selected_rows.py index 69fd7d80327f1..6349d71760934 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_lod_tensor_to_selected_rows.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_lod_tensor_to_selected_rows.py @@ -113,8 +113,9 @@ def simple_net_float32(self, is_sparse, dtype): dy_loss = None helper = DyGraphProgramDescTracerTestHelper(self) - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = is_sort_sum_gradient + fluid.set_flags({ + 'FLAGS_sort_sum_gradient': is_sort_sum_gradient + }) for i in range(batch_num): x_data = np.arange(12).reshape(4, 3).astype('int64') @@ -129,7 +130,7 @@ def simple_net_float32(self, is_sparse, dtype): if i == 0: for param in simple_net.parameters(): dy_param_init[param.name] = param.numpy() - dy_loss.backward(backward_strategy) + dy_loss.backward() sgd.minimize(dy_loss) sgd.clear_gradients() if i == batch_num - 1: diff --git a/python/paddle/fluid/tests/unittests/test_imperative_mnist_sorted_gradient.py b/python/paddle/fluid/tests/unittests/test_imperative_mnist_sorted_gradient.py index 4ce0ca350ddb9..bda1958c0f354 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_mnist_sorted_gradient.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_mnist_sorted_gradient.py @@ -36,8 +36,7 @@ def test_mnist_sort_gradient_float32(self): with fluid.dygraph.guard(): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = True + fluid.set_flags({'FLAGS_sort_sum_gradient': True}) mnist2 = MNIST() sgd2 = SGDOptimizer( @@ -69,7 +68,7 @@ def test_mnist_sort_gradient_float32(self): for param in mnist2.parameters(): dy_param_init_value2[param.name] = param.numpy() - avg_loss2.backward(backward_strategy) + avg_loss2.backward() sgd2.minimize(avg_loss2) mnist2.clear_gradients() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_ocr_attention_model.py b/python/paddle/fluid/tests/unittests/test_imperative_ocr_attention_model.py index 246b013f1ada6..499a4311f6e17 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_ocr_attention_model.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_ocr_attention_model.py @@ -403,8 +403,7 @@ def test_while_op(self): with fluid.dygraph.guard(): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = True + fluid.set_flags({'FLAGS_sort_sum_gradient': True}) ocr_attention = OCRAttention() if Config.learning_rate_decay == "piecewise_decay": @@ -438,7 +437,7 @@ def test_while_op(self): for param in ocr_attention.parameters(): if param.name not in dy_param_init_value: dy_param_init_value[param.name] = param.numpy() - avg_loss.backward(backward_strategy) + avg_loss.backward() dy_grad_value = {} for param in ocr_attention.parameters(): if param.trainable: diff --git a/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn_sorted_gradient.py b/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn_sorted_gradient.py index 8e85fe5dfefea..526c1706e2d08 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn_sorted_gradient.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn_sorted_gradient.py @@ -45,8 +45,7 @@ def ptb_rnn_sort_gradient_cpu_float32(self, is_sparse): with fluid.dygraph.guard(): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = True + fluid.set_flags({'FLAGS_sort_sum_gradient': True}) # TODO: marsyang1993 Change seed to ptb_model = PtbModel( hidden_size=hidden_size, @@ -82,7 +81,7 @@ def ptb_rnn_sort_gradient_cpu_float32(self, is_sparse): if i == 0: for param in ptb_model.parameters(): dy_param_init[param.name] = param.numpy() - dy_loss.backward(backward_strategy) + dy_loss.backward() sgd.minimize(dy_loss) ptb_model.clear_gradients() if i == batch_num - 1: diff --git a/python/paddle/fluid/tests/unittests/test_imperative_resnet_sorted_gradient.py b/python/paddle/fluid/tests/unittests/test_imperative_resnet_sorted_gradient.py index 8cbd08ea3e245..d26d6f25aa8ff 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_resnet_sorted_gradient.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_resnet_sorted_gradient.py @@ -79,8 +79,7 @@ def test_resnet_sort_gradient_float32(self): with fluid.dygraph.guard(): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = True + fluid.set_flags({'FLAGS_sort_sum_gradient': True}) resnet = ResNet() optimizer = optimizer_setting( train_parameters, parameter_list=resnet.parameters()) @@ -119,7 +118,7 @@ def test_resnet_sort_gradient_float32(self): if param.name not in dy_param_init_value: dy_param_init_value[param.name] = param.numpy() - avg_loss.backward(backward_strategy) + avg_loss.backward() dy_grad_value = {} for param in resnet.parameters(): diff --git a/python/paddle/fluid/tests/unittests/test_imperative_selected_rows.py b/python/paddle/fluid/tests/unittests/test_imperative_selected_rows.py index 9878e2f9ad772..59ddb365e5396 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_selected_rows.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_selected_rows.py @@ -48,8 +48,9 @@ def test_selectedrows_gradient1(self): for dtype in ["float32", "float64"]: for sort_sum_gradient in [True, False]: paddle.disable_static(place) - backward_strategy = paddle.BackwardStrategy() - backward_strategy.sort_sum_gradient = sort_sum_gradient + fluid.set_flags({ + 'FLAGS_sort_sum_gradient': sort_sum_gradient + }) # grad_clip = fluid.clip.GradientClipByGlobalNorm(5.0) input_word = np.array([[1, 2], [2, 1]]).astype('int64') @@ -65,7 +66,7 @@ def test_selectedrows_gradient1(self): self.assertTrue(emb.weight.gradient() is None) self.assertTrue(input_emb.gradient() is None) - input_emb.backward(backward_strategy) + input_emb.backward() adam.minimize(input_emb) self.assertTrue(emb.weight.gradient() is not None) @@ -84,8 +85,9 @@ def test_selectedrows_gradient2(self): for place in places: for sort_sum_gradient in [True, False]: with fluid.dygraph.guard(place): - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = sort_sum_gradient + fluid.set_flags({ + 'FLAGS_sort_sum_gradient': sort_sum_gradient + }) grad_clip = fluid.clip.GradientClipByGlobalNorm(5.0) input_word = np.array([[1, 2], [2, 1]]).astype('int64') @@ -101,7 +103,7 @@ def test_selectedrows_gradient2(self): self.assertTrue(emb.weight.gradient() is None) self.assertTrue(input_emb.gradient() is None) - input_emb.backward(backward_strategy) + input_emb.backward() adam.minimize(input_emb) self.assertTrue(emb.weight.gradient() is not None) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_selected_rows_to_lod_tensor.py b/python/paddle/fluid/tests/unittests/test_imperative_selected_rows_to_lod_tensor.py index a42a62019ba54..3765cb784d652 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_selected_rows_to_lod_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_selected_rows_to_lod_tensor.py @@ -119,8 +119,9 @@ def simple_net_float(self, is_sparse, dtype): dy_param_init = dict() dy_loss = None - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = is_sort_sum_gradient + fluid.set_flags({ + 'FLAGS_sort_sum_gradient': is_sort_sum_gradient + }) for i in range(batch_num): x_data = np.arange(12).reshape(4, 3).astype('int64') @@ -135,7 +136,7 @@ def simple_net_float(self, is_sparse, dtype): if i == 0: for param in simple_net.parameters(): dy_param_init[param.name] = param.numpy() - dy_loss.backward(backward_strategy) + dy_loss.backward() sgd.minimize(dy_loss) sgd.clear_gradients() if i == batch_num - 1: diff --git a/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py b/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py index 649dc1ad91d38..d603a7d6ca0de 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py @@ -479,8 +479,7 @@ def __init__(self, cfg): self.cfg = cfg - self.backward_strategy = fluid.dygraph.BackwardStrategy() - self.backward_strategy.sort_sum_gradient = cfg.sort_sum_gradient + fluid.set_flags({'FLAGS_sort_sum_gradient': cfg.sort_sum_gradient}) def clear_gradients(self): if self.g_optimizer: @@ -497,7 +496,7 @@ def run(self, image_real, label_org, label_trg): g_loss = get_generator_loss(image_real, label_org, label_trg, self.generator, self.discriminator, self.cfg) - g_loss.backward(self.backward_strategy) + g_loss.backward() if self.g_optimizer: self.g_optimizer.minimize(g_loss) @@ -506,7 +505,7 @@ def run(self, image_real, label_org, label_trg): d_loss = get_discriminator_loss(image_real, label_org, label_trg, self.generator, self.discriminator, self.cfg) - d_loss.backward(self.backward_strategy) + d_loss.backward() if self.d_optimizer: self.d_optimizer.minimize(d_loss) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_static_runner_mnist.py b/python/paddle/fluid/tests/unittests/test_imperative_static_runner_mnist.py index acc56b7db27f4..f10d2df7f06f9 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_static_runner_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_static_runner_mnist.py @@ -121,8 +121,7 @@ def load_and_train_dygraph(self): with fluid.dygraph.guard(place): fluid.default_startup_program().random_seed = self.seed fluid.default_main_program().random_seed = self.seed - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = True + fluid.set_flags({'FLAGS_sort_sum_gradient': True}) mnist = fluid.dygraph.static_runner.StaticModelRunner( model_dir=self.save_dirname, @@ -156,7 +155,7 @@ def load_and_train_dygraph(self): loss = fluid.layers.cross_entropy(cost, label) avg_loss = fluid.layers.mean(loss) - avg_loss.backward(backward_strategy) + avg_loss.backward() sgd.minimize(avg_loss) mnist.clear_gradients() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_static_runner_while.py b/python/paddle/fluid/tests/unittests/test_imperative_static_runner_while.py index 0792582175ef0..db47170c7bfff 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_static_runner_while.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_static_runner_while.py @@ -111,9 +111,7 @@ def load_and_train_dygraph(self): fluid.default_startup_program().random_seed = self.seed fluid.default_main_program().random_seed = self.seed np.random.seed(self.seed) - - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = True + fluid.set_flags({'FLAGS_sort_sum_gradient': True}) while_net = fluid.dygraph.static_runner.StaticModelRunner( self.save_dirname) @@ -141,7 +139,7 @@ def load_and_train_dygraph(self): loss = fluid.layers.cross_entropy(cost, label) avg_loss = fluid.layers.mean(loss) - avg_loss.backward(backward_strategy) + avg_loss.backward() sgd.minimize(avg_loss) while_net.clear_gradients() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py b/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py index 29cc718f14ff9..c59ce44ec96a8 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py @@ -951,8 +951,7 @@ def transformer_sort_gradient_float32(self, is_sparse): with guard(): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = True + fluid.set_flags({'FLAGS_sort_sum_gradient': True}) transformer = TransFormer( ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size, @@ -1021,7 +1020,7 @@ def transformer_sort_gradient_float32(self, is_sparse): for param in transformer.parameters(): dy_param_init[param.name] = param.numpy() - dy_avg_cost.backward(backward_strategy) + dy_avg_cost.backward() optimizer.minimize(dy_avg_cost) transformer.clear_gradients() diff --git a/python/paddle/fluid/tests/unittests/test_paddle_imperative_double_grad.py b/python/paddle/fluid/tests/unittests/test_paddle_imperative_double_grad.py index 858d56c1fc04f..2ffe523ef6dda 100644 --- a/python/paddle/fluid/tests/unittests/test_paddle_imperative_double_grad.py +++ b/python/paddle/fluid/tests/unittests/test_paddle_imperative_double_grad.py @@ -52,8 +52,6 @@ def grad(self, retain_graph=None, create_graph=False, allow_unused=False): - backward_strategy = fluid.dygraph.BackwardStrategy() - backward_strategy.sort_sum_gradient = self.sort_sum_gradient return paddle.grad( outputs=outputs, inputs=inputs, @@ -61,8 +59,7 @@ def grad(self, no_grad_vars=no_grad_vars, retain_graph=retain_graph, create_graph=create_graph, - allow_unused=allow_unused, - backward_strategy=backward_strategy) + allow_unused=allow_unused) @dygraph_guard def test_exception(self): diff --git a/python/paddle/framework/__init__.py b/python/paddle/framework/__init__.py index f01dc01973a60..95a0cb5204679 100644 --- a/python/paddle/framework/__init__.py +++ b/python/paddle/framework/__init__.py @@ -20,8 +20,8 @@ ] __all__ += [ - 'BackwardStrategy', 'grad', 'LayerList', 'load', 'save', 'prepare_context', - 'to_variable', 'no_grad', 'ParallelEnv', 'DataParallel' + 'grad', 'LayerList', 'load', 'save', 'prepare_context', 'to_variable', + 'no_grad', 'ParallelEnv', 'DataParallel' ] __all__ += [ @@ -61,5 +61,3 @@ from ..fluid.dygraph.learning_rate_scheduler import InverseTimeDecay #DEFINE_ALIAS from ..fluid.dygraph.learning_rate_scheduler import PolynomialDecay #DEFINE_ALIAS from ..fluid.dygraph.learning_rate_scheduler import CosineDecay #DEFINE_ALIAS - -BackwardStrategy = core.BackwardStrategy From f1ae017fa9294d8bd024aefbea0678e0fee59dfc Mon Sep 17 00:00:00 2001 From: lilong12 Date: Fri, 28 Aug 2020 07:51:57 +0800 Subject: [PATCH 6/8] update copyright year, test=document_fix (#26586) --- python/paddle/fluid/tests/unittests/test_collective_api_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_collective_api_base.py b/python/paddle/fluid/tests/unittests/test_collective_api_base.py index b04bd0cbdefbd..437b8b7befae4 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_api_base.py +++ b/python/paddle/fluid/tests/unittests/test_collective_api_base.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 2f75465d9a699540115e80c93a76fcd50815fc7e Mon Sep 17 00:00:00 2001 From: lilong12 Date: Fri, 28 Aug 2020 07:52:34 +0800 Subject: [PATCH 7/8] fix the call to core.ops.x, test=develop (#26729) --- python/paddle/tensor/manipulation.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 7f3dddc4e472f..845d2cf4d1993 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -23,7 +23,6 @@ import numpy as np # TODO: define functions to manipulate a tensor from ..fluid.layers import cast #DEFINE_ALIAS -from ..fluid.layers import expand_as #DEFINE_ALIAS from ..fluid.layers import slice #DEFINE_ALIAS from ..fluid.layers import strided_slice #DEFINE_ALIAS from ..fluid.layers import transpose #DEFINE_ALIAS @@ -1100,6 +1099,9 @@ def tile(x, repeat_times, name=None): np_out = out.numpy() # [[1, 2, 3], [1, 2, 3]] """ + if in_dygraph_mode(): + return core.ops.tile(x, 'repeat_times', repeat_times) + check_variable_and_dtype( x, 'x', ['bool', 'float32', 'float64', 'int32', 'int64'], 'tile') check_type(repeat_times, 'repeat_times', (list, tuple, Variable), 'tile') @@ -1109,9 +1111,6 @@ def tile(x, repeat_times, name=None): "must set its stop_gradient to be True by " "some_var.stop_gradient == True supporting some_var is the input.") - if in_dygraph_mode(): - return core.ops.tile(x, 'repeat_times', repeat_times) - helper = LayerHelper('tile', **locals()) inputs = {"X": [x]} @@ -1176,6 +1175,9 @@ def expand_as(x, y, name=None): np_out = out.numpy() # [[1, 2, 3], [1, 2, 3]] """ + if in_dygraph_mode(): + return core.ops.expand_as_v2(x, y) + check_variable_and_dtype( x, 'x', ['bool', 'float32', 'float64', 'int32', 'int64'], 'expand_as') check_type(y, 'y', Variable, 'expand_as') @@ -1188,9 +1190,6 @@ def expand_as(x, y, name=None): "some_var as the input 'x'.") inputs = {"X": [x], "target_tensor": [y]} - if in_dygraph_mode(): - return core.ops.expand_as_v2(x, y) - helper = LayerHelper('expand_as', **locals()) dtype = helper.input_dtype(input_param_name='x') out = helper.create_variable_for_type_inference(dtype) @@ -1229,6 +1228,9 @@ def expand(x, shape, name=None): out = out.numpy() # [[1, 2, 3], [1, 2, 3]] """ + if in_dygraph_mode(): + return core.ops.expand_v2(x, 'shape', shape) + check_variable_and_dtype( x, 'x', ['bool', 'float32', 'float64', 'int32', 'int64'], 'expand') check_type(shape, 'shape', (list, tuple, Variable), 'expand') @@ -1241,9 +1243,6 @@ def expand(x, shape, name=None): "some_var.stop_gradient = True, supporting " "some_var as the input.") - if in_dygraph_mode(): - return core.ops.expand_v2(x, 'shape', shape) - helper = LayerHelper('expand', **locals()) def get_attr_expand_shape(list_expand_shape): From edf5f3173a25ae2230e9619ab5426317b4bd7cde Mon Sep 17 00:00:00 2001 From: donproc Date: Fri, 28 Aug 2020 08:24:15 +0800 Subject: [PATCH 8/8] [2.0 API] add paddle.nn.functional.linear and fix paddle.nn.Linear (#26480) --- .../fluid/tests/unittests/test_adamax_api.py | 2 +- .../fluid/tests/unittests/test_adamw_op.py | 4 +- .../unittests/test_imperative_layer_apply.py | 5 +- .../fluid/tests/unittests/test_linear.py | 78 ++++++++++++++++ .../fluid/tests/unittests/test_rmsprop_op.py | 2 +- .../paddle/incubate/hapi/tests/test_model.py | 18 ++-- .../test_uncombined_weight2state_dict.py | 8 +- .../incubate/hapi/vision/models/lenet.py | 8 +- .../paddle/incubate/hapi/vision/models/vgg.py | 8 +- python/paddle/nn/__init__.py | 1 + python/paddle/nn/functional/__init__.py | 1 + python/paddle/nn/functional/common.py | 88 ++++++++++++++++++- python/paddle/nn/layer/common.py | 84 +++++++++++++++++- 13 files changed, 275 insertions(+), 32 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_linear.py diff --git a/python/paddle/fluid/tests/unittests/test_adamax_api.py b/python/paddle/fluid/tests/unittests/test_adamax_api.py index f6946dc80b5e5..5a33e11d2862c 100644 --- a/python/paddle/fluid/tests/unittests/test_adamax_api.py +++ b/python/paddle/fluid/tests/unittests/test_adamax_api.py @@ -26,7 +26,7 @@ def test_adamax_api_dygraph(self): paddle.disable_static() value = np.arange(26).reshape(2, 13).astype("float32") a = paddle.to_variable(value) - linear = paddle.nn.Linear(13, 5, dtype="float32") + linear = paddle.nn.Linear(13, 5) adam = paddle.optimizer.Adamax( learning_rate=0.01, parameters=linear.parameters(), diff --git a/python/paddle/fluid/tests/unittests/test_adamw_op.py b/python/paddle/fluid/tests/unittests/test_adamw_op.py index ddb70d6e6400c..0a7cf54e2e0f1 100644 --- a/python/paddle/fluid/tests/unittests/test_adamw_op.py +++ b/python/paddle/fluid/tests/unittests/test_adamw_op.py @@ -23,7 +23,7 @@ def test_adamw_op_dygraph(self): paddle.disable_static() value = np.arange(26).reshape(2, 13).astype("float32") a = paddle.to_variable(value) - linear = paddle.nn.Linear(13, 5, dtype="float32") + linear = paddle.nn.Linear(13, 5) adam = paddle.optimizer.AdamW( learning_rate=0.01, parameters=linear.parameters(), @@ -38,7 +38,7 @@ def test_adamw_op_coverage(self): paddle.disable_static() value = np.arange(26).reshape(2, 13).astype("float32") a = paddle.to_variable(value) - linear = paddle.nn.Linear(13, 5, dtype="float32") + linear = paddle.nn.Linear(13, 5) adam = paddle.optimizer.AdamW( learning_rate=0.0, parameters=linear.parameters(), diff --git a/python/paddle/fluid/tests/unittests/test_imperative_layer_apply.py b/python/paddle/fluid/tests/unittests/test_imperative_layer_apply.py index b15ad911ee79d..f61d1ab888a51 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_layer_apply.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_layer_apply.py @@ -40,9 +40,8 @@ def __init__(self, num_classes=10, classifier_activation='softmax'): if num_classes > 0: self.fc = nn.Sequential( nn.Linear(400, 120), - nn.Linear(120, 84), - nn.Linear( - 84, 10, act=classifier_activation)) + nn.Linear(120, 84), nn.Linear(84, 10), + nn.Softmax()) #Todo: accept any activation def forward(self, inputs): x = self.features(inputs) diff --git a/python/paddle/fluid/tests/unittests/test_linear.py b/python/paddle/fluid/tests/unittests/test_linear.py new file mode 100644 index 0000000000000..9d07a80da15db --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_linear.py @@ -0,0 +1,78 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle.fluid.core as core +from op_test import OpTest +import paddle +from paddle import fluid, nn +import paddle.fluid.dygraph as dg +import paddle.nn.functional as F +import paddle.fluid.initializer as I + + +class LinearTestCase(unittest.TestCase): + def setUp(self): + self.dtype = 'float32' + self.input = np.ones((3, 1, 2)).astype(self.dtype) + self.weight = np.ones((2, 2)).astype(self.dtype) + self.bias = np.ones((2)).astype(self.dtype) + self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else paddle.CPUPlace() + + def functional(self, place): + paddle.disable_static(place) + input = paddle.to_tensor(self.input) + weight = paddle.to_tensor(self.weight) + bias = paddle.to_tensor(self.bias) + out = F.linear(input, weight, bias) + return out.numpy() + + def paddle_nn_layer(self, place): + paddle.disable_static(place) + input = paddle.to_tensor(self.input) + weight_attr = fluid.ParamAttr( + name="linear_weight", + learning_rate=1.0, + trainable=False, + regularizer=None, + initializer=paddle.fluid.initializer.ConstantInitializer(value=1.0)) + bias_attr = fluid.ParamAttr( + name="linear_bias", + learning_rate=1.0, + trainable=False, + regularizer=None, + initializer=paddle.fluid.initializer.ConstantInitializer(value=1.0)) + linear = paddle.nn.Linear( + 2, 2, weight_attr=weight_attr, bias_attr=bias_attr) + y = linear(input) + return y.numpy() + + def numpy_cal(self): + res = np.matmul(self.input, self.weight) + self.bias + return res + + def test_error(self, place=paddle.CPUPlace()): + res_f = self.functional(place) + res_nn = self.paddle_nn_layer(place) + res_np = self.numpy_cal() + np.testing.assert_array_almost_equal(res_f, res_nn) + np.testing.assert_array_almost_equal(res_nn, res_np) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_rmsprop_op.py b/python/paddle/fluid/tests/unittests/test_rmsprop_op.py index 0f225758ced3b..f7b9d4214d36a 100644 --- a/python/paddle/fluid/tests/unittests/test_rmsprop_op.py +++ b/python/paddle/fluid/tests/unittests/test_rmsprop_op.py @@ -228,7 +228,7 @@ def test_rmsprop_dygraph(self): paddle.disable_static() value = np.arange(26).reshape(2, 13).astype("float32") a = paddle.to_tensor(value) - linear = paddle.nn.Linear(13, 5, dtype="float32") + linear = paddle.nn.Linear(13, 5) # This can be any optimizer supported by dygraph. adam = paddle.optimizer.RMSProp( learning_rate=0.01, diff --git a/python/paddle/incubate/hapi/tests/test_model.py b/python/paddle/incubate/hapi/tests/test_model.py index 96c432e1bfd8f..8e0c051ee8c39 100644 --- a/python/paddle/incubate/hapi/tests/test_model.py +++ b/python/paddle/incubate/hapi/tests/test_model.py @@ -23,7 +23,7 @@ import tempfile from paddle import fluid -from paddle.nn import Conv2d, Pool2D, Linear, ReLU, Sequential +from paddle.nn import Conv2d, Pool2D, Linear, ReLU, Sequential, Softmax from paddle.fluid.dygraph.base import to_variable import paddle.incubate.hapi as hapi @@ -53,10 +53,8 @@ def __init__(self, num_classes=10, classifier_activation=None): if num_classes > 0: self.fc = Sequential( - Linear(400, 120), - Linear(120, 84), - Linear( - 84, 10, act=classifier_activation)) + Linear(400, 120), Linear(120, 84), Linear(84, 10), + Softmax()) #Todo: accept any activation def forward(self, inputs): x = self.features(inputs) @@ -83,10 +81,8 @@ def __init__(self, num_classes=10, classifier_activation=None): if num_classes > 0: self.fc = Sequential( - Linear(400, 120), - Linear(120, 84), - Linear( - 84, 10, act=classifier_activation)) + Linear(400, 120), Linear(120, 84), Linear(84, 10), + Softmax()) #Todo: accept any activation @declarative def forward(self, inputs): @@ -320,10 +316,12 @@ def predict(self, dynamic): class MyModel(fluid.dygraph.Layer): def __init__(self, classifier_activation='softmax'): super(MyModel, self).__init__() - self._fc = Linear(20, 10, act=classifier_activation) + self._fc = Linear(20, 10) + self._act = Softmax() #Todo: accept any activation def forward(self, x): y = self._fc(x) + y = self._act(y) return y diff --git a/python/paddle/incubate/hapi/tests/test_uncombined_weight2state_dict.py b/python/paddle/incubate/hapi/tests/test_uncombined_weight2state_dict.py index 26ec53014b1c3..6df9b31217aae 100644 --- a/python/paddle/incubate/hapi/tests/test_uncombined_weight2state_dict.py +++ b/python/paddle/incubate/hapi/tests/test_uncombined_weight2state_dict.py @@ -22,7 +22,7 @@ import tempfile from paddle import fluid -from paddle.nn import Conv2d, Pool2D, Linear, ReLU, Sequential +from paddle.nn import Conv2d, Pool2D, Linear, ReLU, Sequential, Softmax from paddle.incubate.hapi.utils import uncombined_weight_to_state_dict @@ -43,10 +43,8 @@ def __init__(self, num_classes=10, classifier_activation='softmax'): if num_classes > 0: self.fc = Sequential( - Linear(400, 120), - Linear(120, 84), - Linear( - 84, 10, act=classifier_activation)) + Linear(400, 120), Linear(120, 84), Linear(84, 10), + Softmax()) #Todo: accept any activation def forward(self, inputs): x = self.features(inputs) diff --git a/python/paddle/incubate/hapi/vision/models/lenet.py b/python/paddle/incubate/hapi/vision/models/lenet.py index dc7b094de0f26..169f70562f6ed 100644 --- a/python/paddle/incubate/hapi/vision/models/lenet.py +++ b/python/paddle/incubate/hapi/vision/models/lenet.py @@ -13,7 +13,7 @@ #limitations under the License. import paddle.fluid as fluid -from paddle.nn import Conv2d, Pool2D, Linear, ReLU, Sequential +from paddle.nn import Conv2d, Pool2D, Linear, ReLU, Sequential, Softmax __all__ = ['LeNet'] @@ -50,10 +50,8 @@ def __init__(self, num_classes=10, classifier_activation='softmax'): if num_classes > 0: self.fc = Sequential( - Linear(400, 120), - Linear(120, 84), - Linear( - 84, 10, act=classifier_activation)) + Linear(400, 120), Linear(120, 84), Linear(84, 10), + Softmax()) #Todo: accept any activation def forward(self, inputs): x = self.features(inputs) diff --git a/python/paddle/incubate/hapi/vision/models/vgg.py b/python/paddle/incubate/hapi/vision/models/vgg.py index 30f6e120b2502..4352a768eb720 100644 --- a/python/paddle/incubate/hapi/vision/models/vgg.py +++ b/python/paddle/incubate/hapi/vision/models/vgg.py @@ -13,7 +13,7 @@ # limitations under the License. import paddle.fluid as fluid -from paddle.nn import Conv2d, Pool2D, BatchNorm, Linear, ReLU +from paddle.nn import Conv2d, Pool2D, BatchNorm, Linear, ReLU, Softmax from paddle.fluid.dygraph.container import Sequential from ...download import get_weights_path_from_url @@ -37,7 +37,8 @@ def __init__(self, num_classes, classifier_activation='softmax'): super(Classifier, self).__init__() self.linear1 = Linear(512 * 7 * 7, 4096) self.linear2 = Linear(4096, 4096) - self.linear3 = Linear(4096, num_classes, act=classifier_activation) + self.linear3 = Linear(4096, num_classes) + self.act = Softmax() #Todo: accept any activation def forward(self, x): x = self.linear1(x) @@ -46,7 +47,8 @@ def forward(self, x): x = self.linear2(x) x = fluid.layers.relu(x) x = fluid.layers.dropout(x, 0.5) - out = self.linear3(x) + x = self.linear3(x) + out = self.act(x) return out diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 645b2115650a1..76063458d44de 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -115,6 +115,7 @@ # from .layer.learning_rate import NoamDecay #DEFINE_ALIAS # from .layer.learning_rate import PiecewiseDecay #DEFINE_ALIAS # from .layer.learning_rate import PolynomialDecay #DEFINE_ALIAS +from .layer.common import Linear # from .layer.loss import NCELoss #DEFINE_ALIAS from .layer.loss import BCEWithLogitsLoss #DEFINE_ALIAS from .layer.loss import CrossEntropyLoss #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 75e2da4cf7e92..414e70853eb71 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -75,6 +75,7 @@ from .common import bilinear #DEFINE_ALIAS from .conv import conv1d #DEFINE_ALIAS from .conv import conv_transpose1d #DEFINE_ALIAS +from .common import linear #DEFINE_ALIAS from .conv import conv2d #DEFINE_ALIAS from .conv import conv_transpose2d #DEFINE_ALIAS from .conv import conv3d #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 6a462b53b753c..8408e224d8737 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -17,7 +17,8 @@ from ...fluid.framework import in_dygraph_mode, default_main_program from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layers.tensor import Variable, fill_constant, zeros, concat - +from ...fluid.layers import core +from ...fluid import dygraph_utils # TODO: define the common functions to build a neural network from ...fluid.layers import label_smooth #DEFINE_ALIAS from ...fluid import one_hot #DEFINE_ALIAS @@ -30,6 +31,10 @@ from ...tensor import clip from ...tensor import sum from ...tensor import sqrt +from ...tensor import sum #DEFINE_ALIAS +from ...tensor import sqrt #DEFINE_ALIAS +from ...fluid.data_feeder import check_variable_and_dtype, check_dtype +from ...fluid.framework import Variable, in_dygraph_mode, _varbase_creator #from ...fluid.layers import fc #DEFINE_ALIAS from ...fluid.layers import pad_constant_like #DEFINE_ALIAS @@ -46,6 +51,7 @@ # 'embedding', # 'fc', 'label_smooth', + 'linear', 'one_hot', 'pad', 'pad_constant_like', @@ -1348,3 +1354,83 @@ def cosine_similarity(x1, x2, axis=1, eps=1e-8): n12 = sqrt(clip(w1 * w2, min=eps * eps)) cos_sim = w12 / n12 return cos_sim + + +def linear(x, weight, bias=None, name=None): + """ + + Fully-connected linear transformation op + + .. math:: + + Out = {XW + b} + + where :math:`X` is the input Tensor, :math:`W` and :math:`b` are weight and bias respectively. + + The linear op multiplies input tensor with weight matrix and + produces an output Tensor of shape [N, *, output_dim], + where N is batch size and `*` means any number of additional dimensions and output_dim is the last dim of ``weight``. + If ``bias`` is not None, a bias will be added to the output. + + Args: + x(Tensor): Input tensor, its data type is float16, float32 or float64 + weight(Tensor): Weight tensor, its data type is float16, float32 or float64 + bias(Tensor|None, optional): Bias tensor, its data type is float16, float32 or float64. If it is set to None, no bias will be added to the output units. + name(str|None, optional): For detailed information, please refer to :ref:`api_guide_Name`. Default: None. + + Returns: + Output tensor + + Examples: + .. code-block:: python + + import numpy as np + import paddle + import paddle.nn.functional as F + + input = np.ones((3,1,2), dtype=np.float32) + weight = np.ones((2,2), dtype=np.float32) + bias = np.ones((2), dtype=np.float32) + place = paddle.CPUPlace() + paddle.disable_static(place) + input = paddle.to_tensor(input) + weight = paddle.to_tensor(weight) + bias = paddle.to_tensor(bias) + out = F.linear(input, weight, bias) + print(out) #[3 3 3 3 3 3] + + """ + if in_dygraph_mode(): + pre_bias = _varbase_creator(dtype=x.dtype) + core.ops.matmul(x, weight, pre_bias, 'transpose_X', False, + 'transpose_Y', False, "alpha", 1) + return dygraph_utils._append_bias_in_dygraph( + pre_bias, bias, axis=len(x.shape) - 1) + else: + helper = LayerHelper('linear', **locals()) + dtype = x.dtype + + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'linear') + check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear') + + inputs = {'X': [x], 'Y': [weight]} + attrs = { + 'transpose_X': False, + 'transpose_Y': False, + 'alpha': 1, + } + tmp = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type='matmul', inputs=inputs, outputs={'Out': tmp}, attrs=attrs) + if bias is not None: + res = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type='elementwise_add', + inputs={'X': [tmp], + 'Y': [bias]}, + outputs={'Out': [res]}, + attrs={'axis': len(x.shape) - 1}) + else: + res = tmp + return res diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 9f32c1365a39d..a1e6508c67d96 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -16,7 +16,6 @@ from ...fluid.dygraph import BilinearTensorProduct #DEFINE_ALIAS from ...fluid.dygraph import Pool2D #DEFINE_ALIAS from ...fluid.dygraph import Embedding #DEFINE_ALIAS -from ...fluid.dygraph import Linear #DEFINE_ALIAS from ...fluid.dygraph import Flatten #DEFINE_ALIAS from ...fluid.dygraph import layers from .. import functional as F @@ -49,6 +48,89 @@ ] +class Linear(layers.Layer): + """ + + Fully-connected linear transformation layer: + + .. math:: + + Out = {XW + b} + + where :math:`X` is the input Tensor, :math:`W` and :math:`b` are weight and bias respectively. + + Linear layer takes only one ``Tensor`` input. + The Linear layer multiplies input tensor with weight matrix and + produces an output Tensor of shape [N, *, `output_dim`], + where N is batch size and `*` means any number of additional dimensions. + If ``bias_attr`` is not None, a bias variable will be created and added to the output. + + Parameters: + in_features(int): The number of input units in this layer. + out_features(int): The number of output units in this layer. + weight_attr(ParamAttr or list of ParamAttr, optional): The parameter attribute for learnable + weights(Parameter) of this layer. Default: None. + bias_attr(ParamAttr or list of ParamAttr, optional): The attribute for the bias + of this layer. If it is set to False, no bias will be added to the output units. + If it is set to None, the bias is initialized zero. Default: None. + name(str|None): For detailed information, please refer to :ref:`api_guide_Name`. Default: None. + + Attributes: + **weight** (Parameter): the learnable weights of this layer. + + **bias** (Parameter or None): the learnable bias of this layer. + + Returns: + None + + Examples: + .. code-block:: python + + import paddle + from paddle import nn + import numpy as np + + data = np.ones((3,1,2), np.float32) + place = paddle.CPUPlace() + paddle.disable_static(place) + data = paddle.to_tensor(data) + weight_attr=paddle.framework.ParamAttr(name="linear_weight", learning_rate=1.0, + trainable=False, regularizer=None, initializer=paddle.fluid.initializer.ConstantInitializer(value=1.0)) + bias_attr=paddle.framework.ParamAttr(name="linear_bias", learning_rate=1.0, + trainable=False, regularizer=None, initializer=paddle.fluid.initializer.ConstantInitializer(value=1.0)) + linear = nn.Linear(2,2,weight_attr=weight_attr, bias_attr=bias_attr) + res = linear(data) # [3 3 3 3 3 3] + """ + + def __init__(self, + in_features, + out_features, + weight_attr=None, + bias_attr=None, + name=None): + super(Linear, self).__init__() + self._dtype = self._helper.get_default_dtype() + self._weight_attr = weight_attr + self._bias_attr = bias_attr + self.name = name + self.weight = self.create_parameter( + shape=[in_features, out_features], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False) + self.bias = self.create_parameter( + shape=[out_features], + attr=self._bias_attr, + dtype=self._dtype, + is_bias=True) + self.name = name + + def forward(self, input): + out = F.linear( + x=input, weight=self.weight, bias=self.bias, name=self.name) + return out + + class UpSample(layers.Layer): """ This op resizes a batch of images.