Skip to content

Commit

Permalink
[TVMScript] Base IRBuilder methods for PrimFunc (#12745)
Browse files Browse the repository at this point in the history
Base IRBuilder methods for `PrimFunc`

This PR introduces base IRBuilder methods for `PrimFunc`.

Co-authored-by: yongwww <yongcale@gmail.com>

Co-authored-by: yongwww <yongcale@gmail.com>
  • Loading branch information
cyx-6 and yongwww authored Sep 9, 2022
1 parent 1d32c40 commit 8bd81e6
Show file tree
Hide file tree
Showing 14 changed files with 561 additions and 0 deletions.
2 changes: 2 additions & 0 deletions include/tvm/script/ir_builder/ir/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
namespace tvm {
namespace script {
namespace ir_builder {
namespace ir {

/*!
* \brief A frame that represents the IRModule frame with functions and global variables.
Expand Down Expand Up @@ -64,6 +65,7 @@ class IRModuleFrame : public IRBuilderFrame {
IRModuleFrameNode);
};

} // namespace ir
} // namespace ir_builder
} // namespace script
} // namespace tvm
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/script/ir_builder/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@
namespace tvm {
namespace script {
namespace ir_builder {
namespace ir {

/*!
* \brief The IRModule declaration statement.
* \return The IRModuleFrame.
*/
TVM_DLL IRModuleFrame IRModule();

} // namespace ir
} // namespace ir_builder
} // namespace script
} // namespace tvm
Expand Down
155 changes: 155 additions & 0 deletions include/tvm/script/ir_builder/tir/frame.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_SCRIPT_IR_BUILDER_TIR_FRAME_H_
#define TVM_SCRIPT_IR_BUILDER_TIR_FRAME_H_

#include <tvm/script/ir_builder/base.h>
#include <tvm/script/ir_builder/ir/frame.h>
#include <tvm/tir/stmt.h>

namespace tvm {
namespace script {
namespace ir_builder {
namespace tir {

/*!
* \brief A base frame that represents the TIR fame with body of statements.
*
* \sa TIRFrame
*/
class TIRFrameNode : public IRBuilderFrameNode {
public:
/*! \brief The Stmt within in this frame. */
Array<tvm::tir::Stmt> stmts;

void VisitAttrs(tvm::AttrVisitor* v) {
IRBuilderFrameNode::VisitAttrs(v);
v->Visit("stmts", &stmts);
}

static constexpr const char* _type_key = "script.ir_builder.tir.TIRFrame";
TVM_DECLARE_BASE_OBJECT_INFO(TIRFrameNode, IRBuilderFrameNode);
};

/*!
* \brief Managed reference to TIRFrameNode.
*
* \sa TIRFrameNode
*/
class TIRFrame : public IRBuilderFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRFrame, IRBuilderFrame, TIRFrameNode);

protected:
TIRFrame() = default;
};

/*!
* \brief A frame that represents the PrimFunc containing TIR statements.
*
* \sa PrimFuncFrame
*/
class PrimFuncFrameNode : public TIRFrameNode {
public:
/*! \brief The name of the block. */
Optional<String> name;
/*! \brief Function parameters. */
Array<tvm::tir::Var> args;
/*! \brief The return type of the function. */
Optional<Type> ret_type;
/*! \brief Maps some parameters to specific Buffer data structures. */
Map<tvm::tir::Var, tvm::tir::Buffer> buffer_map;
/*! \brief The buffer map prior to flattening. */
Map<tvm::tir::Var, tvm::tir::Buffer> preflattened_buffer_map;
/*! \brief Additional attributes storing the meta-data */
Optional<Map<String, ObjectRef>> attrs;
/*! \brief The variable map bound to thread env. */
Map<tvm::tir::Var, tvm::tir::IterVar> env_threads;
/*! \brief The buffer allocated in root block. */
Array<tvm::tir::Buffer> root_alloc_buffers;

void VisitAttrs(tvm::AttrVisitor* v) {
TIRFrameNode::VisitAttrs(v);
v->Visit("name", &name);
v->Visit("args", &args);
v->Visit("ret_type", &ret_type);
v->Visit("buffer_map", &buffer_map);
v->Visit("preflattened_buffer_map", &preflattened_buffer_map);
v->Visit("attrs", &attrs);
v->Visit("env_threads", &env_threads);
v->Visit("root_alloc_buffers", &root_alloc_buffers);
}

static constexpr const char* _type_key = "script.ir_builder.tir.PrimFuncFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to PrimFuncFrameNode.
*
* \sa PrimFuncFrameNode
*/
class PrimFuncFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncFrame, TIRFrame, PrimFuncFrameNode);
};

/*!
* \brief A frame that represents the assert statement. Proceeds if the condition is true,
* otherwise aborts with the message.
*
* \sa AssertFrame
*/
class AssertFrameNode : public TIRFrameNode {
public:
/*! \brief The PrimExpr to test. */
PrimExpr condition;
/*! \brief The output error message when the assertion failed. */
PrimExpr message;

void VisitAttrs(tvm::AttrVisitor* v) {
TIRFrameNode::VisitAttrs(v);
v->Visit("condition", &condition);
v->Visit("message", &message);
}

static constexpr const char* _type_key = "script.ir_builder.tir.AssertFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(AssertFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

} // namespace tir
} // namespace ir_builder
} // namespace script
} // namespace tvm

#endif // TVM_SCRIPT_IR_BUILDER_TIR_FRAME_H_
48 changes: 48 additions & 0 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_SCRIPT_IR_BUILDER_TIR_IR_H_
#define TVM_SCRIPT_IR_BUILDER_TIR_IR_H_

#include <tvm/script/ir_builder/base.h>
#include <tvm/script/ir_builder/tir/frame.h>
#include <tvm/tir/op.h>

namespace tvm {
namespace script {
namespace ir_builder {
namespace tir {

/*!
* \brief The primitive function statement.
* \return The PrimFuncFrame.
*/
PrimFuncFrame PrimFunc();

/*!
* \brief Evaluate the input expression.
* \param value The input expression to evaluate.
*/
void Evaluate(PrimExpr value);

} // namespace tir
} // namespace ir_builder
} // namespace script
} // namespace tvm

#endif // TVM_SCRIPT_IR_BUILDER_TIR_IR_H_
18 changes: 18 additions & 0 deletions python/tvm/script/ir_builder/tir/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Package tvm.script.ir_builder.tir"""
from .ir import * # pylint: disable=wildcard-import,redefined-builtin
20 changes: 20 additions & 0 deletions python/tvm/script/ir_builder/tir/_ffi_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs"""
import tvm._ffi

tvm._ffi._init_api("script.ir_builder.tir", __name__) # pylint: disable=protected-access
31 changes: 31 additions & 0 deletions python/tvm/script/ir_builder/tir/frame.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""IRBuilder for TIR"""

from tvm._ffi import register_object as _register_object

from ..base import IRBuilderFrame


@_register_object("script.ir_builder.tir.TIRFrame")
class TIRFrame(IRBuilderFrame):
...


@_register_object("script.ir_builder.tir.PrimFuncFrame")
class PrimFuncFrame(TIRFrame):
...
55 changes: 55 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring
"""IRBuilder for TIR"""

from tvm.tir import PrimExpr, StringImm

from . import _ffi_api, frame


def prim_func() -> frame.PrimFuncFrame:
"""The primitive function statement.
Returns
-------
res : frame.PrimFuncFrame
The PrimFuncFrame.
"""
return _ffi_api.PrimFunc() # pylint: disable=no-member # type: ignore


def evaluate(value: PrimExpr) -> None:
"""Evaluate the input expression.
Parameters
----------
value: PrimExpr
The input expression to evaluate.
"""
if isinstance(value, str):
value = StringImm(value)
return _ffi_api.Evaluate(value) # pylint: disable=no-member # type: ignore


# pylint: enable=invalid-name


__all__ = [
"evaluate",
"prim_func",
]
2 changes: 2 additions & 0 deletions src/script/ir_builder/ir/frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
namespace tvm {
namespace script {
namespace ir_builder {
namespace ir {

void IRModuleFrameNode::ExitWithScope() {
ICHECK_EQ(functions.size(), global_vars.size());
Expand All @@ -38,6 +39,7 @@ void IRModuleFrameNode::ExitWithScope() {

TVM_REGISTER_NODE_TYPE(IRModuleFrameNode);

} // namespace ir
} // namespace ir_builder
} // namespace script
} // namespace tvm
2 changes: 2 additions & 0 deletions src/script/ir_builder/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
namespace tvm {
namespace script {
namespace ir_builder {
namespace ir {

IRModuleFrame IRModule() {
ObjectPtr<IRModuleFrameNode> n = make_object<IRModuleFrameNode>();
Expand All @@ -33,6 +34,7 @@ IRModuleFrame IRModule() {

TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule);

} // namespace ir
} // namespace ir_builder
} // namespace script
} // namespace tvm
Loading

0 comments on commit 8bd81e6

Please sign in to comment.