Skip to content

Commit

Permalink
fix gpu kernel for numel Op (#27085) (#27130)
Browse files Browse the repository at this point in the history
  • Loading branch information
wangchaochaohu authored Sep 7, 2020
1 parent 1e02d26 commit eed05e1
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 5 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/operators/size_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ REGISTER_OPERATOR(
size, ops::SizeOp, ops::SizeOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(size, ops::SizeKernel<int>, ops::SizeKernel<int32_t>,
REGISTER_OP_CPU_KERNEL(size, ops::SizeKernel<int>, ops::SizeKernel<int64_t>,
ops::SizeKernel<paddle::platform::float16>,
ops::SizeKernel<float>, ops::SizeKernel<double>,
ops::SizeKernel<bool>);
2 changes: 1 addition & 1 deletion paddle/fluid/operators/size_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ limitations under the License. */

REGISTER_OP_CUDA_KERNEL(
size, paddle::operators::SizeKernel<int>,
paddle::operators::SizeKernel<int32_t>,
paddle::operators::SizeKernel<int64_t>,
paddle::operators::SizeKernel<paddle::platform::float16>,
paddle::operators::SizeKernel<float>, paddle::operators::SizeKernel<bool>,
paddle::operators::SizeKernel<double>);
14 changes: 12 additions & 2 deletions paddle/fluid/operators/size_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,18 @@ class SizeKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_t = ctx.Input<Tensor>("Input");
auto* out_t = ctx.Output<Tensor>("Out");
auto out_data = out_t->mutable_data<int64_t>(platform::CPUPlace());
out_data[0] = in_t->numel();
auto place = ctx.GetPlace();
auto out_data = out_t->mutable_data<int64_t>(place);
auto cpu_place = platform::CPUPlace();
if (place == cpu_place) {
out_data[0] = in_t->numel();
} else {
Tensor cpu_tensor;
auto cpu_data =
cpu_tensor.mutable_data<int64_t>(out_t->dims(), cpu_place);
cpu_data[0] = in_t->numel();
TensorCopy(cpu_tensor, place, out_t);
}
}
};
} // namespace operators
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,7 @@ def chunk(x, chunks, axis=0, name=None):
x_np = np.random.random([3, 9, 5]).astype("int32")
x = paddle.to_tensor(x_np)
out0, out1, out22 = paddle.chunk(x, chunks=3, axis=1)
out0, out1, out2 = paddle.chunk(x, chunks=3, axis=1)
# out0.shape [3, 3, 5]
# out1.shape [3, 3, 5]
# out2.shape [3, 3, 5]
Expand Down

0 comments on commit eed05e1

Please sign in to comment.