Skip to content

Commit

Permalink
[Sparse] add_coo_dense (#46322)
Browse files Browse the repository at this point in the history
* for add_bias
  • Loading branch information
zhangkaihuo authored Sep 21, 2022
1 parent 0f9dde4 commit 55d3198
Show file tree
Hide file tree
Showing 12 changed files with 155 additions and 16 deletions.
3 changes: 3 additions & 0 deletions paddle/phi/api/lib/api_gen_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,9 @@ phi::SelectedRows* SetSelectedRowsKernelOutput(Tensor* out) {
}

phi::TensorBase* SetSparseKernelOutput(Tensor* out, TensorType type) {
if (!out) {
return nullptr;
}
if (!out->initialized()) {
if (type == TensorType::SPARSE_COO) {
auto sparse_tensor = std::make_shared<phi::SparseCooTensor>(
Expand Down
15 changes: 8 additions & 7 deletions paddle/phi/api/yaml/sparse_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@
args : (Tensor x, Tensor y, Tensor out_grad)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
func : GeneralBinaryGradInferMeta
param : [x, y]
kernel :
func : add_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo},
add_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr}
add_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr},
add_coo_dense_grad{sparse_coo, dense, sparse_coo -> sparse_coo, dense}

- backward_op : addmm_grad
forward : addmm(Tensor input, Tensor x, Tensor y, float alpha=1.0, float beta=1.0) -> Tensor(out)
Expand Down Expand Up @@ -104,7 +105,7 @@
args : (Tensor x, Tensor out_grad, DataType value_dtype)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
func : UnchangedInferMeta
param: [x]
kernel :
func : cast_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
Expand All @@ -126,7 +127,7 @@
args : (Tensor x, Tensor y, Tensor out, Tensor out_grad)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
func : GeneralBinaryGradInferMeta
param : [x, y]
kernel :
func : divide_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo},
Expand Down Expand Up @@ -209,7 +210,7 @@
args : (Tensor x, Tensor y, Tensor out_grad)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
func : GeneralBinaryGradInferMeta
param : [x, y]
kernel :
func : multiply_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo},
Expand Down Expand Up @@ -337,7 +338,7 @@
args : (Tensor x, Tensor y, Tensor out_grad)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
func : GeneralBinaryGradInferMeta
param : [x, y]
kernel :
func : subtract_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo},
Expand Down Expand Up @@ -399,7 +400,7 @@
args: (Tensor query, Tensor key, Tensor value, Tensor softmax, Tensor out_grad)
output : Tensor(query_grad), Tensor(key_grad), Tensor(value_grad)
infer_meta :
func : sparse::FusedAttentionGradInferMeta
func : sparse::FusedAttentionGradInferMeta
kernel :
func : fused_attention_csr_grad{dense, dense, dense, sparse_csr, dense -> dense, dense, dense}
layout : softmax
Expand Down
5 changes: 3 additions & 2 deletions paddle/phi/api/yaml/sparse_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@
args : (Tensor x, Tensor y)
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
func : ElementwiseInferMeta
kernel :
func : add_coo_coo{sparse_coo, sparse_coo -> sparse_coo},
add_csr_csr{sparse_csr, sparse_csr -> sparse_csr}
add_coo_dense{sparse_coo, dense -> sparse_coo},
layout : x
backward : add_grad

Expand Down Expand Up @@ -114,7 +115,7 @@
args : (Tensor x, Tensor y)
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
func : ElementwiseInferMeta
kernel :
func : divide_coo_coo{sparse_coo, sparse_coo -> sparse_coo},
divide_csr_csr{sparse_csr, sparse_csr -> sparse_csr}
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/kernels/sparse/cpu/elementwise_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -415,3 +415,14 @@ PD_REGISTER_KERNEL(divide_coo_coo_grad,
kernel->InputAt(2).SetDataLayout(phi::DataLayout::SPARSE_COO);
kernel->InputAt(3).SetDataLayout(phi::DataLayout::SPARSE_COO);
}

PD_REGISTER_KERNEL(add_coo_dense_grad,
CPU,
ALL_LAYOUT,
phi::sparse::ElementWiseAddDenseGradKernel,
float,
double,
int,
int64_t) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
26 changes: 26 additions & 0 deletions paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,21 @@ void ElementWiseCooKernelImpl(const Context& dev_ctx,
"shape = [%s], Y's shape = [%s].",
x.dims(),
y.dims()));

// temporary policy: for broadcast add
// TODO(zhangkaihuo): implement a correct function
const bool is_add = std::is_same<Functor, funcs::AddFunctor<T>>::value;
if (is_add && x.indices().numel() == y.indices().numel()) {
int compare_indices = memcmp(x.indices().data<IntT>(),
y.indices().data<IntT>(),
sizeof(IntT) * x.indices().numel());
if (compare_indices == 0) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, out);
phi::AddKernel<T, Context>(
dev_ctx, x.values(), y.values(), out->mutable_values());
return;
}
}
int64_t element_size = 1;
for (auto j = 1; j < x.values().dims().size(); ++j) {
element_size *= x.values().dims()[j];
Expand Down Expand Up @@ -435,3 +450,14 @@ PD_REGISTER_KERNEL(divide_coo_coo,
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO);
}

PD_REGISTER_KERNEL(add_coo_dense,
CPU,
ALL_LAYOUT,
phi::sparse::ElementWiseAddDenseKernel,
float,
double,
int,
int64_t) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
25 changes: 25 additions & 0 deletions paddle/phi/kernels/sparse/elementwise_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ limitations under the License. */

#pragma once

#include "paddle/phi/kernels/elementwise_add_grad_kernel.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
Expand Down Expand Up @@ -119,5 +122,27 @@ std::vector<SparseCooTensor> ElementWiseDivideCooGrad(
return std::vector<SparseCooTensor>{dx, dy};
}

template <typename T, typename Context>
void ElementWiseAddDenseGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& y,
const SparseCooTensor& dout,
SparseCooTensor* dx,
DenseTensor* dy) {
DenseTensor* x_values_grad = nullptr;
DenseTensor* y_grad = nullptr;
if (dx) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, dx);
x_values_grad = dx->mutable_values();
}

if (dy) {
*dy = phi::EmptyLike<T>(dev_ctx, y);
y_grad = dy;
}
phi::AddGradKernel<T, Context>(
dev_ctx, x.values(), y, dout.values(), -1, x_values_grad, y_grad);
}

} // namespace sparse
} // namespace phi
20 changes: 20 additions & 0 deletions paddle/phi/kernels/sparse/elementwise_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ limitations under the License. */

#pragma once

#include "paddle/phi/kernels/elementwise_add_kernel.h"
#include "paddle/phi/kernels/sparse/elementwise_kernel.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
Expand Down Expand Up @@ -78,5 +82,21 @@ DEFINE_ELEMENTWISE_KERNEL_FUNC(Subtract)
DEFINE_ELEMENTWISE_KERNEL_FUNC(Multiply)
DEFINE_ELEMENTWISE_KERNEL_FUNC(Divide)

template <typename T, typename Context>
void ElementWiseAddDenseKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& y,
SparseCooTensor* out) {
// TODO(zhangkaiuo): to support universal sparse + dense
if (y.dims().size() == 1 && y.dims()[0] == x.dims()[x.dims().size() - 1]) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, out);
phi::AddKernel<T, Context>(dev_ctx, x.values(), y, out->mutable_values());
out->SetIndicesDict(x.GetIndicesDict());
} else {
PADDLE_THROW(
errors::Unimplemented("Not support Sparse + Dense in GPU mode"));
}
}

} // namespace sparse
} // namespace phi
15 changes: 15 additions & 0 deletions paddle/phi/kernels/sparse/gpu/elementwise_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ limitations under the License. */
#include "paddle/phi/kernels/sparse/elementwise_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/elementwise_grad_base.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"

namespace phi {
Expand Down Expand Up @@ -54,3 +57,15 @@ PD_REGISTER_KERNEL(add_coo_coo_grad,
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO);
}

PD_REGISTER_KERNEL(add_coo_dense_grad,
GPU,
ALL_LAYOUT,
phi::sparse::ElementWiseAddDenseGradKernel,
float,
double,
int,
int64_t,
phi::dtype::float16) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
14 changes: 14 additions & 0 deletions paddle/phi/kernels/sparse/gpu/elementwise_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ void ElementWiseAddCooGPUKernel(const GPUContext& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& y,
SparseCooTensor* out) {
// TODO(zhangkaiuo): to support universal sparse + sparse
const auto& x_indices = x.indices();
const auto& y_indices = y.indices();
PADDLE_ENFORCE_EQ(
Expand All @@ -57,6 +58,7 @@ void ElementWiseAddCooGPUKernel(const GPUContext& dev_ctx,
EmptyLikeCooKernel<T, GPUContext>(dev_ctx, x, out);
phi::AddKernel<T, GPUContext>(
dev_ctx, x.values(), y.values(), out->mutable_values());
out->SetIndicesDict(x.GetIndicesDict());
}

template <typename T, typename Context>
Expand Down Expand Up @@ -86,3 +88,15 @@ PD_REGISTER_KERNEL(add_coo_coo,
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO);
}

PD_REGISTER_KERNEL(add_coo_dense,
GPU,
ALL_LAYOUT,
phi::sparse::ElementWiseAddDenseKernel,
float,
double,
int,
int64_t,
phi::dtype::float16) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
26 changes: 26 additions & 0 deletions python/paddle/fluid/tests/unittests/test_sparse_elementwise_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,32 @@ def test_add_same_indices(self):
np.testing.assert_allclose(sp_b.grad.values().numpy(),
values2.grad.numpy())

def test_add_bias(self):
indices_data = [[0, 1], [0, 3]]
values_data = [[1.0, 1.0], [2.0, 2.0]]
shape = [2, 4, 2]

sp_a = sparse.sparse_coo_tensor(indices_data,
values_data,
shape,
stop_gradient=False)

bias_values = [1.0, 2.0]

values1 = paddle.to_tensor(values_data, stop_gradient=False)
values2 = paddle.to_tensor(bias_values, stop_gradient=False)
values3 = paddle.to_tensor(bias_values, stop_gradient=False)

#c.values() = a.values() + b
sp_c = sparse.add(sp_a, values2)
sp_c.backward()
ref_c = values1 + values3
ref_c.backward()
np.testing.assert_allclose(sp_c.values().numpy(), ref_c.numpy())
np.testing.assert_allclose(sp_a.grad.values().numpy(),
values1.grad.numpy())
np.testing.assert_allclose(values2.grad.numpy(), values3.grad.numpy())


if __name__ == "__main__":
paddle.device.set_device('cpu')
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/incubate/sparse/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def add(x, y, name=None):
"""
if y.dtype != x.dtype:
y = _C_ops.sparse_cast(y, None, x.dtype)
y = cast(y, None, x.dtype)
return _C_ops.sparse_add(x, y)


Expand Down
9 changes: 3 additions & 6 deletions python/paddle/incubate/sparse/nn/functional/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from paddle.fluid.layers.utils import convert_to_list
from paddle.fluid.layers.nn import elementwise_add
from ...creation import sparse_coo_tensor
from ...binary import add
from paddle.tensor import arange
from paddle.nn.functional.conv import _update_padding_nd


Expand Down Expand Up @@ -67,12 +69,7 @@ def _conv3d(x,
groups, subm,
key if key is not None else "")
if bias is not None:
values = pre_bias.values()
add_bias = elementwise_add(values, bias, axis=1)
return sparse_coo_tensor(pre_bias.indices(),
add_bias,
shape=pre_bias.shape,
stop_gradient=pre_bias.stop_gradient)
return add(pre_bias, bias)
else:
return pre_bias

Expand Down

0 comments on commit 55d3198

Please sign in to comment.