diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h new file mode 100644 index 000000000000..61ca3eb9f7eb --- /dev/null +++ b/include/tvm/script/ir_builder/base.h @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_BASE_H_ +#define TVM_SCRIPT_IR_BUILDER_BASE_H_ + +#include +#include +#include + +#include + +namespace tvm { +namespace script { +namespace ir_builder { + +////////////////////////////// IRBuilderFrame ////////////////////////////// + +/*! + * \brief A stack frame of the IRBuilder used to keep track of the current scope. + * Furthermore, the information stored in each stack frame can be useful for context-dependent + * IR construction. + * + * \example + * + * The `T::MatchBuffer` below adds an element in `PrimFuncNode::buffer_map`: + * + * \code {.cpp} + * + * using T = tvm::script::ir_builder::tir; + * With _(...); + * Buffer buffer = T::MatchBuffer(...); + * + * \endcode + * + * The `T::MatchBuffer` below instead generates `MatchBufferRegion` in a TIR block: + * + * \code {.cpp} + * + * using T = tvm::script::ir_builder::tir; + * With _(...); + * { + * With _2(...); + * Buffer buffer = T::MatchBuffer(...); + * } + * + * \endcode + */ +class IRBuilderFrameNode : public runtime::Object { + public: + /*! \brief A list of callbacks used when exiting the frame. */ + std::vector> callbacks; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `callbacks` is not visited. + } + + static constexpr const char* _type_key = "script.ir_builder.IRBuilderFrame"; + TVM_DECLARE_BASE_OBJECT_INFO(IRBuilderFrameNode, runtime::Object); + + public: + /*! \brief Default destructor. */ + virtual ~IRBuilderFrameNode() = default; + /*! + * \brief The method called when entering RAII scope. + * \sa tvm::support::With + */ + virtual void EnterWithScope(); + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + virtual void ExitWithScope(); + /*! + * \brief Add a callback method invoked when exiting the RAII scope. + * \param callback The callback to be added. + */ + void AddCallback(runtime::TypedPackedFunc callback); +}; + +/*! + * \brief Managed reference to an IRBuilderFrameNode. + * \sa IRBuilderFrameNode + */ +class IRBuilderFrame : public runtime::ObjectRef { + public: + /*! \brief Default destructor. */ + virtual ~IRBuilderFrame() = default; + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRBuilderFrame, ObjectRef, IRBuilderFrameNode); + + protected: + /*! \brief Disallow direct construction of this object. */ + IRBuilderFrame() = default; + + public: + /*! + * \brief Redirected to `IRBuilderFrameNode::EnterWithScope`. + * \sa IRBuilderFrameNode::EnterWithScope + */ + inline void EnterWithScope() { + ICHECK(data_ != nullptr); + static_cast(data_.get())->EnterWithScope(); + } + /*! + * \brief Redirected to `IRBuilderFrameNode::ExitWithScope`. + * \sa IRBuilderFrameNode::ExitWithScope + */ + inline void ExitWithScope() { + ICHECK(data_ != nullptr); + static_cast(data_.get())->ExitWithScope(); + data_.reset(); + } +}; + +////////////////////////////// IRBuilder ////////////////////////////// + +/*! + * \brief A dialect-agnostic IRBuilder that constructs any IR of TVM. + * An idiomatic use of this class is to put this inside the RAII with-scope, + * call dialect-specific methods accordingly. Upon exiting the scope. + * + * \code + * + * PrimFunc ConstructPrimFunc() { + * using tvm::script::ir_builder::IRBuilder; + * using T = tvm::script::ir_builder::tir; + * IRBuilder builder; + * // Step 1. Place IRBuilder inside the with-scope. + * { + * With _(builder); + * // Step 2. Call dialect-specific methods. + * With _2(...); + * T::MatchBuffer(...); + * } + * // Step 3. Return the constructed PrimFunc. + * return builder->Get(); + * } + * + * \endcode + */ +class IRBuilderNode : public runtime::Object { + public: + /*! \brief A stack of context frames in the IRBuilder */ + runtime::Array frames; + /*! \brief The outcome of IR construction */ + Optional result; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("frames", &frames); + v->Visit("result", &result); + } + + static constexpr const char* _type_key = "script.ir_builder.IRBuilder"; + TVM_DECLARE_FINAL_OBJECT_INFO(IRBuilderNode, runtime::Object); + + public: + /*! + * \brief Find a frame of the given type in the stack `this->frames` from top to bottom. + * \tparam T The type of the frame to find. + * \return The frame if found, otherwise NullOpt. + */ + template + inline Optional FindFrame() const; + /*! + * \brief Get the frame on top of the stack `this->frames` if its type is `TFrame`. + * \tparam TFrame The assumed type of the last frame on stack. + * \return The frame if the stack is non-empty and the top of the stack is of type `TFrame`. + * Otherwise NullOpt. + */ + template + inline Optional GetLastFrame() const; + /*! + * \brief Get the IR being constructed. + * \tparam TObjectRef The type of the IR being constructed. + * \return The resulting IR. Throw an exception if the IR is not constructed yet. + */ + template + inline TObjectRef Get() const; +}; + +/*! + * \brief Managed reference to an IRBuilderNode. + * \sa IRBuilderNode + */ +class IRBuilder : public runtime::ObjectRef { + public: + /*! \brief Creates an IRBuilder. */ + IRBuilder(); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRBuilder, ObjectRef, IRBuilderNode); + + public: + /*! + * \brief Puts the current IRBuilder into a thread-local scope, which can be retrieved using + * `IRBuilder::Current()`. + * + * \code {.cpp} + * IRBuilder builder; + * { + * With _(builder); + * // IRBuilder::Current() == builder + * } + * // IRBuilder::Current() == nullptr + * \endcode + * + * \sa IRBuilder::Current + * \sa IRBuilder::ExitWithScope + * \sa tvm::support::With + */ + void EnterWithScope(); + /*! + * \brief Exit the RAII scope. + * \sa IRBuilder::EnterWithScope + * \sa IRBuilder::Current + * \sa tvm::support::With + */ + void ExitWithScope(); + /*! + * \brief Get the current IRBuilder in the current thread-local scope. + * \return The current IRBuilder. + * \sa IRBuilder::EnterWithScope + * \sa IRBuilder::ExitWithScope + * \sa tvm::support::With + */ + static IRBuilder Current(); + /*! + * \brief Give a string name to the `obj` + * \tparam TObjectRef The type of the object to name. + * \param name The name to give to the object. + * \param obj The object to name. + */ + template + inline static TObjectRef Name(String name, TObjectRef obj); +}; + +////////////////////////////// Details ////////////////////////////// + +namespace details { + +class Namer { + public: + using FType = NodeFunctor; + static FType& vtable(); + static void Name(ObjectRef node, String name); +}; + +} // namespace details + +template +inline TObjectRef IRBuilder::Name(String name, TObjectRef obj) { + details::Namer::Name(obj, name); + return Downcast(obj); +} + +template +inline Optional IRBuilderNode::FindFrame() const { + using TFrameNode = typename TFrame::ContainerType; + for (auto it = frames.rbegin(); it != frames.rend(); ++it) { + if (const TFrameNode* p = (*it).template as()) { + return GetRef(p); + } + } + return NullOpt; +} + +template +inline Optional IRBuilderNode::GetLastFrame() const { + using TFrameNode = typename TFrame::ContainerType; + if (!frames.empty() && frames.back()->IsInstance()) { + return Downcast(frames.back()); + } + return NullOpt; +} + +template +inline TObjectRef IRBuilderNode::Get() const { + using TObject = typename TObjectRef::ContainerType; + CHECK(result.defined()) << "IndexError: No result exists in IRBuilder yet"; + const auto* n = result.as(); + CHECK(n != nullptr) << "TypeError: IRBuilder result is not of type: " << TObject::_type_key; + return GetRef(n); +} + +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_BASE_H_ diff --git a/python/tvm/script/ir_builder/__init__.py b/python/tvm/script/ir_builder/__init__.py new file mode 100644 index 000000000000..b325fadd864b --- /dev/null +++ b/python/tvm/script/ir_builder/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.script.ir_builder is a generic IR builder for TVM.""" +from .base import IRBuilder diff --git a/python/tvm/script/ir_builder/_ffi_api.py b/python/tvm/script/ir_builder/_ffi_api.py new file mode 100644 index 000000000000..68811c9e018c --- /dev/null +++ b/python/tvm/script/ir_builder/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.script.ir_builder""" +import tvm._ffi + +tvm._ffi._init_api("script.ir_builder", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/ir_builder/base.py b/python/tvm/script/ir_builder/base.py new file mode 100644 index 000000000000..767fa8bf2596 --- /dev/null +++ b/python/tvm/script/ir_builder/base.py @@ -0,0 +1,181 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A generic IRBuilder across the TVM stack""" +from typing import Any, Callable, List + +from tvm._ffi import register_object as _register_object +from tvm.runtime import Object as _Object + +from . import _ffi_api + + +@_register_object("script.ir_builder.IRBuilderFrame") +class IRBuilderFrame(_Object): + """A stack frame of the IRBuilder used to keep track of the current scope. + Furthermore, the information stored in each stack frame can be useful for context-dependent + IR construction. + + Examples + -------- + + The `T.match_buffer` below instead an element in the buffer map of `PrimFuncFrame`: + + .. code-block:: python + + from tvm.script.ir_builder import tir as T + from tvm.script.ir_builder import IRBuilder + + with IRBuilder() as builder: + with T.prim_func(...): # pushes a PrimFuncFrame (subclass of IRBuilderFrame) + # to `builder`'s stack of frames + buffer = T.match_buffer(...) + + + The `T.match_buffer` below instead generates `MatchBufferRegion` in a TIR block: + + .. code-block:: python + + from tvm.script.ir_builder import tir as T + from tvm.script.ir_builder import IRBuilder + + with IRBuilder() as builder: + with T.prim_func(...): # pushes a PrimFuncFrame (subclass of IRBuilderFrame) + # to `builder`'s stack of frames + with T.block(...): # pushes a BlockFrame (subclass of IRBuilderFrame) + # to `builder`'s stack of frames + buffer = T.match_buffer(...) + """ + + def __enter__(self) -> "IRBuilderFrame": + _ffi_api.IRBuilderFrameEnter(self) # pylint: disable=no-member # type: ignore + return self + + def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument + _ffi_api.IRBuilderFrameExit(self) # pylint: disable=no-member # type: ignore + + def add_callback(self, callback: Callable[[], None]) -> None: + """Add a callback method invoked when exiting the with-scope. + + Parameters + ---------- + callback : Callable[[], None] + The callback method to be invoked. + """ + _ffi_api.IRBuilderFrameAddCallback( # pylint: disable=no-member # type: ignore + self, callback + ) + + +@_register_object("script.ir_builder.IRBuilder") +class IRBuilder(_Object): + """A dialect-agnostic IRBuilder that constructs any IR of TVM. + + Examples + -------- + An idiomatic use of this class is to put this inside the with-scope, + call dialect-specific methods accordingly. Upon exiting the scope. + + .. code-block:: python + from tvm.script.ir_builder import tir as T + from tvm.script.ir_builder import IRBuilder + + with IRBuilder() as builder: + with T.prim_func(...): # pushes a PrimFuncFrame (subclass of IRBuilderFrame) + # to `builder`'s stack of frames + buffer = T.match_buffer(...) + + return builder.get() # returns the constructed IR, i.e. tir.PrimFunc + """ + + def __init__(self) -> None: + """Construct an IRBuilder.""" + self.__init_handle_by_constructor__( + _ffi_api.IRBuilder # pylint: disable=no-member # type: ignore + ) + + def __enter__(self) -> "IRBuilder": + """Enter the with-scope for IRBuilder, which allows the IRBuilder to be discoverable + using `IRBuilder.current()`. + + Examples + -------- + .. code-block:: python + from tvm.script.ir_builder import IRBuilder + + with IRBuilder() as builder: + assert IRBuilder.current() == builder + """ + _ffi_api.IRBuilderEnter(self) # pylint: disable=no-member # type: ignore + return self + + def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument + _ffi_api.IRBuilderExit(self) # pylint: disable=no-member # type: ignore + + @staticmethod + def current() -> "IRBuilder": + """Get the current IRBuilder put in the with-scope. + + Returns + ------- + builder : IRBuilder + The current IRBuilder. + """ + return _ffi_api.IRBuilderCurrent() # pylint: disable=no-member # type: ignore + + def get(self) -> _Object: + """Get the constructed IR.""" + return _ffi_api.IRBuilderGet(self) # pylint: disable=no-member # type: ignore + + @staticmethod + def name(s: str, v: Any) -> Any: + """Set the name of an object. + + Parameters + ---------- + s : str + The name of the object. + v : Any + The object to name. + + Returns + ------- + v : Any + The same object with the name set. + """ + return _ffi_api.IRBuilderName(s, v) # pylint: disable=no-member # type: ignore + + @staticmethod + def name_many( # pylint: disable=invalid-name + s: List[str], + vs: List[Any], + ) -> List[Any]: + """Set the name of a list of objects. + + Parameters + ---------- + s : List[str] + The names of the objects. + vs : List[Any] + The objects to name. + + Returns + ------- + vs : List[Any] + The same objects with the names set. + """ + assert len(s) == len(vs) + return [IRBuilder.name(i, v) for i, v in zip(s, vs)] diff --git a/src/script/ir_builder/base.cc b/src/script/ir_builder/base.cc new file mode 100644 index 000000000000..8303efff4f20 --- /dev/null +++ b/src/script/ir_builder/base.cc @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include + +namespace tvm { +namespace script { +namespace ir_builder { + +void IRBuilderFrameNode::EnterWithScope() { + IRBuilder::Current()->frames.push_back(GetRef(this)); +} + +void IRBuilderFrameNode::ExitWithScope() { + for (auto it = callbacks.rbegin(); it != callbacks.rend(); ++it) { + (*it)(); + } + this->callbacks.clear(); + IRBuilder::Current()->frames.pop_back(); +} + +void IRBuilderFrameNode::AddCallback(runtime::TypedPackedFunc callback) { + if (IRBuilder::Current()->frames.empty()) { + LOG(FATAL) << "ValueError: No frames in Builder to add callback"; + } + IRBuilder::Current()->frames.back()->callbacks.push_back(callback); +} + +IRBuilder::IRBuilder() { + ObjectPtr n = make_object(); + n->frames.clear(); + n->result = NullOpt; + data_ = n; +} + +std::vector* ThreadLocalBuilderStack() { + thread_local std::vector stack; + return &stack; +} + +void IRBuilder::EnterWithScope() { + IRBuilderNode* n = this->get(); + CHECK(n->frames.empty()) << "ValueError: There are frame(s) left in the builder: " + << n->frames.size() + << ". Please use a fresh new builder every time building IRs"; + n->result = NullOpt; + std::vector* stack = ThreadLocalBuilderStack(); + stack->push_back(*this); +} + +void IRBuilder::ExitWithScope() { + std::vector* stack = ThreadLocalBuilderStack(); + ICHECK(!stack->empty()); + stack->pop_back(); +} + +IRBuilder IRBuilder::Current() { + std::vector* stack = ThreadLocalBuilderStack(); + CHECK(!stack->empty()) << "ValueError: No builder in current scope"; + return stack->back(); +} + +namespace details { + +Namer::FType& Namer::vtable() { + static FType inst; + return inst; +} + +void Namer::Name(ObjectRef node, String name) { + static const FType& f = vtable(); + CHECK(node.defined()) << "ValueError: Cannot name nullptr with: " << name; + CHECK(f.can_dispatch(node)) << "ValueError: Do not know how to name type \"" + << node->GetTypeKey(); + f(node, name); +} + +} // namespace details + +TVM_REGISTER_NODE_TYPE(IRBuilderFrameNode); +TVM_REGISTER_NODE_TYPE(IRBuilderNode); +TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderFrameEnter") + .set_body_method(&IRBuilderFrameNode::EnterWithScope); +TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderFrameExit") + .set_body_method(&IRBuilderFrameNode::ExitWithScope); +TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderFrameAddCallback") + .set_body_method(&IRBuilderFrameNode::AddCallback); +TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilder").set_body_typed([]() { return IRBuilder(); }); +TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderEnter").set_body_method(&IRBuilder::EnterWithScope); +TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderExit").set_body_method(&IRBuilder::ExitWithScope); +TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderCurrent").set_body_typed(IRBuilder::Current); +TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderGet") + .set_body_method(&IRBuilderNode::Get); +TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderName").set_body_typed(IRBuilder::Name); + +} // namespace ir_builder +} // namespace script +} // namespace tvm diff --git a/tests/python/unittest/test_tvmscript_ir_builder_base.py b/tests/python/unittest/test_tvmscript_ir_builder_base.py new file mode 100644 index 000000000000..b41e8cdd92cb --- /dev/null +++ b/tests/python/unittest/test_tvmscript_ir_builder_base.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unittests for tvm.script.ir_builder.base""" +import pytest +from tvm.script.ir_builder import IRBuilder + + +def test_ir_builder_scope(): + with IRBuilder() as ib: # pylint: disable=invalid-name + assert IRBuilder.current() == ib + + +def test_ir_builder_multi_scope(): + with IRBuilder() as ib: # pylint: disable=invalid-name + with IRBuilder() as ib2: # pylint: disable=invalid-name + assert IRBuilder.current() == ib2 + assert IRBuilder.current() == ib + + +def test_ir_builder_no_scope(): + with pytest.raises(ValueError): + IRBuilder.current() + + +if __name__ == "__main__": + test_ir_builder_scope() + test_ir_builder_multi_scope() + test_ir_builder_no_scope()