Skip to content

Commit

Permalink
[Phi]Move topk kernel to phi (PaddlePaddle#40064)
Browse files Browse the repository at this point in the history
* first commit

* cpu kernel

* first version

* fix compile error

* fix compile error

* delete v2

* fix

* fix

* add alias

* fix

* fix

* fix

* fix error

* fix

* fix

* fix

* fix format
  • Loading branch information
ZzSean authored Mar 10, 2022
1 parent 2b6da4d commit 329b095
Show file tree
Hide file tree
Showing 18 changed files with 897 additions and 650 deletions.
1 change: 0 additions & 1 deletion paddle/fluid/operators/kthvalue_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/kthvalue_op.h"
#include "paddle/fluid/operators/top_k_function_cuda.h"
#include "paddle/fluid/operators/top_k_v2_op.h"
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/operators/mode_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mode_op.h"
#include "paddle/fluid/operators/top_k_function_cuda.h"
#include "paddle/fluid/operators/top_k_v2_op.h"

namespace paddle {
namespace operators {
Expand Down
13 changes: 13 additions & 0 deletions paddle/fluid/operators/top_k_function_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,19 @@ namespace operators {

using Tensor = framework::Tensor;

inline void GetDims(const phi::DDim& dim, int axis, int* pre, int* n,
int* post) {
*pre = 1;
*post = 1;
*n = dim[axis];
for (int i = 0; i < axis; ++i) {
(*pre) *= dim[i];
}
for (int i = axis + 1; i < dim.size(); ++i) {
(*post) *= dim[i];
}
}

struct SegmentOffsetIter {
EIGEN_DEVICE_FUNC
explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {}
Expand Down
15 changes: 2 additions & 13 deletions paddle/fluid/operators/top_k_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/top_k_v2_op.h"
#include <memory>

#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {

Expand Down Expand Up @@ -173,15 +174,3 @@ REGISTER_OPERATOR(top_k_v2, ops::TopkV2Op, ops::TopkV2OpMaker,
ops::TopkV2GradOpMaker<paddle::imperative::OpBase>);

REGISTER_OPERATOR(top_k_v2_grad, ops::TopkV2OpGrad);

REGISTER_OP_CPU_KERNEL(top_k_v2,
ops::TopkV2Kernel<paddle::platform::CPUPlace, float>,
ops::TopkV2Kernel<paddle::platform::CPUPlace, double>,
ops::TopkV2Kernel<paddle::platform::CPUPlace, int32_t>,
ops::TopkV2Kernel<paddle::platform::CPUPlace, int64_t>)

REGISTER_OP_CPU_KERNEL(
top_k_v2_grad, ops::TopkV2GradKernel<paddle::platform::CPUPlace, float>,
ops::TopkV2GradKernel<paddle::platform::CPUPlace, double>,
ops::TopkV2GradKernel<paddle::platform::CPUPlace, int32_t>,
ops::TopkV2GradKernel<paddle::platform::CPUPlace, int64_t>)
296 changes: 0 additions & 296 deletions paddle/fluid/operators/top_k_v2_op.cu

This file was deleted.

Loading

0 comments on commit 329b095

Please sign in to comment.