Skip to content

Commit

Permalink
add maxunpool1d op
Browse files Browse the repository at this point in the history
  • Loading branch information
andyjiang1116 committed Jan 6, 2022
1 parent a2280c5 commit 1bf3f97
Show file tree
Hide file tree
Showing 6 changed files with 353 additions and 0 deletions.
156 changes: 156 additions & 0 deletions python/paddle/fluid/tests/unittests/test_unpool1d_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.nn.functional as F

paddle.enable_static()
paddle.seed(2022)


def _unpool_output_size(x, kernel_size, stride, padding, output_size):
input_size = x.shape
default_size = []
for d in range(len(kernel_size)):
default_size.append((input_size[-len(kernel_size) + d] - 1) * stride[d]
+ kernel_size[d] - 2 * padding[d])
if output_size is None:
ret = default_size
else:
ret = output_size
return ret


def unpool1dmax_forward_naive(input, indices, ksize, strides, paddings,
output_size):
s0, s1, s2 = input.shape
output_size = _unpool_output_size(input, ksize, strides, paddings,
output_size)
out_lsize = output_size[0]
out = np.zeros((s0, s1, out_lsize))
for nidx in range(s0):
for cidx in range(s1):
for l in range(s2):
index = indices[nidx, cidx, l]
lidx = index % out_lsize
out[nidx, cidx, lidx] = input[nidx, cidx, l]

return out


class TestUnpool1DOpAPI_dygraph(unittest.TestCase):
def test_case(self):
places = [paddle.CPUPlace()]
if paddle.fluid.core.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
for place in places:
paddle.disable_static()
input_data = np.random.rand(1, 3, 16)
input_x = paddle.to_tensor(input_data)
output, indices = F.max_pool1d(
input_x, kernel_size=2, stride=2, return_mask=True)
output_unpool = F.max_unpool1d(
output, indices, kernel_size=2, stride=2)
expected_output_unpool = unpool1dmax_forward_naive(
output.numpy(), indices.numpy(), [2], [2], [0], [16])
self.assertTrue(
np.allclose(output_unpool.numpy(), expected_output_unpool))

paddle.enable_static()


class TestUnpool1DOpAPI_dygraph2(unittest.TestCase):
def test_case(self):
places = [paddle.CPUPlace()]
if paddle.fluid.core.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
for place in places:
paddle.disable_static()
input_data = np.random.rand(1, 3, 16)
input_x = paddle.to_tensor(input_data)
output, indices = F.max_pool1d(
input_x, kernel_size=2, stride=2, return_mask=True)
output_unpool = F.max_unpool1d(
output, indices, kernel_size=2, stride=None)
expected_output_unpool = unpool1dmax_forward_naive(
output.numpy(), indices.numpy(), [2], [2], [0], [16])
self.assertTrue(
np.allclose(output_unpool.numpy(), expected_output_unpool))

paddle.enable_static()


class TestUnpool1DOpAPI_dygraph3(unittest.TestCase):
def test_case(self):
places = [paddle.CPUPlace()]
if paddle.fluid.core.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
for place in places:
paddle.disable_static()
input_data = np.random.rand(1, 3, 16)
input_x = paddle.to_tensor(input_data)
Pool1d = paddle.nn.MaxPool1D(
kernel_size=2, stride=2, return_mask=True)
UnPool1d = paddle.nn.MaxUnPool1D(kernel_size=2, stride=2)

output, indices = Pool1d(input_x)
output_unpool = UnPool1d(output, indices)
expected_output_unpool = unpool1dmax_forward_naive(
output.numpy(), indices.numpy(), [2], [2], [0], [16])
self.assertTrue(
np.allclose(output_unpool.numpy(), expected_output_unpool))

paddle.enable_static()


class TestUnpool1DOpAPI_static(unittest.TestCase):
def test_case(self):
paddle.enable_static()
places = [paddle.CPUPlace()]
if paddle.fluid.core.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
for place in places:
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):

input_data = np.array([[[1, 2, 3, 4], [5, 6, 7, 8],
[9, 10, 11, 12]]]).astype("float32")
x = paddle.fluid.data(
name='x', shape=[1, 3, 4], dtype='float32')
output, indices = F.max_pool1d(
x, kernel_size=2, stride=2, return_mask=True)
output_unpool = F.max_unpool1d(
output, indices, kernel_size=2, stride=None)

exe = paddle.fluid.Executor(place)
fetches = exe.run(paddle.fluid.default_main_program(),
feed={"x": input_data},
fetch_list=[output_unpool],
return_numpy=True)
pool1d_out_np = np.array(
[[[2., 4.], [6., 8.], [10., 12.]]]).astype("float32")
indices_np = np.array(
[[[1, 3], [1, 3], [1, 3]]]).astype("int32")
expected_output_unpool = unpool1dmax_forward_naive(
pool1d_out_np, indices_np, [2], [2], [0], [4])
self.assertTrue(np.allclose(fetches[0], expected_output_unpool))


if __name__ == '__main__':
unittest.main()
2 changes: 2 additions & 0 deletions python/paddle/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
from .layer.pooling import MaxPool1D # noqa: F401
from .layer.pooling import MaxPool2D # noqa: F401
from .layer.pooling import MaxPool3D # noqa: F401
from .layer.pooling import MaxUnPool1D # noqa: F401
from .layer.pooling import MaxUnPool2D # noqa: F401
from .layer.pooling import MaxUnPool3D # noqa: F401
from .layer.pooling import AdaptiveAvgPool1D # noqa: F401
Expand Down Expand Up @@ -300,6 +301,7 @@ def weight_norm(*args):
'ReLU6',
'LayerDict',
'ZeroPad2D',
'MaxUnPool1D',
'MaxUnPool2D',
'MaxUnPool3D',
'HingeEmbeddingLoss',
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
from .pooling import adaptive_avg_pool1d # noqa: F401
from .pooling import adaptive_avg_pool2d # noqa: F401
from .pooling import adaptive_avg_pool3d # noqa: F401
from .pooling import max_unpool1d # noqa: F401
from .pooling import max_unpool2d # noqa: F401
from .pooling import max_unpool3d # noqa: F401

Expand Down Expand Up @@ -178,6 +179,7 @@
'max_pool1d',
'max_pool2d',
'max_pool3d',
'max_unpool1d',
'max_unpool2d',
'max_unpool3d',
'adaptive_avg_pool1d',
Expand Down
110 changes: 110 additions & 0 deletions python/paddle/nn/functional/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,116 @@ def _unpool_output_size(x, kernel_size, stride, padding, output_size):
return ret


def max_unpool1d(x,
indices,
kernel_size,
stride=None,
padding=0,
data_format="NCL",
output_size=None,
name=None):
"""
This API implements max unpooling 1d opereation.
`max_unpool1d` accepts the output of `max_pool1d` as input,
including the indices of the maximum value and calculate the partial inverse.
All non-maximum values ​​are set to zero.
- Input: :math:`(N, C, L_{in})`
- Output: :math:`(N, C, L_{out})`, where
.. math::
L_{out} = (L_{in} - 1) * stride - 2 * padding + kernel\_size
or as given by :attr:`output_size` in the call operator.
Args:
x (Tensor): The input tensor of unpooling operator which is a 3-D tensor with
shape [N, C, L]. The format of input tensor is `"NCL"`,
where `N` is batch size, `C` is the number of channels, `L` is
the length of the feature. The data type is float32 or float64.
indices (Tensor): The indices given out by maxpooling1d which is a 3-D tensor with
shape [N, C, L]. The format of input tensor is `"NCL"` ,
where `N` is batch size, `C` is the number of channels, `L` is
the length of the featuree. The data type is float32 or float64.
kernel_size (int|list|tuple): The unpool kernel size. If unpool kernel size is a tuple or list,
it must contain an integer.
stride (int|list|tuple): The unpool stride size. If unpool stride size is a tuple or list,
it must contain an integer.
padding (int | tuple): Padding that was added to the input.
output_size(list|tuple, optional): The target output size. If output_size is not specified,
the actual output shape will be automatically calculated by (input_shape,
kernel_size, stride, padding).
data_format (string): The data format of the input and output data.
The default is `"NCL"`. When it is `"NCL"`, the data is stored in the order of:
`[batch_size, input_channels, input_length]`.
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
Tensor: The output tensor of unpooling result.
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
data = paddle.rand(shape=[1, 3, 16])
pool_out, indices = F.max_pool1d(data, kernel_size=2, stride=2, padding=0, return_mask=True)
# pool_out shape: [1, 3, 8], indices shape: [1, 3, 8]
unpool_out = F.max_unpool1d(pool_out, indices, kernel_size=2, padding=0)
# unpool_out shape: [1, 3, 16]
"""
"""NCL to NCHW"""
data_format = "NCHW"
x = unsqueeze(x, [2])
indices = unsqueeze(indices, [2])
kernel_size = [1] + utils.convert_to_list(kernel_size, 1, 'pool_size')
if stride is None:
stride = kernel_size
else:
stride = [1] + utils.convert_to_list(stride, 1, 'pool_stride')
padding, padding_algorithm = _update_padding_nd(padding, 1)
# use 2d to implenment 1d should expand padding in advance.
padding = _expand_low_nd_padding(padding)

if data_format not in ["NCHW"]:
raise ValueError("Attr(data_format) should be 'NCHW'. Received "
"Attr(data_format): %s." % str(data_format))

output_size = _unpool_output_size(x, kernel_size, stride, padding,
output_size)

if in_dygraph_mode():
output = _C_ops.unpool(x, indices, 'unpooling_type', 'max', 'ksize',
kernel_size, 'strides', stride, 'paddings',
padding, "output_size", output_size,
"data_format", data_format)
return squeeze(output, [2])

op_type = "unpool"
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype(input_param_name="x")
unpool_out = helper.create_variable_for_type_inference(dtype)

helper.append_op(
type=op_type,
inputs={"X": x,
"Indices": indices},
outputs={"Out": unpool_out},
attrs={
"unpooling_type": "max",
"ksize": kernel_size,
"strides": stride,
"paddings": padding,
"output_size": output_size
})
return squeeze(unpool_out, [2])


def max_unpool2d(x,
indices,
kernel_size,
Expand Down
1 change: 1 addition & 0 deletions python/paddle/nn/layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from .pooling import AdaptiveMaxPool1D # noqa: F401
from .pooling import AdaptiveMaxPool2D # noqa: F401
from .pooling import AdaptiveMaxPool3D # noqa: F401
from .pooling import MaxUnPool1D # noqa: F401
from .pooling import MaxUnPool2D # noqa: F401
from .pooling import MaxUnPool3D # noqa: F401
from .conv import Conv1D # noqa: F401
Expand Down
Loading

0 comments on commit 1bf3f97

Please sign in to comment.