Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

【PaddlePaddle Hackathon No.75】add sort op #891

Merged
merged 65 commits into from
Sep 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
d342891
add cast op
zrr1999 Aug 8, 2022
2604fc4
fix bug
zrr1999 Aug 16, 2022
0953042
fix bug
zrr1999 Aug 16, 2022
9066e2c
fix bug
zrr1999 Aug 17, 2022
d8dfb5a
fix bug
zrr1999 Aug 19, 2022
ab8e434
fix bug
zrr1999 Aug 19, 2022
03e948f
fix bug
zrr1999 Aug 19, 2022
b30a7da
add gather and scatter op
zrr1999 Aug 8, 2022
9675f38
fix bug
zrr1999 Aug 21, 2022
6d02455
fix bug
zrr1999 Aug 22, 2022
0a8872a
fix bug
zrr1999 Aug 22, 2022
dabe72c
fix bug
zrr1999 Aug 23, 2022
3b4e976
fix bug
zrr1999 Aug 24, 2022
cce1792
fix bug
zrr1999 Aug 24, 2022
abca2db
fix bug
zrr1999 Aug 24, 2022
417b389
fix bug
zrr1999 Aug 24, 2022
f3a607c
fix bug
zrr1999 Aug 24, 2022
99c0223
fix bug
zrr1999 Aug 24, 2022
adb099b
fix bug
zrr1999 Aug 24, 2022
1a39c01
fix bug
zrr1999 Aug 24, 2022
107c519
fix bug
zrr1999 Aug 24, 2022
bd6cd48
fix bug
zrr1999 Aug 24, 2022
97d9578
fix bug
zrr1999 Aug 24, 2022
4e0c3d5
fix bug
zrr1999 Aug 24, 2022
ba87eca
fix bug
zrr1999 Aug 24, 2022
94834cf
fix bug
zrr1999 Aug 24, 2022
6520697
fix bug
zrr1999 Aug 24, 2022
b5cfedb
fix bug
zrr1999 Aug 24, 2022
8b08876
fix bug
zrr1999 Aug 24, 2022
feaa87b
fix bug
zrr1999 Aug 24, 2022
42c1754
fix bug
zrr1999 Aug 25, 2022
6fde20f
fix bug
zrr1999 Aug 25, 2022
0234492
fix bug
zrr1999 Aug 25, 2022
fab011e
fix bug
zrr1999 Aug 25, 2022
31d73b7
fix bug
zrr1999 Aug 25, 2022
5388e4d
fix bug
zrr1999 Aug 25, 2022
b484bbb
fix bug
zrr1999 Aug 25, 2022
078e3de
fix bug
zrr1999 Aug 25, 2022
f719a86
Merge branch 'develop' into sort
zrr1999 Aug 25, 2022
5e943f4
fix bug
zrr1999 Aug 25, 2022
99854cb
fix bug
zrr1999 Aug 25, 2022
c04bdb7
fix bug
zrr1999 Aug 25, 2022
41cb210
fix bug
zrr1999 Aug 25, 2022
2e616dc
fix bug
zrr1999 Aug 26, 2022
a9936b6
fix bug
zrr1999 Aug 26, 2022
9c98aa6
fix bug
zrr1999 Aug 27, 2022
914e206
fix bug
zrr1999 Aug 27, 2022
a7b4414
fix bug
zrr1999 Aug 28, 2022
82eb881
fix bug
zrr1999 Aug 28, 2022
097165f
fix bug
zrr1999 Aug 28, 2022
c3b61bb
Merge branch 'gather' into sort
zrr1999 Aug 28, 2022
ae4f9ca
fix bug
zrr1999 Aug 28, 2022
1514638
Merge branch 'develop' into sort
zrr1999 Sep 1, 2022
958367a
fix bug
zrr1999 Sep 1, 2022
e3facf0
Merge branch 'develop' into sort
zrr1999 Sep 1, 2022
49a135e
remove scatter and gather
zrr1999 Sep 5, 2022
bd79ec5
modified
zrr1999 Sep 7, 2022
2ca7fb1
Merge branch 'develop' into sort
zrr1999 Sep 7, 2022
0dfb89c
fix tensor_name
zrr1999 Sep 9, 2022
5e7195f
fix bugs
zrr1999 Sep 9, 2022
e714d03
Merge branch 'develop' into sort
zrr1999 Sep 9, 2022
6d05ac3
fix bugs
zrr1999 Sep 9, 2022
3b8a6b2
Merge branch 'develop' into sort
zrr1999 Sep 14, 2022
209b333
fix bugs
zrr1999 Sep 14, 2022
e83b9eb
Merge branch 'develop' into sort
zrr1999 Sep 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions cinn/frontend/net_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,24 @@ Variable NetBuilder::Conv(const Variable& lhs,
.front();
}

Variable NetBuilder::ArgSort(const Variable& operand, const int& axis, const bool& is_ascend) {
Instruction instr("argsort", {operand});
instr.SetAttr("axis", axis);
instr.SetAttr("is_ascend", is_ascend);
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable NetBuilder::Sort(const Variable& operand, const int& axis, const bool& is_ascend) {
Instruction instr("sort", {operand});
instr.SetAttr("axis", axis);
instr.SetAttr("is_ascend", is_ascend);
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable NetBuilder::Conv2d(const Variable& a,
const Variable& b,
const std::vector<int>& strides,
Expand Down
20 changes: 20 additions & 0 deletions cinn/frontend/net_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,26 @@ class NetBuilder {
const float epsilon = 1e-5f,
const std::string& data_layout = "NCHW");

/**
* @brief Sort Variable x along the given axis. The original Variable x will not be changed.
* @param operand The variable that will be sorted.
* @param axis Specify the axis to operate on the input. Default: 0.
* @param is_ascend Sort mode.
* Defalut “NCHW”.
* @return `Sorted variable index`.
*/
Variable ArgSort(const Variable& operand, const int& axis, const bool& is_ascend = true);

/**
* @brief Sort Variable x along the given axis. The original Variable x will not be changed.
* @param operand The variable that will be sorted.
* @param axis Specify the axis to operate on the input. Default: 0.
* @param is_ascend Sort mode.
* Defalut “NCHW”.
* @return `Sorted variable`.
*/
Variable Sort(const Variable& operand, const int& axis, const bool& is_ascend = true);

private:
CINN_DISALLOW_COPY_AND_ASSIGN(NetBuilder);
};
Expand Down
110 changes: 110 additions & 0 deletions cinn/frontend/net_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,116 @@ TEST(net_build, program_execute_squeeze_case3) {
}
}

TEST(net_build, program_execute_argsort) {
const int B = 4;
const int H = 7;

NetBuilder builder("net_builder");
Placeholder input = builder.CreateInput(Float(32), {B, H}, "In");
Variable output = builder.ArgSort(input, 0, true);
auto program = builder.Build();

Target target = common::DefaultHostTarget();

auto graph = std::make_shared<hlir::framework::Graph>(program, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build();

scope->Var<hlir::framework::Tensor>(std::string(input.id()));
scope->Var<hlir::framework::Tensor>(std::string(output->id));

auto input_tensor = scope->GetTensor(std::string(input.id()));
SetRandData<float>(input_tensor, target);
auto* input_data = input_tensor->mutable_data<float>(target);

runtime_program->Execute();

auto output_tensor = scope->GetTensor(std::string(output->id));
const std::vector<int>& output_shape = output_tensor->shape().data();
EXPECT_EQ(output_tensor->type(), Int(32));
EXPECT_EQ(output_shape.size(), 2UL);
EXPECT_EQ(output_shape[0], B);
EXPECT_EQ(output_shape[1], H);

int* output_data = output_tensor->mutable_data<int>(target);
VLOG(6) << "Visualize output_data";
for (int h = 0; h < H; ++h) {
std::vector<float> sorted_data;
std::vector<float> out_sorted_data(H);
for (int b = 0; b < B; ++b) {
int index = h + H * b;
sorted_data.push_back(input_data[index]);
out_sorted_data[output_data[index]] = input_data[index];
}
std::sort(sorted_data.begin(), sorted_data.begin() + B);

for (int b = 0; b < B; ++b) {
std::string line;
int index = h + H * b;
float true_data = sorted_data[b];
float out_data = out_sorted_data[b];
line += (std::to_string(out_data) + ", ");
EXPECT_EQ(true_data, out_data);
VLOG(6) << line;
}
}
}

TEST(net_build, program_execute_sort) {
const int B = 4;
const int H = 7;

NetBuilder builder("net_builder");
Placeholder input = builder.CreateInput(Float(32), {B, H}, "In");
Variable output = builder.Sort(input, 0, true);
auto program = builder.Build();

Target target = common::DefaultHostTarget();

auto graph = std::make_shared<hlir::framework::Graph>(program, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build();

scope->Var<hlir::framework::Tensor>(std::string(input.id()));
scope->Var<hlir::framework::Tensor>(std::string(output->id));

auto input_tensor = scope->GetTensor(std::string(input.id()));
SetRandData<float>(input_tensor, target);
auto* input_data = input_tensor->mutable_data<float>(target);

runtime_program->Execute();

auto output_tensor = scope->GetTensor(std::string(output->id));
const std::vector<int>& output_shape = output_tensor->shape().data();
EXPECT_EQ(output_tensor->type(), Float(32));
EXPECT_EQ(output_shape.size(), 2UL);
EXPECT_EQ(output_shape[0], B);
EXPECT_EQ(output_shape[1], H);

float* output_data = output_tensor->mutable_data<float>(target);
VLOG(6) << "Visualize output_data";
for (int h = 0; h < H; ++h) {
std::vector<float> sorted_data;
for (int b = 0; b < B; ++b) {
int index = h + H * b;
sorted_data.push_back(input_data[index]);
}
std::sort(sorted_data.begin(), sorted_data.begin() + B);

for (int b = 0; b < B; ++b) {
std::string line;
int index = h + H * b;
float true_data = sorted_data[b];
float out_data = output_data[index];
line += (std::to_string(out_data) + ", ");
EXPECT_EQ(true_data, out_data);
VLOG(6) << line;
}
}
}

TEST(net_build, program_execute_arange_float) {
const float start = 1.5F;
const float stop = 31.5F;
Expand Down
3 changes: 3 additions & 0 deletions cinn/hlir/op/contrib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@ gather_srcs(cinnapi_src SRCS
squeeze.cc
clip.cc
arange.cc
sort.cc
squeeze.cc
)

cc_test(test_cast SRCS cast_test.cc DEPS cinncore)
cc_test(test_squeeze SRCS squeeze_test.cc DEPS cinncore)
cc_test(test_clip SRCS clip_test.cc DEPS cinncore)
cc_test(test_sort SRCS sort_test.cc DEPS cinncore)
cc_test(test_arange SRCS arange_test.cc DEPS cinncore)
Loading