Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Enhancements for MXTensor for custom operators #17204

Merged
merged 16 commits into from
Jan 8, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 35 additions & 28 deletions include/mxnet/lib_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
#include <utility>
#include <stdexcept>

#define MX_LIBRARY_VERSION 1
#define MX_LIBRARY_VERSION 2

/*
* Import from DLPack https://github.com/dmlc/dlpack/blob/master/include/dlpack/dlpack.h
Expand Down Expand Up @@ -198,6 +198,7 @@ enum MXDType {
kInt32 = 4,
kInt8 = 5,
kInt64 = 6,
kUNSET = 100,
};

enum MXReturnValue {
Expand All @@ -209,10 +210,15 @@ enum MXReturnValue {
* \brief Tensor data structure used by custom operator
*/
struct MXTensor {
MXTensor() : data_ptr(NULL) {}
MXTensor() : data_ptr(NULL), dtype(kUNSET), version(0) {}

MXTensor(void *data_ptr, const std::vector<int64_t> &shape, MXDType dtype)
: data_ptr(data_ptr), shape(shape), dtype(dtype) {}
MXTensor(void *data_ptr, const std::vector<int64_t> &shape, MXDType dtype,
size_t ID)
: data_ptr(data_ptr), shape(shape), dtype(dtype), version(ID) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it will be better to unify the naming across all places, like using verID in lib_api.h and here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


void update(void *dptr, MXDType type, size_t ver) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need this function? it doesn't have any checks, only copy pointers. I think we can copy them line by line in lib_api.h and keep MXTensor as simple as possible

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

per our discussion, lets move the the for loop to copy shape and call to setDLTensor inside this function. Change name to "setTensor"

data_ptr = dptr; dtype = type; version = ver;
}

/*! \brief populate DLTensor fields */
void setDLTensor() {
Expand Down Expand Up @@ -277,6 +283,14 @@ struct MXTensor {
return size;
}

/*! \brief helper function to compare two MXTensors */
inline bool isSame(const MXTensor &oth) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we override operator==? since we won't support C anyway

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think operator== is confusing. For a tensor object, == usually means value comparison.
In the future, we may add other operators !=, <, >, etc.
It may be better and more consistent with MXNet NDArray API to use ‘isSame’.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comparing object is how c++ is doing for vector and usually for struct, and in NDArray we don't have operators !=, <, > either, so I don't think it is going to be confusing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I agree with @wkcn here, == should compare the values of the tensor not the "state" of the tensor (data_ptr, versionID, etc)

@mseth10 @eric-haibin-lin @haojin2 what do you guys think?

Copy link
Member

@wkcn wkcn Jan 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For C++ vector container, operator== compares the values.

#include <iostream>
#include <vector>
using namespace std;

int main() {
  vector<int> a{1,2,3};
  vector<int> b{1,2,3};
  vector<int> c{1,2,4};
  cout << (a == b) << endl; // 1
  cout << (a == c) << endl; // 0
  cout << (a.data() == b.data()) << endl; // 0
  return 0;
}

Although (a==b) is 1, a.data() is not equal to b.data().

return data_ptr == oth.data_ptr &&
dtype == oth.dtype &&
version == oth.version &&
shape == oth.shape;
}

// data is flatten 1D repr of tensor, elements are in continuous memory
// user can access each element using the shape of tensor
void *data_ptr;
Expand All @@ -287,6 +301,9 @@ struct MXTensor {
// type can only be MXDType enum types
MXDType dtype;

// version number updated if the tensor has changed since the last use by custom op
size_t version;

// corresponding DLTensor repr of MXTensor
// easy way to reuse functions taking DLTensor
DLTensor dltensor;
Expand Down Expand Up @@ -684,15 +701,9 @@ typedef int (*opCallInferType_t)(inferType_t, const char* const*, const char* co

#define MXLIB_OPCALLFCOMP_STR "_opCallFCompute"
typedef int (*opCallFComp_t)(fcomp_t, const char* const*, const char* const*, int,
const int64_t**, int*, void**, int*, int,
const int64_t**, int*, void**, int*, int,
xpu_malloc_t, void*);

#define MXLIB_OPCALLBKWD_STR "_opCallBackward"
typedef int (*opCallBkwd_t)(fcomp_t, const char* const*, const char* const*, int,
const int64_t**, int*, void**, int*, int,
const int64_t**, int*, void**, int*, int,
xpu_malloc_t, void*);
const int64_t**, int*, void**, int*, size_t*, int,
const int64_t**, int*, void**, int*, size_t*, int,
xpu_malloc_t, void*);

#define MXLIB_OPCALLMUTATEINPUTS_STR "_opCallMutateInputs"
typedef int (*opCallMutateInputs_t)(mutateInputs_t, const char* const*, const char* const*, int,
Expand All @@ -703,9 +714,9 @@ typedef int (*opCallCreateOpState_t)(createOpState_t, const char* const*, const
void**);

#define MXLIB_OPCALLFSTATEFULCOMP_STR "_opCallFStatefulCompute"
typedef int (*opCallFStatefulComp_t)(bool, void*, const int64_t**, int*, void**, int*, int,
const int64_t**, int*, void**, int*, int,
xpu_malloc_t, void*);
typedef int (*opCallFStatefulComp_t)(bool, void*, const int64_t**, int*, void**, int*, size_t*,
int, const int64_t**, int*, void**, int*, size_t*,
int, xpu_malloc_t, void*);

#define MXLIB_INITIALIZE_STR "initialize"
typedef int (*initialize_t)(int);
Expand Down Expand Up @@ -876,9 +887,9 @@ extern "C" {
_opCallFCompute(fcomp_t fcomp, const char* const* keys,
const char* const* vals, int num,
const int64_t** inshapes, int* indims,
void** indata, int* intypes, int num_in,
void** indata, int* intypes, size_t* inIDs, int num_in,
const int64_t** outshapes, int* outdims,
void** outdata, int* outtypes, int num_out,
void** outdata, int* outtypes, size_t* outIDs, int num_out,
xpu_malloc_t cpu_malloc, void* cpu_alloc) {
// create map of attributes from list
std::map<std::string, std::string> attrs;
Expand All @@ -889,8 +900,7 @@ extern "C" {
// create a vector of tensors for inputs
std::vector<MXTensor> inputs(num_in);
for (int i = 0; i < num_in; i++) {
inputs[i].data_ptr = indata[i];
inputs[i].dtype = (MXDType)intypes[i];
inputs[i].update(indata[i], (MXDType)intypes[i], inIDs[i]);
for (int j = 0; j < indims[i]; j++) {
inputs[i].shape.push_back(inshapes[i][j]);
}
Expand All @@ -900,8 +910,7 @@ extern "C" {
// create a vector of tensors for outputs
std::vector<MXTensor> outputs(num_out);
for (int i = 0; i < num_out; i++) {
outputs[i].data_ptr = outdata[i];
outputs[i].dtype = (MXDType) outtypes[i];
outputs[i].update(outdata[i], (MXDType)outtypes[i], outIDs[i]);
for (int j = 0; j < outdims[i]; j++) {
outputs[i].shape.push_back(outshapes[i][j]);
}
Expand Down Expand Up @@ -973,15 +982,14 @@ extern "C" {
#endif
_opCallFStatefulCompute(bool is_forward, void* state_op,
const int64_t** inshapes, int* indims,
void** indata, int* intypes, int num_in,
void** indata, int* intypes, size_t* inIDs, int num_in,
const int64_t** outshapes, int* outdims,
void** outdata, int* outtypes, int num_out,
void** outdata, int* outtypes, size_t* outIDs, int num_out,
xpu_malloc_t cpu_malloc, void* cpu_alloc) {
// create a vector of tensors for inputs
std::vector<MXTensor> inputs(num_in);
for (int i = 0; i < num_in; i++) {
inputs[i].data_ptr = indata[i];
inputs[i].dtype = (MXDType)intypes[i];
inputs[i].update(indata[i], (MXDType)intypes[i], inIDs[i]);
for (int j = 0; j < indims[i]; j++) {
inputs[i].shape.push_back(inshapes[i][j]);
}
Expand All @@ -991,8 +999,7 @@ extern "C" {
// create a vector of tensors for outputs
std::vector<MXTensor> outputs(num_out);
for (int i = 0; i < num_out; i++) {
outputs[i].data_ptr = outdata[i];
outputs[i].dtype = (MXDType) outtypes[i];
outputs[i].update(outdata[i], (MXDType)outtypes[i], outIDs[i]);
for (int j = 0; j < outdims[i]; j++) {
outputs[i].shape.push_back(outshapes[i][j]);
}
Expand Down
17 changes: 12 additions & 5 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -395,13 +395,15 @@ int MXLoadLib(const char *path) {
std::vector<const int64_t *> in_shapes, out_shapes;
std::vector<int> in_dims, out_dims;
std::vector<int> in_types, out_types;
std::vector<size_t> in_versions, out_versions;

// convert input tensors to constituent parts
for (size_t i = 0; i < inputs.size(); i++) {
in_data.push_back(inputs[i].data().dptr_);
in_shapes.push_back(inputs[i].shape().data());
in_dims.push_back(inputs[i].shape().ndim());
in_types.push_back(inputs[i].dtype());
in_versions.push_back(inputs[i].version());
}

// convert output tensors to constituent parts
Expand All @@ -410,6 +412,7 @@ int MXLoadLib(const char *path) {
out_shapes.push_back(outputs[i].shape().data());
out_dims.push_back(outputs[i].shape().ndim());
out_types.push_back(outputs[i].dtype());
out_versions.push_back(outputs[i].version());
}

// get memory resource
Expand Down Expand Up @@ -438,9 +441,10 @@ int MXLoadLib(const char *path) {
// call fcompute function
CHECK(callFComp(fcomp_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(),
in_shapes.data(), in_dims.data(), in_data.data(),
in_types.data(), in_data.size(),
in_types.data(), in_versions.data(), in_data.size(),
out_shapes.data(), out_dims.data(), out_data.data(),
out_types.data(), out_data.size(), cpu_malloc, &cpu_alloc))
out_types.data(), out_versions.data(), out_data.size(),
cpu_malloc, &cpu_alloc))
<< "Error calling FCompute for custom operator '" << name_str << "'";

// return type void
Expand Down Expand Up @@ -570,13 +574,15 @@ int MXLoadLib(const char *path) {
std::vector<const int64_t *> in_shapes, out_shapes;
std::vector<int> in_dims, out_dims;
std::vector<int> in_types, out_types;
std::vector<size_t> in_versions, out_versions;

// convert input tensors to constituent parts
for (size_t i = 0; i < inputs.size(); i++) {
in_data.push_back(inputs[i].data().dptr_);
in_shapes.push_back(inputs[i].shape().data());
in_dims.push_back(inputs[i].shape().ndim());
in_types.push_back(inputs[i].dtype());
in_versions.push_back(inputs[i].version());
}

// convert output tensors to constituent parts
Expand All @@ -585,6 +591,7 @@ int MXLoadLib(const char *path) {
out_shapes.push_back(outputs[i].shape().data());
out_dims.push_back(outputs[i].shape().ndim());
out_types.push_back(outputs[i].dtype());
out_versions.push_back(outputs[i].version());
}

// get memory resource
Expand Down Expand Up @@ -618,9 +625,9 @@ int MXLoadLib(const char *path) {

// call fcompute function
CHECK(callFStatefulComp(is_forward, state_op_inst, in_shapes.data(), in_dims.data(),
in_data.data(), in_types.data(), in_data.size(),
out_shapes.data(), out_dims.data(), out_data.data(),
out_types.data(), out_data.size(), cpu_malloc, &cpu_alloc))
in_data.data(), in_types.data(), in_versions.data(), in_data.size(),
out_shapes.data(), out_dims.data(), out_data.data(), out_types.data(),
out_versions.data(), out_data.size(), cpu_malloc, &cpu_alloc))
<< "Error calling FStatefulCompute for custom operator '" << name_str << "'";
};

Expand Down