Skip to content

Commit

Permalink
refine fused_elemwise_activation error message, test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
juncaipeng committed Sep 30, 2020
1 parent 96daa25 commit 1bbf446
Showing 1 changed file with 38 additions and 18 deletions.
56 changes: 38 additions & 18 deletions paddle/fluid/operators/fused/fused_elemwise_activation_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,8 @@ static void RunFunctors(const framework::ExecutionContext &ctx,
ctx, paddle::operators::math::MulFunctor<T>(),
paddle::operators::math::SigmoidFunctor<T>(), in_x, in_y, outputs);
} else {
PADDLE_THROW("%s has not been implemented.", funcs_str);
PADDLE_THROW(platform::errors::InvalidArgument(
"%s has not been implemented.", funcs_str));
}
}

Expand Down Expand Up @@ -374,7 +375,8 @@ static void RunGradFunctors(
paddle::operators::math::SigmoidGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else {
PADDLE_THROW("%s has not been implemented.", funcs_str);
PADDLE_THROW(platform::errors::InvalidArgument(
"%s has not been implemented.", funcs_str));
}
}

Expand All @@ -386,16 +388,21 @@ class FusedElemwiseActivationKernel : public framework::OpKernel<T> {
"X", "FusedElemwiseActivation");
auto &in_y = GET_DATA_SAFELY(ctx.Input<framework::Tensor>("Y"), "Input",
"Y", "FusedElemwiseActivation");
PADDLE_ENFORCE(ctx.HasOutput("Out"), "The output(Out) should not be empty");

PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"The output(Out) should not be empty"));
auto output = ctx.Output<framework::Tensor>("Out");

std::vector<framework::Tensor *> outputs;
outputs.emplace_back(output);

if (ctx.Attr<bool>("save_intermediate_out")) {
PADDLE_ENFORCE(ctx.HasOutput("IntermediateOut"),
"The save_intermediate_out is enable, so the "
"IntermediateOut should not be empty.");
PADDLE_ENFORCE_EQ(ctx.HasOutput("IntermediateOut"), true,
platform::errors::InvalidArgument(
"The save_intermediate_out is enable, so the "
"IntermediateOut should not be empty."));

auto intermediate_out = ctx.Output<framework::Tensor>("IntermediateOut");
outputs.emplace_back(intermediate_out);
} else {
Expand All @@ -411,13 +418,20 @@ class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto in_y = ctx.Input<framework::Tensor>("Y");
PADDLE_ENFORCE(in_y != nullptr, "Input(Y) should not be nullptr.");
PADDLE_ENFORCE_NE(in_y, nullptr,
platform::errors::InvalidArgument(
"Input(Y) should not be nullptr.");
auto in_out = ctx.Input<framework::Tensor>("Out");
PADDLE_ENFORCE(in_out != nullptr, "Input(Out) should not be nullptr.");
PADDLE_ENFORCE_NE(in_out, nullptr,
platform::errors::InvalidArgument(
"Input(Out) should not be nullptr.");
auto in_out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE(in_out_grad != nullptr,
"Input(Out@Grad) should not be nullptr.");
PADDLE_ENFORCE_NE(in_out_grad, nullptr,
platform::errors::InvalidArgument(
"Input(Out@Grad) should not be nullptr.");


framework::Tensor *in_x =
const_cast<framework::Tensor *>(ctx.Input<framework::Tensor>("X"));
framework::Tensor *x_grad =
Expand All @@ -437,24 +451,30 @@ class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> {
// recompute.
in_intermediate_out = const_cast<framework::Tensor *>(
ctx.Input<framework::Tensor>("IntermediateOut"));
PADDLE_ENFORCE(in_intermediate_out != nullptr,
"The option of 'save_intermediate_out' is opened, "
"so the number of 'Out' should be two.");
PADDLE_ENFORCE_NE(in_intermediate_out, nullptr,
platform::errors::InvalidArgument(
"The option of 'save_intermediate_out' is opened,"
" so the number of 'Out' should be two.");
} else {
if (!InputXCanBeAbsent(functor_list)) {
PADDLE_ENFORCE(in_x != nullptr, "Input(X) should not be null.");
PADDLE_ENFORCE_NE(in_x, nullptr,
platform::errors::InvalidArgument(
"Input(X) should not be null.");
}
}

// Get in_x
if (ctx.HasInput("X")) {
PADDLE_ENFORCE(in_x != nullptr, "Input(X) should not be nullptr.");
PADDLE_ENFORCE_NE(in_x, nullptr,
platform::errors::InvalidArgument(
"Input(X) should not be null.");
} else {
// If functor_list contains elementwise_add, the backward doesn't use
// in_x, in_y and in_out.
PADDLE_ENFORCE(InputXCanBeAbsent(functor_list),
"Only when the compoundfunctor contains "
"elementwise_add_grad, the 'X' could be absent.");
PADDLE_ENFORCE_EQ(InputXCanBeAbsent(functor_list), true,
platform::errors::InvalidArgument(
"Only when the compoundfunctor contains "
"elementwise_add_grad, the 'X' could be absent."));
in_x = const_cast<framework::Tensor *>(in_out_grad);
}

Expand Down

0 comments on commit 1bbf446

Please sign in to comment.