Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pool max with index.(Mask type should be int, not float) #5749

Merged
merged 3 commits into from
Nov 20, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 30 additions & 30 deletions paddle/operators/math/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -498,8 +498,8 @@ template class Pool3dGradFunctor<
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template <typename T>
class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
template <typename T1, typename T2>
class MaxPool2dWithIndexFunctor<platform::CPUPlace, T1, T2> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, std::vector<int>& ksize,
Expand All @@ -520,9 +520,9 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
const int input_stride = input_height * input_width;
const int output_stride = output_height * output_width;

const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());
T* mask_data = mask->mutable_data<T>(context.GetPlace());
const T1* input_data = input.data<T1>();
T1* output_data = output->mutable_data<T1>(context.GetPlace());
T2* mask_data = mask->mutable_data<T2>(context.GetPlace());

for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
Expand All @@ -535,7 +535,7 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);

T ele = static_cast<T>(-FLT_MAX);
T1 ele = static_cast<T1>(-FLT_MAX);
int index = -1;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
Expand Down Expand Up @@ -563,8 +563,8 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template <typename T>
class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
template <typename T1, typename T2>
class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T1, T2> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& output_grad,
Expand All @@ -580,9 +580,9 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
const int input_stride = input_height * input_width;
const int output_stride = output_height * output_width;

const T* mask_data = mask.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
const T2* mask_data = mask.data<T2>();
const T1* output_grad_data = output_grad.data<T1>();
T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());

for (int n = 0; n < batch_size; ++n) {
for (int c = 0; c < output_channels; ++c) {
Expand All @@ -602,18 +602,18 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
}
};

template class MaxPool2dWithIndexFunctor<platform::CPUPlace, float>;
template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, float>;
template class MaxPool2dWithIndexFunctor<platform::CPUPlace, double>;
template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, double>;
template class MaxPool2dWithIndexFunctor<platform::CPUPlace, float, int>;
template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, float, int>;
template class MaxPool2dWithIndexFunctor<platform::CPUPlace, double, int>;
template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, double, int>;

/*
* All tensors are in NCDHW format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template <typename T>
class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
template <typename T1, typename T2>
class MaxPool3dWithIndexFunctor<platform::CPUPlace, T1, T2> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, std::vector<int>& ksize,
Expand All @@ -639,9 +639,9 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
const int input_stride = input_depth * input_height * input_width;
const int output_stride = output_depth * output_height * output_width;

const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());
T* mask_data = mask->mutable_data<T>(context.GetPlace());
const T1* input_data = input.data<T1>();
T1* output_data = output->mutable_data<T1>(context.GetPlace());
T2* mask_data = mask->mutable_data<T2>(context.GetPlace());

for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
Expand All @@ -659,7 +659,7 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
wstart = std::max(wstart, 0);

int output_idx = (pd * output_height + ph) * output_width + pw;
T ele = static_cast<T>(-FLT_MAX);
T1 ele = static_cast<T1>(-FLT_MAX);
int index = -1;
for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) {
Expand Down Expand Up @@ -691,8 +691,8 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template <typename T>
class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> {
template <typename T1, typename T2>
class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T1, T2> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& output_grad,
Expand All @@ -710,9 +710,9 @@ class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> {
const int input_stride = input_depth * input_height * input_width;
const int output_stride = output_depth * output_height * output_width;

const T* mask_data = mask.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
const T2* mask_data = mask.data<T2>();
const T1* output_grad_data = output_grad.data<T1>();
T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());

for (int n = 0; n < batch_size; ++n) {
for (int c = 0; c < output_channels; ++c) {
Expand All @@ -735,10 +735,10 @@ class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> {
}
};

template class MaxPool3dWithIndexFunctor<platform::CPUPlace, float>;
template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, float>;
template class MaxPool3dWithIndexFunctor<platform::CPUPlace, double>;
template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, double>;
template class MaxPool3dWithIndexFunctor<platform::CPUPlace, float, int>;
template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, float, int>;
template class MaxPool3dWithIndexFunctor<platform::CPUPlace, double, int>;
template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, double, int>;
} // namespace math
} // namespace operators
} // namespace paddle
Loading