Skip to content

Commit

Permalink
Add IRDocsifier
Browse files Browse the repository at this point in the history
Co-authored-by: Greg Bonik <gbonik@octoml.ai>
  • Loading branch information
yelite and gbonik committed Aug 12, 2022
1 parent e242b78 commit 9bff2ca
Show file tree
Hide file tree
Showing 6 changed files with 662 additions and 0 deletions.
189 changes: 189 additions & 0 deletions include/tvm/script/printer/ir_docsifier.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_SCRIPT_PRINTER_IR_DOCSIFIER_H_
#define TVM_SCRIPT_PRINTER_IR_DOCSIFIER_H_

#include <tvm/node/node.h>
#include <tvm/runtime/logging.h>
#include <tvm/script/printer/doc.h>
#include <tvm/script/printer/frame.h>
#include <tvm/script/printer/traced_object.h>
#include <tvm/script/printer/traced_object_functor.h>
#include <tvm/script/printer/var_table.h>
#include <tvm/support/with.h>

namespace tvm {
namespace script {
namespace printer {

using WithCtx = With<ContextManager>;

/*!
* \breif IRDocsifier is the top-level interface in the IR->Doc process.
*
* It provides methods to convert IR node object to Doc, operate on Frame
* objects and change dispatch tokens.
*
* Example usage:
* \code
* TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
* .set_dispatch([](TracedObject<tir::Var> obj, IRDocsifier p) { return IdDoc("x"); });
*
* TracedObject<tir::Var> var = ...;
* IRDocsifier p;
* p->AsDoc(var); // returns an IdDoc("x")
* \endcode
*
*/
class IRDocsifierNode : public Object {
public:
/*!
* \brief The var table to use during the printing process.
* \sa VarTableNode
*/
VarTable vars;
/*!
* \brief The stack of frames.
* \sa FrameNode
*/
Array<Frame> frames;
/*!
* \brief The stack of dispatch tokens.
*
* The dispatch token on the top decides which dispatch function to use
* when converting IR node object to Doc.
*/
Array<String> dispatch_tokens;
/*!
* \brief This map connects IR dipatch token to the name of identifier.
*/
Map<String, String> ir_prefix;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("vars", &vars);
v->Visit("frames", &frames);
v->Visit("dispatch_tokens", &dispatch_tokens);
v->Visit("ir_prefix", &ir_prefix);
}

static constexpr const char* _type_key = "script.printer.IRDocsifier";
TVM_DECLARE_FINAL_OBJECT_INFO(IRDocsifierNode, Object);

public:
/*!
* \brief Transform the input object into TDoc.
* \param obj The object to be transformed.
*
* \return The Doc object.
*/
template <class TDoc>
TDoc AsDoc(const TracedObject<ObjectRef>& obj) const {
auto result = Downcast<TDoc>(AsDocImpl(obj));
result->source_paths.push_back(obj.GetPath());
return result;
}

/*!
* \brief Helper method to transform object into ExprDoc.
* \param obj The object to be transformed.
*
* \return The ExprDoc object.
*/
ExprDoc AsExprDoc(const TracedObject<ObjectRef>& obj) { return AsDoc<ExprDoc>(obj); }

/*!
* \brief Push a new dispatch token into the stack
* \details The top dispatch token decides which dispatch table to use
* when printing Object. This method returns a RAII guard which
* pops the token when going out of the scope.
*
* \param token The dispatch token to push.
*
* \return A RAII guard to pop dispatch token when going out of scope.
*/
WithCtx WithDispatchToken(const String& token) {
this->dispatch_tokens.push_back(token);
return WithCtx(nullptr, [this]() { this->dispatch_tokens.pop_back(); });
}

/*!
* \brief Push a new frame the stack
* \details Frame contains the contextual information that's needed during printing,
* for example, variables in the scope. This method returns a RAII guard which
* pops the frame and call the cleanup method of frame when going out of the scope.
*
* \param frame The frame to push.
*
* \return A RAII guard to pop frame and call the exit method of frame
* when going out of scope
*/
WithCtx WithFrame(const Frame& frame) {
frame->EnterWithScope();
this->frames.push_back(frame);
return WithCtx(nullptr, [this, pushed_frame = frame]() {
Frame last_frame = this->frames.back();
ICHECK_EQ(last_frame, pushed_frame);
this->frames.pop_back();
last_frame->ExitWithScope();
});
}

/*!
* \brief Get the top frame with type FrameType
* \tparam FrameType The type of frame to get.
*/
template <typename FrameType>
Optional<FrameType> GetFrame() const {
for (auto it = frames.rbegin(); it != frames.rend(); ++it) {
if (const auto* f = (*it).as<typename FrameType::ContainerType>()) {
return GetRef<FrameType>(f);
}
}
return NullOpt;
}

private:
Doc AsDocImpl(const TracedObject<ObjectRef>& obj) const;
};

/*!
* \breif Reference type of IRDocsifierNode.
*/
class IRDocsifier : public ObjectRef {
public:
/*!
* \brief Create a IRDocsifier.
* \param ir_prefix The ir_prefix to use for this IRDocsifier.
*/
explicit IRDocsifier(Map<String, String> ir_prefix);

using FType = TracedObjectFunctor<printer::Doc, IRDocsifier>;
/*!
* \brief The registration table for IRDocsifier.
*/
TVM_DLL static FType& vtable();

TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRDocsifier, ObjectRef, IRDocsifierNode);
};

} // namespace printer
} // namespace script
} // namespace tvm

#endif // TVM_SCRIPT_PRINTER_IR_DOCSIFIER_H_
19 changes: 19 additions & 0 deletions include/tvm/support/with.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <dmlc/common.h>

#include <functional>
#include <utility>

namespace tvm {
Expand Down Expand Up @@ -80,5 +81,23 @@ class With {
ContextType ctx_;
};

class ContextManager {
public:
template <class FEnter, class FExit>
explicit ContextManager(FEnter f_enter, FExit f_exit) : f_enter_(f_enter), f_exit_(f_exit) {}

private:
void EnterWithScope() {
if (f_enter_) f_enter_();
}
void ExitWithScope() {
if (f_exit_) f_exit_();
}
std::function<void()> f_enter_;
std::function<void()> f_exit_;
template <typename>
friend class With;
};

} // namespace tvm
#endif // TVM_SUPPORT_WITH_H_
159 changes: 159 additions & 0 deletions python/tvm/script/printer/ir_docsifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from contextlib import ExitStack, contextmanager
from typing import Callable, Dict, Mapping, Optional, Sequence, Type, TypeVar, Generator

from tvm._ffi import get_object_type_index, register_object
from tvm.runtime import Object, ObjectPath

from . import _ffi_api
from .doc import Doc
from .frame import Frame
from .var_table import VarTable


@register_object("script.printer.IRDocsifier")
class IRDocsifier(Object):
"""
IRDocsifier is the top-level interface in the IR->Doc process.
It provides methods to convert IR node object to Doc, operate on Frame
objects and change dispatch tokens.
"""

ir_prefix: Mapping[str, str]
vars: VarTable
frames: Sequence[Frame]
dispatch_tokens: Sequence[str]

def __init__(self, ir_prefix: Dict[str, str]):
"""
Create a new IRDocsifier.
Parameters
----------
ir_prefix : Dict[str, str]
The ir prefix to use. Key is the IR dispatch token and
value is the name of identifier for this IR's namespace in TVMScript.
"""
self.__init_handle_by_constructor__(_ffi_api.IRDocsifier, ir_prefix) # type: ignore # pylint: disable=no-member

_TObject = TypeVar("_TObject", bound=Object)

@classmethod
def set_dispatch(
cls,
node_type: Type[_TObject],
dispatch_function: Callable[[_TObject, "IRDocsifier"], Doc],
dispatch_token: str = "",
) -> None:
"""
Set the dispatch function to transform a particular IR node type to Doc
Parameters
----------
node_type : Type[_TObject]
The type of object to dispatch on.
dispatch_function : Callable[[_TObject, "IRDocsifier"], Doc]
The dispatch function. It's called to transform IR node object to Doc.
dispatch_token : str
Function will only be called when this dispatch_token is the same as the one
on the top of IRDocsifier's dispatch_tokens stack. An empty dispatch token
means registering as default dispatch function, which will be called when
there is no dispatch function registered with the current dispatch token.
"""
type_index = get_object_type_index(node_type)
_ffi_api.IRDocsifierSetDispatch(dispatch_token, type_index, dispatch_function) # type: ignore # pylint: disable=no-member

def as_doc(self, obj: Object, object_path: ObjectPath) -> Doc:
"""
Transform the input object into Doc.
Parameters
----------
obj : Object
The IR node object.
object_path : ObjectPath
The object path of this object. It's used for locating diagnostic message.
Returns
-------
doc : Doc
The doc for this object.
"""
return _ffi_api.IRDocsifierAsDoc(self, obj, object_path) # type: ignore # pylint: disable=no-member

def get_frame(self, frame_type: Type[Frame]) -> Optional[Frame]:
"""
Get the top frame with type `frame_type`.
Parameters
----------
frame_type : Type[Frame]
The target frame type.
Returns
-------
frame : Optional[Frame]
The frame if found, otherwise None.
"""
for i in range(len(self.frames) - 1, -1, -1):
if isinstance(self.frames[i], frame_type):
return self.frames[i]
return None

@contextmanager
def dispatch_token(self, token: str):
"""
Push a new dispatch token to the stack.
Parameters
----------
token : str
The token to push.
Returns
-------
A context manager that pops this dispatch token when exits.
"""
with ExitStack() as stack:
_ffi_api.IRDocsifierPushDispatchToken(self, token) # type: ignore # pylint: disable=no-member
stack.callback(_ffi_api.IRDocsifierPopDispatchToken, self) # type: ignore # pylint: disable=no-member
yield

_TFrame = TypeVar("_TFrame", bound=Frame)

@contextmanager
def frame(self, frame: _TFrame) -> Generator[_TFrame, None, None]:
"""
Push a new frame to the stack.
Parameters
----------
frame : Frame
The frame to push.
Returns
-------
A context manager that pops this frame when exits.
"""
with ExitStack() as stack:
stack.enter_context(frame)
_ffi_api.IRDocsifierPushFrame(self, frame) # type: ignore # pylint: disable=no-member
stack.callback(_ffi_api.IRDocsifierPopFrame, self) # type: ignore # pylint: disable=no-member
yield frame
Loading

0 comments on commit 9bff2ca

Please sign in to comment.