Skip to content

Commit

Permalink
fix conflict (#40851)
Browse files Browse the repository at this point in the history
  • Loading branch information
csy0225 authored Mar 31, 2022
1 parent e559fe4 commit 74894cd
Show file tree
Hide file tree
Showing 47 changed files with 2,641 additions and 1,310 deletions.
57 changes: 8 additions & 49 deletions paddle/fluid/operators/range_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ limitations under the License. */

#include "paddle/fluid/operators/range_op.h"
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/ternary.h"

namespace paddle {
namespace operators {
Expand All @@ -22,51 +26,6 @@ class RangeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
if (ctx->HasInput("Start")) {
auto s_dims = ctx->GetInputDim("Start");
PADDLE_ENFORCE_EQ(
s_dims.size(), 1,
platform::errors::InvalidArgument(
"The dim of the shape of Input(Start) should be 1, but got %d",
s_dims.size()));

PADDLE_ENFORCE_EQ(s_dims[0], 1,
platform::errors::InvalidArgument(
"The first dim of the shape of Input(Start) should "
"be 1, but got %d",
s_dims[0]));
}
if (ctx->HasInput("End")) {
auto e_dims = ctx->GetInputDim("End");
PADDLE_ENFORCE_EQ(
e_dims.size(), 1,
platform::errors::InvalidArgument(
"The dim of the shape of Input(End) should be 1, but got %d",
e_dims.size()));

PADDLE_ENFORCE_EQ(e_dims[0], 1, platform::errors::InvalidArgument(
"The first dim of the shape of "
"Input(End) should be 1, but got %d",
e_dims[0]));
}
if (ctx->HasInput("Step")) {
auto step_dims = ctx->GetInputDim("Step");
PADDLE_ENFORCE_EQ(
step_dims.size(), 1,
platform::errors::InvalidArgument(
"The dim of the shape of Input(Step) should be 1, but got %d",
step_dims.size()));

PADDLE_ENFORCE_EQ(step_dims[0], 1,
platform::errors::InvalidArgument(
"The first dim of the shape of Input(Step) should "
"be 1, but got %d",
step_dims[0]));
}
ctx->SetOutputDim("Out", {-1});
}

protected:
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
Expand Down Expand Up @@ -101,7 +60,7 @@ class RangeOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(range, ops::RangeOp, ops::RangeOpMaker);
REGISTER_OP_CPU_KERNEL(range, ops::CPURangeKernel<int>,
ops::CPURangeKernel<float>, ops::CPURangeKernel<double>,
ops::CPURangeKernel<int64_t>);
DECLARE_INFER_SHAPE_FUNCTOR(range, RangeInferMetaFunctor,
PD_INFER_META(phi::RangeInferMeta));
REGISTER_OP_WITHOUT_GRADIENT(range, ops::RangeOp, ops::RangeOpMaker,
RangeInferMetaFunctor);
61 changes: 0 additions & 61 deletions paddle/fluid/operators/range_op.cu

This file was deleted.

2 changes: 1 addition & 1 deletion paddle/fluid/operators/range_op_npu_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ limitations under the License. */
namespace f = paddle::framework;
namespace p = paddle::platform;

USE_OP(range);
USE_OP_ITSELF(range);
USE_OP_DEVICE_KERNEL(range, NPU);

template <typename T>
Expand Down
70 changes: 8 additions & 62 deletions paddle/fluid/operators/stack_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/stack_op.h"
#include <memory>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"

namespace plat = paddle::platform;
namespace ops = paddle::operators;
Expand All @@ -26,52 +29,6 @@ class StackOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 0,
platform::errors::InvalidArgument(
"Number of Inputs(X) must be larger than 0, but"
" received value is:%d.",
ctx->Inputs("X").size()));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Y"), true,
platform::errors::InvalidArgument(
"Output(Y) of stack_op should not be null."));

auto input_dims = ctx->GetInputsDim("X");
for (size_t i = 1; i < input_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0],
platform::errors::InvalidArgument(
"Dims of all Inputs(X) must be the same, but"
" received input %d dim is:%d not equal to input 0"
" dim:%d.",
i, input_dims[i], input_dims[0]));
}

// Only lod of X[0] would be shared with Y
ctx->ShareLoD("X", /*->*/ "Y");

int axis = ctx->Attrs().Get<int>("axis");
int rank = input_dims[0].size();
PADDLE_ENFORCE_GE(
axis, -(rank + 1),
platform::errors::InvalidArgument(
"Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d, "
"but received axis is:%d.",
rank, axis));

PADDLE_ENFORCE_LT(
axis, rank + 1,
platform::errors::InvalidArgument(
"Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d, "
"but received axis is:%d",
rank, axis));

if (axis < 0) axis += (rank + 1);

auto vec = phi::vectorize<int>(input_dims[0]);
vec.insert(vec.begin() + axis, input_dims.size());
ctx->SetOutputDim("Y", phi::make_ddim(vec));
}

framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
Expand Down Expand Up @@ -168,21 +125,10 @@ class StackGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace operators
} // namespace paddle

DECLARE_INFER_SHAPE_FUNCTOR(stack, StackInferMetaFunctor,
PD_INFER_META(phi::StackInferMeta));
REGISTER_OPERATOR(stack, ops::StackOp, ops::StackOpMaker,
ops::StackGradOpMaker<paddle::framework::OpDesc>,
ops::StackGradOpMaker<paddle::imperative::OpBase>);
ops::StackGradOpMaker<paddle::imperative::OpBase>,
StackInferMetaFunctor);
REGISTER_OPERATOR(stack_grad, ops::StackOpGrad);

REGISTER_OP_CPU_KERNEL(
stack, ops::StackKernel<plat::CPUDeviceContext, float>,
ops::StackKernel<plat::CPUDeviceContext, double>,
ops::StackKernel<plat::CPUDeviceContext, int>,
ops::StackKernel<plat::CPUDeviceContext, int64_t>,
ops::StackKernel<plat::CPUDeviceContext, paddle::platform::bfloat16>);

REGISTER_OP_CPU_KERNEL(
stack_grad, ops::StackGradKernel<plat::CPUDeviceContext, float>,
ops::StackGradKernel<plat::CPUDeviceContext, double>,
ops::StackGradKernel<plat::CPUDeviceContext, int>,
ops::StackGradKernel<plat::CPUDeviceContext, int64_t>,
ops::StackGradKernel<plat::CPUDeviceContext, paddle::platform::bfloat16>);
Loading

0 comments on commit 74894cd

Please sign in to comment.