Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refine fused_elemwise_activation error message #27734

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 34 additions & 18 deletions paddle/fluid/operators/fused/fused_elemwise_activation_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,8 @@ static void RunFunctors(const framework::ExecutionContext &ctx,
ctx, paddle::operators::math::MulFunctor<T>(),
paddle::operators::math::SigmoidFunctor<T>(), in_x, in_y, outputs);
} else {
PADDLE_THROW("%s has not been implemented.", funcs_str);
PADDLE_THROW(platform::errors::InvalidArgument(
"%s has not been implemented.", funcs_str));
}
}

Expand Down Expand Up @@ -374,7 +375,8 @@ static void RunGradFunctors(
paddle::operators::math::SigmoidGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else {
PADDLE_THROW("%s has not been implemented.", funcs_str);
PADDLE_THROW(platform::errors::InvalidArgument(
"%s has not been implemented.", funcs_str));
}
}

Expand All @@ -386,16 +388,21 @@ class FusedElemwiseActivationKernel : public framework::OpKernel<T> {
"X", "FusedElemwiseActivation");
auto &in_y = GET_DATA_SAFELY(ctx.Input<framework::Tensor>("Y"), "Input",
"Y", "FusedElemwiseActivation");
PADDLE_ENFORCE(ctx.HasOutput("Out"), "The output(Out) should not be empty");

PADDLE_ENFORCE_EQ(ctx.HasOutput("Out"), true,
platform::errors::InvalidArgument(
"The output(Out) should not be empty"));
auto output = ctx.Output<framework::Tensor>("Out");

std::vector<framework::Tensor *> outputs;
outputs.emplace_back(output);

if (ctx.Attr<bool>("save_intermediate_out")) {
PADDLE_ENFORCE(ctx.HasOutput("IntermediateOut"),
"The save_intermediate_out is enable, so the "
"IntermediateOut should not be empty.");
PADDLE_ENFORCE_EQ(ctx.HasOutput("IntermediateOut"), true,
platform::errors::InvalidArgument(
"The save_intermediate_out is enable, so the "
"IntermediateOut should not be empty."));

auto intermediate_out = ctx.Output<framework::Tensor>("IntermediateOut");
outputs.emplace_back(intermediate_out);
} else {
Expand All @@ -411,13 +418,18 @@ class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto in_y = ctx.Input<framework::Tensor>("Y");
PADDLE_ENFORCE(in_y != nullptr, "Input(Y) should not be nullptr.");
PADDLE_ENFORCE_NE(in_y, nullptr, platform::errors::InvalidArgument(
"Input(Y) should not be nullptr."));
auto in_out = ctx.Input<framework::Tensor>("Out");
PADDLE_ENFORCE(in_out != nullptr, "Input(Out) should not be nullptr.");
PADDLE_ENFORCE_NE(
in_out, nullptr,
platform::errors::InvalidArgument("Input(Out) should not be nullptr."));
auto in_out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE(in_out_grad != nullptr,
"Input(Out@Grad) should not be nullptr.");
PADDLE_ENFORCE_NE(in_out_grad, nullptr,
platform::errors::InvalidArgument(
"Input(Out@Grad) should not be nullptr."));

framework::Tensor *in_x =
const_cast<framework::Tensor *>(ctx.Input<framework::Tensor>("X"));
framework::Tensor *x_grad =
Expand All @@ -437,24 +449,28 @@ class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> {
// recompute.
in_intermediate_out = const_cast<framework::Tensor *>(
ctx.Input<framework::Tensor>("IntermediateOut"));
PADDLE_ENFORCE(in_intermediate_out != nullptr,
"The option of 'save_intermediate_out' is opened, "
"so the number of 'Out' should be two.");
PADDLE_ENFORCE_NE(in_intermediate_out, nullptr,
platform::errors::InvalidArgument(
"The option of 'save_intermediate_out' is opened,"
" so the number of 'Out' should be two."));
} else {
if (!InputXCanBeAbsent(functor_list)) {
PADDLE_ENFORCE(in_x != nullptr, "Input(X) should not be null.");
PADDLE_ENFORCE_NE(in_x, nullptr, platform::errors::InvalidArgument(
"Input(X) should not be null."));
}
}

// Get in_x
if (ctx.HasInput("X")) {
PADDLE_ENFORCE(in_x != nullptr, "Input(X) should not be nullptr.");
PADDLE_ENFORCE_NE(in_x, nullptr, platform::errors::InvalidArgument(
"Input(X) should not be null."));
} else {
// If functor_list contains elementwise_add, the backward doesn't use
// in_x, in_y and in_out.
PADDLE_ENFORCE(InputXCanBeAbsent(functor_list),
"Only when the compoundfunctor contains "
"elementwise_add_grad, the 'X' could be absent.");
PADDLE_ENFORCE_EQ(InputXCanBeAbsent(functor_list), true,
platform::errors::InvalidArgument(
"Only when the compoundfunctor contains "
"elementwise_add_grad, the 'X' could be absent."));
in_x = const_cast<framework::Tensor *>(in_out_grad);
}

Expand Down