diff --git a/.gitignore b/.gitignore index 9b07e4811bec5..9ee75776ba362 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ cmake-build* build* .idea* +*.html \ No newline at end of file diff --git a/cinn/backends/codegen_c.cc b/cinn/backends/codegen_c.cc index bae707d9bd911..94b7efd9baf2a 100644 --- a/cinn/backends/codegen_c.cc +++ b/cinn/backends/codegen_c.cc @@ -1,18 +1,20 @@ #include "cinn/backends/codegen_c.h" +#include "cinn/ir/lowered_func.h" + namespace cinn { namespace backends { CodeGenC::CodeGenC(std::ostream &os, Target target) : ir::IrPrinter(os), target_(target) {} void CodeGenC::Compile(const lang::Module &module) {} -void CodeGenC::Compile(const lang::LoweredFunc &function) { - os() << "void " << function.name; +void CodeGenC::Compile(const ir::LoweredFunc &function) { + os() << "void " << function->name; // output arguments os() << "("; - auto print_arg = [&](const lang::Argument &arg) { + auto print_arg = [&](const ir::Argument &arg) { if (arg.is_buffer()) { os() << "struct cinn_buffer_t *"; } else if (arg.is_scalar()) { @@ -22,12 +24,12 @@ void CodeGenC::Compile(const lang::LoweredFunc &function) { os() << arg.name; }; - for (int i = 0; i < function.args.size() - 1; i++) { - print_arg(function.args[i]); + for (int i = 0; i < function->args.size() - 1; i++) { + print_arg(function->args[i]); os() << ", "; } - if (function.args.size() >= 1) { - print_arg(function.args.back()); + if (function->args.size() >= 1) { + print_arg(function->args.back()); } os() << ")"; @@ -35,7 +37,7 @@ void CodeGenC::Compile(const lang::LoweredFunc &function) { DoIndent(); os() << "{\n"; - Print(function.body); + Print(function->body); DoIndent(); os() << "}"; diff --git a/cinn/backends/codegen_c.h b/cinn/backends/codegen_c.h index 9fe240b086281..0b5041c4ecf50 100644 --- a/cinn/backends/codegen_c.h +++ b/cinn/backends/codegen_c.h @@ -7,6 +7,7 @@ #include "cinn/ir/function.h" #include "cinn/ir/ir.h" #include "cinn/ir/ir_printer.h" +#include "cinn/ir/lowered_func.h" #include "cinn/lang/module.h" namespace cinn { @@ -24,7 +25,7 @@ class CodeGenC : public ir::IrPrinter { void Compile(const lang::Module& module); protected: - void Compile(const lang::LoweredFunc& function); + void Compile(const ir::LoweredFunc& function); void Compile(const ir::Buffer& buffer); std::string PrintType(Type type); diff --git a/cinn/ir/CMakeLists.txt b/cinn/ir/CMakeLists.txt index 031dccd32ad05..055443a5f937a 100644 --- a/cinn/ir/CMakeLists.txt +++ b/cinn/ir/CMakeLists.txt @@ -7,6 +7,7 @@ set(srcs ir_mutator.cc function.cc function_definition.cc + lowered_func.cc ir_operators.cc buffer.cc function_base.cc diff --git a/cinn/ir/ir_mutator.cc b/cinn/ir/ir_mutator.cc index 9eb9c4e988a70..12939cdc758e4 100644 --- a/cinn/ir/ir_mutator.cc +++ b/cinn/ir/ir_mutator.cc @@ -114,5 +114,10 @@ void IRMutator::Visit(const _Tensor_ *expr, Expr *op) { } } +void IRMutator::Visit(const _LoweredFunc_ *expr, Expr *op) { + auto *node = op->As<_LoweredFunc_>(); + IRVisitorBase::Visit(&node->body, &node->body); +} + } // namespace ir } // namespace cinn diff --git a/cinn/ir/ir_printer.cc b/cinn/ir/ir_printer.cc index ec540751ffe22..cc1fbc045531b 100644 --- a/cinn/ir/ir_printer.cc +++ b/cinn/ir/ir_printer.cc @@ -2,7 +2,10 @@ #include +#include "cinn/ir/lowered_func.h" +#include "cinn/lang/module.h" #include "cinn/lang/tensor.h" +#include "cinn/utils/string.h" namespace cinn { namespace ir { @@ -173,6 +176,27 @@ void IrPrinter::Visit(const _Tensor_ *x) { } os_ << ")"; } +void IrPrinter::Visit(const _LoweredFunc_ *f) { + os_ << "function " << f->name << " "; + + std::vector arg_names; + for (auto &arg : f->args) { + arg_names.push_back(arg.name); + } + os_ << "(" << utils::Join(arg_names, ", "); + + DoIndent(); + os_ << "{"; + + IncIndent(); + + Print(f->body); + + DecIndent(); + + DoIndent(); + os_ << "}"; +} std::ostream &operator<<(std::ostream &os, Expr a) { std::stringstream ss; IrPrinter printer(ss); @@ -181,5 +205,9 @@ std::ostream &operator<<(std::ostream &os, Expr a) { return os; } +std::ostream &operator<<(std::ostream &os, const ir::LoweredFunc &f) {} + +std::ostream &operator<<(std::ostream &os, const lang::Module &m); + } // namespace ir } // namespace cinn diff --git a/cinn/ir/ir_printer.h b/cinn/ir/ir_printer.h index 200e310f7b629..44eb68d885d18 100644 --- a/cinn/ir/ir_printer.h +++ b/cinn/ir/ir_printer.h @@ -7,6 +7,12 @@ #include "cinn/ir/ir_visitor.h" namespace cinn { + +namespace lang { +class Module; +class LoweredFunc; +} // namespace lang + namespace ir { struct IrPrinter : public IRVisitor { @@ -72,6 +78,7 @@ struct IrPrinter : public IRVisitor { void Visit(const _IterVar_ *x) override {} void Visit(const _Buffer_ *x) override; void Visit(const _Tensor_ *x) override; + void Visit(const _LoweredFunc_ *x) override; private: std::ostream &os_; @@ -80,6 +87,7 @@ struct IrPrinter : public IRVisitor { }; std::ostream &operator<<(std::ostream &os, Expr a); +std::ostream &operator<<(std::ostream &os, const lang::Module &m); } // namespace ir } // namespace cinn diff --git a/cinn/ir/ir_visitor.h b/cinn/ir/ir_visitor.h index 817aa1ddcf0fd..6cb345478db84 100644 --- a/cinn/ir/ir_visitor.h +++ b/cinn/ir/ir_visitor.h @@ -4,6 +4,7 @@ #include "cinn/ir/buffer.h" #include "cinn/ir/ir.h" +#include "cinn/ir/lowered_func.h" #include "cinn/lang/tensor.h" namespace cinn { diff --git a/cinn/ir/lowered_func.cc b/cinn/ir/lowered_func.cc new file mode 100644 index 0000000000000..bb80ea448ddae --- /dev/null +++ b/cinn/ir/lowered_func.cc @@ -0,0 +1,30 @@ +#include "cinn/ir/lowered_func.h" + +#include "cinn/common/common.h" + +namespace cinn { +namespace ir { + +const _LoweredFunc_* LoweredFunc::operator->() const { return As<_LoweredFunc_>(); } +_LoweredFunc_* LoweredFunc::operator->() { return As<_LoweredFunc_>(); } + +LoweredFunc _LoweredFunc_::Make(const std::string& name, const std::vector& args, const Expr& body) { + auto* n = make_shared<_LoweredFunc_>(); + n->name = name; + n->args = args; + n->body = body; + return LoweredFunc(n); +} + +LoweredFunc _LoweredFunc_::Make(const std::string& name, + const std::vector& args, + const std::vector& body) { + CHECK_EQ(body.size(), 1); + return Make(name, args, body.front()); +} + +std::vector _LoweredFunc_::expr_fields() { return {&body}; } +std::vector _LoweredFunc_::expr_fields() const { return {&body}; } + +} // namespace ir +} // namespace cinn diff --git a/cinn/ir/lowered_func.h b/cinn/ir/lowered_func.h new file mode 100644 index 0000000000000..845a658249b24 --- /dev/null +++ b/cinn/ir/lowered_func.h @@ -0,0 +1,70 @@ +#pragma once +#include "cinn/ir/buffer.h" +#include "cinn/ir/node.h" + +namespace cinn { +namespace ir { + +class _LoweredFunc_; + +/** + * A struct representing an argument to a lowered function. Used for specifying the function signature of generated + * code. + */ +struct Argument { + //! The name of the argument. + std::string name; + + enum class Kind { kScalar = 0, kBuffer } kind{Kind::kScalar}; + + //! Number of the dimensions of buffer. + uint32_t ndims{0}; + + //! The type of the buffer or scalar. + Type type; + + bool is_buffer() const { return kind == Kind::kBuffer; } + bool is_scalar() const { return kind == Kind::kScalar; } + + Argument() {} + Argument(const std::string& name, Kind kind, const Type& type, int ndims) + : name(name), kind(kind), type(type), ndims(ndims) {} + + explicit Argument(const ir::Buffer& buffer) : name(buffer->name), type(buffer->type()), ndims(buffer->shape.size()) {} +}; + +//! Wrapper for _LoweredFunc_ +class LoweredFunc : public IrNodeRef { + public: + LoweredFunc() = default; + explicit LoweredFunc(IrNode* n) : IrNodeRef(n) {} + + const _LoweredFunc_* operator->() const; + _LoweredFunc_* operator->(); +}; + +/** + * Definition of a lowered function. Note that, it should be functional. + */ +struct _LoweredFunc_ : ExprNode<_LoweredFunc_> { + //! The name of this function. + std::string name; + + //! The Arguments used in the body of the function. + std::vector args; + + //! Body of this function. + Expr body; + + static LoweredFunc Make(const std::string& name, const std::vector& args, const Expr& body); + + static LoweredFunc Make(const std::string& name, const std::vector& args, const std::vector& body); + + std::vector expr_fields() override; + std::vector expr_fields() const override; + + static const IrNodeTy _node_type_ = IrNodeTy::_LoweredFunc_; +}; + +} // namespace ir +} // namespace cinn diff --git a/cinn/ir/node.h b/cinn/ir/node.h index 85d48f4fba504..ee3e104e5e8b6 100644 --- a/cinn/ir/node.h +++ b/cinn/ir/node.h @@ -66,6 +66,7 @@ class IRVisitor; macro__(_IterVar_) \ macro__(_Buffer_) \ macro__(_Tensor_) \ + macro__(_LoweredFunc_) \ #define NODETY_FORALL(__m) \ NODETY_PRIMITIVE_TYPE_FOR_EACH(__m) \ diff --git a/cinn/lang/lower.cc b/cinn/lang/lower.cc index a96e8b8033a48..fa7be633bf974 100644 --- a/cinn/lang/lower.cc +++ b/cinn/lang/lower.cc @@ -45,7 +45,7 @@ Expr LowerGroup(const poly::detail::Group& group, const std::map Lower(const std::string& name, const std::vector& args) { +std::vector Lower(const std::string& name, const std::vector& args) { // make sure the graph's start-points in the args. auto stages = poly::GatherStagesInTensors(args); @@ -106,12 +106,12 @@ std::vector Lower(const std::string& name, const std::vector arguments; + std::vector arguments; for (auto& arg : args) { - arguments.emplace_back(arg->name, Argument::Kind::kBuffer, arg->type(), arg->shape.size()); + arguments.emplace_back(arg->name, ir::Argument::Kind::kBuffer, arg->type(), arg->shape.size()); } - return {LoweredFunc(name, arguments, block)}; + return {ir::_LoweredFunc_::Make(name, arguments, block)}; } } // namespace lang diff --git a/cinn/lang/lower.h b/cinn/lang/lower.h index 7c9d443f867a8..d39d8c55000e5 100644 --- a/cinn/lang/lower.h +++ b/cinn/lang/lower.h @@ -16,7 +16,7 @@ namespace cinn { namespace lang { using ir::Tensor; -std::vector Lower(const std::string& name, const std::vector& args); +std::vector Lower(const std::string& name, const std::vector& args); } // namespace lang } // namespace cinn diff --git a/cinn/lang/lower_test.cc b/cinn/lang/lower_test.cc index 4acf953d3f1cf..212ca81e54b2a 100644 --- a/cinn/lang/lower_test.cc +++ b/cinn/lang/lower_test.cc @@ -42,7 +42,7 @@ TEST(lower, basic) { } } )ROC"; - TEST_SOUTPUT(lower_funcs.front().body, out); + TEST_SOUTPUT(lower_funcs.front()->body, out); } TEST(lower, more_complex) { diff --git a/cinn/lang/module.cc b/cinn/lang/module.cc index ef06e67379308..5b6d9774fe7e8 100644 --- a/cinn/lang/module.cc +++ b/cinn/lang/module.cc @@ -10,7 +10,7 @@ struct _Module_ : Object { std::string name; Target target; std::vector buffers; - std::vector functions; + std::vector functions; std::vector submodules; const char *type_info() const override { return "_Module_"; } @@ -30,23 +30,17 @@ const std::string &Module::name() const { return self()->name; } const std::vector &Module::buffers() const { return self()->buffers; } -const std::vector &Module::functions() const { return self()->functions; } +const std::vector &Module::functions() const { return self()->functions; } const std::vector &Module::submodules() const { return self()->submodules; } -void Module::Append(const ir::Buffer &buffer) { self()->buffers.push_back(buffer); } +void Module::Append(const Buffer &buffer) { self()->buffers.push_back(buffer.buffer()); } -void Module::Append(const ir::PackedFunc &function) { self()->functions.push_back(function); } +void Module::Append(const ir::LoweredFunc &function) { self()->functions.push_back(function); } void Module::Append(const Module &module) { self()->submodules.push_back(module); } void Module::Compile(const backends::Outputs &outputs) const {} -LoweredFunc::LoweredFunc(const std::string &name, const std::vector &args, const std::vector &body) { - this->name = name; - this->args = args; - this->body = ir::Block::Make(body); -} - } // namespace lang } // namespace cinn diff --git a/cinn/lang/module.h b/cinn/lang/module.h index 81792763b4736..4f37429452d21 100644 --- a/cinn/lang/module.h +++ b/cinn/lang/module.h @@ -4,8 +4,8 @@ #include "cinn/backends/outputs.h" #include "cinn/common/common.h" -#include "cinn/ir/buffer.h" -#include "cinn/ir/function.h" +#include "cinn/ir/lowered_func.h" +#include "cinn/lang/buffer.h" namespace cinn { namespace lang { @@ -28,14 +28,14 @@ class Module { //! The members in the module. // @{ const std::vector& buffers() const; - const std::vector& functions() const; + const std::vector& functions() const; const std::vector& submodules() const; // @} //! Add something to this module. // @{ - void Append(const ir::Buffer& buffer); - void Append(const ir::PackedFunc& function); + void Append(const Buffer& buffer); + void Append(const ir::LoweredFunc& function); void Append(const Module& module); // @} @@ -49,49 +49,5 @@ class Module { Shared<_Module_> module_; }; -/** - * A struct representing an argument to a lowered function. Used for specifying the function signature of generated - * code. - */ -struct Argument { - //! The name of the argument. - std::string name; - - enum class Kind { kScalar = 0, kBuffer } kind{Kind::kScalar}; - - //! Number of the dimensions of buffer. - uint32_t ndims{0}; - - //! The type of the buffer or scalar. - Type type; - - bool is_buffer() const { return kind == Kind::kBuffer; } - bool is_scalar() const { return kind == Kind::kScalar; } - - Argument() {} - Argument(const std::string& name, Kind kind, const Type& type, int ndims) - : name(name), kind(kind), type(type), ndims(ndims) {} - - explicit Argument(const ir::Buffer& buffer) : name(buffer->name), type(buffer->type()), ndims(buffer->shape.size()) {} -}; - -/** - * Definition of a lowered function. Note that, it should be functional. - */ -struct LoweredFunc { - //! The name of this function. - std::string name; - - //! The Arguments used in the body of the function. - std::vector args; - - //! Body of this function. - Expr body; - - LoweredFunc(const std::string& name, const std::vector& args, const Expr& body) - : name(name), args(args), body(body) {} - LoweredFunc(const std::string& name, const std::vector& args, const std::vector& body); -}; - } // namespace lang } // namespace cinn diff --git a/cinn/optim/ir_copy.cc b/cinn/optim/ir_copy.cc index fca30a69e92dc..47a6603f8351f 100644 --- a/cinn/optim/ir_copy.cc +++ b/cinn/optim/ir_copy.cc @@ -114,6 +114,11 @@ struct IRCopyVisitor : public ir::IRVisitorBase { return Expr(); } + Expr Visit(const _LoweredFunc_* op) override { + LOG(FATAL) << "not implemented"; + return Expr(); + } + Expr Visit(const _IterVar_* op) override { LOG(FATAL) << "not implemented"; return Expr(); diff --git a/docs/design.md b/docs/design.org similarity index 50% rename from docs/design.md rename to docs/design.org index 5d2c85c99fc8c..d3b45391e6c40 100644 --- a/docs/design.md +++ b/docs/design.org @@ -1,129 +1,120 @@ -# Design of CINN - -## Multi-layer architecture -To enable the compiler to handle the general DNN tasks, a multi-layer system is -introduced. +#+title: Design of CINN +This document describe the design of the project CINN (Compiler Infrastructure for Neural Networks). +* Multi-layer architecture +To enable the compiler to handle the general DNN tasks, a multi-layer system is designed. The layers are as follows 1. NN compatible layer - - Add operator wrapper for DNN platform such as PaddlePaddle or TensorFlow. + - Add operator wrapper for DNN platform such as PaddlePaddle or TensorFlow. 2. Virtual graph layer - - Add virtual node and graph to utilize both the compiler and third-party computational library such as CUDNN or MKLDNN. + - Add virtual node and graph to utilize both the compiler and third-party computational library such as CUDNN or MKLDNN. 3. DSL layer - - Export a friendly domain language to programming with the underlying compiler. + - Export a friendly domain language to programming with the underlying compiler. 4. Compiler layer - - A NN compiler which can optimize affine forloop. - -## DSL layer + - A NN compiler which can optimize affine forloop. +* DSL layer **Matrix multiplication with blocking** -```c++ -Var i("i"), j("j"), k("k"); -Constant N("N"), M("M"), K("K"); +#+BEGIN_SRC C++ + Var i("i"), j("j"), k("k"); + Constant N("N"), M("M"), K("K"); -PlaceHolder x("x", {M, K}); -PlaceHolder y("x", {K, N}); + PlaceHolder x("x", {M, K}); + PlaceHolder y("x", {K, N}); -Tensor C = compute({M, N, K}/*dims*/, [&](Var i, Var j, Var k){ - return x(i,k) * y(k,j); -}); + Tensor C = compute({M, N, K} /*dims*/, [&](Var i, Var j, Var k) { return x(i, k) * y(k, j); }); -Schedule s = ComputeSchedule({C}/*outputs*/, {A,B}/*inputs*/); -{ // schedule C's computation + Schedule s = ComputeSchedule({C} /*outputs*/, {A, B} /*inputs*/); + { // schedule C's computation - // tile i, j with factor 4 - Var i0,i1,j0,j1; - std::tie(i0,i1,j0,j1) = s[C].tile(s[C].axis("i"), s[C].axis("j"), 4, 4); - + Var i0, i1, j0, j1; + std::tie(i0, i1, j0, j1) = s[C].tile(s[C].axis("i"), s[C].axis("j"), 4, 4); + // tile k with factor 4 - Var k0,k1; - std::tie(k0,k1) = s[C].split(k, 4); - - s[C].reorder(i0,j0,k0,k1,i1,j1); // swap(i1,j0) -} + Var k0, k1; + std::tie(k0, k1) = s[C].split(k, 4); + + s[C].reorder(i0, j0, k0, k1, i1, j1); // swap(i1,j0) + } -auto module = Build(s, {A, B, C}, "llvm", "host", "func"); -void* func = module.GetFuncByName("func"); -``` + auto module = Build(s, {A, B, C}, "llvm", "host", "func"); + void* func = module.GetFuncByName("func"); +#+END_SRC **Matrix with Vectorization** -```c++ -Schedule S = ComputeSchedule({C}, {A,B}); +#+BEGIN_SRC C++ + Schedule S = ComputeSchedule({C}, {A, B}); -Var k0,k1; -std::tie(k0,k1) = S[C].split(k,4); -Var x0,x1,y0,y1; -std::tie(x0, x1, y0, y1) = S[C].tile(x, y, 4, 4); + Var k0, k1; + std::tie(k0, k1) = S[C].split(k, 4); + Var x0, x1, y0, y1; + std::tie(x0, x1, y0, y1) = S[C].tile(x, y, 4, 4); -S[C].reorder(x0, y0, k0, k1, x1, y1); + S[C].reorder(x0, y0, k0, k1, x1, y1); -S[C].vectorize(y1); -``` + S[C].vectorize(y1); +#+END_SRC **Matrix with Packing** -```c++ -Tensor packedB = compute((N/bn, K, bn), [&](Var i, Var j, Var k) { - return B(j, i*bn+k); -}); +#+BEGIN_SRC C++ + Tensor packedB = compute((N / bn, K, bn), [&](Var i, Var j, Var k) { return B(j, i * bn + k); }); -Tensor C = compute({M, N}, [&](Var i, Var j, Var k) { + Tensor C = compute({M, N}, [&](Var i, Var j, Var k) { // reduce sum(need initialize) - return sum(A(i,k) * packedB(y/bn, k, y%bn), k); -}); + return sum(A(i, k) * packedB(y / bn, k, y % bn), k); + }); -Schedule S = compute_schedule({C}, {A,B}); + Schedule S = compute_schedule({C}, {A, B}); -Var i0,j0,i1,j1; -Var k0,k1; -std::tie(i0,i1,j0,j1) = S[C].tile(S[C].axis(0), S.axis(1), 4, 4); -std::tie(k0,k1) = S[C].split(S[C].axis(k, 4); + Var i0, j0, i1, j1; + Var k0, k1; + std::tie(i0, i1, j0, j1) = S[C].tile(S[C].axis(0), S.axis(1), 4, 4); + std::tie(k0, k1) = S[C].split(S[C].axis(k, 4)); -S[C].reorder(i0, j0, k0, i1, k1, j1); -S[C].vectorize(j1); + S[C].reorder(i0, j0, k0, i1, k1, j1); + S[C].vectorize(j1); -{ + { Var i, j, k; - std::tie(i,j,k) = S[packedB].axis(); + std::tie(i, j, k) = S[packedB].axis(); S[packedB].vectorize(k); S[packedB].parallel(i); -} -``` + } +#+END_SRC -## Compiler Layer +* Compiler Layer +** IR -### IR The IR is similar to Halide. +*** Basic elements -#### Basic elements The IR has following basic elements: - Expr, the expression in the IR(which represents a value or returns a value). - Stmt, the statement in the IR. - Tensor (the input or temporary value) - Buffer (the memory buffer) +** Tensor -### Tensor Tensor represents the input or temporary variable. Each tensor is assigned a buffer by default, but `store_in` can change the relation. +** Polyhedral usage - -### Polyhedral usage The polyhedral technology is used to simplify the forloop analysis and transform. +** schedule -### schedule The original tensor-based computation forms a SSA graph. Each tensor is assign a `Stage`, which is the basic schedule element. A stage has a domain(isl.Set) and a schedule(isl.Map), all the schedule is performed on them. - -#### Schedule the stages +*** Schedule the stages We use the ideas from Tiramisu project, and walk through the dependency graph, split the graph into several groups. @@ -135,24 +126,24 @@ There are several rules to split the graph, the naive one is - if two statement is marked by `compute_at`, merge to the same group too. - this period is like a union find. - for each group, use a different `ast_build` to generate ISL IR(so that we can set iterators separately) +*** Scheduler module - -#### Scheduler module The Scheduler take the stages as input, and do the previous mentioned graph partition, and finally output several schedule elements. Each schedule element owns an (ISL)iteration domain and a (ISL)schedule, and one can pass it to a ast_gen and generate code. - -### Lower output Tensors to LoweredFuncs +*** Lower output Tensors to LoweredFuncs First, given the output tensors, the `Lower` function will collect all the depended inputs, and lower them to a function. The lower interface is -```c++ -std::vector Lower(vector& args, DeviceAPI device); -``` -### Buffer +#+BEGIN_SRC C++ + std::vector Lower(vector& args, DeviceAPI device); +#+END_SRC + +** Buffer + Buffer represents the actual memory in host or devices. The `Buffer` node in IR represents a buffer, it can be used by binding to a Tensor. @@ -161,22 +152,28 @@ The Tensor will be noninlined only if it binds to some buffer. NOTE A buffer can be reused in multiple tensors(TODO the write-read correctness should be considered). -```c++ -Buffer buffer0; -Tensor x = Compute(...); -// x will write the result to buffer0 -x->Bind(buffer0); +#+BEGIN_SRC C++ + Buffer buffer0; -Tensor y = Compute(..., [](Var i) { - return x(i) * 2; // here it will read the buffer instead, x is just a alias. -}); -``` + Tensor x = Compute(...); + // x will write the result to buffer0 + x->Bind(buffer0); + + Tensor y = Compute(..., [](Var i) { + return x(i) * 2; // here it will read the buffer instead, x is just a alias. + }); +#+END_SRC The size of the buffer will be inferenced from the shape and data type of tensor. It by default can be resized to proper shape by binding to multiple tensors. +*** Buffer in CodeGen -#### Buffer in CodeGen All the buffers will be maintained in global scope, and alloc or dealloc in local scopes. The benefit is buffer is easy to shared accross multiple statements. +** Module +Module is the container of LoweredFuncs and Buffers. +There might be more than one module in an generated execution. + +The Module can compile to a backends.