Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript] Base IRBuilder methods for Block #12748

Merged
merged 1 commit into from
Sep 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions include/tvm/script/ir_builder/tir/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,76 @@ class PrimFuncFrame : public TIRFrame {
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncFrame, TIRFrame, PrimFuncFrameNode);
};

/*!
* \brief A frame that represents the block.
*
* \sa BlockFrame
*/
class BlockFrameNode : public TIRFrameNode {
public:
/*! \brief The name of the block. */
String name;
/*! \brief The variables of the block. */
Array<tvm::tir::IterVar> iter_vars;
/*! \brief The read buffer regions of the block. */
Optional<Array<tvm::tir::BufferRegion>> reads;
/*! \brief The write buffer regions of the block. */
Optional<Array<tvm::tir::BufferRegion>> writes;
/*! \brief The init statement of the bolck. */
Optional<tvm::tir::Stmt> init;
/*! \brief The buffer allocated in the block. */
Array<tvm::tir::Buffer> alloc_buffers;
/*! \brief The match buffer regions. */
Array<tvm::tir::MatchBufferRegion> match_buffers;
/*! \brief The annotation of the block. */
Optional<Map<String, ObjectRef>> annotations;
/*! \brief The corresponding values of the iter vars. */
Array<PrimExpr> iter_values;
/*!
* \brief The predicate of the block realization, the block will only be executed when the
* predicate is true.
*/
Optional<PrimExpr> predicate;
/*! \brief The flag whether to construct BlockRealize or Block. */
bool no_realize;

void VisitAttrs(tvm::AttrVisitor* v) {
TIRFrameNode::VisitAttrs(v);
v->Visit("name", &name);
v->Visit("iter_vars", &iter_vars);
v->Visit("reads", &reads);
v->Visit("writes", &writes);
v->Visit("init", &init);
v->Visit("alloc_buffers", &alloc_buffers);
v->Visit("match_buffers", &match_buffers);
v->Visit("annotations", &annotations);
v->Visit("iter_values", &iter_values);
v->Visit("predicate", &predicate);
v->Visit("no_realize", &no_realize);
}

static constexpr const char* _type_key = "script.ir_builder.tir.BlockFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(BlockFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to BlockFrameNode.
*
* \sa BlockFrameNode
*/

class BlockFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, TIRFrame, BlockFrameNode);
};

/*!
* \brief A frame that represents the assert statement. Proceeds if the condition is true,
* otherwise aborts with the message.
Expand Down
8 changes: 8 additions & 0 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ namespace tir {
*/
PrimFuncFrame PrimFunc();

/*!
* \brief The block declaration statement.
* \param name The name of the block.
* \param no_realize The flag whether to construct BlockRealize or Block.
* \return The BlockFrame.
*/
BlockFrame Block(String name, bool no_realize = false);

/*!
* \brief Evaluate the input expression.
* \param value The input expression to evaluate.
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/script/ir_builder/tir/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,8 @@ class TIRFrame(IRBuilderFrame):
@_register_object("script.ir_builder.tir.PrimFuncFrame")
class PrimFuncFrame(TIRFrame):
...


@_register_object("script.ir_builder.tir.BlockFrame")
class BlockFrame(TIRFrame):
...
20 changes: 20 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,25 @@ def prim_func() -> frame.PrimFuncFrame:
return _ffi_api.PrimFunc() # pylint: disable=no-member # type: ignore


def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame:
"""The block declaration statement.

Parameters
----------
name : str
The name of the block.

no_realize : bool
The flag whether to construct BlockRealize or Block.

Returns
-------
res : frame.BlockFrame
The BlockFrame.
"""
return _ffi_api.Block(name, no_realize) # pylint: disable=no-member # type: ignore


def evaluate(value: PrimExpr) -> None:
"""Evaluate the input expression.

Expand All @@ -50,6 +69,7 @@ def evaluate(value: PrimExpr) -> None:


__all__ = [
"block",
"evaluate",
"prim_func",
]
24 changes: 24 additions & 0 deletions src/script/ir_builder/tir/frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,32 @@ void PrimFuncFrameNode::ExitWithScope() {
}
}

void BlockFrameNode::ExitWithScope() {
TIRFrameNode::ExitWithScope();
Array<tvm::tir::Buffer> tir_alloc_buffers;
for (const tvm::tir::Buffer& buffer : alloc_buffers) {
tir_alloc_buffers.push_back(buffer);
}
Map<String, ObjectRef> attrs = annotations.value_or({});
if (int detect_access = (!reads.defined()) | (!writes.defined() << 1)) {
attrs.Set("tir.script_parsing_detect_access", tvm::IntImm(DataType::Int(64), detect_access));
}
tvm::tir::Block block(iter_vars, reads.value_or(Array<tvm::tir::BufferRegion>()),
writes.value_or(Array<tvm::tir::BufferRegion>()), name, AsStmt(stmts), init,
tir_alloc_buffers, match_buffers, attrs);
if (no_realize) {
CHECK(iter_values.empty())
<< "ValueError: Block bindings are not allowed when `no_realize=True`";
CHECK(!predicate.defined()) << "ValueError: `T.where` is not allowed when `no_realize=True`";
AddToParent(block);
} else {
AddToParent(tvm::tir::BlockRealize(iter_values, predicate.value_or(Bool(true)), block));
}
}

TVM_REGISTER_NODE_TYPE(TIRFrameNode);
TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode);
TVM_REGISTER_NODE_TYPE(BlockFrameNode);

} // namespace tir
} // namespace ir_builder
Expand Down
17 changes: 17 additions & 0 deletions src/script/ir_builder/tir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,25 @@ PrimFuncFrame PrimFunc() {
return PrimFuncFrame(n);
}

BlockFrame Block(String name, bool no_realize) {
ObjectPtr<BlockFrameNode> n = make_object<BlockFrameNode>();
n->name = name;
n->iter_vars.clear();
n->reads = NullOpt;
n->writes = NullOpt;
n->init = NullOpt;
n->alloc_buffers.clear();
n->match_buffers.clear();
n->annotations = NullOpt;
n->iter_values.clear();
n->predicate = NullOpt;
n->no_realize = no_realize;
return BlockFrame(n);
}

void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); }
TVM_REGISTER_GLOBAL("script.ir_builder.tir.PrimFunc").set_body_typed(PrimFunc);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Block").set_body_typed(Block);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate);
} // namespace tir
} // namespace ir_builder
Expand Down
9 changes: 9 additions & 0 deletions src/script/ir_builder/tir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ inline PrimFuncFrame FindPrimFuncFrame(const String& method) {
throw;
}

inline BlockFrame FindBlockFrame(const String& method) {
if (Optional<BlockFrame> frame = IRBuilder::Current()->GetLastFrame<BlockFrame>()) {
return frame.value();
}
LOG(FATAL) << "ValueError: Block frame not find. Please ensure '" << method
<< "' is called under T.block()";
throw;
}

} // namespace tir
} // namespace ir_builder
} // namespace script
Expand Down
27 changes: 27 additions & 0 deletions tests/python/unittest/test_tvmscript_ir_builder_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,32 @@ def test_ir_builder_tir_primfunc():
assert_structural_equal(prim_func_actual, prim_func_expected, map_free_vars=True)


def test_ir_builder_tir_block():
with IRBuilder() as ib:
with T.block("block"):
T.evaluate(0)
# the block generated by IRBuilder
block_realize_actual = ib.get()

# the expected block
block_expected = tir.Block(
iter_vars=[],
reads=[],
writes=[],
name_hint="block",
body=tir.Evaluate(0),
alloc_buffers=None,
match_buffers=None,
annotations={"tir.script_parsing_detect_access": tir.IntImm("int64", 3)},
)
block_realize_expected = tir.BlockRealize(
iter_values=[],
predicate=True,
block=block_expected,
)
# Check if the generated ir is expected
assert_structural_equal(block_realize_actual, block_realize_expected, map_free_vars=True)


if __name__ == "__main__":
tvm.testing.main()