diff --git a/cinn/hlir/pe/add.h b/cinn/hlir/pe/add.h new file mode 100644 index 0000000000000..a0bc1f666b372 --- /dev/null +++ b/cinn/hlir/pe/add.h @@ -0,0 +1,35 @@ +#pragma once +#include +#include +#include "cinn/common/common.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/node.h" +#include "cinn/lang/compute.h" +#include "cinn/lang/tensor.h" + +using cinn::ir::Expr; +using cinn::ir::Tensor; +using cinn::ir::Var; +using cinn::lang::Compute; + +namespace cinn { +namespace hlir { +namespace pe { + +Tensor Add(const Tensor &A, const Tensor &B, const std::string output_name = "") { + CHECK(A->SameShapeWith(B)) << "The 2 inputs have different shapes with each other. " + "The Add fucntion needs two inputs to have identical shape."; + const std::vector output_shape = A->shape; + CHECK_GE(output_shape.size(), 1) << "The input shape of pe::Add function is " << output_shape.size() + << " and it should be >= 1."; + CHECK_LE(output_shape.size(), 4) << "The input shape of pe::Add function is " << output_shape.size() + << " and it should be <= 4."; + + Tensor output = Compute( + output_shape, [&](const std::vector &indice) { return A(indice) + B(indice); }, output_name); + return output; +} + +} // namespace pe +} // namespace hlir +} // namespace cinn diff --git a/tests/test01_elementwise_add_main.cc b/tests/test01_elementwise_add_main.cc index 3981422c71efe..c78ffe4cc90eb 100644 --- a/tests/test01_elementwise_add_main.cc +++ b/tests/test01_elementwise_add_main.cc @@ -2,8 +2,8 @@ #include "cinn/cinn.h" #include "cinn/common/ir_util.h" +#include "cinn/hlir/pe/add.h" #include "cinn/optim/optimize.h" - namespace cinn { TEST(test01_elementwise_add, basic) { @@ -12,8 +12,9 @@ TEST(test01_elementwise_add, basic) { Placeholder A("A", {M, N}); Placeholder B("B", {M, N}); - auto C = Compute( - {M, N}, [&](Var i, Var j) { return A(i, j) + B(i, j); }, "C"); + Buffer C_buf(Float(32)); + auto C = hlir::pe::Add(A, B, "C"); + C->Bind(C_buf); Target target; target.arch = Target::Arch ::X86; @@ -38,8 +39,7 @@ TEST(test01_elementwise_add, vectorize) { Placeholder A("A", {M, N}); Placeholder B("B", {M, N}); - auto C = Compute( - {M, N}, [&](Var i, Var j) { return A(i, j) + B(i, j); }, "C"); + auto C = hlir::pe::Add(A, B, "C"); C->stage()->Vectorize(1, 8); Target target;