Skip to content

Commit

Permalink
[PHI] Migrate where_index op (#40255)
Browse files Browse the repository at this point in the history
* [PHI] Migrate where_index op

* [PHI] Fix where_index infermate

* [Phi] set where_index out data type
  • Loading branch information
xiaolao authored Mar 10, 2022
1 parent 2747de2 commit 857069f
Show file tree
Hide file tree
Showing 11 changed files with 399 additions and 352 deletions.
30 changes: 11 additions & 19 deletions paddle/fluid/operators/where_index_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/where_index_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"

namespace paddle {
namespace operators {
Expand All @@ -21,16 +24,6 @@ class WhereIndexOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Condition"), "Input", "Condition", "where");
PADDLE_ENFORCE_GE(
ctx->GetInputDim("Condition").size(), 1UL,
platform::errors::InvalidArgument(
"Input(Condition) should have number of dimension at least 1"));
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "where");
ctx->SetOutputDim("Out", {-1, ctx->GetInputDim("Condition").size()});
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
Expand All @@ -53,11 +46,10 @@ class WhereIndexOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(where_index, ops::WhereIndexOp,
ops::WhereIndexOpMaker);
REGISTER_OP_CPU_KERNEL(where_index, ops::CPUWhereIndexKernel<int64_t>,
ops::CPUWhereIndexKernel<int>,
ops::CPUWhereIndexKernel<int16_t>,
ops::CPUWhereIndexKernel<bool>,
ops::CPUWhereIndexKernel<float>,
ops::CPUWhereIndexKernel<double>);
DECLARE_INFER_SHAPE_FUNCTOR(where_index, WhereIndexInferShapeFunctor,
PD_INFER_META(phi::WhereIndexInferMeta));
REGISTER_OPERATOR(
where_index, ops::WhereIndexOp, ops::WhereIndexOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
WhereIndexInferShapeFunctor);
164 changes: 0 additions & 164 deletions paddle/fluid/operators/where_index_op.cu

This file was deleted.

95 changes: 0 additions & 95 deletions paddle/fluid/operators/where_index_op.h

This file was deleted.

5 changes: 4 additions & 1 deletion paddle/fluid/operators/where_index_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/where_index_op.h"
#include <vector>

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace paddle {
namespace operators {
Expand Down
73 changes: 0 additions & 73 deletions paddle/fluid/operators/where_index_op_xpu.cc

This file was deleted.

Loading

0 comments on commit 857069f

Please sign in to comment.