Skip to content

Commit

Permalink
Add gpu compile
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed Aug 3, 2017
1 parent 5c4e9c1 commit 9da68f3
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ void PushOperator(const OpStatePtr& state,
#if MXNET_USE_CUDA
CastNonDefaultStorage<gpu>(temp_in_src, temp_in_dst, opctx);
fcompute(state, opctx, input_blobs, req, output_blobs);
CastNonDefaultStorage<gpu>(temp_our_dst, temp_out_src, opctx);
CastNonDefaultStorage<gpu>(temp_out_dst, temp_out_src, opctx);
#else
LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
#endif
Expand Down
18 changes: 18 additions & 0 deletions src/operator/tensor/sparse_retain.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*!
* Copyright (c) 2017 by Contributors
* \file sparse_retain.cu
* \brief
*/

#include "./sparse_retain-inl.h"
namespace mxnet {
namespace op {

NNVM_REGISTER_OP(sparse_retain)
.set_attr<FComputeEx>("FComputeEx<gpu>", SparseRetainOpForwardEx<gpu>);

NNVM_REGISTER_OP(_backward_sparse_retain)
.set_attr<FComputeEx>("FComputeEx<gpu>", SparseRetainOpBackwardEx<gpu>);

} // namespace op
} // namespace mxnet

0 comments on commit 9da68f3

Please sign in to comment.