Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript] IRBuilder methods for Stmt #12830

Merged
merged 1 commit into from
Sep 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions include/tvm/script/ir_builder/tir/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,138 @@ class AssertFrameNode : public TIRFrameNode {
void ExitWithScope() final;
};

/*!
* \brief Managed reference to AssertFrameNode.
*
* \sa AssertFrameNode
*/
class AssertFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AssertFrame, TIRFrame, AssertFrameNode);
};

/*!
* \brief A frame represents the let binding expression, which binds a var.
*
* \sa LetFrameNode
*/
class LetFrameNode : public TIRFrameNode {
public:
/*! \brief The variable we bind to */
tvm::tir::Var var;
/*! \brief The value we bind var to */
PrimExpr value;

void VisitAttrs(tvm::AttrVisitor* v) {
TIRFrameNode::VisitAttrs(v);
v->Visit("var", &var);
v->Visit("value", &value);
}

static constexpr const char* _type_key = "script.ir_builder.tir.LetFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(LetFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to LetFrameNode.
*
* \sa LetFrameNode
*/
class LetFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LetFrame, TIRFrame, LetFrameNode);
};

/*!
* \brief The LaunchThreadFrameNode.
* \note It is used only inside a PrimFunc.
*/
class LaunchThreadFrameNode : public TIRFrameNode {
public:
/*! \brief The extent of environment thread. */
PrimExpr extent;
/*! \brief The attribute key, could be either virtual_thread or thread_extent. */
String attr_key;
/*! \brief The iteration variable. */
tvm::tir::IterVar iter_var;

void VisitAttrs(tvm::AttrVisitor* v) {
TIRFrameNode::VisitAttrs(v);
v->Visit("extent", &extent);
v->Visit("attr_key", &attr_key);
v->Visit("iter_var", &iter_var);
}

static constexpr const char* _type_key = "script.ir_builder.tir.LaunchThreadFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(LaunchThreadFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to LaunchThreadFrameNode.
*
* \sa LaunchThreadFrameNode
*/
class LaunchThreadFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LaunchThreadFrame, TIRFrame,
LaunchThreadFrameNode);
};

/*!
* \brief A frame that represents realization.
*
* \sa RealizeFrame
*/
class RealizeFrameNode : public TIRFrameNode {
public:
/*! \brief The region of buffer access. */
tvm::tir::BufferRegion buffer_slice;
/*! \brief The storage scope associated with this realization. */
String storage_scope;
/*! \brief The condition expression. */
PrimExpr condition;

void VisitAttrs(tvm::AttrVisitor* v) {
TIRFrameNode::VisitAttrs(v);
v->Visit("buffer_slice", &buffer_slice);
v->Visit("storage_scope", &storage_scope);
v->Visit("condition", &condition);
}

static constexpr const char* _type_key = "script.ir_builder.tir.RealizeFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(RealizeFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to RealizeFrameNode.
*
* \sa RealizeFrameNode
*/
class RealizeFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RealizeFrame, TIRFrame, RealizeFrameNode);
};
} // namespace tir
} // namespace ir_builder
} // namespace script
Expand Down
40 changes: 40 additions & 0 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,46 @@ ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread,
*/
ForFrame Grid(Array<PrimExpr> extents);

/*!
* \brief The assertion statement.
* \param condition The assertion condition.
* \param message The error message when the assertion fails.
* \return The AssertFrame.
*/
AssertFrame Assert(PrimExpr condition, String message);

/*!
* \brief The let binding.
* \param var The variable to bind.
* \param value The value to be bound.
* \return The created LetFrame.
*/
LetFrame Let(Var var, PrimExpr value);

/*!
* \brief The realization.
* \param buffer_slice The region of buffer access.
* \param storage_scope The storage scope associated with this realization.
* \param condition The condition expression.
* \return The result RealizeFrame.
*/
RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, PrimExpr condition);

/*!
* \brief Launch a thread.
* \param var The iteration variable.
* \param extent The extent of environment thread.
* \return The result LaunchThreadFrame.
*/
LaunchThreadFrame LaunchThread(Var var, PrimExpr extent);

/*!
* \brief Bind a var to thread env.
* \param thread_tag The thread type tag.
* \return The result variable which gets bound to the thread env.
*/
Var EnvThread(String thread_tag);

/*!
* \brief Evaluate the input expression.
* \param value The input expression to evaluate.
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/script/ir_builder/tir/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,23 @@ class ForFrame(TIRFrame):
def __enter__(self) -> Union[Var, List[Var]]: # type: ignore[override]
super().__enter__()
return self.vars if len(self.vars) > 1 else self.vars[0]


@_register_object("script.ir_builder.tir.AssertFrame")
class AssertFrame(TIRFrame):
...


@_register_object("script.ir_builder.tir.LetFrame")
class LetFrame(TIRFrame):
...


@_register_object("script.ir_builder.tir.RealizeFrame")
class RealizeFrame(TIRFrame):
...


@_register_object("script.ir_builder.tir.LaunchThreadFrame")
class LaunchThreadFrame(TIRFrame):
...
131 changes: 131 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
BufferLoad,
BufferRegion,
IntImm,
IterVar,
Let,
PrimExpr,
StringImm,
Var,
Expand Down Expand Up @@ -813,6 +815,130 @@ def grid(*extents: PrimExpr) -> frame.ForFrame:
return _ffi_api.Grid(extents) # type: ignore[attr-defined] # pylint: disable=no-member


def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame: # pylint: disable=invalid-name
"""Create an assertion statement.

Parameters
----------
condition : PrimExpr
The PrimExpr to test.

message : str
The output error message when the assertion fails.

Returns
-------
res : frame.AssertFrame
The result AssertFrame.
"""
return _ffi_api.Assert(condition, message) # type: ignore[attr-defined] # pylint: disable=no-member


def let(
v: Var,
value: PrimExpr,
body: PrimExpr = None,
) -> frame.LetFrame:
"""Create a new let binding.

Parameters
----------
v : Var
The variable to bind.

value : PrimExpr
The value to be bound.

body : PrimExpr
The body expression, None will be used if it was not specified.

Returns
-------
res : frame.LetFrame
The result LetFrame.
"""
if body is None:
return _ffi_api.Let(v, value) # type: ignore[attr-defined] # pylint: disable=no-member
return Let(v, value, body)


def realize(
buffer_slice: BufferRegion,
storage_scope: str,
condition: PrimExpr = True,
) -> frame.RealizeFrame:
"""Create a realization.

Parameters
----------
buffer_slice : BufferRegion
The region of buffer access.

storage_scope : str
The storage scope associated with this realization.

condition: PrimExpr
The condition expression, the default is True.

Returns
-------
res : frame.RealizeFrame
The result RealizeFrame.
"""
return _ffi_api.Realize( # type: ignore[attr-defined] # pylint: disable=no-member
buffer_slice, storage_scope, condition
)


def launch_thread(
iter_var: IterVar, # pylint: disable=redefined-outer-name
extent: PrimExpr,
) -> frame.LaunchThreadFrame:
"""Launch a thread.

Parameters
----------
iter_var : IterVar
The iteration variable.

extent : PrimExpr
The extent of environment thread.

Returns
-------
res : frame.LaunchThreadFrame
The result LaunchThreadFrame.

Examples
--------

.. code-block:: python

from tvm.script.ir_builder import tir as T
brow = T.env_thread("blockIdx.y")
T.launch_thread(brow, 1)

"""
return _ffi_api.LaunchThread(iter_var, extent) # type: ignore[attr-defined] # pylint: disable=no-member


def env_thread(thread_tag: str) -> IterVar:
"""Bind a var to thread env"

Parameters
----------
thread_tag : str
The thread type tag.

Returns
-------
res : IterVar
The result iteration variable gets bound to the thread env.

"""
return _ffi_api.EnvThread(thread_tag) # type: ignore[attr-defined] # pylint: disable=no-member


def evaluate(value: PrimExpr) -> None:
"""Evaluate the input expression.

Expand Down Expand Up @@ -1159,6 +1285,11 @@ def var(dtype, name="") -> Var:
"unroll",
"thread_binding",
"grid",
"Assert",
"let",
"realize",
"launch_thread",
"env_thread",
"evaluate",
"int8",
"int16",
Expand Down
Loading