From 68e86f2128cfb6dfe23201f2804b7f336e68c295 Mon Sep 17 00:00:00 2001 From: reminisce Date: Mon, 6 May 2019 16:56:36 -0700 Subject: [PATCH] Enable np op compat check with name prefix (#14897) --- src/c_api/c_api_common.h | 17 ++++++++++++++++- .../numpy/np_broadcast_reduce_op_value.cc | 3 +-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h index 118341d4ef1f..ab1f5f71da99 100644 --- a/src/c_api/c_api_common.h +++ b/src/c_api/c_api_common.h @@ -163,10 +163,25 @@ inline void CopyAttr(const nnvm::IndexedGraph& idx, extern const std::vector kHiddenKeys; } // namespace mxnet +/*! + * An operator is considered as numpy compatible if it satisfies either one + * of the following conditions. + * 1. The op has the attribute mxnet::TIsNumpyCompatible> registered as True. + * 2. The op's name starts with the prefix _numpy_. + * The first condition is usually for the ops registered as internal ops, such + * as _np_add, _true_divide, etc. They are wrapped by some user-facing op + * APIs in the Python end. + * The second condition is for the ops registered in the backend while exposed + * directly to users as is, such as _numpy_sum etc. + */ inline bool IsNumpyCompatOp(const nnvm::Op* op) { static const auto& is_np_compat = nnvm::Op::GetAttr("TIsNumpyCompatible"); - return is_np_compat.get(op, false); + if (is_np_compat.get(op, false)) { + return true; + } + static const std::string prefix = "_numpy_"; + return op->name.find(prefix.c_str(), 0, prefix.size()) != std::string::npos; } #endif // MXNET_C_API_C_API_COMMON_H_ diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc b/src/operator/numpy/np_broadcast_reduce_op_value.cc index 13b575a6674a..6c81bf6e5de8 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_value.cc +++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc @@ -65,8 +65,7 @@ NNVM_REGISTER_OP(_numpy_sum) [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) -.set_attr("FGradient", ElemwiseGradUseNone{"_backward_numpy_sum"}) -.set_attr("TIsNumpyCompatible", true); +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_numpy_sum"}); NNVM_REGISTER_OP(_backward_numpy_sum) .set_num_outputs(1)