Skip to content

Commit

Permalink
fix error mesage for negative_positive_pair_op and nce_op
Browse files Browse the repository at this point in the history
  • Loading branch information
Feiyu committed Oct 9, 2020
1 parent c826bcb commit 27b28f6
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 41 deletions.
77 changes: 43 additions & 34 deletions paddle/fluid/operators/nce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,25 +104,29 @@ class NCEKernel : public framework::OpKernel<T> {

PADDLE_ENFORCE_EQ(
dist_probs->numel(), num_total_classes,
"ShapeError: The number of elements in Input(CustomDistProbs) "
"should be equal to the number of total classes. But Received: "
"Input(CustomDistProbs).numel() = %d, Attr(num_total_classes) "
"= %d.",
dist_probs->numel(), num_total_classes);
platform::errors::InvalidArgument(
"ShapeError: The number of elements in Input(CustomDistProbs) "
"should be equal to the number of total classes. But Received: "
"Input(CustomDistProbs).numel() = %d, Attr(num_total_classes) "
"= %d.",
dist_probs->numel(), num_total_classes));
PADDLE_ENFORCE_EQ(
dist_alias->numel(), num_total_classes,
"ShapeError: The number of elements in Input(CustomDistAlias) "
"should be equal to the number of total classes. But Received: "
"Input(CustomDistAlias).numel() = %d, Attr(num_total_classes) "
"= %d.",
dist_alias->numel(), num_total_classes);
platform::errors::InvalidArgument(
"ShapeError: The number of elements in Input(CustomDistAlias) "
"should be equal to the number of total classes. But Received: "
"Input(CustomDistAlias).numel() = %d, Attr(num_total_classes) "
"= %d.",
dist_alias->numel(), num_total_classes));
PADDLE_ENFORCE_EQ(
dist_alias_probs->numel(), num_total_classes,
"ShapeError: The number of elements in Input(CustomDistAliasProbs) "
"should be equal to the number of total classes. But Received: "
"Input(CustomDistAliasProbs).numel() = %d, "
"Attr(num_total_classes) = %d.",
dist_alias_probs->numel(), num_total_classes);
platform::errors::InvalidArgument(
"ShapeError: The number of elements in "
"Input(CustomDistAliasProbs) "
"should be equal to the number of total classes. But Received: "
"Input(CustomDistAliasProbs).numel() = %d, "
"Attr(num_total_classes) = %d.",
dist_alias_probs->numel(), num_total_classes));

const float *probs_data = dist_probs->data<float>();
const int *alias_data = dist_alias->data<int>();
Expand All @@ -140,10 +144,11 @@ class NCEKernel : public framework::OpKernel<T> {

for (int x = 0; x < sample_labels->numel(); x++) {
PADDLE_ENFORCE_GE(sample_labels_data[x], 0,
"ValueError: Every sample label should be "
"non-negative. But received: "
"Input(SampleLabels)[%d] = %d",
x, sample_labels_data[x]);
platform::errors::InvalidArgument(
"ValueError: Every sample label should be "
"non-negative. But received: "
"Input(SampleLabels)[%d] = %d",
x, sample_labels_data[x]));
}

auto sample_out = context.Output<Tensor>("SampleLogits");
Expand Down Expand Up @@ -311,25 +316,29 @@ class NCEGradKernel : public framework::OpKernel<T> {

PADDLE_ENFORCE_EQ(
dist_probs->numel(), num_total_classes,
"ShapeError: The number of elements in Input(CustomDistProbs) "
"should be equal to the number of total classes. But Received: "
"Input(CustomDistProbs).numel() = %d, Attr(num_total_classes) "
"= %d.",
dist_probs->numel(), num_total_classes);
platform::errors::InvalidArgument(
"ShapeError: The number of elements in Input(CustomDistProbs) "
"should be equal to the number of total classes. But Received: "
"Input(CustomDistProbs).numel() = %d, Attr(num_total_classes) "
"= %d.",
dist_probs->numel(), num_total_classes));
PADDLE_ENFORCE_EQ(
dist_alias->numel(), num_total_classes,
"ShapeError: The number of elements in Input(CustomDistAlias) "
"should be equal to the number of total classes. But Received: "
"Input(CustomDistAlias).numel() = %d, Attr(num_total_classes) "
"= %d.",
dist_alias->numel(), num_total_classes);
platform::errors::InvalidArgument(
"ShapeError: The number of elements in Input(CustomDistAlias) "
"should be equal to the number of total classes. But Received: "
"Input(CustomDistAlias).numel() = %d, Attr(num_total_classes) "
"= %d.",
dist_alias->numel(), num_total_classes));
PADDLE_ENFORCE_EQ(
dist_alias_probs->numel(), num_total_classes,
"ShapeError: The number of elements in Input(CustomDistAliasProbs) "
"should be equal to the number of total classes. But Received: "
"Input(CustomDistAliasProbs).numel() = %d, "
"Attr(num_total_classes) = %d.",
dist_alias_probs->numel(), num_total_classes);
platform::errors::InvalidArgument(
"ShapeError: The number of elements in "
"Input(CustomDistAliasProbs) "
"should be equal to the number of total classes. But Received: "
"Input(CustomDistAliasProbs).numel() = %d, "
"Attr(num_total_classes) = %d.",
dist_alias_probs->numel(), num_total_classes));

const float *probs_data = dist_probs->data<float>();
const int *alias_data = dist_alias->data<int>();
Expand Down
16 changes: 9 additions & 7 deletions paddle/fluid/operators/positive_negative_pair_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
if (ctx->HasInput("AccumulatePositivePair") ||
ctx->HasInput("AccumulateNegativePair") ||
ctx->HasInput("AccumulateNeutralPair")) {
PADDLE_ENFORCE(ctx->HasInput("AccumulatePositivePair") &&
ctx->HasInput("AccumulateNegativePair") &&
ctx->HasInput("AccumulateNeutralPair"),
"All optional inputs(AccumulatePositivePair, "
"AccumulateNegativePair, AccumulateNeutralPair) of "
"PositiveNegativePairOp are required if one of them is "
"specified.");
PADDLE_ENFORCE_EQ(
ctx->HasInput("AccumulatePositivePair") &&
ctx->HasInput("AccumulateNegativePair") &&
ctx->HasInput("AccumulateNeutralPair"),
true, platform::errors::InvalidArgument(
"All optional inputs(AccumulatePositivePair, "
"AccumulateNegativePair, AccumulateNeutralPair) of "
"PositiveNegativePairOp are required if one of them "
"is specified."));
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("AccumulatePositivePair"), scalar_dim,
platform::errors::InvalidArgument(
Expand Down

1 comment on commit 27b28f6

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.