Skip to content

Commit

Permalink
Add primitive layer. add function primitive::add (PaddlePaddle#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
haozech committed Jul 25, 2020
1 parent 96ca5bd commit cdb90c1
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
35 changes: 35 additions & 0 deletions cinn/hlir/pe/add.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once
#include <string>
#include <vector>
#include "cinn/common/common.h"
#include "cinn/ir/ir.h"
#include "cinn/ir/node.h"
#include "cinn/lang/compute.h"
#include "cinn/lang/tensor.h"

using cinn::ir::Expr;
using cinn::ir::Tensor;
using cinn::ir::Var;
using cinn::lang::Compute;

namespace cinn {
namespace hlir {
namespace pe {

Tensor Add(const Tensor &A, const Tensor &B, const std::string output_name = "") {
CHECK(A->SameShapeWith(B)) << "The 2 inputs have different shapes with each other. "
"The Add fucntion needs two inputs to have identical shape.";
const std::vector<Expr> output_shape = A->shape;
CHECK_GE(output_shape.size(), 1) << "The input shape of pe::Add function is " << output_shape.size()
<< " and it should be >= 1.";
CHECK_LE(output_shape.size(), 4) << "The input shape of pe::Add function is " << output_shape.size()
<< " and it should be <= 4.";

Tensor output = Compute(
output_shape, [&](const std::vector<Expr> &indice) { return A(indice) + B(indice); }, output_name);
return output;
}

} // namespace pe
} // namespace hlir
} // namespace cinn
10 changes: 5 additions & 5 deletions tests/test01_elementwise_add_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

#include "cinn/cinn.h"
#include "cinn/common/ir_util.h"
#include "cinn/hlir/pe/add.h"
#include "cinn/optim/optimize.h"

namespace cinn {

TEST(test01_elementwise_add, basic) {
Expand All @@ -12,8 +12,9 @@ TEST(test01_elementwise_add, basic) {
Placeholder<float> A("A", {M, N});
Placeholder<float> B("B", {M, N});

auto C = Compute(
{M, N}, [&](Var i, Var j) { return A(i, j) + B(i, j); }, "C");
Buffer C_buf(Float(32));
auto C = hlir::pe::Add(A, B, "C");
C->Bind(C_buf);

Target target;
target.arch = Target::Arch ::X86;
Expand All @@ -38,8 +39,7 @@ TEST(test01_elementwise_add, vectorize) {
Placeholder<float> A("A", {M, N});
Placeholder<float> B("B", {M, N});

auto C = Compute(
{M, N}, [&](Var i, Var j) { return A(i, j) + B(i, j); }, "C");
auto C = hlir::pe::Add(A, B, "C");
C->stage()->Vectorize(1, 8);

Target target;
Expand Down

0 comments on commit cdb90c1

Please sign in to comment.