Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.38】为 Paddle 新增 FractionalMaxPool2d / FractionalMaxPool3d API #59130

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
34c6d6f
[Init] init commit
megemini Oct 27, 2023
71bc3c3
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini Oct 27, 2023
8efe56a
[Add] fractional max pool python api
megemini Oct 28, 2023
88362e8
[Add] FractionalMaxPool layer
megemini Oct 28, 2023
f265203
[Add] fractional index generation
megemini Nov 15, 2023
d7457e0
[Add] add fractional max pooling to __init__
megemini Nov 16, 2023
f79b8d2
[Fix] add default value False to max_poolNd_with_index
megemini Nov 16, 2023
19cbb42
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini Nov 16, 2023
594e130
[Update] __init__ add fractional funcs
megemini Nov 16, 2023
e667f09
[Add] test file
megemini Nov 19, 2023
ab256ac
[Add] test_pool_max_op for fractional
megemini Nov 19, 2023
bf7a640
[Fix] test index cast to int
megemini Nov 20, 2023
44f44c0
[Change] pooling cu rand
megemini Nov 20, 2023
a88aeec
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini Nov 20, 2023
e0679f0
[Change] skip fractional op test
megemini Nov 21, 2023
99196b2
[Update] cu seed
megemini Nov 21, 2023
89297c8
[Add] add param random_u for solid fractional
megemini Nov 21, 2023
ba9d583
[Fix] fractional pool 3d random_u
megemini Nov 22, 2023
6d6f93b
[Add] add xpu support for fractional max pool 2d with index
megemini Nov 23, 2023
8ea0e52
[Add] fractional api unittest
megemini Nov 26, 2023
0c7254c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini Nov 27, 2023
a302293
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini Nov 27, 2023
62b1c13
[Add] docstring for fractional max pool
megemini Nov 27, 2023
4fe49f7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini Dec 4, 2023
ae13ac2
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini Dec 5, 2023
71ededc
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini Dec 5, 2023
1782dda
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini Dec 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1482,17 +1482,17 @@
func : matrix_power_grad

- backward_op : max_pool2d_with_index_grad
forward : max_pool2d_with_index(Tensor x, int[] kernel_size, int[] strides = {1, 1}, int[] paddings = {0, 0}, bool global_pooling = false, bool adaptive = false) -> Tensor(out), Tensor(mask)
args : (Tensor x, Tensor mask, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive)
forward : max_pool2d_with_index(Tensor x, int[] kernel_size, int[] strides = {1, 1}, int[] paddings = {0, 0}, bool global_pooling = false, bool adaptive = false, bool fractional = false, float random_u = 0.0) -> Tensor(out), Tensor(mask)
args : (Tensor x, Tensor mask, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive, bool fractional, float random_u)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这儿直接改变原来phi算子定义会有不兼容升级的问题,新增api的话新增一个phi接口,可以复用已有的max_pool2d_with_index接口即可

output : Tensor(x_grad)
infer_meta :
func : MaxPoolWithIndexGradInferMeta
kernel :
func : max_pool2d_with_index_grad

- backward_op : max_pool3d_with_index_grad
forward : max_pool3d_with_index(Tensor x, int[] kernel_size, int[] strides = {1, 1, 1}, int[] paddings = {0, 0, 0}, bool global_pooling = false, bool adaptive = false) -> Tensor(out), Tensor(mask)
args : (Tensor x, Tensor mask, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive)
forward : max_pool3d_with_index(Tensor x, int[] kernel_size, int[] strides = {1, 1, 1}, int[] paddings = {0, 0, 0}, bool global_pooling = false, bool adaptive = false, bool fractional = false, float random_u = 0.0) -> Tensor(out), Tensor(mask)
args : (Tensor x, Tensor mask, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive, bool fractional, float random_u)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

output : Tensor(x_grad)
infer_meta :
func : MaxPoolWithIndexGradInferMeta
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1682,7 +1682,7 @@
backward : matrix_power_grad

- op : max_pool2d_with_index
args : (Tensor x, int[] kernel_size, int[] strides= {1, 1}, int[] paddings = {0, 0}, bool global_pooling = false, bool adaptive = false)
args : (Tensor x, int[] kernel_size, int[] strides= {1, 1}, int[] paddings = {0, 0}, bool global_pooling = false, bool adaptive = false, bool fractional = false, float random_u = 0.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

output : Tensor(out), Tensor(mask)
infer_meta :
func : MaxPoolWithIndexInferMeta
Expand All @@ -1691,7 +1691,7 @@
backward : max_pool2d_with_index_grad

- op : max_pool3d_with_index
args : (Tensor x, int[] kernel_size, int[] strides = {1, 1, 1}, int[] paddings = {0, 0, 0}, bool global_pooling = false, bool adaptive = false)
args : (Tensor x, int[] kernel_size, int[] strides = {1, 1, 1}, int[] paddings = {0, 0, 0}, bool global_pooling = false, bool adaptive = false, bool fractional = false, float random_u = 0.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

output : Tensor(out), Tensor(mask)
infer_meta :
func : MaxPoolWithIndexInferMeta
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,8 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
const std::vector<int>& paddings,
bool global_pooling,
bool adaptive,
bool fractional,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

infermeta也需要新增一个复用之前的即可

float random_u,
MetaTensor* dx) {
dx->share_meta(x);
}
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
const std::vector<int>& paddings,
bool global_pooling,
bool adaptive,
bool fractional,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

float random_u,
MetaTensor* dx);

void MeshgridGradInferMeta(const std::vector<const MetaTensor*>& inputs,
Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2263,6 +2263,8 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x,
const std::vector<int>& paddings,
bool global_pooling,
bool adaptive,
bool fractional,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

float random_u,
MetaTensor* out,
MetaTensor* mask,
MetaConfig config) {
Expand Down Expand Up @@ -2309,7 +2311,7 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x,
kernel_size_.size()));

std::vector<int64_t> output_shape({x_dims[0], x_dims[1]});
if (adaptive) {
if (adaptive || fractional) {
output_shape.insert(
output_shape.end(), kernel_size_.begin(), kernel_size_.end());
} else {
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,8 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x,
const std::vector<int>& paddings,
bool global_pooling,
bool adaptive,
bool fractional,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

float random_u,
MetaTensor* out,
MetaTensor* mask,
MetaConfig config = MetaConfig());
Expand Down
75 changes: 75 additions & 0 deletions paddle/phi/kernels/funcs/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1571,6 +1571,8 @@ class MaxPool2dWithIndexFunctor<CPUContext, T1, T2> {
const std::vector<int>& strides,
const std::vector<int>& paddings,
bool adaptive,
bool fractional,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

直接修改原kernel定义也会有不兼容升级的问题,可以新增一个kernel来复用当前kernel

float random_u,
DenseTensor* output,
DenseTensor* mask) {
const int batch_size = static_cast<int>(input.dims()[0]);
Expand All @@ -1592,6 +1594,26 @@ class MaxPool2dWithIndexFunctor<CPUContext, T1, T2> {
T1* output_data = context.template Alloc<T1>(output);
T2* mask_data = context.template Alloc<T2>(mask);

float alpha_height = 0, alpha_width = 0;
float u_height = 0, u_width = 0;
if (fractional) {
float u = 0;
if (random_u == 0) {
std::uniform_real_distribution<float> dist(0, 1);
auto engine = phi::GetCPURandomEngine(0);
u = dist(*engine);
} else {
u = random_u;
}

alpha_height = static_cast<float>(input_height) / output_height;
alpha_width = static_cast<float>(input_width) / output_width;

u_height =
FractionalRationalU(u, alpha_height, input_height, output_height);
u_width = FractionalRationalU(u, alpha_width, input_width, output_width);
}

int hstart = 0, hend = 0;
int wstart = 0, wend = 0;
for (int i = 0; i < batch_size; i++) {
Expand All @@ -1600,6 +1622,11 @@ class MaxPool2dWithIndexFunctor<CPUContext, T1, T2> {
if (adaptive) {
hstart = AdaptStartIndex(ph, input_height, output_height);
hend = AdaptEndIndex(ph, input_height, output_height);
} else if (fractional) {
hstart = FractionalStartIndex(ph, alpha_height, u_height);
hend = FractionalEndIndex(ph, alpha_height, u_height);
hstart = std::max(hstart, 0);
hend = std::min(hend, input_height);
} else {
hstart = ph * stride_height - padding_height;
hend = std::min(hstart + ksize_height, input_height);
Expand All @@ -1609,6 +1636,11 @@ class MaxPool2dWithIndexFunctor<CPUContext, T1, T2> {
if (adaptive) {
wstart = AdaptStartIndex(pw, input_width, output_width);
wend = AdaptEndIndex(pw, input_width, output_width);
} else if (fractional) {
wstart = FractionalStartIndex(pw, alpha_width, u_width);
wend = FractionalEndIndex(pw, alpha_width, u_width);
wstart = std::max(wstart, 0);
wend = std::min(wend, input_width);
} else {
wstart = pw * stride_width - padding_width;
wend = std::min(wstart + ksize_width, input_width);
Expand Down Expand Up @@ -1653,6 +1685,8 @@ class MaxPool2dWithIndexGradFunctor<CPUContext, T1, T2> {
const std::vector<int>& strides UNUSED,
const std::vector<int>& paddings UNUSED,
bool adaptive UNUSED,
bool fractional UNUSED,
float random_u UNUSED,
DenseTensor* input_grad) {
const int batch_size = static_cast<int>(input_grad->dims()[0]);
const int input_height = static_cast<int>(input_grad->dims()[2]);
Expand Down Expand Up @@ -1704,6 +1738,8 @@ class MaxPool3dWithIndexFunctor<CPUContext, T1, T2> {
const std::vector<int>& strides,
const std::vector<int>& paddings,
bool adaptive,
bool fractional,
float random_u,
DenseTensor* output,
DenseTensor* mask) {
const int batch_size = static_cast<int>(input.dims()[0]);
Expand All @@ -1730,6 +1766,28 @@ class MaxPool3dWithIndexFunctor<CPUContext, T1, T2> {
T1* output_data = context.template Alloc<T1>(output);
T2* mask_data = context.template Alloc<T2>(mask);

float alpha_height = 0, alpha_width = 0, alpha_depth = 0;
float u_height = 0, u_width = 0, u_depth = 0;
if (fractional) {
float u = 0;
if (random_u == 0) {
std::uniform_real_distribution<float> dist(0, 1);
auto engine = phi::GetCPURandomEngine(0);
u = dist(*engine);
} else {
u = random_u;
}

alpha_depth = static_cast<float>(input_depth) / output_depth;
alpha_height = static_cast<float>(input_height) / output_height;
alpha_width = static_cast<float>(input_width) / output_width;

u_depth = FractionalRationalU(u, alpha_depth, input_depth, output_depth);
u_height =
FractionalRationalU(u, alpha_height, input_height, output_height);
u_width = FractionalRationalU(u, alpha_width, input_width, output_width);
}

int dstart = 0, dend = 0;
int hstart = 0, hend = 0;
int wstart = 0, wend = 0;
Expand All @@ -1739,6 +1797,11 @@ class MaxPool3dWithIndexFunctor<CPUContext, T1, T2> {
if (adaptive) {
dstart = AdaptStartIndex(pd, input_depth, output_depth);
dend = AdaptEndIndex(pd, input_depth, output_depth);
} else if (fractional) {
dstart = FractionalStartIndex(pd, alpha_depth, u_depth);
dend = FractionalEndIndex(pd, alpha_depth, u_depth);
dstart = std::max(dstart, 0);
dend = std::min(dend, input_depth);
} else {
dstart = pd * stride_depth - padding_depth;
dend = std::min(dstart + ksize_depth, input_depth);
Expand All @@ -1748,6 +1811,11 @@ class MaxPool3dWithIndexFunctor<CPUContext, T1, T2> {
if (adaptive) {
hstart = AdaptStartIndex(ph, input_height, output_height);
hend = AdaptEndIndex(ph, input_height, output_height);
} else if (fractional) {
hstart = FractionalStartIndex(ph, alpha_height, u_height);
hend = FractionalEndIndex(ph, alpha_height, u_height);
hstart = std::max(hstart, 0);
hend = std::min(hend, input_height);
} else {
hstart = ph * stride_height - padding_height;
hend = std::min(hstart + ksize_height, input_height);
Expand All @@ -1757,6 +1825,11 @@ class MaxPool3dWithIndexFunctor<CPUContext, T1, T2> {
if (adaptive) {
wstart = AdaptStartIndex(pw, input_width, output_width);
wend = AdaptEndIndex(pw, input_width, output_width);
} else if (fractional) {
wstart = FractionalStartIndex(pw, alpha_width, u_width);
wend = FractionalEndIndex(pw, alpha_width, u_width);
wstart = std::max(wstart, 0);
wend = std::min(wend, input_width);
} else {
wstart = pw * stride_width - padding_width;
wend = std::min(wstart + ksize_width, input_width);
Expand Down Expand Up @@ -1806,6 +1879,8 @@ class MaxPool3dWithIndexGradFunctor<CPUContext, T1, T2> {
const std::vector<int>& strides UNUSED,
const std::vector<int>& paddings UNUSED,
bool adaptive UNUSED,
bool fractional UNUSED,
float random_u UNUSED,
DenseTensor* input_grad) {
const int batch_size = static_cast<int>(input_grad->dims()[0]);
const int input_depth = static_cast<int>(input_grad->dims()[2]);
Expand Down
Loading