From f0a877e052b136fa868c99fcca9a7f69506be9cd Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Fri, 9 Sep 2022 14:16:47 -0700 Subject: [PATCH] Base IRBuilder methods for `Block` This PR introduces base IRBuilder methods for `Block`. Co-authored-by: yongwww --- include/tvm/script/ir_builder/tir/frame.h | 70 +++++++++++++++++++ include/tvm/script/ir_builder/tir/ir.h | 8 +++ python/tvm/script/ir_builder/tir/frame.py | 5 ++ python/tvm/script/ir_builder/tir/ir.py | 20 ++++++ src/script/ir_builder/tir/frame.cc | 24 +++++++ src/script/ir_builder/tir/ir.cc | 17 +++++ src/script/ir_builder/tir/utils.h | 9 +++ .../unittest/test_tvmscript_ir_builder_tir.py | 27 +++++++ 8 files changed, 180 insertions(+) diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index 4bfd022af27a..15ab77863e5e 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -117,6 +117,76 @@ class PrimFuncFrame : public TIRFrame { TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncFrame, TIRFrame, PrimFuncFrameNode); }; +/*! + * \brief A frame that represents the block. + * + * \sa BlockFrame + */ +class BlockFrameNode : public TIRFrameNode { + public: + /*! \brief The name of the block. */ + String name; + /*! \brief The variables of the block. */ + Array iter_vars; + /*! \brief The read buffer regions of the block. */ + Optional> reads; + /*! \brief The write buffer regions of the block. */ + Optional> writes; + /*! \brief The init statement of the bolck. */ + Optional init; + /*! \brief The buffer allocated in the block. */ + Array alloc_buffers; + /*! \brief The match buffer regions. */ + Array match_buffers; + /*! \brief The annotation of the block. */ + Optional> annotations; + /*! \brief The corresponding values of the iter vars. */ + Array iter_values; + /*! + * \brief The predicate of the block realization, the block will only be executed when the + * predicate is true. + */ + Optional predicate; + /*! \brief The flag whether to construct BlockRealize or Block. */ + bool no_realize; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("name", &name); + v->Visit("iter_vars", &iter_vars); + v->Visit("reads", &reads); + v->Visit("writes", &writes); + v->Visit("init", &init); + v->Visit("alloc_buffers", &alloc_buffers); + v->Visit("match_buffers", &match_buffers); + v->Visit("annotations", &annotations); + v->Visit("iter_values", &iter_values); + v->Visit("predicate", &predicate); + v->Visit("no_realize", &no_realize); + } + + static constexpr const char* _type_key = "script.ir_builder.tir.BlockFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(BlockFrameNode, TIRFrameNode); + + public: + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to BlockFrameNode. + * + * \sa BlockFrameNode + */ + +class BlockFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, TIRFrame, BlockFrameNode); +}; + /*! * \brief A frame that represents the assert statement. Proceeds if the condition is true, * otherwise aborts with the message. diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index cee60ad4f827..615ce90383dd 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -34,6 +34,14 @@ namespace tir { */ PrimFuncFrame PrimFunc(); +/*! + * \brief The block declaration statement. + * \param name The name of the block. + * \param no_realize The flag whether to construct BlockRealize or Block. + * \return The BlockFrame. + */ +BlockFrame Block(String name, bool no_realize = false); + /*! * \brief Evaluate the input expression. * \param value The input expression to evaluate. diff --git a/python/tvm/script/ir_builder/tir/frame.py b/python/tvm/script/ir_builder/tir/frame.py index 61418e0b2aa6..0e7eb2bb4720 100644 --- a/python/tvm/script/ir_builder/tir/frame.py +++ b/python/tvm/script/ir_builder/tir/frame.py @@ -29,3 +29,8 @@ class TIRFrame(IRBuilderFrame): @_register_object("script.ir_builder.tir.PrimFuncFrame") class PrimFuncFrame(TIRFrame): ... + + +@_register_object("script.ir_builder.tir.BlockFrame") +class BlockFrame(TIRFrame): + ... diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index ae5d5b260f65..7ba2f6df9418 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -33,6 +33,25 @@ def prim_func() -> frame.PrimFuncFrame: return _ffi_api.PrimFunc() # pylint: disable=no-member # type: ignore +def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame: + """The block declaration statement. + + Parameters + ---------- + name : str + The name of the block. + + no_realize : bool + The flag whether to construct BlockRealize or Block. + + Returns + ------- + res : frame.BlockFrame + The BlockFrame. + """ + return _ffi_api.Block(name, no_realize) # pylint: disable=no-member # type: ignore + + def evaluate(value: PrimExpr) -> None: """Evaluate the input expression. @@ -50,6 +69,7 @@ def evaluate(value: PrimExpr) -> None: __all__ = [ + "block", "evaluate", "prim_func", ] diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index 139c8193b0ba..dd3097e388b7 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -50,8 +50,32 @@ void PrimFuncFrameNode::ExitWithScope() { } } +void BlockFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + Array tir_alloc_buffers; + for (const tvm::tir::Buffer& buffer : alloc_buffers) { + tir_alloc_buffers.push_back(buffer); + } + Map attrs = annotations.value_or({}); + if (int detect_access = (!reads.defined()) | (!writes.defined() << 1)) { + attrs.Set("tir.script_parsing_detect_access", tvm::IntImm(DataType::Int(64), detect_access)); + } + tvm::tir::Block block(iter_vars, reads.value_or(Array()), + writes.value_or(Array()), name, AsStmt(stmts), init, + tir_alloc_buffers, match_buffers, attrs); + if (no_realize) { + CHECK(iter_values.empty()) + << "ValueError: Block bindings are not allowed when `no_realize=True`"; + CHECK(!predicate.defined()) << "ValueError: `T.where` is not allowed when `no_realize=True`"; + AddToParent(block); + } else { + AddToParent(tvm::tir::BlockRealize(iter_values, predicate.value_or(Bool(true)), block)); + } +} + TVM_REGISTER_NODE_TYPE(TIRFrameNode); TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode); +TVM_REGISTER_NODE_TYPE(BlockFrameNode); } // namespace tir } // namespace ir_builder diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 5f994d71ca0a..4c2679ae6b56 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -41,8 +41,25 @@ PrimFuncFrame PrimFunc() { return PrimFuncFrame(n); } +BlockFrame Block(String name, bool no_realize) { + ObjectPtr n = make_object(); + n->name = name; + n->iter_vars.clear(); + n->reads = NullOpt; + n->writes = NullOpt; + n->init = NullOpt; + n->alloc_buffers.clear(); + n->match_buffers.clear(); + n->annotations = NullOpt; + n->iter_values.clear(); + n->predicate = NullOpt; + n->no_realize = no_realize; + return BlockFrame(n); +} + void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); } TVM_REGISTER_GLOBAL("script.ir_builder.tir.PrimFunc").set_body_typed(PrimFunc); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Block").set_body_typed(Block); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate); } // namespace tir } // namespace ir_builder diff --git a/src/script/ir_builder/tir/utils.h b/src/script/ir_builder/tir/utils.h index 47557917cca5..4f8b3f77c6e1 100644 --- a/src/script/ir_builder/tir/utils.h +++ b/src/script/ir_builder/tir/utils.h @@ -60,6 +60,15 @@ inline PrimFuncFrame FindPrimFuncFrame(const String& method) { throw; } +inline BlockFrame FindBlockFrame(const String& method) { + if (Optional frame = IRBuilder::Current()->GetLastFrame()) { + return frame.value(); + } + LOG(FATAL) << "ValueError: Block frame not find. Please ensure '" << method + << "' is called under T.block()"; + throw; +} + } // namespace tir } // namespace ir_builder } // namespace script diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py index 70a8f3565d03..85080c7c65fc 100644 --- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py +++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py @@ -45,5 +45,32 @@ def test_ir_builder_tir_primfunc(): assert_structural_equal(prim_func_actual, prim_func_expected, map_free_vars=True) +def test_ir_builder_tir_block(): + with IRBuilder() as ib: + with T.block("block"): + T.evaluate(0) + # the block generated by IRBuilder + block_realize_actual = ib.get() + + # the expected block + block_expected = tir.Block( + iter_vars=[], + reads=[], + writes=[], + name_hint="block", + body=tir.Evaluate(0), + alloc_buffers=None, + match_buffers=None, + annotations={"tir.script_parsing_detect_access": tir.IntImm("int64", 3)}, + ) + block_realize_expected = tir.BlockRealize( + iter_values=[], + predicate=True, + block=block_expected, + ) + # Check if the generated ir is expected + assert_structural_equal(block_realize_actual, block_realize_expected, map_free_vars=True) + + if __name__ == "__main__": tvm.testing.main()