From 1a3466212793a505b8e7cb091281ab0c709a9e35 Mon Sep 17 00:00:00 2001 From: Superjomn Date: Thu, 27 Feb 2020 14:53:12 +0800 Subject: [PATCH] support CodeGenC for LoweredFunc --- cinn/backends/CMakeLists.txt | 3 + cinn/backends/codegen_c.cc | 97 ++++++++++++++++++++++----------- cinn/backends/codegen_c.h | 3 +- cinn/backends/codegen_c_test.cc | 37 +++++++++++++ cinn/ir/ir_printer.cc | 16 +----- cinn/ir/ir_printer.h | 17 +++--- cinn/ir/ir_visitor.h | 1 + cinn/ir/lowered_func.h | 2 + cinn/lang/compute.cc | 6 +- cinn/lang/lower_test.cc | 5 +- cinn/lang/placeholder.h | 36 ++++++++---- 11 files changed, 150 insertions(+), 73 deletions(-) create mode 100644 cinn/backends/codegen_c_test.cc diff --git a/cinn/backends/CMakeLists.txt b/cinn/backends/CMakeLists.txt index ee890a007bdfa..dfcd247dc3d6b 100644 --- a/cinn/backends/CMakeLists.txt +++ b/cinn/backends/CMakeLists.txt @@ -5,3 +5,6 @@ foreach(cpp ${srcs}) "${core_src};cinn/backends/${cpp}" CACHE INTERNAL "") endforeach() + + +cc_test(test_codegen_c SRCS codegen_c_test.cc DEPS core) diff --git a/cinn/backends/codegen_c.cc b/cinn/backends/codegen_c.cc index 94b7efd9baf2a..41ea279e47671 100644 --- a/cinn/backends/codegen_c.cc +++ b/cinn/backends/codegen_c.cc @@ -1,6 +1,7 @@ #include "cinn/backends/codegen_c.h" #include "cinn/ir/lowered_func.h" +#include "cinn/utils/string.h" namespace cinn { namespace backends { @@ -8,42 +9,31 @@ namespace backends { CodeGenC::CodeGenC(std::ostream &os, Target target) : ir::IrPrinter(os), target_(target) {} void CodeGenC::Compile(const lang::Module &module) {} -void CodeGenC::Compile(const ir::LoweredFunc &function) { - os() << "void " << function->name; - - // output arguments - os() << "("; - - auto print_arg = [&](const ir::Argument &arg) { - if (arg.is_buffer()) { - os() << "struct cinn_buffer_t *"; - } else if (arg.is_scalar()) { - os() << PrintType(arg.type) << " "; - os() << arg.name; - } - os() << arg.name; - }; - - for (int i = 0; i < function->args.size() - 1; i++) { - print_arg(function->args[i]); - os() << ", "; +void CodeGenC::Compile(const ir::LoweredFunc &function) { Print(function); } +void CodeGenC::Compile(const ir::Buffer &buffer) {} +std::string CodeGenC::PrintType(Type type) { + if (type == Int(8)) { + return "int8_t"; } - if (function->args.size() >= 1) { - print_arg(function->args.back()); + if (type == Int(32)) { + return "int32_t"; + } + if (type == Int(64)) { + return "int64_t"; + } + if (type == Bool()) { + return "bool"; + } + if (type == Float(32)) { + return "float"; + } + if (type == Float(64)) { + return "double"; } - os() << ")"; - - DoIndent(); - os() << "{\n"; - - Print(function->body); - - DoIndent(); - os() << "}"; + LOG(ERROR) << type; + NOT_IMPLEMENTED } -void CodeGenC::Compile(const ir::Buffer &buffer) {} -std::string CodeGenC::PrintType(Type type) { return std::__cxx11::string(); } void CodeGenC::Visit(const ir::IntImm *op) { IrPrinter::Visit(op); } void CodeGenC::Visit(const ir::UIntImm *op) { IrPrinter::Visit(op); } void CodeGenC::Visit(const ir::FloatImm *op) { IrPrinter::Visit(op); } @@ -72,10 +62,16 @@ void CodeGenC::Visit(const ir::Cast *op) { PrintCastExpr(op->type(), op->v); } void CodeGenC::Visit(const ir::For *op) { LOG(FATAL) << "Not Implemented"; } void CodeGenC::Visit(const ir::PolyFor *op) { os() << "for ("; + os() << PrintType(Int(32)); + os() << " " << op->iterator->name; + os() << " = "; Print(op->init); os() << "; "; Print(op->condition); os() << "; "; + + os() << op->iterator->name; + os() << " += "; Print(op->inc); os() << ")"; @@ -141,5 +137,42 @@ void CodeGenC::PrintCastExpr(const Type &type, Expr e) { os() << ")"; } +void CodeGenC::Visit(const ir::_LoweredFunc_ *op) { + os() << "void " << op->name; + + // output arguments + os() << "("; + + auto print_arg = [&](const ir::Argument &arg) { + if (arg.is_buffer()) { + os() << "struct cinn_buffer_t *"; + } else if (arg.is_scalar()) { + os() << PrintType(arg.type) << " "; + os() << arg.name; + } else { + NOT_IMPLEMENTED + } + os() << arg.name; + }; + + for (int i = 0; i < op->args.size() - 1; i++) { + print_arg(op->args[i]); + os() << ", "; + } + if (op->args.size() >= 1) { + print_arg(op->args.back()); + } + + os() << ")"; + + DoIndent(); + os() << "{\n"; + + Print(op->body); + + DoIndent(); + os() << "}"; +} + } // namespace backends } // namespace cinn diff --git a/cinn/backends/codegen_c.h b/cinn/backends/codegen_c.h index 0b5041c4ecf50..0aae505004d4c 100644 --- a/cinn/backends/codegen_c.h +++ b/cinn/backends/codegen_c.h @@ -23,11 +23,10 @@ class CodeGenC : public ir::IrPrinter { CodeGenC(std::ostream& os, Target target); void Compile(const lang::Module& module); - - protected: void Compile(const ir::LoweredFunc& function); void Compile(const ir::Buffer& buffer); + protected: std::string PrintType(Type type); void PrintCastExpr(const Type& type, Expr e); diff --git a/cinn/backends/codegen_c_test.cc b/cinn/backends/codegen_c_test.cc new file mode 100644 index 0000000000000..b772f18fcf2c3 --- /dev/null +++ b/cinn/backends/codegen_c_test.cc @@ -0,0 +1,37 @@ +#include "cinn/backends/codegen_c.h" + +#include + +#include + +#include "cinn/lang/compute.h" +#include "cinn/lang/lower.h" +#include "cinn/lang/placeholder.h" + +namespace cinn { +namespace backends { + +TEST(CodeGenC, basic) { + std::stringstream ss; + Target target; + CodeGenC codegen(ss, target); + + lang::Placeholder A("A", {100, 20}); + lang::Placeholder B("B", {100, 20}); + + lang::Buffer C_buf; + auto C = lang::Compute({100, 20}, [&](Var i, Var j) { return A(i, j) + B(i, j); }); + C->Bind(C_buf); + + auto funcs = lang::Lower("func_C", {A, B, C}); + ASSERT_EQ(funcs.size(), 1UL); + + codegen.Compile(funcs.front()); + + auto out = ss.str(); + + std::cout << "codegen C:" << std::endl << out << std::endl; +} + +} // namespace backends +} // namespace cinn diff --git a/cinn/ir/ir_printer.cc b/cinn/ir/ir_printer.cc index cc1fbc045531b..920625626f2ff 100644 --- a/cinn/ir/ir_printer.cc +++ b/cinn/ir/ir_printer.cc @@ -10,7 +10,7 @@ namespace cinn { namespace ir { -void IrPrinter::Print(Expr e) { e.Accept(reinterpret_cast(this)); } +void IrPrinter::Print(Expr e) { IRVisitor::Visit(&e); } void IrPrinter::Print(const std::vector &exprs, const std::string &splitter) { for (int i = 0; i < exprs.size() - 1; i++) { Print(exprs[i]); @@ -183,19 +183,9 @@ void IrPrinter::Visit(const _LoweredFunc_ *f) { for (auto &arg : f->args) { arg_names.push_back(arg.name); } - os_ << "(" << utils::Join(arg_names, ", "); - - DoIndent(); - os_ << "{"; - - IncIndent(); + os_ << "(" << utils::Join(arg_names, ", ") << ")\n"; Print(f->body); - - DecIndent(); - - DoIndent(); - os_ << "}"; } std::ostream &operator<<(std::ostream &os, Expr a) { std::stringstream ss; @@ -205,8 +195,6 @@ std::ostream &operator<<(std::ostream &os, Expr a) { return os; } -std::ostream &operator<<(std::ostream &os, const ir::LoweredFunc &f) {} - std::ostream &operator<<(std::ostream &os, const lang::Module &m); } // namespace ir diff --git a/cinn/ir/ir_printer.h b/cinn/ir/ir_printer.h index 44eb68d885d18..e00f56d411870 100644 --- a/cinn/ir/ir_printer.h +++ b/cinn/ir/ir_printer.h @@ -24,13 +24,7 @@ struct IrPrinter : public IRVisitor { void Print(const std::vector &exprs, const std::string &splitter = ", "); //! Emit a binary operator template - void PrintBinaryOp(const std::string &op, const BinaryOpNode *x) { - os_ << "("; - Print(x->a); - os_ << " " + op + " "; - Print(x->b); - os_ << ")"; - } + void PrintBinaryOp(const std::string &op, const BinaryOpNode *x); //! Prefix the current line with `indent_` spaces. void DoIndent(); @@ -89,5 +83,14 @@ struct IrPrinter : public IRVisitor { std::ostream &operator<<(std::ostream &os, Expr a); std::ostream &operator<<(std::ostream &os, const lang::Module &m); +template +void IrPrinter::PrintBinaryOp(const std::string &op, const BinaryOpNode *x) { + os_ << "("; + Print(x->a); + os_ << " " + op + " "; + Print(x->b); + os_ << ")"; +} + } // namespace ir } // namespace cinn diff --git a/cinn/ir/ir_visitor.h b/cinn/ir/ir_visitor.h index 6cb345478db84..d5c5f0562c6c3 100644 --- a/cinn/ir/ir_visitor.h +++ b/cinn/ir/ir_visitor.h @@ -49,6 +49,7 @@ struct IRVisitorBase { struct IRVisitor : public IRVisitorBase { IRVisitor() = default; + void Visit(const Expr* x) { IRVisitorBase::Visit(x); } #define __m(t__) virtual void Visit(const t__* x) = 0; NODETY_FORALL(__m) #undef __m diff --git a/cinn/ir/lowered_func.h b/cinn/ir/lowered_func.h index 845a658249b24..7c22833d6d9a6 100644 --- a/cinn/ir/lowered_func.h +++ b/cinn/ir/lowered_func.h @@ -39,6 +39,8 @@ class LoweredFunc : public IrNodeRef { LoweredFunc() = default; explicit LoweredFunc(IrNode* n) : IrNodeRef(n) {} + operator Expr() const { return Expr(ptr()); } + const _LoweredFunc_* operator->() const; _LoweredFunc_* operator->(); }; diff --git a/cinn/lang/compute.cc b/cinn/lang/compute.cc index 7e1ab4e3fc015..e8491b20d5886 100644 --- a/cinn/lang/compute.cc +++ b/cinn/lang/compute.cc @@ -69,8 +69,10 @@ ir::Tensor Compute(const std::vector &dims, std::vector shape; for (int v : dims) shape.emplace_back(v); - auto op = ir::ComputeOp::Make(name, "" /*tag*/, {}, axis, {expr}, shape); - return ir::_Tensor_::Make(name, shape, op); + auto unique_name = name.empty() ? Context::Global().NewName("tensor") : name; + + auto op = ir::ComputeOp::Make(unique_name, "" /*tag*/, {}, axis, {expr}, shape); + return ir::_Tensor_::Make(unique_name, shape, op); } } // namespace lang diff --git a/cinn/lang/lower_test.cc b/cinn/lang/lower_test.cc index 212ca81e54b2a..4602b063f527f 100644 --- a/cinn/lang/lower_test.cc +++ b/cinn/lang/lower_test.cc @@ -61,10 +61,7 @@ TEST(lower, more_complex) { auto lower_funcs = Lower("cal_C", {A, B, C}); LOG(INFO) << "lower_size " << lower_funcs.size(); - -#define TEST_SOUTPUT(x, out) \ - LOG(INFO) << "\n" << x; \ - EXPECT_EQ(utils::GetStreamCnt(x), utils::Trim(out)); + LOG(INFO) << "func:\n" << Expr(lower_funcs.front()->self()); } } // namespace lang diff --git a/cinn/lang/placeholder.h b/cinn/lang/placeholder.h index 9e32c49be8074..4f551be1a2fcf 100644 --- a/cinn/lang/placeholder.h +++ b/cinn/lang/placeholder.h @@ -22,6 +22,7 @@ using ir::Expr; template class Placeholder { public: + Placeholder(const std::string &name, const std::vector &shape); Placeholder(const std::string &name, const std::vector &shape); //! Get a slice. @@ -36,6 +37,21 @@ class Placeholder { operator ir::Tensor() { return tensor_; } private: + void Init(const std::string &name, const std::vector &shape) { + ir::Var buffer_ptr(Context::Global().NewName("buffer")); + buffer_ptr->set_type(type_of()); + + std::vector strides(shape.size(), Expr(1)); + Expr offset(0); + + std::vector axis; + for (int i = 0; i < shape.size(); i++) axis.emplace_back(common::axis_name(i)); + + auto op = ir::PlaceholderOp::Make(name, shape, type_of()); + + tensor_ = ir::_Tensor_::Make(name, shape, op); + } + ir::Tensor tensor_; }; @@ -45,19 +61,15 @@ Expr Placeholder::operator()(const std::vector &indices) const { } template -Placeholder::Placeholder(const std::string &name, const std::vector &shape) { - ir::Var buffer_ptr(Context::Global().NewName("buffer")); - buffer_ptr->set_type(type_of()); - - std::vector strides(shape.size(), Expr(1)); - Expr offset(0); - - std::vector axis; - for (int i = 0; i < shape.size(); i++) axis.emplace_back(common::axis_name(i)); - - auto op = ir::PlaceholderOp::Make(name, shape, type_of()); +Placeholder::Placeholder(const std::string &name, const std::vector &shape) { + std::vector _shape; + for (int v : shape) _shape.push_back(Expr(v)); + Init(name, _shape); +} - tensor_ = ir::_Tensor_::Make(name, shape, op); +template +Placeholder::Placeholder(const std::string &name, const std::vector &shape) { + Init(name, shape); } } // namespace lang