Skip to content

Commit

Permalink
[utest] fix bugs in topk_v2 utest (#6629)
Browse files Browse the repository at this point in the history
* fix bugs in topkv2_host_test test=develop

* fix bugs in topk_v2 kernel test=develop

* fix codestyle issues test=develop
  • Loading branch information
zhenlin-work authored Aug 9, 2021
1 parent 966c0e2 commit da4e3a9
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 31 deletions.
29 changes: 15 additions & 14 deletions lite/kernels/host/topk_v2_compute.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -44,19 +44,20 @@ void TopkV2Compute::Run() {
int inner_size = x_dims.count(axis + 1, dim_size);
int sum_size = axis_size * inner_size;
int out_sum_size = k * inner_size;
for (int n = 0; n < outer_size; n++) {
const float* in_data = x_data + n * sum_size;
float* out_data = out_val + n * out_sum_size;
int64_t* out_ind_data = out_ind + n * out_sum_size;
for (int i = 0; i < inner_size; i++) {
std::vector<std::pair<float, int>> vec;
for (int j = 0; j < axis_size; j++) {
vec.push_back(std::make_pair(in_data[j * inner_size + i], j));
}
std::partial_sort(vec.begin(), vec.begin() + k, vec.end(), comp_func);
for (int j = 0; j < k; j++) {
out_data[j * inner_size + i] = vec[j].first;
out_ind_data[j * inner_size + i] = vec[j].second;

for (int i = 0; i < outer_size; i++) {
int glb_in_off = i * sum_size;
int glb_out_off = i * out_sum_size;
std::vector<std::pair<float, int>> vec;
for (int j = 0; j < axis_size; j++) {
vec.push_back(std::make_pair(x_data[glb_in_off + j * inner_size], j));
}
std::partial_sort(vec.begin(), vec.begin() + k, vec.end(), comp_func);
for (int j = 0; j < k; j++) {
for (int k = 0; k < inner_size; k++) {
int cur_off = glb_in_off + vec[j].second * inner_size + k;
out_val[glb_out_off + j * inner_size + k] = x_data[cur_off];
out_ind[glb_out_off + j * inner_size + k] = vec[j].second;
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions lite/tests/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ lite_cc_test(test_kernel_sequence_softmax_compute SRCS sequence_softmax_compute_
lite_cc_test(test_kernel_compare_compute SRCS compare_compute_test.cc DEPS ${test_kernel_deps})
lite_cc_test(test_kernel_logical_compute SRCS logical_compute_test.cc DEPS ${test_kernel_deps})
lite_cc_test(test_kernel_topk_compute SRCS topk_compute_test.cc DEPS ${test_kernel_deps})
lite_cc_test(test_kernel_topk_v2_compute SRCS topk_v2_compute_test.cc DEPS ${test_kernel_deps})

lite_cc_test(test_kernel_increment_compute SRCS increment_compute_test.cc DEPS ${test_kernel_deps})
lite_cc_test(test_kernel_write_to_array_compute SRCS write_to_array_compute_test.cc DEPS ${test_kernel_deps})
lite_cc_test(test_kernel_read_from_array_compute SRCS read_from_array_compute_test.cc DEPS ${test_kernel_deps})
Expand Down
56 changes: 39 additions & 17 deletions lite/tests/kernels/topk_v2_compute_test.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -48,36 +48,41 @@ class TopkV2ComputeTester : public arena::TestCase {
void RunBaseline(Scope* scope) override {
auto* out_val = scope->NewTensor(out_);
auto* out_ind = scope->NewTensor(indices_);

DDim out_dims = x_dims_;
if (axis_ < 0) {
axis_ += x_dims_.size();
}
out_dims[axis_] = k_;

out_val->Resize(out_dims);
out_ind->Resize(out_dims);
auto* out_val_data = out_val->template mutable_data<T1>();
auto* out_ind_data = out_ind->template mutable_data<T2>();

auto* x = scope->FindTensor(x_);
const auto* x_data = x->template data<T1>();

int inner_size = x_dims_.count(axis_ + 1, x_dims_.size());
int axis_size = x_dims_[axis_];
int outer_size = x_dims_.count(0, axis_);
int out_sum_size = k * inner_size;
for (int n = 0; n < outer_size; n++) {
const float* in_data = x_data + n * sum_size;
float* out_val_data1 = out_val_data + n * out_sum_size;
int64_t* out_ind_data1 = out_ind_data + n * out_sum_size;
for (int i = 0; i < inner_size; i++) {
std::vector<std::pair<float, int>> vec;
for (int j = 0; j < axis_size; j++) {
vec.push_back(std::make_pair(in_data[j * outer_size + i], j));
}
std::partial_sort(
vec.begin(), vec.begin() + k_, vec.end(), comp_func<T1, T2>);
for (int j = 0; j < k_; j++) {
out_val_data1[j * outer_size + i] = vec[j].first;
out_ind_data1[j * outer_size + i] = vec[j].second;
int out_sum_size = k_ * inner_size;
int sum_size = axis_size * inner_size;

for (int i = 0; i < outer_size; i++) {
int glb_in_off = i * sum_size;
int glb_out_off = i * out_sum_size;
std::vector<std::pair<float, int>> vec;
for (int j = 0; j < axis_size; j++) {
vec.push_back(std::make_pair(x_data[glb_in_off + j * inner_size], j));
}
std::partial_sort(
vec.begin(), vec.begin() + k_, vec.end(), comp_func<T1, T2>);
for (int j = 0; j < k_; j++) {
for (int k = 0; k < inner_size; k++) {
int cur_off = glb_in_off + vec[j].second * inner_size + k;
out_val_data[glb_out_off + j * inner_size + k] = x_data[cur_off];
out_ind_data[glb_out_off + j * inner_size + k] = vec[j].second;
}
}
}
Expand All @@ -101,10 +106,27 @@ class TopkV2ComputeTester : public arena::TestCase {

template <typename T1, typename T2>
void test_topk_v2(Place place, float abs_error) {
int caseNum = 0;
for (auto x_shape :
std::vector<std::vector<int64_t>>{{2, 3, 4, 5}, {3, 4, 5}, {4, 5}}) {
for (int axis : {-1, -2}) {
for (int k : {2, 5}) {
std::cout << "start case " << caseNum++ << ":" << std::endl;
auto axis_valid = ((axis >= (-1 * (int)x_shape.size())) &&
(axis < (int)x_shape.size()));
if (!axis_valid) {
LOG(INFO) << "the axis of topk_v2 must be [" << (-1 * x_shape.size())
<< ", " << x_shape.size() << "but you set axis is" << axis;
continue;
}
if (axis < 0) {
axis += x_shape.size();
}
if (x_shape[axis] < k) {
LOG(INFO) << "input of topk_v2 op must have >=" << k
<< " columns in axis of " << x_shape[axis];
continue;
}
std::unique_ptr<arena::TestCase> tester(new TopkV2ComputeTester<T1, T2>(
place, "def", DDim(x_shape), axis, k));
arena::Arena arena(std::move(tester), place, abs_error);
Expand All @@ -116,9 +138,9 @@ void test_topk_v2(Place place, float abs_error) {

TEST(Topk, precision) {
Place place;
float abs_error = 2e-5;
#if defined(LITE_WITH_ARM)
place = TARGET(kHost);
float abs_error = 2e-5;
test_topk_v2<float, int64_t>(place, abs_error);
#else
return;
Expand Down

0 comments on commit da4e3a9

Please sign in to comment.