From 1aae5df9334d207072c99e911084d105b27d9a2b Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Sat, 17 Sep 2022 14:18:31 -0700 Subject: [PATCH] [TVMScript] IRBuilder methods for `Stmt` This PR introduces IRBuilder methods for `Assert`, `Let`, `Realize`, `Evaluate`, `LaunchThread`, `EnvThread`. Co-authored-by: yongwww --- include/tvm/script/ir_builder/tir/frame.h | 132 ++++++++++++++++++ include/tvm/script/ir_builder/tir/ir.h | 40 ++++++ python/tvm/script/ir_builder/tir/frame.py | 20 +++ python/tvm/script/ir_builder/tir/ir.py | 131 +++++++++++++++++ src/script/ir_builder/tir/frame.cc | 27 ++++ src/script/ir_builder/tir/ir.cc | 67 +++++++++ .../unittest/test_tvmscript_ir_builder_tir.py | 69 +++++++++ 7 files changed, 486 insertions(+) diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index c76b400d96b4..38fe9009dd61 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -303,6 +303,138 @@ class AssertFrameNode : public TIRFrameNode { void ExitWithScope() final; }; +/*! + * \brief Managed reference to AssertFrameNode. + * + * \sa AssertFrameNode + */ +class AssertFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AssertFrame, TIRFrame, AssertFrameNode); +}; + +/*! + * \brief A frame represents the let binding expression, which binds a var. + * + * \sa LetFrameNode + */ +class LetFrameNode : public TIRFrameNode { + public: + /*! \brief The variable we bind to */ + tvm::tir::Var var; + /*! \brief The value we bind var to */ + PrimExpr value; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("var", &var); + v->Visit("value", &value); + } + + static constexpr const char* _type_key = "script.ir_builder.tir.LetFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(LetFrameNode, TIRFrameNode); + + public: + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to LetFrameNode. + * + * \sa LetFrameNode + */ +class LetFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LetFrame, TIRFrame, LetFrameNode); +}; + +/*! + * \brief The LaunchThreadFrameNode. + * \note It is used only inside a PrimFunc. + */ +class LaunchThreadFrameNode : public TIRFrameNode { + public: + /*! \brief The extent of environment thread. */ + PrimExpr extent; + /*! \brief The attribute key, could be either virtual_thread or thread_extent. */ + String attr_key; + /*! \brief The iteration variable. */ + tvm::tir::IterVar iter_var; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("extent", &extent); + v->Visit("attr_key", &attr_key); + v->Visit("iter_var", &iter_var); + } + + static constexpr const char* _type_key = "script.ir_builder.tir.LaunchThreadFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(LaunchThreadFrameNode, TIRFrameNode); + + public: + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to LaunchThreadFrameNode. + * + * \sa LaunchThreadFrameNode + */ +class LaunchThreadFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LaunchThreadFrame, TIRFrame, + LaunchThreadFrameNode); +}; + +/*! + * \brief A frame that represents realization. + * + * \sa RealizeFrame + */ +class RealizeFrameNode : public TIRFrameNode { + public: + /*! \brief The region of buffer access. */ + tvm::tir::BufferRegion buffer_slice; + /*! \brief The storage scope associated with this realization. */ + String storage_scope; + /*! \brief The condition expression. */ + PrimExpr condition; + + void VisitAttrs(tvm::AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("buffer_slice", &buffer_slice); + v->Visit("storage_scope", &storage_scope); + v->Visit("condition", &condition); + } + + static constexpr const char* _type_key = "script.ir_builder.tir.RealizeFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(RealizeFrameNode, TIRFrameNode); + + public: + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to RealizeFrameNode. + * + * \sa RealizeFrameNode + */ +class RealizeFrame : public TIRFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RealizeFrame, TIRFrame, RealizeFrameNode); +}; } // namespace tir } // namespace ir_builder } // namespace script diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 191887648dbd..ec1f7f3753d1 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -292,6 +292,46 @@ ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread, */ ForFrame Grid(Array extents); +/*! + * \brief The assertion statement. + * \param condition The assertion condition. + * \param message The error message when the assertion fails. + * \return The AssertFrame. + */ +AssertFrame Assert(PrimExpr condition, String message); + +/*! + * \brief The let binding. + * \param var The variable to bind. + * \param value The value to be bound. + * \return The created LetFrame. + */ +LetFrame Let(Var var, PrimExpr value); + +/*! + * \brief The realization. + * \param buffer_slice The region of buffer access. + * \param storage_scope The storage scope associated with this realization. + * \param condition The condition expression. + * \return The result RealizeFrame. + */ +RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, PrimExpr condition); + +/*! + * \brief Launch a thread. + * \param var The iteration variable. + * \param extent The extent of environment thread. + * \return The result LaunchThreadFrame. + */ +LaunchThreadFrame LaunchThread(Var var, PrimExpr extent); + +/*! + * \brief Bind a var to thread env. + * \param thread_tag The thread type tag. + * \return The result variable which gets bound to the thread env. + */ +Var EnvThread(String thread_tag); + /*! * \brief Evaluate the input expression. * \param value The input expression to evaluate. diff --git a/python/tvm/script/ir_builder/tir/frame.py b/python/tvm/script/ir_builder/tir/frame.py index 2ad08f35160d..69bc5bfc9676 100644 --- a/python/tvm/script/ir_builder/tir/frame.py +++ b/python/tvm/script/ir_builder/tir/frame.py @@ -48,3 +48,23 @@ class ForFrame(TIRFrame): def __enter__(self) -> Union[Var, List[Var]]: # type: ignore[override] super().__enter__() return self.vars if len(self.vars) > 1 else self.vars[0] + + +@_register_object("script.ir_builder.tir.AssertFrame") +class AssertFrame(TIRFrame): + ... + + +@_register_object("script.ir_builder.tir.LetFrame") +class LetFrame(TIRFrame): + ... + + +@_register_object("script.ir_builder.tir.RealizeFrame") +class RealizeFrame(TIRFrame): + ... + + +@_register_object("script.ir_builder.tir.LaunchThreadFrame") +class LaunchThreadFrame(TIRFrame): + ... diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index d1dc1c89600d..6db8f40c32c8 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -26,6 +26,8 @@ BufferLoad, BufferRegion, IntImm, + IterVar, + Let, PrimExpr, StringImm, Var, @@ -813,6 +815,130 @@ def grid(*extents: PrimExpr) -> frame.ForFrame: return _ffi_api.Grid(extents) # type: ignore[attr-defined] # pylint: disable=no-member +def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame: # pylint: disable=invalid-name + """Create an assertion statement. + + Parameters + ---------- + condition : PrimExpr + The PrimExpr to test. + + message : str + The output error message when the assertion fails. + + Returns + ------- + res : frame.AssertFrame + The result AssertFrame. + """ + return _ffi_api.Assert(condition, message) # type: ignore[attr-defined] # pylint: disable=no-member + + +def let( + v: Var, + value: PrimExpr, + body: PrimExpr = None, +) -> frame.LetFrame: + """Create a new let binding. + + Parameters + ---------- + v : Var + The variable to bind. + + value : PrimExpr + The value to be bound. + + body : PrimExpr + The body expression, None will be used if it was not specified. + + Returns + ------- + res : frame.LetFrame + The result LetFrame. + """ + if body is None: + return _ffi_api.Let(v, value) # type: ignore[attr-defined] # pylint: disable=no-member + return Let(v, value, body) + + +def realize( + buffer_slice: BufferRegion, + storage_scope: str, + condition: PrimExpr = True, +) -> frame.RealizeFrame: + """Create a realization. + + Parameters + ---------- + buffer_slice : BufferRegion + The region of buffer access. + + storage_scope : str + The storage scope associated with this realization. + + condition: PrimExpr + The condition expression, the default is True. + + Returns + ------- + res : frame.RealizeFrame + The result RealizeFrame. + """ + return _ffi_api.Realize( # type: ignore[attr-defined] # pylint: disable=no-member + buffer_slice, storage_scope, condition + ) + + +def launch_thread( + iter_var: IterVar, # pylint: disable=redefined-outer-name + extent: PrimExpr, +) -> frame.LaunchThreadFrame: + """Launch a thread. + + Parameters + ---------- + iter_var : IterVar + The iteration variable. + + extent : PrimExpr + The extent of environment thread. + + Returns + ------- + res : frame.LaunchThreadFrame + The result LaunchThreadFrame. + + Examples + -------- + + .. code-block:: python + + from tvm.script.ir_builder import tir as T + brow = T.env_thread("blockIdx.y") + T.launch_thread(brow, 1) + + """ + return _ffi_api.LaunchThread(iter_var, extent) # type: ignore[attr-defined] # pylint: disable=no-member + + +def env_thread(thread_tag: str) -> IterVar: + """Bind a var to thread env" + + Parameters + ---------- + thread_tag : str + The thread type tag. + + Returns + ------- + res : IterVar + The result iteration variable gets bound to the thread env. + + """ + return _ffi_api.EnvThread(thread_tag) # type: ignore[attr-defined] # pylint: disable=no-member + + def evaluate(value: PrimExpr) -> None: """Evaluate the input expression. @@ -1159,6 +1285,11 @@ def var(dtype, name="") -> Var: "unroll", "thread_binding", "grid", + "Assert", + "let", + "realize", + "launch_thread", + "env_thread", "evaluate", "int8", "int16", diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index 8b8b2a4d80e0..6c9459e6389c 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -92,11 +92,38 @@ void ForFrameNode::ExitWithScope() { AddToParent(this->f_make_for_loop(vars, doms, AsStmt(stmts))); } +void AssertFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + AddToParent(tvm::tir::AssertStmt(condition, message, AsStmt(stmts))); +} + +void LetFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + AddToParent(tvm::tir::LetStmt(var, value, AsStmt(stmts))); +} + +void RealizeFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + AddToParent(tvm::tir::AttrStmt(buffer_slice->buffer, "realize_scope", + tvm::tir::StringImm(storage_scope), + tvm::tir::BufferRealize(buffer_slice->buffer, buffer_slice->region, + condition, AsStmt(stmts)))); +} + +void LaunchThreadFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + AddToParent(tvm::tir::AttrStmt(iter_var, attr_key, extent, AsStmt(stmts))); +} + TVM_REGISTER_NODE_TYPE(TIRFrameNode); TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode); TVM_REGISTER_NODE_TYPE(BlockFrameNode); TVM_REGISTER_NODE_TYPE(BlockInitFrameNode); TVM_REGISTER_NODE_TYPE(ForFrameNode); +TVM_REGISTER_NODE_TYPE(AssertFrameNode); +TVM_REGISTER_NODE_TYPE(LetFrameNode); +TVM_REGISTER_NODE_TYPE(RealizeFrameNode); +TVM_REGISTER_NODE_TYPE(LaunchThreadFrameNode); } // namespace tir } // namespace ir_builder diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 75e759262655..5951af298f62 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -395,6 +395,67 @@ ForFrame Grid(Array extents) { return ForFrame(n); } +AssertFrame Assert(PrimExpr condition, String message) { + ObjectPtr n = make_object(); + n->condition = condition; + n->message = tvm::tir::StringImm(message); + return AssertFrame(n); +} + +LetFrame Let(Var var, PrimExpr value) { + ObjectPtr n = make_object(); + n->var = var; + n->value = value; + return LetFrame(n); +} + +LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) { + IterVar iter_var{nullptr}; + + if (Optional opt_frame = IRBuilder::Current()->FindFrame()) { + if (Optional opt_iter_var = opt_frame.value()->env_threads.Get(var)) { + iter_var = opt_iter_var.value(); + } else { + LOG(FATAL) << "ValueError: " << var->name_hint + << " is not an env_thread created using T.env_thread."; + } + } else { + LOG(FATAL) << "LaunchThread can only be used inside a PrimFunc"; + } + ObjectPtr n = make_object(); + if (!iter_var->dom.defined()) { + const_cast(iter_var.get())->dom = Range(0, extent); + } else if (!arith::Analyzer().CanProveEqual(iter_var->dom->extent, extent)) { + LOG(FATAL) << "ValueError: Inconsistent extents of environment thread. " + << iter_var->dom->extent << " vs " << extent; + } + n->iter_var = iter_var; + n->extent = extent; + n->attr_key = iter_var->thread_tag == "vthread" ? "virtual_thread" : "thread_extent"; + return LaunchThreadFrame(n); +} + +RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, + PrimExpr condition) { + ObjectPtr n = make_object(); + n->buffer_slice = buffer_slice; + n->storage_scope = storage_scope; + n->condition = condition; + return RealizeFrame(n); +} + +Var EnvThread(String thread_tag) { + IterVar iter_var(Range{nullptr}, Var("", DataType::Int(32)), tvm::tir::IterVarType::kThreadIndex, + thread_tag); + Var var = iter_var->var; + if (Optional opt_frame = IRBuilder::Current()->FindFrame()) { + opt_frame.value()->env_threads.Set(var, iter_var); + } else { + LOG(FATAL) << "EnvThread can only be used inside a PrimFunc"; + } + return var; +} + void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); } using tvm::script::ir_builder::details::Namer; @@ -477,6 +538,12 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.Unroll").set_body_typed(Unroll); TVM_REGISTER_GLOBAL("script.ir_builder.tir.ThreadBinding").set_body_typed(ThreadBinding); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Grid").set_body_typed(Grid); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Assert").set_body_typed(Assert); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Let").set_body_typed(Let); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.Realize").set_body_typed(Realize); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread").set_body_typed(LaunchThread); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.EnvThread").set_body_typed(EnvThread); + TVM_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int8").set_body_typed(Int8); diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py index a5d8c1068064..7f2e6e1a4706 100644 --- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py +++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py @@ -260,5 +260,74 @@ def test_ir_builder_tir_for(): assert_structural_equal(for_actual, for_expected, map_free_vars=True) +def test_ir_builder_tir_assert(): + with IRBuilder() as ib: + with T.Assert(T.var("int32", name="a") == 0, message="a is 0"): + T.evaluate(0) + # the assert generated by IRBuilder + assert_actual = ib.get() + + # the expected assert statement + assert_expected = tir.AssertStmt( + T.var("int32", name="a") == 0, tir.StringImm("a is 0"), tir.Evaluate(0) + ) + # Check if the generated ir is expected + assert_structural_equal(assert_actual, assert_expected, map_free_vars=True) + + +def test_ir_builder_tir_evaluate(): + with IRBuilder() as ib: + T.evaluate(0) + # the evaluate generated by IRBuilder + eval_actual = ib.get() + + # the expected evaluate + eval_expected = tir.Evaluate(0) + # Check if the generated ir is expected + assert_structural_equal(eval_actual, eval_expected, map_free_vars=True) + + +def test_ir_builder_tir_let(): + with IRBuilder() as ib: + with T.let(T.var("int32", name="a"), tir.IntImm("int32", 2)): + T.evaluate(0) + # the let binding generated by IRBuilder + let_actual = ib.get() + + # the expected Let statement + let_expected = tir.LetStmt(T.var("int32", name="a"), tir.IntImm("int32", 2), tir.Evaluate(0)) + assert_structural_equal(let_actual, let_expected, map_free_vars=True) + + +def test_ir_builder_tir_realize(): + buffer_a = T.buffer_decl((128, 128), "float32") + with IRBuilder() as ib: + with T.realize(buffer_a[0:128, 0:128], "test_storage_scope", True): + T.evaluate(0) + realize_actual = ib.get() + + # the expected buffer realization + buffer_realize = tir.BufferRealize( + buffer_a, [tvm.ir.Range(0, 128), tvm.ir.Range(0, 128)], True, tir.Evaluate(0) + ) + expected_realize = tir.AttrStmt( + buffer_a, "realize_scope", tir.StringImm("test_storage_scope"), buffer_realize + ) + assert_structural_equal(realize_actual, expected_realize, map_free_vars=True) + + +def test_ir_builder_tir_thread(): + with IRBuilder() as ib: + with T.prim_func(): + brow = T.env_thread("blockIdx.y") + with T.launch_thread(brow, 1): + T.evaluate(0) + ir_actual = ib.get() + iter_var = tir.IterVar((0, 1), "v", iter_type=1, thread_tag="blockIdx.y") + attr_stmt = tir.AttrStmt(iter_var, "thread_extent", 1, tir.Evaluate(0)) + func = tir.PrimFunc([], attr_stmt) + assert_structural_equal(ir_actual, func, map_free_vars=True) + + if __name__ == "__main__": tvm.testing.main()