Skip to content

Commit

Permalink
up, test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
juncaipeng committed Sep 30, 2020
1 parent 1bbf446 commit d4b597d
Showing 1 changed file with 14 additions and 18 deletions.
32 changes: 14 additions & 18 deletions paddle/fluid/operators/fused/fused_elemwise_activation_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ class FusedElemwiseActivationKernel : public framework::OpKernel<T> {
auto &in_y = GET_DATA_SAFELY(ctx.Input<framework::Tensor>("Y"), "Input",
"Y", "FusedElemwiseActivation");

PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
PADDLE_ENFORCE_EQ(ctx.HasOutput("Out"), true,
platform::errors::InvalidArgument(
"The output(Out) should not be empty"));
auto output = ctx.Output<framework::Tensor>("Out");
Expand Down Expand Up @@ -418,19 +418,17 @@ class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto in_y = ctx.Input<framework::Tensor>("Y");
PADDLE_ENFORCE_NE(in_y, nullptr,
platform::errors::InvalidArgument(
"Input(Y) should not be nullptr.");
PADDLE_ENFORCE_NE(in_y, nullptr, platform::errors::InvalidArgument(
"Input(Y) should not be nullptr."));
auto in_out = ctx.Input<framework::Tensor>("Out");
PADDLE_ENFORCE_NE(in_out, nullptr,
platform::errors::InvalidArgument(
"Input(Out) should not be nullptr.");
PADDLE_ENFORCE_NE(
in_out, nullptr,
platform::errors::InvalidArgument("Input(Out) should not be nullptr."));
auto in_out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE_NE(in_out_grad, nullptr,
platform::errors::InvalidArgument(
"Input(Out@Grad) should not be nullptr.");

"Input(Out@Grad) should not be nullptr."));

framework::Tensor *in_x =
const_cast<framework::Tensor *>(ctx.Input<framework::Tensor>("X"));
Expand All @@ -452,22 +450,20 @@ class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> {
in_intermediate_out = const_cast<framework::Tensor *>(
ctx.Input<framework::Tensor>("IntermediateOut"));
PADDLE_ENFORCE_NE(in_intermediate_out, nullptr,
platform::errors::InvalidArgument(
"The option of 'save_intermediate_out' is opened,"
" so the number of 'Out' should be two.");
platform::errors::InvalidArgument(
"The option of 'save_intermediate_out' is opened,"
" so the number of 'Out' should be two."));
} else {
if (!InputXCanBeAbsent(functor_list)) {
PADDLE_ENFORCE_NE(in_x, nullptr,
platform::errors::InvalidArgument(
"Input(X) should not be null.");
PADDLE_ENFORCE_NE(in_x, nullptr, platform::errors::InvalidArgument(
"Input(X) should not be null."));
}
}

// Get in_x
if (ctx.HasInput("X")) {
PADDLE_ENFORCE_NE(in_x, nullptr,
platform::errors::InvalidArgument(
"Input(X) should not be null.");
PADDLE_ENFORCE_NE(in_x, nullptr, platform::errors::InvalidArgument(
"Input(X) should not be null."));
} else {
// If functor_list contains elementwise_add, the backward doesn't use
// in_x, in_y and in_out.
Expand Down

0 comments on commit d4b597d

Please sign in to comment.