Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
fixes a bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Shixiaowei02 committed Feb 22, 2023
1 parent 00dccc3 commit 297d270
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
3 changes: 3 additions & 0 deletions cinn/frontend/decomposer/top_k.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ void top_k(const Instruction& instr, const DecomposerContext& context) {
int k = instr.GetAttrs<int>("k");
CHECK_GT(k, 0) << "The attribute k must be greater than 0.";
int axis = instr.GetAttrs<int>("axis");
if (axis < 0) {
axis += x->shape.size();
}

auto sort_tmp = builder->Sort(x, axis, false);
auto sort_out = builder->Slice(sort_tmp, {axis}, {0}, {k});
Expand Down
1 change: 0 additions & 1 deletion python/tests/ops/test_top_k_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def init_case(self):

class TestTopKCase2(TestTopKOp):
def init_case(self):
print("TestTopKCase2 ---------")
self.inputs = {
"x1": np.random.random([
2,
Expand Down

0 comments on commit 297d270

Please sign in to comment.