Skip to content

Commit

Permalink
[Sparse]Sparse add support gpu (PaddlePaddle#45974)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangkaihuo committed Sep 19, 2022
1 parent 7f0c1f0 commit b334852
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 1 deletion.
1 change: 1 addition & 0 deletions paddle/phi/kernels/sparse/elementwise_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#pragma once

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/kernels/empty_kernel.h"

Expand Down
56 changes: 56 additions & 0 deletions paddle/phi/kernels/sparse/gpu/elementwise_grad_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/sparse/elementwise_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"

namespace phi {
namespace sparse {

template <typename T, typename Context>
void ElementWiseAddCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& y,
const SparseCooTensor& dout,
SparseCooTensor* dx,
SparseCooTensor* dy) {
if (dx) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, dx);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
}

if (dy) {
EmptyLikeCooKernel<T, Context>(dev_ctx, y, dy);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
}
}

} // namespace sparse
} // namespace phi

PD_REGISTER_KERNEL(add_coo_coo_grad,
GPU,
ALL_LAYOUT,
phi::sparse::ElementWiseAddCooGradKernel,
float,
double,
int16_t,
int,
int64_t,
phi::dtype::float16) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
88 changes: 88 additions & 0 deletions paddle/phi/kernels/sparse/gpu/elementwise_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <thrust/equal.h>
#include <thrust/execution_policy.h>

#include "paddle/phi/kernels/elementwise_add_kernel.h"
#include "paddle/phi/kernels/sparse/elementwise_kernel.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"

#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h"

namespace phi {
namespace sparse {

template <typename T, typename IntT>
void ElementWiseAddCooGPUKernel(const GPUContext& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& y,
SparseCooTensor* out) {
const auto& x_indices = x.indices();
const auto& y_indices = y.indices();
PADDLE_ENFORCE_EQ(
x_indices.numel(),
y_indices.numel(),
phi::errors::PreconditionNotMet(
"The numel of x.indices() and y.indices() should be equal"));
const IntT* x_indices_ptr = x_indices.data<IntT>();
const IntT* y_indices_ptr = y_indices.data<IntT>();
#ifdef PADDLE_WITH_HIP
bool is_same = thrust::equal(thrust::hip::par.on(dev_ctx.stream()),
#else
bool is_same = thrust::equal(thrust::cuda::par.on(dev_ctx.stream()),
#endif
x_indices_ptr,
x_indices_ptr + x_indices.numel(),
y_indices_ptr);
PADDLE_ENFORCE_EQ(
is_same,
true,
phi::errors::PreconditionNotMet(
"Currently, ElementWiseAddCooKernel only supports the case "
"where x and y have the same indices"));
EmptyLikeCooKernel<T, GPUContext>(dev_ctx, x, out);
phi::AddKernel<T, GPUContext>(
dev_ctx, x.values(), y.values(), out->mutable_values());
}

template <typename T, typename Context>
void ElementWiseAddCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& y,
SparseCooTensor* out) {
PD_VISIT_BASE_INTEGRAL_TYPES(x.indices().dtype(), "VerifyIndices", ([&] {
ElementWiseAddCooGPUKernel<T, data_t>(
dev_ctx, x, y, out);
}));
}

} // namespace sparse
} // namespace phi

PD_REGISTER_KERNEL(add_coo_coo,
GPU,
ALL_LAYOUT,
phi::sparse::ElementWiseAddCooKernel,
float,
double,
int16_t,
int,
int64_t,
phi::dtype::float16) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import numpy as np
import paddle
from paddle.fluid.framework import _test_eager_guard
import paddle.incubate.sparse as sparse

op_list = [__add__, __sub__, __mul__, __truediv__]

Expand Down Expand Up @@ -134,6 +134,35 @@ def test_support_dtypes_coo(self):
for op in op_list:
self.func_test_coo(op)

def test_add_same_indices(self):
indices_data = [[0, 1], [0, 3]]
values1_data = [[1.0], [2.0]]
values2_data = [[1.0], [2.0]]
shape = [2, 4, 2]

sp_a = sparse.sparse_coo_tensor(indices_data,
values1_data,
shape,
stop_gradient=False)
sp_b = sparse.sparse_coo_tensor(indices_data,
values2_data,
shape,
stop_gradient=False)

values1 = paddle.to_tensor(values1_data, stop_gradient=False)
values2 = paddle.to_tensor(values2_data, stop_gradient=False)

#c.values() = a.values() + b.values()
sp_c = sparse.add(sp_a, sp_b)
sp_c.backward()
ref_c = values1 + values2
ref_c.backward()
np.testing.assert_allclose(sp_c.values().numpy(), ref_c.numpy())
np.testing.assert_allclose(sp_a.grad.values().numpy(),
values1.grad.numpy())
np.testing.assert_allclose(sp_b.grad.values().numpy(),
values2.grad.numpy())


if __name__ == "__main__":
paddle.device.set_device('cpu')
Expand Down

0 comments on commit b334852

Please sign in to comment.