From 3526cad6c272f4138aeda15a7f02f249b9388e52 Mon Sep 17 00:00:00 2001 From: andyjpaddle Date: Wed, 5 Jan 2022 08:05:51 +0000 Subject: [PATCH 1/7] add maxunpool3d op --- paddle/fluid/operators/math/unpooling.cc | 93 +++++- paddle/fluid/operators/math/unpooling.cu | 113 ++++++- paddle/fluid/operators/math/unpooling.h | 18 +- paddle/fluid/operators/unpool_op.cc | 158 +++++++++- paddle/fluid/operators/unpool_op.cu.cc | 9 +- paddle/fluid/operators/unpool_op.h | 51 ++- .../fluid/tests/unittests/test_unpool3d_op.py | 293 ++++++++++++++++++ python/paddle/nn/__init__.py | 1 + python/paddle/nn/functional/__init__.py | 2 + python/paddle/nn/functional/pooling.py | 110 +++++++ python/paddle/nn/layer/__init__.py | 1 + python/paddle/nn/layer/pooling.py | 90 ++++++ 12 files changed, 933 insertions(+), 6 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_unpool3d_op.py diff --git a/paddle/fluid/operators/math/unpooling.cc b/paddle/fluid/operators/math/unpooling.cc index bcb2b92780cc8..69fd2dbb85246 100644 --- a/paddle/fluid/operators/math/unpooling.cc +++ b/paddle/fluid/operators/math/unpooling.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -96,10 +96,101 @@ class Unpool2dMaxGradFunctor { } } }; + +template +class Unpool3dMaxFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& indices, framework::Tensor* output) { + const int batch_size = input.dims()[0]; + const int input_depth = input.dims()[2]; + const int input_height = input.dims()[3]; + const int input_width = input.dims()[4]; + const int output_channels = output->dims()[1]; + const int output_depth = output->dims()[2]; + const int output_height = output->dims()[3]; + const int output_width = output->dims()[4]; + int input_feasize = input_depth * input_height * input_width; + int output_feasize = output_depth * output_height * output_width; + const T* input_data = input.data(); + const int* indices_data = indices.data(); + T* output_data = output->mutable_data(context.GetPlace()); + for (int b = 0; b < batch_size; ++b) { + for (int c = 0; c < output_channels; ++c) { + for (int i = 0; i < input_feasize; ++i) { + int index = indices_data[i]; + + PADDLE_ENFORCE_LT( + index, output_feasize, + platform::errors::InvalidArgument( + "index should less than output tensor depth * output tensor " + "height " + "* output tensor width. Expected %ld < %ld, but got " + "%ld >= %ld. Please check input value.", + index, output_feasize, index, output_feasize)); + output_data[index] = input_data[i]; + } + input_data += input_feasize; + indices_data += input_feasize; + output_data += output_feasize; + } + } + } +}; +template +class Unpool3dMaxGradFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& indices, + const framework::Tensor& output, + const framework::Tensor& output_grad, + framework::Tensor* input_grad) { + const int batch_size = input.dims()[0]; + const int input_depth = input.dims()[2]; + const int input_height = input.dims()[3]; + const int input_width = input.dims()[4]; + const int output_channels = output.dims()[1]; + const int output_depth = output.dims()[2]; + const int output_height = output.dims()[3]; + const int output_width = output.dims()[4]; + int input_feasize = input_depth * input_height * input_width; + int output_feasize = output_depth * output_height * output_width; + const int* indices_data = indices.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); + + for (int b = 0; b < batch_size; ++b) { + for (int c = 0; c < output_channels; ++c) { + for (int i = 0; i < input_feasize; ++i) { + int index = indices_data[i]; + PADDLE_ENFORCE_LT( + index, output_feasize, + platform::errors::InvalidArgument( + "index should less than output tensor depth * output tensor " + "height " + "* output tensor width. Expected %ld < %ld, but got " + "%ld >= %ld. Please check input value.", + index, output_feasize, index, output_feasize)); + input_grad_data[i] = output_grad_data[index]; + } + input_grad_data += input_feasize; + indices_data += input_feasize; + output_grad_data += output_feasize; + } + } + } +}; + template class Unpool2dMaxGradFunctor; template class Unpool2dMaxGradFunctor; template class Unpool2dMaxFunctor; template class Unpool2dMaxFunctor; +template class Unpool3dMaxGradFunctor; +template class Unpool3dMaxGradFunctor; +template class Unpool3dMaxFunctor; +template class Unpool3dMaxFunctor; } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/unpooling.cu b/paddle/fluid/operators/math/unpooling.cu index dbb3d64350cae..973865caba688 100644 --- a/paddle/fluid/operators/math/unpooling.cu +++ b/paddle/fluid/operators/math/unpooling.cu @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 paddlepaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -51,6 +51,45 @@ __global__ void KernelUnpool2dMaxGrad( /* * All tensors are in NCHW format. */ + +template +__global__ void KernelUnpool3dMax(const int nthreads, const T* input_data, + const int* indices_data, + const int input_depth, const int input_height, + const int input_width, const int channels, + T* output_data, const int output_depth, + const int output_height, + const int output_width) { + CUDA_KERNEL_LOOP(linearIndex, nthreads) { + int c = (linearIndex / input_depth / input_width / input_height) % channels; + int n = linearIndex / input_depth / input_width / input_height / channels; + output_data += + (n * channels + c) * output_depth * output_height * output_width; + int maxind = indices_data[linearIndex]; + output_data[maxind] = input_data[linearIndex]; + } +} + +template +__global__ void KernelUnpool3dMaxGrad( + const int nthreads, const T* input_data, const int* indices_data, + const int input_depth, const int input_height, const int input_width, + const int channels, const T* output_data, const T* output_grad, + const int output_depth, const int output_height, const int output_width, + T* input_grad) { + CUDA_KERNEL_LOOP(linearIndex, nthreads) { + int c = (linearIndex / input_depth / input_width / input_height) % channels; + int n = linearIndex / input_depth / input_width / input_height / channels; + output_grad += + (n * channels + c) * output_depth * output_height * output_width; + int maxind = indices_data[linearIndex]; + input_grad[linearIndex] = output_grad[maxind]; + } +} +/* + * All tensors are in NCDHW format. + */ + template class Unpool2dMaxFunctor { public: @@ -112,10 +151,82 @@ class Unpool2dMaxGradFunctor { output_width, input_grad_data); } }; + +template +class Unpool3dMaxFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& indices, framework::Tensor* output) { + const int batch_size = input.dims()[0]; + const int input_depth = input.dims()[2]; + const int input_height = input.dims()[3]; + const int input_width = input.dims()[4]; + const int output_channels = output->dims()[1]; + const int output_depth = output->dims()[2]; + const int output_height = output->dims()[3]; + const int output_width = output->dims()[4]; + const T* input_data = input.data(); + const int* indices_data = indices.data(); + T* output_data = output->mutable_data(context.GetPlace()); +#ifdef __HIPCC__ + int threads = 256; +#else + int threads = 1024; +#endif + int grid = (input.numel() + threads - 1) / threads; + KernelUnpool3dMax<<>>( + input.numel(), input_data, indices_data, input_depth, input_height, + input_width, output_channels, output_data, output_depth, output_height, + output_width); + } +}; +/* + * All tensors are in NCDHW format. + */ +template +class Unpool3dMaxGradFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& indices, + const framework::Tensor& output, + const framework::Tensor& output_grad, + framework::Tensor* input_grad) { + const int batch_size = input.dims()[0]; + const int input_depth = input.dims()[2]; + const int input_height = input.dims()[3]; + const int input_width = input.dims()[4]; + const int output_channels = output.dims()[1]; + const int output_depth = output.dims()[2]; + const int output_height = output.dims()[3]; + const int output_width = output.dims()[4]; + const T* input_data = input.data(); + const int* indices_data = indices.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); +#ifdef __HIPCC__ + int threads = 256; +#else + int threads = 1024; +#endif + int grid = (input.numel() + threads - 1) / threads; + KernelUnpool3dMaxGrad<<>>( + input.numel(), input_data, indices_data, input_depth, input_height, + input_width, output_channels, output_data, output_grad_data, + output_depth, output_height, output_width, input_grad_data); + } +}; + template class Unpool2dMaxGradFunctor; template class Unpool2dMaxGradFunctor; template class Unpool2dMaxFunctor; template class Unpool2dMaxFunctor; +template class Unpool3dMaxGradFunctor; +template class Unpool3dMaxGradFunctor; +template class Unpool3dMaxFunctor; +template class Unpool3dMaxFunctor; } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/unpooling.h b/paddle/fluid/operators/math/unpooling.h index 74ca39d114e26..63bd8186adeb2 100644 --- a/paddle/fluid/operators/math/unpooling.h +++ b/paddle/fluid/operators/math/unpooling.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -33,6 +33,22 @@ class Unpool2dMaxGradFunctor { const framework::Tensor& output_grad, framework::Tensor* input_grad); }; + +template +class Unpool3dMaxFunctor { + public: + void operator()(const DeviceContext& context, const framework::Tensor& input, + const framework::Tensor& indices, framework::Tensor* output); +}; +template +class Unpool3dMaxGradFunctor { + public: + void operator()(const DeviceContext& context, const framework::Tensor& input, + const framework::Tensor& indices, + const framework::Tensor& output, + const framework::Tensor& output_grad, + framework::Tensor* input_grad); +}; } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/unpool_op.cc b/paddle/fluid/operators/unpool_op.cc index 108cd2722b5ed..8edfb4bc6c52f 100644 --- a/paddle/fluid/operators/unpool_op.cc +++ b/paddle/fluid/operators/unpool_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -76,6 +76,65 @@ Paper: http://www.matthewzeiler.com/wp-content/uploads/2017/07/iccv2011.pdf } }; +class Unpool3dOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput( + "X", + "(Tensor) The input tensor of unpool operator. " + "The format of input tensor is NCDHW. Where N is batch size, C is the " + "number of channels, D, H and W is the depth, height and width of " + "feature."); + AddInput( + "Indices", + "(Tensor) The input tensor of the indices given out by MaxPool3d. " + "The format of input tensor is NCDHW. Where N is batch size, C is the " + "number of channels, D, H and W is the depth, height and width of " + "feature."); + AddOutput("Out", + "(Tensor) The output tensor of unpool operator." + "The format of output tensor is also NCDHW." + "Where N is batch size, C is " + "the number of channels, D, H and W is the depth, height and " + "width of feature."); + AddAttr>( + "ksize", + "(vector), the unpooling window size(depth, height, width) " + "of unpooling operator."); + AddAttr>( + "strides", + "(vector, default:{1, 1, 1}), " + "strides (depth, height, width) of unpooling operator.") + .SetDefault({1, 1, 1}); + AddAttr>( + "paddings", + "(vector default:{0, 0,0}), " + "paddings (depth, height, width) of unpooling operator.") + .SetDefault({0, 0, 0}); + AddAttr( + "unpooling_type", + "(string), unpooling type, can be \"max\" for max-unpooling ") + .InEnum({"max"}); + AddAttr>("output_size", + "(vector, optional). The shape of output.") + .SetDefault({0, 0, 0}); + AddAttr( + "data_format", + "(string, default NCDHW)" + "Defaults to \"NCDHW\". Specify the data format of the output data, ") + .SetDefault("NCDHW"); + AddComment(R"DOC( +Input shape is: $(N, C_{in}, D_{in}, H_{in}, W_{in})$, Output shape is: +$(N, C_{out}, D_{out}, H_{out}, W_{out})$, where +$$ +D_{out} = (D_{in}-1) * strides[0] - 2 * paddings[0] + ksize[0] \\ +H_{out} = (H_{in}-1) * strides[1] - 2 * paddings[1] + ksize[1] \\ +W_{out} = (W_{in}-1) * strides[2] - 2 * paddings[2] + ksize[2] +$$ +)DOC"); + } +}; + int UnpoolOutputSize(int input_size, int ksize, int padding, int stride) { int output_size = (input_size - 1) * stride - 2 * padding + ksize; return output_size; @@ -130,6 +189,55 @@ class UnpoolOp : public framework::OperatorWithKernel { } }; +class Unpool3dOp : public framework::OperatorWithKernel { + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } + + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Unpool3d"); + OP_INOUT_CHECK(ctx->HasInput("Indices"), "Input", "Indices", "Unpool3d"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Unpool3d"); + auto in_x_dims = ctx->GetInputDim("X"); + auto in_y_dims = ctx->GetInputDim("Indices"); + std::string unpooling_type = + ctx->Attrs().Get("unpooling_type"); + std::vector ksize = ctx->Attrs().Get>("ksize"); + std::vector strides = ctx->Attrs().Get>("strides"); + std::vector paddings = ctx->Attrs().Get>("paddings"); + std::vector output_size = + ctx->Attrs().Get>("output_size"); + PADDLE_ENFORCE_EQ(in_x_dims.size() == 5, true, + platform::errors::InvalidArgument( + "Unpool Intput(X) must be of 5-dimensional, but " + "received Input(X)'s dimensions is %d.", + in_x_dims.size())); + PADDLE_ENFORCE_EQ(in_x_dims, in_y_dims, + platform::errors::InvalidArgument( + "The dimensions of Input(X) must equal to be" + "the dimensions of Input(Indices), but received" + "dimensions of Input(X) is [%d], received dimensions" + "of Input(Indices) is [%d]", + in_x_dims, in_y_dims)); + + std::vector output_shape({in_x_dims[0], in_x_dims[1]}); + for (size_t i = 0; i < ksize.size(); ++i) { + if (!ctx->IsRuntime() && in_x_dims[i + 2] <= 0) { + output_shape.push_back(-1); + } else { + output_shape.push_back(output_size[i]); + } + } + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + } +}; + template class UnpoolOpGradMaker : public framework::SingleGradOpMaker { public: @@ -145,6 +253,21 @@ class UnpoolOpGradMaker : public framework::SingleGradOpMaker { } }; +template +class Unpool3dOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + void Apply(GradOpPtr op) const override { + op->SetType(this->ForwardOpType() + "_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput("Indices", this->Input("Indices")); + op->SetInput("Out", this->Output("Out")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + class UnpoolOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( @@ -163,6 +286,26 @@ class UnpoolOpGrad : public framework::OperatorWithKernel { ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } }; + +class Unpool3dOpGrad : public framework::OperatorWithKernel { + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } + + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Unpool3dGrad"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", + framework::GradVarName("X"), "Unpool3dGrad"); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } +}; + } // namespace operators } // namespace paddle @@ -179,3 +322,16 @@ REGISTER_OP_CPU_KERNEL( unpool_grad, ops::UnpoolGradKernel, ops::UnpoolGradKernel); + +REGISTER_OPERATOR(unpool3d, ops::Unpool3dOp, ops::Unpool3dOpMaker, + ops::Unpool3dOpGradMaker, + ops::Unpool3dOpGradMaker); + +REGISTER_OPERATOR(unpool3d_grad, ops::Unpool3dOpGrad); +REGISTER_OP_CPU_KERNEL( + unpool3d, ops::Unpool3dKernel, + ops::Unpool3dKernel); +REGISTER_OP_CPU_KERNEL( + unpool3d_grad, + ops::Unpool3dGradKernel, + ops::Unpool3dGradKernel); diff --git a/paddle/fluid/operators/unpool_op.cu.cc b/paddle/fluid/operators/unpool_op.cu.cc index 7c59a0feaa472..e3cab4426b4d8 100644 --- a/paddle/fluid/operators/unpool_op.cu.cc +++ b/paddle/fluid/operators/unpool_op.cu.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,3 +22,10 @@ REGISTER_OP_CUDA_KERNEL( unpool_grad, ops::UnpoolGradKernel, ops::UnpoolGradKernel); +REGISTER_OP_CUDA_KERNEL( + unpool3d, ops::Unpool3dKernel, + ops::Unpool3dKernel); +REGISTER_OP_CUDA_KERNEL( + unpool3d_grad, + ops::Unpool3dGradKernel, + ops::Unpool3dGradKernel); diff --git a/paddle/fluid/operators/unpool_op.h b/paddle/fluid/operators/unpool_op.h index e388ec5ae3937..52849cb3e0f8e 100644 --- a/paddle/fluid/operators/unpool_op.h +++ b/paddle/fluid/operators/unpool_op.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -69,5 +69,54 @@ class UnpoolGradKernel : public framework::OpKernel { unpool2d_max_backward(device_ctx, *in_x, *in_y, *out, *out_grad, in_x_grad); } }; + +template +class Unpool3dKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const framework::Tensor* in_x = context.Input("X"); + const framework::Tensor* in_y = context.Input("Indices"); + auto* out = context.Output("Out"); + std::string unpooling_type = context.Attr("unpooling_type"); + std::vector ksize = context.Attr>("ksize"); + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + T* output_data = out->mutable_data(context.GetPlace()); + auto& dev_ctx = context.template device_context(); + if (output_data) { + math::SetConstant set_zero; + set_zero(dev_ctx, out, static_cast(0)); + } + math::Unpool3dMaxFunctor unpool3d_max_forward; + unpool3d_max_forward(dev_ctx, *in_x, *in_y, out); + } +}; + +template +class Unpool3dGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const framework::Tensor* in_x = context.Input("X"); + const framework::Tensor* in_y = context.Input("Indices"); + const framework::Tensor* out = context.Input("Out"); + const framework::Tensor* out_grad = + context.Input(framework::GradVarName("Out")); + framework::Tensor* in_x_grad = + context.Output(framework::GradVarName("X")); + std::string unpooling_type = context.Attr("unpooling_type"); + std::vector ksize = context.Attr>("ksize"); + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + + auto& device_ctx = context.template device_context(); + math::SetConstant zero; + + in_x_grad->mutable_data(context.GetPlace()); + zero(device_ctx, in_x_grad, static_cast(0)); + + math::Unpool3dMaxGradFunctor unpool3d_max_backward; + unpool3d_max_backward(device_ctx, *in_x, *in_y, *out, *out_grad, in_x_grad); + } +}; } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_unpool3d_op.py b/python/paddle/fluid/tests/unittests/test_unpool3d_op.py new file mode 100644 index 0000000000000..e6031d9cee8b1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_unpool3d_op.py @@ -0,0 +1,293 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.nn.functional as F + +paddle.enable_static() +paddle.seed(2022) + + +def _unpool_output_size(x, kernel_size, stride, padding, output_size): + input_size = x.shape + default_size = [] + for d in range(len(kernel_size)): + default_size.append((input_size[-len(kernel_size) + d] - 1) * stride[d] + + kernel_size[d] - 2 * padding[d]) + if output_size is None: + ret = default_size + else: + ret = output_size + return ret + + +def unpool3dmax_forward_naive(input, indices, ksize, strides, paddings, + output_size): + s0, s1, s2, s3, s4 = input.shape + output_size = _unpool_output_size(input, ksize, strides, paddings, + output_size) + out_dsize = output_size[0] + out_hsize = output_size[1] + out_wsize = output_size[2] + out = np.zeros((s0, s1, out_dsize, out_hsize, out_wsize)) + for nidx in range(s0): + for cidx in range(s1): + for d in range(s2): + for h in range(s3): + for w in range(s4): + index = indices[nidx, cidx, d, h, w] + didx = index // (out_wsize * out_hsize) + hidx = ( + index - didx * out_hsize * out_wsize) // out_wsize + widx = ( + index - didx * out_hsize * out_wsize) % out_wsize + out[nidx, cidx, didx, hidx, widx] = \ + input[nidx, cidx, d, h, w] + + return out + + +class TestUnpool3DOp(OpTest): + def setUp(self): + self.op_type = "unpool3d" + self.init_test_case() + inputs = np.random.randint(0, 100, self.shape) + nsize, csize, dsize, hsize, wsize = inputs.shape + self.output_size = _unpool_output_size(inputs, self.ksize, self.strides, + self.paddings, self.output_size) + indices = np.random.permutation( + np.arange(0, self.output_size[0] * self.output_size[1] * + self.output_size[2]))[:dsize * hsize * wsize] + indices = np.reshape(indices, [dsize, hsize, wsize]) + idx_list = [] + for n in range(nsize): + c_list = [] + for c in range(csize): + c_list.append(indices.tolist()) + idx_list.append(c_list) + indices = np.array(idx_list) + + output = self.unpool3d_forward_naive(inputs, indices, self.ksize, \ + self.strides, self.paddings, self.output_size).astype("float64") + + self.inputs = { + 'X': inputs.astype('float64'), + 'Indices': indices.astype('int32') + } + self.attrs = { + 'strides': self.strides, + 'paddings': self.paddings, + 'ksize': self.ksize, + 'unpooling_type': self.unpooling_type, + 'output_size': self.output_size, + } + self.outputs = {'Out': output.astype('float64')} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + def init_test_case(self): + self.unpool3d_forward_naive = unpool3dmax_forward_naive + self.unpooling_type = "max" + self.shape = [1, 1, 4, 5, 6] + self.ksize = [2, 2, 2] + self.strides = [2, 2, 2] + self.paddings = [0, 0, 0] + self.output_size = None + + +class TestUnpool3DOpcase1(TestUnpool3DOp): + def init_test_case(self): + self.unpool3d_forward_naive = unpool3dmax_forward_naive + self.unpooling_type = "max" + self.shape = [1, 3, 4, 5, 6] + self.ksize = [2, 2, 2] + self.strides = [2, 2, 2] + self.paddings = [0, 0, 0] + self.output_size = None + + +class TestUnpool3DOpOutput(TestUnpool3DOp): + def init_test_case(self): + self.unpool3d_forward_naive = unpool3dmax_forward_naive + self.unpooling_type = "max" + self.shape = [1, 3, 4, 5, 6] + self.ksize = [2, 2, 2] + self.strides = [2, 2, 2] + self.paddings = [0, 0, 0] + self.output_size = [7, 9, 11] + + +class TestUnpool3DOpException(unittest.TestCase): + def test_exception(self): + def indices_size_error(): + data = paddle.randint(shape=[1, 1, 3, 3, 3]) + indices = paddle.reshape( + paddle.arange(0, 36), shape=[1, 1, 3, 3, 4]) + MaxUnPool3D = F.maxunpool3d(data, indices, kernel_size=2, stride=2) + + def indices_value_error(): + data = paddle.randint(shape=[1, 1, 3, 3, 3]) + indices = paddle.reshape( + paddle.arange(4, 40), shape=[1, 1, 3, 3, 3]) + MaxUnPool3D = F.maxunpool3d(data, indices, kernel_size=2, stride=2) + + def data_format_error(): + data = paddle.randint(shape=[1, 1, 3, 3, 3]) + indices = paddle.reshape( + paddle.arange(0, 27), shape=[1, 1, 3, 3, 3]) + MaxUnPool3D = F.maxunpool3d( + data, indices, kernel_size=2, stride=2, data_format="NDHWC") + + def data_outputsize_error(): + data = paddle.randint(shape=[1, 1, 3, 3, 3]) + indices = paddle.reshape( + paddle.arange(0, 27), shape=[1, 1, 3, 3, 3]) + MaxUnPool3D = F.maxunpool3d( + data, + indices, + kernel_size=2, + stride=2, + output_size=[2, 2, 3, 4, 5]) + + def data_outputsize_error2(): + data = paddle.randint(shape=[1, 1, 3, 3, 3]) + indices = paddle.reshape( + paddle.arange(0, 27), shape=[1, 1, 3, 3, 3]) + MaxUnPool3D = F.maxunpool3d( + data, + indices, + kernel_size=2, + stride=2, + output_size=[10, 10, 10]) + + self.assertRaises(ValueError, indices_size_error) + self.assertRaises(ValueError, indices_value_error) + self.assertRaises(ValueError, data_format_error) + self.assertRaises(ValueError, data_outputsize_error) + self.assertRaises(ValueError, data_outputsize_error2) + + +class TestUnpool3DOpAPI_dygraph(unittest.TestCase): + def test_case(self): + places = [paddle.CPUPlace()] + if paddle.fluid.core.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + for place in places: + paddle.disable_static() + input_data = np.random.rand(1, 3, 4, 4, 6) + input_x = paddle.to_tensor(input_data) + output, indices = F.max_pool3d( + input_x, kernel_size=2, stride=2, return_mask=True) + output_unpool = F.max_unpool3d( + output, indices, kernel_size=2, stride=2) + expected_output_unpool = unpool3dmax_forward_naive( + output.numpy(), + indices.numpy(), [2, 2, 2], [2, 2, 2], [0, 0, 0], [4, 4, 6]) + self.assertTrue( + np.allclose(output_unpool.numpy(), expected_output_unpool)) + + paddle.enable_static() + + +class TestUnpool3DOpAPI_dygraph2(unittest.TestCase): + def test_case(self): + places = [paddle.CPUPlace()] + if paddle.fluid.core.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + for place in places: + paddle.disable_static() + input_data = np.random.rand(1, 3, 4, 4, 6) + input_x = paddle.to_tensor(input_data) + output, indices = F.max_pool3d( + input_x, kernel_size=2, stride=2, return_mask=True) + output_unpool = F.max_unpool3d( + output, indices, kernel_size=2, stride=None) + expected_output_unpool = unpool3dmax_forward_naive( + output.numpy(), + indices.numpy(), [2, 2, 2], [2, 2, 2], [0, 0, 0], [4, 4, 6]) + self.assertTrue( + np.allclose(output_unpool.numpy(), expected_output_unpool)) + + paddle.enable_static() + + +class TestUnpool3DOpAPI_dygraph3(unittest.TestCase): + def test_case(self): + places = [paddle.CPUPlace()] + if paddle.fluid.core.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + for place in places: + paddle.disable_static() + input_data = np.random.rand(1, 3, 4, 4, 6) + input_x = paddle.to_tensor(input_data) + Pool3d = paddle.nn.MaxPool3D( + kernel_size=2, stride=2, return_mask=True) + UnPool3d = paddle.nn.MaxUnPool3D(kernel_size=2, stride=2) + + output, indices = Pool3d(input_x) + output_unpool = UnPool3d(output, indices) + expected_output_unpool = unpool3dmax_forward_naive( + output.numpy(), + indices.numpy(), [2, 2, 2], [2, 2, 2], [0, 0, 0], [4, 4, 6]) + self.assertTrue( + np.allclose(output_unpool.numpy(), expected_output_unpool)) + + paddle.enable_static() + + +class TestUnpool3DOpAPI_static(unittest.TestCase): + def test_case(self): + paddle.enable_static() + places = [paddle.CPUPlace()] + if paddle.fluid.core.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + for place in places: + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + + input_data = np.array([[[[[1, 2, 3, 4], [5, 6, 7, 8], \ + [9, 10, 11, 12], [13, 14, 15, 16]], [[1, 2, 3, 4], [5, 6, 7, 8], \ + [9, 10, 11, 12], [13, 14, 15, 16]]]]]).astype("float32") + x = paddle.fluid.data( + name='x', shape=[1, 1, 2, 4, 4], dtype='float32') + output, indices = F.max_pool3d( + x, kernel_size=2, stride=2, return_mask=True) + output_unpool = F.max_unpool3d( + output, indices, kernel_size=2, stride=None) + + exe = paddle.fluid.Executor(place) + fetches = exe.run(paddle.fluid.default_main_program(), + feed={"x": input_data}, + fetch_list=[output_unpool], + return_numpy=True) + pool3d_out_np = np.array( + [[[[[6., 8.], [14., 16.]]]]]).astype("float32") + indices_np = np.array([[[[[5, 7], [13, 15]]]]]).astype("int32") + expected_output_unpool = unpool3dmax_forward_naive( + pool3d_out_np, indices_np, [2, 2, 2], [2, 2, 2], [0, 0, 0], + [2, 4, 4]) + self.assertTrue(np.allclose(fetches[0], expected_output_unpool)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 37df0d4446767..f817ebaeeffca 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -77,6 +77,7 @@ from .layer.pooling import MaxPool2D # noqa: F401 from .layer.pooling import MaxPool3D # noqa: F401 from .layer.pooling import MaxUnPool2D # noqa: F401 +from .layer.pooling import MaxUnPool3D # noqa: F401 from .layer.pooling import AdaptiveAvgPool1D # noqa: F401 from .layer.pooling import AdaptiveAvgPool2D # noqa: F401 from .layer.pooling import AdaptiveAvgPool3D # noqa: F401 diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 676d7259f2843..7611d06a8f957 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -107,6 +107,7 @@ from .pooling import adaptive_avg_pool2d # noqa: F401 from .pooling import adaptive_avg_pool3d # noqa: F401 from .pooling import max_unpool2d # noqa: F401 +from .pooling import max_unpool3d # noqa: F401 from .vision import affine_grid # noqa: F401 from .vision import grid_sample # noqa: F401 @@ -178,6 +179,7 @@ 'max_pool2d', 'max_pool3d', 'max_unpool2d', + 'max_unpool3d', 'adaptive_avg_pool1d', 'adaptive_avg_pool2d', 'adaptive_avg_pool3d', diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index 27f4d4a7db345..52ced060775d8 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -779,6 +779,116 @@ def max_unpool2d(x, return unpool_out +def max_unpool3d(x, + indices, + kernel_size, + stride=None, + padding=0, + data_format="NCDHW", + output_size=None, + name=None): + """ + This API implements max unpooling 3d opereation. + See more details in :ref:`api_nn_pooling_MaxUnPool3D` . + + + Args: + x (Tensor): The input tensor of unpooling operator which is a 5-D tensor with + shape [N, C, D, H, W]. The format of input tensor is `"NCDHW"`, + where `N` is batch size, `C` is the number of channels, `D` is + the depth of the feature, `H` is the height of the feature, + and `W` is the width of the feature. The data type is float32 or float64. + indices (Tensor): The indices given out by maxpooling3d which is a 5-D tensor with + shape [N, C, D, H, W]. The format of input tensor is `"NCDHW"` , + where `N` is batch size, `C` is the number of channels, `D` is + the depth of the feature, `H` is the height of the feature, + and `W` is the width of the feature. The data type is float32 or float64. + kernel_size (int|list|tuple): The unpool kernel size. If unpool kernel size is a tuple or list, + it must contain an integer. + stride (int|list|tuple): The unpool stride size. If unpool stride size is a tuple or list, + it must contain an integer. + padding (int | tuple): Padding that was added to the input. + output_size(list|tuple, optional): The target output size. If output_size is not specified, + the actual output shape will be automatically calculated by (input_shape, + kernel_size, stride, padding). + data_format (string): The data format of the input and output data. + The default is `"NCDHW"`. When it is `"NCDHW"`, the data is stored in the order of: + `[batch_size, input_channels, input_depth, input_height, input_width]`. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})`, where + .. math:: + D_{out} = (D_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]} + + .. math:: + H_{out} = (H_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]} + + .. math:: + W_{out} = (W_{in} - 1) \times \text{stride[2]} - 2 \times \text{padding[2]} + \text{kernel\_size[2]} + + or as given by :attr:`output_size` in the call operator + + Returns: + Tensor: The output tensor of unpooling result. + + Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + + data = paddle.rand(shape=[1, 1, 6, 6, 6]) + pool_out, indices = F.max_pool3d(data, kernel_size=2, stride=2, padding=0, return_mask=True) + # pool_out shape: [1, 1, 3, 3, 3], indices shape: [1, 1, 3, 3, 3] + unpool_out = F.max_unpool3d(pool_out, indices, kernel_size=2, padding=0) + # unpool_out shape: [1, 1, 6, 6, 6] + + """ + kernel_size = utils.convert_to_list(kernel_size, 3, 'pool_size') + if stride is None: + stride = kernel_size + else: + stride = utils.convert_to_list(stride, 3, 'pool_stride') + padding = utils.convert_to_list(padding, 3, 'padding') + + if data_format not in ["NCDHW"]: + raise ValueError("Attr(data_format) should be 'NCDHW'. Received " + "Attr(data_format): %s." % str(data_format)) + + output_size = _unpool_output_size(x, kernel_size, stride, padding, + output_size) + + if in_dygraph_mode(): + output = _C_ops.unpool3d(x, indices, 'unpooling_type', 'max', 'ksize', + kernel_size, 'strides', stride, 'paddings', + padding, "output_size", output_size, + "data_format", data_format) + return output + + op_type = "unpool3d" + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype(input_param_name="x") + unpool_out = helper.create_variable_for_type_inference(dtype) + + helper.append_op( + type=op_type, + inputs={"X": x, + "Indices": indices}, + outputs={"Out": unpool_out}, + attrs={ + "unpooling_type": "max", + "ksize": kernel_size, + "strides": stride, + "paddings": padding, + "output_size": output_size + }) + return unpool_out + + def max_pool2d(x, kernel_size, stride=None, diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index f536c3d5ff379..772d6d390bf44 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -58,6 +58,7 @@ from .pooling import AdaptiveMaxPool2D # noqa: F401 from .pooling import AdaptiveMaxPool3D # noqa: F401 from .pooling import MaxUnPool2D # noqa: F401 +from .pooling import MaxUnPool3D # noqa: F401 from .conv import Conv1D # noqa: F401 from .conv import Conv2D # noqa: F401 from .conv import Conv3D # noqa: F401 diff --git a/python/paddle/nn/layer/pooling.py b/python/paddle/nn/layer/pooling.py index cc49db9b2056f..0fe310aaef7d5 100755 --- a/python/paddle/nn/layer/pooling.py +++ b/python/paddle/nn/layer/pooling.py @@ -1214,3 +1214,93 @@ def forward(self, x, indices): def extra_repr(self): return 'output_size={}'.format(self.output_size) + + +class MaxUnPool3D(Layer): + """ + This API implements max unpooling 3d opereation. + + 'max_unpool3d' accepts the output of 'max_pool3d' as input + Including the indices of the maximum value and calculating the partial inverse + All non-maximum values ​​are set to zero. + + + Parameters: + kernel_size (int|list|tuple): The unpool kernel size. If unpool kernel size is a tuple or list, + it must contain an integer. + stride (int|list|tuple): The unpool stride size. If unpool stride size is a tuple or list, + it must contain an integer. + padding (int | tuple): Padding that was added to the input. + output_size(list|tuple, optional): The target output size. If output_size is not specified, + the actual output shape will be automatically calculated by (input_shape, + kernel_size, stride, padding). + data_format (string): The data format of the input and output data. + The default is `"NCDHW"`. When it is `"NCDHW"`, the data is stored in the order of: + `[batch_size, input_channels, input_depth, input_height, input_width]`. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})`, where + .. math:: + D_{out} = (D_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]} + + .. math:: + H_{out} = (H_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]} + + .. math:: + W_{out} = (W_{in} - 1) \times \text{stride[2]} - 2 \times \text{padding[2]} + \text{kernel\_size[2]} + + or as given by :attr:`output_size` in the call operator + + Returns: + A callable object of MaxUnPool3D. + + + + Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + import numpy as np + + data = paddle.rand(shape=[1, 1, 7, 7, 7]) + pool_out, indices = F.max_pool3d(data, kernel_size=2, stride=2, padding=0, return_mask=True) + # pool_out shape: [1, 1, 3, 3, 3], indices shape: [1, 1, 3, 3, 3] + Unpool3D = paddle.nn.MaxUnPool3D(kernel_size=2, padding=0) + unpool_out = Unpool3D(pool_out, indices) + # unpool_out shape: [1, 1, 6, 6, 6] + + """ + + def __init__(self, + kernel_size, + stride=None, + padding=0, + data_format="NCDHW", + output_size=None, + name=None): + super(MaxUnPool3D, self).__init__() + self.ksize = kernel_size + self.stride = stride + self.padding = padding + self.data_format = data_format + self.output_size = output_size + self.name = name + + def forward(self, x, indices): + return F.max_unpool3d( + x, + indices, + kernel_size=self.ksize, + stride=self.stride, + padding=self.padding, + data_format=self.data_format, + output_size=self.output_size, + name=self.name) + + def extra_repr(self): + return 'output_size={}'.format(self.output_size) From 6e99901b7a41c048d7a6bcbfd160ab7a9b491fde Mon Sep 17 00:00:00 2001 From: andyjpaddle Date: Wed, 5 Jan 2022 11:21:50 +0000 Subject: [PATCH 2/7] update doc for maxunpool3d op --- python/paddle/nn/__init__.py | 1 + python/paddle/nn/functional/pooling.py | 42 ++++++++++++----------- python/paddle/nn/layer/pooling.py | 47 +++++++++++++------------- 3 files changed, 46 insertions(+), 44 deletions(-) diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index f817ebaeeffca..60bc5f3111d8f 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -301,5 +301,6 @@ def weight_norm(*args): 'LayerDict', 'ZeroPad2D', 'MaxUnPool2D', + 'MaxUnPool3D', 'HingeEmbeddingLoss', ] diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index 52ced060775d8..8a360b7270fb5 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -789,7 +789,23 @@ def max_unpool3d(x, name=None): """ This API implements max unpooling 3d opereation. - See more details in :ref:`api_nn_pooling_MaxUnPool3D` . + 'max_unpool3d' accepts the output of 'max_pool3d' as input + Including the indices of the maximum value and calculating the partial inverse + All non-maximum values ​​are set to zero. + + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})`, where + + .. math:: + D_{out} = (D_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]} + + .. math:: + H_{out} = (H_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]} + + .. math:: + W_{out} = (W_{in} - 1) \times \text{stride[2]} - 2 \times \text{padding[2]} + \text{kernel\_size[2]} + + or as given by :attr:`output_size` in the call operator Args: @@ -818,26 +834,12 @@ def max_unpool3d(x, to :ref:`api_guide_Name`. Usually name is no need to set and None by default. + Returns: + Tensor: The output tensor of unpooling result. - - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` - - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})`, where - .. math:: - D_{out} = (D_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]} - - .. math:: - H_{out} = (H_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]} - - .. math:: - W_{out} = (W_{in} - 1) \times \text{stride[2]} - 2 \times \text{padding[2]} + \text{kernel\_size[2]} - - or as given by :attr:`output_size` in the call operator - - Returns: - Tensor: The output tensor of unpooling result. - - Examples: - .. code-block:: python - + Examples: + .. code-block:: python + import paddle import paddle.nn.functional as F diff --git a/python/paddle/nn/layer/pooling.py b/python/paddle/nn/layer/pooling.py index 0fe310aaef7d5..a2deddea77222 100755 --- a/python/paddle/nn/layer/pooling.py +++ b/python/paddle/nn/layer/pooling.py @@ -1223,8 +1223,22 @@ class MaxUnPool3D(Layer): 'max_unpool3d' accepts the output of 'max_pool3d' as input Including the indices of the maximum value and calculating the partial inverse All non-maximum values ​​are set to zero. + + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})`, where + .. math:: + D_{out} = (D_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]} + + .. math:: + H_{out} = (H_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]} + + .. math:: + W_{out} = (W_{in} - 1) \times \text{stride[2]} - 2 \times \text{padding[2]} + \text{kernel\_size[2]} + or as given by :attr:`output_size` in the call operator + + Parameters: kernel_size (int|list|tuple): The unpool kernel size. If unpool kernel size is a tuple or list, it must contain an integer. @@ -1242,37 +1256,22 @@ class MaxUnPool3D(Layer): None by default. - - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` - - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})`, where - .. math:: - D_{out} = (D_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]} - - .. math:: - H_{out} = (H_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]} - - .. math:: - W_{out} = (W_{in} - 1) \times \text{stride[2]} - 2 \times \text{padding[2]} + \text{kernel\_size[2]} - - or as given by :attr:`output_size` in the call operator - Returns: A callable object of MaxUnPool3D. - - Examples: .. code-block:: python - import paddle - import paddle.nn.functional as F - import numpy as np + import paddle + import paddle.nn.functional as F + import numpy as np - data = paddle.rand(shape=[1, 1, 7, 7, 7]) - pool_out, indices = F.max_pool3d(data, kernel_size=2, stride=2, padding=0, return_mask=True) - # pool_out shape: [1, 1, 3, 3, 3], indices shape: [1, 1, 3, 3, 3] - Unpool3D = paddle.nn.MaxUnPool3D(kernel_size=2, padding=0) - unpool_out = Unpool3D(pool_out, indices) - # unpool_out shape: [1, 1, 6, 6, 6] + data = paddle.rand(shape=[1, 1, 7, 7, 7]) + pool_out, indices = F.max_pool3d(data, kernel_size=2, stride=2, padding=0, return_mask=True) + # pool_out shape: [1, 1, 3, 3, 3], indices shape: [1, 1, 3, 3, 3] + Unpool3D = paddle.nn.MaxUnPool3D(kernel_size=2, padding=0) + unpool_out = Unpool3D(pool_out, indices) + # unpool_out shape: [1, 1, 6, 6, 6] """ From 499418c6c42e4d95c63a0dd648c94b259a58cd79 Mon Sep 17 00:00:00 2001 From: andyjpaddle Date: Wed, 5 Jan 2022 11:57:24 +0000 Subject: [PATCH 3/7] update doc for maxunpool3d op --- python/paddle/nn/functional/pooling.py | 12 ++++++------ python/paddle/nn/layer/pooling.py | 10 +++++----- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index 8a360b7270fb5..87bb4e4f40192 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -789,21 +789,21 @@ def max_unpool3d(x, name=None): """ This API implements max unpooling 3d opereation. - 'max_unpool3d' accepts the output of 'max_pool3d' as input - Including the indices of the maximum value and calculating the partial inverse + `max_unpool3d` accepts the output of `max_pool3d` as input, + including the indices of the maximum value and calculate the partial inverse. All non-maximum values ​​are set to zero. - + - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})`, where .. math:: - D_{out} = (D_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]} + D_{out} = (D_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]} .. math:: - H_{out} = (H_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]} + H_{out} = (H_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]} .. math:: - W_{out} = (W_{in} - 1) \times \text{stride[2]} - 2 \times \text{padding[2]} + \text{kernel\_size[2]} + W_{out} = (W_{in} - 1) \times \text{stride[2]} - 2 \times \text{padding[2]} + \text{kernel\_size[2]} or as given by :attr:`output_size` in the call operator diff --git a/python/paddle/nn/layer/pooling.py b/python/paddle/nn/layer/pooling.py index a2deddea77222..2578ad9226d53 100755 --- a/python/paddle/nn/layer/pooling.py +++ b/python/paddle/nn/layer/pooling.py @@ -1220,21 +1220,21 @@ class MaxUnPool3D(Layer): """ This API implements max unpooling 3d opereation. - 'max_unpool3d' accepts the output of 'max_pool3d' as input - Including the indices of the maximum value and calculating the partial inverse + `max_unpool3d` accepts the output of `max_pool3d` as input, + including the indices of the maximum value and calculate the partial inverse. All non-maximum values ​​are set to zero. - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})`, where .. math:: - D_{out} = (D_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]} + D_{out} = (D_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]} .. math:: - H_{out} = (H_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]} + H_{out} = (H_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]} .. math:: - W_{out} = (W_{in} - 1) \times \text{stride[2]} - 2 \times \text{padding[2]} + \text{kernel\_size[2]} + W_{out} = (W_{in} - 1) \times \text{stride[2]} - 2 \times \text{padding[2]} + \text{kernel\_size[2]} or as given by :attr:`output_size` in the call operator From 9b2599453cd6dbb1db724f08658e54e17ff02bf6 Mon Sep 17 00:00:00 2001 From: andyjpaddle Date: Wed, 5 Jan 2022 12:29:40 +0000 Subject: [PATCH 4/7] update doc for maxunpool3d op --- python/paddle/nn/functional/pooling.py | 6 +++--- python/paddle/nn/layer/pooling.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index 87bb4e4f40192..cc62875ffc9aa 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -797,13 +797,13 @@ def max_unpool3d(x, - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})`, where .. math:: - D_{out} = (D_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]} + D_{out} = (D_{in} - 1) * stride[0] - 2 * padding[0] + kernel\_size[0] .. math:: - H_{out} = (H_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]} + H_{out} = (H_{in} - 1) * stride[1] - 2 * padding[1] + kernel\_size[1] .. math:: - W_{out} = (W_{in} - 1) \times \text{stride[2]} - 2 \times \text{padding[2]} + \text{kernel\_size[2]} + W_{out} = (W_{in} - 1) * stride[2] - 2 * padding[2] + kernel\_size[2] or as given by :attr:`output_size` in the call operator diff --git a/python/paddle/nn/layer/pooling.py b/python/paddle/nn/layer/pooling.py index 2578ad9226d53..fd31bba922e9e 100755 --- a/python/paddle/nn/layer/pooling.py +++ b/python/paddle/nn/layer/pooling.py @@ -1228,13 +1228,13 @@ class MaxUnPool3D(Layer): - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})`, where .. math:: - D_{out} = (D_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]} + D_{out} = (D_{in} - 1) * stride[0] - 2 * padding[0] + kernel\_size[0] .. math:: - H_{out} = (H_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]} + H_{out} = (H_{in} - 1) * stride[1] - 2 * padding[1] + kernel\_size[1] .. math:: - W_{out} = (W_{in} - 1) \times \text{stride[2]} - 2 \times \text{padding[2]} + \text{kernel\_size[2]} + W_{out} = (W_{in} - 1) * stride[2] - 2 * padding[2] + kernel\_size[2] or as given by :attr:`output_size` in the call operator From a2280c5d6b00575b249d45cde2e9f2496db98571 Mon Sep 17 00:00:00 2001 From: andyjpaddle Date: Thu, 6 Jan 2022 02:15:44 +0000 Subject: [PATCH 5/7] update sample code for maxunpool3d --- python/paddle/nn/functional/pooling.py | 6 +++--- python/paddle/nn/layer/pooling.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index cc62875ffc9aa..11b25265d7057 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -843,11 +843,11 @@ def max_unpool3d(x, import paddle import paddle.nn.functional as F - data = paddle.rand(shape=[1, 1, 6, 6, 6]) + data = paddle.rand(shape=[1, 1, 4, 4, 6]) pool_out, indices = F.max_pool3d(data, kernel_size=2, stride=2, padding=0, return_mask=True) - # pool_out shape: [1, 1, 3, 3, 3], indices shape: [1, 1, 3, 3, 3] + # pool_out shape: [1, 1, 2, 2, 3], indices shape: [1, 1, 2, 2, 3] unpool_out = F.max_unpool3d(pool_out, indices, kernel_size=2, padding=0) - # unpool_out shape: [1, 1, 6, 6, 6] + # unpool_out shape: [1, 1, 4, 4, 6] """ kernel_size = utils.convert_to_list(kernel_size, 3, 'pool_size') diff --git a/python/paddle/nn/layer/pooling.py b/python/paddle/nn/layer/pooling.py index fd31bba922e9e..dd3feee535657 100755 --- a/python/paddle/nn/layer/pooling.py +++ b/python/paddle/nn/layer/pooling.py @@ -1266,12 +1266,12 @@ class MaxUnPool3D(Layer): import paddle.nn.functional as F import numpy as np - data = paddle.rand(shape=[1, 1, 7, 7, 7]) + data = paddle.rand(shape=[1, 1, 4, 4, 6]) pool_out, indices = F.max_pool3d(data, kernel_size=2, stride=2, padding=0, return_mask=True) - # pool_out shape: [1, 1, 3, 3, 3], indices shape: [1, 1, 3, 3, 3] + # pool_out shape: [1, 1, 2, 2, 3], indices shape: [1, 1, 2, 2, 3] Unpool3D = paddle.nn.MaxUnPool3D(kernel_size=2, padding=0) unpool_out = Unpool3D(pool_out, indices) - # unpool_out shape: [1, 1, 6, 6, 6] + # unpool_out shape: [1, 1, 4, 4, 6] """ From 1bf3f97bd622cbc550b0416672f5ef505531351b Mon Sep 17 00:00:00 2001 From: andyjpaddle Date: Thu, 6 Jan 2022 09:04:52 +0000 Subject: [PATCH 6/7] add maxunpool1d op --- .../fluid/tests/unittests/test_unpool1d_op.py | 156 ++++++++++++++++++ python/paddle/nn/__init__.py | 2 + python/paddle/nn/functional/__init__.py | 2 + python/paddle/nn/functional/pooling.py | 110 ++++++++++++ python/paddle/nn/layer/__init__.py | 1 + python/paddle/nn/layer/pooling.py | 82 +++++++++ 6 files changed, 353 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_unpool1d_op.py diff --git a/python/paddle/fluid/tests/unittests/test_unpool1d_op.py b/python/paddle/fluid/tests/unittests/test_unpool1d_op.py new file mode 100644 index 0000000000000..95d19210acb72 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_unpool1d_op.py @@ -0,0 +1,156 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.nn.functional as F + +paddle.enable_static() +paddle.seed(2022) + + +def _unpool_output_size(x, kernel_size, stride, padding, output_size): + input_size = x.shape + default_size = [] + for d in range(len(kernel_size)): + default_size.append((input_size[-len(kernel_size) + d] - 1) * stride[d] + + kernel_size[d] - 2 * padding[d]) + if output_size is None: + ret = default_size + else: + ret = output_size + return ret + + +def unpool1dmax_forward_naive(input, indices, ksize, strides, paddings, + output_size): + s0, s1, s2 = input.shape + output_size = _unpool_output_size(input, ksize, strides, paddings, + output_size) + out_lsize = output_size[0] + out = np.zeros((s0, s1, out_lsize)) + for nidx in range(s0): + for cidx in range(s1): + for l in range(s2): + index = indices[nidx, cidx, l] + lidx = index % out_lsize + out[nidx, cidx, lidx] = input[nidx, cidx, l] + + return out + + +class TestUnpool1DOpAPI_dygraph(unittest.TestCase): + def test_case(self): + places = [paddle.CPUPlace()] + if paddle.fluid.core.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + for place in places: + paddle.disable_static() + input_data = np.random.rand(1, 3, 16) + input_x = paddle.to_tensor(input_data) + output, indices = F.max_pool1d( + input_x, kernel_size=2, stride=2, return_mask=True) + output_unpool = F.max_unpool1d( + output, indices, kernel_size=2, stride=2) + expected_output_unpool = unpool1dmax_forward_naive( + output.numpy(), indices.numpy(), [2], [2], [0], [16]) + self.assertTrue( + np.allclose(output_unpool.numpy(), expected_output_unpool)) + + paddle.enable_static() + + +class TestUnpool1DOpAPI_dygraph2(unittest.TestCase): + def test_case(self): + places = [paddle.CPUPlace()] + if paddle.fluid.core.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + for place in places: + paddle.disable_static() + input_data = np.random.rand(1, 3, 16) + input_x = paddle.to_tensor(input_data) + output, indices = F.max_pool1d( + input_x, kernel_size=2, stride=2, return_mask=True) + output_unpool = F.max_unpool1d( + output, indices, kernel_size=2, stride=None) + expected_output_unpool = unpool1dmax_forward_naive( + output.numpy(), indices.numpy(), [2], [2], [0], [16]) + self.assertTrue( + np.allclose(output_unpool.numpy(), expected_output_unpool)) + + paddle.enable_static() + + +class TestUnpool1DOpAPI_dygraph3(unittest.TestCase): + def test_case(self): + places = [paddle.CPUPlace()] + if paddle.fluid.core.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + for place in places: + paddle.disable_static() + input_data = np.random.rand(1, 3, 16) + input_x = paddle.to_tensor(input_data) + Pool1d = paddle.nn.MaxPool1D( + kernel_size=2, stride=2, return_mask=True) + UnPool1d = paddle.nn.MaxUnPool1D(kernel_size=2, stride=2) + + output, indices = Pool1d(input_x) + output_unpool = UnPool1d(output, indices) + expected_output_unpool = unpool1dmax_forward_naive( + output.numpy(), indices.numpy(), [2], [2], [0], [16]) + self.assertTrue( + np.allclose(output_unpool.numpy(), expected_output_unpool)) + + paddle.enable_static() + + +class TestUnpool1DOpAPI_static(unittest.TestCase): + def test_case(self): + paddle.enable_static() + places = [paddle.CPUPlace()] + if paddle.fluid.core.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + for place in places: + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + + input_data = np.array([[[1, 2, 3, 4], [5, 6, 7, 8], + [9, 10, 11, 12]]]).astype("float32") + x = paddle.fluid.data( + name='x', shape=[1, 3, 4], dtype='float32') + output, indices = F.max_pool1d( + x, kernel_size=2, stride=2, return_mask=True) + output_unpool = F.max_unpool1d( + output, indices, kernel_size=2, stride=None) + + exe = paddle.fluid.Executor(place) + fetches = exe.run(paddle.fluid.default_main_program(), + feed={"x": input_data}, + fetch_list=[output_unpool], + return_numpy=True) + pool1d_out_np = np.array( + [[[2., 4.], [6., 8.], [10., 12.]]]).astype("float32") + indices_np = np.array( + [[[1, 3], [1, 3], [1, 3]]]).astype("int32") + expected_output_unpool = unpool1dmax_forward_naive( + pool1d_out_np, indices_np, [2], [2], [0], [4]) + self.assertTrue(np.allclose(fetches[0], expected_output_unpool)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 60bc5f3111d8f..57e1b710cab0d 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -76,6 +76,7 @@ from .layer.pooling import MaxPool1D # noqa: F401 from .layer.pooling import MaxPool2D # noqa: F401 from .layer.pooling import MaxPool3D # noqa: F401 +from .layer.pooling import MaxUnPool1D # noqa: F401 from .layer.pooling import MaxUnPool2D # noqa: F401 from .layer.pooling import MaxUnPool3D # noqa: F401 from .layer.pooling import AdaptiveAvgPool1D # noqa: F401 @@ -300,6 +301,7 @@ def weight_norm(*args): 'ReLU6', 'LayerDict', 'ZeroPad2D', + 'MaxUnPool1D', 'MaxUnPool2D', 'MaxUnPool3D', 'HingeEmbeddingLoss', diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 7611d06a8f957..683d7ad01b6b8 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -106,6 +106,7 @@ from .pooling import adaptive_avg_pool1d # noqa: F401 from .pooling import adaptive_avg_pool2d # noqa: F401 from .pooling import adaptive_avg_pool3d # noqa: F401 +from .pooling import max_unpool1d # noqa: F401 from .pooling import max_unpool2d # noqa: F401 from .pooling import max_unpool3d # noqa: F401 @@ -178,6 +179,7 @@ 'max_pool1d', 'max_pool2d', 'max_pool3d', + 'max_unpool1d', 'max_unpool2d', 'max_unpool3d', 'adaptive_avg_pool1d', diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index 11b25265d7057..aebd36426bc71 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -664,6 +664,116 @@ def _unpool_output_size(x, kernel_size, stride, padding, output_size): return ret +def max_unpool1d(x, + indices, + kernel_size, + stride=None, + padding=0, + data_format="NCL", + output_size=None, + name=None): + """ + This API implements max unpooling 1d opereation. + `max_unpool1d` accepts the output of `max_pool1d` as input, + including the indices of the maximum value and calculate the partial inverse. + All non-maximum values ​​are set to zero. + + - Input: :math:`(N, C, L_{in})` + - Output: :math:`(N, C, L_{out})`, where + + .. math:: + L_{out} = (L_{in} - 1) * stride - 2 * padding + kernel\_size + + or as given by :attr:`output_size` in the call operator. + + + Args: + x (Tensor): The input tensor of unpooling operator which is a 3-D tensor with + shape [N, C, L]. The format of input tensor is `"NCL"`, + where `N` is batch size, `C` is the number of channels, `L` is + the length of the feature. The data type is float32 or float64. + indices (Tensor): The indices given out by maxpooling1d which is a 3-D tensor with + shape [N, C, L]. The format of input tensor is `"NCL"` , + where `N` is batch size, `C` is the number of channels, `L` is + the length of the featuree. The data type is float32 or float64. + kernel_size (int|list|tuple): The unpool kernel size. If unpool kernel size is a tuple or list, + it must contain an integer. + stride (int|list|tuple): The unpool stride size. If unpool stride size is a tuple or list, + it must contain an integer. + padding (int | tuple): Padding that was added to the input. + output_size(list|tuple, optional): The target output size. If output_size is not specified, + the actual output shape will be automatically calculated by (input_shape, + kernel_size, stride, padding). + data_format (string): The data format of the input and output data. + The default is `"NCL"`. When it is `"NCL"`, the data is stored in the order of: + `[batch_size, input_channels, input_length]`. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + Tensor: The output tensor of unpooling result. + + Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + + data = paddle.rand(shape=[1, 3, 16]) + pool_out, indices = F.max_pool1d(data, kernel_size=2, stride=2, padding=0, return_mask=True) + # pool_out shape: [1, 3, 8], indices shape: [1, 3, 8] + unpool_out = F.max_unpool1d(pool_out, indices, kernel_size=2, padding=0) + # unpool_out shape: [1, 3, 16] + + """ + """NCL to NCHW""" + data_format = "NCHW" + x = unsqueeze(x, [2]) + indices = unsqueeze(indices, [2]) + kernel_size = [1] + utils.convert_to_list(kernel_size, 1, 'pool_size') + if stride is None: + stride = kernel_size + else: + stride = [1] + utils.convert_to_list(stride, 1, 'pool_stride') + padding, padding_algorithm = _update_padding_nd(padding, 1) + # use 2d to implenment 1d should expand padding in advance. + padding = _expand_low_nd_padding(padding) + + if data_format not in ["NCHW"]: + raise ValueError("Attr(data_format) should be 'NCHW'. Received " + "Attr(data_format): %s." % str(data_format)) + + output_size = _unpool_output_size(x, kernel_size, stride, padding, + output_size) + + if in_dygraph_mode(): + output = _C_ops.unpool(x, indices, 'unpooling_type', 'max', 'ksize', + kernel_size, 'strides', stride, 'paddings', + padding, "output_size", output_size, + "data_format", data_format) + return squeeze(output, [2]) + + op_type = "unpool" + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype(input_param_name="x") + unpool_out = helper.create_variable_for_type_inference(dtype) + + helper.append_op( + type=op_type, + inputs={"X": x, + "Indices": indices}, + outputs={"Out": unpool_out}, + attrs={ + "unpooling_type": "max", + "ksize": kernel_size, + "strides": stride, + "paddings": padding, + "output_size": output_size + }) + return squeeze(unpool_out, [2]) + + def max_unpool2d(x, indices, kernel_size, diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 772d6d390bf44..2b50508065605 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -57,6 +57,7 @@ from .pooling import AdaptiveMaxPool1D # noqa: F401 from .pooling import AdaptiveMaxPool2D # noqa: F401 from .pooling import AdaptiveMaxPool3D # noqa: F401 +from .pooling import MaxUnPool1D # noqa: F401 from .pooling import MaxUnPool2D # noqa: F401 from .pooling import MaxUnPool3D # noqa: F401 from .conv import Conv1D # noqa: F401 diff --git a/python/paddle/nn/layer/pooling.py b/python/paddle/nn/layer/pooling.py index dd3feee535657..96942f5c8500a 100755 --- a/python/paddle/nn/layer/pooling.py +++ b/python/paddle/nn/layer/pooling.py @@ -1130,6 +1130,88 @@ def extra_repr(self): self._return_mask) +class MaxUnPool1D(Layer): + """ + This API implements max unpooling 1d opereation. + + `max_unpool1d` accepts the output of `max_pool1d` as input, + including the indices of the maximum value and calculate the partial inverse. + All non-maximum values ​​are set to zero. + + - Input: :math:`(N, C, L_{in})` + - Output: :math:`(N, C, L_{out})`, where + + .. math:: + L_{out} = (L_{in} - 1) * stride - 2 * padding + kernel\_size + + or as given by :attr:`output_size` in the call operator. + + Parameters: + kernel_size (int|list|tuple): The unpool kernel size. If unpool kernel size is a tuple or list, + it must contain an integer. + stride (int|list|tuple): The unpool stride size. If unpool stride size is a tuple or list, + it must contain an integer. + padding (int | tuple): Padding that was added to the input. + output_size(list|tuple, optional): The target output size. If output_size is not specified, + the actual output shape will be automatically calculated by (input_shape, + kernel_size, stride, padding). + data_format (string): The data format of the input and output data. + The default is `"NCL"`. When it is `"NCL"`, the data is stored in the order of: + `[batch_size, input_channels, input_length]`. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + + Returns: + A callable object of MaxUnPool1D. + + Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + import numpy as np + + data = paddle.rand(shape=[1, 3, 16]) + pool_out, indices = F.max_pool1d(data, kernel_size=2, stride=2, padding=0, return_mask=True) + # pool_out shape: [1, 3, 8], indices shape: [1, 3, 8] + Unpool1D = paddle.nn.MaxUnPool1D(kernel_size=2, padding=0) + unpool_out = Unpool1D(pool_out, indices) + # unpool_out shape: [1, 3, 16] + + """ + + def __init__(self, + kernel_size, + stride=None, + padding=0, + data_format="NCL", + output_size=None, + name=None): + super(MaxUnPool1D, self).__init__() + self.ksize = kernel_size + self.stride = stride + self.padding = padding + self.data_format = data_format + self.output_size = output_size + self.name = name + + def forward(self, x, indices): + return F.max_unpool1d( + x, + indices, + kernel_size=self.ksize, + stride=self.stride, + padding=self.padding, + data_format=self.data_format, + output_size=self.output_size, + name=self.name) + + def extra_repr(self): + return 'output_size={}'.format(self.output_size) + + class MaxUnPool2D(Layer): """ This API implements max unpooling 2d opereation. From c15be4cd2c2cd975cb201b817f2882b3f1474086 Mon Sep 17 00:00:00 2001 From: andyjpaddle Date: Fri, 7 Jan 2022 06:02:39 +0000 Subject: [PATCH 7/7] update some code for maxunpool1d --- python/paddle/nn/functional/pooling.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index aebd36426bc71..db9665f7a32c4 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -728,6 +728,9 @@ def max_unpool1d(x, """ """NCL to NCHW""" + if data_format not in ["NCL"]: + raise ValueError("Attr(data_format) should be 'NCL'. Received " + "Attr(data_format): %s." % str(data_format)) data_format = "NCHW" x = unsqueeze(x, [2]) indices = unsqueeze(indices, [2]) @@ -740,10 +743,6 @@ def max_unpool1d(x, # use 2d to implenment 1d should expand padding in advance. padding = _expand_low_nd_padding(padding) - if data_format not in ["NCHW"]: - raise ValueError("Attr(data_format) should be 'NCHW'. Received " - "Attr(data_format): %s." % str(data_format)) - output_size = _unpool_output_size(x, kernel_size, stride, padding, output_size)