-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Hackathon 3rd No.22 ] add paddle.incubate.sparse.reshape #46694
Changes from 66 commits
b2d2ea5
5277531
ab3e871
a8a4960
8eb27b8
234122f
64f98b0
0f4660d
4761494
aa72cc7
c957c8d
fb434fd
281a3d4
64a3503
84f51db
90bfea3
7e80110
8ad2f55
21d3538
687778a
11ee8c3
de6d903
497e27a
8b09f24
4420f4f
953b88e
d74bda9
17ec7c3
ca08918
cdd78a1
471c648
5f33291
f1a9d9b
6aa3061
499b78a
2e91335
cc21e67
1d3ed6d
b4e6d2f
b5d6dbc
46247a7
0661bd2
c01cb39
3df911c
eaca3a2
d80cf1a
1da6d60
bbf4b48
dd2961e
6db2391
b01ebb9
dd30c2e
64d7de2
01ec505
118c429
4a510d2
f011729
9db2c6a
659c034
213145b
0d77ea7
eb33c53
67cfa50
32cd23c
a96baa6
8c7f144
518038c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -489,3 +489,14 @@ | |
func : sync_batch_norm_coo{sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense} | ||
data_type : x | ||
backward : sync_batch_norm_grad | ||
|
||
- op : reshape | ||
args : (Tensor x, IntArray shape) | ||
output : Tensor(out) | ||
infer_meta : | ||
func : ReshapeInferMeta | ||
kernel : | ||
func : reshape_coo{sparse_coo -> sparse_coo}, | ||
reshape_csr{sparse_csr -> sparse_csr} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the design of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean I need to modify the RFC to make the content of the RFC agree with my actual implementation ? Ok, I will do it. |
||
layout : x | ||
backward : reshape_grad |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/sparse/unary_grad_kernel.h" | ||
#include "paddle/phi/kernels/sparse/unary_kernel.h" | ||
|
||
#include "paddle/phi/backends/cpu/cpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/sparse/empty_kernel.h" | ||
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h" | ||
|
||
namespace phi { | ||
namespace sparse { | ||
|
||
template <typename T, typename Context> | ||
void ReshapeCooGradKernel(const Context& dev_ctx, | ||
const SparseCooTensor& x, | ||
const SparseCooTensor& dout, | ||
SparseCooTensor* dx) { | ||
EmptyLikeCooKernel<T, Context>(dev_ctx, x, dx); | ||
phi::IntArray x_shape(phi::vectorize(x.dims())); | ||
ReshapeCooKernel<T, Context>(dev_ctx, dout, x_shape, dx); | ||
} | ||
|
||
template <typename T, typename Context> | ||
void ReshapeCsrGradKernel(const Context& dev_ctx, | ||
const SparseCsrTensor& x, | ||
const SparseCsrTensor& dout, | ||
SparseCsrTensor* dx) { | ||
EmptyLikeCsrKernel<T, Context>(dev_ctx, x, dx); | ||
phi::IntArray x_shape(phi::vectorize(x.dims())); | ||
ReshapeCsrKernel<T, Context>(dev_ctx, dout, x_shape, dx); | ||
} | ||
|
||
} // namespace sparse | ||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL(reshape_coo_grad, | ||
CPU, | ||
ALL_LAYOUT, | ||
phi::sparse::ReshapeCooGradKernel, | ||
float, | ||
double, | ||
int8_t, | ||
uint8_t, | ||
int16_t, | ||
int, | ||
int64_t, | ||
bool) {} | ||
|
||
PD_REGISTER_KERNEL(reshape_csr_grad, | ||
CPU, | ||
ALL_LAYOUT, | ||
phi::sparse::ReshapeCsrGradKernel, | ||
float, | ||
double, | ||
int8_t, | ||
uint8_t, | ||
int16_t, | ||
int, | ||
int64_t, | ||
bool) {} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/sparse/unary_kernel.h" | ||
|
||
#include "paddle/phi/core/ddim.h" | ||
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h" | ||
|
||
#include "paddle/phi/backends/cpu/cpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/funcs/eigen/common.h" | ||
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" | ||
#include "paddle/phi/kernels/sparse/empty_kernel.h" | ||
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h" | ||
#include "paddle/phi/kernels/sparse/impl/unary_kernel_impl.h" | ||
|
||
namespace phi { | ||
namespace sparse { | ||
|
||
template <typename T, typename Context> | ||
void ReshapeCooKernel(const Context& dev_ctx, | ||
const SparseCooTensor& x, | ||
const phi::IntArray& shape, | ||
SparseCooTensor* out) { | ||
// TODO(OccupyMars2025): Currently, reshape is only applicable to sparse dims | ||
int64_t x_nnz = x.nnz(); | ||
|
||
// Use DDim::reshape to handle -1 and 0 in the argument "shape" | ||
std::vector<int> new_shape(shape.GetData().begin(), shape.GetData().end()); | ||
phi::DDim out_dims = x.dims().reshape(new_shape); | ||
// get sparse part dimensions of x and out | ||
std::vector<int64_t> x_sparse_part_dims; | ||
std::vector<int64_t> out_sparse_part_dims; | ||
for (int i = 0; i < x.sparse_dim(); ++i) { | ||
x_sparse_part_dims.push_back(x.dims()[i]); | ||
} | ||
for (int i = 0; i < out_dims.size() - x.dense_dim(); ++i) { | ||
out_sparse_part_dims.push_back(out_dims[i]); | ||
} | ||
DenseTensor out_indices = Empty<int64_t, Context>( | ||
dev_ctx, {static_cast<int64_t>(out_sparse_part_dims.size()), x_nnz}); | ||
DenseTensor out_values(x.values()); | ||
out->SetMember(out_indices, out_values, out_dims, x.coalesced()); | ||
|
||
// compute values of indices | ||
const DenseTensor& x_indices = x.indices(); | ||
const auto* x_indices_data = x_indices.data<int64_t>(); | ||
auto* out_indices_data = out_indices.data<int64_t>(); | ||
|
||
const phi::DDim& x_sparse_part_strides = | ||
phi::stride(phi::make_ddim(x_sparse_part_dims)); | ||
const phi::DDim& out_sparse_part_strides = | ||
phi::stride(phi::make_ddim(out_sparse_part_dims)); | ||
int64_t location = 0; | ||
for (int64_t j = 0; j < x_nnz; ++j) { | ||
location = 0; | ||
for (int i = 0; i < x.sparse_dim(); ++i) { | ||
location += x_indices_data[i * x_nnz + j] * x_sparse_part_strides[i]; | ||
} | ||
for (size_t i = 0; i < out_sparse_part_dims.size(); ++i) { | ||
out_indices_data[i * x_nnz + j] = location / out_sparse_part_strides[i]; | ||
location %= out_sparse_part_strides[i]; | ||
} | ||
} | ||
} | ||
|
||
template <typename T, typename Context> | ||
void ReshapeCsrKernel(const Context& dev_ctx, | ||
const SparseCsrTensor& x, | ||
const phi::IntArray& shape, | ||
SparseCsrTensor* out) { | ||
// transform csr format to coo format, and then use coo kernel | ||
const SparseCooTensor x_coo = CsrToCoo<T, Context>(dev_ctx, x); | ||
SparseCooTensor out_coo; | ||
ReshapeCooKernel<T, Context>(dev_ctx, x_coo, shape, &out_coo); | ||
CooToCsrKernel<T, Context>(dev_ctx, out_coo, out); | ||
} | ||
|
||
} // namespace sparse | ||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL(reshape_coo, | ||
CPU, | ||
ALL_LAYOUT, | ||
phi::sparse::ReshapeCooKernel, | ||
float, | ||
double, | ||
int8_t, | ||
uint8_t, | ||
int16_t, | ||
int, | ||
int64_t, | ||
bool) {} | ||
|
||
PD_REGISTER_KERNEL(reshape_csr, | ||
CPU, | ||
ALL_LAYOUT, | ||
phi::sparse::ReshapeCsrKernel, | ||
float, | ||
double, | ||
int8_t, | ||
uint8_t, | ||
int16_t, | ||
int, | ||
int64_t, | ||
bool) {} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/sparse/unary_grad_kernel.h" | ||
#include "paddle/phi/kernels/sparse/unary_kernel.h" | ||
|
||
#include "paddle/phi/backends/gpu/gpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/sparse/empty_kernel.h" | ||
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h" | ||
|
||
namespace phi { | ||
namespace sparse { | ||
|
||
// just copy from paddle\phi\kernels\sparse\cpu\reshape_grad_kernel.cc | ||
template <typename T, typename Context> | ||
void ReshapeCooGradKernel(const Context& dev_ctx, | ||
const SparseCooTensor& x, | ||
const SparseCooTensor& dout, | ||
SparseCooTensor* dx) { | ||
EmptyLikeCooKernel<T, Context>(dev_ctx, x, dx); | ||
phi::IntArray x_shape(phi::vectorize(x.dims())); | ||
ReshapeCooKernel<T, Context>(dev_ctx, dout, x_shape, dx); | ||
} | ||
|
||
// just copy from paddle\phi\kernels\sparse\cpu\reshape_grad_kernel.cc | ||
template <typename T, typename Context> | ||
void ReshapeCsrGradKernel(const Context& dev_ctx, | ||
const SparseCsrTensor& x, | ||
const SparseCsrTensor& dout, | ||
SparseCsrTensor* dx) { | ||
EmptyLikeCsrKernel<T, Context>(dev_ctx, x, dx); | ||
phi::IntArray x_shape(phi::vectorize(x.dims())); | ||
ReshapeCsrKernel<T, Context>(dev_ctx, dout, x_shape, dx); | ||
} | ||
|
||
} // namespace sparse | ||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL(reshape_coo_grad, | ||
GPU, | ||
ALL_LAYOUT, | ||
phi::sparse::ReshapeCooGradKernel, | ||
phi::dtype::float16, | ||
float, | ||
double, | ||
int8_t, | ||
uint8_t, | ||
int16_t, | ||
int, | ||
int64_t, | ||
bool) {} | ||
|
||
PD_REGISTER_KERNEL(reshape_csr_grad, | ||
GPU, | ||
ALL_LAYOUT, | ||
phi::sparse::ReshapeCsrGradKernel, | ||
phi::dtype::float16, | ||
float, | ||
double, | ||
int8_t, | ||
uint8_t, | ||
int16_t, | ||
int, | ||
int64_t, | ||
bool) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"func : ReshapeInferMeta" , I have used the infermeta function for dense tensor, so there is no need to write new ones.