From cef93e81ff2883e15cd5e04ab395701c90136ff7 Mon Sep 17 00:00:00 2001 From: Sam Skalicky Date: Sun, 18 Aug 2019 07:41:27 +0000 Subject: [PATCH] added num in/out --- include/mxnet/lib_api.h | 16 +++++++++++++ src/c_api/c_api.cc | 53 ++++++++++++++++++++++++++++++++++------- 2 files changed, 61 insertions(+), 8 deletions(-) diff --git a/include/mxnet/lib_api.h b/include/mxnet/lib_api.h index 4ea284e39754..47ed086333fd 100644 --- a/include/mxnet/lib_api.h +++ b/include/mxnet/lib_api.h @@ -190,6 +190,10 @@ typedef int (*opRegGet_t)(int, const char**, fcomp_t*, parseAttrs_t*, inferType_t*, inferShape_t*); +#define MXLIB_OPCALLPARSEATTRS_STR "_opCallParseAttrs" +typedef int (*opCallParseAttrs_t)(parseAttrs_t, const char* const*, const char* const*, int, + int*, int*); + #define MXLIB_OPCALLFCOMP_STR "_opCallFCompute" typedef int (*opCallFComp_t)(fcomp_t, const char* const*, const char* const*, int, const int64_t**, int*, void**, int*, int, @@ -220,6 +224,18 @@ extern "C" { *shape = op.infer_shape; } + int _opCallParseAttrs(parseAttrs_t parseAttrs, const char* const* keys, const char* const* vals, int num, + int* num_in, int* num_out) { + //create map of attributes from list + std::map attrs; + for(int i=0; i(MXNET_VERSION))) LOG(FATAL) << "Library failed to initialize"; - //get function to call fcompute + //get call functions + opCallParseAttrs_t callParseAttrs = get_func(lib, const_cast(MXLIB_OPCALLPARSEATTRS_STR)); opCallFComp_t callFComp = get_func(lib, const_cast(MXLIB_OPCALLFCOMP_STR)); //get number of operators registered in the library @@ -132,14 +133,26 @@ int MXLoadLib(const char *path) { CHECK(shape != nullptr) << "Error loading '" << name << "' custom op, InferShape function was not set."; LOG(INFO) << "\tOp[" << i << "] " << name; - std::string name_str(name); - //generate lambda functions to convert from MXNet types to external types - auto fcomp_conv = [=](const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { + + auto num_inputs = [=](const NodeAttrs& attrs) { + //convert attributes to vector of char + std::vector attr_keys, attr_vals; + for(auto kv : attrs.dict) { + attr_keys.push_back(kv.first.c_str()); + attr_vals.push_back(kv.second.c_str()); + } + + int num_in=-1; + int num_out=-1; + CHECK(callParseAttrs(parse, attr_keys.data(), attr_vals.data(), attr_keys.size(), + &num_in, &num_out)) + << "Error calling ParseAttrs for custom operator '" << name_str << "'"; + + return num_in; + }; + + auto num_outputs = [=](const NodeAttrs& attrs) { //convert attributes to vector of char* std::vector attr_keys,attr_vals; for(auto kv : attrs.dict) { @@ -147,6 +160,28 @@ int MXLoadLib(const char *path) { attr_vals.push_back(kv.second.c_str()); } + int num_in=-1; + int num_out=-1; + CHECK(callParseAttrs(parse, attr_keys.data(), attr_vals.data(), attr_keys.size(), + &num_in, &num_out)) + << "Error calling ParseAttrs for custom operator '" << name_str << "'"; + + return num_out; + }; + + // lambda function to convert from external fcompute to internal MXNet types + auto fcomp_conv = [=](const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + //convert attributes to vector of char* + std::vector attr_keys, attr_vals; + for(auto kv : attrs.dict) { + attr_keys.push_back(kv.first.c_str()); + attr_vals.push_back(kv.second.c_str()); + } + std::vector in_data, out_data; std::vector in_shapes, out_shapes; std::vector in_dims, out_dims; @@ -180,6 +215,8 @@ int MXLoadLib(const char *path) { contrib_name += name; nnvm::Op ®Op = dmlc::Registry::Get()->__REGISTER_OR_GET__(contrib_name.c_str()); regOp.set_attr("FCompute",fcomp_conv); + regOp.set_num_inputs(num_inputs); + regOp.set_num_outputs(num_outputs); } API_END();