Skip to content

Commit

Permalink
Base IRBuilder methods for Block
Browse files Browse the repository at this point in the history
This PR introduces base IRBuilder methods for `Block`.

Co-authored-by: yongwww <yongcale@gmail.com>
  • Loading branch information
cyx-6 and yongwww committed Sep 10, 2022
1 parent 2eed663 commit f0a877e
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 0 deletions.
70 changes: 70 additions & 0 deletions include/tvm/script/ir_builder/tir/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,76 @@ class PrimFuncFrame : public TIRFrame {
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncFrame, TIRFrame, PrimFuncFrameNode);
};

/*!
* \brief A frame that represents the block.
*
* \sa BlockFrame
*/
class BlockFrameNode : public TIRFrameNode {
public:
/*! \brief The name of the block. */
String name;
/*! \brief The variables of the block. */
Array<tvm::tir::IterVar> iter_vars;
/*! \brief The read buffer regions of the block. */
Optional<Array<tvm::tir::BufferRegion>> reads;
/*! \brief The write buffer regions of the block. */
Optional<Array<tvm::tir::BufferRegion>> writes;
/*! \brief The init statement of the bolck. */
Optional<tvm::tir::Stmt> init;
/*! \brief The buffer allocated in the block. */
Array<tvm::tir::Buffer> alloc_buffers;
/*! \brief The match buffer regions. */
Array<tvm::tir::MatchBufferRegion> match_buffers;
/*! \brief The annotation of the block. */
Optional<Map<String, ObjectRef>> annotations;
/*! \brief The corresponding values of the iter vars. */
Array<PrimExpr> iter_values;
/*!
* \brief The predicate of the block realization, the block will only be executed when the
* predicate is true.
*/
Optional<PrimExpr> predicate;
/*! \brief The flag whether to construct BlockRealize or Block. */
bool no_realize;

void VisitAttrs(tvm::AttrVisitor* v) {
TIRFrameNode::VisitAttrs(v);
v->Visit("name", &name);
v->Visit("iter_vars", &iter_vars);
v->Visit("reads", &reads);
v->Visit("writes", &writes);
v->Visit("init", &init);
v->Visit("alloc_buffers", &alloc_buffers);
v->Visit("match_buffers", &match_buffers);
v->Visit("annotations", &annotations);
v->Visit("iter_values", &iter_values);
v->Visit("predicate", &predicate);
v->Visit("no_realize", &no_realize);
}

static constexpr const char* _type_key = "script.ir_builder.tir.BlockFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(BlockFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to BlockFrameNode.
*
* \sa BlockFrameNode
*/

class BlockFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, TIRFrame, BlockFrameNode);
};

/*!
* \brief A frame that represents the assert statement. Proceeds if the condition is true,
* otherwise aborts with the message.
Expand Down
8 changes: 8 additions & 0 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ namespace tir {
*/
PrimFuncFrame PrimFunc();

/*!
* \brief The block declaration statement.
* \param name The name of the block.
* \param no_realize The flag whether to construct BlockRealize or Block.
* \return The BlockFrame.
*/
BlockFrame Block(String name, bool no_realize = false);

/*!
* \brief Evaluate the input expression.
* \param value The input expression to evaluate.
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/script/ir_builder/tir/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,8 @@ class TIRFrame(IRBuilderFrame):
@_register_object("script.ir_builder.tir.PrimFuncFrame")
class PrimFuncFrame(TIRFrame):
...


@_register_object("script.ir_builder.tir.BlockFrame")
class BlockFrame(TIRFrame):
...
20 changes: 20 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,25 @@ def prim_func() -> frame.PrimFuncFrame:
return _ffi_api.PrimFunc() # pylint: disable=no-member # type: ignore


def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame:
"""The block declaration statement.
Parameters
----------
name : str
The name of the block.
no_realize : bool
The flag whether to construct BlockRealize or Block.
Returns
-------
res : frame.BlockFrame
The BlockFrame.
"""
return _ffi_api.Block(name, no_realize) # pylint: disable=no-member # type: ignore


def evaluate(value: PrimExpr) -> None:
"""Evaluate the input expression.
Expand All @@ -50,6 +69,7 @@ def evaluate(value: PrimExpr) -> None:


__all__ = [
"block",
"evaluate",
"prim_func",
]
24 changes: 24 additions & 0 deletions src/script/ir_builder/tir/frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,32 @@ void PrimFuncFrameNode::ExitWithScope() {
}
}

void BlockFrameNode::ExitWithScope() {
TIRFrameNode::ExitWithScope();
Array<tvm::tir::Buffer> tir_alloc_buffers;
for (const tvm::tir::Buffer& buffer : alloc_buffers) {
tir_alloc_buffers.push_back(buffer);
}
Map<String, ObjectRef> attrs = annotations.value_or({});
if (int detect_access = (!reads.defined()) | (!writes.defined() << 1)) {
attrs.Set("tir.script_parsing_detect_access", tvm::IntImm(DataType::Int(64), detect_access));
}
tvm::tir::Block block(iter_vars, reads.value_or(Array<tvm::tir::BufferRegion>()),
writes.value_or(Array<tvm::tir::BufferRegion>()), name, AsStmt(stmts), init,
tir_alloc_buffers, match_buffers, attrs);
if (no_realize) {
CHECK(iter_values.empty())
<< "ValueError: Block bindings are not allowed when `no_realize=True`";
CHECK(!predicate.defined()) << "ValueError: `T.where` is not allowed when `no_realize=True`";
AddToParent(block);
} else {
AddToParent(tvm::tir::BlockRealize(iter_values, predicate.value_or(Bool(true)), block));
}
}

TVM_REGISTER_NODE_TYPE(TIRFrameNode);
TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode);
TVM_REGISTER_NODE_TYPE(BlockFrameNode);

} // namespace tir
} // namespace ir_builder
Expand Down
17 changes: 17 additions & 0 deletions src/script/ir_builder/tir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,25 @@ PrimFuncFrame PrimFunc() {
return PrimFuncFrame(n);
}

BlockFrame Block(String name, bool no_realize) {
ObjectPtr<BlockFrameNode> n = make_object<BlockFrameNode>();
n->name = name;
n->iter_vars.clear();
n->reads = NullOpt;
n->writes = NullOpt;
n->init = NullOpt;
n->alloc_buffers.clear();
n->match_buffers.clear();
n->annotations = NullOpt;
n->iter_values.clear();
n->predicate = NullOpt;
n->no_realize = no_realize;
return BlockFrame(n);
}

void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); }
TVM_REGISTER_GLOBAL("script.ir_builder.tir.PrimFunc").set_body_typed(PrimFunc);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Block").set_body_typed(Block);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate);
} // namespace tir
} // namespace ir_builder
Expand Down
9 changes: 9 additions & 0 deletions src/script/ir_builder/tir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ inline PrimFuncFrame FindPrimFuncFrame(const String& method) {
throw;
}

inline BlockFrame FindBlockFrame(const String& method) {
if (Optional<BlockFrame> frame = IRBuilder::Current()->GetLastFrame<BlockFrame>()) {
return frame.value();
}
LOG(FATAL) << "ValueError: Block frame not find. Please ensure '" << method
<< "' is called under T.block()";
throw;
}

} // namespace tir
} // namespace ir_builder
} // namespace script
Expand Down
27 changes: 27 additions & 0 deletions tests/python/unittest/test_tvmscript_ir_builder_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,32 @@ def test_ir_builder_tir_primfunc():
assert_structural_equal(prim_func_actual, prim_func_expected, map_free_vars=True)


def test_ir_builder_tir_block():
with IRBuilder() as ib:
with T.block("block"):
T.evaluate(0)
# the block generated by IRBuilder
block_realize_actual = ib.get()

# the expected block
block_expected = tir.Block(
iter_vars=[],
reads=[],
writes=[],
name_hint="block",
body=tir.Evaluate(0),
alloc_buffers=None,
match_buffers=None,
annotations={"tir.script_parsing_detect_access": tir.IntImm("int64", 3)},
)
block_realize_expected = tir.BlockRealize(
iter_values=[],
predicate=True,
block=block_expected,
)
# Check if the generated ir is expected
assert_structural_equal(block_realize_actual, block_realize_expected, map_free_vars=True)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit f0a877e

Please sign in to comment.