From 1bf3f97bd622cbc550b0416672f5ef505531351b Mon Sep 17 00:00:00 2001 From: andyjpaddle Date: Thu, 6 Jan 2022 09:04:52 +0000 Subject: [PATCH] add maxunpool1d op --- .../fluid/tests/unittests/test_unpool1d_op.py | 156 ++++++++++++++++++ python/paddle/nn/__init__.py | 2 + python/paddle/nn/functional/__init__.py | 2 + python/paddle/nn/functional/pooling.py | 110 ++++++++++++ python/paddle/nn/layer/__init__.py | 1 + python/paddle/nn/layer/pooling.py | 82 +++++++++ 6 files changed, 353 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_unpool1d_op.py diff --git a/python/paddle/fluid/tests/unittests/test_unpool1d_op.py b/python/paddle/fluid/tests/unittests/test_unpool1d_op.py new file mode 100644 index 0000000000000..95d19210acb72 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_unpool1d_op.py @@ -0,0 +1,156 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.nn.functional as F + +paddle.enable_static() +paddle.seed(2022) + + +def _unpool_output_size(x, kernel_size, stride, padding, output_size): + input_size = x.shape + default_size = [] + for d in range(len(kernel_size)): + default_size.append((input_size[-len(kernel_size) + d] - 1) * stride[d] + + kernel_size[d] - 2 * padding[d]) + if output_size is None: + ret = default_size + else: + ret = output_size + return ret + + +def unpool1dmax_forward_naive(input, indices, ksize, strides, paddings, + output_size): + s0, s1, s2 = input.shape + output_size = _unpool_output_size(input, ksize, strides, paddings, + output_size) + out_lsize = output_size[0] + out = np.zeros((s0, s1, out_lsize)) + for nidx in range(s0): + for cidx in range(s1): + for l in range(s2): + index = indices[nidx, cidx, l] + lidx = index % out_lsize + out[nidx, cidx, lidx] = input[nidx, cidx, l] + + return out + + +class TestUnpool1DOpAPI_dygraph(unittest.TestCase): + def test_case(self): + places = [paddle.CPUPlace()] + if paddle.fluid.core.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + for place in places: + paddle.disable_static() + input_data = np.random.rand(1, 3, 16) + input_x = paddle.to_tensor(input_data) + output, indices = F.max_pool1d( + input_x, kernel_size=2, stride=2, return_mask=True) + output_unpool = F.max_unpool1d( + output, indices, kernel_size=2, stride=2) + expected_output_unpool = unpool1dmax_forward_naive( + output.numpy(), indices.numpy(), [2], [2], [0], [16]) + self.assertTrue( + np.allclose(output_unpool.numpy(), expected_output_unpool)) + + paddle.enable_static() + + +class TestUnpool1DOpAPI_dygraph2(unittest.TestCase): + def test_case(self): + places = [paddle.CPUPlace()] + if paddle.fluid.core.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + for place in places: + paddle.disable_static() + input_data = np.random.rand(1, 3, 16) + input_x = paddle.to_tensor(input_data) + output, indices = F.max_pool1d( + input_x, kernel_size=2, stride=2, return_mask=True) + output_unpool = F.max_unpool1d( + output, indices, kernel_size=2, stride=None) + expected_output_unpool = unpool1dmax_forward_naive( + output.numpy(), indices.numpy(), [2], [2], [0], [16]) + self.assertTrue( + np.allclose(output_unpool.numpy(), expected_output_unpool)) + + paddle.enable_static() + + +class TestUnpool1DOpAPI_dygraph3(unittest.TestCase): + def test_case(self): + places = [paddle.CPUPlace()] + if paddle.fluid.core.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + for place in places: + paddle.disable_static() + input_data = np.random.rand(1, 3, 16) + input_x = paddle.to_tensor(input_data) + Pool1d = paddle.nn.MaxPool1D( + kernel_size=2, stride=2, return_mask=True) + UnPool1d = paddle.nn.MaxUnPool1D(kernel_size=2, stride=2) + + output, indices = Pool1d(input_x) + output_unpool = UnPool1d(output, indices) + expected_output_unpool = unpool1dmax_forward_naive( + output.numpy(), indices.numpy(), [2], [2], [0], [16]) + self.assertTrue( + np.allclose(output_unpool.numpy(), expected_output_unpool)) + + paddle.enable_static() + + +class TestUnpool1DOpAPI_static(unittest.TestCase): + def test_case(self): + paddle.enable_static() + places = [paddle.CPUPlace()] + if paddle.fluid.core.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + for place in places: + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + + input_data = np.array([[[1, 2, 3, 4], [5, 6, 7, 8], + [9, 10, 11, 12]]]).astype("float32") + x = paddle.fluid.data( + name='x', shape=[1, 3, 4], dtype='float32') + output, indices = F.max_pool1d( + x, kernel_size=2, stride=2, return_mask=True) + output_unpool = F.max_unpool1d( + output, indices, kernel_size=2, stride=None) + + exe = paddle.fluid.Executor(place) + fetches = exe.run(paddle.fluid.default_main_program(), + feed={"x": input_data}, + fetch_list=[output_unpool], + return_numpy=True) + pool1d_out_np = np.array( + [[[2., 4.], [6., 8.], [10., 12.]]]).astype("float32") + indices_np = np.array( + [[[1, 3], [1, 3], [1, 3]]]).astype("int32") + expected_output_unpool = unpool1dmax_forward_naive( + pool1d_out_np, indices_np, [2], [2], [0], [4]) + self.assertTrue(np.allclose(fetches[0], expected_output_unpool)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 60bc5f3111d8f..57e1b710cab0d 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -76,6 +76,7 @@ from .layer.pooling import MaxPool1D # noqa: F401 from .layer.pooling import MaxPool2D # noqa: F401 from .layer.pooling import MaxPool3D # noqa: F401 +from .layer.pooling import MaxUnPool1D # noqa: F401 from .layer.pooling import MaxUnPool2D # noqa: F401 from .layer.pooling import MaxUnPool3D # noqa: F401 from .layer.pooling import AdaptiveAvgPool1D # noqa: F401 @@ -300,6 +301,7 @@ def weight_norm(*args): 'ReLU6', 'LayerDict', 'ZeroPad2D', + 'MaxUnPool1D', 'MaxUnPool2D', 'MaxUnPool3D', 'HingeEmbeddingLoss', diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 7611d06a8f957..683d7ad01b6b8 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -106,6 +106,7 @@ from .pooling import adaptive_avg_pool1d # noqa: F401 from .pooling import adaptive_avg_pool2d # noqa: F401 from .pooling import adaptive_avg_pool3d # noqa: F401 +from .pooling import max_unpool1d # noqa: F401 from .pooling import max_unpool2d # noqa: F401 from .pooling import max_unpool3d # noqa: F401 @@ -178,6 +179,7 @@ 'max_pool1d', 'max_pool2d', 'max_pool3d', + 'max_unpool1d', 'max_unpool2d', 'max_unpool3d', 'adaptive_avg_pool1d', diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index 11b25265d7057..aebd36426bc71 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -664,6 +664,116 @@ def _unpool_output_size(x, kernel_size, stride, padding, output_size): return ret +def max_unpool1d(x, + indices, + kernel_size, + stride=None, + padding=0, + data_format="NCL", + output_size=None, + name=None): + """ + This API implements max unpooling 1d opereation. + `max_unpool1d` accepts the output of `max_pool1d` as input, + including the indices of the maximum value and calculate the partial inverse. + All non-maximum values ​​are set to zero. + + - Input: :math:`(N, C, L_{in})` + - Output: :math:`(N, C, L_{out})`, where + + .. math:: + L_{out} = (L_{in} - 1) * stride - 2 * padding + kernel\_size + + or as given by :attr:`output_size` in the call operator. + + + Args: + x (Tensor): The input tensor of unpooling operator which is a 3-D tensor with + shape [N, C, L]. The format of input tensor is `"NCL"`, + where `N` is batch size, `C` is the number of channels, `L` is + the length of the feature. The data type is float32 or float64. + indices (Tensor): The indices given out by maxpooling1d which is a 3-D tensor with + shape [N, C, L]. The format of input tensor is `"NCL"` , + where `N` is batch size, `C` is the number of channels, `L` is + the length of the featuree. The data type is float32 or float64. + kernel_size (int|list|tuple): The unpool kernel size. If unpool kernel size is a tuple or list, + it must contain an integer. + stride (int|list|tuple): The unpool stride size. If unpool stride size is a tuple or list, + it must contain an integer. + padding (int | tuple): Padding that was added to the input. + output_size(list|tuple, optional): The target output size. If output_size is not specified, + the actual output shape will be automatically calculated by (input_shape, + kernel_size, stride, padding). + data_format (string): The data format of the input and output data. + The default is `"NCL"`. When it is `"NCL"`, the data is stored in the order of: + `[batch_size, input_channels, input_length]`. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + Tensor: The output tensor of unpooling result. + + Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + + data = paddle.rand(shape=[1, 3, 16]) + pool_out, indices = F.max_pool1d(data, kernel_size=2, stride=2, padding=0, return_mask=True) + # pool_out shape: [1, 3, 8], indices shape: [1, 3, 8] + unpool_out = F.max_unpool1d(pool_out, indices, kernel_size=2, padding=0) + # unpool_out shape: [1, 3, 16] + + """ + """NCL to NCHW""" + data_format = "NCHW" + x = unsqueeze(x, [2]) + indices = unsqueeze(indices, [2]) + kernel_size = [1] + utils.convert_to_list(kernel_size, 1, 'pool_size') + if stride is None: + stride = kernel_size + else: + stride = [1] + utils.convert_to_list(stride, 1, 'pool_stride') + padding, padding_algorithm = _update_padding_nd(padding, 1) + # use 2d to implenment 1d should expand padding in advance. + padding = _expand_low_nd_padding(padding) + + if data_format not in ["NCHW"]: + raise ValueError("Attr(data_format) should be 'NCHW'. Received " + "Attr(data_format): %s." % str(data_format)) + + output_size = _unpool_output_size(x, kernel_size, stride, padding, + output_size) + + if in_dygraph_mode(): + output = _C_ops.unpool(x, indices, 'unpooling_type', 'max', 'ksize', + kernel_size, 'strides', stride, 'paddings', + padding, "output_size", output_size, + "data_format", data_format) + return squeeze(output, [2]) + + op_type = "unpool" + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype(input_param_name="x") + unpool_out = helper.create_variable_for_type_inference(dtype) + + helper.append_op( + type=op_type, + inputs={"X": x, + "Indices": indices}, + outputs={"Out": unpool_out}, + attrs={ + "unpooling_type": "max", + "ksize": kernel_size, + "strides": stride, + "paddings": padding, + "output_size": output_size + }) + return squeeze(unpool_out, [2]) + + def max_unpool2d(x, indices, kernel_size, diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 772d6d390bf44..2b50508065605 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -57,6 +57,7 @@ from .pooling import AdaptiveMaxPool1D # noqa: F401 from .pooling import AdaptiveMaxPool2D # noqa: F401 from .pooling import AdaptiveMaxPool3D # noqa: F401 +from .pooling import MaxUnPool1D # noqa: F401 from .pooling import MaxUnPool2D # noqa: F401 from .pooling import MaxUnPool3D # noqa: F401 from .conv import Conv1D # noqa: F401 diff --git a/python/paddle/nn/layer/pooling.py b/python/paddle/nn/layer/pooling.py index dd3feee535657..96942f5c8500a 100755 --- a/python/paddle/nn/layer/pooling.py +++ b/python/paddle/nn/layer/pooling.py @@ -1130,6 +1130,88 @@ def extra_repr(self): self._return_mask) +class MaxUnPool1D(Layer): + """ + This API implements max unpooling 1d opereation. + + `max_unpool1d` accepts the output of `max_pool1d` as input, + including the indices of the maximum value and calculate the partial inverse. + All non-maximum values ​​are set to zero. + + - Input: :math:`(N, C, L_{in})` + - Output: :math:`(N, C, L_{out})`, where + + .. math:: + L_{out} = (L_{in} - 1) * stride - 2 * padding + kernel\_size + + or as given by :attr:`output_size` in the call operator. + + Parameters: + kernel_size (int|list|tuple): The unpool kernel size. If unpool kernel size is a tuple or list, + it must contain an integer. + stride (int|list|tuple): The unpool stride size. If unpool stride size is a tuple or list, + it must contain an integer. + padding (int | tuple): Padding that was added to the input. + output_size(list|tuple, optional): The target output size. If output_size is not specified, + the actual output shape will be automatically calculated by (input_shape, + kernel_size, stride, padding). + data_format (string): The data format of the input and output data. + The default is `"NCL"`. When it is `"NCL"`, the data is stored in the order of: + `[batch_size, input_channels, input_length]`. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + + Returns: + A callable object of MaxUnPool1D. + + Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + import numpy as np + + data = paddle.rand(shape=[1, 3, 16]) + pool_out, indices = F.max_pool1d(data, kernel_size=2, stride=2, padding=0, return_mask=True) + # pool_out shape: [1, 3, 8], indices shape: [1, 3, 8] + Unpool1D = paddle.nn.MaxUnPool1D(kernel_size=2, padding=0) + unpool_out = Unpool1D(pool_out, indices) + # unpool_out shape: [1, 3, 16] + + """ + + def __init__(self, + kernel_size, + stride=None, + padding=0, + data_format="NCL", + output_size=None, + name=None): + super(MaxUnPool1D, self).__init__() + self.ksize = kernel_size + self.stride = stride + self.padding = padding + self.data_format = data_format + self.output_size = output_size + self.name = name + + def forward(self, x, indices): + return F.max_unpool1d( + x, + indices, + kernel_size=self.ksize, + stride=self.stride, + padding=self.padding, + data_format=self.data_format, + output_size=self.output_size, + name=self.name) + + def extra_repr(self): + return 'output_size={}'.format(self.output_size) + + class MaxUnPool2D(Layer): """ This API implements max unpooling 2d opereation.