Skip to content

Commit

Permalink
draft for sgd rsp rsp (#75)
Browse files Browse the repository at this point in the history
support sgd(rsp, rsp)

support dot(csr, rsp) when rsp is full

add ref to const ndarray params

support sparse embedding with rsp weight'

fix lint

modify embedding backward to produce dense grad

remove invalid_rid for rsp->dns

remove previous embedding op changes

pass sparse embedding test

add STORAGE_TYPE_ASSIGN_CHECK

remove backward storage infer
  • Loading branch information
eric-haibin-lin committed Jun 10, 2017
1 parent f98912b commit a880bc7
Show file tree
Hide file tree
Showing 11 changed files with 365 additions and 224 deletions.
4 changes: 3 additions & 1 deletion python/mxnet/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from .ndarray import NDArray, zeros, clip, sqrt, sign
from .ndarray import sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update
from .sparse_ndarray import zeros as sparse_zeros
from .random import normal


Expand Down Expand Up @@ -332,7 +333,8 @@ def create_state(self, index, weight):
if self.momentum == 0.0:
return None
else:
return zeros(weight.shape, weight.context, dtype=weight.dtype)
return sparse_zeros(weight.storage_type, weight.shape,
weight.context, dtype=weight.dtype)

def update(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
Expand Down
6 changes: 4 additions & 2 deletions python/mxnet/sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ def to_dense(source):
"""
return ndarray.cast_storage(source, storage_type='default')

def zeros(storage_type, shape, ctx=None, dtype=None, aux_types=None):
def zeros(storage_type, shape, ctx=None, dtype=None, aux_types=None, **kwargs):
"""Return a new array of given shape and type, filled with zeros.
Parameters
Expand Down Expand Up @@ -599,6 +599,8 @@ def zeros(storage_type, shape, ctx=None, dtype=None, aux_types=None):
>>> mx.sparse_nd.zeros('row_sparse', (1,2), mx.gpu(0), 'float16').asnumpy()
array([[ 0., 0.]], dtype=float16)
"""
if storage_type == 'default':
return ndarray.zeros(shape, ctx, dtype, **kwargs)
if ctx is None:
ctx = Context.default_ctx
dtype = mx_real_t if dtype is None else dtype
Expand All @@ -609,7 +611,7 @@ def zeros(storage_type, shape, ctx=None, dtype=None, aux_types=None):
raise Exception("unknown storage type")
assert(len(aux_types) == len(_STORAGE_AUX_TYPES[storage_type]))
out = _ndarray_cls(_new_alloc_handle(storage_type, shape, ctx, True, dtype, aux_types))
return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, out=out)
return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, out=out, **kwargs)

def _ndarray_cls(handle, writable=True):
stype = _storage_type(handle)
Expand Down
31 changes: 31 additions & 0 deletions src/operator/operator_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,19 @@ inline std::string type_string(const int& x) {
return "unknown";
}

/*! \brief get string representation of storage_type */
inline std::string stype_string(const int& x) {
switch (x) {
case kDefaultStorage:
return "default";
case kCSRStorage:
return "csr";
case kRowSparseStorage:
return "row_sparse";
}
return "unknown";
}

/*!
* \brief Assign x to y. Checks for compatiblity when y is not empty.
* Allow missing dim in both x and y (as 0).
Expand Down Expand Up @@ -186,6 +199,24 @@ inline bool type_assign(int *y, const int& x) {
} \
}

/*!
* \brief macro assign type to out if out is unknown (-1) otherwise check consistency
* Use macro so we can see the error file more clearly
* \param type_array the storage type array to store the result
* \param index the index of in the array
* \param type the inferred storage type
*/
#define STORAGE_TYPE_ASSIGN_CHECK(type_array, index, type) \
{ \
if (!type_assign(&(type_array)[index], type)) { \
std::ostringstream os; \
os << "Storage type inconsistent, Provided=" \
<< stype_string((type_array)[index]) << ',' \
<< " inferred storage type=" << stype_string(type); \
throw ::mxnet::op::InferTypeError(os.str(), index); \
} \
}

// helper macro to implement bind dispatch
#if MXNET_USE_CUDA
#define DO_BIND_DISPATCH(Method, ...) \
Expand Down
159 changes: 118 additions & 41 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,32 +112,31 @@ struct SGDDnsRspKernel {

template<typename xpu>
inline void SGDUpdateDnsRspImpl(const SGDParam& param,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
const OpContext &ctx,
const TBlob& weight,
const NDArray& grad,
const OpReqType& req,
TBlob *out) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mshadow_op;
using namespace mxnet_op;
Stream<xpu>* s = ctx.get_stream<xpu>();
auto &weight = inputs[0];
auto &grad = inputs[1];
auto &out = outputs[0];
CHECK_EQ(weight.storage_type(), kDefaultStorage);
CHECK_EQ(grad.storage_type(), kRowSparseStorage);
if (!grad.storage_initialized()) return;
// if gradients are zeros, no weights are updated
if (!grad.storage_initialized() || req == kNullOp) return;
CHECK_GT(weight.shape_.Size(), 0);

MSHADOW_REAL_TYPE_SWITCH(weight.dtype(), DType, {
MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
MSHADOW_INT_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
auto weight_data = weight.data().FlatTo2D<xpu, DType>(s);
auto grad_idx = grad.aux_data(rowsparse::kIdx).FlatTo1D<xpu, IType>(s);
auto grad_val = grad.data().FlatTo2D<xpu, DType>(s);
auto out_data = out.data().FlatTo2D<xpu, DType>(s);
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
auto weight_data = weight.dptr<DType>();
auto grad_idx = grad.aux_data(rowsparse::kIdx).dptr<IType>();
auto grad_val = grad.data().dptr<DType>();
auto num_rows = grad.aux_shape(rowsparse::kIdx)[0];
auto width = weight.shape().ProdShape(1, weight.shape().ndim());
mxnet_op::Kernel<SGDDnsRspKernel<req_type>, xpu>::Launch(s, num_rows, width,
out_data.dptr_, weight_data.dptr_, grad_idx.dptr_, grad_val.dptr_,
auto width = weight.shape_.ProdShape(1, weight.ndim());
Kernel<SGDDnsRspKernel<req_type>, xpu>::Launch(s, num_rows, width,
out->dptr<DType>(), weight_data, grad_idx, grad_val,
static_cast<DType>(param.clip_gradient),
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
static_cast<DType>(param.rescale_grad));
Expand All @@ -146,6 +145,29 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param,
});
}

template<typename xpu>
inline void SGDUpdateRspRspImpl(const SGDParam& param,
const OpContext& ctx,
const NDArray& weight,
const NDArray& grad,
const OpReqType& req,
NDArray *out) {
if (weight.storage_shape()[0] == weight.shape()[0] &&
out->storage_shape()[0] == out->shape()[0]) {
// TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only
// feed in kWriteTo as req for all operators.
// For sgd we don't want to assign zeros to the output values when req == kWriteTo
auto out_req = req;
if (out_req == kWriteTo) out_req = kWriteInplace;
// reuse dns rsp implementation when storage_shape == shape
TBlob out_blob = out->data();
SGDUpdateDnsRspImpl<xpu>(param, ctx, weight.data(), grad, out_req, &out_blob);
} else {
LOG(FATAL) << "SGDUpdate for RowSparse weights is only implemented when "
<< "weights.values.shape == weights.shape";
}
}

template<typename xpu>
inline void SGDUpdateEx(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
Expand All @@ -159,7 +181,11 @@ inline void SGDUpdateEx(const nnvm::NodeAttrs& attrs,
auto weight_stype = inputs[0].storage_type();
auto grad_stype = inputs[1].storage_type();
if (weight_stype == kDefaultStorage && grad_stype == kRowSparseStorage) {
SGDUpdateDnsRspImpl<xpu>(param, ctx, inputs, req, outputs);
TBlob out = outputs[0].data();
SGDUpdateDnsRspImpl<xpu>(param, ctx, inputs[0].data(), inputs[1], req[0], &out);
} else if (weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage) {
NDArray out = outputs[0];
SGDUpdateRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], req[0], &out);
} else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage) {
FCompExFallback<xpu>(attrs, ctx, inputs, req, outputs, SGDUpdate<xpu>, "SGDUpdate");
}
Expand Down Expand Up @@ -262,30 +288,31 @@ struct SGDMomDnsRspDnsKernel {

template<typename xpu>
inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
const OpContext& ctx,
const TBlob& weight,
const NDArray& grad,
const TBlob& mom,
const OpReqType& req,
TBlob *out) {
using namespace mxnet_op;
using namespace rowsparse;
Stream<xpu>* s = ctx.get_stream<xpu>();
auto &weight = inputs[0];
auto &grad = inputs[1];
auto &mom = inputs[2];
auto &out = outputs[0];
if (!grad.storage_initialized()) return;
if (!grad.storage_initialized() || req == kNullOp) return;
CHECK_GT(weight.shape_.Size(), 0);
CHECK_GT(mom.shape_.Size(), 0);

MSHADOW_REAL_TYPE_SWITCH(weight.dtype(), DType, {
MSHADOW_INT_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
auto weight_data = weight.data().FlatTo2D<xpu, DType>(s);
auto grad_idx = grad.aux_data(rowsparse::kIdx).FlatTo1D<xpu, IType>(s);
auto grad_val = grad.data().FlatTo2D<xpu, DType>(s);
auto mom_data = mom.data().FlatTo2D<xpu, DType>(s);
auto out_data = out.data().FlatTo2D<xpu, DType>(s);
auto num_rows = grad.aux_shape(rowsparse::kIdx)[0];
auto width = weight.shape().ProdShape(1, weight.shape().ndim());
MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
MSHADOW_INT_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
auto weight_data = weight.dptr<DType>();
auto grad_idx = grad.aux_data(kIdx).dptr<IType>();
auto grad_val = grad.data().dptr<DType>();
auto mom_data = mom.dptr<DType>();
auto out_data = out->dptr<DType>();
auto num_rows = grad.aux_shape(kIdx)[0];
auto width = weight.shape_.ProdShape(1, weight.ndim());
Kernel<SGDMomDnsRspDnsKernel<req_type>, xpu>::Launch(s, num_rows, width,
out_data.dptr_, mom_data.dptr_, weight_data.dptr_, grad_idx.dptr_, grad_val.dptr_,
out_data, mom_data, weight_data, grad_idx, grad_val,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum),
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
static_cast<DType>(param.rescale_grad));
Expand All @@ -294,6 +321,50 @@ inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param,
});
}

template<typename xpu>
inline void SGDMomUpdateRspRspRspImpl(const SGDMomParam& param,
const OpContext& ctx,
const NDArray& weight,
const NDArray& grad,
const NDArray& mom,
const OpReqType& req,
NDArray *out) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mxnet_op;
using namespace rowsparse;
if (weight.storage_shape()[0] == weight.shape()[0] &&
out->storage_shape()[0] == out->shape()[0]) {
Stream<xpu>* s = ctx.get_stream<xpu>();
// fill mom with zero values in order to reuse the sgd mom dns impl
if (!mom.storage_initialized()) {
MSHADOW_REAL_TYPE_SWITCH(mom.dtype(), DType, {
MSHADOW_INT_TYPE_SWITCH(mom.aux_type(kIdx), IType, {
auto num_rows = mom.shape()[0];
mom.CheckAndAlloc({Shape1(num_rows)});
auto mom_idx = mom.aux_data(kIdx).FlatTo1D<xpu, IType>(s);
auto mom_val = mom.data();
// TODO(haibin) this is single-thread execution
Kernel<set_zero, xpu>::Launch(s, mom_val.Size(), mom_val.dptr<DType>());
ASSIGN_DISPATCH(mom_idx, kWriteTo, range<IType>(0, num_rows, 1, 1))
});
});
}
// TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only
// feed in kWriteTo as req for all operators.
// For sgd we don't want to assign zeros to the output values when req == kWriteTo
auto out_req = req;
if (out_req == kWriteTo) out_req = kWriteInplace;
TBlob out_blob = out->data();
// reuse dns rsp implementation when storage_shape == shape
SGDMomUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
mom.data(), out_req, &out_blob);
} else {
LOG(FATAL) << "SGDUpdate for RowSparse weights is only implemented when "
<< "weights.values.shape == weights.shape";
}
}

template<typename xpu>
inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
Expand All @@ -305,10 +376,16 @@ inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs,
auto weight_stype = inputs[0].storage_type();
auto grad_stype = inputs[1].storage_type();
auto mom_stype = inputs[2].storage_type();

if (weight_stype == kDefaultStorage && grad_stype == kRowSparseStorage &&
mom_stype == kDefaultStorage) {
SGDMomUpdateDnsRspDnsImpl<xpu>(param, ctx, inputs, req, outputs);
TBlob out = outputs[0].data();
SGDMomUpdateDnsRspDnsImpl<xpu>(param, ctx, inputs[0].data(), inputs[1],
inputs[2].data(), req[0], &out);
} else if (weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage &&
mom_stype == kRowSparseStorage) {
NDArray out = outputs[0];
SGDMomUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1],
inputs[2], req[0], &out);
} else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage &&
mom_stype == kDefaultStorage) {
FCompExFallback<xpu>(attrs, ctx, inputs, req, outputs,
Expand Down
9 changes: 3 additions & 6 deletions src/operator/tensor/elemwise_unary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,10 +324,8 @@ inline void CastStorageDnsRspImpl(mshadow::Stream<cpu>* s, const TBlob& dns, NDA
struct CastStorageRspDnsKernel {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, const index_t width, const IType* idx, const DType *data,
DType* dns, const index_t invalid_rid) {
DType* dns) {
auto rid = idx[i];
// skip invalid rows
if (rid == invalid_rid) return;
auto dns_offset = rid * width;
auto rsp_offset = i * width;
for (size_t col = 0; col < width; col++) {
Expand Down Expand Up @@ -356,10 +354,9 @@ void CastStorageRspDnsImpl(mshadow::Stream<xpu>* s, const NDArray& rsp, TBlob* d
auto out_data = dns->FlatTo2D<xpu, DType>(s).dptr_;
auto num_rows = rsp.aux_shape(rowsparse::kIdx).Size();
auto rsp_shape = rsp.shape();
auto invalid_rid = rsp_shape[0];
auto width = rsp_shape.ProdShape(1, rsp_shape.ndim());
mxnet_op::Kernel<CastStorageRspDnsKernel, xpu>::Launch(s, num_rows, width, in_idx, in_data,
out_data, invalid_rid);
mxnet_op::Kernel<CastStorageRspDnsKernel, xpu>::Launch(s, num_rows, width, in_idx,
in_data, out_data);
}
});
});
Expand Down
26 changes: 17 additions & 9 deletions src/operator/tensor/indexing_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,39 +87,47 @@ NNVM_REGISTER_OP(_backward_Embedding)
.set_attr<FCompute>("FCompute<cpu>", EmbeddingOpBackward<cpu>);

NNVM_REGISTER_OP(SparseEmbedding)
.describe(R"code(Maps integer indices to vector representations (embeddings) with sparse weight update
)code" ADD_FILELINE)
.describe(R"doc(Represents words or other sparse inputs by dense continuous vectors.
It assumes that the input is in one-hot form. E.g., for a vocabulary size of 10,000,
each input vector is expected to have dimension 10,000.
The index of the non-zero entry is the index of the word or item it represents.
The corresponding embedding vectors are stored as rows of a matrix.
Hence, mapping an input word to its embedding is implemented as a matrix product.
The gradient of an embedding matrix has the form of gradient vectors that are only
non-zero for words seen in a minibatch.
)doc" ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(ParamParser<EmbeddingParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "weight"};
})
.set_attr<nnvm::FInferShape>("FInferShape", EmbeddingOpShape)
.set_attr<nnvm::FInferShape>("FInferShape", SparseEmbeddingShape)
.set_attr<nnvm::FInferType>("FInferType", EmbeddingOpType)
.set_attr<nnvm::FInferStorageType>("FInferStorageType", SparseEmbeddingForwardStorageType)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", EmbeddingOpForward<cpu>)
.set_attr<FComputeEx>(FCOMP_EX_CPU, SparseEmbeddingForwardEx<cpu>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
return MakeNonlossGradNode("_backward_SparseEmbedding", n, ograds,
{n->inputs[0]}, n->attrs.dict);
})
.add_argument("data", "NDArray-or-Symbol", "The input array to the embedding operator.")
.add_argument("data", "NDArray-or-Symbol",
"The input array to the sparse embedding operator.")
.add_argument("weight", "NDArray-or-Symbol", "The embedding weight matrix.")
.add_arguments(EmbeddingParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_SparseEmbedding)
.set_num_inputs(2)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<nnvm::FInferStorageType>("FInferStorageType", SparseEmbeddingBackwardStorageType)
.set_attr<FComputeEx>("FComputeEx<cpu>", SparseEmbeddingOpBackwardEx<cpu>);
// TODO(haibin) handle dense case
// .set_attr<FCompute>("FCompute<cpu>", EmbeddingOpBackward<cpu>);
.set_attr<FComputeEx>("FComputeEx<cpu>", SparseEmbeddingBackwardEx<cpu>);

NNVM_REGISTER_OP(take)
.describe(R"code(Takes elements from an input array along the given axis.
Expand Down
Loading

0 comments on commit a880bc7

Please sign in to comment.