Skip to content

Commit

Permalink
[PHI] move roi_pool kernel to phi (PaddlePaddle#40574)
Browse files Browse the repository at this point in the history
* move roi_pool forward kernel to phi

* move roi_pool_grad to phi

* fix compile bug

* fix compile bug

* fix register data_type
  • Loading branch information
zyfncg authored and liqitong-a committed Mar 17, 2022
1 parent 596652f commit 3d78c18
Show file tree
Hide file tree
Showing 11 changed files with 766 additions and 569 deletions.
16 changes: 4 additions & 12 deletions paddle/fluid/operators/roi_pool_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/roi_pool_op.h"
#include <memory>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/kernels/roi_pool_kernel.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -57,7 +58,7 @@ class ROIPoolOp : public framework::OperatorWithKernel {
"%d-dimensional LoDTensor",
rois_dims.size()));
PADDLE_ENFORCE_EQ(
rois_dims[1], kROISize,
rois_dims[1], phi::kROISize,
platform::errors::InvalidArgument(
"ROIs should be a 2-D LoDTensor with shape (num_rois, 4)"
"given as [[x1, y1, x2, y2], ...]. But the second dimension of "
Expand Down Expand Up @@ -216,16 +217,7 @@ REGISTER_OPERATOR(roi_pool, ops::ROIPoolOp, ops::ROIPoolOpMaker,
ops::ROIPoolGradMaker<paddle::framework::OpDesc>,
ops::ROIPoolGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(roi_pool_grad, ops::ROIPoolGradOp);
REGISTER_OP_CPU_KERNEL(
roi_pool,
ops::CPUROIPoolOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUROIPoolOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::CPUROIPoolOpKernel<paddle::platform::CPUDeviceContext, int>);
REGISTER_OP_CPU_KERNEL(
roi_pool_grad,
ops::CPUROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::CPUROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, int>);

REGISTER_OP_VERSION(roi_pool)
.AddCheckpoint(
R"ROC(
Expand Down
306 changes: 0 additions & 306 deletions paddle/fluid/operators/roi_pool_op.cu

This file was deleted.

Loading

0 comments on commit 3d78c18

Please sign in to comment.