Skip to content

Commit

Permalink
Handle ograd_stype='row_sparse' for square_sum backward (#143)
Browse files Browse the repository at this point in the history
* Add one kernel for square_sum backward pass to take rsp ograd

* Add kNullOp and change to use type_assign in infer stype fallback
  • Loading branch information
reminisce authored and eric-haibin-lin committed Aug 3, 2017
1 parent 56b5a63 commit 325f4db
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 29 deletions.
143 changes: 116 additions & 27 deletions src/operator/tensor/square_sum-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@ inline bool SquareSumForwardInferStorageType(const nnvm::NodeAttrs& attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
CHECK_EQ((*in_attrs)[0], kRowSparseStorage)
<< "_square_sum only supports row-sparse ndarray as input";
const ReduceAxesParam& param = nnvm::get<ReduceAxesParam>(attrs.parsed);
if (param.axis[0] == 1 && param.keepdims) { // sum per row and keep dims
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kRowSparseStorage);
} else {
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kDefaultStorage);
if (in_attrs->at(0) == kRowSparseStorage) { // current impl
if (param.axis[0] == 1 && param.keepdims) { // sum per row and keep dims
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kRowSparseStorage);
} else {
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kDefaultStorage);
}
} else { // fallback
type_assign(&((*in_attrs)[0]), kDefaultStorage);
type_assign(&((*out_attrs)[0]), kDefaultStorage);
}
return true;
}
Expand All @@ -45,9 +48,15 @@ inline bool SquareSumBackwardInferStorageType(const nnvm::NodeAttrs& attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, 0, kDefaultStorage);
STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, 1, kRowSparseStorage);
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kRowSparseStorage);
const ReduceAxesParam& param = nnvm::get<ReduceAxesParam>(attrs.parsed);
if (in_attrs->at(0) == kDefaultStorage || in_attrs->at(0) == kRowSparseStorage) {
STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, 1, kRowSparseStorage);
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kRowSparseStorage);
} else { // fallback
type_assign(&((*in_attrs)[0]), kDefaultStorage);
type_assign(&((*in_attrs)[1]), kDefaultStorage);
type_assign(&((*out_attrs)[0]), kDefaultStorage);
}
return true;
}

Expand Down Expand Up @@ -125,7 +134,7 @@ struct SquareSumRspKernel<req, 1, true> {
}
};

template<int req, int axis>
template<int req, int axis, int ograd_stype = kDefaultStorage>
struct SquareSumRspGradKernel;

template<int req>
Expand Down Expand Up @@ -168,12 +177,42 @@ struct SquareSumRspGradKernel<req, 1> {
}
};

/*!
* This kernel assumes that the ograd and in_data
* are all rsp and have equal row_idx array.
* TODO(junwu): make the kernel general to support
* the cases when ograd and in_data have different
* row_idx arrays.
*/
template<int req>
struct SquareSumRspGradKernel<req, 1, kRowSparseStorage> {
/*!
* \param i index of out_grad_row_idx
* \param in_grad_row_idx row_idx of the gradient of the op's input
* \param in_grad gradient of the op's input
* \param out_grad_row_idx row_idx of the gradient of the op's output
* \param out_grad gradient of the op's output
* \param in_row_idx row idx of the op's input
* \param in_data op's input
*/
template<typename IType, typename DType>
MSHADOW_XINLINE static void Map(int i, IType* in_grad_row_idx, DType* in_grad,
const IType* out_grad_row_idx, const DType* out_grad,
const IType* in_row_idx, const DType* in_data,
const int64_t num_cols) {
const int64_t row = i / num_cols;
in_grad_row_idx[row] = in_row_idx[row];
KERNEL_ASSIGN(in_grad[i], req, 2*in_data[i]*out_grad[row]);
}
};

template<typename xpu>
void SquareSumRspImpl(const nnvm::NodeAttrs& attrs,
mshadow::Stream<xpu>* s,
const NDArray& input,
const OpReqType req,
NDArray* output) {
if (req == kNullOp) return;
const ReduceAxesParam& param = nnvm::get<ReduceAxesParam>(attrs.parsed);
CHECK_EQ(param.axis.ndim(), 1U) << "_square_sum(row_sparse_matrix) only supports axis=0 or 1";
CHECK(param.axis[0] == 0 || param.axis[0] == 1)
Expand Down Expand Up @@ -261,39 +300,88 @@ void SquareSumRspGradImpl(const nnvm::NodeAttrs& attrs,
const NDArray& input,
const OpReqType req,
NDArray* igrad) {
if (req == kNullOp) return;
const ReduceAxesParam& param = nnvm::get<ReduceAxesParam>(attrs.parsed);
CHECK_EQ(param.axis.ndim(), 1U) << "_square_sum(row_sparse_matrix) only supports axis=0";
CHECK_EQ(param.axis.ndim(), 1U) << "_square_sum(row_sparse_matrix) only supports axis=0/1";
CHECK(param.axis[0] == 0 || param.axis[0] == 1)
<< "_square_sum(row_sparse_matrix) only supports axis=0 or 1";
CHECK_EQ(ograd.storage_type(), kDefaultStorage);
CHECK(ograd.storage_type() == kDefaultStorage || ograd.storage_type() == kRowSparseStorage);
CHECK_EQ(input.storage_type(), kRowSparseStorage);
CHECK_EQ(igrad->storage_type(), kRowSparseStorage);
CHECK_NE(req, kWriteInplace);
if (!input.storage_initialized()) return;
CHECK_EQ(req, kWriteTo);
if (!input.storage_initialized()) {
FillZerosRspImpl<xpu>(s, igrad);
return;
}

using namespace mxnet_op;
// TODO(junwu) change the input of CheckAndAlloc
// if we want to support differen row idx arrays
// for ograd and input when they are both row-sparse ndarrays
igrad->CheckAndAlloc({input.aux_shape(rowsparse::kIdx)});
const int64_t num_cols = input.storage_shape()[1];
const TBlob& igrad_data = igrad->data();
const TBlob igrad_row_idx = igrad->aux_data(rowsparse::kIdx);
const TBlob& ograd_data = ograd.data();
const TBlob in_data = input.data();
const TBlob in_row_idx = input.aux_data(rowsparse::kIdx);
MSHADOW_TYPE_SWITCH(igrad_data.type_flag_, DType, {
if (ograd.storage_type() == kDefaultStorage) {
if (0 == param.axis[0]) { // forward is sum per column
MSHADOW_TYPE_SWITCH(igrad_data.type_flag_, DType, {
MSHADOW_IDX_TYPE_SWITCH(igrad_row_idx.type_flag_, IType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
Kernel<SquareSumRspGradKernel<req_type, 0, kDefaultStorage>, xpu>::Launch(
s, igrad_data.Size(), igrad_row_idx.dptr<IType>(),
igrad_data.dptr<DType>(), ograd_data.dptr<DType>(),
in_row_idx.dptr<IType>(), in_data.dptr<DType>(), num_cols);
})
})
})
} else { // forward is sum per row
MSHADOW_TYPE_SWITCH(igrad_data.type_flag_, DType, {
MSHADOW_IDX_TYPE_SWITCH(igrad_row_idx.type_flag_, IType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
Kernel<SquareSumRspGradKernel<req_type, 1, kDefaultStorage>, xpu>::Launch(
s, igrad_data.Size(), igrad_row_idx.dptr<IType>(),
igrad_data.dptr<DType>(), ograd_data.dptr<DType>(),
in_row_idx.dptr<IType>(), in_data.dptr<DType>(), num_cols);
})
})
})
}
} else if (ograd.storage_type() == kRowSparseStorage) {
CHECK_EQ(1, param.axis[0]) << "SquareSumRspGradImpl only supports axis = 1"
" when ograd_stype = kRowSparseStorage";
CHECK_EQ(ograd.shape().ndim(), 2U);
const TBlob ograd_row_idx = ograd.aux_data(rowsparse::kIdx);
CHECK_EQ(ograd_row_idx.Size(), in_row_idx.Size());
MSHADOW_IDX_TYPE_SWITCH(igrad_row_idx.type_flag_, IType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
if (0 == param.axis[0]) { // forward is sum per column
Kernel<SquareSumRspGradKernel<req_type, 0>, xpu>::Launch(s, igrad_data.Size(),
igrad_row_idx.dptr<IType>(), igrad_data.dptr<DType>(), ograd_data.dptr<DType>(),
in_row_idx.dptr<IType>(), in_data.dptr<DType>(), num_cols);
} else { // forward is sum per row
Kernel<SquareSumRspGradKernel<req_type, 1>, xpu>::Launch(s, igrad_data.Size(),
igrad_row_idx.dptr<IType>(), igrad_data.dptr<DType>(), ograd_data.dptr<DType>(),
in_row_idx.dptr<IType>(), in_data.dptr<DType>(), num_cols);
}
if (std::is_same<xpu, cpu>::value) {
const IType* first1 = ograd_row_idx.dptr<IType>();
const IType* last1 = first1 + ograd_row_idx.Size();
const IType* first2 = in_row_idx.dptr<IType>();
CHECK(std::equal(first1, last1, first2)) << "SquareSumRspGradImpl only supports"
" equal ograd_row_idx and input_row_idx"
" when ograd and input are both"
" row-sparse";
} else {
LOG(FATAL) << "SquareSumRspGradImpl has not implemented GPU version when"
" ograd and input are both row-sparse";
}
MSHADOW_TYPE_SWITCH(igrad_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
Kernel<SquareSumRspGradKernel<req_type, 1, kRowSparseStorage>, xpu>::Launch(
s, igrad_data.Size(), igrad_row_idx.dptr<IType>(),
igrad_data.dptr<DType>(), ograd_row_idx.dptr<IType>(),
ograd_data.dptr<DType>(), in_row_idx.dptr<IType>(),
in_data.dptr<DType>(), num_cols);
})
})
})
})
} else {
LOG(FATAL) << "SquareSumRspGradImpl only supports ograd_stype"
<< " = kDefaultStorage/kRowSparseStorage";
}
}

template<typename xpu>
Expand Down Expand Up @@ -331,7 +419,8 @@ void SquareSumOpBackwardEx(const nnvm::NodeAttrs& attrs,
mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
const NDArrayStorageType ograd_stype = inputs[0].storage_type();
const NDArrayStorageType input_stype = inputs[1].storage_type();
if (input_stype == kRowSparseStorage && ograd_stype == kDefaultStorage) {
if (input_stype == kRowSparseStorage
&& (ograd_stype == kDefaultStorage || ograd_stype == kRowSparseStorage)) {
CHECK_EQ(inputs[1].shape().ndim(), 2U) << "_square_sum op only supports"
" 2D ndarray as input";
NDArray output = outputs[0];
Expand Down
20 changes: 18 additions & 2 deletions tests/python/unittest/test_sparse_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,25 @@ def test_sparse_square_sum():
# check forward result
assert same(ret.asnumpy(), ret_expected.asnumpy())

rsp_data = mx.sym.Variable('data', stype='row_sparse')
test = mx._symbol_internal._square_sum(rsp_data, axis=axis, keepdims=keepdim)

# check symbolic backward since ograd can be a rsp
# and cannot be checked through check_numeric_gradient
# because it will add a loss layer as the output layer
# which makes ograd of the square_sum dense
if axis == 1 and keepdims:
dns_data = mx.sym.Variable('data')
baseline = mx.sym.sum(mx.sym.square(dns_data), axis=axis, keepdims=keepdim)
igrad_expected = mx.nd.empty(dns.shape)
baseline_exec = baseline.bind(default_context(), args=[dns],
args_grad=[igrad_expected])
baseline_exec.forward(is_train=True)
baseline_exec.backward([ret_expected])
check_symbolic_backward(test, [rsp], [ret], [igrad_expected.asnumpy()],
grad_stypes={'data': 'row_sparse'})

# check numeric gradient
data = mx.sym.Variable('data', stype='row_sparse')
test = mx._symbol_internal._square_sum(data, axis=axis, keepdims=keepdim)
check_numeric_gradient(test, [rsp], grad_stype_dict={'data': 'row_sparse'},
atol=1e-2, rtol=0.1)

Expand Down

0 comments on commit 325f4db

Please sign in to comment.