Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
added num in/out
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam Skalicky committed Aug 18, 2019
1 parent e680d1c commit cef93e8
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 8 deletions.
16 changes: 16 additions & 0 deletions include/mxnet/lib_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ typedef int (*opRegGet_t)(int, const char**, fcomp_t*,
parseAttrs_t*, inferType_t*,
inferShape_t*);

#define MXLIB_OPCALLPARSEATTRS_STR "_opCallParseAttrs"
typedef int (*opCallParseAttrs_t)(parseAttrs_t, const char* const*, const char* const*, int,
int*, int*);

#define MXLIB_OPCALLFCOMP_STR "_opCallFCompute"
typedef int (*opCallFComp_t)(fcomp_t, const char* const*, const char* const*, int,
const int64_t**, int*, void**, int*, int,
Expand Down Expand Up @@ -220,6 +224,18 @@ extern "C" {
*shape = op.infer_shape;
}

int _opCallParseAttrs(parseAttrs_t parseAttrs, const char* const* keys, const char* const* vals, int num,
int* num_in, int* num_out) {
//create map of attributes from list
std::map<std::string,std::string> attrs;
for(int i=0; i<num; i++) {
attrs[std::string(keys[i])] = std::string(vals[i]);
}

return parseAttrs(attrs,num_in,num_out);
}


int _opCallFCompute(fcomp_t fcomp, const char* const* keys, const char* const* vals, int num,
const int64_t** inshapes, int* indims, void** indata, int* intypes, int num_in,
const int64_t** outshapes, int* outdims, void** outdata, int* outtypes, int num_out) {
Expand Down
53 changes: 45 additions & 8 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ int MXLoadLib(const char *path) {
if (!initialize(static_cast<int>(MXNET_VERSION)))
LOG(FATAL) << "Library failed to initialize";

//get function to call fcompute
//get call functions
opCallParseAttrs_t callParseAttrs = get_func<opCallParseAttrs_t>(lib, const_cast<char*>(MXLIB_OPCALLPARSEATTRS_STR));
opCallFComp_t callFComp = get_func<opCallFComp_t>(lib, const_cast<char*>(MXLIB_OPCALLFCOMP_STR));

//get number of operators registered in the library
Expand All @@ -132,21 +133,55 @@ int MXLoadLib(const char *path) {
CHECK(shape != nullptr) << "Error loading '" << name << "' custom op, InferShape function was not set.";

LOG(INFO) << "\tOp[" << i << "] " << name;

std::string name_str(name);
//generate lambda functions to convert from MXNet types to external types
auto fcomp_conv = [=](const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {

auto num_inputs = [=](const NodeAttrs& attrs) {
//convert attributes to vector of char
std::vector<const char*> attr_keys, attr_vals;
for(auto kv : attrs.dict) {
attr_keys.push_back(kv.first.c_str());
attr_vals.push_back(kv.second.c_str());
}

int num_in=-1;
int num_out=-1;
CHECK(callParseAttrs(parse, attr_keys.data(), attr_vals.data(), attr_keys.size(),
&num_in, &num_out))
<< "Error calling ParseAttrs for custom operator '" << name_str << "'";

return num_in;
};

auto num_outputs = [=](const NodeAttrs& attrs) {
//convert attributes to vector of char*
std::vector<const char*> attr_keys,attr_vals;
for(auto kv : attrs.dict) {
attr_keys.push_back(kv.first.c_str());
attr_vals.push_back(kv.second.c_str());
}

int num_in=-1;
int num_out=-1;
CHECK(callParseAttrs(parse, attr_keys.data(), attr_vals.data(), attr_keys.size(),
&num_in, &num_out))
<< "Error calling ParseAttrs for custom operator '" << name_str << "'";

return num_out;
};

// lambda function to convert from external fcompute to internal MXNet types
auto fcomp_conv = [=](const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
//convert attributes to vector of char*
std::vector<const char*> attr_keys, attr_vals;
for(auto kv : attrs.dict) {
attr_keys.push_back(kv.first.c_str());
attr_vals.push_back(kv.second.c_str());
}

std::vector<void*> in_data, out_data;
std::vector<const int64_t *> in_shapes, out_shapes;
std::vector<int> in_dims, out_dims;
Expand Down Expand Up @@ -180,6 +215,8 @@ int MXLoadLib(const char *path) {
contrib_name += name;
nnvm::Op &regOp = dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(contrib_name.c_str());
regOp.set_attr<FCompute>("FCompute<cpu>",fcomp_conv);
regOp.set_num_inputs(num_inputs);
regOp.set_num_outputs(num_outputs);
}

API_END();
Expand Down

0 comments on commit cef93e8

Please sign in to comment.