Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript] Printer entry point #12462

Merged
merged 3 commits into from
Aug 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions include/tvm/script/printer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_SCRIPT_PRINTER_H_
#define TVM_SCRIPT_PRINTER_H_

#include <tvm/node/node.h>
#include <tvm/node/object_path.h>

namespace tvm {
namespace script {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, what do you think works better, tvm::script::printer or tvm::script

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the function name is Script, I think tvm::script::printer is a better choice. Otherwise the purpose of this function is less clear, because it's a free function rather than a method on IR node.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved it to the tvm::script::printer

namespace printer {

/*!
* \brief Print IR graph as TVMScript code
*
* \param root_node The root node to print.
* \param ir_name The dispatch token of the target IR, e.g., "tir", "relax".
* \param ir_prefix The symbol name for TVMScript IR namespaces. For example, {"tir": "T"}.
* \param indent_spaces Number of spaces used for indentation
* \param print_line_numbers Whether to print line numbers
* \param num_context_lines Number of context lines to print around the underlined text
* \param path_to_underline Object path to be underlined
*
* \return the TVMScript code as string.
*/
String Script( //
const ObjectRef& root_node, //
String ir_name, //
Map<String, String> ir_prefix, //
int indent_spaces = 4, //
bool print_line_numbers = false, //
int num_context_lines = -1, //
Optional<ObjectPath> path_to_underline = NullOpt //
);

} // namespace printer
} // namespace script
} // namespace tvm

#endif // TVM_SCRIPT_PRINTER_H_
6 changes: 6 additions & 0 deletions include/tvm/script/printer/doc.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,12 @@ class ExprDoc : public Doc {
ExprDoc() = default;

public:
/*!
* \brief Create a doc representing index access on the current ExprDoc
* \param indices The indices to access.
*/
ExprDoc operator[](Array<Doc> indices) const;

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ExprDoc, Doc, ExprDocNode);
};

Expand Down
41 changes: 41 additions & 0 deletions include/tvm/script/printer/ir_docsifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,47 @@ class IRDocsifier : public ObjectRef {
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRDocsifier, ObjectRef, IRDocsifierNode);
};

/*!
* \brief A wrapper object to provide injection point for printer of each IR.
*
* For any IR node to be transformed by IRDocsifier, it will be wrapped by RootNodeContainer
* and be dispatched to the corresponding function first. This provides an injection point for
* each IR's printer implemention to add specialized logic, for example, pushing a special
* Frame to the IRDocsifier before doing any IR->Doc transformation.
*
* \code
* TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
* .set_dispatch("relax", [](TracedObject<RootNodeContainer> obj, IRDocsifier p) {
* const ObjectRef& root_node = obj.Get()->root_node;
* // For example, relax printer can create a Frame specialized to Relax here
* RelaxGeneralFrame frame;
* auto ctx = p->WithFrame(frame);
* // More specialized logic for your IR.
* return p->AsDoc<Doc>(MakeTraced(root_node));
* });
* \endcode
*/
class RootNodeContainerNode : public Object {
public:
/*! \brief The root node to print. */
ObjectRef root_node;

void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("root_node", &root_node); }

static constexpr const char* _type_key = "script.printer.RootNodeContainer";
TVM_DECLARE_FINAL_OBJECT_INFO(RootNodeContainerNode, Object);
};

class RootNodeContainer : public ObjectRef {
public:
/*!
* \brief Constructor of RootNodeContainer.
* \param root_node The root node to print.
* */
explicit RootNodeContainer(ObjectRef root_node);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RootNodeContainer, ObjectRef, RootNodeContainerNode);
};

} // namespace printer
} // namespace script
} // namespace tvm
Expand Down
1 change: 1 addition & 0 deletions python/tvm/script/printer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@
"""

from . import _ffi_api
from .entry import script
71 changes: 71 additions & 0 deletions python/tvm/script/printer/entry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This file contains the entry point of TVMScript Unified Printer.
"""

from typing import Dict, Optional

from tvm.runtime import Object, ObjectPath

from . import _ffi_api


def script( # pylint: disable=too-many-arguments
root_node: Object,
ir_name: str,
ir_prefix: Dict[str, str],
indent_spaces: int = 4,
print_line_numbers: bool = False,
num_context_lines: int = -1,
path_to_underline: Optional[ObjectPath] = None,
) -> str:
"""
Print IR graph as TVMScript code

Parameters
----------
root_node : Object
The root node to print.
ir_name : str
The dispatch token of the target IR, e.g., "tir", "relax".
ir_prefix : Dict[str, str]
The symbol name for TVMScript IR namespaces. For example,
{"tir": "T"}.
indent_spaces : int
The number of indent spaces to use in the output
print_line_numbers: bool
Whether to print line numbers
num_context_lines : Optional[int]
Number of context lines to print around the underlined text
path_to_underline : Optional[ObjectPath]
Object path to be underlined

Returns
-------
script : str
The TVMScript code of the root_node
"""
return _ffi_api.Script( # type: ignore # pylint: disable=no-member
root_node,
ir_name,
ir_prefix,
indent_spaces,
print_line_numbers,
num_context_lines,
path_to_underline,
)
51 changes: 49 additions & 2 deletions python/tvm/script/printer/ir_docsifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,21 @@ def _ensure_cleanup_function_registered():
_CLEANUP_REGISTERED = True


@register_object("script.printer.RootNodeContainer")
class RootNodeContainer(Object):
"""
A wrapper object to provide injection point for printer of each IR.

This class shouldn't be used directly. `IRDocsifier.set_root_dispatch`
should be used instead.
"""

root_node: Object

def __init__(self, root_node: Object):
self.__init_handle_by_constructor__(_ffi_api.RootNodeContainer, root_node) # type: ignore # pylint: disable=no-member


@register_object("script.printer.IRDocsifier")
class IRDocsifier(Object):
"""
Expand Down Expand Up @@ -91,7 +106,7 @@ def __init__(self, ir_prefix: Dict[str, str]):
def set_dispatch(
cls,
node_type: Type[_TObject],
dispatch_function: Callable[[_TObject, "IRDocsifier"], Doc],
dispatch_function: Callable[[_TObject, ObjectPath, "IRDocsifier"], Doc],
dispatch_token: str = "",
) -> None:
"""
Expand All @@ -101,7 +116,7 @@ def set_dispatch(
----------
node_type : Type[_TObject]
The type of object to dispatch on.
dispatch_function : Callable[[_TObject, "IRDocsifier"], Doc]
dispatch_function : Callable[[_TObject, ObjectPath, "IRDocsifier"], Doc]
The dispatch function. It's called to transform IR node object to Doc.
dispatch_token : str
Function will only be called when this dispatch_token is the same as the one
Expand All @@ -119,6 +134,38 @@ def set_dispatch(
)
_REGISTERED_TYPES.add((dispatch_token, type_index))

@classmethod
def set_root_dispatch(
cls, dispatch_token: str, root_dispatch_function: Callable[[Object, "IRDocsifier"], Doc]
) -> None:
"""
Set the root dispatch function for an IR.

The root dispatch function will be called with the root node of an IR graph
that's being transformed to Doc. This provides an injection point for
each IR's printer implemention to add specialized logic, for example,
pushing a special Frame to the IRDocsifier before doing actual IR->Doc
transformation.

The simplest root dispatch function is
```
def f(obj, ir_docsifier)
return ir_docsifier.as_doc(obj, ObjectPath.root())
```

Parameters
----------
root_dispatch_function : Callable[[_TObject, "IRDocsifier"], Doc]
The root dispatch function. It's called with the root node to be printed.
dispatch_token : str
The dispatch token of the IR that root_dispatch_funnction applies to.
"""

def dispatch_function(obj: RootNodeContainer, _, ir_docsifier):
return root_dispatch_function(obj.root_node, ir_docsifier)

cls.set_dispatch(RootNodeContainer, dispatch_function, dispatch_token)

def as_doc(self, obj: Object, object_path: ObjectPath) -> Doc:
"""
Transform the input object into Doc.
Expand Down
54 changes: 54 additions & 0 deletions src/script/printer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/runtime/registry.h>
#include <tvm/script/printer.h>
#include <tvm/script/printer/doc.h>
#include <tvm/script/printer/doc_printer.h>
#include <tvm/script/printer/frame.h>
#include <tvm/script/printer/ir_docsifier.h>

namespace tvm {
namespace script {
namespace printer {

String Script( //
const ObjectRef& root_node, //
String ir_name, //
Map<String, String> ir_prefix, //
int indent_spaces, //
bool print_line_numbers, //
int num_context_lines, //
Optional<ObjectPath> path_to_underline //
) {
IRDocsifier ir_docsifier(ir_prefix);

auto dispatch_ctx = ir_docsifier->WithDispatchToken(ir_name);

Doc doc = ir_docsifier->AsDoc<Doc>(MakeTraced(RootNodeContainer(root_node)));

return DocToPythonScript(doc, indent_spaces, print_line_numbers, num_context_lines,
path_to_underline);
}

TVM_REGISTER_GLOBAL("script.printer.Script").set_body_typed(&Script);

} // namespace printer
} // namespace script
} // namespace tvm
2 changes: 2 additions & 0 deletions src/script/printer/doc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ ExprDoc ExprDocNode::Call(Array<ExprDoc, void> args, Array<String, void> kwargs_
return CallDoc(GetRef<ExprDoc>(this), args, kwargs_keys, kwargs_values);
}

ExprDoc ExprDoc::operator[](Array<Doc> indices) const { return (*get())[indices]; }

StmtBlockDoc::StmtBlockDoc(Array<StmtDoc> stmts) {
ObjectPtr<StmtBlockDocNode> n = make_object<StmtBlockDocNode>();
n->stmts = stmts;
Expand Down
32 changes: 32 additions & 0 deletions src/script/printer/ir_docsifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* under the License.
*/
#include <tvm/runtime/container/base.h>
#include <tvm/runtime/logging.h>
#include <tvm/runtime/registry.h>
#include <tvm/script/printer/ir_docsifier.h>
#include <tvm/script/printer/traced_object.h>
Expand All @@ -42,6 +43,31 @@ IRDocsifier::FType& IRDocsifier::vtable() {
return inst;
}

RootNodeContainer::RootNodeContainer(ObjectRef root_node) {
auto n = make_object<RootNodeContainerNode>();
n->root_node = std::move(root_node);
data_ = std::move(n);
}

// Add a default dispatch for the RootNodeContainer to throw error.
// To add implementation for a new IR, RootNodeContainer needs to be
// registered under the dispatch token of that IR, like:
// \code
// TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
// .set_dispatch("relax", [](TracedObject<RootNodeContainer> obj, IRDocsifier p) {
// const ObjectRef& root_node = obj.Get()->root_node;
// \\ More specialized logic for your IR.
// return p->AsDoc<Doc>(MakeTraced(root_node));
// });
// \endcode
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch([](TracedObject<RootNodeContainer> obj, IRDocsifier p) -> Doc {
String top_dispatch_token = p->dispatch_tokens.back();
ICHECK_NE(top_dispatch_token, "");
ICHECK(false) << "Printing IR " << top_dispatch_token << " is not implemented.";
throw;
});

TVM_REGISTER_NODE_TYPE(IRDocsifierNode);
TVM_REGISTER_GLOBAL("script.printer.IRDocsifier").set_body_typed([](Map<String, String> ir_prefix) {
return IRDocsifier(ir_prefix);
Expand Down Expand Up @@ -71,6 +97,12 @@ TVM_REGISTER_GLOBAL("script.printer.IRDocsifierRemoveDispatch")
.set_body_typed([](String token, uint64_t type_index) {
IRDocsifier::vtable().remove_dispatch(token, type_index);
});

TVM_REGISTER_NODE_TYPE(RootNodeContainerNode);
TVM_REGISTER_GLOBAL("script.printer.RootNodeContainer").set_body_typed([](ObjectRef root_node) {
return RootNodeContainer(root_node);
});

} // namespace printer
} // namespace script
} // namespace tvm
Loading