Skip to content

Commit

Permalink
use CastCsrKernel
Browse files Browse the repository at this point in the history
  • Loading branch information
MayYouBeProsperous committed Dec 21, 2023
1 parent 0ee3545 commit c8ff3d5
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 135 deletions.
73 changes: 57 additions & 16 deletions paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -159,27 +159,66 @@ void MatmulCsrCsrGradKernel(const Context& dev_ctx,
perm = {0, 2, 1};
}

// dx{SparseCsr} = dout{Dense} * y'{Dense}
// cusparseSpGEMM only support 32-bit index.
SparseCsrTensor dout_tmp;
CastCsrKernel<T, Context>(
dev_ctx, dout, DataType::INT32, dout.values().dtype(), &dout_tmp);

// dx{SparseCsr} = dout{SparseCsr} * y'{SparseCsr}
if (dx) {
// InferMeta of SparseCsrTensor 'dx', CreateLikeInferMeta
EmptyLikeCsrKernel<T, Context>(dev_ctx, x, dx);
// cusparseSPGEMM only support CUSPARSE_OPERATION_NON_TRANSPOSE.
SparseCsrTensor trans_y;
TransposeCsrKernel<T, Context>(dev_ctx, y, perm, &trans_y);
SparseCsrTensor x_tmp, dx_tmp;
CastCsrKernel<T, Context>(
dev_ctx, x, DataType::INT32, x.values().dtype(), &x_tmp);

EmptyLikeCsrKernel<T, Context>(dev_ctx, x_tmp, &dx_tmp);

sparse_blas.SPGEMM(
false, false, static_cast<T>(1), dout, trans_y, static_cast<T>(0), dx);
// cusparseSpGEMM only support CUSPARSE_OPERATION_NON_TRANSPOSE.
SparseCsrTensor trans_y, trans_y_tmp;
TransposeCsrKernel<T, Context>(dev_ctx, y, perm, &trans_y);
CastCsrKernel<T, Context>(dev_ctx,
trans_y,
DataType::INT32,
trans_y.values().dtype(),
&trans_y_tmp);

sparse_blas.SPGEMM(false,
false,
static_cast<T>(1),
dout_tmp,
trans_y_tmp,
static_cast<T>(0),
&dx_tmp);

CastCsrKernel<T, Context>(
dev_ctx, dx_tmp, DataType::INT64, dx_tmp.values().dtype(), dx);
}

// dy{Dense} = x'{SparseCsr} * dout{Dense}
// dy{SparseCsr} = x'{SparseCsr} * dout{SparseCsr}
if (dy) {
// InferMeta of DenseTensor 'dy'
EmptyLikeCsrKernel<T, Context>(dev_ctx, y, dy);
SparseCsrTensor trans_x;
TransposeCsrKernel<T, Context>(dev_ctx, x, perm, &trans_x);
SparseCsrTensor y_tmp, dy_tmp;
CastCsrKernel<T, Context>(
dev_ctx, y, DataType::INT32, y.values().dtype(), &y_tmp);
EmptyLikeCsrKernel<T, Context>(dev_ctx, y_tmp, &dy_tmp);

sparse_blas.SPGEMM(
false, false, static_cast<T>(1), trans_x, dout, static_cast<T>(0), dy);
// cusparseSpGEMM only support CUSPARSE_OPERATION_NON_TRANSPOSE.
SparseCsrTensor trans_x, trans_x_tmp;
TransposeCsrKernel<T, Context>(dev_ctx, x, perm, &trans_x);
CastCsrKernel<T, Context>(dev_ctx,
trans_x,
DataType::INT32,
trans_x.values().dtype(),
&trans_x_tmp);

sparse_blas.SPGEMM(false,
false,
static_cast<T>(1),
trans_x_tmp,
dout_tmp,
static_cast<T>(0),
&dy_tmp);

CastCsrKernel<T, Context>(
dev_ctx, dy_tmp, DataType::INT64, dy_tmp.values().dtype(), dy);
}
#else
#ifdef PADDLE_WITH_CUDA
Expand All @@ -197,7 +236,7 @@ void MatmulCooCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& dout,
SparseCooTensor* dx,
SparseCooTensor* dy) {
// 'cusparseSPGEMM' only support CSR now, so use COO->CSR->COO,
// cusparseSpGEMM only support CSR now, so use COO->CSR->COO
SparseCsrTensor x_csr = CooToCsr<T, Context>(dev_ctx, x);
SparseCsrTensor y_csr = CooToCsr<T, Context>(dev_ctx, y);
SparseCsrTensor dout_csr = CooToCsr<T, Context>(dev_ctx, dout);
Expand Down Expand Up @@ -288,6 +327,7 @@ PD_REGISTER_KERNEL(matmul_csr_csr_grad,
float,
double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}

PD_REGISTER_KERNEL(matmul_coo_coo_grad,
Expand All @@ -297,6 +337,7 @@ PD_REGISTER_KERNEL(matmul_coo_coo_grad,
float,
double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO);
}

PD_REGISTER_KERNEL(masked_matmul_csr_grad,
Expand Down
26 changes: 22 additions & 4 deletions paddle/phi/kernels/sparse/gpu/matmul_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/math_function_impl.h"
#include "paddle/phi/kernels/funcs/sparse/sparse_blas.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/impl/unary_kernel_impl.h"
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"

namespace phi {
Expand Down Expand Up @@ -158,21 +159,36 @@ void MatmulCsrCsrKernel(const Context& dev_ctx,
"The shape of Input(x) and Input(y) is not suitable for matmul "
"opetation, x_dim[-1] must be eaqual to y_dim[-2]."));

// cusparseSpGEMM only support 32-bit index.
SparseCsrTensor x_tmp, y_tmp, out_tmp;
CastCsrKernel<T, Context>(
dev_ctx, x, DataType::INT32, x.values().dtype(), &x_tmp);
CastCsrKernel<T, Context>(
dev_ctx, y, DataType::INT32, y.values().dtype(), &y_tmp);

std::vector<int64_t> out_dim_vec = phi::vectorize(out->dims());
int batch_size = 1;
for (int i = 0; i < out_dim_vec.size() - 2; i++) {
batch_size *= out_dim_vec[i];
}

int64_t out_crows_size = batch_size * (xdim_vec[x_ndims - 2] + 1);
DenseTensor out_crows = phi::Empty<int32_t>(dev_ctx, {out_crows_size});
DenseTensor out_cols = phi::Empty<int32_t>(dev_ctx, {0});
DenseTensor out_values = phi::Empty<T>(dev_ctx, {0});
out->SetMember(out_crows, out_cols, out_values, out->dims());
out_tmp.SetMember(out_crows, out_cols, out_values, out->dims());

auto sparse_blas = phi::funcs::sparse::GetSparseBlas<Context, T>(dev_ctx);
sparse_blas.SPGEMM(
false, false, static_cast<T>(1), x, y, static_cast<T>(0), out);
sparse_blas.SPGEMM(false,
false,
static_cast<T>(1),
x_tmp,
y_tmp,
static_cast<T>(0),
&out_tmp);

CastCsrKernel<T, Context>(
dev_ctx, out_tmp, DataType::INT64, out_tmp.values().dtype(), out);

#else
#ifdef PADDLE_WITH_CUDA
PADDLE_THROW(phi::errors::Unimplemented(
Expand Down Expand Up @@ -307,6 +323,7 @@ PD_REGISTER_KERNEL(matmul_coo_coo,
float,
double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO);
}

PD_REGISTER_KERNEL(matmul_csr_csr,
Expand All @@ -316,6 +333,7 @@ PD_REGISTER_KERNEL(matmul_csr_csr,
float,
double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}

PD_REGISTER_KERNEL(masked_matmul_csr,
Expand Down
Loading

0 comments on commit c8ff3d5

Please sign in to comment.