From d1ca7b6e424718ee534646d0b45bd939f195f8f4 Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Tue, 6 Jun 2017 22:19:55 +0000 Subject: [PATCH] draft for sgd rsp rsp support sgd(rsp, rsp) support dot(csr, rsp) when rsp is full add ref to const ndarray params support sparse embedding with rsp weight' fix lint modify embedding backward to produce dense grad remove invalid_rid for rsp->dns remove previous embedding op changes pass sparse embedding test add STORAGE_TYPE_ASSIGN_CHECK remove backward storage infer --- python/mxnet/optimizer.py | 4 +- python/mxnet/sparse_ndarray.py | 6 +- src/operator/operator_common.h | 31 +++ src/operator/optimizer_op-inl.h | 159 +++++++++---- src/operator/tensor/elemwise_unary_op.h | 9 +- src/operator/tensor/indexing_op.cc | 26 ++- src/operator/tensor/indexing_op.h | 208 ++++++++---------- src/operator/tensor/matrix_op-inl.h | 79 +++++-- tests/python/unittest/test_optimizer.py | 20 +- tests/python/unittest/test_sparse_ndarray.py | 1 + tests/python/unittest/test_sparse_operator.py | 46 ++-- 11 files changed, 365 insertions(+), 224 deletions(-) diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index 1f7b1d3aed1b..04107128cf4b 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -4,6 +4,7 @@ import logging from .ndarray import NDArray, zeros, clip, sqrt, sign from .ndarray import sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update +from .sparse_ndarray import zeros as sparse_zeros from .random import normal @@ -332,7 +333,8 @@ def create_state(self, index, weight): if self.momentum == 0.0: return None else: - return zeros(weight.shape, weight.context, dtype=weight.dtype) + return sparse_zeros(weight.storage_type, weight.shape, + weight.context, dtype=weight.dtype) def update(self, index, weight, grad, state): assert(isinstance(weight, NDArray)) diff --git a/python/mxnet/sparse_ndarray.py b/python/mxnet/sparse_ndarray.py index 79351b1eb371..bc06fc1d1113 100644 --- a/python/mxnet/sparse_ndarray.py +++ b/python/mxnet/sparse_ndarray.py @@ -571,7 +571,7 @@ def to_dense(source): """ return ndarray.cast_storage(source, storage_type='default') -def zeros(storage_type, shape, ctx=None, dtype=None, aux_types=None): +def zeros(storage_type, shape, ctx=None, dtype=None, aux_types=None, **kwargs): """Return a new array of given shape and type, filled with zeros. Parameters @@ -599,6 +599,8 @@ def zeros(storage_type, shape, ctx=None, dtype=None, aux_types=None): >>> mx.sparse_nd.zeros('row_sparse', (1,2), mx.gpu(0), 'float16').asnumpy() array([[ 0., 0.]], dtype=float16) """ + if storage_type == 'default': + return ndarray.zeros(shape, ctx, dtype, **kwargs) if ctx is None: ctx = Context.default_ctx dtype = mx_real_t if dtype is None else dtype @@ -609,7 +611,7 @@ def zeros(storage_type, shape, ctx=None, dtype=None, aux_types=None): raise Exception("unknown storage type") assert(len(aux_types) == len(_STORAGE_AUX_TYPES[storage_type])) out = _ndarray_cls(_new_alloc_handle(storage_type, shape, ctx, True, dtype, aux_types)) - return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, out=out) + return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, out=out, **kwargs) def _ndarray_cls(handle, writable=True): stype = _storage_type(handle) diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index ecfb9c76acb3..a6d78c2558be 100755 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -110,6 +110,19 @@ inline std::string type_string(const int& x) { return "unknown"; } +/*! \brief get string representation of storage_type */ +inline std::string stype_string(const int& x) { + switch (x) { + case kDefaultStorage: + return "default"; + case kCSRStorage: + return "csr"; + case kRowSparseStorage: + return "row_sparse"; + } + return "unknown"; +} + /*! * \brief Assign x to y. Checks for compatiblity when y is not empty. * Allow missing dim in both x and y (as 0). @@ -186,6 +199,24 @@ inline bool type_assign(int *y, const int& x) { } \ } +/*! + * \brief macro assign type to out if out is unknown (-1) otherwise check consistency + * Use macro so we can see the error file more clearly + * \param type_array the storage type array to store the result + * \param index the index of in the array + * \param type the inferred storage type + */ +#define STORAGE_TYPE_ASSIGN_CHECK(type_array, index, type) \ + { \ + if (!type_assign(&(type_array)[index], type)) { \ + std::ostringstream os; \ + os << "Storage type inconsistent, Provided=" \ + << stype_string((type_array)[index]) << ',' \ + << " inferred storage type=" << stype_string(type); \ + throw ::mxnet::op::InferTypeError(os.str(), index); \ + } \ + } + // helper macro to implement bind dispatch #if MXNET_USE_CUDA #define DO_BIND_DISPATCH(Method, ...) \ diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index 83a4a9cfccbb..d6d8ccc37c53 100755 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -112,32 +112,31 @@ struct SGDDnsRspKernel { template inline void SGDUpdateDnsRspImpl(const SGDParam& param, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + const OpContext &ctx, + const TBlob& weight, + const NDArray& grad, + const OpReqType& req, + TBlob *out) { using namespace mshadow; using namespace mshadow::expr; using namespace mshadow_op; + using namespace mxnet_op; Stream* s = ctx.get_stream(); - auto &weight = inputs[0]; - auto &grad = inputs[1]; - auto &out = outputs[0]; - CHECK_EQ(weight.storage_type(), kDefaultStorage); CHECK_EQ(grad.storage_type(), kRowSparseStorage); - if (!grad.storage_initialized()) return; + // if gradients are zeros, no weights are updated + if (!grad.storage_initialized() || req == kNullOp) return; + CHECK_GT(weight.shape_.Size(), 0); - MSHADOW_REAL_TYPE_SWITCH(weight.dtype(), DType, { + MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, { MSHADOW_INT_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, { - MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { - auto weight_data = weight.data().FlatTo2D(s); - auto grad_idx = grad.aux_data(rowsparse::kIdx).FlatTo1D(s); - auto grad_val = grad.data().FlatTo2D(s); - auto out_data = out.data().FlatTo2D(s); + MXNET_ASSIGN_REQ_SWITCH(req, req_type, { + auto weight_data = weight.dptr(); + auto grad_idx = grad.aux_data(rowsparse::kIdx).dptr(); + auto grad_val = grad.data().dptr(); auto num_rows = grad.aux_shape(rowsparse::kIdx)[0]; - auto width = weight.shape().ProdShape(1, weight.shape().ndim()); - mxnet_op::Kernel, xpu>::Launch(s, num_rows, width, - out_data.dptr_, weight_data.dptr_, grad_idx.dptr_, grad_val.dptr_, + auto width = weight.shape_.ProdShape(1, weight.ndim()); + Kernel, xpu>::Launch(s, num_rows, width, + out->dptr(), weight_data, grad_idx, grad_val, static_cast(param.clip_gradient), static_cast(param.lr), static_cast(param.wd), static_cast(param.rescale_grad)); @@ -146,6 +145,29 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param, }); } +template +inline void SGDUpdateRspRspImpl(const SGDParam& param, + const OpContext& ctx, + const NDArray& weight, + const NDArray& grad, + const OpReqType& req, + NDArray *out) { + if (weight.storage_shape()[0] == weight.shape()[0] && + out->storage_shape()[0] == out->shape()[0]) { + // TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only + // feed in kWriteTo as req for all operators. + // For sgd we don't want to assign zeros to the output values when req == kWriteTo + auto out_req = req; + if (out_req == kWriteTo) out_req = kWriteInplace; + // reuse dns rsp implementation when storage_shape == shape + TBlob out_blob = out->data(); + SGDUpdateDnsRspImpl(param, ctx, weight.data(), grad, out_req, &out_blob); + } else { + LOG(FATAL) << "SGDUpdate for RowSparse weights is only implemented when " + << "weights.values.shape == weights.shape"; + } +} + template inline void SGDUpdateEx(const nnvm::NodeAttrs& attrs, const OpContext &ctx, @@ -159,7 +181,11 @@ inline void SGDUpdateEx(const nnvm::NodeAttrs& attrs, auto weight_stype = inputs[0].storage_type(); auto grad_stype = inputs[1].storage_type(); if (weight_stype == kDefaultStorage && grad_stype == kRowSparseStorage) { - SGDUpdateDnsRspImpl(param, ctx, inputs, req, outputs); + TBlob out = outputs[0].data(); + SGDUpdateDnsRspImpl(param, ctx, inputs[0].data(), inputs[1], req[0], &out); + } else if (weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage) { + NDArray out = outputs[0]; + SGDUpdateRspRspImpl(param, ctx, inputs[0], inputs[1], req[0], &out); } else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage) { FCompExFallback(attrs, ctx, inputs, req, outputs, SGDUpdate, "SGDUpdate"); } @@ -262,30 +288,31 @@ struct SGDMomDnsRspDnsKernel { template inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + const OpContext& ctx, + const TBlob& weight, + const NDArray& grad, + const TBlob& mom, + const OpReqType& req, + TBlob *out) { using namespace mxnet_op; + using namespace rowsparse; Stream* s = ctx.get_stream(); - auto &weight = inputs[0]; - auto &grad = inputs[1]; - auto &mom = inputs[2]; - auto &out = outputs[0]; - if (!grad.storage_initialized()) return; + if (!grad.storage_initialized() || req == kNullOp) return; + CHECK_GT(weight.shape_.Size(), 0); + CHECK_GT(mom.shape_.Size(), 0); - MSHADOW_REAL_TYPE_SWITCH(weight.dtype(), DType, { - MSHADOW_INT_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, { - MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { - auto weight_data = weight.data().FlatTo2D(s); - auto grad_idx = grad.aux_data(rowsparse::kIdx).FlatTo1D(s); - auto grad_val = grad.data().FlatTo2D(s); - auto mom_data = mom.data().FlatTo2D(s); - auto out_data = out.data().FlatTo2D(s); - auto num_rows = grad.aux_shape(rowsparse::kIdx)[0]; - auto width = weight.shape().ProdShape(1, weight.shape().ndim()); + MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, { + MSHADOW_INT_TYPE_SWITCH(grad.aux_type(kIdx), IType, { + MXNET_ASSIGN_REQ_SWITCH(req, req_type, { + auto weight_data = weight.dptr(); + auto grad_idx = grad.aux_data(kIdx).dptr(); + auto grad_val = grad.data().dptr(); + auto mom_data = mom.dptr(); + auto out_data = out->dptr(); + auto num_rows = grad.aux_shape(kIdx)[0]; + auto width = weight.shape_.ProdShape(1, weight.ndim()); Kernel, xpu>::Launch(s, num_rows, width, - out_data.dptr_, mom_data.dptr_, weight_data.dptr_, grad_idx.dptr_, grad_val.dptr_, + out_data, mom_data, weight_data, grad_idx, grad_val, static_cast(param.clip_gradient), static_cast(param.momentum), static_cast(param.lr), static_cast(param.wd), static_cast(param.rescale_grad)); @@ -294,6 +321,50 @@ inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param, }); } +template +inline void SGDMomUpdateRspRspRspImpl(const SGDMomParam& param, + const OpContext& ctx, + const NDArray& weight, + const NDArray& grad, + const NDArray& mom, + const OpReqType& req, + NDArray *out) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + using namespace rowsparse; + if (weight.storage_shape()[0] == weight.shape()[0] && + out->storage_shape()[0] == out->shape()[0]) { + Stream* s = ctx.get_stream(); + // fill mom with zero values in order to reuse the sgd mom dns impl + if (!mom.storage_initialized()) { + MSHADOW_REAL_TYPE_SWITCH(mom.dtype(), DType, { + MSHADOW_INT_TYPE_SWITCH(mom.aux_type(kIdx), IType, { + auto num_rows = mom.shape()[0]; + mom.CheckAndAlloc({Shape1(num_rows)}); + auto mom_idx = mom.aux_data(kIdx).FlatTo1D(s); + auto mom_val = mom.data(); + // TODO(haibin) this is single-thread execution + Kernel::Launch(s, mom_val.Size(), mom_val.dptr()); + ASSIGN_DISPATCH(mom_idx, kWriteTo, range(0, num_rows, 1, 1)) + }); + }); + } + // TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only + // feed in kWriteTo as req for all operators. + // For sgd we don't want to assign zeros to the output values when req == kWriteTo + auto out_req = req; + if (out_req == kWriteTo) out_req = kWriteInplace; + TBlob out_blob = out->data(); + // reuse dns rsp implementation when storage_shape == shape + SGDMomUpdateDnsRspDnsImpl(param, ctx, weight.data(), grad, + mom.data(), out_req, &out_blob); + } else { + LOG(FATAL) << "SGDUpdate for RowSparse weights is only implemented when " + << "weights.values.shape == weights.shape"; + } +} + template inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs, const OpContext &ctx, @@ -305,10 +376,16 @@ inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs, auto weight_stype = inputs[0].storage_type(); auto grad_stype = inputs[1].storage_type(); auto mom_stype = inputs[2].storage_type(); - if (weight_stype == kDefaultStorage && grad_stype == kRowSparseStorage && mom_stype == kDefaultStorage) { - SGDMomUpdateDnsRspDnsImpl(param, ctx, inputs, req, outputs); + TBlob out = outputs[0].data(); + SGDMomUpdateDnsRspDnsImpl(param, ctx, inputs[0].data(), inputs[1], + inputs[2].data(), req[0], &out); + } else if (weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage && + mom_stype == kRowSparseStorage) { + NDArray out = outputs[0]; + SGDMomUpdateRspRspRspImpl(param, ctx, inputs[0], inputs[1], + inputs[2], req[0], &out); } else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage && mom_stype == kDefaultStorage) { FCompExFallback(attrs, ctx, inputs, req, outputs, diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index 996a25d5a647..64b7c34359b9 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -324,10 +324,8 @@ inline void CastStorageDnsRspImpl(mshadow::Stream* s, const TBlob& dns, NDA struct CastStorageRspDnsKernel { template MSHADOW_XINLINE static void Map(int i, const index_t width, const IType* idx, const DType *data, - DType* dns, const index_t invalid_rid) { + DType* dns) { auto rid = idx[i]; - // skip invalid rows - if (rid == invalid_rid) return; auto dns_offset = rid * width; auto rsp_offset = i * width; for (size_t col = 0; col < width; col++) { @@ -356,10 +354,9 @@ void CastStorageRspDnsImpl(mshadow::Stream* s, const NDArray& rsp, TBlob* d auto out_data = dns->FlatTo2D(s).dptr_; auto num_rows = rsp.aux_shape(rowsparse::kIdx).Size(); auto rsp_shape = rsp.shape(); - auto invalid_rid = rsp_shape[0]; auto width = rsp_shape.ProdShape(1, rsp_shape.ndim()); - mxnet_op::Kernel::Launch(s, num_rows, width, in_idx, in_data, - out_data, invalid_rid); + mxnet_op::Kernel::Launch(s, num_rows, width, in_idx, + in_data, out_data); } }); }); diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index 8cf00c0eb7b4..da20cf49f1a0 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -87,8 +87,17 @@ NNVM_REGISTER_OP(_backward_Embedding) .set_attr("FCompute", EmbeddingOpBackward); NNVM_REGISTER_OP(SparseEmbedding) -.describe(R"code(Maps integer indices to vector representations (embeddings) with sparse weight update -)code" ADD_FILELINE) +.describe(R"doc(Represents words or other sparse inputs by dense continuous vectors. +It assumes that the input is in one-hot form. E.g., for a vocabulary size of 10,000, + each input vector is expected to have dimension 10,000. +The index of the non-zero entry is the index of the word or item it represents. + +The corresponding embedding vectors are stored as rows of a matrix. +Hence, mapping an input word to its embedding is implemented as a matrix product. + +The gradient of an embedding matrix has the form of gradient vectors that are only + non-zero for words seen in a minibatch. +)doc" ADD_FILELINE) .set_num_inputs(2) .set_num_outputs(1) .set_attr_parser(ParamParser) @@ -96,19 +105,21 @@ NNVM_REGISTER_OP(SparseEmbedding) [](const NodeAttrs& attrs) { return std::vector{"data", "weight"}; }) -.set_attr("FInferShape", EmbeddingOpShape) +.set_attr("FInferShape", SparseEmbeddingShape) .set_attr("FInferType", EmbeddingOpType) +.set_attr("FInferStorageType", SparseEmbeddingForwardStorageType) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) -.set_attr("FCompute", EmbeddingOpForward) +.set_attr(FCOMP_EX_CPU, SparseEmbeddingForwardEx) .set_attr("FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { return MakeNonlossGradNode("_backward_SparseEmbedding", n, ograds, {n->inputs[0]}, n->attrs.dict); }) -.add_argument("data", "NDArray-or-Symbol", "The input array to the embedding operator.") +.add_argument("data", "NDArray-or-Symbol", + "The input array to the sparse embedding operator.") .add_argument("weight", "NDArray-or-Symbol", "The embedding weight matrix.") .add_arguments(EmbeddingParam::__FIELDS__()); @@ -116,10 +127,7 @@ NNVM_REGISTER_OP(_backward_SparseEmbedding) .set_num_inputs(2) .set_num_outputs(2) .set_attr("TIsBackward", true) -.set_attr("FInferStorageType", SparseEmbeddingBackwardStorageType) -.set_attr("FComputeEx", SparseEmbeddingOpBackwardEx); -// TODO(haibin) handle dense case -// .set_attr("FCompute", EmbeddingOpBackward); +.set_attr("FComputeEx", SparseEmbeddingBackwardEx); NNVM_REGISTER_OP(take) .describe(R"code(Takes elements from an input array along the given axis. diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 81b219f7c2c9..7387b7dc79f1 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -9,7 +9,6 @@ #include #include -#include #include #include #include @@ -23,6 +22,7 @@ #include "../elemwise_op_common.h" #include "../mxnet_op.h" #include "./sort_op.h" +#include "./matrix_op-inl.h" namespace mxnet { namespace op { @@ -204,6 +204,82 @@ void EmbeddingOpForward(const nnvm::NodeAttrs& attrs, }); } +template +void SparseEmbeddingForwardRspImpl(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const NDArray& data, + const NDArray& weight, + const OpReqType req, + NDArray *out) { + if (weight.storage_shape()[0] == weight.shape()[0]) { + TBlob out_blob = out->data(); + // forward to dns implementation when storage_shape equals shape + bool transpose_a = false; + DotCsrRspDnsImpl(ctx, data, weight, req, transpose_a, &out_blob); + } else { + LOG(FATAL) << "SparseEmbedding for RowSparse weights is only implemented when " + << "weights.values.shape == weights.shape"; + } +} + +template +void SparseEmbeddingForwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(req[embedding::kOut], kWriteTo); + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + + NDArray output = outputs[embedding::kOut]; + auto data_stype = inputs[embedding::kData].storage_type(); + auto weight_stype = inputs[embedding::kWeight].storage_type(); + auto out_stype = outputs[embedding::kOut].storage_type(); + if (data_stype == kCSRStorage && weight_stype == kRowSparseStorage && + out_stype == kDefaultStorage) { + NDArray ret = outputs[embedding::kOut]; + SparseEmbeddingForwardRspImpl(attrs, ctx, inputs[embedding::kData], + inputs[embedding::kWeight], + req[embedding::kOut], &ret); + } else { + LOG(FATAL) << "Not supported SparseEmbedding operation for data.storage_type = " + << data_stype << ", weight.storage_type = " << weight_stype + << ", out.storage_type = " << out_stype; + } +} + +inline bool SparseEmbeddingForwardStorageType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, embedding::kData, kCSRStorage); + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, embedding::kOut, kDefaultStorage); + // override the default storage type generated in nnvm + in_attrs->at(embedding::kWeight) = kRowSparseStorage; + return true; +} + +inline bool SparseEmbeddingShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + using namespace mshadow; + const EmbeddingParam& param = nnvm::get(attrs.parsed); + const TShape &dshape = (*in_attrs)[embedding::kData]; + CHECK_EQ(dshape.ndim(), 2) + << "SparseEmbedding shape error: data is expected to be 2D."; + SHAPE_ASSIGN_CHECK(*in_attrs, embedding::kWeight, + Shape2(param.input_dim, param.output_dim)); + out_attrs->clear(); + std::vector buf(2); + buf[0] = dshape[0]; + buf[1] = param.output_dim; + out_attrs->emplace_back(buf.begin(), buf.end()); + return true; +} + // Returns integer log2(a) rounded up inline int ilog2(unsigned int a) { int k = 1; @@ -316,130 +392,28 @@ void EmbeddingOpBackward(const nnvm::NodeAttrs& attrs, }); } -template -struct EmbeddingBackwardRsp { - template - // each thread i is responsible for target gradient row ids in [segment_start, segment_end) - MSHADOW_XINLINE static void Map(int i, const size_t width, IType* dst_idx, DType* dst_val, - const IType* idx, const size_t num_idx, const DType* src, - const size_t segment_len, const size_t num_rows) { - auto req_type = req; - size_t segment_start = i * segment_len; - size_t segment_end = (i + 1) * segment_len; - for (size_t y = 0; y < num_idx; y++) { - size_t j = idx[y]; - if (j >= num_rows) j = num_rows - 1; - if (j < segment_start || j >= segment_end) continue; - dst_idx[j] = j; - for (size_t k = 0; k < width; k++) { - if (req_type == kWriteTo) req_type = kAddTo; - KERNEL_ASSIGN(dst_val[j * width + k], req_type, src[y * width + k]); - } - } - } -}; - -/* - * for sparse embedding, the storage type for weight gradient is row_sparse. - * we don't care about the storage type for data gradient, since it is not - * differentiable. - */ -inline bool SparseEmbeddingBackwardStorageType(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { - CHECK_EQ((*in_attrs)[0], kDefaultStorage); - CHECK_EQ((*in_attrs)[1], kDefaultStorage); - (*out_attrs)[0] = kRowSparseStorage; - (*out_attrs)[1] = kRowSparseStorage; - return true; -} - template -void SparseEmbeddingOpBackwardDnsDnsRsp(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - using namespace mshadow; - using namespace mxnet_op; - using namespace mshadow::expr; +void SparseEmbeddingBackwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 2U); - if (req[1] == kNullOp) return; - // check storage types - auto idx = inputs[1]; // idx shape (d1, d2 .. dk) - auto grad = inputs[0]; // grad shape (d1, d2, .. dk, out_dim) - auto output = outputs[1]; // weight shape (in_dim, out_dim) - CHECK_EQ(idx.storage_type(), kDefaultStorage); - CHECK_EQ(grad.storage_type(), kDefaultStorage); - CHECK_EQ(output.dtype(), grad.dtype()); - CHECK_EQ(idx.dtype(), output.aux_type(rowsparse::kIdx)) << "Index type doesn't match"; + CHECK_EQ(req.size(), 2U); // CHECK_EQ(req[embedding::kData], kNullOp) - // << "Embedding layer doesn't support calculate data gradient" << req[embedding::kData]; + // << "Embedding layer doesn't support calculate data gradient" << req[0] << " " << req[1]; + // CHECK_NE(req[1], kWriteInplace) << "DotBackwardEx does not support WriteInplace"; - const TShape& ishape = idx.shape(); - const TShape& oshape = grad.shape(); - - Stream *s = ctx.get_stream(); - CHECK_EQ(idx.dtype(), output.aux_type(rowsparse::kIdx)) - << "embedding input index and gradient row sparse type doesn't match!"; - // Alloc dense output - unsigned int num_rows = output.shape()[0]; - output.CheckAndAlloc({mshadow::Shape1(num_rows)}); - MSHADOW_TYPE_SWITCH(output.dtype(), DType, { - MSHADOW_INT_TYPE_SWITCH(idx.dtype(), IType, { - MXNET_ASSIGN_REQ_SWITCH(req[1], req_type, { - // input embedding indice, each idx in [0, input_dim) - auto idx_data = idx.data().FlatTo1D(s); - auto grad_data = grad.data().get_with_shape( - Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s); - auto output_idx = output.aux_data(rowsparse::kIdx).FlatTo1D(s); - auto output_val = output.data().FlatTo2D(s); - int num_threads = omp_get_num_threads(); - size_t width = output.shape()[1]; - size_t segment_len = (num_rows + num_threads - 1) / num_threads; - // fill indices with invalid row ids - Kernel::Launch(s, num_rows, output_idx.dptr_, - static_cast(num_rows)); - // fill zeros if needed - if (req_type == kWriteTo) { - Kernel::Launch(s, output_val.shape_.Size(), output_val.dptr_); - } - Kernel, xpu>::Launch(s, num_threads, width, - output_idx.dptr_, - output_val.dptr_, idx_data.dptr_, - ishape.Size(), grad_data.dptr_, - segment_len, num_rows); - }); - }); - }); -} - -// todo replace xpu with cpu -template -void SparseEmbeddingOpBackwardEx(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - using namespace mshadow; - using namespace mxnet_op; - using namespace mshadow::expr; - CHECK_EQ(inputs.size(), 2U); - CHECK_EQ(outputs.size(), 2U); - // CHECK_EQ(req[embedding::kData], kNullOp) - // << "Embedding layer doesn't support calculate data gradient" << req[0] << " " << req[1]; - // idx shape (d1, d2 .. dk) - auto idx_stype = inputs[1].storage_type(); - // grad shape (d1, d2, .. dk, out_dim) + auto data_stype = inputs[1].storage_type(); auto grad_stype = inputs[0].storage_type(); - // weight shape (in_dim, out_dim) auto output_stype = outputs[1].storage_type(); - if (idx_stype == kDefaultStorage && grad_stype == kDefaultStorage && - output_stype == kRowSparseStorage) { - SparseEmbeddingOpBackwardDnsDnsRsp(attrs, ctx, inputs, req, outputs); + if (data_stype == kCSRStorage && grad_stype == kDefaultStorage && + output_stype == kDefaultStorage) { + TBlob ret = outputs[1].data(); + DotCsrDnsDnsImpl(ctx, inputs[1], inputs[0].data(), req[1], true, &ret); } else { - LOG(FATAL) << "Not implemented"; + LOG(FATAL) << "Not supported dot backward for sparse input(s) with sparse gradients"; } } diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 05fba76d0ff3..f01d6428b0d4 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -643,22 +643,20 @@ struct DotCsrTransDnsDnsByRowBlocks { template void DotCsrDnsDnsImpl(const OpContext& ctx, const NDArray& lhs, - const NDArray& rhs, + const TBlob& rhs, const OpReqType req, const bool trans_lhs, - NDArray* ret) { + TBlob* ret) { if (kNullOp == req) return; CHECK_EQ(lhs.storage_type(), kCSRStorage); - CHECK_EQ(rhs.storage_type(), kDefaultStorage); - CHECK_EQ(ret->storage_type(), kDefaultStorage); if (!lhs.storage_initialized()) return; mshadow::Stream *s = ctx.get_stream(); const TBlob data_l = lhs.data(); const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); const TBlob col_idx_l = lhs.aux_data(csr::kIdx); - const TBlob data_r = rhs.data(); - const TBlob data_out = ret->data(); + const TBlob& data_r = rhs; + const TBlob data_out = *ret; MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type MSHADOW_INT_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type @@ -693,7 +691,7 @@ void DotCsrDnsDnsImpl(const OpContext& ctx, MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { mxnet_op::Kernel, xpu>::Launch(s, data_out.Size(), data_out.dptr(), data_l.dptr(), indptr_l.dptr(), - col_idx_l.dptr(), data_r.dptr(), rhs.shape()[1]); + col_idx_l.dptr(), data_r.dptr(), rhs.shape_[1]); }); } } @@ -702,6 +700,21 @@ void DotCsrDnsDnsImpl(const OpContext& ctx, }); } +template +void DotCsrRspDnsImpl(const OpContext& ctx, + const NDArray& lhs, + const NDArray& rhs, + const OpReqType req, + const bool trans_lhs, + TBlob* ret) { + if (rhs.storage_shape()[0] == rhs.shape()[0]) { + // reuse csr dns implementation when storage_shape == shape for rhs + DotCsrDnsDnsImpl(ctx, lhs, rhs.data(), req, trans_lhs, ret); + } else { + LOG(FATAL) << "Dot for RowSparse rhs is only implemented for rhs.values.shape == rhs.shape"; + } +} + template void DotBackwardCsrDnsDns(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -709,8 +722,25 @@ void DotBackwardCsrDnsDns(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { const DotParam& param = nnvm::get(attrs.parsed); - NDArray ret = outputs[1]; - DotCsrDnsDnsImpl(ctx, inputs[1], inputs[0], req[1], !param.transpose_a, &ret); + TBlob ret = outputs[1].data(); + DotCsrDnsDnsImpl(ctx, inputs[1], inputs[0].data(), req[1], !param.transpose_a, &ret); +} + +template +void DotBackwardCsrRspDns(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const auto& rhs = inputs[2]; + if (rhs.storage_shape()[0] == rhs.shape()[0]) { + // reuse csr dns implementation when storage_shape == shape for rhs + const DotParam& param = nnvm::get(attrs.parsed); + TBlob ret = outputs[1].data(); + DotCsrDnsDnsImpl(ctx, inputs[1], inputs[0].data(), req[1], !param.transpose_a, &ret); + } else { + LOG(FATAL) << "Dot for RowSparse rhs is only implemented for rhs.values.shape == rhs.shape"; + } } inline bool DotShape(const nnvm::NodeAttrs& attrs, @@ -767,12 +797,16 @@ void DotForwardEx(const nnvm::NodeAttrs& attrs, CHECK_EQ(req.size(), 1U); const DotParam& param = nnvm::get(attrs.parsed); CHECK(!param.transpose_b) << "tranposing rhs of the op dot is not supported"; - - NDArray ret = outputs[0]; // get rid of the const qualifier - if (inputs[0].storage_type() == kCSRStorage - && inputs[1].storage_type() == kDefaultStorage - && outputs[0].storage_type() == kDefaultStorage) { - DotCsrDnsDnsImpl(ctx, inputs[0], inputs[1], req[0], param.transpose_a, &ret); + auto lhs_stype = inputs[0].storage_type(); + auto rhs_stype = inputs[1].storage_type(); + auto out_stype = outputs[0].storage_type(); + if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == kDefaultStorage) { + TBlob ret = outputs[0].data(); + DotCsrDnsDnsImpl(ctx, inputs[0], inputs[1].data(), req[0], param.transpose_a, &ret); + } else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage && + out_stype == kDefaultStorage) { + TBlob ret = outputs[0].data(); + DotCsrRspDnsImpl(ctx, inputs[0], inputs[1], req[0], param.transpose_a, &ret); } else { // TODO(junwu): add fallback LOG(FATAL) << "Not supported dot operation for lhs.storage_type = " << inputs[0].storage_type() << ", rhs.storage_type = " << inputs[1].storage_type() @@ -796,12 +830,19 @@ void DotBackwardEx(const nnvm::NodeAttrs& attrs, // TODO(junwu): check whether this CHECK is reasonable const DotParam& param = nnvm::get(attrs.parsed); CHECK(!param.transpose_b) << "sparse dot only supports dot(A, X) and dot(A.T(), X)"; - if (inputs[0].storage_type() == kDefaultStorage // ograd dns format - // dns, csr, dns => *, dns - && inputs[1].storage_type() == kCSRStorage // csr input lhs of the op - && inputs[2].storage_type() == kDefaultStorage // dns input rhs of the op + auto ograd_stype = inputs[0].storage_type(); + auto lhs_stype = inputs[1].storage_type(); + auto rhs_stype = inputs[2].storage_type(); + if (ograd_stype == kDefaultStorage // ograd dns format + && lhs_stype == kCSRStorage // csr input lhs of the op + && rhs_stype == kDefaultStorage // dns input rhs of the op && outputs[1].storage_type() == kDefaultStorage) { // grad(rhs) dns format + // dns, csr, dns => *, dns DotBackwardCsrDnsDns(attrs, ctx, inputs, req, outputs); + } else if (ograd_stype == kDefaultStorage && lhs_stype == kCSRStorage && + rhs_stype == kRowSparseStorage && outputs[1].storage_type() == kDefaultStorage) { + // dns, csr, rsp => *, dns + DotBackwardCsrRspDns(attrs, ctx, inputs, req, outputs); } else { LOG(FATAL) << "Not supported dot backward for sparse input(s) with sparse gradients"; } diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index 6f69828ed9b1..80632c262a8e 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -35,8 +35,8 @@ def compare_optimizer(opt1, opt2, shape, w_stype='default', g_stype='default'): w2 = mx.random.uniform(shape=shape, ctx=default_context()) w1 = w2.copyto(default_context()) elif w_stype == 'row_sparse': - w2 = rand_ndarray(shape, w_stype) - w1 = rand_ndarray(shape, w_stype).to_dense() + w2 = rand_ndarray(shape, w_stype, density=1) + w1 = w2.copyto(default_context()).to_dense() else: raise Exception("type not supported yet") if g_stype == 'default': @@ -51,14 +51,20 @@ def compare_optimizer(opt1, opt2, shape, w_stype='default', g_stype='default'): state1 = opt1.create_state(0, w1) state2 = opt2.create_state(0, w2) if state1 is not None and state2 is not None: - for s1, s2, in zip(state1, state2): - assert(same(s1.asnumpy(), s2.asnumpy())) + if isinstance(state1, tuple): + for s1, s2, in zip(state1, state2): + assert(same(s1.asnumpy(), s2.asnumpy())) + else: + assert_almost_equal(state1.asnumpy(), state2.asnumpy()) opt1.update(0, w1, g1, state1) opt2.update(0, w2, g2, state2) if state1 is not None and state2 is not None: - for s1, s2, in zip(state1, state2): - assert_almost_equal(s1.asnumpy(), s2.asnumpy(), rtol=1e-4, atol=1e-5) + if isinstance(state1, tuple): + for s1, s2, in zip(state1, state2): + assert_almost_equal(s1.asnumpy(), s2.asnumpy(), rtol=1e-4, atol=1e-5) + else: + assert_almost_equal(state1.asnumpy(), state2.asnumpy()) assert_almost_equal(w1.asnumpy(), w2.asnumpy(), rtol=1e-4, atol=1e-5) # SGD @@ -230,7 +236,7 @@ def test_sparse_sgd(): {'clip_gradient': 0.4, 'rescale_grad': 0.14, 'wd': 0.03, 'momentum': 0.9}, {'rescale_grad': 0.8, 'wd': 0.05, 'momentum': 0.9}] for kwarg in kwargs: - compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, w_stype='default', g_stype='row_sparse') + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, w_stype='row_sparse', g_stype='row_sparse') # ADAM diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index fc27b80f4530..d46a5f7c81a2 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -62,6 +62,7 @@ def check_sparse_nd_zeros(stype, shape): shape = rand_shape_2d() check_sparse_nd_zeros('row_sparse', shape) check_sparse_nd_zeros('csr', shape) + check_sparse_nd_zeros('default', shape) def test_sparse_nd_copy(): diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index d625dfa7906b..ac7be4b41c80 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -65,8 +65,7 @@ def test_elemwise_add_ex_multiple_stages(): exec_test.backward(out_grads=exec_test.outputs) assert_almost_equal(arr_grads[0].asnumpy(), arr_grads[1].asnumpy()) - -# TODO(haibin) also add test for backward pass. Check if exception is thrown +# TODO(haibin) also add test for backward pass. def test_cast_storage_ex(): def test_rsp_to_dns(shape): rsp, (data, row_idx) = rand_sparse_ndarray(shape, 'row_sparse') @@ -102,52 +101,56 @@ def test_dns_to_csr(dns_in): test_dns_to_csr([[0, 1, 0], [0, 2, 0], [3, 0, 0], [0, 0, 4], [5, 6, 0], [0, 0, 7]]) def test_sparse_dot(): - def test_dot_csr_dns(csr_shape, dns_shape, trans_csr): - dns1 = rand_ndarray(csr_shape, 'default') - dns2 = rand_ndarray(dns_shape, 'default') - csr = mx.nd.cast_storage(dns1, storage_type='csr') - out = mx.nd.dot(csr, dns2, transpose_a=trans_csr) + def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs): + lhs_dns = rand_ndarray(lhs_shape, 'default') + lhs_nd = mx.nd.cast_storage(lhs_dns, storage_type='csr') + rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=1) + rhs_dns = rhs_nd if rhs_stype == 'default' else rhs_nd.to_dense() + out = mx.nd.dot(lhs_nd, rhs_dns, transpose_a=trans_lhs) assert out.storage_type == 'default' - out_expected = mx.nd.dot(dns1, dns2, transpose_a=trans_csr) + out_expected = mx.nd.dot(lhs_dns, rhs_dns, transpose_a=trans_lhs) out_np = out_expected.asnumpy() - backward_trans = not trans_csr - rhs_backward_grad = mx.nd.dot(dns1, out_expected, transpose_a=backward_trans).asnumpy() + backward_trans = not trans_lhs + rhs_backward_grad = mx.nd.dot(lhs_dns, out_expected, transpose_a=backward_trans).asnumpy() assert_almost_equal(out.asnumpy(), out_np, rtol=1e-4, atol=1e-5) # test symbolic forward lhs = mx.symbol.Variable('lhs', storage_type='csr') - rhs = mx.symbol.Variable('rhs', storage_type='default') - test = mx.symbol.dot(lhs, rhs, transpose_a=trans_csr) - location = {'lhs': csr, 'rhs': dns2} + rhs = mx.symbol.Variable('rhs', storage_type=rhs_stype) + test = mx.symbol.dot(lhs, rhs, transpose_a=trans_lhs) + location = {'lhs': lhs_nd, 'rhs': rhs_nd} expected = {'rhs': rhs_backward_grad} - # dot(lhs, rhs) - check_symbolic_forward(test, location, [out_expected.asnumpy()], rtol=1e-3, atol=1e-4) + check_symbolic_forward(test, location, [out_np], rtol=1e-3, atol=1e-4) + # test symbolic backward check_symbolic_backward(test, location, [out_np], expected, grad_req={'lhs': 'null', 'rhs': 'write'}, rtol=1e-3, atol=1e-4) lhs_shape = rand_shape_2d() - test_dot_csr_dns(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), False) - test_dot_csr_dns(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), True) - + test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'default', False) + test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'default', True) + test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False) + test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True) def test_sparse_embedding(): in_dim = 10 out_dim = 4 batch = 24 - data = mx.sym.Variable("data", dtype=np.int32) + data = mx.sym.Variable("data", storage_type='csr') embed = mx.sym.SparseEmbedding(data=data, input_dim=in_dim, output_dim=out_dim, name="embed") exe_test = embed.simple_bind(default_context(), grad_req={'data': 'null', 'embed_weight': 'write'}, - data=(batch,)) + data=(batch, in_dim)) + arg_map = dict(zip(embed.list_arguments(), exe_test.arg_arrays)) grad_map = dict(zip(embed.list_arguments(), exe_test.grad_arrays)) np_data = np.random.randint(low=0, high=in_dim, size=batch) np_weight = np.random.uniform(-0.01, 0.01, arg_map["embed_weight"].shape) np_onehot = np.zeros((batch, in_dim)) np_onehot[np.arange(batch), np_data] = 1.0 + nd_onehot = mx.nd.array(np_onehot).to_csr() # forward - arg_map["data"][:] = np_data + arg_map["data"][:] = nd_onehot arg_map["embed_weight"][:] = np_weight exe_test.forward(is_train=True) assert_almost_equal(exe_test.outputs[0].asnumpy(), np.dot(np_onehot, np_weight)) @@ -197,7 +200,6 @@ def test_sparse_retain(): sym = mx.sym.sparse_retain(data=data, indices=idx) check_numeric_gradient(sym, [rsp, indices], grad_nodes=['data'], grad_stype_dict={'data': 'row_sparse'}) - if __name__ == '__main__': import nose nose.runmodule()