-
Notifications
You must be signed in to change notification settings - Fork 6.8k
MXNet Extensions enhancements #17885
Changes from 32 commits
60942d0
5ff9fee
0809dc2
bc4855f
b16a257
8a32e69
4c5cf56
176b38d
9fb1368
83296ae
ae9fd4b
4bcb6bf
6950ac2
6fff166
32e9458
4658e79
0223f22
9674973
9b46e9d
0a3aa90
333ce11
e620e60
af16501
c61e5a5
a28747b
8efe836
c2efe3b
17f4906
60d5210
2c370c2
8e7fb91
1d33d7e
b4633bb
c611f2b
3392610
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -726,18 +726,39 @@ endif() | |
|
||
# extension libraries (custom operators, custom subgraphs) are built by default | ||
add_library(customop_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/gemm_lib.cc) | ||
add_library(transposecsr_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/transposecsr_lib.cc) | ||
add_library(transposerowsp_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/transposerowsp_lib.cc) | ||
add_library(subgraph_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_subgraph/subgraph_lib.cc) | ||
add_library(pass_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_pass/pass_lib.cc) | ||
target_include_directories(customop_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet) | ||
target_include_directories(transposecsr_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet) | ||
target_include_directories(transposerowsp_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet) | ||
target_include_directories(subgraph_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet) | ||
target_include_directories(pass_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet) | ||
if(USE_CUDA) | ||
add_library(customop_gpu_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/relu_lib.cu) | ||
target_include_directories(customop_gpu_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet) | ||
endif() | ||
if(MSVC) | ||
if(UNIX) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. those things can be deleted There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean you don't need to add -shared even for customop_gpu_lib, just change the lines inside MSVC block |
||
target_compile_options(customop_lib PUBLIC -shared) | ||
target_compile_options(transposecsr_lib PUBLIC -shared) | ||
target_compile_options(transposerowsp_lib PUBLIC -shared) | ||
target_compile_options(subgraph_lib PUBLIC -shared) | ||
target_compile_options(pass_lib PUBLIC -shared) | ||
if (USE_CUDA) | ||
target_compile_options(customop_gpu_lib PUBLIC -shared) | ||
endif() | ||
elseif(MSVC) | ||
target_compile_options(customop_lib PUBLIC /LD) | ||
target_compile_options(transposecsr_lib PUBLIC /LD) | ||
target_compile_options(transposerowsp_lib PUBLIC /LD) | ||
target_compile_options(subgraph_lib PUBLIC /LD) | ||
target_compile_options(pass_lib PUBLIC /LD) | ||
set_target_properties(customop_lib PROPERTIES PREFIX "lib") | ||
set_target_properties(transposecsr_lib PROPERTIES PREFIX "lib") | ||
set_target_properties(transposerowsp_lib PROPERTIES PREFIX "lib") | ||
set_target_properties(subgraph_lib PROPERTIES PREFIX "lib") | ||
set_target_properties(pass_lib PROPERTIES PREFIX "lib") | ||
if(USE_CUDA) | ||
target_compile_options(customop_gpu_lib PUBLIC "$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-fPIC>") | ||
set_target_properties(customop_gpu_lib PROPERTIES PREFIX "lib") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,23 +53,23 @@ void transpose(const float* A, float* At, const unsigned n, const unsigned m) { | |
* Executes C = A * B | ||
* inputs[0] = A; inputs[1] = B; outputs[0] = C | ||
*/ | ||
MXReturnValue forward(std::map<std::string, std::string> attrs, | ||
std::vector<MXTensor> inputs, | ||
std::vector<MXTensor> outputs, | ||
OpResource res) { | ||
MXReturnValue forward(const std::unordered_map<std::string, std::string>& attrs, | ||
std::vector<MXTensor>* inputs, | ||
std::vector<MXTensor>* outputs, | ||
const OpResource& res) { | ||
// simple example of using runtime data type | ||
if (inputs[0].dtype == kFloat32) { | ||
if (inputs->at(0).dtype == kFloat32) { | ||
typedef float DType; | ||
// extract data pointers from tensors | ||
// if using dltensor repr, below lines can be changed to something like | ||
// DType* A = reinterpret_cast<DType*>(inputs[0].dltensor.data); | ||
DType* A = inputs[0].data<DType>(); | ||
DType* B = inputs[1].data<DType>(); | ||
DType* C = outputs[0].data<DType>(); | ||
DType* A = inputs->at(0).data<DType>(); | ||
DType* B = inputs->at(1).data<DType>(); | ||
DType* C = outputs->at(0).data<DType>(); | ||
// set tensor shapes | ||
unsigned n = inputs[0].shape[0]; | ||
unsigned k = inputs[0].shape[1]; | ||
unsigned m = inputs[1].shape[1]; | ||
unsigned n = inputs->at(0).shape[0]; | ||
unsigned k = inputs->at(0).shape[1]; | ||
unsigned m = inputs->at(1).shape[1]; | ||
|
||
gemm(A, B, C, n, k, m); | ||
} | ||
|
@@ -87,20 +87,20 @@ MXReturnValue forward(std::map<std::string, std::string> attrs, | |
***** gradient outputs | ||
* outputs[0] = dA; outputs[1] = dB | ||
*/ | ||
MXReturnValue backward(std::map<std::string, std::string> attrs, | ||
std::vector<MXTensor> inputs, | ||
std::vector<MXTensor> outputs, | ||
OpResource res) { | ||
MXReturnValue backward(const std::unordered_map<std::string, std::string>& attrs, | ||
std::vector<MXTensor>* inputs, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't forget to update this change to lib_custom_op README There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
std::vector<MXTensor>* outputs, | ||
const OpResource& res) { | ||
// extract data pointers from tensors | ||
float* dC = inputs[0].data<float>(); | ||
float* A = inputs[1].data<float>(); | ||
float* B = inputs[2].data<float>(); | ||
float* dA = outputs[0].data<float>(); | ||
float* dB = outputs[1].data<float>(); | ||
float* dC = inputs->at(0).data<float>(); | ||
float* A = inputs->at(1).data<float>(); | ||
float* B = inputs->at(2).data<float>(); | ||
float* dA = outputs->at(0).data<float>(); | ||
float* dB = outputs->at(1).data<float>(); | ||
// set tensor shapes | ||
unsigned n = inputs[1].shape[0]; | ||
unsigned k = inputs[1].shape[1]; | ||
unsigned m = inputs[2].shape[1]; | ||
unsigned n = inputs->at(1).shape[0]; | ||
unsigned k = inputs->at(1).shape[1]; | ||
unsigned m = inputs->at(2).shape[1]; | ||
// allocate temporary workspace memory through resource manager | ||
// for multiple arrays better to request a big memory pool | ||
void *workspace = res.alloc_cpu((k*n + m*k) * sizeof(float)); | ||
|
@@ -115,15 +115,16 @@ MXReturnValue backward(std::map<std::string, std::string> attrs, | |
return MX_SUCCESS; | ||
} | ||
|
||
MXReturnValue parseAttrs(std::map<std::string, std::string> attrs, int* num_in, int* num_out) { | ||
MXReturnValue parseAttrs(const std::unordered_map<std::string, std::string>& attrs, | ||
int* num_in, int* num_out) { | ||
*num_in = 2; | ||
*num_out = 1; | ||
return MX_SUCCESS; | ||
} | ||
|
||
MXReturnValue inferType(std::map<std::string, std::string> attrs, | ||
std::vector<int> &intypes, | ||
std::vector<int> &outtypes) { | ||
MXReturnValue inferType(const std::unordered_map<std::string, std::string>& attrs, | ||
const std::vector<int> &intypes, | ||
std::vector<int> *outtypes) { | ||
// validate inputs | ||
if (intypes.size() != 2) { | ||
std::cout << "Expected 2 inputs to inferType" << std::endl; | ||
|
@@ -136,13 +137,13 @@ MXReturnValue inferType(std::map<std::string, std::string> attrs, | |
} | ||
} | ||
|
||
outtypes[0] = intypes[0]; | ||
outtypes->at(0) = intypes[0]; | ||
return MX_SUCCESS; | ||
} | ||
|
||
MXReturnValue inferShape(std::map<std::string, std::string> attrs, | ||
std::vector<std::vector<unsigned int>> &inshapes, | ||
std::vector<std::vector<unsigned int>> &outshapes) { | ||
MXReturnValue inferShape(const std::unordered_map<std::string, std::string>& attrs, | ||
const std::vector<std::vector<unsigned int>>& inshapes, | ||
ptrendx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
std::vector<std::vector<unsigned int>>* outshapes) { | ||
// validate inputs | ||
if (inshapes.size() != 2) { | ||
std::cout << "Expected 2 inputs to inferShape" << std::endl; | ||
|
@@ -162,7 +163,7 @@ MXReturnValue inferShape(std::map<std::string, std::string> attrs, | |
return MX_FAIL; | ||
} | ||
|
||
outshapes[0] = {n, m}; | ||
outshapes->at(0) = {n, m}; | ||
return MX_SUCCESS; | ||
} | ||
|
||
|
@@ -177,41 +178,42 @@ REGISTER_OP(my_gemm) | |
|
||
class MyStatefulGemm : public CustomStatefulOp { | ||
public: | ||
explicit MyStatefulGemm(int count) : count(count) {} | ||
explicit MyStatefulGemm(int count, | ||
const std::unordered_map<std::string, std::string>& attrs) | ||
: count(count), attrs_(attrs) {} | ||
|
||
MXReturnValue Forward(std::vector<MXTensor> inputs, | ||
std::vector<MXTensor> outputs, | ||
OpResource op_res) { | ||
MXReturnValue Forward(std::vector<MXTensor>* inputs, | ||
std::vector<MXTensor>* outputs, | ||
const OpResource& op_res) { | ||
std::cout << "Info: keyword + number of forward: " << ++count << std::endl; | ||
std::map<std::string, std::string> attrs; | ||
return forward(attrs, inputs, outputs, op_res); | ||
return forward(attrs_, inputs, outputs, op_res); | ||
} | ||
|
||
MXReturnValue Backward(std::vector<MXTensor> inputs, | ||
std::vector<MXTensor> outputs, | ||
OpResource op_res) { | ||
std::map<std::string, std::string> attrs; | ||
return backward(attrs, inputs, outputs, op_res); | ||
MXReturnValue Backward(std::vector<MXTensor>* inputs, | ||
std::vector<MXTensor>* outputs, | ||
const OpResource& op_res) { | ||
return backward(attrs_, inputs, outputs, op_res); | ||
} | ||
|
||
~MyStatefulGemm() {} | ||
|
||
private: | ||
int count; | ||
const std::unordered_map<std::string, std::string> attrs_; | ||
}; | ||
|
||
MXReturnValue createOpState(std::map<std::string, std::string> attrs, | ||
MXReturnValue createOpState(const std::unordered_map<std::string, std::string>& attrs, | ||
CustomStatefulOp** op_inst) { | ||
// testing passing of keyword arguments | ||
int count = attrs.count("test_kw") > 0 ? std::stoi(attrs["test_kw"]) : 0; | ||
int count = attrs.count("test_kw") > 0 ? std::stoi(attrs.at("test_kw")) : 0; | ||
// creating stateful operator instance | ||
*op_inst = new MyStatefulGemm(count); | ||
*op_inst = new MyStatefulGemm(count, attrs); | ||
std::cout << "Info: stateful operator created" << std::endl; | ||
return MX_SUCCESS; | ||
} | ||
|
||
MXReturnValue mutateInputs(std::map<std::string, std::string> attrs, | ||
std::vector<int> &input_indices) { | ||
MXReturnValue mutateInputs(const std::unordered_map<std::string, std::string>& attrs, | ||
std::vector<int>* input_indices) { | ||
// input_indices.push_back(1); // mark mutate input | ||
return MX_SUCCESS; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there some way to automatically discover all the .cc files in the lib_custom_op folder instead of hardcoding all the paths?