Skip to content

Commit

Permalink
Enable np op compat check with name prefix (apache#14897)
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce authored and haojin2 committed Jul 26, 2019
1 parent fbf0151 commit 3adf63e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
17 changes: 16 additions & 1 deletion src/c_api/c_api_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,25 @@ inline void CopyAttr(const nnvm::IndexedGraph& idx,
extern const std::vector<std::string> kHiddenKeys;
} // namespace mxnet

/*!
* An operator is considered as numpy compatible if it satisfies either one
* of the following conditions.
* 1. The op has the attribute mxnet::TIsNumpyCompatible> registered as True.
* 2. The op's name starts with the prefix _numpy_.
* The first condition is usually for the ops registered as internal ops, such
* as _np_add, _true_divide, etc. They are wrapped by some user-facing op
* APIs in the Python end.
* The second condition is for the ops registered in the backend while exposed
* directly to users as is, such as _numpy_sum etc.
*/
inline bool IsNumpyCompatOp(const nnvm::Op* op) {
static const auto& is_np_compat =
nnvm::Op::GetAttr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible");
return is_np_compat.get(op, false);
if (is_np_compat.get(op, false)) {
return true;
}
static const std::string prefix = "_numpy_";
return op->name.find(prefix.c_str(), 0, prefix.size()) != std::string::npos;
}

#endif // MXNET_C_API_C_API_COMMON_H_
3 changes: 1 addition & 2 deletions src/operator/numpy/np_broadcast_reduce_op_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ NNVM_REGISTER_OP(_numpy_sum)
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_numpy_sum"})
.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_numpy_sum"});

NNVM_REGISTER_OP(_backward_numpy_sum)
.set_num_outputs(1)
Expand Down

0 comments on commit 3adf63e

Please sign in to comment.