Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the proformance problem of enforce #6085

Merged
merged 4 commits into from
Dec 1, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/operators/concat_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class ConcatOp : public framework::OperatorWithKernel {

void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL,
"Inputs(X) of ConcatOp should be empty.")
"Inputs(X) of ConcatOp should be empty.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ConcatOp should not be null.");

Expand All @@ -45,7 +45,7 @@ class ConcatOp : public framework::OperatorWithKernel {
}
PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j],
"Input tensors should have the same "
"elements except the specify axis.")
"elements except the specify axis.");
}
}
ctx->SetOutputDim("Out", out_dims);
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/elementwise_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputDim("Y");
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
"Rank of first input must >= rank of second input.")
"Rank of first input must >= rank of second input.");
ctx->SetOutputDim("Out", x_dim);
ctx->ShareLoD("X", /*->*/ "Out");
}
Expand Down Expand Up @@ -120,7 +120,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));

PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input.")
"Rank of first input must >= rank of second input.");

auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/elementwise_op_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ void ElementwiseCompute(const framework::ExecutionContext& ctx) {
auto x_dims = x->dims();
auto y_dims = y->dims();
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input.")
"Rank of first input must >= rank of second input.");

if (x_dims == y_dims) {
functor f;
Expand Down
10 changes: 5 additions & 5 deletions paddle/operators/sequence_slice_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
PADDLE_ENFORCE_EQ(
n, static_cast<size_t>(length->dims()[0]),
"The size of input-sequence and length-array should be the same")
"The size of input-sequence and length-array should be the same");
PADDLE_ENFORCE_EQ(
n, static_cast<size_t>(offset->dims()[0]),
"The size of input-sequence and offset-array should be the same")
"The size of input-sequence and offset-array should be the same");

const int64_t* offset_data = offset->data<int64_t>();
const int64_t* length_data = length->data<int64_t>();
Expand All @@ -78,11 +78,11 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {

for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_LT(0, offset_data[i],
"The offset[%d] must greater than zero.", i)
"The offset[%d] must greater than zero.", i);
PADDLE_ENFORCE_LT(0, length_data[i],
"The length[%d] must greater than zero.", i)
"The length[%d] must greater than zero.", i);
PADDLE_ENFORCE_LT(lod[0][i] + offset_data[i] + length_data[i],
lod[0][i + 1], "The target tensor's length overflow.")
lod[0][i + 1], "The target tensor's length overflow.");
}

out->mutable_data<T>(ctx.GetPlace());
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/sum_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class SumKernel : public framework::OpKernel<T> {
int64_t offset = 0;
for (int i = 0; i < N; i++) {
PADDLE_ENFORCE_EQ(out->height(),
in_vars[i]->Get<SelectedRows>().height())
in_vars[i]->Get<SelectedRows>().height());
functor(context.device_context(), in_vars[i]->Get<SelectedRows>(),
offset, out);
offset += in_vars[i]->Get<SelectedRows>().value().numel();
Expand Down
28 changes: 18 additions & 10 deletions paddle/platform/enforce.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,16 +234,24 @@ inline void throw_on_error(T e) {
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <, >=, __VA_ARGS__)
#define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__)
#define PADDLE_ENFORCE_NOT_NULL(__VAL, ...) \
PADDLE_ENFORCE(nullptr != (__VAL), #__VAL " should not be null\n%s", \
paddle::string::Sprintf("" __VA_ARGS__));

#define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \
PADDLE_ENFORCE(__VAL0 __CMP __VAL1, \
"enforce %s " #__CMP " %s failed, %s " #__INV_CMP " %s\n%s", \
#__VAL0, #__VAL1, paddle::string::to_string(__VAL0), \
paddle::string::to_string(__VAL1), \
paddle::string::Sprintf("" __VA_ARGS__));
#define PADDLE_ENFORCE_NOT_NULL(__VAL, ...) \
do { \
if (UNLIKELY(nullptr == (__VAL))) { \
PADDLE_THROW(#__VAL " should not be null\n%s", \
paddle::string::Sprintf("" __VA_ARGS__)); \
} \
} while (0)

#define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \
do { \
if (!UNLIKELY((__VAL0)__CMP(__VAL1))) { \
PADDLE_THROW("enforce %s " #__CMP " %s failed, %s " #__INV_CMP \
" %s\n%s", \
#__VAL0, #__VAL1, paddle::string::to_string(__VAL0), \
paddle::string::to_string(__VAL1), \
paddle::string::Sprintf("" __VA_ARGS__)); \
} \
} while (0)

} // namespace platform
} // namespace paddle