Skip to content

Commit

Permalink
重构 TopK 类中的排序逻辑,使用结构体替代 Lambda 表达式以提高兼容性
Browse files Browse the repository at this point in the history
  • Loading branch information
Baiyuetribe committed Dec 21, 2024
1 parent 13dacbb commit 7f732af
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions src/layer/topk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,36 +49,41 @@ int TopK::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons
vec.push_back(std::make_pair(ptr[i], i));
}

if (largest == 1)
// [](const std::pair<float, int>& a, const std::pair<float, int>& b) {return a.first > b.first;}); // fix Lambda with lower version of C++
struct CompareGreater
{
std::partial_sort(vec.begin(), vec.begin() + k_, vec.end(),
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
bool operator()(const std::pair<float, int>& a, const std::pair<float, int>& b) const
{
return a.first > b.first;
});
}
};

struct CompareLess
{
bool operator()(const std::pair<float, int>& a, const std::pair<float, int>& b) const
{
return a.first < b.first;
}
};

if (largest == 1)
{
std::partial_sort(vec.begin(), vec.begin() + k_, vec.end(), CompareGreater());
}
else
{
std::partial_sort(vec.begin(), vec.begin() + k_, vec.end(),
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
return a.first < b.first;
});
std::partial_sort(vec.begin(), vec.begin() + k_, vec.end(), CompareLess());
}

if (sorted)
{
if (largest == 1)
{
std::sort(vec.begin(), vec.begin() + k_,
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
return a.first > b.first;
});
std::sort(vec.begin(), vec.begin() + k_, CompareGreater());
}
else
{
std::sort(vec.begin(), vec.begin() + k_,
[](const std::pair<float, int>& a, const std::pair<float, int>& b) {
return a.first < b.first;
});
std::sort(vec.begin(), vec.begin() + k_, CompareLess());
}
}

Expand Down

0 comments on commit 7f732af

Please sign in to comment.