Skip to content

Commit

Permalink
Sgd with row_sparse weight, dns gradient (apache#83)
Browse files Browse the repository at this point in the history
* sgd rsp dns draft

* support sgd_mom(rsp, dns, rsp)

* update doc

* remove cast storage for kv updater

* code refactoring
  • Loading branch information
eric-haibin-lin authored and Olivier committed Jun 13, 2017
1 parent 204e1e8 commit 88d1f4c
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 105 deletions.
2 changes: 1 addition & 1 deletion mshadow
Submodule mshadow updated 2 files
+4 −0 mshadow/half.h
+8 −4 mshadow/half2.h
3 changes: 0 additions & 3 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,6 @@ def _update_params(param_arrays, grad_arrays, updater, num_device,
# state for the same index but on diff devs, TODO(mli)
# use a better solution later
w, g = p
# cast storage type if stype doesn't match
if g.storage_type != w.storage_type:
g = nd.cast_storage(g, w.storage_type)
updater(index*num_device+k, g, w)


Expand Down
9 changes: 9 additions & 0 deletions src/operator/operator_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,15 @@ void FCompExFallback(const nnvm::NodeAttrs& attrs,
CastNonDefaultStorage<xpu>(outputs, temp_out, ctx, true);
}

#define CHECK_RSP_ALL_ROWS_NON_ZERO(rsp, func, param) \
{ \
CHECK(rsp.storage_shape()[0] == rsp.shape()[0]) << func \
<< " for RowSparse " << param << " is only implemented for " \
<< "RowSparse " << param << " with all rows containing non-zeros. " \
<< "Expects " << param << ".values.shape[0] (" << rsp.storage_shape()[0] \
<< ") == " << param << ".shape[0] (" << rsp.shape()[0] << ")."; \
}


} // namespace op
} // namespace mxnet
Expand Down
276 changes: 209 additions & 67 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,29 +145,84 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param,
});
}

/*! \brief kernel for sparse sgd
*/
template<int req>
struct SGDRspDnsKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, size_t num_cols, DType* out, const DType* weight,
const DType *grad, const DType clip_gradient, const DType lr,
const DType wd, const DType rescale_grad) {
bool contains_non_zeros = false;
index_t j = 0;
index_t offset = i * num_cols;
for (; j < num_cols; ++j) {
if (grad[offset + j] != 0) {
contains_non_zeros = true;
break;
}
}
if (!contains_non_zeros) return;
const DType rate = 1.f - lr * wd;
for (index_t j = 0; j < num_cols; j++) {
auto index = offset + j;
if (clip_gradient >= 0.0f) {
KERNEL_ASSIGN(out[index], req, rate * weight[index] -
lr * mshadow_op::clip::Map(rescale_grad * grad[index], clip_gradient));
} else {
KERNEL_ASSIGN(out[index], req, rate * weight[index] -
lr * rescale_grad * grad[index]);
}
}
}
};

template<typename xpu>
inline void SGDUpdateRspDnsImpl(const SGDParam& param,
const OpContext &ctx,
const NDArray& weight,
const TBlob& grad,
const OpReqType req,
NDArray *out) {
using namespace mshadow;
using namespace mxnet_op;
using namespace rowsparse;
CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDUpdate", "weights");
CHECK_EQ(weight.storage_type(), kRowSparseStorage);
if (req == kNullOp) return;
CHECK(weight.storage_initialized());
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(weight.dtype(), DType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
auto weight_data = weight.data().dptr<DType>();
auto grad_data = grad.dptr<DType>();
auto num_rows = weight.aux_shape(kIdx)[0];
auto num_cols = weight.shape().ProdShape(1, weight.shape().ndim());
Kernel<SGDRspDnsKernel<req_type>, xpu>::Launch(s, num_rows, num_cols,
out->data().dptr<DType>(), weight_data, grad_data,
static_cast<DType>(param.clip_gradient),
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
static_cast<DType>(param.rescale_grad));
});
});
}

template<typename xpu>
inline void SGDUpdateRspRspImpl(const SGDParam& param,
const OpContext& ctx,
const NDArray& weight,
const NDArray& grad,
const OpReqType& req,
NDArray *out) {
if (weight.storage_shape()[0] == weight.shape()[0] &&
out->storage_shape()[0] == out->shape()[0]) {
// TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only
// feed in kWriteTo as req for all operators.
// For sgd we don't want to assign zeros to the output values when req == kWriteTo
auto out_req = req;
if (out_req == kWriteTo) out_req = kWriteInplace;
// reuse dns rsp implementation when storage_shape == shape
TBlob out_blob = out->data();
SGDUpdateDnsRspImpl<xpu>(param, ctx, weight.data(), grad, out_req, &out_blob);
} else {
LOG(FATAL) << "SGDUpdate for RowSparse weights is only implemented for "
<< "RowSparse weights with all rows containing non-zeros. "
<< "Expects weights.values.shape[0] (" << weight.storage_shape()[0]
<< ") == weights.shape[0] (" << weight.shape()[0] << ").";
}
CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDUpdate", "weights");
// TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only
// feed in kWriteTo as req for all operators.
// For sgd we don't want to assign zeros to the output values when req == kWriteTo
auto out_req = req;
if (out_req == kWriteTo) out_req = kWriteInplace;
// reuse dns rsp implementation when storage_shape == shape
TBlob out_blob = out->data();
SGDUpdateDnsRspImpl<xpu>(param, ctx, weight.data(), grad, out_req, &out_blob);
}

template<typename xpu>
Expand All @@ -188,6 +243,9 @@ inline void SGDUpdateEx(const nnvm::NodeAttrs& attrs,
} else if (weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage) {
NDArray out = outputs[0];
SGDUpdateRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], req[0], &out);
} else if (weight_stype == kRowSparseStorage && grad_stype == kDefaultStorage) {
NDArray out = outputs[0];
SGDUpdateRspDnsImpl<xpu>(param, ctx, inputs[0], inputs[1].data(), req[0], &out);
} else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage) {
FCompExFallback<xpu>(attrs, ctx, inputs, req, outputs, SGDUpdate<xpu>, "SGDUpdate");
}
Expand Down Expand Up @@ -267,21 +325,22 @@ struct SGDMomDnsRspDnsKernel {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, size_t width, DType* out_data,
DType* mom_data, const DType* weight_data, const IType* grad_idx,
const DType* grad_data, const DType param_clip_gradient, const DType param_momentum,
const DType param_lr, const DType param_wd, const DType param_rescale_grad) {
const DType* grad_data, const DType clip_gradient, const DType momentum,
const DType lr, const DType wd, const DType rescale_grad) {
const DType rate = lr * wd;
for (size_t j = 0; j < width; j++) {
uint64_t data_i = grad_idx[i] * width + j;
uint64_t grad_i = i * width + j;
if (param_clip_gradient >= 0.0f) {
mom_data[data_i] = param_momentum * mom_data[data_i]
- param_lr * param_wd * weight_data[data_i]
- param_lr *
mshadow_op::clip::Map(param_rescale_grad * grad_data[grad_i],
param_clip_gradient);
if (clip_gradient >= 0.0f) {
mom_data[data_i] = momentum * mom_data[data_i]
- rate * weight_data[data_i]
- lr *
mshadow_op::clip::Map(rescale_grad * grad_data[grad_i],
clip_gradient);
} else {
mom_data[data_i] = param_momentum * mom_data[data_i]
- param_lr * param_wd * weight_data[data_i]
- param_lr * param_rescale_grad * grad_data[grad_i];
mom_data[data_i] = momentum * mom_data[data_i]
- rate * weight_data[data_i]
- lr * rescale_grad * grad_data[grad_i];
}
KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] + mom_data[data_i]);
}
Expand Down Expand Up @@ -323,6 +382,100 @@ inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param,
});
}

template<int req>
struct SGDMomRspDnsKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, size_t num_cols, DType* out, DType* mom,
const DType* weight, const DType *grad,
const DType clip_gradient, const DType momentum,
const DType lr, const DType wd, const DType rescale_grad) {
bool contains_non_zeros = false;
index_t j = 0;
index_t offset = i * num_cols;
for (; j < num_cols; ++j) {
if (grad[offset + j] != 0) {
contains_non_zeros = true;
break;
}
}
if (!contains_non_zeros) return;
const DType rate = lr * wd;
for (index_t j = 0; j < num_cols; j++) {
auto index = offset + j;
if (clip_gradient >= 0.0f) {
mom[index] = momentum * mom[index] - rate * weight[index]
- lr * mshadow_op::clip::Map(rescale_grad * grad[index], clip_gradient);
} else {
mom[index] = momentum * mom[index] - rate * weight[index]
- lr * rescale_grad * grad[index];
}
KERNEL_ASSIGN(out[index], req, weight[index] + mom[index]);
}
}
};

template<typename xpu>
inline void InitDnsZeros(mshadow::Stream<xpu> *s, NDArray *out) {
using namespace rowsparse;
using namespace mshadow::expr;
using namespace mshadow;
using namespace mxnet_op;
CHECK_EQ(out->storage_type(), kRowSparseStorage);
MSHADOW_REAL_TYPE_SWITCH(out->dtype(), DType, {
MSHADOW_INT_TYPE_SWITCH(out->aux_type(kIdx), IType, {
auto num_rows = out->shape()[0];
out->CheckAndAlloc({Shape1(num_rows)});
auto idx = out->aux_data(kIdx).FlatTo1D<xpu, IType>(s);
auto val = out->data();
Kernel<set_zero, xpu>::Launch(s, val.Size(), val.dptr<DType>());
ASSIGN_DISPATCH(idx, kWriteTo, range<IType>(0, num_rows, 1, 1))
});
});
}

template<typename xpu>
inline void SGDMomUpdateRspDnsImpl(const SGDMomParam& param,
const OpContext &ctx,
const NDArray& weight,
const TBlob& grad,
const NDArray& mom,
const OpReqType req,
NDArray *out) {
using namespace mshadow;
using namespace mxnet_op;
using namespace rowsparse;
CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDMomUpdate", "weights");
Stream<xpu>* s = ctx.get_stream<xpu>();
CHECK_EQ(weight.storage_type(), kRowSparseStorage);
if (req == kNullOp) return;
CHECK(weight.storage_initialized());
// fill mom with zero values if not initialized yet
if (!mom.storage_initialized()) {
NDArray mom_zeros = mom;
InitDnsZeros(s, &mom_zeros);
}
// TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only
// feed in kWriteTo as req for all operators.
// For sgd we don't want to assign zeros to the output values when req == kWriteTo
auto out_req = req;
if (out_req == kWriteTo) out_req = kWriteInplace;
MSHADOW_REAL_TYPE_SWITCH(weight.dtype(), DType, {
MXNET_ASSIGN_REQ_SWITCH(out_req, req_type, {
auto weight_data = weight.data().dptr<DType>();
auto grad_data = grad.dptr<DType>();
auto mom_data = mom.data().dptr<DType>();
auto num_rows = weight.aux_shape(kIdx)[0];
auto num_cols = weight.shape().ProdShape(1, weight.shape().ndim());
Kernel<SGDMomRspDnsKernel<req_type>, xpu>::Launch(s, num_rows, num_cols,
out->data().dptr<DType>(), mom_data, weight_data, grad_data,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum),
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
static_cast<DType>(param.rescale_grad));
});
});
}


template<typename xpu>
inline void SGDMomUpdateRspRspRspImpl(const SGDMomParam& param,
const OpContext& ctx,
Expand All @@ -335,38 +488,22 @@ inline void SGDMomUpdateRspRspRspImpl(const SGDMomParam& param,
using namespace mshadow::expr;
using namespace mxnet_op;
using namespace rowsparse;
if (weight.storage_shape()[0] == weight.shape()[0] &&
out->storage_shape()[0] == out->shape()[0]) {
Stream<xpu>* s = ctx.get_stream<xpu>();
// fill mom with zero values in order to reuse the sgd mom dns impl
if (!mom.storage_initialized()) {
MSHADOW_REAL_TYPE_SWITCH(mom.dtype(), DType, {
MSHADOW_INT_TYPE_SWITCH(mom.aux_type(kIdx), IType, {
auto num_rows = mom.shape()[0];
mom.CheckAndAlloc({Shape1(num_rows)});
auto mom_idx = mom.aux_data(kIdx).FlatTo1D<xpu, IType>(s);
auto mom_val = mom.data();
// TODO(haibin) this is single-thread execution
Kernel<set_zero, xpu>::Launch(s, mom_val.Size(), mom_val.dptr<DType>());
ASSIGN_DISPATCH(mom_idx, kWriteTo, range<IType>(0, num_rows, 1, 1))
});
});
}
// TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only
// feed in kWriteTo as req for all operators.
// For sgd we don't want to assign zeros to the output values when req == kWriteTo
auto out_req = req;
if (out_req == kWriteTo) out_req = kWriteInplace;
TBlob out_blob = out->data();
// reuse dns rsp implementation when storage_shape == shape
SGDMomUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
mom.data(), out_req, &out_blob);
} else {
LOG(FATAL) << "SGDUpdate for RowSparse weights is only implemented for "
<< "RowSparse weights with all rows containing non-zeros. "
<< "Expects weights.values.shape[0] (" << weight.storage_shape()[0]
<< ") == weights.shape[0] (" << weight.shape()[0] << ").";
CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDMomUpdate", "weights");
Stream<xpu>* s = ctx.get_stream<xpu>();
// fill mom with zero values in order to reuse the sgd mom dns impl
if (!mom.storage_initialized()) {
NDArray mom_zeros = mom;
InitDnsZeros(s, &mom_zeros);
}
// TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only
// feed in kWriteTo as req for all operators.
// For sgd we don't want to assign zeros to the output values when req == kWriteTo
auto out_req = req;
if (out_req == kWriteTo) out_req = kWriteInplace;
TBlob out_blob = out->data();
// reuse dns rsp implementation when storage_shape == shape
SGDMomUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
mom.data(), out_req, &out_blob);
}

template<typename xpu>
Expand All @@ -377,23 +514,28 @@ inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray> &outputs) {
using namespace mxnet_op;
const SGDMomParam& param = nnvm::get<SGDMomParam>(attrs.parsed);
auto weight_stype = inputs[0].storage_type();
auto grad_stype = inputs[1].storage_type();
auto mom_stype = inputs[2].storage_type();
auto &weight = inputs[0];
auto &grad = inputs[1];
auto &mom = inputs[2];
auto weight_stype = weight.storage_type();
auto grad_stype = grad.storage_type();
auto mom_stype = mom.storage_type();
if (weight_stype == kDefaultStorage && grad_stype == kRowSparseStorage &&
mom_stype == kDefaultStorage) {
TBlob out = outputs[0].data();
SGDMomUpdateDnsRspDnsImpl<xpu>(param, ctx, inputs[0].data(), inputs[1],
inputs[2].data(), req[0], &out);
SGDMomUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
mom.data(), req[0], &out);
} else if (weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage &&
mom_stype == kRowSparseStorage) {
NDArray out = outputs[0];
SGDMomUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1],
inputs[2], req[0], &out);
SGDMomUpdateRspRspRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out);
} else if (weight_stype == kRowSparseStorage && grad_stype == kDefaultStorage &&
mom_stype == kRowSparseStorage) {
NDArray out = outputs[0];
SGDMomUpdateRspDnsImpl<xpu>(param, ctx, weight, grad.data(), mom, req[0], &out);
} else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage &&
mom_stype == kDefaultStorage) {
FCompExFallback<xpu>(attrs, ctx, inputs, req, outputs,
SGDMomUpdate<xpu>, "SGDMomUpdate");
FCompExFallback<xpu>(attrs, ctx, inputs, req, outputs, SGDMomUpdate<xpu>, "SGDMomUpdate");
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/operator/optimizer_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ It updates the weights using::
weight = weight - learning_rate * gradient
If gradients are stored with `row_sparse` storage,
where update is applied only to rows whose gradient has non-zero entries.
If weights are stored with `row_sparse` storage,
update is applied only to rows whose gradient has non-zero entries.
)code" ADD_FILELINE)
.set_num_inputs(2)
Expand Down Expand Up @@ -56,7 +56,7 @@ It updates the weights using::
Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch.
If gradients are stored with `row_sparse` storage,
If weights are stored with `row_sparse` storage,
only rows whose gradients contain non-zero entries are updated (for both weight and momentum).
)code" ADD_FILELINE)
Expand Down
Loading

0 comments on commit 88d1f4c

Please sign in to comment.