From dbbd3398f320ef72ff8ad8787ef71f9f6451debd Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Wed, 28 Dec 2022 11:14:27 +0800 Subject: [PATCH] fix bugs of paddle.multiplex API (#49368) --- paddle/phi/infermeta/multiary.cc | 8 ++++++++ paddle/phi/kernels/cpu/multiplex_kernel.cc | 2 +- paddle/phi/kernels/gpu/multiplex_kernel.cu | 2 +- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 1ab67ede698d9..375b88493a92b 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2190,6 +2190,14 @@ void MultiplexInferMeta(const std::vector& ins, phi::errors::PreconditionNotMet( "All the candidate tensors must have the same size.")); } + + PADDLE_ENFORCE_GE( + in_dim[0], + ids_dim[0], + phi::errors::InvalidArgument("The 2nd-dim of input cannot be smaller " + "than batchSize of the index tensor.")); + + in_dim[0] = ids_dim[0]; out->set_dims(in_dim); out->set_dtype(ins[0]->dtype()); } diff --git a/paddle/phi/kernels/cpu/multiplex_kernel.cc b/paddle/phi/kernels/cpu/multiplex_kernel.cc index 2d9f4c51a981e..4e60448c6c536 100644 --- a/paddle/phi/kernels/cpu/multiplex_kernel.cc +++ b/paddle/phi/kernels/cpu/multiplex_kernel.cc @@ -37,7 +37,7 @@ void MultiplexKernel(const Context& ctx, auto rows = ins[0]->dims()[0]; auto cols = ins[0]->numel() / rows; auto index = ids.data(); - for (auto i = 0; i < rows; i++) { + for (auto i = 0; i < ids.dims()[0]; i++) { int32_t k = index[i]; PADDLE_ENFORCE_GE( k, 0, errors::PreconditionNotMet("index must be nonnegative.")); diff --git a/paddle/phi/kernels/gpu/multiplex_kernel.cu b/paddle/phi/kernels/gpu/multiplex_kernel.cu index 743448a468666..2a86827bcf475 100644 --- a/paddle/phi/kernels/gpu/multiplex_kernel.cu +++ b/paddle/phi/kernels/gpu/multiplex_kernel.cu @@ -41,7 +41,7 @@ void MultiplexKernel(const Context& ctx, paddle::framework::TensorCopySync(ids, phi::CPUPlace(), &index_t_cpu); auto* index = index_t_cpu.data(); auto stream = ctx.stream(); - for (auto i = 0; i < rows; i++) { + for (auto i = 0; i < ids.dims()[0]; i++) { int32_t k = index[i]; PADDLE_ENFORCE_GE( k, 0, errors::PreconditionNotMet("index must be nonnegative."));