From 2b81676ab022d4986cfd756c4c5079912b748aa5 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 21 Sep 2022 16:17:09 +0800 Subject: [PATCH] [Sparse] add_coo_dense (#46322) * for add_bias --- paddle/phi/api/lib/api_gen_utils.cc | 3 +++ paddle/phi/api/yaml/sparse_backward.yaml | 15 ++++++----- paddle/phi/api/yaml/sparse_ops.yaml | 5 ++-- .../sparse/cpu/elementwise_grad_kernel.cc | 11 ++++++++ .../kernels/sparse/cpu/elementwise_kernel.cc | 26 +++++++++++++++++++ .../kernels/sparse/elementwise_grad_kernel.h | 25 ++++++++++++++++++ .../phi/kernels/sparse/elementwise_kernel.h | 20 ++++++++++++++ .../sparse/gpu/elementwise_grad_kernel.cu | 15 +++++++++++ .../kernels/sparse/gpu/elementwise_kernel.cu | 14 ++++++++++ .../unittests/test_sparse_elementwise_op.py | 26 +++++++++++++++++++ python/paddle/incubate/sparse/binary.py | 2 +- .../incubate/sparse/nn/functional/conv.py | 9 +++---- 12 files changed, 155 insertions(+), 16 deletions(-) diff --git a/paddle/phi/api/lib/api_gen_utils.cc b/paddle/phi/api/lib/api_gen_utils.cc index 39f9fa93918d7..e1795edf5002c 100644 --- a/paddle/phi/api/lib/api_gen_utils.cc +++ b/paddle/phi/api/lib/api_gen_utils.cc @@ -230,6 +230,9 @@ phi::SelectedRows* SetSelectedRowsKernelOutput(Tensor* out) { } phi::TensorBase* SetSparseKernelOutput(Tensor* out, TensorType type) { + if (!out) { + return nullptr; + } if (!out->initialized()) { if (type == TensorType::SPARSE_COO) { auto sparse_tensor = std::make_shared( diff --git a/paddle/phi/api/yaml/sparse_backward.yaml b/paddle/phi/api/yaml/sparse_backward.yaml index 41816898c3a50..8347ee200e815 100644 --- a/paddle/phi/api/yaml/sparse_backward.yaml +++ b/paddle/phi/api/yaml/sparse_backward.yaml @@ -36,11 +36,12 @@ args : (Tensor x, Tensor y, Tensor out_grad) output : Tensor(x_grad), Tensor(y_grad) infer_meta : - func : GeneralBinaryGradInferMeta + func : GeneralBinaryGradInferMeta param : [x, y] kernel : func : add_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo}, - add_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr} + add_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr}, + add_coo_dense_grad{sparse_coo, dense, sparse_coo -> sparse_coo, dense} - backward_op : addmm_grad forward : addmm(Tensor input, Tensor x, Tensor y, float alpha=1.0, float beta=1.0) -> Tensor(out) @@ -104,7 +105,7 @@ args : (Tensor x, Tensor out_grad, DataType value_dtype) output : Tensor(x_grad) infer_meta : - func : UnchangedInferMeta + func : UnchangedInferMeta param: [x] kernel : func : cast_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, @@ -126,7 +127,7 @@ args : (Tensor x, Tensor y, Tensor out, Tensor out_grad) output : Tensor(x_grad), Tensor(y_grad) infer_meta : - func : GeneralBinaryGradInferMeta + func : GeneralBinaryGradInferMeta param : [x, y] kernel : func : divide_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo}, @@ -209,7 +210,7 @@ args : (Tensor x, Tensor y, Tensor out_grad) output : Tensor(x_grad), Tensor(y_grad) infer_meta : - func : GeneralBinaryGradInferMeta + func : GeneralBinaryGradInferMeta param : [x, y] kernel : func : multiply_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo}, @@ -337,7 +338,7 @@ args : (Tensor x, Tensor y, Tensor out_grad) output : Tensor(x_grad), Tensor(y_grad) infer_meta : - func : GeneralBinaryGradInferMeta + func : GeneralBinaryGradInferMeta param : [x, y] kernel : func : subtract_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo}, @@ -399,7 +400,7 @@ args: (Tensor query, Tensor key, Tensor value, Tensor softmax, Tensor out_grad) output : Tensor(query_grad), Tensor(key_grad), Tensor(value_grad) infer_meta : - func : sparse::FusedAttentionGradInferMeta + func : sparse::FusedAttentionGradInferMeta kernel : func : fused_attention_csr_grad{dense, dense, dense, sparse_csr, dense -> dense, dense, dense} layout : softmax diff --git a/paddle/phi/api/yaml/sparse_ops.yaml b/paddle/phi/api/yaml/sparse_ops.yaml index 043c12615fb7f..a917012b2f791 100644 --- a/paddle/phi/api/yaml/sparse_ops.yaml +++ b/paddle/phi/api/yaml/sparse_ops.yaml @@ -35,10 +35,11 @@ args : (Tensor x, Tensor y) output : Tensor(out) infer_meta : - func : ElementwiseInferMeta + func : ElementwiseInferMeta kernel : func : add_coo_coo{sparse_coo, sparse_coo -> sparse_coo}, add_csr_csr{sparse_csr, sparse_csr -> sparse_csr} + add_coo_dense{sparse_coo, dense -> sparse_coo}, layout : x backward : add_grad @@ -114,7 +115,7 @@ args : (Tensor x, Tensor y) output : Tensor(out) infer_meta : - func : ElementwiseInferMeta + func : ElementwiseInferMeta kernel : func : divide_coo_coo{sparse_coo, sparse_coo -> sparse_coo}, divide_csr_csr{sparse_csr, sparse_csr -> sparse_csr} diff --git a/paddle/phi/kernels/sparse/cpu/elementwise_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/elementwise_grad_kernel.cc index 58ed3f2d6b0b6..98afed84d6643 100644 --- a/paddle/phi/kernels/sparse/cpu/elementwise_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/elementwise_grad_kernel.cc @@ -415,3 +415,14 @@ PD_REGISTER_KERNEL(divide_coo_coo_grad, kernel->InputAt(2).SetDataLayout(phi::DataLayout::SPARSE_COO); kernel->InputAt(3).SetDataLayout(phi::DataLayout::SPARSE_COO); } + +PD_REGISTER_KERNEL(add_coo_dense_grad, + CPU, + ALL_LAYOUT, + phi::sparse::ElementWiseAddDenseGradKernel, + float, + double, + int, + int64_t) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} diff --git a/paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc b/paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc index 4e0eb90d7816d..0e46efc0e8673 100644 --- a/paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc @@ -156,6 +156,21 @@ void ElementWiseCooKernelImpl(const Context& dev_ctx, "shape = [%s], Y's shape = [%s].", x.dims(), y.dims())); + + // temporary policy: for broadcast add + // TODO(zhangkaihuo): implement a correct function + const bool is_add = std::is_same>::value; + if (is_add && x.indices().numel() == y.indices().numel()) { + int compare_indices = memcmp(x.indices().data(), + y.indices().data(), + sizeof(IntT) * x.indices().numel()); + if (compare_indices == 0) { + EmptyLikeCooKernel(dev_ctx, x, out); + phi::AddKernel( + dev_ctx, x.values(), y.values(), out->mutable_values()); + return; + } + } int64_t element_size = 1; for (auto j = 1; j < x.values().dims().size(); ++j) { element_size *= x.values().dims()[j]; @@ -435,3 +450,14 @@ PD_REGISTER_KERNEL(divide_coo_coo, kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); } + +PD_REGISTER_KERNEL(add_coo_dense, + CPU, + ALL_LAYOUT, + phi::sparse::ElementWiseAddDenseKernel, + float, + double, + int, + int64_t) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} diff --git a/paddle/phi/kernels/sparse/elementwise_grad_kernel.h b/paddle/phi/kernels/sparse/elementwise_grad_kernel.h index 86eb3b4381dc0..f16e2f95d47eb 100644 --- a/paddle/phi/kernels/sparse/elementwise_grad_kernel.h +++ b/paddle/phi/kernels/sparse/elementwise_grad_kernel.h @@ -14,6 +14,9 @@ limitations under the License. */ #pragma once +#include "paddle/phi/kernels/elementwise_add_grad_kernel.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" + #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h" @@ -119,5 +122,27 @@ std::vector ElementWiseDivideCooGrad( return std::vector{dx, dy}; } +template +void ElementWiseAddDenseGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& y, + const SparseCooTensor& dout, + SparseCooTensor* dx, + DenseTensor* dy) { + DenseTensor* x_values_grad = nullptr; + DenseTensor* y_grad = nullptr; + if (dx) { + EmptyLikeCooKernel(dev_ctx, x, dx); + x_values_grad = dx->mutable_values(); + } + + if (dy) { + *dy = phi::EmptyLike(dev_ctx, y); + y_grad = dy; + } + phi::AddGradKernel( + dev_ctx, x.values(), y, dout.values(), -1, x_values_grad, y_grad); +} + } // namespace sparse } // namespace phi diff --git a/paddle/phi/kernels/sparse/elementwise_kernel.h b/paddle/phi/kernels/sparse/elementwise_kernel.h index 59a554348cfea..515644d4fcfce 100644 --- a/paddle/phi/kernels/sparse/elementwise_kernel.h +++ b/paddle/phi/kernels/sparse/elementwise_kernel.h @@ -14,6 +14,10 @@ limitations under the License. */ #pragma once +#include "paddle/phi/kernels/elementwise_add_kernel.h" +#include "paddle/phi/kernels/sparse/elementwise_kernel.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" + #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h" @@ -78,5 +82,21 @@ DEFINE_ELEMENTWISE_KERNEL_FUNC(Subtract) DEFINE_ELEMENTWISE_KERNEL_FUNC(Multiply) DEFINE_ELEMENTWISE_KERNEL_FUNC(Divide) +template +void ElementWiseAddDenseKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& y, + SparseCooTensor* out) { + // TODO(zhangkaiuo): to support universal sparse + dense + if (y.dims().size() == 1 && y.dims()[0] == x.dims()[x.dims().size() - 1]) { + EmptyLikeCooKernel(dev_ctx, x, out); + phi::AddKernel(dev_ctx, x.values(), y, out->mutable_values()); + out->SetIndicesDict(x.GetIndicesDict()); + } else { + PADDLE_THROW( + errors::Unimplemented("Not support Sparse + Dense in GPU mode")); + } +} + } // namespace sparse } // namespace phi diff --git a/paddle/phi/kernels/sparse/gpu/elementwise_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/elementwise_grad_kernel.cu index e434dad588e13..e7f0c9d96e920 100644 --- a/paddle/phi/kernels/sparse/gpu/elementwise_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/elementwise_grad_kernel.cu @@ -15,6 +15,9 @@ limitations under the License. */ #include "paddle/phi/kernels/sparse/elementwise_grad_kernel.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/elementwise_grad_base.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" #include "paddle/phi/kernels/sparse/empty_kernel.h" namespace phi { @@ -54,3 +57,15 @@ PD_REGISTER_KERNEL(add_coo_coo_grad, kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); } + +PD_REGISTER_KERNEL(add_coo_dense_grad, + GPU, + ALL_LAYOUT, + phi::sparse::ElementWiseAddDenseGradKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} diff --git a/paddle/phi/kernels/sparse/gpu/elementwise_kernel.cu b/paddle/phi/kernels/sparse/gpu/elementwise_kernel.cu index 7496f47de8948..47daa1eae19ed 100644 --- a/paddle/phi/kernels/sparse/gpu/elementwise_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/elementwise_kernel.cu @@ -31,6 +31,7 @@ void ElementWiseAddCooGPUKernel(const GPUContext& dev_ctx, const SparseCooTensor& x, const SparseCooTensor& y, SparseCooTensor* out) { + // TODO(zhangkaiuo): to support universal sparse + sparse const auto& x_indices = x.indices(); const auto& y_indices = y.indices(); PADDLE_ENFORCE_EQ( @@ -57,6 +58,7 @@ void ElementWiseAddCooGPUKernel(const GPUContext& dev_ctx, EmptyLikeCooKernel(dev_ctx, x, out); phi::AddKernel( dev_ctx, x.values(), y.values(), out->mutable_values()); + out->SetIndicesDict(x.GetIndicesDict()); } template @@ -86,3 +88,15 @@ PD_REGISTER_KERNEL(add_coo_coo, kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); } + +PD_REGISTER_KERNEL(add_coo_dense, + GPU, + ALL_LAYOUT, + phi::sparse::ElementWiseAddDenseKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} diff --git a/python/paddle/fluid/tests/unittests/test_sparse_elementwise_op.py b/python/paddle/fluid/tests/unittests/test_sparse_elementwise_op.py index 20f66e5f9a65e..9acad42a9b8ee 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_elementwise_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_elementwise_op.py @@ -163,6 +163,32 @@ def test_add_same_indices(self): np.testing.assert_allclose(sp_b.grad.values().numpy(), values2.grad.numpy()) + def test_add_bias(self): + indices_data = [[0, 1], [0, 3]] + values_data = [[1.0, 1.0], [2.0, 2.0]] + shape = [2, 4, 2] + + sp_a = sparse.sparse_coo_tensor(indices_data, + values_data, + shape, + stop_gradient=False) + + bias_values = [1.0, 2.0] + + values1 = paddle.to_tensor(values_data, stop_gradient=False) + values2 = paddle.to_tensor(bias_values, stop_gradient=False) + values3 = paddle.to_tensor(bias_values, stop_gradient=False) + + #c.values() = a.values() + b + sp_c = sparse.add(sp_a, values2) + sp_c.backward() + ref_c = values1 + values3 + ref_c.backward() + np.testing.assert_allclose(sp_c.values().numpy(), ref_c.numpy()) + np.testing.assert_allclose(sp_a.grad.values().numpy(), + values1.grad.numpy()) + np.testing.assert_allclose(values2.grad.numpy(), values3.grad.numpy()) + if __name__ == "__main__": paddle.device.set_device('cpu') diff --git a/python/paddle/incubate/sparse/binary.py b/python/paddle/incubate/sparse/binary.py index 93ce90c9f021a..626a24c95a072 100644 --- a/python/paddle/incubate/sparse/binary.py +++ b/python/paddle/incubate/sparse/binary.py @@ -253,7 +253,7 @@ def add(x, y, name=None): """ if y.dtype != x.dtype: - y = _C_ops.sparse_cast(y, None, x.dtype) + y = cast(y, None, x.dtype) return _C_ops.sparse_add(x, y) diff --git a/python/paddle/incubate/sparse/nn/functional/conv.py b/python/paddle/incubate/sparse/nn/functional/conv.py index cd3e8e3551f5b..b6e492109577c 100644 --- a/python/paddle/incubate/sparse/nn/functional/conv.py +++ b/python/paddle/incubate/sparse/nn/functional/conv.py @@ -18,6 +18,8 @@ from paddle.fluid.layers.utils import convert_to_list from paddle.fluid.layers.nn import elementwise_add from ...creation import sparse_coo_tensor +from ...binary import add +from paddle.tensor import arange from paddle.nn.functional.conv import _update_padding_nd @@ -67,12 +69,7 @@ def _conv3d(x, groups, subm, key if key is not None else "") if bias is not None: - values = pre_bias.values() - add_bias = elementwise_add(values, bias, axis=1) - return sparse_coo_tensor(pre_bias.indices(), - add_bias, - shape=pre_bias.shape, - stop_gradient=pre_bias.stop_gradient) + return add(pre_bias, bias) else: return pre_bias