Skip to content

Commit

Permalink
Add entry point
Browse files Browse the repository at this point in the history
  • Loading branch information
yelite committed Aug 16, 2022
1 parent 247c54b commit 2cab720
Show file tree
Hide file tree
Showing 11 changed files with 327 additions and 3 deletions.
54 changes: 54 additions & 0 deletions include/tvm/script/printer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_SCRIPT_PRINTER_H_
#define TVM_SCRIPT_PRINTER_H_

#include <tvm/node/node.h>
#include <tvm/node/object_path.h>

namespace tvm {
namespace script {

/*!
* \brief Print IR graph as TVMScript code
*
* \param root_node The root node to print.
* \param ir_name The dispatch token of the target IR, e.g., "tir", "relax".
* \param ir_prefix The symbol name for TVMScript IR namespaces. For example, {"tir": "T"}.
* \param indent_spaces Number of spaces used for indentation
* \param print_line_numbers Whether to print line numbers
* \param num_context_lines Number of context lines to print around the underlined text
* \param path_to_underline Object path to be underlined
*
* \return the TVMScript code as string.
*/
String AsScript( //
const ObjectRef& root_node, //
String ir_name, //
Map<String, String> ir_prefix, //
int indent_spaces = 4, //
bool print_line_numbers = false, //
int num_context_lines = -1, //
Optional<ObjectPath> path_to_underline = NullOpt //
);

} // namespace script
} // namespace tvm

#endif // TVM_SCRIPT_PRINTER_H_
6 changes: 6 additions & 0 deletions include/tvm/script/printer/doc.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,12 @@ class ExprDoc : public Doc {
ExprDoc() = default;

public:
/*!
* \brief Create a doc representing index access on the current ExprDoc
* \param indices The indices to access.
*/
ExprDoc operator[](Array<Doc> indices) const;

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ExprDoc, Doc, ExprDocNode);
};

Expand Down
41 changes: 41 additions & 0 deletions include/tvm/script/printer/ir_docsifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,47 @@ class IRDocsifier : public ObjectRef {
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRDocsifier, ObjectRef, IRDocsifierNode);
};

/*!
* \brief A wrapper object to provide injection point for printer of each IR.
*
* For any IR node to be transformed by IRDocsifier, it will be wrapped by RootNodeContainer
* and be dispatched to the corresponding function first. This provides an injection point for
* each IR's printer implemention to add specialized logic, for example, pushing a special
* Frame to the IRDocsifier before doing any IR->Doc transformation.
*
* \code
* TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
* .set_dispatch("relax", [](TracedObject<RootNodeContainer> obj, IRDocsifier p) {
* const ObjectRef& root_node = obj.Get()->root_node;
* // For example, relax printer can create a Frame specialized to Relax here
* RelaxGeneralFrame frame;
* auto ctx = p->WithFrame(frame);
* // More specialized logic for your IR.
* return p->AsDoc<Doc>(MakeTraced(root_node));
* });
* \endcode
*/
class RootNodeContainerNode : public Object {
public:
/*! \brief The root node to print. */
ObjectRef root_node;

void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("root_node", &root_node); }

static constexpr const char* _type_key = "script.printer.RootNodeContainer";
TVM_DECLARE_FINAL_OBJECT_INFO(RootNodeContainerNode, Object);
};

class RootNodeContainer : public ObjectRef {
public:
/*!
* \brief Constructor of RootNodeContainer.
* \param root_node The root node to print.
* */
explicit RootNodeContainer(ObjectRef root_node);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RootNodeContainer, ObjectRef, RootNodeContainerNode);
};

} // namespace printer
} // namespace script
} // namespace tvm
Expand Down
1 change: 1 addition & 0 deletions python/tvm/script/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from . import tir

from .parser import ir_module, from_source
from .as_script import as_script
45 changes: 45 additions & 0 deletions python/tvm/script/as_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This file contains the entry point of TVMScript Unified Printer.
"""

from typing import Dict, Optional

from tvm.runtime.object_path import ObjectPath

from . import _ffi_api


def as_script(
root_node,
ir_name: str,
ir_prefix: Dict[str, str],
indent_spaces: int = 4,
print_line_numbers: bool = False,
num_context_lines: int = -1,
path_to_underline: Optional[ObjectPath] = None,
) -> str:
return _ffi_api.AsScript(
root_node,
ir_name,
ir_prefix,
indent_spaces,
print_line_numbers,
num_context_lines,
path_to_underline,
)
51 changes: 49 additions & 2 deletions python/tvm/script/printer/ir_docsifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,21 @@ def _ensure_cleanup_function_registered():
_CLEANUP_REGISTERED = True


@register_object("script.printer.RootNodeContainer")
class RootNodeContainer(Object):
"""
A wrapper object to provide injection point for printer of each IR.
This class shouldn't be used directly. `IRDocsifier.set_root_dispatch`
should be used instead.
"""

root_node: Object

def __init__(self, root_node: Object):
self.__init_handle_by_constructor__(_ffi_api.RootNodeContainer, root_node) # type: ignore # pylint: disable=no-member


@register_object("script.printer.IRDocsifier")
class IRDocsifier(Object):
"""
Expand Down Expand Up @@ -91,7 +106,7 @@ def __init__(self, ir_prefix: Dict[str, str]):
def set_dispatch(
cls,
node_type: Type[_TObject],
dispatch_function: Callable[[_TObject, "IRDocsifier"], Doc],
dispatch_function: Callable[[_TObject, ObjectPath, "IRDocsifier"], Doc],
dispatch_token: str = "",
) -> None:
"""
Expand All @@ -101,7 +116,7 @@ def set_dispatch(
----------
node_type : Type[_TObject]
The type of object to dispatch on.
dispatch_function : Callable[[_TObject, "IRDocsifier"], Doc]
dispatch_function : Callable[[_TObject, ObjectPath, "IRDocsifier"], Doc]
The dispatch function. It's called to transform IR node object to Doc.
dispatch_token : str
Function will only be called when this dispatch_token is the same as the one
Expand All @@ -119,6 +134,38 @@ def set_dispatch(
)
_REGISTERED_TYPES.add((dispatch_token, type_index))

@classmethod
def set_root_dispatch(
cls, dispatch_token: str, root_dispatch_function: Callable[[Object, "IRDocsifier"], Doc]
) -> None:
"""
Set the root dispatch function for an IR.
The root dispatch function will be called with the root node of an IR graph
that's being transformed to Doc. This provides an injection point for
each IR's printer implemention to add specialized logic, for example,
pushing a special Frame to the IRDocsifier before doing actual IR->Doc
transformation.
The simplest root dispatch function is
```
def f(obj, ir_docsifier)
return ir_docsifier.as_doc(obj, ObjectPath.root())
```
Parameters
----------
root_dispatch_function : Callable[[_TObject, "IRDocsifier"], Doc]
The root dispatch function. It's called with the root node to be printed.
dispatch_token : str
The dispatch token of the IR that root_dispatch_funnction applies to.
"""

def dispatch_function(obj: RootNodeContainer, _, ir_docsifier):
return root_dispatch_function(obj.root_node, ir_docsifier)

cls.set_dispatch(RootNodeContainer, dispatch_function, dispatch_token)

def as_doc(self, obj: Object, object_path: ObjectPath) -> Doc:
"""
Transform the input object into Doc.
Expand Down
54 changes: 54 additions & 0 deletions src/script/printer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/runtime/registry.h>
#include <tvm/script/printer.h>
#include <tvm/script/printer/doc.h>
#include <tvm/script/printer/doc_printer.h>
#include <tvm/script/printer/frame.h>
#include <tvm/script/printer/ir_docsifier.h>

namespace tvm {
namespace script {

using namespace printer;

String AsScript( //
const ObjectRef& root_node, //
String ir_name, //
Map<String, String> ir_prefix, //
int indent_spaces, //
bool print_line_numbers, //
int num_context_lines, //
Optional<ObjectPath> path_to_underline //
) {
IRDocsifier ir_docsifier(ir_prefix);

auto dispatch_ctx = ir_docsifier->WithDispatchToken(ir_name);

Doc doc = ir_docsifier->AsDoc<Doc>(MakeTraced(RootNodeContainer(root_node)));

return DocToPythonScript(doc, indent_spaces, print_line_numbers, num_context_lines,
path_to_underline);
}

TVM_REGISTER_GLOBAL("script.AsScript").set_body_typed(&AsScript);

} // namespace script
} // namespace tvm
2 changes: 2 additions & 0 deletions src/script/printer/doc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ ExprDoc ExprDocNode::Call(Array<ExprDoc, void> args, Array<String, void> kwargs_
return CallDoc(GetRef<ExprDoc>(this), args, kwargs_keys, kwargs_values);
}

ExprDoc ExprDoc::operator[](Array<Doc> indices) const { return (*get())[indices]; }

StmtBlockDoc::StmtBlockDoc(Array<StmtDoc> stmts) {
ObjectPtr<StmtBlockDocNode> n = make_object<StmtBlockDocNode>();
n->stmts = stmts;
Expand Down
32 changes: 32 additions & 0 deletions src/script/printer/ir_docsifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* under the License.
*/
#include <tvm/runtime/container/base.h>
#include <tvm/runtime/logging.h>
#include <tvm/runtime/registry.h>
#include <tvm/script/printer/ir_docsifier.h>
#include <tvm/script/printer/traced_object.h>
Expand All @@ -42,6 +43,31 @@ IRDocsifier::FType& IRDocsifier::vtable() {
return inst;
}

RootNodeContainer::RootNodeContainer(ObjectRef root_node) {
auto n = make_object<RootNodeContainerNode>();
n->root_node = std::move(root_node);
data_ = std::move(n);
}

// Add a default dispatch for the RootNodeContainer to throw error.
// To add implementation for a new IR, RootNodeContainer needs to be
// registered under the dispatch token of that IR, like:
// \code
// TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
// .set_dispatch("relax", [](TracedObject<RootNodeContainer> obj, IRDocsifier p) {
// const ObjectRef& root_node = obj.Get()->root_node;
// \\ More specialized logic for your IR.
// return p->AsDoc<Doc>(MakeTraced(root_node));
// });
// \endcode
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch([](TracedObject<RootNodeContainer> obj, IRDocsifier p) -> Doc {
String top_dispatch_token = p->dispatch_tokens.back();
ICHECK_NE(top_dispatch_token, "");
ICHECK(false) << "Printing IR " << top_dispatch_token << " is not implemented.";
throw;
});

TVM_REGISTER_NODE_TYPE(IRDocsifierNode);
TVM_REGISTER_GLOBAL("script.printer.IRDocsifier").set_body_typed([](Map<String, String> ir_prefix) {
return IRDocsifier(ir_prefix);
Expand Down Expand Up @@ -71,6 +97,12 @@ TVM_REGISTER_GLOBAL("script.printer.IRDocsifierRemoveDispatch")
.set_body_typed([](String token, uint64_t type_index) {
IRDocsifier::vtable().remove_dispatch(token, type_index);
});

TVM_REGISTER_NODE_TYPE(RootNodeContainerNode);
TVM_REGISTER_GLOBAL("script.printer.RootNodeContainer").set_body_typed([](ObjectRef root_node) {
return RootNodeContainer(root_node);
});

} // namespace printer
} // namespace script
} // namespace tvm
30 changes: 30 additions & 0 deletions tests/python/unittest/test_tvmscript_printer_entry_point.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest

from tvm.error import TVMError
from tvm.script import as_script
from tvm.tir import FloatImm


def test_as_script_unknown_ir():
ir_node = FloatImm("float32", 1.0)

with pytest.raises(TVMError) as e:
as_script(ir_node, "test_xyz", {})

assert "test_xyz" in str(e.value)
Loading

0 comments on commit 2cab720

Please sign in to comment.