-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TVMScript] Base IRBuilder methods for
PrimFunc
(#12745)
Base IRBuilder methods for `PrimFunc` This PR introduces base IRBuilder methods for `PrimFunc`. Co-authored-by: yongwww <yongcale@gmail.com> Co-authored-by: yongwww <yongcale@gmail.com>
- Loading branch information
Showing
14 changed files
with
561 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
#ifndef TVM_SCRIPT_IR_BUILDER_TIR_FRAME_H_ | ||
#define TVM_SCRIPT_IR_BUILDER_TIR_FRAME_H_ | ||
|
||
#include <tvm/script/ir_builder/base.h> | ||
#include <tvm/script/ir_builder/ir/frame.h> | ||
#include <tvm/tir/stmt.h> | ||
|
||
namespace tvm { | ||
namespace script { | ||
namespace ir_builder { | ||
namespace tir { | ||
|
||
/*! | ||
* \brief A base frame that represents the TIR fame with body of statements. | ||
* | ||
* \sa TIRFrame | ||
*/ | ||
class TIRFrameNode : public IRBuilderFrameNode { | ||
public: | ||
/*! \brief The Stmt within in this frame. */ | ||
Array<tvm::tir::Stmt> stmts; | ||
|
||
void VisitAttrs(tvm::AttrVisitor* v) { | ||
IRBuilderFrameNode::VisitAttrs(v); | ||
v->Visit("stmts", &stmts); | ||
} | ||
|
||
static constexpr const char* _type_key = "script.ir_builder.tir.TIRFrame"; | ||
TVM_DECLARE_BASE_OBJECT_INFO(TIRFrameNode, IRBuilderFrameNode); | ||
}; | ||
|
||
/*! | ||
* \brief Managed reference to TIRFrameNode. | ||
* | ||
* \sa TIRFrameNode | ||
*/ | ||
class TIRFrame : public IRBuilderFrame { | ||
public: | ||
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRFrame, IRBuilderFrame, TIRFrameNode); | ||
|
||
protected: | ||
TIRFrame() = default; | ||
}; | ||
|
||
/*! | ||
* \brief A frame that represents the PrimFunc containing TIR statements. | ||
* | ||
* \sa PrimFuncFrame | ||
*/ | ||
class PrimFuncFrameNode : public TIRFrameNode { | ||
public: | ||
/*! \brief The name of the block. */ | ||
Optional<String> name; | ||
/*! \brief Function parameters. */ | ||
Array<tvm::tir::Var> args; | ||
/*! \brief The return type of the function. */ | ||
Optional<Type> ret_type; | ||
/*! \brief Maps some parameters to specific Buffer data structures. */ | ||
Map<tvm::tir::Var, tvm::tir::Buffer> buffer_map; | ||
/*! \brief The buffer map prior to flattening. */ | ||
Map<tvm::tir::Var, tvm::tir::Buffer> preflattened_buffer_map; | ||
/*! \brief Additional attributes storing the meta-data */ | ||
Optional<Map<String, ObjectRef>> attrs; | ||
/*! \brief The variable map bound to thread env. */ | ||
Map<tvm::tir::Var, tvm::tir::IterVar> env_threads; | ||
/*! \brief The buffer allocated in root block. */ | ||
Array<tvm::tir::Buffer> root_alloc_buffers; | ||
|
||
void VisitAttrs(tvm::AttrVisitor* v) { | ||
TIRFrameNode::VisitAttrs(v); | ||
v->Visit("name", &name); | ||
v->Visit("args", &args); | ||
v->Visit("ret_type", &ret_type); | ||
v->Visit("buffer_map", &buffer_map); | ||
v->Visit("preflattened_buffer_map", &preflattened_buffer_map); | ||
v->Visit("attrs", &attrs); | ||
v->Visit("env_threads", &env_threads); | ||
v->Visit("root_alloc_buffers", &root_alloc_buffers); | ||
} | ||
|
||
static constexpr const char* _type_key = "script.ir_builder.tir.PrimFuncFrame"; | ||
TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncFrameNode, TIRFrameNode); | ||
|
||
public: | ||
/*! | ||
* \brief The method called when exiting RAII scope. | ||
* \sa tvm::support::With | ||
*/ | ||
void ExitWithScope() final; | ||
}; | ||
|
||
/*! | ||
* \brief Managed reference to PrimFuncFrameNode. | ||
* | ||
* \sa PrimFuncFrameNode | ||
*/ | ||
class PrimFuncFrame : public TIRFrame { | ||
public: | ||
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncFrame, TIRFrame, PrimFuncFrameNode); | ||
}; | ||
|
||
/*! | ||
* \brief A frame that represents the assert statement. Proceeds if the condition is true, | ||
* otherwise aborts with the message. | ||
* | ||
* \sa AssertFrame | ||
*/ | ||
class AssertFrameNode : public TIRFrameNode { | ||
public: | ||
/*! \brief The PrimExpr to test. */ | ||
PrimExpr condition; | ||
/*! \brief The output error message when the assertion failed. */ | ||
PrimExpr message; | ||
|
||
void VisitAttrs(tvm::AttrVisitor* v) { | ||
TIRFrameNode::VisitAttrs(v); | ||
v->Visit("condition", &condition); | ||
v->Visit("message", &message); | ||
} | ||
|
||
static constexpr const char* _type_key = "script.ir_builder.tir.AssertFrame"; | ||
TVM_DECLARE_FINAL_OBJECT_INFO(AssertFrameNode, TIRFrameNode); | ||
|
||
public: | ||
/*! | ||
* \brief The method called when exiting RAII scope. | ||
* \sa tvm::support::With | ||
*/ | ||
void ExitWithScope() final; | ||
}; | ||
|
||
} // namespace tir | ||
} // namespace ir_builder | ||
} // namespace script | ||
} // namespace tvm | ||
|
||
#endif // TVM_SCRIPT_IR_BUILDER_TIR_FRAME_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
#ifndef TVM_SCRIPT_IR_BUILDER_TIR_IR_H_ | ||
#define TVM_SCRIPT_IR_BUILDER_TIR_IR_H_ | ||
|
||
#include <tvm/script/ir_builder/base.h> | ||
#include <tvm/script/ir_builder/tir/frame.h> | ||
#include <tvm/tir/op.h> | ||
|
||
namespace tvm { | ||
namespace script { | ||
namespace ir_builder { | ||
namespace tir { | ||
|
||
/*! | ||
* \brief The primitive function statement. | ||
* \return The PrimFuncFrame. | ||
*/ | ||
PrimFuncFrame PrimFunc(); | ||
|
||
/*! | ||
* \brief Evaluate the input expression. | ||
* \param value The input expression to evaluate. | ||
*/ | ||
void Evaluate(PrimExpr value); | ||
|
||
} // namespace tir | ||
} // namespace ir_builder | ||
} // namespace script | ||
} // namespace tvm | ||
|
||
#endif // TVM_SCRIPT_IR_BUILDER_TIR_IR_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Package tvm.script.ir_builder.tir""" | ||
from .ir import * # pylint: disable=wildcard-import,redefined-builtin |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""FFI APIs""" | ||
import tvm._ffi | ||
|
||
tvm._ffi._init_api("script.ir_builder.tir", __name__) # pylint: disable=protected-access |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""IRBuilder for TIR""" | ||
|
||
from tvm._ffi import register_object as _register_object | ||
|
||
from ..base import IRBuilderFrame | ||
|
||
|
||
@_register_object("script.ir_builder.tir.TIRFrame") | ||
class TIRFrame(IRBuilderFrame): | ||
... | ||
|
||
|
||
@_register_object("script.ir_builder.tir.PrimFuncFrame") | ||
class PrimFuncFrame(TIRFrame): | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
# pylint: disable=missing-docstring | ||
"""IRBuilder for TIR""" | ||
|
||
from tvm.tir import PrimExpr, StringImm | ||
|
||
from . import _ffi_api, frame | ||
|
||
|
||
def prim_func() -> frame.PrimFuncFrame: | ||
"""The primitive function statement. | ||
Returns | ||
------- | ||
res : frame.PrimFuncFrame | ||
The PrimFuncFrame. | ||
""" | ||
return _ffi_api.PrimFunc() # pylint: disable=no-member # type: ignore | ||
|
||
|
||
def evaluate(value: PrimExpr) -> None: | ||
"""Evaluate the input expression. | ||
Parameters | ||
---------- | ||
value: PrimExpr | ||
The input expression to evaluate. | ||
""" | ||
if isinstance(value, str): | ||
value = StringImm(value) | ||
return _ffi_api.Evaluate(value) # pylint: disable=no-member # type: ignore | ||
|
||
|
||
# pylint: enable=invalid-name | ||
|
||
|
||
__all__ = [ | ||
"evaluate", | ||
"prim_func", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.