diff --git a/cinn/frontend/net_builder.cc b/cinn/frontend/net_builder.cc index 68a88833db..f14496e32a 100644 --- a/cinn/frontend/net_builder.cc +++ b/cinn/frontend/net_builder.cc @@ -383,6 +383,24 @@ Variable NetBuilder::Conv(const Variable& lhs, .front(); } +Variable NetBuilder::ArgSort(const Variable& operand, const int& axis, const bool& is_ascend) { + Instruction instr("argsort", {operand}); + instr.SetAttr("axis", axis); + instr.SetAttr("is_ascend", is_ascend); + InferShape(instr); + AppendInstruction(instr); + return instr.GetOutput(0); +} + +Variable NetBuilder::Sort(const Variable& operand, const int& axis, const bool& is_ascend) { + Instruction instr("sort", {operand}); + instr.SetAttr("axis", axis); + instr.SetAttr("is_ascend", is_ascend); + InferShape(instr); + AppendInstruction(instr); + return instr.GetOutput(0); +} + Variable NetBuilder::Conv2d(const Variable& a, const Variable& b, const std::vector& strides, diff --git a/cinn/frontend/net_builder.h b/cinn/frontend/net_builder.h index be2b8800c7..e213c83b9c 100644 --- a/cinn/frontend/net_builder.h +++ b/cinn/frontend/net_builder.h @@ -792,6 +792,26 @@ class NetBuilder { const float epsilon = 1e-5f, const std::string& data_layout = "NCHW"); + /** + * @brief Sort Variable x along the given axis. The original Variable x will not be changed. + * @param operand The variable that will be sorted. + * @param axis Specify the axis to operate on the input. Default: 0. + * @param is_ascend Sort mode. + * Defalut “NCHW”. + * @return `Sorted variable index`. + */ + Variable ArgSort(const Variable& operand, const int& axis, const bool& is_ascend = true); + + /** + * @brief Sort Variable x along the given axis. The original Variable x will not be changed. + * @param operand The variable that will be sorted. + * @param axis Specify the axis to operate on the input. Default: 0. + * @param is_ascend Sort mode. + * Defalut “NCHW”. + * @return `Sorted variable`. + */ + Variable Sort(const Variable& operand, const int& axis, const bool& is_ascend = true); + private: CINN_DISALLOW_COPY_AND_ASSIGN(NetBuilder); }; diff --git a/cinn/frontend/net_builder_test.cc b/cinn/frontend/net_builder_test.cc index 55244574a7..4920f994d7 100755 --- a/cinn/frontend/net_builder_test.cc +++ b/cinn/frontend/net_builder_test.cc @@ -453,6 +453,116 @@ TEST(net_build, program_execute_squeeze_case3) { } } +TEST(net_build, program_execute_argsort) { + const int B = 4; + const int H = 7; + + NetBuilder builder("net_builder"); + Placeholder input = builder.CreateInput(Float(32), {B, H}, "In"); + Variable output = builder.ArgSort(input, 0, true); + auto program = builder.Build(); + + Target target = common::DefaultHostTarget(); + + auto graph = std::make_shared(program, target); + auto scope = BuildScope(target, graph); + hlir::framework::GraphCompiler gc(target, scope, graph); + auto runtime_program = gc.Build(); + + scope->Var(std::string(input.id())); + scope->Var(std::string(output->id)); + + auto input_tensor = scope->GetTensor(std::string(input.id())); + SetRandData(input_tensor, target); + auto* input_data = input_tensor->mutable_data(target); + + runtime_program->Execute(); + + auto output_tensor = scope->GetTensor(std::string(output->id)); + const std::vector& output_shape = output_tensor->shape().data(); + EXPECT_EQ(output_tensor->type(), Int(32)); + EXPECT_EQ(output_shape.size(), 2UL); + EXPECT_EQ(output_shape[0], B); + EXPECT_EQ(output_shape[1], H); + + int* output_data = output_tensor->mutable_data(target); + VLOG(6) << "Visualize output_data"; + for (int h = 0; h < H; ++h) { + std::vector sorted_data; + std::vector out_sorted_data(H); + for (int b = 0; b < B; ++b) { + int index = h + H * b; + sorted_data.push_back(input_data[index]); + out_sorted_data[output_data[index]] = input_data[index]; + } + std::sort(sorted_data.begin(), sorted_data.begin() + B); + + for (int b = 0; b < B; ++b) { + std::string line; + int index = h + H * b; + float true_data = sorted_data[b]; + float out_data = out_sorted_data[b]; + line += (std::to_string(out_data) + ", "); + EXPECT_EQ(true_data, out_data); + VLOG(6) << line; + } + } +} + +TEST(net_build, program_execute_sort) { + const int B = 4; + const int H = 7; + + NetBuilder builder("net_builder"); + Placeholder input = builder.CreateInput(Float(32), {B, H}, "In"); + Variable output = builder.Sort(input, 0, true); + auto program = builder.Build(); + + Target target = common::DefaultHostTarget(); + + auto graph = std::make_shared(program, target); + auto scope = BuildScope(target, graph); + hlir::framework::GraphCompiler gc(target, scope, graph); + auto runtime_program = gc.Build(); + + scope->Var(std::string(input.id())); + scope->Var(std::string(output->id)); + + auto input_tensor = scope->GetTensor(std::string(input.id())); + SetRandData(input_tensor, target); + auto* input_data = input_tensor->mutable_data(target); + + runtime_program->Execute(); + + auto output_tensor = scope->GetTensor(std::string(output->id)); + const std::vector& output_shape = output_tensor->shape().data(); + EXPECT_EQ(output_tensor->type(), Float(32)); + EXPECT_EQ(output_shape.size(), 2UL); + EXPECT_EQ(output_shape[0], B); + EXPECT_EQ(output_shape[1], H); + + float* output_data = output_tensor->mutable_data(target); + VLOG(6) << "Visualize output_data"; + for (int h = 0; h < H; ++h) { + std::vector sorted_data; + for (int b = 0; b < B; ++b) { + int index = h + H * b; + sorted_data.push_back(input_data[index]); + } + std::sort(sorted_data.begin(), sorted_data.begin() + B); + + for (int b = 0; b < B; ++b) { + std::string line; + int index = h + H * b; + float true_data = sorted_data[b]; + float out_data = output_data[index]; + line += (std::to_string(out_data) + ", "); + EXPECT_EQ(true_data, out_data); + VLOG(6) << line; + } + } +} + TEST(net_build, program_execute_arange_float) { const float start = 1.5F; const float stop = 31.5F; diff --git a/cinn/hlir/op/contrib/CMakeLists.txt b/cinn/hlir/op/contrib/CMakeLists.txt index c6aeac910a..7dd7945d5c 100644 --- a/cinn/hlir/op/contrib/CMakeLists.txt +++ b/cinn/hlir/op/contrib/CMakeLists.txt @@ -5,9 +5,12 @@ gather_srcs(cinnapi_src SRCS squeeze.cc clip.cc arange.cc + sort.cc + squeeze.cc ) cc_test(test_cast SRCS cast_test.cc DEPS cinncore) cc_test(test_squeeze SRCS squeeze_test.cc DEPS cinncore) cc_test(test_clip SRCS clip_test.cc DEPS cinncore) +cc_test(test_sort SRCS sort_test.cc DEPS cinncore) cc_test(test_arange SRCS arange_test.cc DEPS cinncore) diff --git a/cinn/hlir/op/contrib/sort.cc b/cinn/hlir/op/contrib/sort.cc new file mode 100644 index 0000000000..75f0c20f51 --- /dev/null +++ b/cinn/hlir/op/contrib/sort.cc @@ -0,0 +1,299 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/hlir/op/contrib/sort.h" + +#include + +#include +#include +#include +#include + +#include "cinn/common/cas.h" +#include "cinn/common/common.h" +#include "cinn/common/context.h" +#include "cinn/common/macros.h" +#include "cinn/hlir/framework/node.h" +#include "cinn/hlir/framework/op.h" +#include "cinn/hlir/framework/op_strategy.h" +#include "cinn/hlir/pe/elementwise.h" +#include "cinn/hlir/pe/transform.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/tensor.h" +#include "cinn/lang/builtin.h" +#include "cinn/lang/compute.h" + +DECLARE_bool(cinn_ir_schedule); + +namespace cinn { +namespace hlir { +namespace op { + +using common::CINNValue; +using common::CINNValuePack; + +ir::Tensor ArgSort(const ir::Tensor &A, + const common::Target &target, + const int &axis, + const bool &is_ascend, + const std::string &name) { + std::string extern_fun_name; + if (target.arch == common::Target::Arch::NVGPU) { + extern_fun_name.assign("cinn_cuda_"); + } else if (target.arch == common::Target::Arch::X86) { + extern_fun_name.assign("cinn_host_"); + } else { + LOG(FATAL) << "ArgSort only supports X86 and NVGPU ! Please Check.\n"; + } + if (is_ascend) { + extern_fun_name.append("lt_num_float"); + } else { + extern_fun_name.append("gt_num_float"); + } + int pos_axis = axis; + if (pos_axis < 0) { + pos_axis += A->shape.size(); + } + auto res = Compute( + A->shape, + [=](const std::vector &indices) { + Expr offset(0); + Expr stride(1); + for (int i = 0; i < indices.size(); i++) { + if (i < pos_axis) { + offset = offset * A->shape[i] + indices[i]; + } else if (i == pos_axis) { + offset = offset * A->shape[i]; + } else { + offset = offset * A->shape[i] + indices[i]; + stride = stride * A->shape[i]; + } + } + offset = common::AutoSimplify(offset); + stride = common::AutoSimplify(stride); + auto A_shape_axis = A->shape[pos_axis]; + return lang::CallExtern(extern_fun_name, {A, A_shape_axis, A(indices), offset, stride}); + }, + name); + return res; +} + +std::vector Sort(const ir::Tensor &A, + const common::Target &target, + const int &axis, + const bool &is_ascend, + const std::string &name) { + std::string extern_fun_name; + if (target.arch == common::Target::Arch::NVGPU) { + extern_fun_name.assign("cinn_cuda_find_int_nd"); + } else if (target.arch == common::Target::Arch::X86) { + extern_fun_name.assign("cinn_host_find_int_nd"); + } else { + LOG(FATAL) << "Sort only supports X86 and NVGPU ! Please Check.\n"; + } + int pos_axis = axis; + if (pos_axis < 0) { + pos_axis += A->shape.size(); + } + auto sort_index = ArgSort(A, target, pos_axis, is_ascend, name + "_index"); + auto res = Compute( + A->shape, + [=](const std::vector &indices) { + Expr offset(0); + Expr stride(1); + for (int i = 0; i < indices.size(); i++) { + if (i < pos_axis) { + offset = offset * A->shape[i] + indices[i]; + } else if (i == pos_axis) { + offset = offset * A->shape[i]; + } else { + offset = offset * A->shape[i] + indices[i]; + stride = stride * A->shape[i]; + } + } + offset = common::AutoSimplify(offset); + stride = common::AutoSimplify(stride); + + auto A_shape_axis = A->shape[pos_axis]; + auto idx = lang::CallExtern(extern_fun_name, {sort_index, A_shape_axis, indices[pos_axis], offset, stride}); + std::vector A_indices(indices); + A_indices[pos_axis] = idx; + return A(A_indices); + }, + name); + return {sort_index, res}; +} + +std::shared_ptr StrategyForSort(const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + auto attr_store = attrs.attr_store; + std::string op_name("sort"); + + CHECK(attr_store.count("axis")) << "find no attr of axis"; + int axis = absl::get(attr_store.at("axis")); + bool is_ascend = true; + if (attr_store.count("is_ascend")) { + is_ascend = absl::get(attr_store.at("is_ascend")); + } + + framework::CINNCompute sort_compute([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input arguments of Sort compute is empty! Please check.\n"; + CINNValuePack pack_args = args[0]; + CHECK_EQ(pack_args.size(), 1U) << "At least 1 input tensors for Sort compute\n"; + Expr A = pack_args[0]; + CHECK(A.as_tensor()); + CHECK(!output_shapes.empty()); + auto tensor_A = A.as_tensor_ref(); + auto stages = CreateStages({tensor_A}); + VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") + << ", output_shapes: " << utils::Join(output_shapes[0], ", "); + auto tensor_name = UniqName("Sort_out"); + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(pack_args.size(), 2U); + CHECK(pack_args[1].is_string()); + tensor_name = pack_args[1].operator std::string(); + } + std::vector outputs = Sort(tensor_A, target, axis, is_ascend, tensor_name); + ir::Tensor sort_index = outputs[0]; + ir::Tensor out = outputs[1]; + std::vector res; + stages->InsertLazily(sort_index); + stages->InsertLazily(out); + res.push_back(CINNValue(out)); + CHECK(!out_type.empty()) << "Output type of Sort is empty! Please check.\n"; + res.push_back(CINNValue(stages)); + *ret = CINNValuePack{res}; + }); + + framework::CINNSchedule sort_schedule([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of reshape schedule is empty! Please check.\n"; + CINNValuePack arg_pack = args[0]; + Expr out = arg_pack[0]; + CHECK(out.as_tensor()); + *ret = arg_pack; + }); + + auto strategy = std::make_shared(); + strategy->AddImpl(sort_compute, sort_schedule, "strategy.sort.x86", 1); + return strategy; +} + +std::shared_ptr StrategyForArgSort(const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + auto attr_store = attrs.attr_store; + CHECK(attr_store.count("axis")) << "find no attr of axis"; + int axis = absl::get(attr_store.at("axis")); + bool is_ascend = true; + if (attr_store.count("is_ascend")) { + is_ascend = absl::get(attr_store.at("is_ascend")); + } + + framework::CINNCompute argsort_compute([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input arguments of ArgSort compute is empty! Please check.\n"; + CINNValuePack pack_args = args[0]; + CHECK_EQ(pack_args.size(), 1U) << "At least 1 input tensors for ArgSort compute\n"; + Expr A = pack_args[0]; + CHECK(A.as_tensor()); + CHECK(!output_shapes.empty()); + auto tensor_A = A.as_tensor_ref(); + auto stages = CreateStages({tensor_A}); + VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") + << ", output_shapes: " << utils::Join(output_shapes[0], ", "); + auto tensor_name = UniqName("ArgSort_out"); + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(pack_args.size(), 2U); + CHECK(pack_args[1].is_string()); + tensor_name = pack_args[1].operator std::string(); + } + ir::Tensor out = ArgSort(tensor_A, target, axis, is_ascend, tensor_name); + std::vector res; + stages->InsertLazily(out); + res.push_back(CINNValue(out)); + CHECK(!out_type.empty()) << "Output type of ArgSort is empty! Please check.\n"; + res.push_back(CINNValue(stages)); + *ret = CINNValuePack{res}; + }); + + framework::CINNSchedule argsort_schedule([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of reshape schedule is empty! Please check.\n"; + CINNValuePack arg_pack = args[0]; + Expr out = arg_pack[0]; + CHECK(out.as_tensor()); + *ret = arg_pack; + }); + + auto strategy = std::make_shared(); + strategy->AddImpl(argsort_compute, argsort_schedule, "strategy.argsort.x86", 1); + return strategy; +} + +std::vector> InferShapeForSort(const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_shape.size(), 1UL) << "The input's shape size should be 1! Please check again."; + int axis = 0; + for (auto &iter : attrs) { + if (iter.first == "axis") { + axis = absl::get(iter.second); + break; + } + } + CHECK_GT(inputs_shape[0].size(), axis) << "The input's dim should be greater than axis! "; + std::vector> res{inputs_shape[0]}; + return res; +} + +std::vector InferDtypeForSort(const std::vector &inputs_type, const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_type.size(), 1UL) << "The input's type size should be 1! Please check again."; + std::vector res{inputs_type[0]}; + return res; +} + +std::vector InferDtypeForArgSort(const std::vector &inputs_type, const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_type.size(), 1UL) << "The input's type size should be 1! Please check again."; + return {Int(32)}; +} + +} // namespace op +} // namespace hlir +} // namespace cinn + +CINN_REGISTER_HELPER(sort_ops) { + CINN_REGISTER_OP(sort) + .describe("Sort a variable x along the given axis and return sorted Variable.") + .set_num_inputs(1) + .set_num_outputs(1) + .set_attr("CINNStrategy", cinn::hlir::op::StrategyForSort) + .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForSort)) + .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForSort)) + .set_support_level(4); + + CINN_REGISTER_OP(argsort) + .describe("Sort a variable x along the given axis and return indices.") + .set_num_inputs(1) + .set_num_outputs(1) + .set_attr("CINNStrategy", cinn::hlir::op::StrategyForArgSort) + .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForSort)) + .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForArgSort)) + .set_support_level(4); + + return true; +} diff --git a/cinn/hlir/op/contrib/sort.h b/cinn/hlir/op/contrib/sort.h new file mode 100644 index 0000000000..8ac93ad57c --- /dev/null +++ b/cinn/hlir/op/contrib/sort.h @@ -0,0 +1,36 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/tensor.h" + +namespace cinn { +namespace hlir { +namespace op { + +ir::Tensor ArgSort( + const ir::Tensor& A, const common::Target& target, const int& axis, const bool& is_ascend, const std::string& name); + +std::vector Sort( + const ir::Tensor& A, const common::Target& target, const int& axis, const bool& is_ascend, const std::string& name); + +} // namespace op +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/op/contrib/sort_test.cc b/cinn/hlir/op/contrib/sort_test.cc new file mode 100644 index 0000000000..2618c6e3a9 --- /dev/null +++ b/cinn/hlir/op/contrib/sort_test.cc @@ -0,0 +1,107 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/hlir/op/contrib/sort.h" + +#include +#include + +#include +#include + +#include "cinn/backends/codegen_c.h" +#include "cinn/backends/codegen_c_x86.h" +#include "cinn/backends/codegen_cuda_dev.h" +#include "cinn/common/context.h" +#include "cinn/lang/lower.h" +#include "cinn/lang/placeholder.h" +#include "cinn/poly/stage.h" + +namespace cinn { +namespace hlir { +namespace op { + +TEST(GenerateCode_Cpu, ArgSort) { + common::Context::Global().ResetNameId(); + +#ifdef CINN_WITH_CUDA + Target target = common::DefaultNVGPUTarget(); +#else + Target target = common::DefaultHostTarget(); +#endif + + ir::Expr n(4); + ir::Expr h(28); + + lang::Placeholder in("in", {n, h}); + ir::Tensor res = ArgSort(in.tensor(), target, 1, true, "test_arg_sort_out"); + + poly::StageMap stages = poly::CreateStages({in, res}); + std::vector funcs = + lang::LowerVec("TestGenerateCodeCpu_ArgSort", stages, {in, res}, {}, {}, nullptr, target, true); + + VLOG(6) << "Expr before CPU codegen:"; + VLOG(6) << funcs[0]->body; + + ir::Module::Builder builder("ArgSort_Module", target); + for (auto& f : funcs) { + builder.AddFunction(f); + } + + backends::CodeGenCX86 codegen(target, backends::CodeGenCX86::Feature::AVX512); + codegen.SetInlineBuiltinCodes(false); + std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); + VLOG(6) << "Cpu Codegen result:"; + VLOG(6) << code << std::endl; +} + +TEST(GenerateCode_Cpu, Sort) { + common::Context::Global().ResetNameId(); + +#ifdef CINN_WITH_CUDA + Target target = common::DefaultNVGPUTarget(); +#else + Target target = common::DefaultHostTarget(); +#endif + + ir::Expr n(4); + ir::Expr h(28); + + lang::Placeholder in("in", {n, h}); + std::vector outputs = Sort(in.tensor(), target, 1, true, "test_sort_out"); + ir::Tensor index = outputs[0]; + ir::Tensor out = outputs[1]; + + poly::StageMap stages = poly::CreateStages({in, index, out}); + std::vector funcs = + lang::LowerVec("TestGenerateCodeCpu_Sort", stages, {in, index, out}, {}, {}, nullptr, target, true); + + VLOG(6) << "Expr before CPU codegen:"; + VLOG(6) << funcs[0]->body; + + ir::Module::Builder builder("Sort_Module", target); + for (auto& f : funcs) { + builder.AddFunction(f); + } + + backends::CodeGenCX86 codegen(target, backends::CodeGenCX86::Feature::AVX512); + codegen.SetInlineBuiltinCodes(false); + std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); + VLOG(6) << "Cpu Codegen result:"; + VLOG(6) << code << std::endl; +} + +} // namespace op +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/op/use_ops.h b/cinn/hlir/op/use_ops.h index 068e530756..5c3e6b3c0b 100644 --- a/cinn/hlir/op/use_ops.h +++ b/cinn/hlir/op/use_ops.h @@ -23,6 +23,7 @@ CINN_USE_REGISTER(broadcast_grad_ops) CINN_USE_REGISTER(elementwise_ops) CINN_USE_REGISTER(transform_ops) CINN_USE_REGISTER(cast_ops) +CINN_USE_REGISTER(sort_ops) CINN_USE_REGISTER(squeeze_ops) CINN_USE_REGISTER(reduce_ops) CINN_USE_REGISTER(clip_ops) diff --git a/cinn/runtime/cpu/host_intrinsics.cc b/cinn/runtime/cpu/host_intrinsics.cc index 5ba5d6c7e2..ea486da55a 100644 --- a/cinn/runtime/cpu/host_intrinsics.cc +++ b/cinn/runtime/cpu/host_intrinsics.cc @@ -36,23 +36,73 @@ void __cinn_host_tanh_v(const cinn_buffer_t* x, cinn_buffer_t* out) { } } -#define __cinn_host_find_kernel(buf, size, num, type) \ - do { \ - for (int i = size - 1; i >= 0; --i) { \ - if (reinterpret_cast(buf->memory)[i] == num) return i; \ - } \ - return -1; \ +#define __cinn_host_find_kernel(buf, size, num, type, begin, stride) \ + do { \ + for (int i = (size - 1) * stride + begin; i >= begin; i -= stride) { \ + if (reinterpret_cast(buf->memory)[i] == num) return (i - begin) / stride; \ + } \ + return -1; \ } while (0) inline int cinn_host_find_int(const cinn_buffer_t* buf, int size, int num) { - __cinn_host_find_kernel(buf, size, num, int); + __cinn_host_find_kernel(buf, size, num, int, 0, 1); } inline int cinn_host_find_float(const cinn_buffer_t* buf, int size, float num) { - __cinn_host_find_kernel(buf, size, num, float); + __cinn_host_find_kernel(buf, size, num, float, 0, 1); +} + +inline int cinn_host_find_int_nd(const cinn_buffer_t* buf, int size, int num, int begin, int stride) { + __cinn_host_find_kernel(buf, size, num, int, begin, stride); +} + +inline int cinn_host_find_float_nd(const cinn_buffer_t* buf, int size, float num, int begin, int stride) { + __cinn_host_find_kernel(buf, size, num, float, begin, stride); } #undef __cinn_host_find_kernel + +#define __cinn_host_lt_num_kernel(buf, size, num, offset, stride, type) \ + do { \ + int out = 0; \ + for (int i = (size - 1) * stride + offset; i >= offset; i -= stride) { \ + if (reinterpret_cast(buf->memory)[i] < num) out++; \ + } \ + return out; \ + } while (0) + +inline int cinn_host_lt_num_float( + const cinn_buffer_t* buf, const int size, const float num, const int offset, const int stride) { + __cinn_host_lt_num_kernel(buf, size, num, offset, stride, float); +} + +inline int cinn_host_lt_num_int( + const cinn_buffer_t* buf, const int size, const int num, const int offset, const int stride) { + __cinn_host_lt_num_kernel(buf, size, num, offset, stride, int); +} + +#undef __cinn_host_lt_num_kernel + +#define __cinn_host_gt_num_kernel(buf, size, num, offset, stride, type) \ + do { \ + int out = 0; \ + for (int i = (size - 1) * stride + offset; i >= offset; i -= stride) { \ + if (reinterpret_cast(buf->memory)[i] > num) out++; \ + } \ + return out; \ + } while (0) + +inline int cinn_host_gt_num_float( + const cinn_buffer_t* buf, const int size, const float num, const int offset, const int stride) { + __cinn_host_gt_num_kernel(buf, size, num, offset, stride, float); +} + +inline int cinn_host_gt_num_int( + const cinn_buffer_t* buf, const int size, const int num, const int offset, const int stride) { + __cinn_host_gt_num_kernel(buf, size, num, offset, stride, int); +} + +#undef __cinn_host_gt_num_kernel } CINN_REGISTER_HELPER(host_intrinsics) { @@ -86,5 +136,59 @@ CINN_REGISTER_HELPER(host_intrinsics) { .AddInputType() .End(); + REGISTER_EXTERN_FUNC_HELPER(cinn_host_find_int_nd, host_target) + .SetRetType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .End(); + + REGISTER_EXTERN_FUNC_HELPER(cinn_host_find_float_nd, host_target) + .SetRetType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .End(); + + REGISTER_EXTERN_FUNC_HELPER(cinn_host_lt_num_int, host_target) + .SetRetType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .End(); + + REGISTER_EXTERN_FUNC_HELPER(cinn_host_lt_num_float, host_target) + .SetRetType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .End(); + + REGISTER_EXTERN_FUNC_HELPER(cinn_host_gt_num_int, host_target) + .SetRetType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .End(); + + REGISTER_EXTERN_FUNC_HELPER(cinn_host_gt_num_float, host_target) + .SetRetType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .End(); + return true; } diff --git a/cinn/runtime/cpu/host_intrinsics.h b/cinn/runtime/cpu/host_intrinsics.h index eb904c29d4..a463be4843 100644 --- a/cinn/runtime/cpu/host_intrinsics.h +++ b/cinn/runtime/cpu/host_intrinsics.h @@ -28,4 +28,20 @@ void __cinn_host_tanh_v(const cinn_buffer_t* x, cinn_buffer_t* out); inline int cinn_host_find_int(const cinn_buffer_t* buf, int size, int num); inline int cinn_host_find_float(const cinn_buffer_t* buf, int size, float num); + +inline int cinn_host_find_int_nd(const cinn_buffer_t* buf, int size, int num, int begin, int stride); + +inline int cinn_host_find_float_nd(const cinn_buffer_t* buf, int size, float num, int begin, int stride); + +inline int cinn_host_lt_num_float( + const cinn_buffer_t* buf, const int size, const float num, const int offset, const int stride); + +inline int cinn_host_lt_num_int( + const cinn_buffer_t* buf, const int size, const int num, const int offset, const int stride); + +inline int cinn_host_gt_num_float( + const cinn_buffer_t* buf, const int size, const float num, const int offset, const int stride); + +inline int cinn_host_gt_num_int( + const cinn_buffer_t* buf, const int size, const int num, const int offset, const int stride); } diff --git a/cinn/runtime/cpu/host_intrinsics_test.cc b/cinn/runtime/cpu/host_intrinsics_test.cc index 0b88019266..762cca09dc 100644 --- a/cinn/runtime/cpu/host_intrinsics_test.cc +++ b/cinn/runtime/cpu/host_intrinsics_test.cc @@ -66,6 +66,143 @@ TEST(tanh, basic) { } } +TEST(find_value_nd, basic) { + Expr M(10), N(20); + Placeholder x("x", {M, N}); + auto y = Compute( + {N}, + [&](Expr i) { + return CallExtern("cinn_host_find_float_nd", {x, M, x({Expr(5), Expr(3)}), i, N}); + }, + "y"); + + auto stages = CreateStages({y}); + + auto jit = backends::SimpleJIT::Create(); + + ir::Module::Builder builder("module1", common::DefaultHostTarget()); + + auto fn = Lower("fn", stages, {x, y}); + LOG(INFO) << "fn:\n" << fn; + + builder.AddFunction(fn); + + jit->Link(builder.Build()); + + auto fn_ptr = jit->Lookup("fn"); + auto fnp = reinterpret_cast(fn_ptr); + ASSERT_TRUE(fnp); + + auto* x_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build(); + auto* out_buf = common::BufferBuilder(Int(32), {N.as_int32()}).set_zero().Build(); + auto args = common::ArgsBuilder().Add(x_buf).Add(out_buf).Build(); + fnp(args.data(), args.size()); + + auto* x_buf_data = reinterpret_cast(x_buf->memory); + auto* out_buf_data = reinterpret_cast(out_buf->memory); + + for (int i = 0; i < out_buf->num_elements(); i++) { + LOG_FIRST_N(INFO, 3) << out_buf_data[i]; + if (out_buf_data[i] != -1) { + ASSERT_NEAR(x_buf_data[out_buf_data[i] * 20 + i], x_buf_data[5 * 20 + 3], 1e-5); + } + } +} + +TEST(cinn_host_lt_num_float, basic) { + Expr M(10), N(20); + Placeholder x("x", {M, N}); + auto y = Compute( + {N}, + [&](Expr j) { + return CallExtern("cinn_host_lt_num_float", {x, M, x({Expr(0), j}), j, N}); + }, + "y"); + + auto stages = CreateStages({y}); + + auto jit = backends::SimpleJIT::Create(); + + ir::Module::Builder builder("module1", common::DefaultHostTarget()); + + auto fn = Lower("fn", stages, {x, y}); + LOG(INFO) << "fn:\n" << fn; + + builder.AddFunction(fn); + + jit->Link(builder.Build()); + + auto fn_ptr = jit->Lookup("fn"); + auto fnp = reinterpret_cast(fn_ptr); + ASSERT_TRUE(fnp); + + auto* x_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build(); + auto* out_buf = common::BufferBuilder(Int(32), {N.as_int32()}).set_zero().Build(); + auto args = common::ArgsBuilder().Add(x_buf).Add(out_buf).Build(); + fnp(args.data(), args.size()); + + auto* x_buf_data = reinterpret_cast(x_buf->memory); + auto* out_buf_data = reinterpret_cast(out_buf->memory); + + for (int j = 0; j < 20; j++) { + int out = 0; + for (int i = 0; i < 10; i++) { + int index = i * 20 + j; + if (x_buf_data[index] < x_buf_data[j]) { + out++; + } + } + ASSERT_NEAR(out_buf_data[j], out, 1e-5); + } +} + +TEST(cinn_host_gt_num_float, basic) { + Expr M(10), N(20); + Placeholder x("x", {M, N}); + auto y = Compute( + {N}, + [&](Expr j) { + return CallExtern("cinn_host_gt_num_float", {x, M, x({Expr(0), j}), j, N}); + }, + "y"); + + auto stages = CreateStages({y}); + + auto jit = backends::SimpleJIT::Create(); + + ir::Module::Builder builder("module1", common::DefaultHostTarget()); + + auto fn = Lower("fn", stages, {x, y}); + LOG(INFO) << "fn:\n" << fn; + + builder.AddFunction(fn); + + jit->Link(builder.Build()); + + auto fn_ptr = jit->Lookup("fn"); + auto fnp = reinterpret_cast(fn_ptr); + ASSERT_TRUE(fnp); + + auto* x_buf = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build(); + auto* out_buf = common::BufferBuilder(Int(32), {N.as_int32()}).set_zero().Build(); + auto args = common::ArgsBuilder().Add(x_buf).Add(out_buf).Build(); + fnp(args.data(), args.size()); + + auto* x_buf_data = reinterpret_cast(x_buf->memory); + auto* out_buf_data = reinterpret_cast(out_buf->memory); + + for (int j = 0; j < 20; j++) { + int out = 0; + for (int i = 0; i < 10; i++) { + int index = i * 20 + j; + if (x_buf_data[index] > x_buf_data[j]) { + out++; + } + } + ASSERT_NEAR(out_buf_data[j], out, 1e-5); + } +} + } // namespace cpu } // namespace runtime } // namespace cinn diff --git a/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh b/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh index 76d5c311f7..9bcdee96e1 100644 --- a/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh +++ b/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh @@ -206,18 +206,28 @@ __device__ inline bool cinn_block_reduce_any(const bool *buf, int offset, int ex return cinn_block_reduce_any_internal(tmp_val); } -#define __cinn_cuda_find_kernel(buf, size, num) \ - do { \ - for (int i = size - 1; i >= 0; --i) { \ - if (buf[i] == num) return i; \ - } \ - return -1; \ +#define __cinn_cuda_find_kernel(buf, size, num, begin, stride) \ + do { \ + for (int i = (size - 1) * stride + begin; i >= begin; i -= stride) { \ + if (buf[i] == num) return (i - begin) / stride; \ + } \ + return -1; \ } while (0) -__device__ inline int cinn_cuda_find_int(const int *buf, int size, int num) { __cinn_cuda_find_kernel(buf, size, num); } +__device__ inline int cinn_cuda_find_int(const int *buf, int size, int num) { + __cinn_cuda_find_kernel(buf, size, num, 0, 1); +} __device__ inline int cinn_cuda_find_float(const float *buf, int size, float num) { - __cinn_cuda_find_kernel(buf, size, num); + __cinn_cuda_find_kernel(buf, size, num, 0, 1); +} + +__device__ inline int cinn_cuda_find_int_nd(const int *buf, int size, int num, int begin, int stride) { + __cinn_cuda_find_kernel(buf, size, num, begin, stride); +} + +__device__ inline int cinn_cuda_find_float_nd(const float *buf, int size, float num, int begin, int stride) { + __cinn_cuda_find_kernel(buf, size, num, begin, stride); } #undef __cinn_cuda_find_kernel @@ -240,6 +250,48 @@ __device__ inline int cinn_cuda_find_float_from(const float *buf, int size, floa #undef __cinn_cuda_find_from_kernel +#define __cinn_cuda_lt_num_kernel(buf, size, num, offset, stride) \ + do { \ + int out = 0; \ + for (int i = (size - 1) * stride + offset; i >= offset; i -= stride) { \ + if (buf[i] < num) out++; \ + } \ + return out; \ + } while (0) + +__device__ inline int cinn_cuda_lt_num_float( + const float *buf, const int size, const float num, const int offset, const int stride) { + __cinn_cuda_lt_num_kernel(buf, size, num, offset, stride); +} + +__device__ inline int cinn_cuda_lt_num_int( + const int *buf, const int size, const int num, const int offset, const int stride) { + __cinn_cuda_lt_num_kernel(buf, size, num, offset, stride); +} + +#undef __cinn_cuda_lt_num_kernel + +#define __cinn_cuda_gt_num_kernel(buf, size, num, offset, stride) \ + do { \ + int out = 0; \ + for (int i = (size - 1) * stride + offset; i >= offset; i -= stride) { \ + if (buf[i] > num) out++; \ + } \ + return out; \ + } while (0) + +__device__ inline int cinn_cuda_gt_num_float( + const float *buf, const int size, const float num, const int offset, const int stride) { + __cinn_cuda_gt_num_kernel(buf, size, num, offset, stride); +} + +__device__ inline int cinn_cuda_gt_num_int( + const int *buf, const int size, const int num, const int offset, const int stride) { + __cinn_cuda_gt_num_kernel(buf, size, num, offset, stride); +} + +#undef __cinn_cuda_gt_num_kernel + __device__ inline float cinn_cuda_index_add(const float x, const int axis_indice, const float *__restrict__ y, @@ -260,11 +312,11 @@ __device__ inline float cinn_cuda_index_add(const float x, #define block_shuffle_kernel(TYPE, name, op, init_value) \ __device__ inline TYPE block_shuffle_##name(const TYPE *buf, int line, int stride) { \ - TYPE val = init_value; \ - for (int idx = threadIdx.x; idx < line; idx += stride) { \ - val = op(val, buf[idx]); \ - } \ - return val; \ + TYPE val = init_value; \ + for (int idx = threadIdx.x; idx < line; idx += stride) { \ + val = op(val, buf[idx]); \ + } \ + return val; \ } block_shuffle_kernel(float, sum, cinn_sum, 0.0f); diff --git a/cinn/runtime/cuda/cuda_intrinsics.cc b/cinn/runtime/cuda/cuda_intrinsics.cc index da5cb594a1..77cacc4f6d 100644 --- a/cinn/runtime/cuda/cuda_intrinsics.cc +++ b/cinn/runtime/cuda/cuda_intrinsics.cc @@ -178,6 +178,24 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { .AddInputType() .End(); + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_cuda_find_int_nd, target) + .SetRetType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .End(); + + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_cuda_find_float_nd, target) + .SetRetType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .End(); + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_cuda_find_int_from, target) .SetRetType() .AddInputType() @@ -194,6 +212,42 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { .AddInputType() .End(); + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_cuda_lt_num_int, target) + .SetRetType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .End(); + + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_cuda_lt_num_float, target) + .SetRetType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .End(); + + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_cuda_gt_num_int, target) + .SetRetType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .End(); + + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_cuda_gt_num_float, target) + .SetRetType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .End(); + REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_cuda_index_add, target) .SetRetType() .AddInputType()