Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Optimized gpu dot kernels (#6937)
Browse files Browse the repository at this point in the history
* pulled update to mshadow

* mshadow update

* added optimized gpu kernels for dot(csr,dns)=dns and dot(csr.T,dns)=dns, and unit test

* added __syncwarp to vector kernel and reduced number of writes to shared memory
  • Loading branch information
stefanhenneking authored and piiswrong committed Jul 11, 2017
1 parent 69a75b6 commit 038fd31
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 23 deletions.
2 changes: 1 addition & 1 deletion mshadow
1 change: 1 addition & 0 deletions src/io/inst_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <mxnet/base.h>
#include <dmlc/base.h>
#include <mshadow/tensor.h>
#include <mshadow/tensor_blob.h>
#include <vector>
#include <string>

Expand Down
257 changes: 239 additions & 18 deletions src/operator/tensor/dot-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@

namespace mxnet {
namespace op {
using mshadow::cuda::kBaseThreadNum;

/*!
* \brief Kernel of dot(csr, dns1) = dns2
* Parallelization by output matrix elements
* \brief Scalar kernel of dot(csr, dns1) = dns2
* Parallelization by output matrix elements: 1 thread/element
*/
template<int req>
struct DotCsrDnsDns {
struct DotCsrDnsDnsScalarKernel {
/*!
* \brief This function represents performing an inner product between a row of lhs
* and a column of rhs and then assigning the value to out[i].
Expand Down Expand Up @@ -45,11 +46,52 @@ struct DotCsrDnsDns {
};

/*!
* \brief Kernel of dot(csr.T(), dns1) = dns2
* Parallelization by output matrix elements
* \brief Vector kernel of dot(csr, dns1) = dns2
* Parallelization by output matrix elements: 1 warp/element
*/
template<int req>
struct DotCsrTransDnsDns {
struct DotCsrDnsDnsVectorKernel {
template<typename DType, typename IType, typename CType>
__device__ __forceinline__ static void Map(int tid, DType* out, const DType* data_l, const IType* indptr_l,
const CType* col_idx_l, const DType* data_r,
const int num_cols_r) {
__shared__ volatile DType vals[kBaseThreadNum];

const int warp_id = tid / 32; // global warp id
const int lane = tid & (32-1); // local thread id within warp
const int irow = warp_id / num_cols_r; // lhs row that this warp computes
const int kcol = warp_id % num_cols_r; // rhs column that this warp computes

// Range of nnz elements in this row
const int low = static_cast<int>(indptr_l[irow]);
const int high = static_cast<int>(indptr_l[irow+1]);

// Compute running sum per thread
DType sum = 0;
for (int j = low+lane; j < high; j+=32) {
sum += data_l[j] * data_r[col_idx_l[j]*num_cols_r + kcol];
}
vals[threadIdx.x] = sum;

// Parallel reduction in shared memory
if (lane < 16) {vals[threadIdx.x] += vals[threadIdx.x+16];} __syncwarp();
if (lane < 8) {vals[threadIdx.x] += vals[threadIdx.x+ 8];} __syncwarp();
if (lane < 4) {vals[threadIdx.x] += vals[threadIdx.x+ 4];} __syncwarp();
if (lane < 2) {vals[threadIdx.x] += vals[threadIdx.x+ 2];} __syncwarp();
if (lane < 1) {vals[threadIdx.x] += vals[threadIdx.x+ 1];} __syncwarp();

if (lane == 0) {
KERNEL_ASSIGN(out[irow*num_cols_r+kcol], req, vals[threadIdx.x]);
}
}
};

/*!
* \brief Scalar kernel of dot(csr.T(), dns1) = dns2
* Parallelization by output matrix elements: 1 thread/element
*/
template<int req>
struct DotCsrTransDnsDnsScalarKernel {
/*!
* \brief This function represents performing an inner product between a column of lhs
* and a column of rhs and then assigning the value to out[i].
Expand All @@ -69,6 +111,8 @@ struct DotCsrTransDnsDns {
const int irow = i / num_cols; // col id of the lhs
const int icol = i % num_cols; // col id of the rhs
DType sum = 0;

// Each thread scans each column with binary search to find nnz elements in its row
for (int k = 0; k < num_rows_l; ++k) {
const IType low = indptr_l[k];
const IType high = indptr_l[k+1];
Expand All @@ -93,6 +137,98 @@ struct DotCsrTransDnsDns {
}
};

/*!
* \brief Warp kernel of dot(csr.T(), dns1) = dns2
* Parallelization by columns: 1 warp computes one lhs column for one rhs column
*/
template<int req>
struct DotCsrTransDnsDnsWarpKernel {
template<typename DType, typename IType, typename CType>
__device__ __forceinline__ static void Map(int tid, DType* out, const DType* data_l, const IType* indptr_l,
const CType* col_idx_l, const DType* data_r,
const int num_cols_r) {
const int warp_id = tid / 32; // global warp id
const int lane = tid & (32-1); // local thread id within warp
const int icol = warp_id / num_cols_r; // lhs column that this warp computes
const int kcol = warp_id % num_cols_r; // rhs column that this warp computes

// Compute range of nnz elements in this column
const int low = static_cast<int>(indptr_l[icol]);
const int high = static_cast<int>(indptr_l[icol+1]);

// Iterate through the nnz elements in this column
for (int j = low+lane; j < high; j+=32) {
const int irow = static_cast<int>(col_idx_l[j]);
const DType val = data_l[j]*data_r[icol*num_cols_r+kcol];
atomicAdd(static_cast<DType *>(&(out[irow*num_cols_r+kcol])), val);
}
}
};

/*!
* \brief Thread block kernel of dot(csr.T(), dns1) = dns2
* Parallelization by columns: 1 thread block computes one lhs column for all rhs columns
*/
template<int req>
struct DotCsrTransDnsDnsThreadBlockKernel {
template<typename DType, typename IType, typename CType>
__device__ __forceinline__ static void Map(int tid, DType* out, const DType* data_l, const IType* indptr_l,
const CType* col_idx_l, const DType* data_r,
const int num_cols_r) {
const int warps_per_block = blockDim.x / 32; // number of warps in this thread block
const int warp_id = tid / 32; // global warp id
const int lane = tid & (32-1); // local thread id within warp
const int icol = blockIdx.x; // lhs column that this thread block computes
const int kcol = warp_id % warps_per_block; // rhs column where warp starts computing (offset)

// Compute range of nnz elements in this lhs column
const int low = static_cast<int>(indptr_l[icol]);
const int high = static_cast<int>(indptr_l[icol+1]);

// Iterate through the nnz elements in this lhs column
for (int j = low+lane; j < high; j+=32) {
const int irow = static_cast<int>(col_idx_l[j]);
const DType datum_l = data_l[j];
// Iterate over rhs columns that this warp computes
for (int k = kcol; k < num_cols_r; k+=warps_per_block) {
const DType val = datum_l*data_r[icol*num_cols_r+k];
atomicAdd(static_cast<DType *>(&(out[irow*num_cols_r+k])), val);
}
}
}
};

/*!
* \brief Warp block kernel of dot(csr.T(), dns1) = dns2
* Parallelization by columns: 1 warp computes one lhs column for all rhs columns
*/
template<int req>
struct DotCsrTransDnsDnsWarpBlockKernel {
template<typename DType, typename IType, typename CType>
__device__ __forceinline__ static void Map(int tid, DType* out, const DType* data_l, const IType* indptr_l,
const CType* col_idx_l, const DType* data_r,
const int num_cols_r) {
const int warp_id = tid / 32; // global warp id
const int lane = tid & (32-1); // local thread id within warp
const int icol = warp_id; // lhs column that this warp computes

// Compute range of nnz elements in this column
const int low = static_cast<int>(indptr_l[icol]);
const int high = static_cast<int>(indptr_l[icol+1]);

// Iterate through the nnz elements in lhs column
for (int j = low+lane; j < high; j+=32) {
const int irow = static_cast<int>(col_idx_l[j]);
const DType datum_l = data_l[j];
// Iterate over all rhs columns
for (int k = 0; k < num_cols_r; k++) {
const DType val = datum_l*data_r[icol*num_cols_r+k];
atomicAdd(static_cast<DType *>(&(out[irow*num_cols_r+k])), val);
}
}
}
};

inline void DotCsrDnsDnsImpl(mshadow::Stream<gpu>* s,
const NDArray& lhs,
const TBlob& rhs,
Expand All @@ -109,22 +245,107 @@ inline void DotCsrDnsDnsImpl(mshadow::Stream<gpu>* s,
const TBlob& data_r = rhs;
const TBlob data_out = *ret;

MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type
MSHADOW_SGL_DBL_TYPE_SWITCH(data_l.type_flag_, DType, { // data type
MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type
if (kWriteTo == req) {
mxnet_op::Kernel<mxnet_op::set_zero, gpu>::Launch(s, data_out.Size(), data_out.dptr<DType>());
}
int num_threads;
const int threads_per_warp = 32;
const int threads_per_block = kBaseThreadNum;
const int num_rows_l = lhs.shape()[0];
const int num_cols_r = rhs.shape_[1];
if (trans_lhs) {
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
mxnet_op::Kernel<DotCsrTransDnsDns<ReqType>, gpu>::Launch(s, data_out.Size(),
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), lhs.shape()[0],
data_out.shape_[1]);
});
// Different kernel versions are optimized for different matrix instances
// TODO: switch between kernel versions depending on input
// (1) 'Scalar kernel' (one thread computing one output element )
// (2) 'Warp kernel' (one warp computing one lhs column for one rhs column )
// (3) 'Thread block kernel' (one thread block computing one lhs column for all rhs columns)
// (4) 'Warp block kernel' (one warp computing one lhs column for all rhs columns)
const int kernel_version = 0;
switch (kernel_version) {
case 1:
num_threads = data_out.Size();
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
mxnet_op::Kernel<DotCsrTransDnsDnsScalarKernel<ReqType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_rows_l, num_cols_r);
});
break;
case 2:
num_threads = threads_per_warp * num_rows_l * num_cols_r;
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
mxnet_op::Kernel<DotCsrTransDnsDnsWarpKernel<ReqType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
});
break;
case 3:
num_threads = threads_per_block * num_rows_l;
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
mxnet_op::Kernel<DotCsrTransDnsDnsThreadBlockKernel<ReqType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
});
break;
case 4:
num_threads = threads_per_warp * num_rows_l;
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
mxnet_op::Kernel<DotCsrTransDnsDnsWarpBlockKernel<ReqType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
});
break;
default:
num_threads = threads_per_warp * num_rows_l * num_cols_r;
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
mxnet_op::Kernel<DotCsrTransDnsDnsWarpKernel<ReqType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
});
break;
}
} else {
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
mxnet_op::Kernel<DotCsrDnsDns<ReqType>, gpu>::Launch(s, data_out.Size(),
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), rhs.shape_[1]);
});
// Different kernel versions are optimized for different matrix instances
// (1) 'Scalar kernel' (one thread computing one output element)
// (2) 'Vector kernel' (one warp computing one output element)
const int kernel_version = 0;
switch (kernel_version) {
case 1:
num_threads = data_out.Size();
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
mxnet_op::Kernel<DotCsrDnsDnsScalarKernel<ReqType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
});
break;
case 2:
num_threads = threads_per_warp * num_rows_l * num_cols_r;
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
mxnet_op::Kernel<DotCsrDnsDnsVectorKernel<ReqType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
});
break;
default:
if (num_cols_r > 4) {
num_threads = data_out.Size();
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
mxnet_op::Kernel<DotCsrDnsDnsScalarKernel<ReqType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
});
} else {
num_threads = threads_per_warp * num_rows_l * num_cols_r;
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
mxnet_op::Kernel<DotCsrDnsDnsVectorKernel<ReqType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
});
}
break;
}
}
});
});
Expand Down
3 changes: 2 additions & 1 deletion src/operator/tensor/dot-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs,
CHECK_EQ(out_attrs->size(), 1U);
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
// csr has many zero columns, so the result of dot(csr.T, matrix) should be rsp
if (param.transpose_a && kCSRStorage == (*in_attrs)[0]) {
// dot(csr.T,dns)=rsp not yet implemented on gpu
if (param.transpose_a && kCSRStorage == (*in_attrs)[0] && ctx.dev_type != Context::kGPU) {
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kRowSparseStorage);
} else {
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kDefaultStorage);
Expand Down
1 change: 1 addition & 0 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from test_operator import *
from test_optimizer import *
from test_random import *
from test_sparse_operator import test_sparse_dot
import mxnet as mx
import numpy as np
from mxnet.test_utils import check_consistency, set_default_context
Expand Down
8 changes: 5 additions & 3 deletions tests/python/unittest/test_sparse_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,12 @@ def test_dns_to_csr(dns_in):

def test_sparse_dot():
def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, density=1):
lhs_dns = rand_ndarray(lhs_shape, 'default')
lhs_nd = mx.nd.cast_storage(lhs_dns, storage_type='csr')
lhs_nd = rand_ndarray(lhs_shape, 'csr', 1)
lhs_dns = lhs_nd.todense()
rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=density)
rhs_dns = rhs_nd if rhs_stype == 'default' else rhs_nd.todense()
out = mx.nd.dot(lhs_nd, rhs_dns, transpose_a=trans_lhs)
if trans_lhs:
if trans_lhs and default_context().device_type is 'cpu':
assert out.storage_type == 'row_sparse'
else:
assert out.storage_type == 'default'
Expand All @@ -131,6 +131,8 @@ def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, density=1):
rtol=1e-3, atol=1e-4)

lhs_shape = rand_shape_2d(50, 200)
test_dot_csr(lhs_shape, (lhs_shape[1], 1), 'default', False)
test_dot_csr(lhs_shape, (lhs_shape[0], 1), 'default', True)
test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'default', False)
test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'default', True)
test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False)
Expand Down

0 comments on commit 038fd31

Please sign in to comment.