Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add operator for dot(dns, csr) = csr #8938

Merged
merged 15 commits into from
Jan 4, 2018

Conversation

anirudh2290
Copy link
Member

@anirudh2290 anirudh2290 commented Dec 4, 2017

Description

Adds operator for dot(dns, csr) = csr. Backward pass will fallback to default implementations.
The performance is better than dot(dns, dns) for sparsity less than 0.5% (c4.8xlarge). Below are the results for tests on c4.8xlarge with OMP_NUM_THREADS set to 32.

========================================================
  mxnet sparse dot benchmark: dot(default, csr) = csr
  (matrix multiplication: (m x k) * (k x n) = m x n)
========================================================
 lhs_density(%)  rhs_density(%)    context        m        k        n  t_sparse(ms)   t_dense(ms)  speedup
            1.0             0.1     cpu(0)     1000      128  1000000        337.74       1810.42     5.36
            1.0             0.1     cpu(0)     1000       64  1000000        172.71       1653.84     9.58
            1.0             0.1     cpu(0)     1000      128  1000000        345.05       1810.87     5.25
            1.0             0.1     cpu(0)      256      128  1000000         89.88        466.65     5.19
            1.0             0.1     cpu(0)     1000      128  1000000        335.76       1785.21     5.32
            0.1             0.1     cpu(0)     1000      128  1000000        332.17       1815.71     5.47
            0.5             0.5     cpu(0)     1000      128  1000000       1718.80       1764.55     1.03
            1.0             1.0     cpu(0)     1000      128  1000000       3434.07       1681.34     0.49
            2.0             2.0     cpu(0)     1000      128  1000000       7621.03       1689.58     0.22
            5.0             5.0     cpu(0)     1000      128  1000000      15715.25       1738.18     0.11
           10.0            10.0     cpu(0)     1000      128  1000000      22354.31       1602.44     0.07
           20.0            20.0     cpu(0)     1000      128  1000000      26913.53       1735.68     0.06

Checklist

Essentials

  • Passed code style checking (make lint)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • For user-facing API changes, API doc string has been updated. For new C++ functions in header files, their functionalities and arguments are well-documented.
  • To my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • dot(dns, csr) = csr, tests, (and when applicable, API doc)

@eric-haibin-lin eric-haibin-lin self-assigned this Dec 4, 2017
if rhs_density > 1 or rhs_density < 0:
raise ValueError("rhs_density has to be between 0 and 1")
raise ValueError("Value other than csr for lhs not supported")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check and the error statement don't seem to match?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Fixed.

* \brief CPU Kernel of PopulateCsrForNNC
* Parallelization by individual rows
*/
struct PopulateCsrForNNC {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add brief description on what this kernel is for?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

@@ -231,6 +231,12 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs,
dispatched = storage_type_assign(&out_stype, kDefaultStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the implementation only available on CPU? No fallback on GPU ctx?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added a check for CPU. will fallback to default storage for gpu

const OpReqType req, NDArray* ret) {
if (kNullOp == req) return;
CHECK_EQ(rhs.storage_type(), kCSRStorage);
if (!rhs.storage_initialized()) return;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we set the result to be ZerosCsrImpl before return?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

inline void DotDnsCsrCsrImpl(const OpContext& ctx, const cpu& cpu_dev,
const TBlob& lhs, const NDArray& rhs,
const OpReqType req, NDArray* ret) {
if (kNullOp == req) return;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is kAddTo and kWriteInplace not checked?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this. Fixed.

return;
}

dim_t num_threads = mxnet_op::get_num_threads<cpu>(num_rows_l);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: const for both

s, num_rows_l, nnc_idx, indptr_out, col_idx_out, nnc, num_rows_l);
mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(s, nnz, data_out);

if (nnc == 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't nnc never be 0 here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should nnc never be 0 ? This is possible when number of non zero columns are zero(matrix with all zeros) in the rhs. In this case we return the output correctly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because you already checked rhs.storage_initialized() in line 922?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed the if and also added some documentation for storage_initialized

s, num_rows_l, nnc_idx, indptr_out, col_idx_out, nnc, num_rows_l);
mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(s, nnz, data_out);

if (nnc == 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because you already checked rhs.storage_initialized() in line 922?

// dns, csr -> csr
if (dev_mask == mshadow::cpu::kDevMask) {
dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode,
DispatchMode::kFComputeEx);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is output stype consistent on cpu and gpu? The output stype should be consistent to avoid confusion to users (see https://github.com/apache/incubator-mxnet/blob/d2a856a3a2abb4e72edc301b8b821f0b75f30722/src/operator/tensor/matrix_op-inl.h#L400-L418)
The only difference is that on GPU, it performs fallback. If the output stype infers sparse, then it first produce dense output, then cast it to sparse. The fallback is handled in executor already

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Fixed.

Copy link
Member

@eric-haibin-lin eric-haibin-lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM in general. A few minor comments

@@ -305,7 +305,10 @@ class NDArray {
bool fresh_out_grad() const;
/*! \return updated grad state in entry_ */
void set_fresh_out_grad(bool state) const;
// returns true if a sparse ndarray's aux_data and storage are initialized
/*! \brief Returns true if a sparse ndarray's aux_data and storage are initialized
* Returns false if the indices array shape is inconsistent
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Returns false if the indices array shape is inconsistent" -> it actually throws an exception without returning false

const auto dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback
: DispatchMode::kFComputeEx;
dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode,
dispatch_ex);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. we should log storage fallback as long as dispatch mode is dispatch_fallback:
https://github.com/apache/incubator-mxnet/blob/d2a856a3a2abb4e72edc301b8b821f0b75f30722/src/operator/elemwise_op_common.h#L79-L81

Maybe I should move this logic to the common path instead of letting developers specify that in operators
https://github.com/apache/incubator-mxnet/blob/master/src/executor/infer_graph_attr_pass.cc#L45-L54

Copy link
Member Author

@anirudh2290 anirudh2290 Jan 3, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We can fix that in a separate PR.

@@ -1248,10 +1273,12 @@ def test_sparse_dot_zero_output(lhs_shape, trans_lhs, rhs_num_cols):
test_dot_csr(lhs_shape, (lhs_shape[0], 1), 'default', True, lhs_d, rhs_d) # (vector kernel)
test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(5, 10)), 'default', False, lhs_d, rhs_d) # test gpu SpMM
test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(5, 10)), 'default', True, lhs_d, rhs_d) # (scalar kernel)
test_dot_dns_csr(lhs_shape, (lhs_shape[1], rnd.randint(500, 1000)), lhs_d, lhs_d)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

randint(50,200) is large (and slow) enough for testing. No need to increase the dim to 1000.

using namespace mshadow::expr;
using nnvm::dim_t;

/*Initialize data structures*/
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: space after /*

const CType start_idx = i * nnc;
nnvm::dim_t cur = 0;
indptr_out[i] = start_idx;
if (i == static_cast<int>(num_rows_l - 1)) indptr_out[i + 1] = indptr_out[i] + nnc;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we are adding large array support in the future, it's more appropriate to cast i up to dim_t instead of cast num_rows_l down to int.

@anirudh2290
Copy link
Member Author

@eric-haibin-lin Thank you for reviewing! I have made the necessary changes.

@eric-haibin-lin
Copy link
Member

Is the operator documentation not updated?

@anirudh2290
Copy link
Member Author

Added dot(dns, csr) = csr to operator doc

@eric-haibin-lin eric-haibin-lin merged commit 8505442 into apache:master Jan 4, 2018
yuxiangw pushed a commit to yuxiangw/incubator-mxnet that referenced this pull request Jan 25, 2018
* Add operator for dot(dns, csr) = csr

* Fix whitespace

* Add comments

* Add comments and fix error message

* Fixes for dot dns csr

* Fixes

* Remove non required statements

* Add fallback for GPU

* Remove unused if

* Fix comments and casting

* Add operator to the documentation
rahul003 pushed a commit to rahul003/mxnet that referenced this pull request Jun 4, 2018
* Add operator for dot(dns, csr) = csr

* Fix whitespace

* Add comments

* Add comments and fix error message

* Fixes for dot dns csr

* Fixes

* Remove non required statements

* Add fallback for GPU

* Remove unused if

* Fix comments and casting

* Add operator to the documentation
zheng-da pushed a commit to zheng-da/incubator-mxnet that referenced this pull request Jun 28, 2018
* Add operator for dot(dns, csr) = csr

* Fix whitespace

* Add comments

* Add comments and fix error message

* Fixes for dot dns csr

* Fixes

* Remove non required statements

* Add fallback for GPU

* Remove unused if

* Fix comments and casting

* Add operator to the documentation
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants