diff --git a/paddle/fluid/operators/roi_pool_op.cc b/paddle/fluid/operators/roi_pool_op.cc index a512e7dcd682b..9fd66590cb729 100644 --- a/paddle/fluid/operators/roi_pool_op.cc +++ b/paddle/fluid/operators/roi_pool_op.cc @@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/roi_pool_op.h" #include +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/kernels/roi_pool_kernel.h" namespace paddle { namespace operators { @@ -57,7 +58,7 @@ class ROIPoolOp : public framework::OperatorWithKernel { "%d-dimensional LoDTensor", rois_dims.size())); PADDLE_ENFORCE_EQ( - rois_dims[1], kROISize, + rois_dims[1], phi::kROISize, platform::errors::InvalidArgument( "ROIs should be a 2-D LoDTensor with shape (num_rois, 4)" "given as [[x1, y1, x2, y2], ...]. But the second dimension of " @@ -216,16 +217,7 @@ REGISTER_OPERATOR(roi_pool, ops::ROIPoolOp, ops::ROIPoolOpMaker, ops::ROIPoolGradMaker, ops::ROIPoolGradMaker); REGISTER_OPERATOR(roi_pool_grad, ops::ROIPoolGradOp); -REGISTER_OP_CPU_KERNEL( - roi_pool, - ops::CPUROIPoolOpKernel, - ops::CPUROIPoolOpKernel, - ops::CPUROIPoolOpKernel); -REGISTER_OP_CPU_KERNEL( - roi_pool_grad, - ops::CPUROIPoolGradOpKernel, - ops::CPUROIPoolGradOpKernel, - ops::CPUROIPoolGradOpKernel); + REGISTER_OP_VERSION(roi_pool) .AddCheckpoint( R"ROC( diff --git a/paddle/fluid/operators/roi_pool_op.cu b/paddle/fluid/operators/roi_pool_op.cu deleted file mode 100644 index b907b1114bbc0..0000000000000 --- a/paddle/fluid/operators/roi_pool_op.cu +++ /dev/null @@ -1,306 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#include -#include "paddle/fluid/memory/memory.h" -#include "paddle/fluid/operators/roi_pool_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; - -static constexpr int kNumCUDAThreads = 512; -static constexpr int kNumMaxinumNumBlocks = 4096; - -static inline int NumBlocks(const int N) { - return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, - kNumMaxinumNumBlocks); -} - -template -__global__ void GPUROIPoolForward( - const int nthreads, const T* input_data, const T* input_rois, - const float spatial_scale, const int channels, const int height, - const int width, const int pooled_height, const int pooled_width, - int* roi_batch_id_data, T* output_data, int64_t* argmax_data) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int offset = blockDim.x * gridDim.x; - for (size_t i = index; i < nthreads; i += offset) { - int pw = i % pooled_width; - int ph = (i / pooled_width) % pooled_height; - int c = (i / pooled_width / pooled_height) % channels; - int n = i / pooled_width / pooled_height / channels; - - const T* offset_input_rois = input_rois + n * kROISize; - int roi_batch_ind = roi_batch_id_data[n]; - int roi_start_w = round(offset_input_rois[0] * spatial_scale); - int roi_start_h = round(offset_input_rois[1] * spatial_scale); - int roi_end_w = round(offset_input_rois[2] * spatial_scale); - int roi_end_h = round(offset_input_rois[3] * spatial_scale); - - int roi_width = max(roi_end_w - roi_start_w + 1, 1); - int roi_height = max(roi_end_h - roi_start_h + 1, 1); - - int hstart = static_cast(floor(static_cast(ph) * - static_cast(roi_height) / - static_cast(pooled_height))); - int wstart = static_cast(floor(static_cast(pw) * - static_cast(roi_width) / - static_cast(pooled_width))); - int hend = static_cast(ceil(static_cast(ph + 1) * - static_cast(roi_height) / - static_cast(pooled_height))); - int wend = static_cast(ceil(static_cast(pw + 1) * - static_cast(roi_width) / - static_cast(pooled_width))); - hstart = min(max(hstart + roi_start_h, 0), height); - hend = min(max(hend + roi_start_h, 0), height); - wstart = min(max(wstart + roi_start_w, 0), width); - wend = min(max(wend + roi_start_w, 0), width); - bool is_empty = (hend <= hstart) || (wend <= wstart); - - T maxval = is_empty ? 0 : -std::numeric_limits::max(); - int maxidx = -1; - const T* offset_input_data = - input_data + (roi_batch_ind * channels + c) * height * width; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - int input_data_index = h * width + w; - if (offset_input_data[input_data_index] > maxval) { - maxval = offset_input_data[input_data_index]; - maxidx = input_data_index; - } - } - } - output_data[i] = maxval; - if (argmax_data) { - argmax_data[i] = maxidx; - } - } -} - -template -__global__ void GPUROIPoolBackward( - const int nthreads, const T* input_rois, const T* output_grad, - const int64_t* argmax_data, const int num_rois, const float spatial_scale, - const int channels, const int height, const int width, - const int pooled_height, const int pooled_width, int* roi_batch_id_data, - T* input_grad) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int offset = blockDim.x * gridDim.x; - for (int i = index; i < nthreads; i += offset) { - int pw = i % pooled_width; - int ph = (i / pooled_width) % pooled_height; - int c = (i / pooled_width / pooled_height) % channels; - int n = i / pooled_width / pooled_height / channels; - - int roi_batch_ind = roi_batch_id_data[n]; - int input_offset = (roi_batch_ind * channels + c) * height * width; - int output_offset = (n * channels + c) * pooled_height * pooled_width; - const T* offset_output_grad = output_grad + output_offset; - T* offset_input_grad = input_grad + input_offset; - const int64_t* offset_argmax_data = argmax_data + output_offset; - - int argmax = offset_argmax_data[ph * pooled_width + pw]; - if (argmax != -1) { - platform::CudaAtomicAdd( - offset_input_grad + argmax, - static_cast(offset_output_grad[ph * pooled_width + pw])); - } - } -} - -template -class GPUROIPoolOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* in = ctx.Input("X"); - auto* rois = ctx.Input("ROIs"); - auto* out = ctx.Output("Out"); - auto* argmax = ctx.Output("Argmax"); - - auto pooled_height = ctx.Attr("pooled_height"); - auto pooled_width = ctx.Attr("pooled_width"); - auto spatial_scale = ctx.Attr("spatial_scale"); - - auto in_dims = in->dims(); - int batch_size = in_dims[0]; - auto in_stride = phi::stride(in_dims); - int channels = in_dims[1]; - int height = in_dims[2]; - int width = in_dims[3]; - - int rois_num = rois->dims()[0]; - - if (rois_num == 0) return; - - int output_size = out->numel(); - int blocks = NumBlocks(output_size); - int threads = kNumCUDAThreads; - - framework::Tensor roi_batch_id_list; - roi_batch_id_list.Resize({rois_num}); - auto cplace = platform::CPUPlace(); - int* roi_batch_id_data = roi_batch_id_list.mutable_data(cplace); - auto& dev_ctx = ctx.cuda_device_context(); - auto gplace = ctx.GetPlace(); - if (ctx.HasInput("RoisNum")) { - auto* rois_num_t = ctx.Input("RoisNum"); - int rois_batch_size = rois_num_t->numel(); - - PADDLE_ENFORCE_EQ( - rois_batch_size, batch_size, - platform::errors::InvalidArgument( - "The batch size of input(ROIs) and input(X) must be the same but " - "received batch size of input(ROIs) and input(X) is %d and %d " - "respectively.", - rois_batch_size, batch_size)); - std::vector rois_num_list(rois_batch_size); - memory::Copy(cplace, rois_num_list.data(), gplace, - rois_num_t->data(), sizeof(int) * rois_batch_size, 0); - int start = 0; - for (int n = 0; n < rois_batch_size; ++n) { - for (int i = start; i < start + rois_num_list[n]; ++i) { - roi_batch_id_data[i] = n; - } - start += rois_num_list[n]; - } - } else { - auto rois_lod = rois->lod().back(); - int rois_batch_size = rois_lod.size() - 1; - PADDLE_ENFORCE_EQ( - rois_batch_size, batch_size, - platform::errors::InvalidArgument( - "The batch size of input(ROIs) and input(X) must be the same but " - "received batch size of input(ROIs) and input(X) is %d and %d " - "respectively.", - rois_batch_size, batch_size)); - - int rois_num_with_lod = rois_lod[rois_batch_size]; - PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod, - platform::errors::InvalidArgument( - "The number of rois from input(ROIs) and its LOD " - "must be the same. Received rois %d of input(ROIs) " - "but the number of rois %d from its LOD is %d", - rois_num, rois_num_with_lod)); - for (int n = 0; n < rois_batch_size; ++n) { - for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { - roi_batch_id_data[i] = n; - } - } - } - int bytes = roi_batch_id_list.numel() * sizeof(int); - auto roi_ptr = memory::Alloc(dev_ctx, bytes); - int* roi_id_data = reinterpret_cast(roi_ptr->ptr()); - memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes, - dev_ctx.stream()); - - GPUROIPoolForward<<>>( - output_size, in->data(), rois->data(), spatial_scale, channels, - height, width, pooled_height, pooled_width, roi_id_data, - out->mutable_data(ctx.GetPlace()), - argmax->mutable_data(ctx.GetPlace())); - } -}; - -template -class GPUROIPoolGradOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* in = ctx.Input("X"); - auto* rois = ctx.Input("ROIs"); - auto* rois_lod = ctx.Input("RoisNum"); - auto* argmax = ctx.Input("Argmax"); - - auto* out_grad = ctx.Input(framework::GradVarName("Out")); - auto* x_grad = ctx.Output(framework::GradVarName("X")); - - auto pooled_height = ctx.Attr("pooled_height"); - auto pooled_width = ctx.Attr("pooled_width"); - auto spatial_scale = ctx.Attr("spatial_scale"); - - int rois_num = rois->dims()[0]; - int channels = in->dims()[1]; - int height = in->dims()[2]; - int width = in->dims()[3]; - - if (x_grad) { - framework::Tensor roi_batch_id_list; - roi_batch_id_list.Resize({rois_num}); - auto cplace = platform::CPUPlace(); - int* roi_batch_id_data = roi_batch_id_list.mutable_data(cplace); - - auto& dev_ctx = ctx.cuda_device_context(); - auto gplace = ctx.GetPlace(); - if (ctx.HasInput("RoisNum")) { - auto* rois_num_t = ctx.Input("RoisNum"); - int rois_batch_size = rois_num_t->numel(); - std::vector rois_num_list(rois_batch_size); - memory::Copy(cplace, rois_num_list.data(), gplace, - rois_num_t->data(), sizeof(int) * rois_batch_size, 0); - int start = 0; - for (int n = 0; n < rois_batch_size; ++n) { - for (int i = start; i < start + rois_num_list[n]; ++i) { - roi_batch_id_data[i] = n; - } - start += rois_num_list[n]; - } - } else { - auto rois_lod = rois->lod().back(); - int rois_batch_size = rois_lod.size() - 1; - for (int n = 0; n < rois_batch_size; ++n) { - for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { - roi_batch_id_data[i] = n; - } - } - } - int bytes = roi_batch_id_list.numel() * sizeof(int); - auto roi_ptr = memory::Alloc(dev_ctx, bytes); - int* roi_id_data = reinterpret_cast(roi_ptr->ptr()); - memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes, - dev_ctx.stream()); - - x_grad->mutable_data(ctx.GetPlace()); - phi::funcs::SetConstant set_zero; - set_zero(dev_ctx, x_grad, static_cast(0)); - - int output_grad_size = out_grad->numel(); - int blocks = NumBlocks(output_grad_size); - int threads = kNumCUDAThreads; - - if (output_grad_size > 0) { - GPUROIPoolBackward<<>>( - output_grad_size, rois->data(), out_grad->data(), - argmax->data(), rois_num, spatial_scale, channels, height, - width, pooled_height, pooled_width, roi_id_data, - x_grad->mutable_data(ctx.GetPlace())); - } - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - roi_pool, - ops::GPUROIPoolOpKernel, - ops::GPUROIPoolOpKernel); -REGISTER_OP_CUDA_KERNEL( - roi_pool_grad, - ops::GPUROIPoolGradOpKernel, - ops::GPUROIPoolGradOpKernel); diff --git a/paddle/fluid/operators/roi_pool_op.h b/paddle/fluid/operators/roi_pool_op.h deleted file mode 100644 index a104fd49eb3e0..0000000000000 --- a/paddle/fluid/operators/roi_pool_op.h +++ /dev/null @@ -1,250 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/memory/memcpy.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -static constexpr int kROISize = 4; - -template -class CPUROIPoolOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* in = ctx.Input("X"); - auto* rois = ctx.Input("ROIs"); - auto* out = ctx.Output("Out"); - auto* argmax = ctx.Output("Argmax"); - - auto pooled_height = ctx.Attr("pooled_height"); - auto pooled_width = ctx.Attr("pooled_width"); - auto spatial_scale = ctx.Attr("spatial_scale"); - - auto in_dims = in->dims(); - int batch_size = in_dims[0]; - int channels = in_dims[1]; - int height = in_dims[2]; - int width = in_dims[3]; - int rois_num = rois->dims()[0]; - - auto in_stride = phi::stride(in_dims); - auto argmax_stride = phi::stride(argmax->dims()); - auto roi_stride = phi::stride(rois->dims()); - auto out_stride = phi::stride(out->dims()); - - const T* input_data = in->data(); - - framework::Tensor roi_batch_id_list; - roi_batch_id_list.Resize({rois_num}); - int* roi_batch_id_data = - roi_batch_id_list.mutable_data(ctx.GetPlace()); - - int rois_batch_size; - if (ctx.HasInput("RoisNum")) { - auto* rois_num_t = ctx.Input("RoisNum"); - rois_batch_size = rois_num_t->numel(); - PADDLE_ENFORCE_EQ( - rois_batch_size, batch_size, - platform::errors::InvalidArgument("The rois_batch_size and imgs " - "batch_size must be the same.")); - auto* rois_num_data = rois_num_t->data(); - int start = 0; - for (int n = 0; n < rois_batch_size; ++n) { - for (int i = start; i < start + rois_num_data[n]; ++i) { - roi_batch_id_data[i] = n; - } - start += rois_num_data[n]; - } - } else { - auto rois_lod = rois->lod().back(); - rois_batch_size = rois_lod.size() - 1; - PADDLE_ENFORCE_EQ( - rois_batch_size, batch_size, - platform::errors::InvalidArgument("The rois_batch_size and imgs " - "batch_size must be the same.")); - int rois_num_with_lod = rois_lod[rois_batch_size]; - PADDLE_ENFORCE_EQ( - rois_num, rois_num_with_lod, - platform::errors::InvalidArgument("The rois_num from input " - "and lod must be the same.")); - for (int n = 0; n < rois_batch_size; ++n) { - for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { - roi_batch_id_data[i] = n; - } - } - } - - T* output_data = out->mutable_data(ctx.GetPlace()); - int64_t* argmax_data = argmax->mutable_data(ctx.GetPlace()); - - const T* rois_data = rois->data(); - for (int n = 0; n < rois_num; ++n) { - int roi_batch_id = roi_batch_id_data[n]; - int roi_start_w = round(rois_data[0] * spatial_scale); - int roi_start_h = round(rois_data[1] * spatial_scale); - int roi_end_w = round(rois_data[2] * spatial_scale); - int roi_end_h = round(rois_data[3] * spatial_scale); - - // Force malformed ROIs to be 1x1 - int roi_height = std::max(roi_end_h - roi_start_h + 1, 1); - int roi_width = std::max(roi_end_w - roi_start_w + 1, 1); - - const float bin_size_h = - static_cast(roi_height) / static_cast(pooled_height); - const float bin_size_w = - static_cast(roi_width) / static_cast(pooled_width); - - const T* batch_data = input_data + roi_batch_id * in_stride[0]; - - for (int c = 0; c < channels; ++c) { - for (int ph = 0; ph < pooled_height; ++ph) { - for (int pw = 0; pw < pooled_width; ++pw) { - // Compute pooling region for this output unit: - // start (included) = floor(ph * roi_height / pooled_height_) - // end (excluded) = ceil((ph + 1) * roi_height / pooled_height_) - int hstart = - static_cast(floor(static_cast(ph) * bin_size_h)); - int wstart = - static_cast(floor(static_cast(pw) * bin_size_w)); - int hend = - static_cast(ceil(static_cast(ph + 1) * bin_size_h)); - int wend = - static_cast(ceil(static_cast(pw + 1) * bin_size_w)); - - hstart = std::min(std::max(hstart + roi_start_h, 0), height); - hend = std::min(std::max(hend + roi_start_h, 0), height); - wstart = std::min(std::max(wstart + roi_start_w, 0), width); - wend = std::min(std::max(wend + roi_start_w, 0), width); - - const int pool_index = ph * pooled_width + pw; - - // Define an empty pooling region to be zero - bool is_empty = (hend <= hstart) || (wend <= wstart); - output_data[pool_index] = - is_empty ? 0 : -std::numeric_limits::max(); - argmax_data[pool_index] = -1; - - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - const int index = h * width + w; - if (batch_data[index] > output_data[pool_index]) { - output_data[pool_index] = batch_data[index]; - argmax_data[pool_index] = index; - } - } - } - } - } - - batch_data += in_stride[1]; - output_data += out_stride[1]; - argmax_data += argmax_stride[1]; - } - // Increment ROI data pointer - rois_data += roi_stride[0]; - } - return; - } -}; - -template -class CPUROIPoolGradOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* in = ctx.Input("X"); - auto* rois = ctx.Input("ROIs"); - auto* argmax = ctx.Input("Argmax"); - auto* out_grad = - ctx.Input(framework::GradVarName("Out")); - auto* in_grad = ctx.Output(framework::GradVarName("X")); - - auto pooled_height = ctx.Attr("pooled_height"); - auto pooled_width = ctx.Attr("pooled_width"); - - if (in_grad) { - int rois_num = rois->dims()[0]; - framework::Tensor roi_batch_id_list; - roi_batch_id_list.Resize({rois_num}); - int* roi_batch_id_data = - roi_batch_id_list.mutable_data(ctx.GetPlace()); - - int rois_batch_size; - if (ctx.HasInput("RoisNum")) { - auto* rois_num_t = ctx.Input("RoisNum"); - rois_batch_size = rois_num_t->numel(); - auto* rois_num_data = rois_num_t->data(); - int start = 0; - for (int n = 0; n < rois_batch_size; ++n) { - for (int i = start; i < start + rois_num_data[n]; ++i) { - roi_batch_id_data[i] = n; - } - start += rois_num_data[n]; - } - } else { - auto rois_lod = rois->lod().back(); - rois_batch_size = rois_lod.size() - 1; - for (int n = 0; n < rois_batch_size; ++n) { - for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { - roi_batch_id_data[i] = n; - } - } - } - - const T* rois_data = rois->data(); - const T* out_grad_data = out_grad->data(); - const int64_t* argmax_data = argmax->data(); - T* in_grad_data = in_grad->mutable_data(ctx.GetPlace()); - phi::funcs::SetConstant set_zero; - set_zero(ctx.template device_context(), in_grad, - static_cast(0)); - - auto in_stride = phi::stride(in->dims()); - auto argmax_stride = phi::stride(argmax->dims()); - auto roi_stride = phi::stride(rois->dims()); - auto out_stride = phi::stride(out_grad->dims()); - - int channels = in->dims()[1]; - - for (int n = 0; n < rois_num; ++n) { - int roi_batch_idx = roi_batch_id_data[n]; - T* batch_grad_data = in_grad_data + roi_batch_idx * in_stride[0]; - for (int c = 0; c < channels; ++c) { - for (int ph = 0; ph < pooled_height; ++ph) { - for (int pw = 0; pw < pooled_width; ++pw) { - int pool_index = ph * pooled_width + pw; - if (argmax_data[pool_index] >= 0) { - auto index = argmax_data[pool_index]; - batch_grad_data[index] += out_grad_data[pool_index]; - } - } - } - batch_grad_data += in_stride[1]; - out_grad_data += out_stride[1]; - argmax_data += argmax_stride[1]; - } - rois_data += roi_stride[0]; - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/kernels/cpu/roi_pool_grad_kernel.cc b/paddle/phi/kernels/cpu/roi_pool_grad_kernel.cc new file mode 100644 index 0000000000000..0eaa873590eb0 --- /dev/null +++ b/paddle/phi/kernels/cpu/roi_pool_grad_kernel.cc @@ -0,0 +1,108 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/roi_pool_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void RoiPoolGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& boxes, + paddle::optional boxes_num, + const DenseTensor& arg_max, + const DenseTensor& out_grad, + int pooled_height, + int pooled_width, + float spatial_scale, + DenseTensor* dx) { + if (dx) { + int rois_num = boxes.dims()[0]; + DenseTensor box_batch_id_list = Empty(dev_ctx, {rois_num}); + int* box_batch_id_data = box_batch_id_list.data(); + + int boxes_batch_size; + if (boxes_num) { + boxes_batch_size = boxes_num->numel(); + auto* boxes_num_data = boxes_num->data(); + int start = 0; + for (int n = 0; n < boxes_batch_size; ++n) { + for (int i = start; i < start + boxes_num_data[n]; ++i) { + box_batch_id_data[i] = n; + } + start += boxes_num_data[n]; + } + } else { + auto boxes_lod = boxes.lod().back(); + boxes_batch_size = boxes_lod.size() - 1; + for (int n = 0; n < boxes_batch_size; ++n) { + for (size_t i = boxes_lod[n]; i < boxes_lod[n + 1]; ++i) { + box_batch_id_data[i] = n; + } + } + } + + const T* boxes_data = boxes.data(); + const T* out_grad_data = out_grad.data(); + const int64_t* arg_max_data = arg_max.data(); + T* dx_data = dev_ctx.template Alloc(dx); + + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, dx, static_cast(0)); + + auto in_stride = phi::stride(x.dims()); + auto arg_max_stride = phi::stride(arg_max.dims()); + auto roi_stride = phi::stride(boxes.dims()); + auto out_stride = phi::stride(out_grad.dims()); + + int channels = x.dims()[1]; + + for (int n = 0; n < rois_num; ++n) { + int roi_batch_idx = box_batch_id_data[n]; + T* batch_grad_data = dx_data + roi_batch_idx * in_stride[0]; + for (int c = 0; c < channels; ++c) { + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int pool_index = ph * pooled_width + pw; + if (arg_max_data[pool_index] >= 0) { + auto index = arg_max_data[pool_index]; + batch_grad_data[index] += out_grad_data[pool_index]; + } + } + } + batch_grad_data += in_stride[1]; + out_grad_data += out_stride[1]; + arg_max_data += arg_max_stride[1]; + } + boxes_data += roi_stride[0]; + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(roi_pool_grad, + CPU, + ALL_LAYOUT, + phi::RoiPoolGradKernel, + float, + double, + int) { + kernel->InputAt(3).SetDataType(phi::DataType::INT64); +} diff --git a/paddle/phi/kernels/cpu/roi_pool_kernel.cc b/paddle/phi/kernels/cpu/roi_pool_kernel.cc new file mode 100644 index 0000000000000..02020354cd357 --- /dev/null +++ b/paddle/phi/kernels/cpu/roi_pool_kernel.cc @@ -0,0 +1,163 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/roi_pool_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/empty_kernel.h" + +namespace phi { + +template +void RoiPoolKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& boxes, + paddle::optional boxes_num, + int pooled_height, + int pooled_width, + float spatial_scale, + DenseTensor* out, + DenseTensor* arg_max) { + auto x_dims = x.dims(); + int batch_size = x_dims[0]; + int channels = x_dims[1]; + int height = x_dims[2]; + int width = x_dims[3]; + int rois_num = boxes.dims()[0]; + + auto in_stride = phi::stride(x_dims); + auto arg_max_stride = phi::stride(arg_max->dims()); + auto box_stride = phi::stride(boxes.dims()); + auto out_stride = phi::stride(out->dims()); + + const T* input_data = x.data(); + + DenseTensor box_batch_id_list = Empty(dev_ctx, {rois_num}); + int* box_batch_id_data = box_batch_id_list.data(); + + int boxes_batch_size; + if (boxes_num) { + boxes_batch_size = boxes_num->numel(); + PADDLE_ENFORCE_EQ( + boxes_batch_size, + batch_size, + phi::errors::InvalidArgument("The boxes_batch_size and imgs " + "batch_size must be the same.")); + auto* boxes_num_data = boxes_num->data(); + int start = 0; + for (int n = 0; n < boxes_batch_size; ++n) { + for (int i = start; i < start + boxes_num_data[n]; ++i) { + box_batch_id_data[i] = n; + } + start += boxes_num_data[n]; + } + } else { + auto boxes_lod = boxes.lod().back(); + boxes_batch_size = boxes_lod.size() - 1; + PADDLE_ENFORCE_EQ( + boxes_batch_size, + batch_size, + phi::errors::InvalidArgument("The boxes_batch_size and imgs " + "batch_size must be the same.")); + int rois_num_with_lod = boxes_lod[boxes_batch_size]; + PADDLE_ENFORCE_EQ( + rois_num, + rois_num_with_lod, + phi::errors::InvalidArgument("The rois_num from input " + "and lod must be the same.")); + for (int n = 0; n < boxes_batch_size; ++n) { + for (size_t i = boxes_lod[n]; i < boxes_lod[n + 1]; ++i) { + box_batch_id_data[i] = n; + } + } + } + + T* output_data = dev_ctx.template Alloc(out); + int64_t* arg_max_data = dev_ctx.template Alloc(arg_max); + + const T* boxes_data = boxes.data(); + for (int n = 0; n < rois_num; ++n) { + int box_batch_id = box_batch_id_data[n]; + int box_start_w = round(boxes_data[0] * spatial_scale); + int box_start_h = round(boxes_data[1] * spatial_scale); + int box_end_w = round(boxes_data[2] * spatial_scale); + int box_end_h = round(boxes_data[3] * spatial_scale); + + // Force malformed ROIs to be 1x1 + int box_height = std::max(box_end_h - box_start_h + 1, 1); + int box_width = std::max(box_end_w - box_start_w + 1, 1); + + const float bin_size_h = + static_cast(box_height) / static_cast(pooled_height); + const float bin_size_w = + static_cast(box_width) / static_cast(pooled_width); + + const T* batch_data = input_data + box_batch_id * in_stride[0]; + + for (int c = 0; c < channels; ++c) { + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + // Compute pooling region for this output unit: + // start (included) = floor(ph * box_height / pooled_height_) + // end (excluded) = ceil((ph + 1) * box_height / pooled_height_) + int hstart = + static_cast(floor(static_cast(ph) * bin_size_h)); + int wstart = + static_cast(floor(static_cast(pw) * bin_size_w)); + int hend = + static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + int wend = + static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + hstart = std::min(std::max(hstart + box_start_h, 0), height); + hend = std::min(std::max(hend + box_start_h, 0), height); + wstart = std::min(std::max(wstart + box_start_w, 0), width); + wend = std::min(std::max(wend + box_start_w, 0), width); + + const int pool_index = ph * pooled_width + pw; + + // Define an empty pooling region to be zero + bool is_empty = (hend <= hstart) || (wend <= wstart); + output_data[pool_index] = + is_empty ? 0 : -std::numeric_limits::max(); + arg_max_data[pool_index] = -1; + + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + const int index = h * width + w; + if (batch_data[index] > output_data[pool_index]) { + output_data[pool_index] = batch_data[index]; + arg_max_data[pool_index] = index; + } + } + } + } + } + + batch_data += in_stride[1]; + output_data += out_stride[1]; + arg_max_data += arg_max_stride[1]; + } + // Increment ROI data pointer + boxes_data += box_stride[0]; + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + roi_pool, CPU, ALL_LAYOUT, phi::RoiPoolKernel, float, double, int) { + kernel->OutputAt(1).SetDataType(phi::DataType::INT64); +} diff --git a/paddle/phi/kernels/gpu/roi_align_kernel.cu b/paddle/phi/kernels/gpu/roi_align_kernel.cu index cd4ed29cdd1dd..cb3375dee95a5 100644 --- a/paddle/phi/kernels/gpu/roi_align_kernel.cu +++ b/paddle/phi/kernels/gpu/roi_align_kernel.cu @@ -18,7 +18,6 @@ #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/fluid/memory/memory.h" diff --git a/paddle/phi/kernels/gpu/roi_pool_grad_kernel.cu b/paddle/phi/kernels/gpu/roi_pool_grad_kernel.cu new file mode 100644 index 0000000000000..d093a71d23f4e --- /dev/null +++ b/paddle/phi/kernels/gpu/roi_pool_grad_kernel.cu @@ -0,0 +1,165 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/roi_pool_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +#include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" + +namespace phi { + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaxinumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaxinumNumBlocks); +} + +template +__global__ void GPURoiPoolBackward(const int nthreads, + const T* input_rois, + const T* output_grad, + const int64_t* arg_max_data, + const int num_rois, + const float spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + int* box_batch_id_data, + T* input_grad) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (int i = index; i < nthreads; i += offset) { + int pw = i % pooled_width; + int ph = (i / pooled_width) % pooled_height; + int c = (i / pooled_width / pooled_height) % channels; + int n = i / pooled_width / pooled_height / channels; + + int roi_batch_ind = box_batch_id_data[n]; + int input_offset = (roi_batch_ind * channels + c) * height * width; + int output_offset = (n * channels + c) * pooled_height * pooled_width; + const T* offset_output_grad = output_grad + output_offset; + T* offset_input_grad = input_grad + input_offset; + const int64_t* offset_arg_max_data = arg_max_data + output_offset; + + int arg_max = offset_arg_max_data[ph * pooled_width + pw]; + if (arg_max != -1) { + paddle::platform::CudaAtomicAdd( + offset_input_grad + arg_max, + static_cast(offset_output_grad[ph * pooled_width + pw])); + } + } +} + +template +void RoiPoolGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& boxes, + paddle::optional boxes_num, + const DenseTensor& arg_max, + const DenseTensor& out_grad, + int pooled_height, + int pooled_width, + float spatial_scale, + DenseTensor* dx) { + auto x_dims = x.dims(); + int channels = x_dims[1]; + int height = x_dims[2]; + int width = x_dims[3]; + int rois_num = boxes.dims()[0]; + + if (dx) { + DenseTensor box_batch_id_list; + box_batch_id_list.Resize({rois_num}); + int* box_batch_id_data = + dev_ctx.template HostAlloc(&box_batch_id_list); + + auto gplace = dev_ctx.GetPlace(); + if (boxes_num) { + int boxes_batch_size = boxes_num->numel(); + std::vector boxes_num_list(boxes_batch_size); + paddle::memory::Copy(phi::CPUPlace(), + boxes_num_list.data(), + gplace, + boxes_num->data(), + sizeof(int) * boxes_batch_size, + 0); + int start = 0; + for (int n = 0; n < boxes_batch_size; ++n) { + for (int i = start; i < start + boxes_num_list[n]; ++i) { + box_batch_id_data[i] = n; + } + start += boxes_num_list[n]; + } + } else { + auto boxes_lod = boxes.lod().back(); + int boxes_batch_size = boxes_lod.size() - 1; + for (int n = 0; n < boxes_batch_size; ++n) { + for (size_t i = boxes_lod[n]; i < boxes_lod[n + 1]; ++i) { + box_batch_id_data[i] = n; + } + } + } + int bytes = box_batch_id_list.numel() * sizeof(int); + auto roi_ptr = paddle::memory::Alloc(dev_ctx, bytes); + int* roi_id_data = reinterpret_cast(roi_ptr->ptr()); + paddle::memory::Copy(gplace, + roi_id_data, + phi::CPUPlace(), + box_batch_id_data, + bytes, + dev_ctx.stream()); + + dev_ctx.template Alloc(dx); + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, dx, static_cast(0)); + + int output_grad_size = out_grad.numel(); + int blocks = NumBlocks(output_grad_size); + int threads = kNumCUDAThreads; + + if (output_grad_size > 0) { + GPURoiPoolBackward<<>>( + output_grad_size, + boxes.data(), + out_grad.data(), + arg_max.data(), + rois_num, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + roi_id_data, + dx->data()); + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + roi_pool_grad, GPU, ALL_LAYOUT, phi::RoiPoolGradKernel, float, double) { + kernel->InputAt(3).SetDataType(phi::DataType::INT64); +} diff --git a/paddle/phi/kernels/gpu/roi_pool_kernel.cu b/paddle/phi/kernels/gpu/roi_pool_kernel.cu new file mode 100644 index 0000000000000..ab33e2cf64751 --- /dev/null +++ b/paddle/phi/kernels/gpu/roi_pool_kernel.cu @@ -0,0 +1,220 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/roi_pool_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/fluid/memory/memory.h" + +namespace phi { + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaxinumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaxinumNumBlocks); +} + +template +__global__ void GPURoiPoolForward(const int nthreads, + const T* input_data, + const T* input_rois, + const float spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + int* box_batch_id_data, + T* output_data, + int64_t* arg_max_data) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (size_t i = index; i < nthreads; i += offset) { + int pw = i % pooled_width; + int ph = (i / pooled_width) % pooled_height; + int c = (i / pooled_width / pooled_height) % channels; + int n = i / pooled_width / pooled_height / channels; + + const T* offset_input_rois = input_rois + n * kROISize; + int box_batch_ind = box_batch_id_data[n]; + int box_start_w = round(offset_input_rois[0] * spatial_scale); + int box_start_h = round(offset_input_rois[1] * spatial_scale); + int box_end_w = round(offset_input_rois[2] * spatial_scale); + int box_end_h = round(offset_input_rois[3] * spatial_scale); + + int box_width = max(box_end_w - box_start_w + 1, 1); + int box_height = max(box_end_h - box_start_h + 1, 1); + + int hstart = static_cast(floor(static_cast(ph) * + static_cast(box_height) / + static_cast(pooled_height))); + int wstart = static_cast(floor(static_cast(pw) * + static_cast(box_width) / + static_cast(pooled_width))); + int hend = static_cast(ceil(static_cast(ph + 1) * + static_cast(box_height) / + static_cast(pooled_height))); + int wend = static_cast(ceil(static_cast(pw + 1) * + static_cast(box_width) / + static_cast(pooled_width))); + hstart = min(max(hstart + box_start_h, 0), height); + hend = min(max(hend + box_start_h, 0), height); + wstart = min(max(wstart + box_start_w, 0), width); + wend = min(max(wend + box_start_w, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + T maxval = is_empty ? 0 : -std::numeric_limits::max(); + int maxidx = -1; + const T* offset_input_data = + input_data + (box_batch_ind * channels + c) * height * width; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int input_data_index = h * width + w; + if (offset_input_data[input_data_index] > maxval) { + maxval = offset_input_data[input_data_index]; + maxidx = input_data_index; + } + } + } + output_data[i] = maxval; + if (arg_max_data) { + arg_max_data[i] = maxidx; + } + } +} + +template +void RoiPoolKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& boxes, + paddle::optional boxes_num, + int pooled_height, + int pooled_width, + float spatial_scale, + DenseTensor* out, + DenseTensor* arg_max) { + auto x_dims = x.dims(); + int batch_size = x_dims[0]; + auto in_stride = phi::stride(x_dims); + int channels = x_dims[1]; + int height = x_dims[2]; + int width = x_dims[3]; + + int rois_num = boxes.dims()[0]; + + if (rois_num == 0) return; + + int output_size = out->numel(); + int blocks = NumBlocks(output_size); + int threads = kNumCUDAThreads; + + DenseTensor box_batch_id_list; + box_batch_id_list.Resize({rois_num}); + int* box_batch_id_data = dev_ctx.template HostAlloc(&box_batch_id_list); + auto gplace = dev_ctx.GetPlace(); + + if (boxes_num) { + int boxes_batch_size = boxes_num->numel(); + PADDLE_ENFORCE_EQ( + boxes_batch_size, + batch_size, + phi::errors::InvalidArgument( + "The batch size of input(ROIs) and input(X) must be the same but " + "received batch size of input(ROIs) and input(X) is %d and %d " + "respectively.", + boxes_batch_size, + batch_size)); + std::vector boxes_num_list(boxes_batch_size); + paddle::memory::Copy(phi::CPUPlace(), + boxes_num_list.data(), + gplace, + boxes_num->data(), + sizeof(int) * boxes_batch_size, + 0); + int start = 0; + for (int n = 0; n < boxes_batch_size; ++n) { + for (int i = start; i < start + boxes_num_list[n]; ++i) { + box_batch_id_data[i] = n; + } + start += boxes_num_list[n]; + } + } else { + auto boxes_lod = boxes.lod().back(); + int boxes_batch_size = boxes_lod.size() - 1; + PADDLE_ENFORCE_EQ( + boxes_batch_size, + batch_size, + phi::errors::InvalidArgument( + "The batch size of input(ROIs) and input(X) must be the same but " + "received batch size of input(ROIs) and input(X) is %d and %d " + "respectively.", + boxes_batch_size, + batch_size)); + + int boxes_num_with_lod = boxes_lod[boxes_batch_size]; + PADDLE_ENFORCE_EQ(rois_num, + boxes_num_with_lod, + phi::errors::InvalidArgument( + "The number of rois from input(ROIs) and its LOD " + "must be the same. Received rois %d of input(ROIs) " + "but the number of rois %d from its LOD is %d", + rois_num, + boxes_num_with_lod)); + for (int n = 0; n < boxes_batch_size; ++n) { + for (size_t i = boxes_lod[n]; i < boxes_lod[n + 1]; ++i) { + box_batch_id_data[i] = n; + } + } + } + + int bytes = box_batch_id_list.numel() * sizeof(int); + auto box_ptr = paddle::memory::Alloc(dev_ctx, bytes); + int* box_id_data = reinterpret_cast(box_ptr->ptr()); + paddle::memory::Copy(gplace, + box_id_data, + phi::CPUPlace(), + box_batch_id_data, + bytes, + dev_ctx.stream()); + + T* output_data = dev_ctx.template Alloc(out); + int64_t* arg_max_data = dev_ctx.template Alloc(arg_max); + + GPURoiPoolForward<<>>( + output_size, + x.data(), + boxes.data(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + box_id_data, + output_data, + arg_max_data); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + roi_pool, GPU, ALL_LAYOUT, phi::RoiPoolKernel, float, double) { + kernel->OutputAt(1).SetDataType(phi::DataType::INT64); +} diff --git a/paddle/phi/kernels/roi_pool_grad_kernel.h b/paddle/phi/kernels/roi_pool_grad_kernel.h new file mode 100644 index 0000000000000..d7f1c378f75c3 --- /dev/null +++ b/paddle/phi/kernels/roi_pool_grad_kernel.h @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/utils/optional.h" + +namespace phi { + +template +void RoiPooGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& boxes, + paddle::optional boxes_num, + const DenseTensor& arg_max, + const DenseTensor& out_grad, + int pooled_height, + int pooled_width, + float spatial_scale, + DenseTensor* dx); + +} // namespace phi diff --git a/paddle/phi/kernels/roi_pool_kernel.h b/paddle/phi/kernels/roi_pool_kernel.h new file mode 100644 index 0000000000000..c6ff6f223612a --- /dev/null +++ b/paddle/phi/kernels/roi_pool_kernel.h @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/utils/optional.h" + +namespace phi { + +static constexpr int kROISize = 4; + +template +void RoiPoolKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& boxes, + paddle::optional boxes_num, + int pooled_height, + int pooled_width, + float spatial_scale, + DenseTensor* out, + DenseTensor* arg_max); + +} // namespace phi diff --git a/paddle/phi/ops/compat/roi_pool_sig.cc b/paddle/phi/ops/compat/roi_pool_sig.cc new file mode 100644 index 0000000000000..d04c645f183c6 --- /dev/null +++ b/paddle/phi/ops/compat/roi_pool_sig.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature RoiPoolOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("roi_pool", + {"X", "ROIs", "RoisNum"}, + {"pooled_height", "pooled_width", "spatial_scale"}, + {"Out", "Argmax"}); +} + +KernelSignature RoiPoolOpGradArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("roi_pool_grad", + {"X", "ROIs", "RoisNum", "Argmax", GradVarName("Out")}, + {"pooled_height", "pooled_width", "spatial_scale"}, + {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(roi_pool, phi::RoiPoolOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(roi_pool_grad, phi::RoiPoolOpGradArgumentMapping);