From 34c6d6f9e2c606a7dd7d75aa6036abab16132936 Mon Sep 17 00:00:00 2001 From: megemini Date: Fri, 27 Oct 2023 17:57:30 +0800 Subject: [PATCH 01/18] [Init] init commit --- paddle/phi/api/yaml/backward.yaml | 8 +++--- paddle/phi/api/yaml/ops.yaml | 4 +-- paddle/phi/infermeta/backward.cc | 1 + paddle/phi/infermeta/backward.h | 1 + paddle/phi/infermeta/unary.cc | 3 ++- paddle/phi/infermeta/unary.h | 1 + paddle/phi/kernels/funcs/pooling.cc | 14 ++++++++++ paddle/phi/kernels/funcs/pooling.cu | 4 +++ paddle/phi/kernels/funcs/pooling.h | 21 +++++++++++++++ .../phi/kernels/impl/pool_grad_kernel_impl.h | 27 ++++++++++++++++--- paddle/phi/kernels/impl/pool_kernel_impl.h | 27 ++++++++++++++++--- paddle/phi/kernels/pool_grad_kernel.h | 2 ++ paddle/phi/kernels/pool_kernel.h | 2 ++ 13 files changed, 100 insertions(+), 15 deletions(-) diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 5e39b764fa96d..8a639f0742169 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1471,8 +1471,8 @@ func : matrix_power_grad - backward_op : max_pool2d_with_index_grad - forward : max_pool2d_with_index(Tensor x, int[] kernel_size, int[] strides = {1, 1}, int[] paddings = {0, 0}, bool global_pooling = false, bool adaptive = false) -> Tensor(out), Tensor(mask) - args : (Tensor x, Tensor mask, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive) + forward : max_pool2d_with_index(Tensor x, int[] kernel_size, int[] strides = {1, 1}, int[] paddings = {0, 0}, bool global_pooling = false, bool adaptive = false, bool fractional = false) -> Tensor(out), Tensor(mask) + args : (Tensor x, Tensor mask, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive, bool fractional) output : Tensor(x_grad) infer_meta : func : MaxPoolWithIndexGradInferMeta @@ -1480,8 +1480,8 @@ func : max_pool2d_with_index_grad - backward_op : max_pool3d_with_index_grad - forward : max_pool3d_with_index(Tensor x, int[] kernel_size, int[] strides = {1, 1, 1}, int[] paddings = {0, 0, 0}, bool global_pooling = false, bool adaptive = false) -> Tensor(out), Tensor(mask) - args : (Tensor x, Tensor mask, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive) + forward : max_pool3d_with_index(Tensor x, int[] kernel_size, int[] strides = {1, 1, 1}, int[] paddings = {0, 0, 0}, bool global_pooling = false, bool adaptive = false, bool fractional = false) -> Tensor(out), Tensor(mask) + args : (Tensor x, Tensor mask, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive, bool fractional) output : Tensor(x_grad) infer_meta : func : MaxPoolWithIndexGradInferMeta diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index b3c6d31c710ec..6e9bf535d32bc 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1673,7 +1673,7 @@ backward : matrix_power_grad - op : max_pool2d_with_index - args : (Tensor x, int[] kernel_size, int[] strides= {1, 1}, int[] paddings = {0, 0}, bool global_pooling = false, bool adaptive = false) + args : (Tensor x, int[] kernel_size, int[] strides= {1, 1}, int[] paddings = {0, 0}, bool global_pooling = false, bool adaptive = false, bool fractional = false) output : Tensor(out), Tensor(mask) infer_meta : func : MaxPoolWithIndexInferMeta @@ -1682,7 +1682,7 @@ backward : max_pool2d_with_index_grad - op : max_pool3d_with_index - args : (Tensor x, int[] kernel_size, int[] strides = {1, 1, 1}, int[] paddings = {0, 0, 0}, bool global_pooling = false, bool adaptive = false) + args : (Tensor x, int[] kernel_size, int[] strides = {1, 1, 1}, int[] paddings = {0, 0, 0}, bool global_pooling = false, bool adaptive = false, bool fractional = false) output : Tensor(out), Tensor(mask) infer_meta : func : MaxPoolWithIndexInferMeta diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 4c5e130aab7a0..481b49bdb482d 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -655,6 +655,7 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x, const std::vector& paddings, bool global_pooling, bool adaptive, + bool fractional, MetaTensor* dx) { dx->share_meta(x); } diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 13dd392344f97..612af367e944c 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -308,6 +308,7 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x, const std::vector& paddings, bool global_pooling, bool adaptive, + bool fractional, MetaTensor* dx); void MeshgridGradInferMeta(const std::vector& inputs, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 8873a617ef303..b4382476491c9 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2224,6 +2224,7 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x, const std::vector& paddings, bool global_pooling, bool adaptive, + bool fractional, MetaTensor* out, MetaTensor* mask, MetaConfig config) { @@ -2270,7 +2271,7 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x, kernel_size_.size())); std::vector output_shape({x_dims[0], x_dims[1]}); - if (adaptive) { + if (adaptive || fractional) { output_shape.insert( output_shape.end(), kernel_size_.begin(), kernel_size_.end()); } else { diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 8a28d454e42f7..e151ffda065b2 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -333,6 +333,7 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x, const std::vector& paddings, bool global_pooling, bool adaptive, + bool fractional, MetaTensor* out, MetaTensor* mask, MetaConfig config = MetaConfig()); diff --git a/paddle/phi/kernels/funcs/pooling.cc b/paddle/phi/kernels/funcs/pooling.cc index 0573430c2010c..4e876b666b8e4 100644 --- a/paddle/phi/kernels/funcs/pooling.cc +++ b/paddle/phi/kernels/funcs/pooling.cc @@ -1571,6 +1571,7 @@ class MaxPool2dWithIndexFunctor { const std::vector& strides, const std::vector& paddings, bool adaptive, + bool fractional, DenseTensor* output, DenseTensor* mask) { const int batch_size = static_cast(input.dims()[0]); @@ -1600,6 +1601,8 @@ class MaxPool2dWithIndexFunctor { if (adaptive) { hstart = AdaptStartIndex(ph, input_height, output_height); hend = AdaptEndIndex(ph, input_height, output_height); + } else if (fractional) { + // TODO(megemini) } else { hstart = ph * stride_height - padding_height; hend = std::min(hstart + ksize_height, input_height); @@ -1609,6 +1612,8 @@ class MaxPool2dWithIndexFunctor { if (adaptive) { wstart = AdaptStartIndex(pw, input_width, output_width); wend = AdaptEndIndex(pw, input_width, output_width); + } else if (fractional) { + // TODO(megemini) } else { wstart = pw * stride_width - padding_width; wend = std::min(wstart + ksize_width, input_width); @@ -1653,6 +1658,7 @@ class MaxPool2dWithIndexGradFunctor { const std::vector& strides UNUSED, const std::vector& paddings UNUSED, bool adaptive UNUSED, + bool fractional UNUSED, DenseTensor* input_grad) { const int batch_size = static_cast(input_grad->dims()[0]); const int input_height = static_cast(input_grad->dims()[2]); @@ -1704,6 +1710,7 @@ class MaxPool3dWithIndexFunctor { const std::vector& strides, const std::vector& paddings, bool adaptive, + bool fractional, DenseTensor* output, DenseTensor* mask) { const int batch_size = static_cast(input.dims()[0]); @@ -1739,6 +1746,8 @@ class MaxPool3dWithIndexFunctor { if (adaptive) { dstart = AdaptStartIndex(pd, input_depth, output_depth); dend = AdaptEndIndex(pd, input_depth, output_depth); + } else if (fractional) { + /* TODO(megemini) */ } else { dstart = pd * stride_depth - padding_depth; dend = std::min(dstart + ksize_depth, input_depth); @@ -1748,6 +1757,8 @@ class MaxPool3dWithIndexFunctor { if (adaptive) { hstart = AdaptStartIndex(ph, input_height, output_height); hend = AdaptEndIndex(ph, input_height, output_height); + } else if (fractional) { + /* TODO(megemini) */ } else { hstart = ph * stride_height - padding_height; hend = std::min(hstart + ksize_height, input_height); @@ -1757,6 +1768,8 @@ class MaxPool3dWithIndexFunctor { if (adaptive) { wstart = AdaptStartIndex(pw, input_width, output_width); wend = AdaptEndIndex(pw, input_width, output_width); + } else if (fractional) { + // TODO(megemini) } else { wstart = pw * stride_width - padding_width; wend = std::min(wstart + ksize_width, input_width); @@ -1806,6 +1819,7 @@ class MaxPool3dWithIndexGradFunctor { const std::vector& strides UNUSED, const std::vector& paddings UNUSED, bool adaptive UNUSED, + bool fractional UNUSED, DenseTensor* input_grad) { const int batch_size = static_cast(input_grad->dims()[0]); const int input_depth = static_cast(input_grad->dims()[2]); diff --git a/paddle/phi/kernels/funcs/pooling.cu b/paddle/phi/kernels/funcs/pooling.cu index 2f89b51815e64..6d378ac6f3c54 100644 --- a/paddle/phi/kernels/funcs/pooling.cu +++ b/paddle/phi/kernels/funcs/pooling.cu @@ -2115,6 +2115,7 @@ class MaxPool2dWithIndexFunctor { const std::vector& strides, const std::vector& paddings, bool adaptive, + bool fractional, DenseTensor* output, DenseTensor* mask) { const int batch_size = input.dims()[0]; @@ -2209,6 +2210,7 @@ class MaxPool2dWithIndexGradFunctor { const std::vector& strides, const std::vector& paddings, bool adaptive, + bool fractional, DenseTensor* input_grad) { const int batch_size = input_grad->dims()[0]; const int input_channels = input_grad->dims()[1]; @@ -2415,6 +2417,7 @@ class MaxPool3dWithIndexFunctor { const std::vector& strides, const std::vector& paddings, bool adaptive, + bool fractional, DenseTensor* output, DenseTensor* mask) { const int batch_size = input.dims()[0]; @@ -2498,6 +2501,7 @@ class MaxPool3dWithIndexGradFunctor { const std::vector& strides, const std::vector& paddings, bool adaptive, + bool fractional, DenseTensor* input_grad) { const int batch_size = input_grad->dims()[0]; const int input_channels = input_grad->dims()[1]; diff --git a/paddle/phi/kernels/funcs/pooling.h b/paddle/phi/kernels/funcs/pooling.h index bf2409d2e502b..fc25d1562fe45 100644 --- a/paddle/phi/kernels/funcs/pooling.h +++ b/paddle/phi/kernels/funcs/pooling.h @@ -100,6 +100,23 @@ HOSTDEVICE inline int AdaptEndIndex(int ph, int input_size, int output_size) { ceil(static_cast((ph + 1) * input_size) / output_size)); } +/* used for fractional pool to calculate start and end index of each divided + * grid + */ +HOSTDEVICE inline int FractionalStartIndex(int ph, + int input_size, + int output_size) { + return static_cast( + floor(static_cast(ph * input_size) / output_size)); +} + +HOSTDEVICE inline int FractionalEndIndex(int ph, + int input_size, + int output_size) { + return static_cast( + ceil(static_cast((ph + 1) * input_size) / output_size)); +} + /* * \brief Getting pooling results, and calculating gradient. * @@ -322,6 +339,7 @@ class MaxPool2dWithIndexFunctor { const std::vector& strides, const std::vector& paddings, bool adaptive, + bool fractional, DenseTensor* output, DenseTensor* mask); }; @@ -336,6 +354,7 @@ class MaxPool2dWithIndexGradFunctor { const std::vector& strides, const std::vector& paddings, bool adaptive, + bool fractional, DenseTensor* input_grad); }; @@ -348,6 +367,7 @@ class MaxPool3dWithIndexFunctor { const std::vector& strides, const std::vector& paddings, bool adaptive, + bool fractional, DenseTensor* output, DenseTensor* mask); }; @@ -362,6 +382,7 @@ class MaxPool3dWithIndexGradFunctor { const std::vector& strides, const std::vector& paddings, bool adaptive, + bool fractional, DenseTensor* input_grad); }; diff --git a/paddle/phi/kernels/impl/pool_grad_kernel_impl.h b/paddle/phi/kernels/impl/pool_grad_kernel_impl.h index e3e19370c86bf..ec7edc9892940 100644 --- a/paddle/phi/kernels/impl/pool_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/pool_grad_kernel_impl.h @@ -150,6 +150,7 @@ void MaxPoolWithIndexGradRawKernel(const Context& ctx, const std::vector& paddings, bool global_pooling, bool adaptive, + bool fractional, DenseTensor* dx) { std::vector paddings_ = paddings; std::vector kernel_size_ = kernel_size; @@ -168,13 +169,27 @@ void MaxPoolWithIndexGradRawKernel(const Context& ctx, switch (kernel_size_.size()) { case 2: { funcs::MaxPool2dWithIndexGradFunctor pool2d_backward; - pool2d_backward( - ctx, dout, mask, kernel_size_, strides, paddings_, adaptive, dx); + pool2d_backward(ctx, + dout, + mask, + kernel_size_, + strides, + paddings_, + adaptive, + fractional, + dx); } break; case 3: { funcs::MaxPool3dWithIndexGradFunctor pool3d_backward; - pool3d_backward( - ctx, dout, mask, kernel_size_, strides, paddings_, adaptive, dx); + pool3d_backward(ctx, + dout, + mask, + kernel_size_, + strides, + paddings_, + adaptive, + fractional, + dx); } break; default: { PADDLE_THROW( @@ -262,6 +277,7 @@ void MaxPool2dWithIndexGradKernel(const Context& ctx, const std::vector& paddings, bool global_pooling, bool adaptive, + bool fractional, DenseTensor* dx) { MaxPoolWithIndexGradRawKernel(ctx, x, @@ -272,6 +288,7 @@ void MaxPool2dWithIndexGradKernel(const Context& ctx, paddings, global_pooling, adaptive, + fractional, dx); } @@ -317,6 +334,7 @@ void MaxPool3dWithIndexGradKernel(const Context& ctx, const std::vector& paddings, bool global_pooling, bool adaptive, + bool fractional, DenseTensor* dx) { MaxPoolWithIndexGradRawKernel(ctx, x, @@ -327,6 +345,7 @@ void MaxPool3dWithIndexGradKernel(const Context& ctx, paddings, global_pooling, adaptive, + fractional, dx); } diff --git a/paddle/phi/kernels/impl/pool_kernel_impl.h b/paddle/phi/kernels/impl/pool_kernel_impl.h index a2a6705a68302..8acf268fd8665 100644 --- a/paddle/phi/kernels/impl/pool_kernel_impl.h +++ b/paddle/phi/kernels/impl/pool_kernel_impl.h @@ -192,6 +192,7 @@ void MaxPoolWithIndexRawKernel(const Context& ctx, const std::vector& paddings, bool global_pooling, bool adaptive, + bool fractional, DenseTensor* out, DenseTensor* mask) { std::vector paddings_ = paddings; @@ -207,13 +208,27 @@ void MaxPoolWithIndexRawKernel(const Context& ctx, switch (kernel_size_.size()) { case 2: { funcs::MaxPool2dWithIndexFunctor pool2d_forward; - pool2d_forward( - ctx, x, kernel_size_, strides, paddings_, adaptive, out, mask); + pool2d_forward(ctx, + x, + kernel_size_, + strides, + paddings_, + adaptive, + fractional, + out, + mask); } break; case 3: { funcs::MaxPool3dWithIndexFunctor pool3d_forward; - pool3d_forward( - ctx, x, kernel_size_, strides, paddings_, adaptive, out, mask); + pool3d_forward(ctx, + x, + kernel_size_, + strides, + paddings_, + adaptive, + fractional, + out, + mask); } break; default: { PADDLE_THROW( @@ -260,6 +275,7 @@ void MaxPool2dWithIndexKernel(const Context& ctx, const std::vector& paddings, bool global_pooling, bool adaptive, + bool fractional, DenseTensor* out, DenseTensor* mask) { MaxPoolWithIndexRawKernel(ctx, @@ -269,6 +285,7 @@ void MaxPool2dWithIndexKernel(const Context& ctx, paddings, global_pooling, adaptive, + fractional, out, mask); } @@ -309,6 +326,7 @@ void MaxPool3dWithIndexKernel(const Context& ctx, const std::vector& paddings, bool global_pooling, bool adaptive, + bool fractional, DenseTensor* out, DenseTensor* mask) { MaxPoolWithIndexRawKernel(ctx, @@ -318,6 +336,7 @@ void MaxPool3dWithIndexKernel(const Context& ctx, paddings, global_pooling, adaptive, + fractional, out, mask); } diff --git a/paddle/phi/kernels/pool_grad_kernel.h b/paddle/phi/kernels/pool_grad_kernel.h index 64ad99a6d3eae..5a87ede3f3737 100644 --- a/paddle/phi/kernels/pool_grad_kernel.h +++ b/paddle/phi/kernels/pool_grad_kernel.h @@ -96,6 +96,7 @@ void MaxPool2dWithIndexGradKernel(const Context& ctx, const std::vector& paddings, bool global_pooling, bool adaptive, + bool fracional, DenseTensor* dx); template @@ -142,6 +143,7 @@ void MaxPool3dWithIndexGradKernel(const Context& ctx, const std::vector& paddings, bool global_pooling, bool adaptive, + bool fracional, DenseTensor* dx); } // namespace phi diff --git a/paddle/phi/kernels/pool_kernel.h b/paddle/phi/kernels/pool_kernel.h index c1a7dd471a02f..49cd0a4955e59 100644 --- a/paddle/phi/kernels/pool_kernel.h +++ b/paddle/phi/kernels/pool_kernel.h @@ -60,6 +60,7 @@ void MaxPool2dWithIndexKernel(const Context& ctx, const std::vector& paddings, bool global_pooling, bool adaptive, + bool fractional, DenseTensor* out, DenseTensor* mask); @@ -101,6 +102,7 @@ void MaxPool3dWithIndexKernel(const Context& ctx, const std::vector& paddings, bool global_pooling, bool adaptive, + bool fractional, DenseTensor* out, DenseTensor* mask); From 8efe56a277c145acc600eb0d2df76ff1984698d7 Mon Sep 17 00:00:00 2001 From: megemini Date: Sat, 28 Oct 2023 15:34:26 +0800 Subject: [PATCH 02/18] [Add] fractional max pool python api --- paddle/phi/kernels/funcs/pooling.cu | 14 ++++ python/paddle/nn/functional/pooling.py | 103 +++++++++++++++++++++++++ 2 files changed, 117 insertions(+) diff --git a/paddle/phi/kernels/funcs/pooling.cu b/paddle/phi/kernels/funcs/pooling.cu index 6d378ac6f3c54..c4f0a612a9a02 100644 --- a/paddle/phi/kernels/funcs/pooling.cu +++ b/paddle/phi/kernels/funcs/pooling.cu @@ -1927,6 +1927,7 @@ __global__ void KernelMaxPool2dWithIdx(const int nthreads, const int padding_height, const int padding_width, bool adaptive, + bool fractional, T1* output_data, T2* mask_data, FastDivModForPooling divmods) { @@ -1953,6 +1954,8 @@ __global__ void KernelMaxPool2dWithIdx(const int nthreads, wstart = AdaptStartIndex(w_offset, input_width, output_width); wend = AdaptEndIndex(w_offset, input_width, output_width); + } else if (fractional) { + // TODO(megemini) } else { hstart = h_offset * stride_height - padding_height; hend = min(hstart + ksize_height, input_height); @@ -2048,6 +2051,7 @@ __global__ void KernelMaxPool2DWithIdxGrad(const int nthreads, const int padding_height, const int padding_width, bool adaptive, + bool fractional, T1* input_grad, FastDivModForPooling divmods) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; @@ -2075,6 +2079,8 @@ __global__ void KernelMaxPool2DWithIdxGrad(const int nthreads, pwstart = w_offset * output_width / input_width; pwend = min((w_offset + 1) * output_width / input_width + 1, output_width); + } else if (fractional) { + // TODO(megemini) } else { phstart = (h_offset + padding_height < ksize_height) @@ -2188,6 +2194,7 @@ class MaxPool2dWithIndexFunctor { padding_height, padding_width, adaptive, + fractional, output_data, mask_data, pool_divmods); @@ -2252,6 +2259,7 @@ class MaxPool2dWithIndexGradFunctor { padding_height, padding_width, adaptive, + fractional, input_grad_data, pool_divmods); } @@ -2290,6 +2298,7 @@ __global__ void KernelMaxPool3DWithIdx(const int ncd, const int padding_height, const int padding_width, bool adaptive, + bool fractional, T1* output_data, T2* mask_data, FastDivModForPooling3D divmods_output) { @@ -2322,6 +2331,8 @@ __global__ void KernelMaxPool3DWithIdx(const int ncd, wstart = AdaptStartIndex(w_offset, input_width, output_width); wend = AdaptEndIndex(w_offset, input_width, output_width); + } else if (fractional) { + // TODO(megemini) } else { dstart = d_offset * stride_depth - padding_depth; hstart = h_offset * stride_height - padding_height; @@ -2375,6 +2386,7 @@ __global__ void KernelMaxPool3DWithIdxGrad( const int padding_height, const int padding_width, bool adaptive, + bool fractional, T1* input_grad, FastDivModForPooling3D divmods_output) { int w_offset, h_offset, d_offset, nc_offset; @@ -2480,6 +2492,7 @@ class MaxPool3dWithIndexFunctor { padding_height, padding_width, adaptive, + fractional, output_data, mask_data, pool_divmods_output); @@ -2563,6 +2576,7 @@ class MaxPool3dWithIndexGradFunctor { padding_height, padding_width, adaptive, + fractional, input_grad_data, pool_divmods_output); } diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index 3e78f585881b7..3b46a47bbbd22 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -2095,3 +2095,106 @@ def adaptive_max_pool3d(x, output_size, return_mask=False, name=None): ) return (pool_out, mask) if return_mask else pool_out + + +def fractional_max_pool2d(x, output_size, return_mask=False, name=None): + """ + TODO(megemini) + """ + _check_input(x, 4) + + in_h, in_w = x.shape[2:4] + if isinstance(output_size, int): + output_size = convert_to_list(output_size, 2, 'output_size') + else: + output_size = list(output_size) + if output_size[0] is None: + output_size[0] = in_h + if output_size[1] is None: + output_size[1] = in_w + + if in_dygraph_mode(): + pool_out = _C_ops.max_pool2d_with_index( + x, output_size, [1, 1], [0, 0], False, False, True + ) + return pool_out if return_mask else pool_out[0] + else: + l_type = 'max_pool2d_with_index' + + check_variable_and_dtype( + x, 'x', ['float32', 'float64'], 'fractional_max_pool2d' + ) + check_type(return_mask, 'return_mask', bool, 'fractional_max_pool2d') + + helper = LayerHelper(l_type, **locals()) + dtype = helper.input_dtype(input_param_name='x') + pool_out = helper.create_variable_for_type_inference(dtype) + + mask = helper.create_variable_for_type_inference('int32') + outputs = {"Out": pool_out, "Mask": mask} + + helper.append_op( + type=l_type, + inputs={"X": x}, + outputs=outputs, + attrs={ + "pooling_type": 'max', + "ksize": output_size, + "fractional": True, + }, + ) + + return (pool_out, mask) if return_mask else pool_out + + +def fractional_max_pool3d(x, output_size, return_mask=False, name=None): + """ + TODO(megemini) + """ + _check_input(x, 5) + + in_l, in_h, in_w = x.shape[2:5] + if isinstance(output_size, int): + output_size = convert_to_list(output_size, 3, 'output_size') + else: + output_size = list(output_size) + if output_size[0] is None: + output_size[0] = in_l + if output_size[1] is None: + output_size[1] = in_h + if output_size[2] is None: + output_size[2] = in_w + + if in_dygraph_mode(): + # By default, strides is [1,1,1] and paddings is [0, 0, 0] + pool_out = _C_ops.max_pool3d_with_index( + x, output_size, [1, 1, 1], [0, 0, 0], False, False, True + ) + return pool_out if return_mask else pool_out[0] + else: + l_type = 'max_pool3d_with_index' + + check_variable_and_dtype( + x, 'x', ['float32', 'float64'], 'fractional_max_pool3d' + ) + check_type(return_mask, 'return_mask', bool, 'fractional_max_pool3d') + + helper = LayerHelper(l_type, **locals()) + dtype = helper.input_dtype(input_param_name='x') + pool_out = helper.create_variable_for_type_inference(dtype) + + mask = helper.create_variable_for_type_inference('int32') + outputs = {"Out": pool_out, "Mask": mask} + + helper.append_op( + type=l_type, + inputs={"X": x}, + outputs=outputs, + attrs={ + "pooling_type": 'max', + "ksize": output_size, + "fractional": True, + }, + ) + + return (pool_out, mask) if return_mask else pool_out From 88362e839d375e969b1e2514c8d2e1f65b1120a4 Mon Sep 17 00:00:00 2001 From: megemini Date: Sat, 28 Oct 2023 17:52:45 +0800 Subject: [PATCH 03/18] [Add] FractionalMaxPool layer --- python/paddle/nn/layer/pooling.py | 50 +++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/python/paddle/nn/layer/pooling.py b/python/paddle/nn/layer/pooling.py index 3108aeebeded4..8db539fe7f763 100755 --- a/python/paddle/nn/layer/pooling.py +++ b/python/paddle/nn/layer/pooling.py @@ -1141,6 +1141,56 @@ def extra_repr(self): ) +class FractionalMaxPool2D(Layer): + """ + TODO(megemini) + """ + + def __init__(self, output_size, return_mask=False, name=None): + super().__init__() + self._output_size = output_size + self._return_mask = return_mask + self._name = name + + def forward(self, x): + return F.fractional_max_pool2d( + x, + output_size=self._output_size, + return_mask=self._return_mask, + name=self._name, + ) + + def extra_repr(self): + return ( + f'output_size={self._output_size}, return_mask={self._return_mask}' + ) + + +class FractionalMaxPool3D(Layer): + """ + TODO(megemini) + """ + + def __init__(self, output_size, return_mask=False, name=None): + super().__init__() + self._output_size = output_size + self._return_mask = return_mask + self._name = name + + def forward(self, x): + return F.fractional_max_pool3d( + x, + output_size=self._output_size, + return_mask=self._return_mask, + name=self._name, + ) + + def extra_repr(self): + return ( + f'output_size={self._output_size}, return_mask={self._return_mask}' + ) + + class MaxUnPool1D(Layer): r""" This API implements max unpooling 1d opereation. From f2652035c4229cd707eeb1948493ac7113a09dbf Mon Sep 17 00:00:00 2001 From: megemini Date: Wed, 15 Nov 2023 19:36:37 +0800 Subject: [PATCH 04/18] [Add] fractional index generation --- paddle/phi/kernels/funcs/pooling.cc | 57 ++++++++++++++++++-- paddle/phi/kernels/funcs/pooling.cu | 82 +++++++++++++++++++++++++++-- paddle/phi/kernels/funcs/pooling.h | 28 ++++++---- 3 files changed, 149 insertions(+), 18 deletions(-) diff --git a/paddle/phi/kernels/funcs/pooling.cc b/paddle/phi/kernels/funcs/pooling.cc index 4e876b666b8e4..0e25ae82130e0 100644 --- a/paddle/phi/kernels/funcs/pooling.cc +++ b/paddle/phi/kernels/funcs/pooling.cc @@ -1593,6 +1593,21 @@ class MaxPool2dWithIndexFunctor { T1* output_data = context.template Alloc(output); T2* mask_data = context.template Alloc(mask); + float alpha_height = 0, alpha_width = 0; + float u_height = 0, u_width = 0; + if (fractional) { + std::uniform_real_distribution dist(0, 1); + auto engine = phi::GetCPURandomEngine(0); + float u = dist(*engine); + + alpha_height = static_cast(input_height) / output_height; + alpha_width = static_cast(input_width) / output_width; + + u_height = + FractionalRationalU(u, alpha_height, input_height, output_height); + u_width = FractionalRationalU(u, alpha_width, input_width, output_width); + } + int hstart = 0, hend = 0; int wstart = 0, wend = 0; for (int i = 0; i < batch_size; i++) { @@ -1602,7 +1617,10 @@ class MaxPool2dWithIndexFunctor { hstart = AdaptStartIndex(ph, input_height, output_height); hend = AdaptEndIndex(ph, input_height, output_height); } else if (fractional) { - // TODO(megemini) + hstart = FractionalStartIndex(ph, alpha_height, u_height); + hend = FractionalEndIndex(ph, alpha_height, u_height); + hstart = std::max(hstart, 0); + hend = std::min(hend, input_height); } else { hstart = ph * stride_height - padding_height; hend = std::min(hstart + ksize_height, input_height); @@ -1613,7 +1631,10 @@ class MaxPool2dWithIndexFunctor { wstart = AdaptStartIndex(pw, input_width, output_width); wend = AdaptEndIndex(pw, input_width, output_width); } else if (fractional) { - // TODO(megemini) + wstart = FractionalStartIndex(pw, alpha_width, u_width); + wend = FractionalEndIndex(pw, alpha_width, u_width); + wstart = std::max(wstart, 0); + wend = std::min(wend, input_width); } else { wstart = pw * stride_width - padding_width; wend = std::min(wstart + ksize_width, input_width); @@ -1737,6 +1758,23 @@ class MaxPool3dWithIndexFunctor { T1* output_data = context.template Alloc(output); T2* mask_data = context.template Alloc(mask); + float alpha_height = 0, alpha_width = 0, alpha_depth = 0; + float u_height = 0, u_width = 0, u_depth = 0; + if (fractional) { + std::uniform_real_distribution dist(0, 1); + auto engine = phi::GetCPURandomEngine(0); + float u = dist(*engine); + + alpha_depth = static_cast(input_depth) / output_depth; + alpha_height = static_cast(input_height) / output_height; + alpha_width = static_cast(input_width) / output_width; + + u_depth = FractionalRationalU(u, alpha_depth, input_depth, output_depth); + u_height = + FractionalRationalU(u, alpha_height, input_height, output_height); + u_width = FractionalRationalU(u, alpha_width, input_width, output_width); + } + int dstart = 0, dend = 0; int hstart = 0, hend = 0; int wstart = 0, wend = 0; @@ -1747,7 +1785,10 @@ class MaxPool3dWithIndexFunctor { dstart = AdaptStartIndex(pd, input_depth, output_depth); dend = AdaptEndIndex(pd, input_depth, output_depth); } else if (fractional) { - /* TODO(megemini) */ + dstart = FractionalStartIndex(pd, alpha_depth, u_depth); + dend = FractionalEndIndex(pd, alpha_depth, u_depth); + dstart = std::max(dstart, 0); + dend = std::min(dend, input_depth); } else { dstart = pd * stride_depth - padding_depth; dend = std::min(dstart + ksize_depth, input_depth); @@ -1758,7 +1799,10 @@ class MaxPool3dWithIndexFunctor { hstart = AdaptStartIndex(ph, input_height, output_height); hend = AdaptEndIndex(ph, input_height, output_height); } else if (fractional) { - /* TODO(megemini) */ + hstart = FractionalStartIndex(ph, alpha_height, u_height); + hend = FractionalEndIndex(ph, alpha_height, u_height); + hstart = std::max(hstart, 0); + hend = std::min(hend, input_height); } else { hstart = ph * stride_height - padding_height; hend = std::min(hstart + ksize_height, input_height); @@ -1769,7 +1813,10 @@ class MaxPool3dWithIndexFunctor { wstart = AdaptStartIndex(pw, input_width, output_width); wend = AdaptEndIndex(pw, input_width, output_width); } else if (fractional) { - // TODO(megemini) + wstart = FractionalStartIndex(pw, alpha_width, u_width); + wend = FractionalEndIndex(pw, alpha_width, u_width); + wstart = std::max(wstart, 0); + wend = std::min(wend, input_width); } else { wstart = pw * stride_width - padding_width; wend = std::min(wstart + ksize_width, input_width); diff --git a/paddle/phi/kernels/funcs/pooling.cu b/paddle/phi/kernels/funcs/pooling.cu index c4f0a612a9a02..08f8203c894b6 100644 --- a/paddle/phi/kernels/funcs/pooling.cu +++ b/paddle/phi/kernels/funcs/pooling.cu @@ -1931,6 +1931,21 @@ __global__ void KernelMaxPool2dWithIdx(const int nthreads, T1* output_data, T2* mask_data, FastDivModForPooling divmods) { + float alpha_height = 0, alpha_width = 0, alpha_depth = 0; + float u_height = 0, u_width = 0, u_depth = 0; + if (fractional) { + std::uniform_real_distribution dist(0, 1); + auto engine = phi::GetCPURandomEngine(0); + float u = dist(*engine); + + alpha_height = static_cast(input_height) / output_height; + alpha_width = static_cast(input_width) / output_width; + + u_height = + FractionalRationalU(u, alpha_height, input_height, output_height); + u_width = FractionalRationalU(u, alpha_width, input_width, output_width); + } + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int hstart, hend, wstart, wend; @@ -1955,7 +1970,15 @@ __global__ void KernelMaxPool2dWithIdx(const int nthreads, wstart = AdaptStartIndex(w_offset, input_width, output_width); wend = AdaptEndIndex(w_offset, input_width, output_width); } else if (fractional) { - // TODO(megemini) + hstart = FractionalStartIndex(h_offset, alpha_height, u_height); + hend = FractionalEndIndex(h_offset, alpha_height, u_height); + hstart = std::max(hstart, 0); + hend = std::min(hend, input_height); + + wstart = FractionalStartIndex(w_offset, alpha_width, u_width); + wend = FractionalEndIndex(w_offset, alpha_width, u_width); + wstart = std::max(wstart, 0); + wend = std::min(wend, input_width); } else { hstart = h_offset * stride_height - padding_height; hend = min(hstart + ksize_height, input_height); @@ -2054,6 +2077,21 @@ __global__ void KernelMaxPool2DWithIdxGrad(const int nthreads, bool fractional, T1* input_grad, FastDivModForPooling divmods) { + float alpha_height = 0, alpha_width = 0, alpha_depth = 0; + float u_height = 0, u_width = 0, u_depth = 0; + if (fractional) { + std::uniform_real_distribution dist(0, 1); + auto engine = phi::GetCPURandomEngine(0); + float u = dist(*engine); + + alpha_height = static_cast(input_height) / output_height; + alpha_width = static_cast(input_width) / output_width; + + u_height = + FractionalRationalU(u, alpha_height, input_height, output_height); + u_width = FractionalRationalU(u, alpha_width, input_width, output_width); + } + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int phstart, phend, pwstart, pwend; @@ -2080,7 +2118,15 @@ __global__ void KernelMaxPool2DWithIdxGrad(const int nthreads, pwend = min((w_offset + 1) * output_width / input_width + 1, output_width); } else if (fractional) { - // TODO(megemini) + phstart = FractionalStartIndex(h_offset, alpha_height, u_height); + phend = FractionalEndIndex(h_offset, alpha_height, u_height); + phstart = std::max(phstart, 0); + phend = std::min(phend, input_height); + + pwstart = FractionalStartIndex(w_offset, alpha_width, u_width); + pwend = FractionalEndIndex(w_offset, alpha_width, u_width); + pwstart = std::max(pwstart, 0); + pwend = std::min(pwend, input_width); } else { phstart = (h_offset + padding_height < ksize_height) @@ -2302,6 +2348,23 @@ __global__ void KernelMaxPool3DWithIdx(const int ncd, T1* output_data, T2* mask_data, FastDivModForPooling3D divmods_output) { + float alpha_height = 0, alpha_width = 0, alpha_depth = 0; + float u_height = 0, u_width = 0, u_depth = 0; + if (fractional) { + std::uniform_real_distribution dist(0, 1); + auto engine = phi::GetCPURandomEngine(0); + float u = dist(*engine); + + alpha_depth = static_cast(input_depth) / output_depth; + alpha_height = static_cast(input_height) / output_height; + alpha_width = static_cast(input_width) / output_width; + + u_depth = FractionalRationalU(u, alpha_depth, input_depth, output_depth); + u_height = + FractionalRationalU(u, alpha_height, input_height, output_height); + u_width = FractionalRationalU(u, alpha_width, input_width, output_width); + } + int w_offset, h_offset, d_offset, nc_offset; int dstart, dend, hstart, hend, wstart, wend; const T1* input_data_cur; @@ -2332,7 +2395,20 @@ __global__ void KernelMaxPool3DWithIdx(const int ncd, wstart = AdaptStartIndex(w_offset, input_width, output_width); wend = AdaptEndIndex(w_offset, input_width, output_width); } else if (fractional) { - // TODO(megemini) + dstart = FractionalStartIndex(d_offset, alpha_depth, u_depth); + dend = FractionalEndIndex(d_offset, alpha_depth, u_depth); + dstart = std::max(dstart, 0); + dend = std::min(dend, input_depth); + + hstart = FractionalStartIndex(h_offset, alpha_height, u_height); + hend = FractionalEndIndex(h_offset, alpha_height, u_height); + hstart = std::max(hstart, 0); + hend = std::min(hend, input_height); + + wstart = FractionalStartIndex(w_offset, alpha_width, u_width); + wend = FractionalEndIndex(w_offset, alpha_width, u_width); + wstart = std::max(wstart, 0); + wend = std::min(wend, input_width); } else { dstart = d_offset * stride_depth - padding_depth; hstart = h_offset * stride_height - padding_height; diff --git a/paddle/phi/kernels/funcs/pooling.h b/paddle/phi/kernels/funcs/pooling.h index fc25d1562fe45..00e6b23dffc03 100644 --- a/paddle/phi/kernels/funcs/pooling.h +++ b/paddle/phi/kernels/funcs/pooling.h @@ -103,18 +103,26 @@ HOSTDEVICE inline int AdaptEndIndex(int ph, int input_size, int output_size) { /* used for fractional pool to calculate start and end index of each divided * grid */ -HOSTDEVICE inline int FractionalStartIndex(int ph, - int input_size, - int output_size) { - return static_cast( - floor(static_cast(ph * input_size) / output_size)); +HOSTDEVICE inline float FractionalRationalU(float u, + float alpha, + int input, + int output) { + int base = input / output; + + float u_max1 = static_cast(base + 2) / alpha - 1; + float u_max2 = static_cast(input + 1 - base) / alpha - + static_cast(output - 1); + float max_u = std::min(u_max1, u_max2); + + return u * max_u; } -HOSTDEVICE inline int FractionalEndIndex(int ph, - int input_size, - int output_size) { - return static_cast( - ceil(static_cast((ph + 1) * input_size) / output_size)); +HOSTDEVICE inline int FractionalStartIndex(int ph, float alpha, float u) { + return static_cast(ceil(alpha * (ph + u) - 1)); +} + +HOSTDEVICE inline int FractionalEndIndex(int ph, float alpha, float u) { + return static_cast(ceil(alpha * (ph + 1 + u) - 1)); } /* From d7457e076d79fa89701b0e5abde23a9fd335d766 Mon Sep 17 00:00:00 2001 From: megemini Date: Thu, 16 Nov 2023 12:15:51 +0800 Subject: [PATCH 05/18] [Add] add fractional max pooling to __init__ --- python/paddle/nn/__init__.py | 5 ++++- python/paddle/nn/functional/__init__.py | 4 ++++ python/paddle/nn/layer/__init__.py | 2 ++ 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index db4c9adf3327c..2dba33712bac2 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -86,7 +86,8 @@ from .layer.pooling import AdaptiveMaxPool1D # noqa: F401 from .layer.pooling import AdaptiveMaxPool2D # noqa: F401 from .layer.pooling import AdaptiveMaxPool3D # noqa: F401 - +from .layer.pooling import FractionalMaxPool2D # noqa: F401 +from .layer.pooling import FractionalMaxPool3D # noqa: F401 from .layer.conv import Conv1D # noqa: F401 from .layer.conv import Conv2D # noqa: F401 from .layer.conv import Conv3D # noqa: F401 @@ -219,6 +220,8 @@ 'SmoothL1Loss', 'MaxPool3D', 'AdaptiveMaxPool2D', + 'FractionalMaxPool2D', + 'FractionalMaxPool3D', 'Hardshrink', 'Softplus', 'KLDivLoss', diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 453627b4cf049..5608d345df9ee 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -125,6 +125,8 @@ from .pooling import max_unpool1d # noqa: F401 from .pooling import max_unpool2d # noqa: F401 from .pooling import max_unpool3d # noqa: F401 +from .pooling import fractional_max_pool2d # noqa: F401 +from .pooling import fractional_max_pool3d # noqa: F401 from .vision import affine_grid # noqa: F401 from .vision import grid_sample # noqa: F401 @@ -212,6 +214,8 @@ 'adaptive_max_pool1d', 'adaptive_max_pool2d', 'adaptive_max_pool3d', + 'fractional_max_pool2d', + 'fractional_max_pool3d', 'binary_cross_entropy', 'binary_cross_entropy_with_logits', 'cross_entropy', diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index f83b8454456ff..ef3dbc195dd7f 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -63,6 +63,8 @@ from .pooling import MaxUnPool1D # noqa: F401 from .pooling import MaxUnPool2D # noqa: F401 from .pooling import MaxUnPool3D # noqa: F401 +from .pooling import FractionalMaxPool2D # noqa: F401 +from .pooling import FractionalMaxPool3D # noqa: F401 from .conv import Conv1D # noqa: F401 from .conv import Conv2D # noqa: F401 from .conv import Conv3D # noqa: F401 From f79b8d29389689eeea0ec23c13fd1f5c029507f1 Mon Sep 17 00:00:00 2001 From: megemini Date: Thu, 16 Nov 2023 12:57:37 +0800 Subject: [PATCH 06/18] [Fix] add default value False to max_poolNd_with_index --- python/paddle/nn/functional/pooling.py | 12 ++++++------ test/legacy_test/test_pool_max_op.py | 6 ++++-- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index 3b46a47bbbd22..1df41a0ef1bd7 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -634,7 +634,7 @@ def max_pool1d( if in_dygraph_mode(): if return_mask: pool_out = _C_ops.max_pool2d_with_index( - x, kernel_size, stride, padding, False, False + x, kernel_size, stride, padding, False, False, False ) return ( (squeeze(pool_out[0], [2]), squeeze(pool_out[1], [2])) @@ -1261,7 +1261,7 @@ def max_pool2d( if in_dynamic_or_pir_mode(): if return_mask: output = _C_ops.max_pool2d_with_index( - x, kernel_size, stride, padding, False, False + x, kernel_size, stride, padding, False, False, False ) return output if return_mask else output[0] else: @@ -1428,7 +1428,7 @@ def max_pool3d( if in_dygraph_mode(): if return_mask: output = _C_ops.max_pool3d_with_index( - x, kernel_size, stride, padding, False, False + x, kernel_size, stride, padding, False, False, False ) return output if return_mask else output[0] else: @@ -1877,7 +1877,7 @@ def adaptive_max_pool1d(x, output_size, return_mask=False, name=None): x = unsqueeze(x, [2]) if in_dygraph_mode(): pool_out = _C_ops.max_pool2d_with_index( - x, pool_size, [1, 1], [0, 0], False, True + x, pool_size, [1, 1], [0, 0], False, True, False ) return ( (squeeze(pool_out[0], [2]), squeeze(pool_out[1], [2])) @@ -1971,7 +1971,7 @@ def adaptive_max_pool2d(x, output_size, return_mask=False, name=None): output_size[1] = in_w if in_dygraph_mode(): pool_out = _C_ops.max_pool2d_with_index( - x, output_size, [1, 1], [0, 0], False, True + x, output_size, [1, 1], [0, 0], False, True, False ) return pool_out if return_mask else pool_out[0] else: @@ -2064,7 +2064,7 @@ def adaptive_max_pool3d(x, output_size, return_mask=False, name=None): if in_dygraph_mode(): # By default, strides is [1,1,1] and paddings is [0, 0, 0] pool_out = _C_ops.max_pool3d_with_index( - x, output_size, [1, 1, 1], [0, 0, 0], False, True + x, output_size, [1, 1, 1], [0, 0, 0], False, True, False ) return pool_out if return_mask else pool_out[0] else: diff --git a/test/legacy_test/test_pool_max_op.py b/test/legacy_test/test_pool_max_op.py index 23740d39b8ef3..cdfd52b200ec4 100644 --- a/test/legacy_test/test_pool_max_op.py +++ b/test/legacy_test/test_pool_max_op.py @@ -143,9 +143,10 @@ def max_pool3d_with_index_wapper( paddings=[], global_pooling=False, adaptive=False, + fractional=False, ): return paddle._C_ops.max_pool3d_with_index( - x, kernel_size, strides, paddings, global_pooling, adaptive + x, kernel_size, strides, paddings, global_pooling, adaptive, fractional ) @@ -342,9 +343,10 @@ def max_pool2d_with_index_wapper( paddings=[], global_pooling=False, adaptive=False, + fractional=False, ): return paddle._C_ops.max_pool2d_with_index( - x, kernel_size, strides, paddings, global_pooling, adaptive + x, kernel_size, strides, paddings, global_pooling, adaptive, fractional ) From 594e130a57a1c8419c3a9e32d4a0632abef0fb86 Mon Sep 17 00:00:00 2001 From: megemini Date: Thu, 16 Nov 2023 15:36:55 +0800 Subject: [PATCH 07/18] [Update] __init__ add fractional funcs --- python/paddle/nn/__init__.py | 2 ++ python/paddle/nn/functional/__init__.py | 2 ++ python/paddle/nn/layer/__init__.py | 2 ++ 3 files changed, 6 insertions(+) diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index a284ca2805f4b..1956190161a18 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -129,6 +129,8 @@ AdaptiveMaxPool1D, AdaptiveMaxPool2D, AdaptiveMaxPool3D, + FractionalMaxPool2D, + FractionalMaxPool3D, AvgPool1D, AvgPool2D, AvgPool3D, diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index b962bb394b790..ce26107ff900a 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -129,6 +129,8 @@ adaptive_max_pool1d, adaptive_max_pool2d, adaptive_max_pool3d, + fractional_max_pool2d, + fractional_max_pool3d, avg_pool1d, avg_pool2d, avg_pool3d, diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 9271c5ecc10e1..6eda37c48ba5b 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -97,6 +97,8 @@ AdaptiveMaxPool1D, AdaptiveMaxPool2D, AdaptiveMaxPool3D, + FractionalMaxPool2D, + FractionalMaxPool3D, AvgPool1D, AvgPool2D, AvgPool3D, From e667f099114ddf0f63d3fe347c82815a9e8aecde Mon Sep 17 00:00:00 2001 From: megemini Date: Sun, 19 Nov 2023 17:32:27 +0800 Subject: [PATCH 08/18] [Add] test file --- test/legacy_test/test_fractional_max_pool2d.py | 13 +++++++++++++ test/legacy_test/test_fractional_max_pool3d.py | 13 +++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 test/legacy_test/test_fractional_max_pool2d.py create mode 100644 test/legacy_test/test_fractional_max_pool3d.py diff --git a/test/legacy_test/test_fractional_max_pool2d.py b/test/legacy_test/test_fractional_max_pool2d.py new file mode 100644 index 0000000000000..595add0aed9e1 --- /dev/null +++ b/test/legacy_test/test_fractional_max_pool2d.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/test/legacy_test/test_fractional_max_pool3d.py b/test/legacy_test/test_fractional_max_pool3d.py new file mode 100644 index 0000000000000..595add0aed9e1 --- /dev/null +++ b/test/legacy_test/test_fractional_max_pool3d.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. From ab256ac740bb2508500baf0d2ca290e1e5277455 Mon Sep 17 00:00:00 2001 From: megemini Date: Sun, 19 Nov 2023 18:48:14 +0800 Subject: [PATCH 09/18] [Add] test_pool_max_op for fractional --- test/legacy_test/test_pool_max_op.py | 137 ++++++++++++++++++++++++++- 1 file changed, 133 insertions(+), 4 deletions(-) diff --git a/test/legacy_test/test_pool_max_op.py b/test/legacy_test/test_pool_max_op.py index cdfd52b200ec4..2285a4a305946 100644 --- a/test/legacy_test/test_pool_max_op.py +++ b/test/legacy_test/test_pool_max_op.py @@ -35,26 +35,86 @@ def adaptive_end_index(index, input_size, output_size): return int(np.ceil((index + 1) * input_size / output_size)) +def fractional_rational_u(u, alpha, input, output): + base = input // output + + u_max1 = (base + 2) / alpha - 1 + u_max2 = (input + 1 - base) / alpha - (output - 1) + max_u = min(u_max1, u_max2) + + return u * max_u + + +def fractional_start_index(ph, alpha, u): + return np.ceil(alpha * (ph + u) - 1) + + +def fractional_end_index(ph, alpha, u): + return np.ceil(alpha * (ph + 1 + u) - 1) + + def max_pool3D_forward_naive( - x, ksize, strides, paddings, global_pool=False, adaptive=False + x, + ksize, + strides, + paddings, + global_pool=False, + adaptive=False, + fractional=False, ): N, C, D, H, W = x.shape if global_pool: ksize = [D, H, W] paddings = [0, 0, 0] - if adaptive: + if adaptive or fractional: D_out, H_out, W_out = ksize else: D_out = (D - ksize[0] + 2 * paddings[0]) // strides[0] + 1 H_out = (H - ksize[1] + 2 * paddings[1]) // strides[1] + 1 W_out = (W - ksize[2] + 2 * paddings[2]) // strides[2] + 1 + + alpha_height = 0 + alpha_width = 0 + alpha_depth = 0 + u_height = 0 + u_width = 0 + u_depth = 0 + input_depth = D + output_depth = D_out + input_height = H + output_height = H_out + input_width = W + output_width = W_out + if fractional: + np.random.seed(2023) + u = np.random.uniform() + + alpha_depth = input_depth / output_depth + alpha_height = input_height / output_height + alpha_width = input_width / output_width + + u_depth = fractional_rational_u( + u, alpha_depth, input_depth, output_depth + ) + u_height = fractional_rational_u( + u, alpha_height, input_height, output_height + ) + u_width = fractional_rational_u( + u, alpha_width, input_width, output_width + ) + out = np.zeros((N, C, D_out, H_out, W_out)) mask = np.zeros((N, C, D_out, H_out, W_out)) for k in range(D_out): if adaptive: d_start = adaptive_start_index(k, D, ksize[0]) d_end = adaptive_end_index(k, D, ksize[0]) + elif fractional: + d_start = fractional_start_index(k, alpha_depth, u_depth) + d_end = fractional_end_index(k, alpha_depth, u_depth) + d_start = max(d_start, 0) + d_end = min(d_end, input_depth) else: d_start = np.max((k * strides[0] - paddings[0], 0)) d_end = np.min((k * strides[0] + ksize[0] - paddings[0], D)) @@ -62,6 +122,11 @@ def max_pool3D_forward_naive( if adaptive: h_start = adaptive_start_index(i, H, ksize[1]) h_end = adaptive_end_index(i, H, ksize[1]) + elif fractional: + h_start = fractional_start_index(i, alpha_height, u_height) + h_end = fractional_end_index(i, alpha_height, u_height) + h_start = max(h_start, 0) + h_end = min(h_end, input_height) else: h_start = np.max((i * strides[1] - paddings[1], 0)) h_end = np.min((i * strides[1] + ksize[1] - paddings[1], H)) @@ -69,6 +134,11 @@ def max_pool3D_forward_naive( if adaptive: w_start = adaptive_start_index(j, W, ksize[2]) w_end = adaptive_end_index(j, W, ksize[2]) + elif fractional: + w_start = fractional_start_index(j, alpha_width, u_width) + w_end = fractional_end_index(j, alpha_width, u_width) + w_start = max(w_start, 0) + w_end = min(w_end, input_width) else: w_start = np.max((j * strides[2] - paddings[2], 0)) w_end = np.min((j * strides[2] + ksize[2] - paddings[2], W)) @@ -94,18 +164,47 @@ def max_pool3D_forward_naive( def max_pool2D_forward_naive( - x, ksize, strides, paddings, global_pool=False, adaptive=False + x, + ksize, + strides, + paddings, + global_pool=False, + adaptive=False, + fractional=False, ): N, C, H, W = x.shape if global_pool: ksize = [H, W] paddings = [0, 0] - if adaptive: + if adaptive or fractional: H_out, W_out = ksize else: H_out = (H - ksize[0] + 2 * paddings[0]) // strides[0] + 1 W_out = (W - ksize[1] + 2 * paddings[1]) // strides[1] + 1 + + alpha_height = 0 + alpha_width = 0 + u_height = 0 + u_width = 0 + input_height = H + output_height = H_out + input_width = W + output_width = W_out + if fractional: + np.random.seed(2023) + u = np.random.uniform() + + alpha_height = input_height / output_height + alpha_width = input_width / output_width + + u_height = fractional_rational_u( + u, alpha_height, input_height, output_height + ) + u_width = fractional_rational_u( + u, alpha_width, input_width, output_width + ) + out = np.zeros((N, C, H_out, W_out)) mask = np.zeros((N, C, H_out, W_out)) for i in range(H_out): @@ -115,6 +214,16 @@ def max_pool2D_forward_naive( r_end = adaptive_end_index(i, H, ksize[0]) c_start = adaptive_start_index(j, W, ksize[1]) c_end = adaptive_end_index(j, W, ksize[1]) + elif fractional: + r_start = fractional_start_index(i, alpha_height, u_height) + r_end = fractional_end_index(i, alpha_height, u_height) + r_start = max(r_start, 0) + r_end = min(r_end, input_height) + + c_start = fractional_start_index(j, alpha_width, u_width) + c_end = fractional_end_index(j, alpha_width, u_width) + c_start = max(c_start, 0) + c_end = min(c_end, input_width) else: r_start = np.max((i * strides[0] - paddings[0], 0)) r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) @@ -155,6 +264,7 @@ def setUp(self): self.init_test_case() self.init_global() self.init_adaptive() + self.init_fractional() self.init_dtype() if self.is_bfloat16_op(): @@ -174,6 +284,7 @@ def setUp(self): self.paddings, self.global_pool, self.adaptive, + self.fractional, ) mask = mask.astype("int32") if self.is_bfloat16_op(): @@ -187,6 +298,7 @@ def setUp(self): 'ksize': self.ksize, 'global_pooling': self.global_pool, 'adaptive': self.adaptive, + 'fractional': self.fractional, } if self.is_bfloat16_op(): @@ -225,6 +337,9 @@ def init_global(self): def init_adaptive(self): self.adaptive = False + def init_fractional(self): + self.fractional = False + class TestCase1(TestMaxPoolWithIndex_Op): def init_global(self): @@ -255,6 +370,11 @@ def init_adaptive(self): self.adaptive = True +class TestCastFractional3d(TestMaxPoolWithIndex_Op): + def init_fractional(self): + self.fractional = True + + # ----------------max_pool3d_with_index_fp16---------------- def create_test_fp16_class(parent): @unittest.skipIf( @@ -285,6 +405,7 @@ def test_check_grad(self): create_test_fp16_class(TestCase2) create_test_fp16_class(TestCase3) create_test_fp16_class(TestCastAdaptive3d) +create_test_fp16_class(TestCastFractional3d) # ----------------max_pool3d_with_index_bf16---------------- @@ -333,6 +454,7 @@ def test_check_grad(self): create_test_bf16_class(TestCase2) create_test_bf16_class(TestCase3) create_test_bf16_class(TestCastAdaptive3d) +create_test_bf16_class(TestCastFractional3d) # ----------------max_pool2d_with_index---------------- @@ -393,6 +515,11 @@ def init_adaptive(self): self.adaptive = True +class TestCastFractional2d(TestCase6): + def init_fractional(self): + self.fractional = True + + # ----------------max_pool2d_with_index_fp16---------------- def create_test_fp16_class(parent): @unittest.skipIf( @@ -423,6 +550,7 @@ def test_check_grad(self): create_test_fp16_class(TestCase6) create_test_fp16_class(TestCase7) create_test_fp16_class(TestCastAdaptive2d) +create_test_fp16_class(TestCastFractional2d) # ----------------max_pool2d_with_index_bf16---------------- @@ -469,6 +597,7 @@ def test_check_grad(self): create_test_bf16_class(TestCase6) create_test_bf16_class(TestCase7) create_test_bf16_class(TestCastAdaptive2d) +create_test_bf16_class(TestCastFractional2d) if __name__ == '__main__': From bf7a64071a10c8a361ccf01144b713714500d68b Mon Sep 17 00:00:00 2001 From: megemini Date: Mon, 20 Nov 2023 13:48:31 +0800 Subject: [PATCH 10/18] [Fix] test index cast to int --- paddle/phi/kernels/funcs/pooling.h | 8 ++++---- test/legacy_test/test_pool_max_op.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/paddle/phi/kernels/funcs/pooling.h b/paddle/phi/kernels/funcs/pooling.h index 00e6b23dffc03..53f0e274d99bb 100644 --- a/paddle/phi/kernels/funcs/pooling.h +++ b/paddle/phi/kernels/funcs/pooling.h @@ -117,12 +117,12 @@ HOSTDEVICE inline float FractionalRationalU(float u, return u * max_u; } -HOSTDEVICE inline int FractionalStartIndex(int ph, float alpha, float u) { - return static_cast(ceil(alpha * (ph + u) - 1)); +HOSTDEVICE inline int FractionalStartIndex(int idx, float alpha, float u) { + return static_cast(ceil(alpha * (idx + u) - 1)); } -HOSTDEVICE inline int FractionalEndIndex(int ph, float alpha, float u) { - return static_cast(ceil(alpha * (ph + 1 + u) - 1)); +HOSTDEVICE inline int FractionalEndIndex(int idx, float alpha, float u) { + return static_cast(ceil(alpha * (idx + 1 + u) - 1)); } /* diff --git a/test/legacy_test/test_pool_max_op.py b/test/legacy_test/test_pool_max_op.py index 2285a4a305946..a833b090d449f 100644 --- a/test/legacy_test/test_pool_max_op.py +++ b/test/legacy_test/test_pool_max_op.py @@ -45,12 +45,12 @@ def fractional_rational_u(u, alpha, input, output): return u * max_u -def fractional_start_index(ph, alpha, u): - return np.ceil(alpha * (ph + u) - 1) +def fractional_start_index(idx, alpha, u): + return int(np.ceil(alpha * (idx + u) - 1)) -def fractional_end_index(ph, alpha, u): - return np.ceil(alpha * (ph + 1 + u) - 1) +def fractional_end_index(idx, alpha, u): + return int(np.ceil(alpha * (idx + 1 + u) - 1)) def max_pool3D_forward_naive( From 44f44c04c0cac6a752628cdba36a3447649b547f Mon Sep 17 00:00:00 2001 From: megemini Date: Mon, 20 Nov 2023 19:00:28 +0800 Subject: [PATCH 11/18] [Change] pooling cu rand --- paddle/phi/kernels/funcs/pooling.cu | 84 +++++++++++++++++++++++++---- 1 file changed, 75 insertions(+), 9 deletions(-) diff --git a/paddle/phi/kernels/funcs/pooling.cu b/paddle/phi/kernels/funcs/pooling.cu index 08f8203c894b6..aa5e6869fc482 100644 --- a/paddle/phi/kernels/funcs/pooling.cu +++ b/paddle/phi/kernels/funcs/pooling.cu @@ -14,6 +14,12 @@ limitations under the License. */ #include #include +#ifdef __NVCC__ +#include +#endif +#ifdef __HIPCC__ +#include +#endif #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" @@ -21,6 +27,9 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/reduce_function.h" #include "paddle/phi/kernels/primitive/datamover_primitives.h" +#include "paddle/phi/kernels/funcs/distribution_helper.h" +#include "paddle/phi/kernels/funcs/random.cuh" + namespace phi { namespace funcs { @@ -1934,9 +1943,28 @@ __global__ void KernelMaxPool2dWithIdx(const int nthreads, float alpha_height = 0, alpha_width = 0, alpha_depth = 0; float u_height = 0, u_width = 0, u_depth = 0; if (fractional) { - std::uniform_real_distribution dist(0, 1); - auto engine = phi::GetCPURandomEngine(0); - float u = dist(*engine); + // std::uniform_real_distribution dist(0, 1); + // auto engine = phi::GetCPURandomEngine(0); + // float u = dist(*engine); + + // auto gen_cuda = phi::DefaultCUDAGenerator(0); + // int seed = static_cast(gen_cuda->Random64()); + int seed = 0; + + size_t thread_idx = + static_cast(blockIdx.x * blockDim.x + threadIdx.x); + +#if defined(__NVCC__) + curandStatePhilox4_32_10_t state; + curand_init(seed, thread_idx, 0, &state); +#else + hiprandStatePhilox4_32_10_t state; + hiprand_init(seed, thread_idx, offset, &state); +#endif + + phi::funcs::uniform_distribution dist; + float4 rand = dist(&state); + float u = (&rand.x)[0]; alpha_height = static_cast(input_height) / output_height; alpha_width = static_cast(input_width) / output_width; @@ -2080,9 +2108,28 @@ __global__ void KernelMaxPool2DWithIdxGrad(const int nthreads, float alpha_height = 0, alpha_width = 0, alpha_depth = 0; float u_height = 0, u_width = 0, u_depth = 0; if (fractional) { - std::uniform_real_distribution dist(0, 1); - auto engine = phi::GetCPURandomEngine(0); - float u = dist(*engine); + // std::uniform_real_distribution dist(0, 1); + // auto engine = phi::GetCPURandomEngine(0); + // float u = dist(*engine); + + // auto gen_cuda = phi::DefaultCUDAGenerator(0); + // int seed = static_cast(gen_cuda->Random64()); + int seed = 0; + + size_t thread_idx = + static_cast(blockIdx.x * blockDim.x + threadIdx.x); + +#if defined(__NVCC__) + curandStatePhilox4_32_10_t state; + curand_init(seed, thread_idx, 0, &state); +#else + hiprandStatePhilox4_32_10_t state; + hiprand_init(seed, thread_idx, offset, &state); +#endif + + phi::funcs::uniform_distribution dist; + float4 rand = dist(&state); + float u = (&rand.x)[0]; alpha_height = static_cast(input_height) / output_height; alpha_width = static_cast(input_width) / output_width; @@ -2351,9 +2398,28 @@ __global__ void KernelMaxPool3DWithIdx(const int ncd, float alpha_height = 0, alpha_width = 0, alpha_depth = 0; float u_height = 0, u_width = 0, u_depth = 0; if (fractional) { - std::uniform_real_distribution dist(0, 1); - auto engine = phi::GetCPURandomEngine(0); - float u = dist(*engine); + // std::uniform_real_distribution dist(0, 1); + // auto engine = phi::GetCPURandomEngine(0); + // float u = dist(*engine); + + // auto gen_cuda = phi::DefaultCUDAGenerator(0); + // int seed = static_cast(gen_cuda->Random64()); + int seed = 0; + + size_t thread_idx = + static_cast(blockIdx.x * blockDim.x + threadIdx.x); + +#if defined(__NVCC__) + curandStatePhilox4_32_10_t state; + curand_init(seed, thread_idx, 0, &state); +#else + hiprandStatePhilox4_32_10_t state; + hiprand_init(seed, thread_idx, offset, &state); +#endif + + phi::funcs::uniform_distribution dist; + float4 rand = dist(&state); + float u = (&rand.x)[0]; alpha_depth = static_cast(input_depth) / output_depth; alpha_height = static_cast(input_height) / output_height; From e0679f07e76443c50f983a561addb665373da386 Mon Sep 17 00:00:00 2001 From: megemini Date: Tue, 21 Nov 2023 12:28:59 +0800 Subject: [PATCH 12/18] [Change] skip fractional op test --- test/legacy_test/test_pool_max_op.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/test/legacy_test/test_pool_max_op.py b/test/legacy_test/test_pool_max_op.py index a833b090d449f..c058b1361b0d9 100644 --- a/test/legacy_test/test_pool_max_op.py +++ b/test/legacy_test/test_pool_max_op.py @@ -370,9 +370,9 @@ def init_adaptive(self): self.adaptive = True -class TestCastFractional3d(TestMaxPoolWithIndex_Op): - def init_fractional(self): - self.fractional = True +# class TestCastFractional3d(TestMaxPoolWithIndex_Op): +# def init_fractional(self): +# self.fractional = True # ----------------max_pool3d_with_index_fp16---------------- @@ -405,7 +405,7 @@ def test_check_grad(self): create_test_fp16_class(TestCase2) create_test_fp16_class(TestCase3) create_test_fp16_class(TestCastAdaptive3d) -create_test_fp16_class(TestCastFractional3d) +# create_test_fp16_class(TestCastFractional3d) # ----------------max_pool3d_with_index_bf16---------------- @@ -454,7 +454,7 @@ def test_check_grad(self): create_test_bf16_class(TestCase2) create_test_bf16_class(TestCase3) create_test_bf16_class(TestCastAdaptive3d) -create_test_bf16_class(TestCastFractional3d) +# create_test_bf16_class(TestCastFractional3d) # ----------------max_pool2d_with_index---------------- @@ -515,9 +515,9 @@ def init_adaptive(self): self.adaptive = True -class TestCastFractional2d(TestCase6): - def init_fractional(self): - self.fractional = True +# class TestCastFractional2d(TestCase6): +# def init_fractional(self): +# self.fractional = True # ----------------max_pool2d_with_index_fp16---------------- @@ -550,7 +550,7 @@ def test_check_grad(self): create_test_fp16_class(TestCase6) create_test_fp16_class(TestCase7) create_test_fp16_class(TestCastAdaptive2d) -create_test_fp16_class(TestCastFractional2d) +# create_test_fp16_class(TestCastFractional2d) # ----------------max_pool2d_with_index_bf16---------------- @@ -597,7 +597,7 @@ def test_check_grad(self): create_test_bf16_class(TestCase6) create_test_bf16_class(TestCase7) create_test_bf16_class(TestCastAdaptive2d) -create_test_bf16_class(TestCastFractional2d) +# create_test_bf16_class(TestCastFractional2d) if __name__ == '__main__': From 99196b22606cfdf351c9c86b978ece25ce275db7 Mon Sep 17 00:00:00 2001 From: megemini Date: Tue, 21 Nov 2023 15:07:25 +0800 Subject: [PATCH 13/18] [Update] cu seed --- paddle/phi/kernels/funcs/pooling.cu | 65 +++++++++++++++++------------ 1 file changed, 38 insertions(+), 27 deletions(-) diff --git a/paddle/phi/kernels/funcs/pooling.cu b/paddle/phi/kernels/funcs/pooling.cu index aa5e6869fc482..1880873721780 100644 --- a/paddle/phi/kernels/funcs/pooling.cu +++ b/paddle/phi/kernels/funcs/pooling.cu @@ -1937,26 +1937,20 @@ __global__ void KernelMaxPool2dWithIdx(const int nthreads, const int padding_width, bool adaptive, bool fractional, + uint64_t seed, + uint64_t offset, T1* output_data, T2* mask_data, FastDivModForPooling divmods) { float alpha_height = 0, alpha_width = 0, alpha_depth = 0; float u_height = 0, u_width = 0, u_depth = 0; if (fractional) { - // std::uniform_real_distribution dist(0, 1); - // auto engine = phi::GetCPURandomEngine(0); - // float u = dist(*engine); - - // auto gen_cuda = phi::DefaultCUDAGenerator(0); - // int seed = static_cast(gen_cuda->Random64()); - int seed = 0; - size_t thread_idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); #if defined(__NVCC__) curandStatePhilox4_32_10_t state; - curand_init(seed, thread_idx, 0, &state); + curand_init(seed, thread_idx, offset, &state); #else hiprandStatePhilox4_32_10_t state; hiprand_init(seed, thread_idx, offset, &state); @@ -2103,25 +2097,19 @@ __global__ void KernelMaxPool2DWithIdxGrad(const int nthreads, const int padding_width, bool adaptive, bool fractional, + uint64_t seed, + uint64_t offset, T1* input_grad, FastDivModForPooling divmods) { float alpha_height = 0, alpha_width = 0, alpha_depth = 0; float u_height = 0, u_width = 0, u_depth = 0; if (fractional) { - // std::uniform_real_distribution dist(0, 1); - // auto engine = phi::GetCPURandomEngine(0); - // float u = dist(*engine); - - // auto gen_cuda = phi::DefaultCUDAGenerator(0); - // int seed = static_cast(gen_cuda->Random64()); - int seed = 0; - size_t thread_idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); #if defined(__NVCC__) curandStatePhilox4_32_10_t state; - curand_init(seed, thread_idx, 0, &state); + curand_init(seed, thread_idx, offset, &state); #else hiprandStatePhilox4_32_10_t state; hiprand_init(seed, thread_idx, offset, &state); @@ -2272,6 +2260,14 @@ class MaxPool2dWithIndexFunctor { int blocks = (nthreads + thread_num - 1) / thread_num; dim3 threads(thread_num, 1); dim3 grid(blocks, 1); + + // generate seed for fractional pool + auto gen_cuda = context.GetGenerator(); + constexpr int increment_offset = 1 * 4; // one seed with multiple of 4 + auto seed_offset = gen_cuda->IncrementOffset(increment_offset); + uint64_t seed = seed_offset.first; + uint64_t offset = seed_offset.second; + KernelMaxPool2dWithIdx <<>>(nthreads, input_data, @@ -2288,6 +2284,8 @@ class MaxPool2dWithIndexFunctor { padding_width, adaptive, fractional, + seed, + offset, output_data, mask_data, pool_divmods); @@ -2336,6 +2334,14 @@ class MaxPool2dWithIndexGradFunctor { auto pool_divmods = FastDivModForPooling(input_channels, input_width, input_height); + + // generate seed for fractional pool + auto gen_cuda = context.GetGenerator(); + constexpr int increment_offset = 1 * 4; // one seed with multiple of 4 + auto seed_offset = gen_cuda->IncrementOffset(increment_offset); + uint64_t seed = seed_offset.first; + uint64_t offset = seed_offset.second; + KernelMaxPool2DWithIdxGrad <<>>(nthreads, output_grad_data, @@ -2353,6 +2359,8 @@ class MaxPool2dWithIndexGradFunctor { padding_width, adaptive, fractional, + seed, + offset, input_grad_data, pool_divmods); } @@ -2392,26 +2400,20 @@ __global__ void KernelMaxPool3DWithIdx(const int ncd, const int padding_width, bool adaptive, bool fractional, + uint64_t seed, + uint64_t offset, T1* output_data, T2* mask_data, FastDivModForPooling3D divmods_output) { float alpha_height = 0, alpha_width = 0, alpha_depth = 0; float u_height = 0, u_width = 0, u_depth = 0; if (fractional) { - // std::uniform_real_distribution dist(0, 1); - // auto engine = phi::GetCPURandomEngine(0); - // float u = dist(*engine); - - // auto gen_cuda = phi::DefaultCUDAGenerator(0); - // int seed = static_cast(gen_cuda->Random64()); - int seed = 0; - size_t thread_idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); #if defined(__NVCC__) curandStatePhilox4_32_10_t state; - curand_init(seed, thread_idx, 0, &state); + curand_init(seed, thread_idx, offset, &state); #else hiprandStatePhilox4_32_10_t state; hiprand_init(seed, thread_idx, offset, &state); @@ -2614,6 +2616,13 @@ class MaxPool3dWithIndexFunctor { auto pool_divmods_output = FastDivModForPooling3D( input_channels, output_width, output_height, output_depth); + // generate seed for fractional pool + auto gen_cuda = context.GetGenerator(); + constexpr int increment_offset = 1 * 4; // one seed with multiple of 4 + auto seed_offset = gen_cuda->IncrementOffset(increment_offset); + uint64_t seed = seed_offset.first; + uint64_t offset = seed_offset.second; + KernelMaxPool3DWithIdx <<>>(ncd, input_data, @@ -2635,6 +2644,8 @@ class MaxPool3dWithIndexFunctor { padding_width, adaptive, fractional, + seed, + offset, output_data, mask_data, pool_divmods_output); From 89297c81d28563a7ebd36bea3edda02f06787824 Mon Sep 17 00:00:00 2001 From: megemini Date: Tue, 21 Nov 2023 18:15:45 +0800 Subject: [PATCH 14/18] [Add] add param random_u for solid fractional --- paddle/phi/api/yaml/backward.yaml | 8 +- paddle/phi/api/yaml/ops.yaml | 4 +- paddle/phi/infermeta/backward.cc | 1 + paddle/phi/infermeta/backward.h | 1 + paddle/phi/infermeta/unary.cc | 1 + paddle/phi/infermeta/unary.h | 1 + paddle/phi/kernels/funcs/pooling.cc | 26 +++- paddle/phi/kernels/funcs/pooling.cu | 135 +++++++++++------- paddle/phi/kernels/funcs/pooling.h | 4 + .../phi/kernels/impl/pool_grad_kernel_impl.h | 7 + paddle/phi/kernels/impl/pool_kernel_impl.h | 7 + paddle/phi/kernels/pool_grad_kernel.h | 2 + paddle/phi/kernels/pool_kernel.h | 2 + python/paddle/nn/functional/pooling.py | 63 ++++++-- python/paddle/nn/layer/pooling.py | 12 +- test/legacy_test/test_pool_max_op.py | 53 ++++--- 16 files changed, 236 insertions(+), 91 deletions(-) diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 0c5df6c0bef81..cf32813d55086 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1480,8 +1480,8 @@ func : matrix_power_grad - backward_op : max_pool2d_with_index_grad - forward : max_pool2d_with_index(Tensor x, int[] kernel_size, int[] strides = {1, 1}, int[] paddings = {0, 0}, bool global_pooling = false, bool adaptive = false, bool fractional = false) -> Tensor(out), Tensor(mask) - args : (Tensor x, Tensor mask, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive, bool fractional) + forward : max_pool2d_with_index(Tensor x, int[] kernel_size, int[] strides = {1, 1}, int[] paddings = {0, 0}, bool global_pooling = false, bool adaptive = false, bool fractional = false, float random_u = 0.0) -> Tensor(out), Tensor(mask) + args : (Tensor x, Tensor mask, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive, bool fractional, float random_u) output : Tensor(x_grad) infer_meta : func : MaxPoolWithIndexGradInferMeta @@ -1489,8 +1489,8 @@ func : max_pool2d_with_index_grad - backward_op : max_pool3d_with_index_grad - forward : max_pool3d_with_index(Tensor x, int[] kernel_size, int[] strides = {1, 1, 1}, int[] paddings = {0, 0, 0}, bool global_pooling = false, bool adaptive = false, bool fractional = false) -> Tensor(out), Tensor(mask) - args : (Tensor x, Tensor mask, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive, bool fractional) + forward : max_pool3d_with_index(Tensor x, int[] kernel_size, int[] strides = {1, 1, 1}, int[] paddings = {0, 0, 0}, bool global_pooling = false, bool adaptive = false, bool fractional = false, float random_u = 0.0) -> Tensor(out), Tensor(mask) + args : (Tensor x, Tensor mask, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive, bool fractional, float random_u) output : Tensor(x_grad) infer_meta : func : MaxPoolWithIndexGradInferMeta diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 3ada51318570a..6b872d59dfc3c 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1684,7 +1684,7 @@ backward : matrix_power_grad - op : max_pool2d_with_index - args : (Tensor x, int[] kernel_size, int[] strides= {1, 1}, int[] paddings = {0, 0}, bool global_pooling = false, bool adaptive = false, bool fractional = false) + args : (Tensor x, int[] kernel_size, int[] strides= {1, 1}, int[] paddings = {0, 0}, bool global_pooling = false, bool adaptive = false, bool fractional = false, float random_u = 0.0) output : Tensor(out), Tensor(mask) infer_meta : func : MaxPoolWithIndexInferMeta @@ -1693,7 +1693,7 @@ backward : max_pool2d_with_index_grad - op : max_pool3d_with_index - args : (Tensor x, int[] kernel_size, int[] strides = {1, 1, 1}, int[] paddings = {0, 0, 0}, bool global_pooling = false, bool adaptive = false, bool fractional = false) + args : (Tensor x, int[] kernel_size, int[] strides = {1, 1, 1}, int[] paddings = {0, 0, 0}, bool global_pooling = false, bool adaptive = false, bool fractional = false, float random_u = 0.0) output : Tensor(out), Tensor(mask) infer_meta : func : MaxPoolWithIndexInferMeta diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 481b49bdb482d..b2eff914a0fe9 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -656,6 +656,7 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x, bool global_pooling, bool adaptive, bool fractional, + float random_u, MetaTensor* dx) { dx->share_meta(x); } diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 612af367e944c..c96cbe9026a42 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -309,6 +309,7 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x, bool global_pooling, bool adaptive, bool fractional, + float random_u, MetaTensor* dx); void MeshgridGradInferMeta(const std::vector& inputs, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 53f3e17f8ea20..4a088ee860f34 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2230,6 +2230,7 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x, bool global_pooling, bool adaptive, bool fractional, + float random_u, MetaTensor* out, MetaTensor* mask, MetaConfig config) { diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 0898660e1e1d6..d3981325edc64 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -336,6 +336,7 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x, bool global_pooling, bool adaptive, bool fractional, + float random_u, MetaTensor* out, MetaTensor* mask, MetaConfig config = MetaConfig()); diff --git a/paddle/phi/kernels/funcs/pooling.cc b/paddle/phi/kernels/funcs/pooling.cc index 0e25ae82130e0..851b74314fff6 100644 --- a/paddle/phi/kernels/funcs/pooling.cc +++ b/paddle/phi/kernels/funcs/pooling.cc @@ -1572,6 +1572,7 @@ class MaxPool2dWithIndexFunctor { const std::vector& paddings, bool adaptive, bool fractional, + float random_u, DenseTensor* output, DenseTensor* mask) { const int batch_size = static_cast(input.dims()[0]); @@ -1596,9 +1597,14 @@ class MaxPool2dWithIndexFunctor { float alpha_height = 0, alpha_width = 0; float u_height = 0, u_width = 0; if (fractional) { - std::uniform_real_distribution dist(0, 1); - auto engine = phi::GetCPURandomEngine(0); - float u = dist(*engine); + float u = 0; + if (random_u == 0) { + std::uniform_real_distribution dist(0, 1); + auto engine = phi::GetCPURandomEngine(0); + u = dist(*engine); + } else { + u = random_u; + } alpha_height = static_cast(input_height) / output_height; alpha_width = static_cast(input_width) / output_width; @@ -1680,6 +1686,7 @@ class MaxPool2dWithIndexGradFunctor { const std::vector& paddings UNUSED, bool adaptive UNUSED, bool fractional UNUSED, + float random_u UNUSED, DenseTensor* input_grad) { const int batch_size = static_cast(input_grad->dims()[0]); const int input_height = static_cast(input_grad->dims()[2]); @@ -1732,6 +1739,7 @@ class MaxPool3dWithIndexFunctor { const std::vector& paddings, bool adaptive, bool fractional, + float random_u, DenseTensor* output, DenseTensor* mask) { const int batch_size = static_cast(input.dims()[0]); @@ -1761,9 +1769,14 @@ class MaxPool3dWithIndexFunctor { float alpha_height = 0, alpha_width = 0, alpha_depth = 0; float u_height = 0, u_width = 0, u_depth = 0; if (fractional) { - std::uniform_real_distribution dist(0, 1); - auto engine = phi::GetCPURandomEngine(0); - float u = dist(*engine); + float u = 0; + if (fractional) { + std::uniform_real_distribution dist(0, 1); + auto engine = phi::GetCPURandomEngine(0); + u = dist(*engine); + } else { + u = random_u; + } alpha_depth = static_cast(input_depth) / output_depth; alpha_height = static_cast(input_height) / output_height; @@ -1867,6 +1880,7 @@ class MaxPool3dWithIndexGradFunctor { const std::vector& paddings UNUSED, bool adaptive UNUSED, bool fractional UNUSED, + float random_u UNUSED, DenseTensor* input_grad) { const int batch_size = static_cast(input_grad->dims()[0]); const int input_depth = static_cast(input_grad->dims()[2]); diff --git a/paddle/phi/kernels/funcs/pooling.cu b/paddle/phi/kernels/funcs/pooling.cu index 1880873721780..b50de21893c2f 100644 --- a/paddle/phi/kernels/funcs/pooling.cu +++ b/paddle/phi/kernels/funcs/pooling.cu @@ -1937,6 +1937,7 @@ __global__ void KernelMaxPool2dWithIdx(const int nthreads, const int padding_width, bool adaptive, bool fractional, + float random_u, uint64_t seed, uint64_t offset, T1* output_data, @@ -1945,20 +1946,23 @@ __global__ void KernelMaxPool2dWithIdx(const int nthreads, float alpha_height = 0, alpha_width = 0, alpha_depth = 0; float u_height = 0, u_width = 0, u_depth = 0; if (fractional) { - size_t thread_idx = - static_cast(blockIdx.x * blockDim.x + threadIdx.x); - + float u = 0; + if (random_u == 0) { + size_t thread_idx = + static_cast(blockIdx.x * blockDim.x + threadIdx.x); #if defined(__NVCC__) - curandStatePhilox4_32_10_t state; - curand_init(seed, thread_idx, offset, &state); + curandStatePhilox4_32_10_t state; + curand_init(seed, thread_idx, offset, &state); #else - hiprandStatePhilox4_32_10_t state; - hiprand_init(seed, thread_idx, offset, &state); + hiprandStatePhilox4_32_10_t state; + hiprand_init(seed, thread_idx, offset, &state); #endif - - phi::funcs::uniform_distribution dist; - float4 rand = dist(&state); - float u = (&rand.x)[0]; + phi::funcs::uniform_distribution dist; + float4 rand = dist(&state); + u = (&rand.x)[0]; + } else { + u = random_u; + } alpha_height = static_cast(input_height) / output_height; alpha_width = static_cast(input_width) / output_width; @@ -2097,6 +2101,7 @@ __global__ void KernelMaxPool2DWithIdxGrad(const int nthreads, const int padding_width, bool adaptive, bool fractional, + float random_u, uint64_t seed, uint64_t offset, T1* input_grad, @@ -2104,20 +2109,23 @@ __global__ void KernelMaxPool2DWithIdxGrad(const int nthreads, float alpha_height = 0, alpha_width = 0, alpha_depth = 0; float u_height = 0, u_width = 0, u_depth = 0; if (fractional) { - size_t thread_idx = - static_cast(blockIdx.x * blockDim.x + threadIdx.x); - + float u = 0; + if (random_u == 0) { + size_t thread_idx = + static_cast(blockIdx.x * blockDim.x + threadIdx.x); #if defined(__NVCC__) - curandStatePhilox4_32_10_t state; - curand_init(seed, thread_idx, offset, &state); + curandStatePhilox4_32_10_t state; + curand_init(seed, thread_idx, offset, &state); #else - hiprandStatePhilox4_32_10_t state; - hiprand_init(seed, thread_idx, offset, &state); + hiprandStatePhilox4_32_10_t state; + hiprand_init(seed, thread_idx, offset, &state); #endif - - phi::funcs::uniform_distribution dist; - float4 rand = dist(&state); - float u = (&rand.x)[0]; + phi::funcs::uniform_distribution dist; + float4 rand = dist(&state); + u = (&rand.x)[0]; + } else { + u = random_u; + } alpha_height = static_cast(input_height) / output_height; alpha_width = static_cast(input_width) / output_width; @@ -2203,6 +2211,7 @@ class MaxPool2dWithIndexFunctor { const std::vector& paddings, bool adaptive, bool fractional, + float random_u, DenseTensor* output, DenseTensor* mask) { const int batch_size = input.dims()[0]; @@ -2261,12 +2270,16 @@ class MaxPool2dWithIndexFunctor { dim3 threads(thread_num, 1); dim3 grid(blocks, 1); - // generate seed for fractional pool - auto gen_cuda = context.GetGenerator(); - constexpr int increment_offset = 1 * 4; // one seed with multiple of 4 - auto seed_offset = gen_cuda->IncrementOffset(increment_offset); - uint64_t seed = seed_offset.first; - uint64_t offset = seed_offset.second; + uint64_t seed = 0; + uint64_t offset = 0; + if (fractional) { + // generate seed for fractional pool + auto gen_cuda = context.GetGenerator(); + constexpr int increment_offset = 1 * 4; // one seed with multiple of 4 + auto seed_offset = gen_cuda->IncrementOffset(increment_offset); + seed = seed_offset.first; + offset = seed_offset.second; + } KernelMaxPool2dWithIdx <<>>(nthreads, @@ -2284,6 +2297,7 @@ class MaxPool2dWithIndexFunctor { padding_width, adaptive, fractional, + random_u, seed, offset, output_data, @@ -2309,6 +2323,7 @@ class MaxPool2dWithIndexGradFunctor { const std::vector& paddings, bool adaptive, bool fractional, + float random_u, DenseTensor* input_grad) { const int batch_size = input_grad->dims()[0]; const int input_channels = input_grad->dims()[1]; @@ -2335,12 +2350,16 @@ class MaxPool2dWithIndexGradFunctor { auto pool_divmods = FastDivModForPooling(input_channels, input_width, input_height); - // generate seed for fractional pool - auto gen_cuda = context.GetGenerator(); - constexpr int increment_offset = 1 * 4; // one seed with multiple of 4 - auto seed_offset = gen_cuda->IncrementOffset(increment_offset); - uint64_t seed = seed_offset.first; - uint64_t offset = seed_offset.second; + uint64_t seed = 0; + uint64_t offset = 0; + if (fractional) { + // generate seed for fractional pool + auto gen_cuda = context.GetGenerator(); + constexpr int increment_offset = 1 * 4; // one seed with multiple of 4 + auto seed_offset = gen_cuda->IncrementOffset(increment_offset); + seed = seed_offset.first; + offset = seed_offset.second; + } KernelMaxPool2DWithIdxGrad <<>>(nthreads, @@ -2359,6 +2378,7 @@ class MaxPool2dWithIndexGradFunctor { padding_width, adaptive, fractional, + random_u, seed, offset, input_grad_data, @@ -2400,6 +2420,7 @@ __global__ void KernelMaxPool3DWithIdx(const int ncd, const int padding_width, bool adaptive, bool fractional, + float random_u, uint64_t seed, uint64_t offset, T1* output_data, @@ -2408,20 +2429,23 @@ __global__ void KernelMaxPool3DWithIdx(const int ncd, float alpha_height = 0, alpha_width = 0, alpha_depth = 0; float u_height = 0, u_width = 0, u_depth = 0; if (fractional) { - size_t thread_idx = - static_cast(blockIdx.x * blockDim.x + threadIdx.x); - + float u = 0; + if (random_u == 0) { + size_t thread_idx = + static_cast(blockIdx.x * blockDim.x + threadIdx.x); #if defined(__NVCC__) - curandStatePhilox4_32_10_t state; - curand_init(seed, thread_idx, offset, &state); + curandStatePhilox4_32_10_t state; + curand_init(seed, thread_idx, offset, &state); #else - hiprandStatePhilox4_32_10_t state; - hiprand_init(seed, thread_idx, offset, &state); + hiprandStatePhilox4_32_10_t state; + hiprand_init(seed, thread_idx, offset, &state); #endif - - phi::funcs::uniform_distribution dist; - float4 rand = dist(&state); - float u = (&rand.x)[0]; + phi::funcs::uniform_distribution dist; + float4 rand = dist(&state); + u = (&rand.x)[0]; + } else { + u = random_u; + } alpha_depth = static_cast(input_depth) / output_depth; alpha_height = static_cast(input_height) / output_height; @@ -2531,6 +2555,7 @@ __global__ void KernelMaxPool3DWithIdxGrad( const int padding_width, bool adaptive, bool fractional, + float random_u, T1* input_grad, FastDivModForPooling3D divmods_output) { int w_offset, h_offset, d_offset, nc_offset; @@ -2574,6 +2599,7 @@ class MaxPool3dWithIndexFunctor { const std::vector& paddings, bool adaptive, bool fractional, + float random_u, DenseTensor* output, DenseTensor* mask) { const int batch_size = input.dims()[0]; @@ -2616,12 +2642,16 @@ class MaxPool3dWithIndexFunctor { auto pool_divmods_output = FastDivModForPooling3D( input_channels, output_width, output_height, output_depth); - // generate seed for fractional pool - auto gen_cuda = context.GetGenerator(); - constexpr int increment_offset = 1 * 4; // one seed with multiple of 4 - auto seed_offset = gen_cuda->IncrementOffset(increment_offset); - uint64_t seed = seed_offset.first; - uint64_t offset = seed_offset.second; + uint64_t seed = 0; + uint64_t offset = 0; + if (fractional) { + // generate seed for fractional pool + auto gen_cuda = context.GetGenerator(); + constexpr int increment_offset = 1 * 4; // one seed with multiple of 4 + auto seed_offset = gen_cuda->IncrementOffset(increment_offset); + seed = seed_offset.first; + offset = seed_offset.second; + } KernelMaxPool3DWithIdx <<>>(ncd, @@ -2644,6 +2674,7 @@ class MaxPool3dWithIndexFunctor { padding_width, adaptive, fractional, + random_u, seed, offset, output_data, @@ -2668,6 +2699,7 @@ class MaxPool3dWithIndexGradFunctor { const std::vector& paddings, bool adaptive, bool fractional, + float random_u, DenseTensor* input_grad) { const int batch_size = input_grad->dims()[0]; const int input_channels = input_grad->dims()[1]; @@ -2730,6 +2762,7 @@ class MaxPool3dWithIndexGradFunctor { padding_width, adaptive, fractional, + random_u, input_grad_data, pool_divmods_output); } diff --git a/paddle/phi/kernels/funcs/pooling.h b/paddle/phi/kernels/funcs/pooling.h index 53f0e274d99bb..4897e54aa08ac 100644 --- a/paddle/phi/kernels/funcs/pooling.h +++ b/paddle/phi/kernels/funcs/pooling.h @@ -348,6 +348,7 @@ class MaxPool2dWithIndexFunctor { const std::vector& paddings, bool adaptive, bool fractional, + float random_u, DenseTensor* output, DenseTensor* mask); }; @@ -363,6 +364,7 @@ class MaxPool2dWithIndexGradFunctor { const std::vector& paddings, bool adaptive, bool fractional, + float random_u, DenseTensor* input_grad); }; @@ -376,6 +378,7 @@ class MaxPool3dWithIndexFunctor { const std::vector& paddings, bool adaptive, bool fractional, + float random_u, DenseTensor* output, DenseTensor* mask); }; @@ -391,6 +394,7 @@ class MaxPool3dWithIndexGradFunctor { const std::vector& paddings, bool adaptive, bool fractional, + float random_u, DenseTensor* input_grad); }; diff --git a/paddle/phi/kernels/impl/pool_grad_kernel_impl.h b/paddle/phi/kernels/impl/pool_grad_kernel_impl.h index ec7edc9892940..8977a656e7318 100644 --- a/paddle/phi/kernels/impl/pool_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/pool_grad_kernel_impl.h @@ -151,6 +151,7 @@ void MaxPoolWithIndexGradRawKernel(const Context& ctx, bool global_pooling, bool adaptive, bool fractional, + float random_u, DenseTensor* dx) { std::vector paddings_ = paddings; std::vector kernel_size_ = kernel_size; @@ -177,6 +178,7 @@ void MaxPoolWithIndexGradRawKernel(const Context& ctx, paddings_, adaptive, fractional, + random_u, dx); } break; case 3: { @@ -189,6 +191,7 @@ void MaxPoolWithIndexGradRawKernel(const Context& ctx, paddings_, adaptive, fractional, + random_u, dx); } break; default: { @@ -278,6 +281,7 @@ void MaxPool2dWithIndexGradKernel(const Context& ctx, bool global_pooling, bool adaptive, bool fractional, + float random_u, DenseTensor* dx) { MaxPoolWithIndexGradRawKernel(ctx, x, @@ -289,6 +293,7 @@ void MaxPool2dWithIndexGradKernel(const Context& ctx, global_pooling, adaptive, fractional, + random_u, dx); } @@ -335,6 +340,7 @@ void MaxPool3dWithIndexGradKernel(const Context& ctx, bool global_pooling, bool adaptive, bool fractional, + float random_u, DenseTensor* dx) { MaxPoolWithIndexGradRawKernel(ctx, x, @@ -346,6 +352,7 @@ void MaxPool3dWithIndexGradKernel(const Context& ctx, global_pooling, adaptive, fractional, + random_u, dx); } diff --git a/paddle/phi/kernels/impl/pool_kernel_impl.h b/paddle/phi/kernels/impl/pool_kernel_impl.h index 8acf268fd8665..7fbe8f44481cb 100644 --- a/paddle/phi/kernels/impl/pool_kernel_impl.h +++ b/paddle/phi/kernels/impl/pool_kernel_impl.h @@ -193,6 +193,7 @@ void MaxPoolWithIndexRawKernel(const Context& ctx, bool global_pooling, bool adaptive, bool fractional, + float random_u, DenseTensor* out, DenseTensor* mask) { std::vector paddings_ = paddings; @@ -215,6 +216,7 @@ void MaxPoolWithIndexRawKernel(const Context& ctx, paddings_, adaptive, fractional, + random_u, out, mask); } break; @@ -227,6 +229,7 @@ void MaxPoolWithIndexRawKernel(const Context& ctx, paddings_, adaptive, fractional, + random_u, out, mask); } break; @@ -276,6 +279,7 @@ void MaxPool2dWithIndexKernel(const Context& ctx, bool global_pooling, bool adaptive, bool fractional, + float random_u, DenseTensor* out, DenseTensor* mask) { MaxPoolWithIndexRawKernel(ctx, @@ -286,6 +290,7 @@ void MaxPool2dWithIndexKernel(const Context& ctx, global_pooling, adaptive, fractional, + random_u, out, mask); } @@ -327,6 +332,7 @@ void MaxPool3dWithIndexKernel(const Context& ctx, bool global_pooling, bool adaptive, bool fractional, + float random_u, DenseTensor* out, DenseTensor* mask) { MaxPoolWithIndexRawKernel(ctx, @@ -337,6 +343,7 @@ void MaxPool3dWithIndexKernel(const Context& ctx, global_pooling, adaptive, fractional, + random_u, out, mask); } diff --git a/paddle/phi/kernels/pool_grad_kernel.h b/paddle/phi/kernels/pool_grad_kernel.h index 5a87ede3f3737..96349d663c0c6 100644 --- a/paddle/phi/kernels/pool_grad_kernel.h +++ b/paddle/phi/kernels/pool_grad_kernel.h @@ -97,6 +97,7 @@ void MaxPool2dWithIndexGradKernel(const Context& ctx, bool global_pooling, bool adaptive, bool fracional, + float random_u, DenseTensor* dx); template @@ -144,6 +145,7 @@ void MaxPool3dWithIndexGradKernel(const Context& ctx, bool global_pooling, bool adaptive, bool fracional, + float random_u, DenseTensor* dx); } // namespace phi diff --git a/paddle/phi/kernels/pool_kernel.h b/paddle/phi/kernels/pool_kernel.h index 49cd0a4955e59..8363890eaeff2 100644 --- a/paddle/phi/kernels/pool_kernel.h +++ b/paddle/phi/kernels/pool_kernel.h @@ -61,6 +61,7 @@ void MaxPool2dWithIndexKernel(const Context& ctx, bool global_pooling, bool adaptive, bool fractional, + float random_u, DenseTensor* out, DenseTensor* mask); @@ -103,6 +104,7 @@ void MaxPool3dWithIndexKernel(const Context& ctx, bool global_pooling, bool adaptive, bool fractional, + float random_u, DenseTensor* out, DenseTensor* mask); diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index f243280f58f6b..e19233839101e 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -632,7 +632,7 @@ def max_pool1d( if in_dygraph_mode(): if return_mask: pool_out = _C_ops.max_pool2d_with_index( - x, kernel_size, stride, padding, False, False, False + x, kernel_size, stride, padding, False, False, False, 0.0 ) return ( (squeeze(pool_out[0], [2]), squeeze(pool_out[1], [2])) @@ -1259,7 +1259,7 @@ def max_pool2d( if in_dynamic_or_pir_mode(): if return_mask: output = _C_ops.max_pool2d_with_index( - x, kernel_size, stride, padding, False, False, False + x, kernel_size, stride, padding, False, False, False, 0.0 ) return output if return_mask else output[0] else: @@ -1426,7 +1426,7 @@ def max_pool3d( if in_dygraph_mode(): if return_mask: output = _C_ops.max_pool3d_with_index( - x, kernel_size, stride, padding, False, False, False + x, kernel_size, stride, padding, False, False, False, 0.0 ) return output if return_mask else output[0] else: @@ -1877,7 +1877,7 @@ def adaptive_max_pool1d(x, output_size, return_mask=False, name=None): x = unsqueeze(x, [2]) if in_dygraph_mode(): pool_out = _C_ops.max_pool2d_with_index( - x, pool_size, [1, 1], [0, 0], False, True, False + x, pool_size, [1, 1], [0, 0], False, True, False, 0.0 ) return ( (squeeze(pool_out[0], [2]), squeeze(pool_out[1], [2])) @@ -1971,7 +1971,7 @@ def adaptive_max_pool2d(x, output_size, return_mask=False, name=None): output_size[1] = in_w if in_dygraph_mode(): pool_out = _C_ops.max_pool2d_with_index( - x, output_size, [1, 1], [0, 0], False, True, False + x, output_size, [1, 1], [0, 0], False, True, False, 0.0 ) return pool_out if return_mask else pool_out[0] else: @@ -2063,7 +2063,7 @@ def adaptive_max_pool3d(x, output_size, return_mask=False, name=None): if in_dygraph_mode(): # By default, strides is [1,1,1] and paddings is [0, 0, 0] pool_out = _C_ops.max_pool3d_with_index( - x, output_size, [1, 1, 1], [0, 0, 0], False, True, False + x, output_size, [1, 1, 1], [0, 0, 0], False, True, False, 0.0 ) return pool_out if return_mask else pool_out[0] else: @@ -2096,12 +2096,22 @@ def adaptive_max_pool3d(x, output_size, return_mask=False, name=None): return (pool_out, mask) if return_mask else pool_out -def fractional_max_pool2d(x, output_size, return_mask=False, name=None): +def fractional_max_pool2d( + x, output_size, random_u=None, return_mask=False, name=None +): """ TODO(megemini) """ _check_input(x, 4) + if random_u is None: + random_u = 0.0 + else: + if random_u <= 0 or random_u >= 1: + raise ValueError( + "The param `random_u` should be a `float` in (0, 1)." + ) + in_h, in_w = x.shape[2:4] if isinstance(output_size, int): output_size = convert_to_list(output_size, 2, 'output_size') @@ -2114,7 +2124,7 @@ def fractional_max_pool2d(x, output_size, return_mask=False, name=None): if in_dygraph_mode(): pool_out = _C_ops.max_pool2d_with_index( - x, output_size, [1, 1], [0, 0], False, False, True + x, output_size, [1, 1], [0, 0], False, False, True, float(random_u) ) return pool_out if return_mask else pool_out[0] else: @@ -2125,6 +2135,13 @@ def fractional_max_pool2d(x, output_size, return_mask=False, name=None): ) check_type(return_mask, 'return_mask', bool, 'fractional_max_pool2d') + check_variable_and_dtype( + random_u, + 'random_u', + ['float32', 'float64'], + 'fractional_max_pool2d', + ) + helper = LayerHelper(l_type, **locals()) dtype = helper.input_dtype(input_param_name='x') pool_out = helper.create_variable_for_type_inference(dtype) @@ -2140,18 +2157,29 @@ def fractional_max_pool2d(x, output_size, return_mask=False, name=None): "pooling_type": 'max', "ksize": output_size, "fractional": True, + "random_u": float(random_u), }, ) return (pool_out, mask) if return_mask else pool_out -def fractional_max_pool3d(x, output_size, return_mask=False, name=None): +def fractional_max_pool3d( + x, output_size, random_u=None, return_mask=False, name=None +): """ TODO(megemini) """ _check_input(x, 5) + if random_u is None: + random_u = 0.0 + else: + if random_u <= 0 or random_u >= 1: + raise ValueError( + "The param `random_u` should be a `float` in (0, 1)." + ) + in_l, in_h, in_w = x.shape[2:5] if isinstance(output_size, int): output_size = convert_to_list(output_size, 3, 'output_size') @@ -2167,7 +2195,14 @@ def fractional_max_pool3d(x, output_size, return_mask=False, name=None): if in_dygraph_mode(): # By default, strides is [1,1,1] and paddings is [0, 0, 0] pool_out = _C_ops.max_pool3d_with_index( - x, output_size, [1, 1, 1], [0, 0, 0], False, False, True + x, + output_size, + [1, 1, 1], + [0, 0, 0], + False, + False, + True, + float(random_u), ) return pool_out if return_mask else pool_out[0] else: @@ -2178,6 +2213,13 @@ def fractional_max_pool3d(x, output_size, return_mask=False, name=None): ) check_type(return_mask, 'return_mask', bool, 'fractional_max_pool3d') + check_variable_and_dtype( + random_u, + 'random_u', + ['float32', 'float64'], + 'fractional_max_pool2d', + ) + helper = LayerHelper(l_type, **locals()) dtype = helper.input_dtype(input_param_name='x') pool_out = helper.create_variable_for_type_inference(dtype) @@ -2193,6 +2235,7 @@ def fractional_max_pool3d(x, output_size, return_mask=False, name=None): "pooling_type": 'max', "ksize": output_size, "fractional": True, + "random_u": float(random_u), }, ) diff --git a/python/paddle/nn/layer/pooling.py b/python/paddle/nn/layer/pooling.py index 8db539fe7f763..f4d07d5b5eb4b 100755 --- a/python/paddle/nn/layer/pooling.py +++ b/python/paddle/nn/layer/pooling.py @@ -1146,9 +1146,12 @@ class FractionalMaxPool2D(Layer): TODO(megemini) """ - def __init__(self, output_size, return_mask=False, name=None): + def __init__( + self, output_size, random_u=None, return_mask=False, name=None + ): super().__init__() self._output_size = output_size + self._random_u = random_u self._return_mask = return_mask self._name = name @@ -1156,6 +1159,7 @@ def forward(self, x): return F.fractional_max_pool2d( x, output_size=self._output_size, + random_u=self._random_u, return_mask=self._return_mask, name=self._name, ) @@ -1171,9 +1175,12 @@ class FractionalMaxPool3D(Layer): TODO(megemini) """ - def __init__(self, output_size, return_mask=False, name=None): + def __init__( + self, output_size, random_u=None, return_mask=False, name=None + ): super().__init__() self._output_size = output_size + self._random_u = random_u self._return_mask = return_mask self._name = name @@ -1181,6 +1188,7 @@ def forward(self, x): return F.fractional_max_pool3d( x, output_size=self._output_size, + random_u=self._random_u, return_mask=self._return_mask, name=self._name, ) diff --git a/test/legacy_test/test_pool_max_op.py b/test/legacy_test/test_pool_max_op.py index c058b1361b0d9..d1a2b55bbfd38 100644 --- a/test/legacy_test/test_pool_max_op.py +++ b/test/legacy_test/test_pool_max_op.py @@ -61,6 +61,7 @@ def max_pool3D_forward_naive( global_pool=False, adaptive=False, fractional=False, + random_u=None, ): N, C, D, H, W = x.shape if global_pool: @@ -87,8 +88,7 @@ def max_pool3D_forward_naive( input_width = W output_width = W_out if fractional: - np.random.seed(2023) - u = np.random.uniform() + u = random_u alpha_depth = input_depth / output_depth alpha_height = input_height / output_height @@ -171,6 +171,7 @@ def max_pool2D_forward_naive( global_pool=False, adaptive=False, fractional=False, + random_u=None, ): N, C, H, W = x.shape if global_pool: @@ -192,8 +193,7 @@ def max_pool2D_forward_naive( input_width = W output_width = W_out if fractional: - np.random.seed(2023) - u = np.random.uniform() + u = random_u alpha_height = input_height / output_height alpha_width = input_width / output_width @@ -253,9 +253,17 @@ def max_pool3d_with_index_wapper( global_pooling=False, adaptive=False, fractional=False, + random_u=None, ): return paddle._C_ops.max_pool3d_with_index( - x, kernel_size, strides, paddings, global_pooling, adaptive, fractional + x, + kernel_size, + strides, + paddings, + global_pooling, + adaptive, + fractional, + random_u, ) @@ -285,6 +293,7 @@ def setUp(self): self.global_pool, self.adaptive, self.fractional, + self.random_u, ) mask = mask.astype("int32") if self.is_bfloat16_op(): @@ -299,6 +308,7 @@ def setUp(self): 'global_pooling': self.global_pool, 'adaptive': self.adaptive, 'fractional': self.fractional, + 'random_u': self.random_u, } if self.is_bfloat16_op(): @@ -339,6 +349,7 @@ def init_adaptive(self): def init_fractional(self): self.fractional = False + self.random_u = None class TestCase1(TestMaxPoolWithIndex_Op): @@ -370,9 +381,10 @@ def init_adaptive(self): self.adaptive = True -# class TestCastFractional3d(TestMaxPoolWithIndex_Op): -# def init_fractional(self): -# self.fractional = True +class TestCastFractional3d(TestMaxPoolWithIndex_Op): + def init_fractional(self): + self.fractional = True + self.random_u = 0.3 # ----------------max_pool3d_with_index_fp16---------------- @@ -405,7 +417,7 @@ def test_check_grad(self): create_test_fp16_class(TestCase2) create_test_fp16_class(TestCase3) create_test_fp16_class(TestCastAdaptive3d) -# create_test_fp16_class(TestCastFractional3d) +create_test_fp16_class(TestCastFractional3d) # ----------------max_pool3d_with_index_bf16---------------- @@ -454,7 +466,7 @@ def test_check_grad(self): create_test_bf16_class(TestCase2) create_test_bf16_class(TestCase3) create_test_bf16_class(TestCastAdaptive3d) -# create_test_bf16_class(TestCastFractional3d) +create_test_bf16_class(TestCastFractional3d) # ----------------max_pool2d_with_index---------------- @@ -466,9 +478,17 @@ def max_pool2d_with_index_wapper( global_pooling=False, adaptive=False, fractional=False, + random_u=None, ): return paddle._C_ops.max_pool2d_with_index( - x, kernel_size, strides, paddings, global_pooling, adaptive, fractional + x, + kernel_size, + strides, + paddings, + global_pooling, + adaptive, + fractional, + random_u, ) @@ -515,9 +535,10 @@ def init_adaptive(self): self.adaptive = True -# class TestCastFractional2d(TestCase6): -# def init_fractional(self): -# self.fractional = True +class TestCastFractional2d(TestCase6): + def init_fractional(self): + self.fractional = True + self.random_u = 0.3 # ----------------max_pool2d_with_index_fp16---------------- @@ -550,7 +571,7 @@ def test_check_grad(self): create_test_fp16_class(TestCase6) create_test_fp16_class(TestCase7) create_test_fp16_class(TestCastAdaptive2d) -# create_test_fp16_class(TestCastFractional2d) +create_test_fp16_class(TestCastFractional2d) # ----------------max_pool2d_with_index_bf16---------------- @@ -597,7 +618,7 @@ def test_check_grad(self): create_test_bf16_class(TestCase6) create_test_bf16_class(TestCase7) create_test_bf16_class(TestCastAdaptive2d) -# create_test_bf16_class(TestCastFractional2d) +create_test_bf16_class(TestCastFractional2d) if __name__ == '__main__': From ba9d5831c00828efc178fcdc19533d2167874e0a Mon Sep 17 00:00:00 2001 From: megemini Date: Wed, 22 Nov 2023 17:06:50 +0800 Subject: [PATCH 15/18] [Fix] fractional pool 3d random_u --- paddle/phi/kernels/funcs/pooling.cc | 2 +- test/legacy_test/test_pool_max_op.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/funcs/pooling.cc b/paddle/phi/kernels/funcs/pooling.cc index 851b74314fff6..891d2c8eea3e0 100644 --- a/paddle/phi/kernels/funcs/pooling.cc +++ b/paddle/phi/kernels/funcs/pooling.cc @@ -1770,7 +1770,7 @@ class MaxPool3dWithIndexFunctor { float u_height = 0, u_width = 0, u_depth = 0; if (fractional) { float u = 0; - if (fractional) { + if (random_u == 0) { std::uniform_real_distribution dist(0, 1); auto engine = phi::GetCPURandomEngine(0); u = dist(*engine); diff --git a/test/legacy_test/test_pool_max_op.py b/test/legacy_test/test_pool_max_op.py index d1a2b55bbfd38..dbedaef7a023d 100644 --- a/test/legacy_test/test_pool_max_op.py +++ b/test/legacy_test/test_pool_max_op.py @@ -349,7 +349,7 @@ def init_adaptive(self): def init_fractional(self): self.fractional = False - self.random_u = None + self.random_u = 0.3 class TestCase1(TestMaxPoolWithIndex_Op): From 6d6f93b86bc7f60fb2755a53b042fc21a4bf0366 Mon Sep 17 00:00:00 2001 From: megemini Date: Thu, 23 Nov 2023 12:10:18 +0800 Subject: [PATCH 16/18] [Add] add xpu support for fractional max pool 2d with index --- paddle/phi/kernels/xpu/pool_grad_kernel.cc | 2 ++ paddle/phi/kernels/xpu/pool_kernel.cc | 2 ++ 2 files changed, 4 insertions(+) diff --git a/paddle/phi/kernels/xpu/pool_grad_kernel.cc b/paddle/phi/kernels/xpu/pool_grad_kernel.cc index b03be5dd9449c..531079ff1d12d 100644 --- a/paddle/phi/kernels/xpu/pool_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/pool_grad_kernel.cc @@ -386,6 +386,8 @@ void MaxPool2dWithIndexGradKernel(const Context& ctx, const std::vector& paddings_t, bool global_pooling, bool adaptive, + bool fracional, + float random_u, DenseTensor* dx) { using XPUType = typename XPUTypeTrait::Type; diff --git a/paddle/phi/kernels/xpu/pool_kernel.cc b/paddle/phi/kernels/xpu/pool_kernel.cc index 466adade072c7..c6076e20b35da 100644 --- a/paddle/phi/kernels/xpu/pool_kernel.cc +++ b/paddle/phi/kernels/xpu/pool_kernel.cc @@ -308,6 +308,8 @@ void MaxPool2dWithIndexKernel(const Context& ctx, const std::vector& paddings_t, bool global_pooling, bool adaptive, + bool fractional, + float random_u, DenseTensor* out, DenseTensor* mask) { using XPUType = typename XPUTypeTrait::Type; From 8ea0e52ad8e2169b90f494dd60ab4d1a0644c2b7 Mon Sep 17 00:00:00 2001 From: megemini Date: Sun, 26 Nov 2023 18:46:25 +0800 Subject: [PATCH 17/18] [Add] fractional api unittest --- python/paddle/nn/functional/pooling.py | 10 +- .../legacy_test/test_fractional_max_pool2d.py | 411 ++++++++++++++++ .../legacy_test/test_fractional_max_pool3d.py | 438 +++++++++++++++++- 3 files changed, 853 insertions(+), 6 deletions(-) diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index e19233839101e..a044a74aa0f46 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -2135,10 +2135,10 @@ def fractional_max_pool2d( ) check_type(return_mask, 'return_mask', bool, 'fractional_max_pool2d') - check_variable_and_dtype( + check_type( random_u, 'random_u', - ['float32', 'float64'], + float, 'fractional_max_pool2d', ) @@ -2213,11 +2213,11 @@ def fractional_max_pool3d( ) check_type(return_mask, 'return_mask', bool, 'fractional_max_pool3d') - check_variable_and_dtype( + check_type( random_u, 'random_u', - ['float32', 'float64'], - 'fractional_max_pool2d', + float, + 'fractional_max_pool3d', ) helper = LayerHelper(l_type, **locals()) diff --git a/test/legacy_test/test_fractional_max_pool2d.py b/test/legacy_test/test_fractional_max_pool2d.py index 595add0aed9e1..d190bd17d3c65 100644 --- a/test/legacy_test/test_fractional_max_pool2d.py +++ b/test/legacy_test/test_fractional_max_pool2d.py @@ -11,3 +11,414 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import unittest + +import numpy as np +from op_test import check_out_dtype + +import paddle +import paddle.nn.functional as F +from paddle import base +from paddle.base import core + + +def fractional_rational_u(u, alpha, input, output): + base = input // output + + u_max1 = (base + 2) / alpha - 1 + u_max2 = (input + 1 - base) / alpha - (output - 1) + max_u = min(u_max1, u_max2) + + return u * max_u + + +def fractional_start_index(idx, alpha, u): + return int(np.ceil(alpha * (idx + u) - 1)) + + +def fractional_end_index(idx, alpha, u): + return int(np.ceil(alpha * (idx + 1 + u) - 1)) + + +def fractional_pool2d_forward( + x, + output_size, + random_u=None, + data_format='NCHW', + pool_type="max", +): + N = x.shape[0] + C, H, W = ( + [x.shape[1], x.shape[2], x.shape[3]] + if data_format == 'NCHW' + else [x.shape[3], x.shape[1], x.shape[2]] + ) + + if isinstance(output_size, int) or output_size is None: + H_out = output_size + W_out = output_size + output_size = [H_out, W_out] + else: + H_out, W_out = output_size + + if output_size[0] is None: + output_size[0] = H + H_out = H + if output_size[1] is None: + output_size[1] = W + W_out = W + + out = ( + np.zeros((N, C, H_out, W_out)) + if data_format == 'NCHW' + else np.zeros((N, H_out, W_out, C)) + ) + + input_height = H + output_height = H_out + input_width = W + output_width = W_out + + u = random_u + + alpha_height = input_height / output_height + alpha_width = input_width / output_width + + u_height = fractional_rational_u( + u, alpha_height, input_height, output_height + ) + u_width = fractional_rational_u(u, alpha_width, input_width, output_width) + + for i in range(H_out): + in_h_start = fractional_start_index(i, alpha_height, u_height) + in_h_end = fractional_end_index(i, alpha_height, u_height) + in_h_start = max(in_h_start, 0) + in_h_end = min(in_h_end, input_height) + + for j in range(W_out): + in_w_start = fractional_start_index(j, alpha_width, u_width) + in_w_end = fractional_end_index(j, alpha_width, u_width) + in_w_start = max(in_w_start, 0) + in_w_end = min(in_w_end, input_width) + + if data_format == 'NCHW': + x_masked = x[:, :, in_h_start:in_h_end, in_w_start:in_w_end] + if pool_type == 'avg': + field_size = (in_h_end - in_h_start) * ( + in_w_end - in_w_start + ) + out[:, :, i, j] = np.sum(x_masked, axis=(2, 3)) / field_size + elif pool_type == 'max': + out[:, :, i, j] = np.max(x_masked, axis=(2, 3)) + elif data_format == 'NHWC': + x_masked = x[:, in_h_start:in_h_end, in_w_start:in_w_end, :] + if pool_type == 'avg': + field_size = (in_h_end - in_h_start) * ( + in_w_end - in_w_start + ) + out[:, i, j, :] = np.sum(x_masked, axis=(1, 2)) / field_size + elif pool_type == 'max': + out[:, i, j, :] = np.max(x_masked, axis=(1, 2)) + return out + + +class TestFractionalMaxPool2DAPI(unittest.TestCase): + def setUp(self): + self.x_np = np.random.random([2, 3, 7, 7]).astype("float32") + self.res_1_np = fractional_pool2d_forward( + x=self.x_np, output_size=[3, 3], random_u=0.3 + ) + + self.res_2_np = fractional_pool2d_forward( + x=self.x_np, output_size=5, random_u=0.5 + ) + + self.res_3_np = fractional_pool2d_forward( + x=self.x_np, output_size=[2, 5], random_u=0.7 + ) + + # self.res_4_np = fractional_pool2d_forward( + # x=self.x_np, + # output_size=[3, 3], + # pool_type="max", + # data_format="NHWC", + # random_u=0.1) + + self.res_5_np = fractional_pool2d_forward( + x=self.x_np, output_size=[None, 3], random_u=0.6 + ) + + def test_static_graph(self): + for use_cuda in ( + [False, True] if core.is_compiled_with_cuda() else [False] + ): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + paddle.enable_static() + x = paddle.static.data( + name="x", shape=[2, 3, 7, 7], dtype="float32" + ) + + out_1 = paddle.nn.functional.fractional_max_pool2d( + x=x, output_size=[3, 3], random_u=0.3 + ) + + out_2 = paddle.nn.functional.fractional_max_pool2d( + x=x, output_size=5, random_u=0.5 + ) + + out_3 = paddle.nn.functional.fractional_max_pool2d( + x=x, output_size=[2, 5], random_u=0.7 + ) + + # out_4 = paddle.nn.functional.fractional_max_pool2d( + # x=x, output_size=[3, 3], data_format="NHWC", random_u=0.1) + + out_5 = paddle.nn.functional.fractional_max_pool2d( + x=x, output_size=[None, 3], random_u=0.6 + ) + + exe = paddle.static.Executor(place=place) + [res_1, res_2, res_3, res_5] = exe.run( + base.default_main_program(), + feed={"x": self.x_np}, + fetch_list=[out_1, out_2, out_3, out_5], + ) + + np.testing.assert_allclose(res_1, self.res_1_np) + + np.testing.assert_allclose(res_2, self.res_2_np) + + np.testing.assert_allclose(res_3, self.res_3_np) + + # np.testing.assert_allclose(res_4, self.res_4_np) + + np.testing.assert_allclose(res_5, self.res_5_np) + + def test_static_graph_return_mask(self): + for use_cuda in ( + [False, True] if core.is_compiled_with_cuda() else [False] + ): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + paddle.enable_static() + x = paddle.static.data( + name="x", shape=[2, 3, 7, 7], dtype="float32" + ) + + out_1 = paddle.nn.functional.fractional_max_pool2d( + x=x, output_size=[3, 3], return_mask=True, random_u=0.3 + ) + + out_2 = paddle.nn.functional.fractional_max_pool2d( + x=x, output_size=5, return_mask=True, random_u=0.5 + ) + + out_3 = paddle.nn.functional.fractional_max_pool2d( + x=x, output_size=[2, 5], return_mask=True, random_u=0.7 + ) + + # out_4 = paddle.nn.functional.fractional_max_pool2d( + # x=x, output_size=[3, 3], data_format="NHWC", return_mask=True, random_u=0.1) + + out_5 = paddle.nn.functional.fractional_max_pool2d( + x=x, output_size=[None, 3], return_mask=True, random_u=0.6 + ) + + exe = paddle.static.Executor(place=place) + [ + res_1, + mask_1, + res_2, + mask_2, + res_3, + mask_3, + res_5, + mask_5, + ] = exe.run( + base.default_main_program(), + feed={"x": self.x_np}, + fetch_list=[out_1, out_2, out_3, out_5], + ) + + self.assertEqual(res_1.shape, mask_1.shape) + + self.assertEqual(res_2.shape, mask_2.shape) + + self.assertEqual(res_3.shape, mask_3.shape) + + # self.assertEqual(res_4.shape, mask_4.shape) + + self.assertEqual(res_5.shape, mask_5.shape) + + def test_dynamic_graph(self): + for use_cuda in ( + [False, True] if core.is_compiled_with_cuda() else [False] + ): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + paddle.disable_static(place=place) + x = paddle.to_tensor(self.x_np) + + out_1 = paddle.nn.functional.fractional_max_pool2d( + x=x, return_mask=False, output_size=[3, 3], random_u=0.3 + ) + + out_2 = paddle.nn.functional.fractional_max_pool2d( + x=x, output_size=5, random_u=0.5 + ) + + out_3 = paddle.nn.functional.fractional_max_pool2d( + x=x, output_size=[2, 5], random_u=0.7 + ) + + # out_4 = paddle.nn.functional.fractional_max_pool2d( + # x=x, output_size=[3, 3], data_format="NHWC", random_u=0.1) + + out_5 = paddle.nn.functional.fractional_max_pool2d( + x=x, output_size=[None, 3], random_u=0.6 + ) + + np.testing.assert_allclose(out_1.numpy(), self.res_1_np) + + np.testing.assert_allclose(out_2.numpy(), self.res_2_np) + + np.testing.assert_allclose(out_3.numpy(), self.res_3_np) + + # np.testing.assert_allclose(out_4.numpy(), self.res_4_np) + + np.testing.assert_allclose(out_5.numpy(), self.res_5_np) + + +class TestFractionalMaxPool2DClassAPI(unittest.TestCase): + def setUp(self): + self.x_np = np.random.random([2, 3, 7, 7]).astype("float32") + self.res_1_np = fractional_pool2d_forward( + x=self.x_np, output_size=[3, 3], random_u=0.3 + ) + + self.res_2_np = fractional_pool2d_forward( + x=self.x_np, output_size=5, random_u=0.5 + ) + + self.res_3_np = fractional_pool2d_forward( + x=self.x_np, output_size=[2, 5], random_u=0.7 + ) + + # self.res_4_np = fractional_pool2d_forward( + # x=self.x_np, + # output_size=[3, 3], + # pool_type="max", + # data_format="NHWC", + # random_u=0.1) + + self.res_5_np = fractional_pool2d_forward( + x=self.x_np, output_size=[None, 3], random_u=0.6 + ) + + def test_static_graph(self): + for use_cuda in ( + [False, True] if core.is_compiled_with_cuda() else [False] + ): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + paddle.enable_static() + x = paddle.static.data( + name="x", shape=[2, 3, 7, 7], dtype="float32" + ) + + fractional_max_pool = paddle.nn.FractionalMaxPool2D( + output_size=[3, 3], random_u=0.3 + ) + out_1 = fractional_max_pool(x=x) + + fractional_max_pool = paddle.nn.FractionalMaxPool2D( + output_size=5, random_u=0.5 + ) + out_2 = fractional_max_pool(x=x) + + fractional_max_pool = paddle.nn.FractionalMaxPool2D( + output_size=[2, 5], random_u=0.7 + ) + out_3 = fractional_max_pool(x=x) + + # fractional_max_pool = paddle.nn.FractionalMaxPool2D( + # output_size=[3, 3], data_format="NHWC", random_u=0.1) + # out_4 = fractional_max_pool(x=x) + + fractional_max_pool = paddle.nn.FractionalMaxPool2D( + output_size=[None, 3], random_u=0.6 + ) + out_5 = fractional_max_pool(x=x) + + exe = paddle.static.Executor(place=place) + [res_1, res_2, res_3, res_5] = exe.run( + base.default_main_program(), + feed={"x": self.x_np}, + fetch_list=[out_1, out_2, out_3, out_5], + ) + + np.testing.assert_allclose(res_1, self.res_1_np) + + np.testing.assert_allclose(res_2, self.res_2_np) + + np.testing.assert_allclose(res_3, self.res_3_np) + + # np.testing.assert_allclose(res_4, self.res_4_np) + + np.testing.assert_allclose(res_5, self.res_5_np) + + def test_dynamic_graph(self): + for use_cuda in ( + [False, True] if core.is_compiled_with_cuda() else [False] + ): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + paddle.disable_static(place=place) + x = paddle.to_tensor(self.x_np) + + fractional_max_pool = paddle.nn.FractionalMaxPool2D( + output_size=[3, 3], random_u=0.3 + ) + out_1 = fractional_max_pool(x=x) + + fractional_max_pool = paddle.nn.FractionalMaxPool2D( + output_size=5, random_u=0.5 + ) + out_2 = fractional_max_pool(x=x) + + fractional_max_pool = paddle.nn.FractionalMaxPool2D( + output_size=[2, 5], random_u=0.7 + ) + out_3 = fractional_max_pool(x=x) + + # fractional_max_pool = paddle.nn.FractionalMaxPool2D( + # output_size=[3, 3], data_format="NHWC", random_u=0.1) + # out_4 = fractional_max_pool(x=x) + + fractional_max_pool = paddle.nn.FractionalMaxPool2D( + output_size=[None, 3], random_u=0.6 + ) + out_5 = fractional_max_pool(x=x) + + np.testing.assert_allclose(out_1.numpy(), self.res_1_np) + + np.testing.assert_allclose(out_2.numpy(), self.res_2_np) + + np.testing.assert_allclose(out_3.numpy(), self.res_3_np) + + # np.testing.assert_allclose(out_4.numpy(), self.res_4_np) + + np.testing.assert_allclose(out_5.numpy(), self.res_5_np) + + +class TestOutDtype(unittest.TestCase): + def test_max_pool(self): + api_fn = F.fractional_max_pool2d + shape = [1, 3, 32, 32] + check_out_dtype( + api_fn, + in_specs=[(shape,)], + expect_dtypes=['float32', 'float64'], + output_size=16, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/legacy_test/test_fractional_max_pool3d.py b/test/legacy_test/test_fractional_max_pool3d.py index 595add0aed9e1..5890aa2ab55d4 100644 --- a/test/legacy_test/test_fractional_max_pool3d.py +++ b/test/legacy_test/test_fractional_max_pool3d.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,3 +11,439 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import unittest + +import numpy as np +from op_test import check_out_dtype + +import paddle +import paddle.nn.functional as F +from paddle import base +from paddle.base import core + + +def fractional_rational_u(u, alpha, input, output): + base = input // output + + u_max1 = (base + 2) / alpha - 1 + u_max2 = (input + 1 - base) / alpha - (output - 1) + max_u = min(u_max1, u_max2) + + return u * max_u + + +def fractional_start_index(idx, alpha, u): + return int(np.ceil(alpha * (idx + u) - 1)) + + +def fractional_end_index(idx, alpha, u): + return int(np.ceil(alpha * (idx + 1 + u) - 1)) + + +def fractional_pool3d_forward( + x, output_size, random_u=None, data_format='NCDHW', pool_type='max' +): + N = x.shape[0] + C, D, H, W = ( + [x.shape[1], x.shape[2], x.shape[3], x.shape[4]] + if data_format == 'NCDHW' + else [x.shape[4], x.shape[1], x.shape[2], x.shape[3]] + ) + + if isinstance(output_size, int) or output_size is None: + H_out = output_size + W_out = output_size + D_out = output_size + output_size = [D_out, H_out, W_out] + else: + D_out, H_out, W_out = output_size + + if output_size[0] is None: + output_size[0] = D + D_out = D + if output_size[1] is None: + output_size[1] = H + H_out = H + if output_size[2] is None: + output_size[2] = W + W_out = W + + out = ( + np.zeros((N, C, D_out, H_out, W_out)) + if data_format == 'NCDHW' + else np.zeros((N, D_out, H_out, W_out, C)) + ) + + input_depth = D + output_depth = D_out + input_height = H + output_height = H_out + input_width = W + output_width = W_out + + u = random_u + + alpha_depth = input_depth / output_depth + alpha_height = input_height / output_height + alpha_width = input_width / output_width + + u_depth = fractional_rational_u(u, alpha_depth, input_depth, output_depth) + u_height = fractional_rational_u( + u, alpha_height, input_height, output_height + ) + u_width = fractional_rational_u(u, alpha_width, input_width, output_width) + + for k in range(D_out): + d_start = fractional_start_index(k, alpha_depth, u_depth) + d_end = fractional_end_index(k, alpha_depth, u_depth) + d_start = max(d_start, 0) + d_end = min(d_end, input_depth) + + for i in range(H_out): + h_start = fractional_start_index(i, alpha_height, u_height) + h_end = fractional_end_index(i, alpha_height, u_height) + h_start = max(h_start, 0) + h_end = min(h_end, input_height) + + for j in range(W_out): + w_start = fractional_start_index(j, alpha_width, u_width) + w_end = fractional_end_index(j, alpha_width, u_width) + w_start = max(w_start, 0) + w_end = min(w_end, input_width) + + if data_format == 'NCDHW': + x_masked = x[ + :, :, d_start:d_end, h_start:h_end, w_start:w_end + ] + if pool_type == 'avg': + field_size = ( + (d_end - d_start) + * (h_end - h_start) + * (w_end - w_start) + ) + out[:, :, k, i, j] = ( + np.sum(x_masked, axis=(2, 3, 4)) / field_size + ) + elif pool_type == 'max': + out[:, :, k, i, j] = np.max(x_masked, axis=(2, 3, 4)) + + elif data_format == 'NDHWC': + x_masked = x[ + :, d_start:d_end, h_start:h_end, w_start:w_end, : + ] + if pool_type == 'avg': + field_size = ( + (d_end - d_start) + * (h_end - h_start) + * (w_end - w_start) + ) + out[:, k, i, j, :] = ( + np.sum(x_masked, axis=(1, 2, 3)) / field_size + ) + elif pool_type == 'max': + out[:, k, i, j, :] = np.max(x_masked, axis=(1, 2, 3)) + return out + + +class TestFractionalMaxPool3DAPI(unittest.TestCase): + def setUp(self): + self.x_np = np.random.random([2, 3, 5, 7, 7]).astype("float32") + self.res_1_np = fractional_pool3d_forward( + x=self.x_np, output_size=[3, 3, 3], random_u=0.3 + ) + + self.res_2_np = fractional_pool3d_forward( + x=self.x_np, output_size=5, random_u=0.5 + ) + + self.res_3_np = fractional_pool3d_forward( + x=self.x_np, output_size=[2, 3, 5], random_u=0.7 + ) + + self.res_4_np = fractional_pool3d_forward( + x=self.x_np, + output_size=[3, 3, 3], + pool_type="max", + data_format="NDHWC", + random_u=0.1, + ) + + self.res_5_np = fractional_pool3d_forward( + x=self.x_np, output_size=[None, 3, None], random_u=0.6 + ) + + def test_static_graph(self): + for use_cuda in ( + [False, True] if core.is_compiled_with_cuda() else [False] + ): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + paddle.enable_static() + x = paddle.static.data( + name="x", shape=[2, 3, 5, 7, 7], dtype="float32" + ) + + out_1 = paddle.nn.functional.fractional_max_pool3d( + x=x, output_size=[3, 3, 3], random_u=0.3 + ) + + out_2 = paddle.nn.functional.fractional_max_pool3d( + x=x, output_size=5, random_u=0.5 + ) + + out_3 = paddle.nn.functional.fractional_max_pool3d( + x=x, output_size=[2, 3, 5], random_u=0.7 + ) + + # out_4 = paddle.nn.functional.fractional_max_pool3d( + # x=x, output_size=[3, 3, 3], data_format="NDHWC", random_u=0.1) + + out_5 = paddle.nn.functional.fractional_max_pool3d( + x=x, output_size=[None, 3, None], random_u=0.6 + ) + + exe = paddle.static.Executor(place=place) + [res_1, res_2, res_3, res_5] = exe.run( + base.default_main_program(), + feed={"x": self.x_np}, + fetch_list=[out_1, out_2, out_3, out_5], + ) + + np.testing.assert_allclose(res_1, self.res_1_np) + + np.testing.assert_allclose(res_2, self.res_2_np) + + np.testing.assert_allclose(res_3, self.res_3_np) + + # np.testing.assert_allclose(res_4, self.res_4_np) + + np.testing.assert_allclose(res_5, self.res_5_np) + + def test_static_graph_return_mask(self): + for use_cuda in ( + [False, True] if core.is_compiled_with_cuda() else [False] + ): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + paddle.enable_static() + x = paddle.static.data( + name="x", shape=[2, 3, 5, 7, 7], dtype="float32" + ) + + out_1 = paddle.nn.functional.fractional_max_pool3d( + x=x, output_size=[3, 3, 3], return_mask=True, random_u=0.3 + ) + + out_2 = paddle.nn.functional.fractional_max_pool3d( + x=x, output_size=5, return_mask=True, random_u=0.5 + ) + + out_3 = paddle.nn.functional.fractional_max_pool3d( + x=x, output_size=[2, 3, 5], return_mask=True, random_u=0.7 + ) + + # out_4 = paddle.nn.functional.fractional_max_pool3d( + # x=x, output_size=[3, 3, 3], data_format="NHWC", return_mask=True, random_u=0.1) + + out_5 = paddle.nn.functional.fractional_max_pool3d( + x=x, output_size=[None, 3, None], return_mask=True, random_u=0.6 + ) + + exe = paddle.static.Executor(place=place) + [ + res_1, + mask_1, + res_2, + mask_2, + res_3, + mask_3, + res_5, + mask_5, + ] = exe.run( + base.default_main_program(), + feed={"x": self.x_np}, + fetch_list=[out_1, out_2, out_3, out_5], + ) + + self.assertEqual(res_1.shape, mask_1.shape) + + self.assertEqual(res_2.shape, mask_2.shape) + + self.assertEqual(res_3.shape, mask_3.shape) + + # self.assertEqual(res_4.shape, mask_4.shape) + + self.assertEqual(res_5.shape, mask_5.shape) + + def test_dynamic_graph(self): + for use_cuda in ( + [False, True] if core.is_compiled_with_cuda() else [False] + ): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + paddle.disable_static(place=place) + x = paddle.to_tensor(self.x_np) + + out_1 = paddle.nn.functional.fractional_max_pool3d( + x=x, output_size=[3, 3, 3], random_u=0.3 + ) + + out_2 = paddle.nn.functional.fractional_max_pool3d( + x=x, output_size=5, random_u=0.5 + ) + + out_3 = paddle.nn.functional.fractional_max_pool3d( + x=x, output_size=[2, 3, 5], random_u=0.7 + ) + + # out_4 = paddle.nn.functional.fractional_max_pool3d( + # x=x, output_size=[3, 3, 3], data_format="NDHWC", random_u=0.1) + + out_5 = paddle.nn.functional.fractional_max_pool3d( + x=x, output_size=[None, 3, None], random_u=0.6 + ) + + np.testing.assert_allclose(out_1.numpy(), self.res_1_np) + + np.testing.assert_allclose(out_2.numpy(), self.res_2_np) + + np.testing.assert_allclose(out_3.numpy(), self.res_3_np) + + # np.testing.assert_allclose(out_4.numpy(), self.res_4_np) + + np.testing.assert_allclose(out_5.numpy(), self.res_5_np) + + +class TestFractionalMaxPool3DClassAPI(unittest.TestCase): + def setUp(self): + self.x_np = np.random.random([2, 3, 5, 7, 7]).astype("float32") + self.res_1_np = fractional_pool3d_forward( + x=self.x_np, output_size=[3, 3, 3], random_u=0.3 + ) + + self.res_2_np = fractional_pool3d_forward( + x=self.x_np, output_size=5, random_u=0.5 + ) + + self.res_3_np = fractional_pool3d_forward( + x=self.x_np, output_size=[2, 3, 5], random_u=0.7 + ) + + # self.res_4_np = fractional_pool3d_forward( + # x=self.x_np, + # output_size=[3, 3, 3], + # pool_type="max", + # data_format="NDHWC", + # random_u=0.1 + # ) + + self.res_5_np = fractional_pool3d_forward( + x=self.x_np, output_size=[None, 3, None], random_u=0.6 + ) + + def test_static_graph(self): + for use_cuda in ( + [False, True] if core.is_compiled_with_cuda() else [False] + ): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + paddle.enable_static() + x = paddle.static.data( + name="x", shape=[2, 3, 5, 7, 7], dtype="float32" + ) + + fractional_max_pool = paddle.nn.FractionalMaxPool3D( + output_size=[3, 3, 3], random_u=0.3 + ) + out_1 = fractional_max_pool(x=x) + + fractional_max_pool = paddle.nn.FractionalMaxPool3D( + output_size=5, random_u=0.5 + ) + out_2 = fractional_max_pool(x=x) + + fractional_max_pool = paddle.nn.FractionalMaxPool3D( + output_size=[2, 3, 5], random_u=0.7 + ) + out_3 = fractional_max_pool(x=x) + + # fractional_max_pool = paddle.nn.FractionalMaxPool3D( + # output_size=[3, 3, 3], data_format="NDHWC", random_u=0.1) + # out_4 = fractional_max_pool(x=x) + + fractional_max_pool = paddle.nn.FractionalMaxPool3D( + output_size=[None, 3, None], random_u=0.6 + ) + out_5 = fractional_max_pool(x=x) + + exe = paddle.static.Executor(place=place) + [res_1, res_2, res_3, res_5] = exe.run( + base.default_main_program(), + feed={"x": self.x_np}, + fetch_list=[out_1, out_2, out_3, out_5], + ) + + np.testing.assert_allclose(res_1, self.res_1_np) + + np.testing.assert_allclose(res_2, self.res_2_np) + + np.testing.assert_allclose(res_3, self.res_3_np) + + # assert np.allclose(res_4, self.res_4_np) + + np.testing.assert_allclose(res_5, self.res_5_np) + + def test_dynamic_graph(self): + for use_cuda in ( + [False, True] if core.is_compiled_with_cuda() else [False] + ): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + paddle.disable_static(place=place) + x = paddle.to_tensor(self.x_np) + + fractional_max_pool = paddle.nn.FractionalMaxPool3D( + output_size=[3, 3, 3], random_u=0.3 + ) + out_1 = fractional_max_pool(x=x) + + fractional_max_pool = paddle.nn.FractionalMaxPool3D( + output_size=5, random_u=0.5 + ) + out_2 = fractional_max_pool(x=x) + + fractional_max_pool = paddle.nn.FractionalMaxPool3D( + output_size=[2, 3, 5], random_u=0.7 + ) + out_3 = fractional_max_pool(x=x) + + # fractional_max_pool = paddle.nn.FractionalMaxPool3D( + # output_size=[3, 3, 3], data_format="NDHWC", random_u=0.1) + # out_4 = fractional_max_pool(x=x) + + fractional_max_pool = paddle.nn.FractionalMaxPool3D( + output_size=[None, 3, None], random_u=0.6 + ) + out_5 = fractional_max_pool(x=x) + + np.testing.assert_allclose(out_1.numpy(), self.res_1_np) + + np.testing.assert_allclose(out_2.numpy(), self.res_2_np) + + np.testing.assert_allclose(out_3.numpy(), self.res_3_np) + + # assert np.allclose(out_4.numpy(), self.res_4_np) + + np.testing.assert_allclose(out_5.numpy(), self.res_5_np) + + +class TestOutDtype(unittest.TestCase): + def test_max_pool(self): + api_fn = F.fractional_max_pool3d + shape = [1, 3, 32, 32, 32] + check_out_dtype( + api_fn, + in_specs=[(shape,)], + expect_dtypes=['float32', 'float64'], + output_size=16, + ) + + +if __name__ == '__main__': + unittest.main() From 62b1c13a829f40be3187fafe723c0680f33b8898 Mon Sep 17 00:00:00 2001 From: megemini Date: Mon, 27 Nov 2023 18:34:07 +0800 Subject: [PATCH 18/18] [Add] docstring for fractional max pool --- python/paddle/nn/functional/pooling.py | 130 +++++++++++++++- python/paddle/nn/layer/pooling.py | 144 +++++++++++++++++- .../legacy_test/test_fractional_max_pool2d.py | 43 ++---- .../legacy_test/test_fractional_max_pool3d.py | 27 ++-- test/legacy_test/test_pool_max_op.py | 60 ++------ 5 files changed, 310 insertions(+), 94 deletions(-) diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index a044a74aa0f46..656971fe27e21 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -2100,7 +2100,70 @@ def fractional_max_pool2d( x, output_size, random_u=None, return_mask=False, name=None ): """ - TODO(megemini) + This operation applies 2D fractional max pooling on input tensor, which is described in the paper: + + [1] Ben Graham, Fractional Max-Pooling. 2015. http://arxiv.org/abs/1412.6071 + + The h and w dimensions of the output tensor are determined by the parameter output_size. + + For each dimension, the fractional max pooling: + + .. math:: + + alpha &= size_{input} / size_{output} + + index_{start} &= ceil( \alpha * (i + u) - 1) + + index_{end} &= ceil( \alpha * (i + 1 + u) - 1) + + Output &= max(Input[index_{start}:index_{end}]) + + where, u in range (0, 1), i = 0,1,2...size_{output} + + The ``u`` from the formula is the parameter ``random_u``, and subtract ``1`` for the index starts from ``0`` + instead of ``1`` where ``ceil`` works. + + For instance, giving a sequence of length ``7`` is ``[2, 4, 3, 1, 5, 2, 3]``, ``output_size`` is ``5`` and ``random_u`` is ``0.3``. + The ``alpha = 7/5 = 1.4``, the starts of index is ``[0, 1, 3, 4, 6]``, the ends of index is ``[1, 3, 4, 6, 7]`` and makes the + random sequence in the paper is ``index_end - index_start = [1, 2, 1, 2, 1]``. The strides and kernel_sizes are both equal to + the random sequence, giving the final pooling output is ``[2, 4, 1, 5, 3]``. + + Parameters: + x (Tensor): The input tensor of fractional max pool2d operator, which is a 4-D tensor. The data type can be float32, float64. + output_size(int|list|tuple): The output size. If output size is a tuple or list, it must contain + two element, (H, W). H and W can be either a int, or None which means the size will be the same as that of + the input. + random_u(float): A random float number in range (0, 1) for the fractional pooling. + Default None, means randomly generated by framework which can be fixed by ``paddle.seed``. + return_mask(bool, optional): If true, the index of max pooling point will be returned along with outputs. Default False. + name(str, optional): For detailed information, please refer to :ref:`api_guide_Name`. + Usually name is no need to set and None by default. + + Returns: + Tensor: The output tensor of fractional max pool2d result which is a 4-D tensor.. The data type is same as input tensor. + + Examples: + .. code-block:: python + + >>> # fractional max pool2d + >>> # suppose input data in shape of [N, C, H, W], `output_size` is [m, n], + >>> # output shape is [N, C, m, n], fractional pool divide H and W dimensions + >>> # of input data into m * n grids and performs poolings in each + >>> # grid to get output. + + >>> import paddle + + >>> x = paddle.rand([2, 3, 32, 32]) + + >>> pool_out = paddle.nn.functional.fractional_max_pool2d(x, output_size=3) + >>> print(pool_out.shape) + [2, 3, 3, 3] + + >>> pool_out, indices = paddle.nn.functional.fractional_max_pool2d(x, output_size=[2, 3], return_mask=True) + >>> print(pool_out.shape) + [2, 3, 2, 3] + >>> print(indices.shape) + [2, 3, 2, 3] """ _check_input(x, 4) @@ -2168,7 +2231,70 @@ def fractional_max_pool3d( x, output_size, random_u=None, return_mask=False, name=None ): """ - TODO(megemini) + This operation applies 3D fractional max pooling on input tensor, which is described in the paper: + + [1] Ben Graham, Fractional Max-Pooling. 2015. http://arxiv.org/abs/1412.6071 + + The d, h and w dimensions of the output tensor are determined by the parameter output_size. + + For each dimension, the fractional max pooling: + + .. math:: + + alpha &= size_{input} / size_{output} + + index_{start} &= ceil( \alpha * (i + u) - 1) + + index_{end} &= ceil( \alpha * (i + 1 + u) - 1) + + Output &= max(Input[index_{start}:index_{end}]) + + where, u in range (0, 1), i = 0,1,2...size_{output} + + The ``u`` from the formula is the parameter ``random_u``, and subtract ``1`` for the index starts from ``0`` + instead of ``1`` where ``ceil`` works. + + For instance, giving a sequence of length ``7`` is ``[2, 4, 3, 1, 5, 2, 3]``, ``output_size`` is ``5`` and ``random_u`` is ``0.3``. + The ``alpha = 7/5 = 1.4``, the starts of index is ``[0, 1, 3, 4, 6]``, the ends of index is ``[1, 3, 4, 6, 7]`` and makes the + random sequence in the paper is ``index_end - index_start = [1, 2, 1, 2, 1]``. The strides and kernel_sizes are both equal to + the random sequence, giving the final pooling output is ``[2, 4, 1, 5, 3]``. + + Parameters: + x (Tensor): The input tensor of fractional max pool3d operator, which is a 5-D tensor. The data type can be float32, float64. + output_size(int|list|tuple): The output size. If output size is a tuple or list, it must contain + three element, (D, H, W). D, H and W can be either a int, or None which means the size will be the same as that of + the input. + random_u(float): A random float number in range (0, 1) for the fractional pooling. + Default None, means randomly generated by framework which can be fixed by ``paddle.seed``. + return_mask(bool, optional): If true, the index of max pooling point will be returned along with outputs. Default False. + name(str, optional): For detailed information, please refer to :ref:`api_guide_Name`. + Usually name is no need to set and None by default. + + Returns: + Tensor: The output tensor of fractional max pool3d result which is a 5-D tensor.. The data type is same as input tensor. + + Examples: + .. code-block:: python + + >>> # fractional max pool3d + >>> # suppose input data in shape of [N, C, D, H, W], `output_size` is [l, m, n], + >>> # output shape is [N, C, l, m, n], fractional pool divide D, H and W dimensions + >>> # of input data into l * m * n grids and performs poolings in each + >>> # grid to get output. + + >>> import paddle + + >>> x = paddle.rand([2, 3, 8, 32, 32]) + + >>> pool_out = paddle.nn.functional.fractional_max_pool3d(x, output_size=3) + >>> print(pool_out.shape) + [2, 3, 3, 3, 3] + + >>> pool_out, indices = paddle.nn.functional.fractional_max_pool3d(x, output_size=[2, 3, 3], return_mask=True) + >>> print(pool_out.shape) + [2, 3, 2, 3, 3] + >>> print(indices.shape) + [2, 3, 2, 3, 3] """ _check_input(x, 5) diff --git a/python/paddle/nn/layer/pooling.py b/python/paddle/nn/layer/pooling.py index f4d07d5b5eb4b..ebb98c567a461 100755 --- a/python/paddle/nn/layer/pooling.py +++ b/python/paddle/nn/layer/pooling.py @@ -1143,7 +1143,77 @@ def extra_repr(self): class FractionalMaxPool2D(Layer): """ - TODO(megemini) + This operation applies 2D fractional max pooling on input tensor, which is described in the paper: + + [1] Ben Graham, Fractional Max-Pooling. 2015. http://arxiv.org/abs/1412.6071 + + The h and w dimensions of the output tensor are determined by the parameter output_size. + + For each dimension, the fractional max pooling: + + .. math:: + + alpha &= size_{input} / size_{output} + + index_{start} &= ceil( \alpha * (i + u) - 1) + + index_{end} &= ceil( \alpha * (i + 1 + u) - 1) + + Output &= max(Input[index_{start}:index_{end}]) + + where, u in range (0, 1), i = 0,1,2...size_{output} + + The ``u`` from the formula is the parameter ``random_u``, and subtract ``1`` for the index starts from ``0`` + instead of ``1`` where ``ceil`` works. + + For instance, giving a sequence of length ``7`` is ``[2, 4, 3, 1, 5, 2, 3]``, ``output_size`` is ``5`` and ``random_u`` is ``0.3``. + The ``alpha = 7/5 = 1.4``, the starts of index is ``[0, 1, 3, 4, 6]``, the ends of index is ``[1, 3, 4, 6, 7]`` and makes the + random sequence in the paper is ``index_end - index_start = [1, 2, 1, 2, 1]``. The strides and kernel_sizes are both equal to + the random sequence, giving the final pooling output is ``[2, 4, 1, 5, 3]``. + + Parameters: + output_size(int|list|tuple): The output size. If output size is a tuple or list, it must contain + two element, (H, W). H and W can be either a int, or None which means the size will be the same as that of + the input. + random_u(float): A random float number in range (0, 1) for the fractional pooling. + Default None, means randomly generated by framework which can be fixed by ``paddle.seed``. + return_mask(bool, optional): If true, the index of max pooling point will be returned along with outputs. Default False. + name(str, optional): For detailed information, please refer to :ref:`api_guide_Name`. + Usually name is no need to set and None by default. + + Shape: + - x(Tensor): The input tensor of fractional max pool2d operator, which is a 4-D tensor. + The data type can be float32, float64. + - output(Tensor): The output tensor of fractional max pool2d operator, which is a 4-D tensor. + The data type is same as input x. + + Returns: + A callable object of FractionalMaxPool2D. + + Examples: + .. code-block:: python + + >>> # fractional max pool2d + >>> # suppose input data in shape of [N, C, H, W], `output_size` is [m, n], + >>> # output shape is [N, C, m, n], fractional pool divide H and W dimensions + >>> # of input data into m * n grids and performs poolings in each + >>> # grid to get output. + + >>> import paddle + + >>> x = paddle.rand([2, 3, 32, 32]) + + >>> fractional_max_pool = paddle.nn.FractionalMaxPool2D(output_size=3) + >>> pool_out = fractional_max_pool(x=x) + >>> print(pool_out.shape) + [2, 3, 3, 3] + + >>> fractional_max_pool = paddle.nn.FractionalMaxPool2D(output_size=[2, 3], return_mask=True) + >>> pool_out, indices = fractional_max_pool(x=x) + >>> print(pool_out.shape) + [2, 3, 2, 3] + >>> print(indices.shape) + [2, 3, 2, 3] """ def __init__( @@ -1172,7 +1242,77 @@ def extra_repr(self): class FractionalMaxPool3D(Layer): """ - TODO(megemini) + This operation applies 3D fractional max pooling on input tensor, which is described in the paper: + + [1] Ben Graham, Fractional Max-Pooling. 2015. http://arxiv.org/abs/1412.6071 + + The d, h and w dimensions of the output tensor are determined by the parameter output_size. + + For each dimension, the fractional max pooling: + + .. math:: + + alpha &= size_{input} / size_{output} + + index_{start} &= ceil( \alpha * (i + u) - 1) + + index_{end} &= ceil( \alpha * (i + 1 + u) - 1) + + Output &= max(Input[index_{start}:index_{end}]) + + where, u in range (0, 1), i = 0,1,2...size_{output} + + The ``u`` from the formula is the parameter ``random_u``, and subtract ``1`` for the index starts from ``0`` + instead of ``1`` where ``ceil`` works. + + For instance, giving a sequence of length ``7`` is ``[2, 4, 3, 1, 5, 2, 3]``, ``output_size`` is ``5`` and ``random_u`` is ``0.3``. + The ``alpha = 7/5 = 1.4``, the starts of index is ``[0, 1, 3, 4, 6]``, the ends of index is ``[1, 3, 4, 6, 7]`` and makes the + random sequence in the paper is ``index_end - index_start = [1, 2, 1, 2, 1]``. The strides and kernel_sizes are both equal to + the random sequence, giving the final pooling output is ``[2, 4, 1, 5, 3]``. + + Parameters: + output_size(int|list|tuple): The output size. If output size is a tuple or list, it must contain + three element, (D, H, W). D, H and W can be either a int, or None which means the size will be the same as that of + the input. + random_u(float): A random float number in range (0, 1) for the fractional pooling. + Default None, means randomly generated by framework which can be fixed by ``paddle.seed``. + return_mask(bool, optional): If true, the index of max pooling point will be returned along with outputs. Default False. + name(str, optional): For detailed information, please refer to :ref:`api_guide_Name`. + Usually name is no need to set and None by default. + + Shape: + - x(Tensor): The input tensor of fractional max pool3d operator, which is a 5-D tensor. + The data type can be float32, float64. + - output(Tensor): The output tensor of fractional max pool3d operator, which is a 5-D tensor. + The data type is same as input x. + + Returns: + A callable object of FractionalMaxPool3D. + + Examples: + .. code-block:: python + + >>> # fractional max pool3d + >>> # suppose input data in shape of [N, C, D, H, W], `output_size` is [l, m, n], + >>> # output shape is [N, C, l, m, n], fractional pool divide D, H and W dimensions + >>> # of input data into l * m * n grids and performs poolings in each + >>> # grid to get output. + + >>> import paddle + + >>> x = paddle.rand([2, 3, 8, 32, 32]) + + >>> fractional_max_pool = paddle.nn.FractionalMaxPool3D(output_size=3) + >>> pool_out = fractional_max_pool(x=x) + >>> print(pool_out.shape) + [2, 3, 3, 3, 3] + + >>> fractional_max_pool = paddle.nn.FractionalMaxPool3D(output_size=[2, 3, 3], return_mask=True) + >>> pool_out, indices = fractional_max_pool(x=x) + >>> print(pool_out.shape) + [2, 3, 2, 3, 3] + >>> print(indices.shape) + [2, 3, 2, 3, 3] """ def __init__( diff --git a/test/legacy_test/test_fractional_max_pool2d.py b/test/legacy_test/test_fractional_max_pool2d.py index d190bd17d3c65..fc79d78ad88d5 100644 --- a/test/legacy_test/test_fractional_max_pool2d.py +++ b/test/legacy_test/test_fractional_max_pool2d.py @@ -75,48 +75,37 @@ def fractional_pool2d_forward( else np.zeros((N, H_out, W_out, C)) ) - input_height = H - output_height = H_out - input_width = W - output_width = W_out - u = random_u - alpha_height = input_height / output_height - alpha_width = input_width / output_width + alpha_height = H / H_out + alpha_width = W / W_out - u_height = fractional_rational_u( - u, alpha_height, input_height, output_height - ) - u_width = fractional_rational_u(u, alpha_width, input_width, output_width) + u_height = fractional_rational_u(u, alpha_height, H, H_out) + u_width = fractional_rational_u(u, alpha_width, W, W_out) for i in range(H_out): - in_h_start = fractional_start_index(i, alpha_height, u_height) - in_h_end = fractional_end_index(i, alpha_height, u_height) - in_h_start = max(in_h_start, 0) - in_h_end = min(in_h_end, input_height) + h_start = fractional_start_index(i, alpha_height, u_height) + h_end = fractional_end_index(i, alpha_height, u_height) + h_start = max(h_start, 0) + h_end = min(h_end, H) for j in range(W_out): - in_w_start = fractional_start_index(j, alpha_width, u_width) - in_w_end = fractional_end_index(j, alpha_width, u_width) - in_w_start = max(in_w_start, 0) - in_w_end = min(in_w_end, input_width) + w_start = fractional_start_index(j, alpha_width, u_width) + w_end = fractional_end_index(j, alpha_width, u_width) + w_start = max(w_start, 0) + w_end = min(w_end, W) if data_format == 'NCHW': - x_masked = x[:, :, in_h_start:in_h_end, in_w_start:in_w_end] + x_masked = x[:, :, h_start:h_end, w_start:w_end] if pool_type == 'avg': - field_size = (in_h_end - in_h_start) * ( - in_w_end - in_w_start - ) + field_size = (h_end - h_start) * (w_end - w_start) out[:, :, i, j] = np.sum(x_masked, axis=(2, 3)) / field_size elif pool_type == 'max': out[:, :, i, j] = np.max(x_masked, axis=(2, 3)) elif data_format == 'NHWC': - x_masked = x[:, in_h_start:in_h_end, in_w_start:in_w_end, :] + x_masked = x[:, h_start:h_end, w_start:w_end, :] if pool_type == 'avg': - field_size = (in_h_end - in_h_start) * ( - in_w_end - in_w_start - ) + field_size = (h_end - h_start) * (w_end - w_start) out[:, i, j, :] = np.sum(x_masked, axis=(1, 2)) / field_size elif pool_type == 'max': out[:, i, j, :] = np.max(x_masked, axis=(1, 2)) diff --git a/test/legacy_test/test_fractional_max_pool3d.py b/test/legacy_test/test_fractional_max_pool3d.py index 5890aa2ab55d4..8c454282f9c32 100644 --- a/test/legacy_test/test_fractional_max_pool3d.py +++ b/test/legacy_test/test_fractional_max_pool3d.py @@ -75,42 +75,33 @@ def fractional_pool3d_forward( else np.zeros((N, D_out, H_out, W_out, C)) ) - input_depth = D - output_depth = D_out - input_height = H - output_height = H_out - input_width = W - output_width = W_out - u = random_u - alpha_depth = input_depth / output_depth - alpha_height = input_height / output_height - alpha_width = input_width / output_width + alpha_depth = D / D_out + alpha_height = H / H_out + alpha_width = W / W_out - u_depth = fractional_rational_u(u, alpha_depth, input_depth, output_depth) - u_height = fractional_rational_u( - u, alpha_height, input_height, output_height - ) - u_width = fractional_rational_u(u, alpha_width, input_width, output_width) + u_depth = fractional_rational_u(u, alpha_depth, D, D_out) + u_height = fractional_rational_u(u, alpha_height, H, H_out) + u_width = fractional_rational_u(u, alpha_width, W, W_out) for k in range(D_out): d_start = fractional_start_index(k, alpha_depth, u_depth) d_end = fractional_end_index(k, alpha_depth, u_depth) d_start = max(d_start, 0) - d_end = min(d_end, input_depth) + d_end = min(d_end, D) for i in range(H_out): h_start = fractional_start_index(i, alpha_height, u_height) h_end = fractional_end_index(i, alpha_height, u_height) h_start = max(h_start, 0) - h_end = min(h_end, input_height) + h_end = min(h_end, H) for j in range(W_out): w_start = fractional_start_index(j, alpha_width, u_width) w_end = fractional_end_index(j, alpha_width, u_width) w_start = max(w_start, 0) - w_end = min(w_end, input_width) + w_end = min(w_end, W) if data_format == 'NCDHW': x_masked = x[ diff --git a/test/legacy_test/test_pool_max_op.py b/test/legacy_test/test_pool_max_op.py index dbedaef7a023d..be12c58ec3487 100644 --- a/test/legacy_test/test_pool_max_op.py +++ b/test/legacy_test/test_pool_max_op.py @@ -75,34 +75,16 @@ def max_pool3D_forward_naive( H_out = (H - ksize[1] + 2 * paddings[1]) // strides[1] + 1 W_out = (W - ksize[2] + 2 * paddings[2]) // strides[2] + 1 - alpha_height = 0 - alpha_width = 0 - alpha_depth = 0 - u_height = 0 - u_width = 0 - u_depth = 0 - input_depth = D - output_depth = D_out - input_height = H - output_height = H_out - input_width = W - output_width = W_out if fractional: u = random_u - alpha_depth = input_depth / output_depth - alpha_height = input_height / output_height - alpha_width = input_width / output_width + alpha_depth = D / D_out + alpha_height = H / H_out + alpha_width = W / W_out - u_depth = fractional_rational_u( - u, alpha_depth, input_depth, output_depth - ) - u_height = fractional_rational_u( - u, alpha_height, input_height, output_height - ) - u_width = fractional_rational_u( - u, alpha_width, input_width, output_width - ) + u_depth = fractional_rational_u(u, alpha_depth, D, D_out) + u_height = fractional_rational_u(u, alpha_height, H, H_out) + u_width = fractional_rational_u(u, alpha_width, W, W_out) out = np.zeros((N, C, D_out, H_out, W_out)) mask = np.zeros((N, C, D_out, H_out, W_out)) @@ -114,7 +96,7 @@ def max_pool3D_forward_naive( d_start = fractional_start_index(k, alpha_depth, u_depth) d_end = fractional_end_index(k, alpha_depth, u_depth) d_start = max(d_start, 0) - d_end = min(d_end, input_depth) + d_end = min(d_end, D) else: d_start = np.max((k * strides[0] - paddings[0], 0)) d_end = np.min((k * strides[0] + ksize[0] - paddings[0], D)) @@ -126,7 +108,7 @@ def max_pool3D_forward_naive( h_start = fractional_start_index(i, alpha_height, u_height) h_end = fractional_end_index(i, alpha_height, u_height) h_start = max(h_start, 0) - h_end = min(h_end, input_height) + h_end = min(h_end, H) else: h_start = np.max((i * strides[1] - paddings[1], 0)) h_end = np.min((i * strides[1] + ksize[1] - paddings[1], H)) @@ -138,7 +120,7 @@ def max_pool3D_forward_naive( w_start = fractional_start_index(j, alpha_width, u_width) w_end = fractional_end_index(j, alpha_width, u_width) w_start = max(w_start, 0) - w_end = min(w_end, input_width) + w_end = min(w_end, W) else: w_start = np.max((j * strides[2] - paddings[2], 0)) w_end = np.min((j * strides[2] + ksize[2] - paddings[2], W)) @@ -184,26 +166,14 @@ def max_pool2D_forward_naive( H_out = (H - ksize[0] + 2 * paddings[0]) // strides[0] + 1 W_out = (W - ksize[1] + 2 * paddings[1]) // strides[1] + 1 - alpha_height = 0 - alpha_width = 0 - u_height = 0 - u_width = 0 - input_height = H - output_height = H_out - input_width = W - output_width = W_out if fractional: u = random_u - alpha_height = input_height / output_height - alpha_width = input_width / output_width + alpha_height = H / H_out + alpha_width = W / W_out - u_height = fractional_rational_u( - u, alpha_height, input_height, output_height - ) - u_width = fractional_rational_u( - u, alpha_width, input_width, output_width - ) + u_height = fractional_rational_u(u, alpha_height, H, H_out) + u_width = fractional_rational_u(u, alpha_width, W, W_out) out = np.zeros((N, C, H_out, W_out)) mask = np.zeros((N, C, H_out, W_out)) @@ -218,12 +188,12 @@ def max_pool2D_forward_naive( r_start = fractional_start_index(i, alpha_height, u_height) r_end = fractional_end_index(i, alpha_height, u_height) r_start = max(r_start, 0) - r_end = min(r_end, input_height) + r_end = min(r_end, H) c_start = fractional_start_index(j, alpha_width, u_width) c_end = fractional_end_index(j, alpha_width, u_width) c_start = max(c_start, 0) - c_end = min(c_end, input_width) + c_end = min(c_end, W) else: r_start = np.max((i * strides[0] - paddings[0], 0)) r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H))