Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#49 from Superjomn/fea/compile_c_lower…
Browse files Browse the repository at this point in the history
…ed_func

support CodeGenC for LoweredFunc
  • Loading branch information
Superjomn committed Feb 27, 2020
2 parents 18a20ce + 1a34662 commit 198ad34
Show file tree
Hide file tree
Showing 11 changed files with 150 additions and 73 deletions.
3 changes: 3 additions & 0 deletions cinn/backends/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ foreach(cpp ${srcs})
"${core_src};cinn/backends/${cpp}"
CACHE INTERNAL "")
endforeach()


cc_test(test_codegen_c SRCS codegen_c_test.cc DEPS core)
97 changes: 65 additions & 32 deletions cinn/backends/codegen_c.cc
Original file line number Diff line number Diff line change
@@ -1,49 +1,39 @@
#include "cinn/backends/codegen_c.h"

#include "cinn/ir/lowered_func.h"
#include "cinn/utils/string.h"

namespace cinn {
namespace backends {

CodeGenC::CodeGenC(std::ostream &os, Target target) : ir::IrPrinter(os), target_(target) {}

void CodeGenC::Compile(const lang::Module &module) {}
void CodeGenC::Compile(const ir::LoweredFunc &function) {
os() << "void " << function->name;

// output arguments
os() << "(";

auto print_arg = [&](const ir::Argument &arg) {
if (arg.is_buffer()) {
os() << "struct cinn_buffer_t *";
} else if (arg.is_scalar()) {
os() << PrintType(arg.type) << " ";
os() << arg.name;
}
os() << arg.name;
};

for (int i = 0; i < function->args.size() - 1; i++) {
print_arg(function->args[i]);
os() << ", ";
void CodeGenC::Compile(const ir::LoweredFunc &function) { Print(function); }
void CodeGenC::Compile(const ir::Buffer &buffer) {}
std::string CodeGenC::PrintType(Type type) {
if (type == Int(8)) {
return "int8_t";
}
if (function->args.size() >= 1) {
print_arg(function->args.back());
if (type == Int(32)) {
return "int32_t";
}
if (type == Int(64)) {
return "int64_t";
}
if (type == Bool()) {
return "bool";
}
if (type == Float(32)) {
return "float";
}
if (type == Float(64)) {
return "double";
}

os() << ")";

DoIndent();
os() << "{\n";

Print(function->body);

DoIndent();
os() << "}";
LOG(ERROR) << type;
NOT_IMPLEMENTED
}
void CodeGenC::Compile(const ir::Buffer &buffer) {}
std::string CodeGenC::PrintType(Type type) { return std::__cxx11::string(); }
void CodeGenC::Visit(const ir::IntImm *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::UIntImm *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::FloatImm *op) { IrPrinter::Visit(op); }
Expand Down Expand Up @@ -72,10 +62,16 @@ void CodeGenC::Visit(const ir::Cast *op) { PrintCastExpr(op->type(), op->v); }
void CodeGenC::Visit(const ir::For *op) { LOG(FATAL) << "Not Implemented"; }
void CodeGenC::Visit(const ir::PolyFor *op) {
os() << "for (";
os() << PrintType(Int(32));
os() << " " << op->iterator->name;
os() << " = ";
Print(op->init);
os() << "; ";
Print(op->condition);
os() << "; ";

os() << op->iterator->name;
os() << " += ";
Print(op->inc);
os() << ")";

Expand Down Expand Up @@ -141,5 +137,42 @@ void CodeGenC::PrintCastExpr(const Type &type, Expr e) {
os() << ")";
}

void CodeGenC::Visit(const ir::_LoweredFunc_ *op) {
os() << "void " << op->name;

// output arguments
os() << "(";

auto print_arg = [&](const ir::Argument &arg) {
if (arg.is_buffer()) {
os() << "struct cinn_buffer_t *";
} else if (arg.is_scalar()) {
os() << PrintType(arg.type) << " ";
os() << arg.name;
} else {
NOT_IMPLEMENTED
}
os() << arg.name;
};

for (int i = 0; i < op->args.size() - 1; i++) {
print_arg(op->args[i]);
os() << ", ";
}
if (op->args.size() >= 1) {
print_arg(op->args.back());
}

os() << ")";

DoIndent();
os() << "{\n";

Print(op->body);

DoIndent();
os() << "}";
}

} // namespace backends
} // namespace cinn
3 changes: 1 addition & 2 deletions cinn/backends/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ class CodeGenC : public ir::IrPrinter {
CodeGenC(std::ostream& os, Target target);

void Compile(const lang::Module& module);

protected:
void Compile(const ir::LoweredFunc& function);
void Compile(const ir::Buffer& buffer);

protected:
std::string PrintType(Type type);
void PrintCastExpr(const Type& type, Expr e);

Expand Down
37 changes: 37 additions & 0 deletions cinn/backends/codegen_c_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#include "cinn/backends/codegen_c.h"

#include <gtest/gtest.h>

#include <sstream>

#include "cinn/lang/compute.h"
#include "cinn/lang/lower.h"
#include "cinn/lang/placeholder.h"

namespace cinn {
namespace backends {

TEST(CodeGenC, basic) {
std::stringstream ss;
Target target;
CodeGenC codegen(ss, target);

lang::Placeholder<float> A("A", {100, 20});
lang::Placeholder<float> B("B", {100, 20});

lang::Buffer C_buf;
auto C = lang::Compute({100, 20}, [&](Var i, Var j) { return A(i, j) + B(i, j); });
C->Bind(C_buf);

auto funcs = lang::Lower("func_C", {A, B, C});
ASSERT_EQ(funcs.size(), 1UL);

codegen.Compile(funcs.front());

auto out = ss.str();

std::cout << "codegen C:" << std::endl << out << std::endl;
}

} // namespace backends
} // namespace cinn
16 changes: 2 additions & 14 deletions cinn/ir/ir_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
namespace cinn {
namespace ir {

void IrPrinter::Print(Expr e) { e.Accept(reinterpret_cast<IRVisitor *>(this)); }
void IrPrinter::Print(Expr e) { IRVisitor::Visit(&e); }
void IrPrinter::Print(const std::vector<Expr> &exprs, const std::string &splitter) {
for (int i = 0; i < exprs.size() - 1; i++) {
Print(exprs[i]);
Expand Down Expand Up @@ -183,19 +183,9 @@ void IrPrinter::Visit(const _LoweredFunc_ *f) {
for (auto &arg : f->args) {
arg_names.push_back(arg.name);
}
os_ << "(" << utils::Join(arg_names, ", ");

DoIndent();
os_ << "{";

IncIndent();
os_ << "(" << utils::Join(arg_names, ", ") << ")\n";

Print(f->body);

DecIndent();

DoIndent();
os_ << "}";
}
std::ostream &operator<<(std::ostream &os, Expr a) {
std::stringstream ss;
Expand All @@ -205,8 +195,6 @@ std::ostream &operator<<(std::ostream &os, Expr a) {
return os;
}

std::ostream &operator<<(std::ostream &os, const ir::LoweredFunc &f) {}

std::ostream &operator<<(std::ostream &os, const lang::Module &m);

} // namespace ir
Expand Down
17 changes: 10 additions & 7 deletions cinn/ir/ir_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,7 @@ struct IrPrinter : public IRVisitor {
void Print(const std::vector<Expr> &exprs, const std::string &splitter = ", ");
//! Emit a binary operator
template <typename IRN>
void PrintBinaryOp(const std::string &op, const BinaryOpNode<IRN> *x) {
os_ << "(";
Print(x->a);
os_ << " " + op + " ";
Print(x->b);
os_ << ")";
}
void PrintBinaryOp(const std::string &op, const BinaryOpNode<IRN> *x);

//! Prefix the current line with `indent_` spaces.
void DoIndent();
Expand Down Expand Up @@ -89,5 +83,14 @@ struct IrPrinter : public IRVisitor {
std::ostream &operator<<(std::ostream &os, Expr a);
std::ostream &operator<<(std::ostream &os, const lang::Module &m);

template <typename IRN>
void IrPrinter::PrintBinaryOp(const std::string &op, const BinaryOpNode<IRN> *x) {
os_ << "(";
Print(x->a);
os_ << " " + op + " ";
Print(x->b);
os_ << ")";
}

} // namespace ir
} // namespace cinn
1 change: 1 addition & 0 deletions cinn/ir/ir_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ struct IRVisitorBase {
struct IRVisitor : public IRVisitorBase<void> {
IRVisitor() = default;

void Visit(const Expr* x) { IRVisitorBase::Visit(x); }
#define __m(t__) virtual void Visit(const t__* x) = 0;
NODETY_FORALL(__m)
#undef __m
Expand Down
2 changes: 2 additions & 0 deletions cinn/ir/lowered_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class LoweredFunc : public IrNodeRef {
LoweredFunc() = default;
explicit LoweredFunc(IrNode* n) : IrNodeRef(n) {}

operator Expr() const { return Expr(ptr()); }

const _LoweredFunc_* operator->() const;
_LoweredFunc_* operator->();
};
Expand Down
6 changes: 4 additions & 2 deletions cinn/lang/compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,10 @@ ir::Tensor Compute(const std::vector<int> &dims,
std::vector<Expr> shape;
for (int v : dims) shape.emplace_back(v);

auto op = ir::ComputeOp::Make(name, "" /*tag*/, {}, axis, {expr}, shape);
return ir::_Tensor_::Make(name, shape, op);
auto unique_name = name.empty() ? Context::Global().NewName("tensor") : name;

auto op = ir::ComputeOp::Make(unique_name, "" /*tag*/, {}, axis, {expr}, shape);
return ir::_Tensor_::Make(unique_name, shape, op);
}

} // namespace lang
Expand Down
5 changes: 1 addition & 4 deletions cinn/lang/lower_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,7 @@ TEST(lower, more_complex) {
auto lower_funcs = Lower("cal_C", {A, B, C});

LOG(INFO) << "lower_size " << lower_funcs.size();

#define TEST_SOUTPUT(x, out) \
LOG(INFO) << "\n" << x; \
EXPECT_EQ(utils::GetStreamCnt(x), utils::Trim(out));
LOG(INFO) << "func:\n" << Expr(lower_funcs.front()->self());
}

} // namespace lang
Expand Down
36 changes: 24 additions & 12 deletions cinn/lang/placeholder.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ using ir::Expr;
template <typename T>
class Placeholder {
public:
Placeholder(const std::string &name, const std::vector<int> &shape);
Placeholder(const std::string &name, const std::vector<Expr> &shape);

//! Get a slice.
Expand All @@ -36,6 +37,21 @@ class Placeholder {
operator ir::Tensor() { return tensor_; }

private:
void Init(const std::string &name, const std::vector<Expr> &shape) {
ir::Var buffer_ptr(Context::Global().NewName("buffer"));
buffer_ptr->set_type(type_of<T>());

std::vector<Expr> strides(shape.size(), Expr(1));
Expr offset(0);

std::vector<ir::Var> axis;
for (int i = 0; i < shape.size(); i++) axis.emplace_back(common::axis_name(i));

auto op = ir::PlaceholderOp::Make(name, shape, type_of<T>());

tensor_ = ir::_Tensor_::Make(name, shape, op);
}

ir::Tensor tensor_;
};

Expand All @@ -45,19 +61,15 @@ Expr Placeholder<T>::operator()(const std::vector<Expr> &indices) const {
}

template <typename T>
Placeholder<T>::Placeholder(const std::string &name, const std::vector<Expr> &shape) {
ir::Var buffer_ptr(Context::Global().NewName("buffer"));
buffer_ptr->set_type(type_of<T>());

std::vector<Expr> strides(shape.size(), Expr(1));
Expr offset(0);

std::vector<ir::Var> axis;
for (int i = 0; i < shape.size(); i++) axis.emplace_back(common::axis_name(i));

auto op = ir::PlaceholderOp::Make(name, shape, type_of<T>());
Placeholder<T>::Placeholder(const std::string &name, const std::vector<int> &shape) {
std::vector<Expr> _shape;
for (int v : shape) _shape.push_back(Expr(v));
Init(name, _shape);
}

tensor_ = ir::_Tensor_::Make(name, shape, op);
template <typename T>
Placeholder<T>::Placeholder(const std::string &name, const std::vector<Expr> &shape) {
Init(name, shape);
}

} // namespace lang
Expand Down

0 comments on commit 198ad34

Please sign in to comment.