Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add operator for dot(dns, csr) = csr #8938

Merged
merged 15 commits into from
Jan 4, 2018
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 36 additions & 15 deletions benchmark/python/sparse/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,10 @@ def bench_dot(lhs_shape, rhs_shape, lhs_stype, rhs_stype,
# Create matrix instances
lhs_nd = rand_ndarray(lhs_shape, lhs_stype, density=lhs_den, distribution=distribution)
# only uniform distribution supported for rhs
rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=rhs_den, distribution="uniform")
if rhs_stype == 'csr':
rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=rhs_den, distribution=distribution)
else:
rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=rhs_den, distribution="uniform")
lhs_dns = None
rhs_dns = None
dense_cost = None
Expand Down Expand Up @@ -337,27 +340,41 @@ def print_benchmark_info(lhs, rhs, lhs_trans, fw):

def run_benchmark(ctx=None, lhs="csr", lhs_trans=False, rhs="dns", fw="mxnet", rhs_density=1,
distribution="uniform"):
if lhs != "csr":
raise ValueError("Value other than csr for lhs not supported")

if rhs_density > 1 or rhs_density < 0:
raise ValueError("rhs_density has to be between 0 and 1")

print_benchmark_info(lhs, rhs, lhs_trans, fw)

if rhs == "csr":
lhs_stype = "default"
rhs_stype = "csr"
assert (lhs_stype == 'default'), "Only dot(default, csr) supported"
# Arrange dimensions according to use case. For below csr will have num_rows << num_cols
feature_dim_list = data_dict['batch_size']
batch_size_list = data_dict['m']
output_dim_list = data_dict['feature_dim']
density_list = data_dict['density']
default_output_index = data_dict['default_index']['feature_dim']
default_density_index = data_dict['default_index']['density']
default_feature_index = data_dict['default_index']['batch_size']
default_batch_size_index = data_dict['default_index']['output_dim']
num_repeat = data_dict['num_repeat']

lhs_stype = "csr"
rhs_stype = "row_sparse" if rhs == "rsp" else "default"
else:
lhs_stype = "csr"
rhs_stype = "row_sparse" if rhs == "rsp" else "default"

feature_dim_list = data_dict['feature_dim']
output_dim_list = data_dict['m']
batch_size_list = data_dict['batch_size']
density_list = data_dict['density']
feature_dim_list = data_dict['feature_dim']
output_dim_list = data_dict['m']
batch_size_list = data_dict['batch_size']
density_list = data_dict['density']

default_output_index = data_dict['default_index']['output_dim']
default_batch_size_index = data_dict['default_index']['batch_size']
default_feature_index = data_dict['default_index']['feature_dim']
default_density_index = data_dict['default_index']['density']
num_repeat = data_dict['num_repeat']
default_output_index = data_dict['default_index']['output_dim']
default_batch_size_index = data_dict['default_index']['batch_size']
default_feature_index = data_dict['default_index']['feature_dim']
default_density_index = data_dict['default_index']['density']
num_repeat = data_dict['num_repeat']

for output_dim in output_dim_list:
if lhs_trans:
Expand Down Expand Up @@ -403,7 +420,7 @@ def run_benchmark(ctx=None, lhs="csr", lhs_trans=False, rhs="dns", fw="mxnet", r
feature_dim_list[default_feature_index]),
(output_row_dim,
output_dim_list[default_output_index]),
lhs_stype, rhs_stype, density, rhs_density, lhs_trans, ctx,
lhs_stype, rhs_stype, density, density, lhs_trans, ctx,
num_repeat=num_repeat, fw=fw, distribution=distribution)

check_call(_LIB.MXSetNumOMPThreads(ctypes.c_int(ARGS.num_omp_threads)))
Expand All @@ -423,6 +440,10 @@ def run_benchmark(ctx=None, lhs="csr", lhs_trans=False, rhs="dns", fw="mxnet", r
rhs="rsp", lhs_trans=False,
fw="mxnet", rhs_density=0.05,
distribution=distribution)
run_benchmark(context, lhs="default",
rhs="csr", lhs_trans=False,
fw="mxnet", rhs_density=0.001,
distribution=distribution)
if not ARGS.gpu:
run_benchmark(context, lhs="csr",
rhs="default", lhs_trans=False,
Expand Down
202 changes: 193 additions & 9 deletions src/operator/tensor/dot-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,22 +202,25 @@ void DotBackward_(const nnvm::NodeAttrs& attrs,
inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
// csr has many zero columns, so the result of dot(csr.T, matrix) should be rsp
// csr has many zero columns, so the result of dot(csr.T, matrix) should be
// rsp
const auto& lhs_stype = in_attrs->at(0);
const auto& rhs_stype = in_attrs->at(1);
auto& out_stype = out_attrs->at(0);
bool dispatched = false;
bool only_lhs_transpose = param.transpose_a && !param.transpose_b;
bool rhs_rsp_or_dns = rhs_stype == kRowSparseStorage || rhs_stype == kDefaultStorage;
if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kDefaultStorage) {
bool rhs_rsp_or_dns =
rhs_stype == kRowSparseStorage || rhs_stype == kDefaultStorage;
if (!dispatched && lhs_stype == kDefaultStorage &&
rhs_stype == kDefaultStorage) {
// dns, dns -> dns
dispatched = storage_type_assign(&out_stype, kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
DispatchMode::kFCompute);
}
if (!dispatched && lhs_stype == kCSRStorage && only_lhs_transpose &&
(rhs_stype == kRowSparseStorage || rhs_stype == kDefaultStorage)) {
Expand All @@ -228,8 +231,16 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs,
if (!dispatched && lhs_stype == kCSRStorage && rhs_rsp_or_dns &&
!param.transpose_a && !param.transpose_b) {
// csr, rsp/dns -> dns
dispatched = storage_type_assign(&out_stype, kDefaultStorage,
dispatch_mode, DispatchMode::kFComputeEx);
dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
DispatchMode::kFComputeEx);
}
if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the implementation only available on CPU? No fallback on GPU ctx?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added a check for CPU. will fallback to default storage for gpu

!param.transpose_a && !param.transpose_b) {
// dns, csr -> csr
if (dev_mask == mshadow::cpu::kDevMask) {
dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode,
DispatchMode::kFComputeEx);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is output stype consistent on cpu and gpu? The output stype should be consistent to avoid confusion to users (see https://github.com/apache/incubator-mxnet/blob/d2a856a3a2abb4e72edc301b8b821f0b75f30722/src/operator/tensor/matrix_op-inl.h#L400-L418)
The only difference is that on GPU, it performs fallback. If the output stype infers sparse, then it first produce dense output, then cast it to sparse. The fallback is handled in executor already

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Fixed.

}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. we should log storage fallback as long as dispatch mode is dispatch_fallback:
https://github.com/apache/incubator-mxnet/blob/d2a856a3a2abb4e72edc301b8b821f0b75f30722/src/operator/elemwise_op_common.h#L79-L81

Maybe I should move this logic to the common path instead of letting developers specify that in operators
https://github.com/apache/incubator-mxnet/blob/master/src/executor/infer_graph_attr_pass.cc#L45-L54

Copy link
Member Author

@anirudh2290 anirudh2290 Jan 3, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We can fix that in a separate PR.

if (!dispatched) {
dispatch_fallback(out_attrs, dispatch_mode);
Expand Down Expand Up @@ -527,6 +538,80 @@ struct DotCsrTransRspRspByRowBlocks {
}
};

/*!
* \brief CPU Kernel of PopulateCsrForNNC
* Parallelization by individual rows
* Populates the indptr and indices array
* based on number of non zero columns
*/
struct PopulateCsrForNNC {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add brief description on what this kernel is for?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

/*!
* \brief
* \param i the i-th thread
* \param nnc_idx all non zero column indexes
* \param indptr_out indptr array for output
* \param col_idx_out column indices for output
* \param nnc number of non zero columns in the output
* \param num_rows_l number of rows in lhs
*/
template <typename IType, typename CType>
MSHADOW_CINLINE static void Map(int i, const CType* nnc_idx,
IType* indptr_out, CType* col_idx_out,
const nnvm::dim_t nnc,
const nnvm::dim_t num_rows_l) {
const CType start_idx = i * nnc;
nnvm::dim_t cur = 0;
indptr_out[i] = start_idx;
if (i == static_cast<int>(num_rows_l - 1)) indptr_out[i + 1] = indptr_out[i] + nnc;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we are adding large array support in the future, it's more appropriate to cast i up to dim_t instead of cast num_rows_l down to int.

for (IType idx = start_idx; idx < (start_idx + nnc); idx++) {
col_idx_out[idx] = nnc_idx[cur++];
}
}
};

/*!
* \brief CPU Impl of dot(dns, csr) = csr
*/
struct DotDnsCsrCsrByRowBlocks {
/*!
* \brief
* \param i the i-th thread
* \param num_rows_r number of rows in rhs
* \param num_rows_l number of rows in lhs
* \param num_cols number of columns in output
* \param nnc number of non zero columns
*/

template <typename DType, typename IType, typename CType>
MSHADOW_CINLINE static void Map(
int i, DType* out, const DType* data_l, const IType* indptr_r,
const CType* col_idx_r, const DType* data_r, const nnvm::dim_t seg_len,
const IType num_rows_r, const IType num_rows_l,
const nnvm::dim_t num_cols, const nnvm::dim_t nnc,
const CType* prefix_sum) {
using nnvm::dim_t;
const dim_t seg_start = i * seg_len;
if (seg_start >= num_rows_l) return;
const dim_t seg_end = std::min(seg_start + seg_len, num_rows_l);

for (dim_t j = seg_start; j < seg_end; j++) {
for (dim_t k = 0; k < num_rows_r; k++) {
const dim_t working_idx = j * num_rows_r + k;
const DType val = data_l[working_idx];
if (indptr_r[k] == indptr_r[k + 1]) continue;
const dim_t row_start = j * nnc;
for (dim_t cur = indptr_r[k]; cur < indptr_r[k + 1]; cur++) {
dim_t cur_col_idx_r = col_idx_r[cur];
const dim_t out_idx = row_start + prefix_sum[cur_col_idx_r] - 1;
out[out_idx] += val * data_r[cur];
}
}
}
}
};



/*!
* \brief CPU Impl of dot(csr, dns1) = dns2 and dot(csr.T, dns1) = dns2
*/
Expand Down Expand Up @@ -811,6 +896,100 @@ inline void DotCsrRspRspImpl(const OpContext& ctx,
});
}

/*
* \brief CPU Impl of dot(dns, csr) = csr
*/
template<typename xpu>
inline void DotDnsCsrCsrImpl(const OpContext& ctx,
const TBlob& lhs, const NDArray& rhs,
const OpReqType req, NDArray* ret) {
if (kNullOp == req) return;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is kAddTo and kWriteInplace not checked?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this. Fixed.


CHECK_EQ(req, kWriteTo);
CHECK_EQ(rhs.storage_type(), kCSRStorage);

using namespace mshadow;
using namespace mshadow::expr;
using nnvm::dim_t;

/*Initialize data structures*/
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: space after /*

mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
const NDArray& out = *ret;
const TBlob data_l = lhs;
const TBlob data_r = rhs.data();
const TBlob indptr_r = rhs.aux_data(csr::kIndPtr);
const TBlob col_idx_r = rhs.aux_data(csr::kIdx);
if (!rhs.storage_initialized()) {
FillZerosCsrImpl(s, *ret);
return;
}

MSHADOW_SGL_DBL_TYPE_SWITCH(data_r.type_flag_, DType, { // data type
MSHADOW_IDX_TYPE_SWITCH(indptr_r.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(col_idx_r.type_flag_, CType, { // colidx type
/* Allocate workspace */
CType num_cols_out = out.shape()[1];
CType rhs_data_size = static_cast<CType>(col_idx_r.shape_.Size());
size_t workspace_size = 2 * num_cols_out * sizeof(CType);
Tensor<cpu, 1, char> workspace =
ctx.requested[0].get_space_typed<cpu, 1, char>(
Shape1(workspace_size), s);
CType* col_flg = reinterpret_cast<dim_t*>(workspace.dptr_);

CType* prefix_sum = col_flg;
CType* nnc_idx = prefix_sum + num_cols_out;

/* Set the column flags for nnz columns */
mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(s, num_cols_out,
col_flg);
mxnet_op::Kernel<MarkRowFlgKernel, cpu>::Launch(
s, rhs_data_size, col_flg, col_idx_r.dptr<CType>());

/* 1. Calculate prefix sum from col flgs
* 2. Storage all non zero column indexes in nnc_idx
*/
CType cur = 0;
prefix_sum[0] = col_flg[0];
if (prefix_sum[0]) nnc_idx[cur++] = 0;
for (CType i = 1; i < num_cols_out; i++) {
prefix_sum[i] = prefix_sum[i - 1] + col_flg[i];
if (prefix_sum[i] > prefix_sum[i - 1]) nnc_idx[cur++] = i;
}

/* Allocate aux data for out */
IType num_rows_l = lhs.shape_[0];
dim_t nnc = prefix_sum[num_cols_out - 1];
dim_t nnz = nnc * num_rows_l;
out.CheckAndAllocAuxData(csr::kIndPtr, Shape1(num_rows_l + 1));
out.CheckAndAllocAuxData(csr::kIdx, Shape1(nnz));
out.CheckAndAllocData(Shape1(nnz));

/* Set csr indptr and index according to nnc_idx*/
IType* indptr_out = out.aux_data(csr::kIndPtr).dptr<IType>();
CType* col_idx_out = out.aux_data(csr::kIdx).dptr<CType>();
DType* data_out = out.data().dptr<DType>();
mxnet_op::Kernel<PopulateCsrForNNC, cpu>::Launch(
s, num_rows_l, nnc_idx, indptr_out, col_idx_out, nnc, num_rows_l);
mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(s, nnz, data_out);

if (nnc == 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't nnc never be 0 here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should nnc never be 0 ? This is possible when number of non zero columns are zero(matrix with all zeros) in the rhs. In this case we return the output correctly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because you already checked rhs.storage_initialized() in line 922?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed the if and also added some documentation for storage_initialized

return;
}

const dim_t num_threads = mxnet_op::get_num_threads<cpu>(num_rows_l);
const dim_t seg_len = (num_rows_l + num_threads - 1) / num_threads;

IType num_rows_r = rhs.shape()[0];
mxnet_op::Kernel<DotDnsCsrCsrByRowBlocks, cpu>::Launch(
s, num_threads, data_out, data_l.dptr<DType>(),
indptr_r.dptr<IType>(), col_idx_r.dptr<CType>(),
data_r.dptr<DType>(), seg_len, num_rows_r, num_rows_l, num_cols_out,
nnc, prefix_sum);
});
});
});
}

inline bool DotShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
Expand Down Expand Up @@ -886,6 +1065,11 @@ void DotForwardEx(const nnvm::NodeAttrs& attrs,
&& out_stype == kRowSparseStorage && !param.transpose_b) {
NDArray ret = outputs[0];
DotCsrRspRspImpl(ctx, xpu(), inputs[0], inputs[1], req[0], param.transpose_a, &ret);
} else if (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage &&
out_stype == kCSRStorage &&
!(param.transpose_a || param.transpose_b)) {
NDArray ret = outputs[0];
DotDnsCsrCsrImpl<xpu>(ctx, inputs[0].data(), inputs[1], req[0], &ret);
} else {
LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
}
Expand Down
27 changes: 27 additions & 0 deletions tests/python/unittest/test_sparse_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,6 +1223,31 @@ def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, lhs_density, rhs_de
grad_req={'lhs': 'null', 'rhs': 'write'},
rtol=1e-3, atol=1e-4)

def test_dot_dns_csr(lhs_shape, rhs_shape, lhs_density, rhs_density, trans_lhs=False, trans_rhs=False):
lhs_nd = rand_ndarray(lhs_shape, stype='default', density=lhs_density)
rhs_nd = rand_ndarray(rhs_shape, stype='csr', density=rhs_density)
rhs_dns = rhs_nd.tostype('default')

out = mx.nd.sparse.dot(lhs_nd, rhs_nd, transpose_a=trans_lhs, transpose_b=trans_rhs)
out_dns = mx.nd.dot(lhs_nd, rhs_dns, transpose_a=trans_lhs, transpose_b=trans_rhs)
out_np = out_dns.asnumpy()
assert_almost_equal(out.asnumpy(), out_np, rtol=1e-4, atol=1e-5)

# test symbolic forward
lhs = mx.symbol.Variable('lhs', stype='default')
rhs = mx.symbol.Variable('rhs', stype='csr')
out = mx.symbol.sparse.dot(lhs, rhs, transpose_a=trans_lhs, transpose_b=trans_rhs)
location = {'lhs': lhs_nd, 'rhs': rhs_nd}
check_symbolic_forward(out, location, [out_np], rtol=1e-3, atol=1e-4)

# test symbolic backward
backward_trans = not trans_lhs
rhs_backward_grad = mx.nd.dot(lhs_nd, out_dns, transpose_a=backward_trans).asnumpy()
expected = {'rhs': rhs_backward_grad}
check_symbolic_backward(out, location, [out_np], expected,
grad_req={'lhs': 'null', 'rhs': 'write'},
rtol=1e-3, atol=1e-4)

def test_sparse_dot_zero_output(lhs_shape, trans_lhs, rhs_num_cols):
"""Test for nnr_out = 0. Before the fix, the test would fail."""
lhs = mx.nd.zeros(lhs_shape)
Expand All @@ -1248,10 +1273,12 @@ def test_sparse_dot_zero_output(lhs_shape, trans_lhs, rhs_num_cols):
test_dot_csr(lhs_shape, (lhs_shape[0], 1), 'default', True, lhs_d, rhs_d) # (vector kernel)
test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(5, 10)), 'default', False, lhs_d, rhs_d) # test gpu SpMM
test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(5, 10)), 'default', True, lhs_d, rhs_d) # (scalar kernel)
test_dot_dns_csr(lhs_shape, (lhs_shape[1], rnd.randint(500, 1000)), lhs_d, lhs_d)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

randint(50,200) is large (and slow) enough for testing. No need to increase the dim to 1000.

for rhs_d in density:
test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False, lhs_d, rhs_d)
test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True, lhs_d, rhs_d)


test_sparse_dot_zero_output(rand_shape_2d(50, 200), False, 40)
test_sparse_dot_zero_output(rand_shape_2d(50, 200), True, 40)

Expand Down