From 37e125eebc06c3cb8c1d6cf9df559ed2e63fbfce Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 16 Aug 2022 14:30:22 -0400 Subject: [PATCH 1/3] Add entry point --- include/tvm/script/printer.h | 54 +++++++++++++++++++ include/tvm/script/printer/doc.h | 6 +++ include/tvm/script/printer/ir_docsifier.h | 41 ++++++++++++++ python/tvm/script/__init__.py | 1 + python/tvm/script/as_script.py | 45 ++++++++++++++++ python/tvm/script/printer/ir_docsifier.py | 51 +++++++++++++++++- src/script/printer.cc | 54 +++++++++++++++++++ src/script/printer/doc.cc | 2 + src/script/printer/ir_docsifier.cc | 32 +++++++++++ .../test_tvmscript_printer_entry_point.py | 30 +++++++++++ .../test_tvmscript_printer_irdocsifier.py | 14 ++++- 11 files changed, 327 insertions(+), 3 deletions(-) create mode 100644 include/tvm/script/printer.h create mode 100644 python/tvm/script/as_script.py create mode 100644 src/script/printer.cc create mode 100644 tests/python/unittest/test_tvmscript_printer_entry_point.py diff --git a/include/tvm/script/printer.h b/include/tvm/script/printer.h new file mode 100644 index 000000000000..471ffe06d92a --- /dev/null +++ b/include/tvm/script/printer.h @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_H_ +#define TVM_SCRIPT_PRINTER_H_ + +#include +#include + +namespace tvm { +namespace script { + +/*! + * \brief Print IR graph as TVMScript code + * + * \param root_node The root node to print. + * \param ir_name The dispatch token of the target IR, e.g., "tir", "relax". + * \param ir_prefix The symbol name for TVMScript IR namespaces. For example, {"tir": "T"}. + * \param indent_spaces Number of spaces used for indentation + * \param print_line_numbers Whether to print line numbers + * \param num_context_lines Number of context lines to print around the underlined text + * \param path_to_underline Object path to be underlined + * + * \return the TVMScript code as string. + */ +String AsScript( // + const ObjectRef& root_node, // + String ir_name, // + Map ir_prefix, // + int indent_spaces = 4, // + bool print_line_numbers = false, // + int num_context_lines = -1, // + Optional path_to_underline = NullOpt // +); + +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_H_ diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index 55faed33fb89..9d777ed16eed 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -125,6 +125,12 @@ class ExprDoc : public Doc { ExprDoc() = default; public: + /*! + * \brief Create a doc representing index access on the current ExprDoc + * \param indices The indices to access. + */ + ExprDoc operator[](Array indices) const; + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ExprDoc, Doc, ExprDocNode); }; diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index c882cf1a0f90..8945bd6e7a94 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -182,6 +182,47 @@ class IRDocsifier : public ObjectRef { TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRDocsifier, ObjectRef, IRDocsifierNode); }; +/*! + * \brief A wrapper object to provide injection point for printer of each IR. + * + * For any IR node to be transformed by IRDocsifier, it will be wrapped by RootNodeContainer + * and be dispatched to the corresponding function first. This provides an injection point for + * each IR's printer implemention to add specialized logic, for example, pushing a special + * Frame to the IRDocsifier before doing any IR->Doc transformation. + * + * \code + * TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + * .set_dispatch("relax", [](TracedObject obj, IRDocsifier p) { + * const ObjectRef& root_node = obj.Get()->root_node; + * // For example, relax printer can create a Frame specialized to Relax here + * RelaxGeneralFrame frame; + * auto ctx = p->WithFrame(frame); + * // More specialized logic for your IR. + * return p->AsDoc(MakeTraced(root_node)); + * }); + * \endcode + */ +class RootNodeContainerNode : public Object { + public: + /*! \brief The root node to print. */ + ObjectRef root_node; + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("root_node", &root_node); } + + static constexpr const char* _type_key = "script.printer.RootNodeContainer"; + TVM_DECLARE_FINAL_OBJECT_INFO(RootNodeContainerNode, Object); +}; + +class RootNodeContainer : public ObjectRef { + public: + /*! + * \brief Constructor of RootNodeContainer. + * \param root_node The root node to print. + * */ + explicit RootNodeContainer(ObjectRef root_node); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RootNodeContainer, ObjectRef, RootNodeContainerNode); +}; + } // namespace printer } // namespace script } // namespace tvm diff --git a/python/tvm/script/__init__.py b/python/tvm/script/__init__.py index 555659d0c55e..cc43e8ecf535 100644 --- a/python/tvm/script/__init__.py +++ b/python/tvm/script/__init__.py @@ -19,3 +19,4 @@ from . import tir from .parser import ir_module, from_source +from .as_script import as_script diff --git a/python/tvm/script/as_script.py b/python/tvm/script/as_script.py new file mode 100644 index 000000000000..753e07e2e4c9 --- /dev/null +++ b/python/tvm/script/as_script.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This file contains the entry point of TVMScript Unified Printer. +""" + +from typing import Dict, Optional + +from tvm.runtime.object_path import ObjectPath + +from . import _ffi_api + + +def as_script( + root_node, + ir_name: str, + ir_prefix: Dict[str, str], + indent_spaces: int = 4, + print_line_numbers: bool = False, + num_context_lines: int = -1, + path_to_underline: Optional[ObjectPath] = None, +) -> str: + return _ffi_api.AsScript( + root_node, + ir_name, + ir_prefix, + indent_spaces, + print_line_numbers, + num_context_lines, + path_to_underline, + ) diff --git a/python/tvm/script/printer/ir_docsifier.py b/python/tvm/script/printer/ir_docsifier.py index 16f3ab62ecab..c5ba8a498b1e 100644 --- a/python/tvm/script/printer/ir_docsifier.py +++ b/python/tvm/script/printer/ir_docsifier.py @@ -59,6 +59,21 @@ def _ensure_cleanup_function_registered(): _CLEANUP_REGISTERED = True +@register_object("script.printer.RootNodeContainer") +class RootNodeContainer(Object): + """ + A wrapper object to provide injection point for printer of each IR. + + This class shouldn't be used directly. `IRDocsifier.set_root_dispatch` + should be used instead. + """ + + root_node: Object + + def __init__(self, root_node: Object): + self.__init_handle_by_constructor__(_ffi_api.RootNodeContainer, root_node) # type: ignore # pylint: disable=no-member + + @register_object("script.printer.IRDocsifier") class IRDocsifier(Object): """ @@ -91,7 +106,7 @@ def __init__(self, ir_prefix: Dict[str, str]): def set_dispatch( cls, node_type: Type[_TObject], - dispatch_function: Callable[[_TObject, "IRDocsifier"], Doc], + dispatch_function: Callable[[_TObject, ObjectPath, "IRDocsifier"], Doc], dispatch_token: str = "", ) -> None: """ @@ -101,7 +116,7 @@ def set_dispatch( ---------- node_type : Type[_TObject] The type of object to dispatch on. - dispatch_function : Callable[[_TObject, "IRDocsifier"], Doc] + dispatch_function : Callable[[_TObject, ObjectPath, "IRDocsifier"], Doc] The dispatch function. It's called to transform IR node object to Doc. dispatch_token : str Function will only be called when this dispatch_token is the same as the one @@ -119,6 +134,38 @@ def set_dispatch( ) _REGISTERED_TYPES.add((dispatch_token, type_index)) + @classmethod + def set_root_dispatch( + cls, dispatch_token: str, root_dispatch_function: Callable[[Object, "IRDocsifier"], Doc] + ) -> None: + """ + Set the root dispatch function for an IR. + + The root dispatch function will be called with the root node of an IR graph + that's being transformed to Doc. This provides an injection point for + each IR's printer implemention to add specialized logic, for example, + pushing a special Frame to the IRDocsifier before doing actual IR->Doc + transformation. + + The simplest root dispatch function is + ``` + def f(obj, ir_docsifier) + return ir_docsifier.as_doc(obj, ObjectPath.root()) + ``` + + Parameters + ---------- + root_dispatch_function : Callable[[_TObject, "IRDocsifier"], Doc] + The root dispatch function. It's called with the root node to be printed. + dispatch_token : str + The dispatch token of the IR that root_dispatch_funnction applies to. + """ + + def dispatch_function(obj: RootNodeContainer, _, ir_docsifier): + return root_dispatch_function(obj.root_node, ir_docsifier) + + cls.set_dispatch(RootNodeContainer, dispatch_function, dispatch_token) + def as_doc(self, obj: Object, object_path: ObjectPath) -> Doc: """ Transform the input object into Doc. diff --git a/src/script/printer.cc b/src/script/printer.cc new file mode 100644 index 000000000000..c8fce03e800e --- /dev/null +++ b/src/script/printer.cc @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace script { + +using namespace printer; + +String AsScript( // + const ObjectRef& root_node, // + String ir_name, // + Map ir_prefix, // + int indent_spaces, // + bool print_line_numbers, // + int num_context_lines, // + Optional path_to_underline // +) { + IRDocsifier ir_docsifier(ir_prefix); + + auto dispatch_ctx = ir_docsifier->WithDispatchToken(ir_name); + + Doc doc = ir_docsifier->AsDoc(MakeTraced(RootNodeContainer(root_node))); + + return DocToPythonScript(doc, indent_spaces, print_line_numbers, num_context_lines, + path_to_underline); +} + +TVM_REGISTER_GLOBAL("script.AsScript").set_body_typed(&AsScript); + +} // namespace script +} // namespace tvm diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index b94d4c55bfbb..d6f5ff35ab53 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -40,6 +40,8 @@ ExprDoc ExprDocNode::Call(Array args, Array kwargs_ return CallDoc(GetRef(this), args, kwargs_keys, kwargs_values); } +ExprDoc ExprDoc::operator[](Array indices) const { return (*get())[indices]; } + StmtBlockDoc::StmtBlockDoc(Array stmts) { ObjectPtr n = make_object(); n->stmts = stmts; diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index 7d9ba2352d88..b72ed48db63b 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include #include #include @@ -42,6 +43,31 @@ IRDocsifier::FType& IRDocsifier::vtable() { return inst; } +RootNodeContainer::RootNodeContainer(ObjectRef root_node) { + auto n = make_object(); + n->root_node = std::move(root_node); + data_ = std::move(n); +} + +// Add a default dispatch for the RootNodeContainer to throw error. +// To add implementation for a new IR, RootNodeContainer needs to be +// registered under the dispatch token of that IR, like: +// \code +// TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) +// .set_dispatch("relax", [](TracedObject obj, IRDocsifier p) { +// const ObjectRef& root_node = obj.Get()->root_node; +// \\ More specialized logic for your IR. +// return p->AsDoc(MakeTraced(root_node)); +// }); +// \endcode +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject obj, IRDocsifier p) -> Doc { + String top_dispatch_token = p->dispatch_tokens.back(); + ICHECK_NE(top_dispatch_token, ""); + ICHECK(false) << "Printing IR " << top_dispatch_token << " is not implemented."; + throw; + }); + TVM_REGISTER_NODE_TYPE(IRDocsifierNode); TVM_REGISTER_GLOBAL("script.printer.IRDocsifier").set_body_typed([](Map ir_prefix) { return IRDocsifier(ir_prefix); @@ -71,6 +97,12 @@ TVM_REGISTER_GLOBAL("script.printer.IRDocsifierRemoveDispatch") .set_body_typed([](String token, uint64_t type_index) { IRDocsifier::vtable().remove_dispatch(token, type_index); }); + +TVM_REGISTER_NODE_TYPE(RootNodeContainerNode); +TVM_REGISTER_GLOBAL("script.printer.RootNodeContainer").set_body_typed([](ObjectRef root_node) { + return RootNodeContainer(root_node); +}); + } // namespace printer } // namespace script } // namespace tvm diff --git a/tests/python/unittest/test_tvmscript_printer_entry_point.py b/tests/python/unittest/test_tvmscript_printer_entry_point.py new file mode 100644 index 000000000000..3fbffb39463f --- /dev/null +++ b/tests/python/unittest/test_tvmscript_printer_entry_point.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from tvm.error import TVMError +from tvm.script import as_script +from tvm.tir import FloatImm + + +def test_as_script_unknown_ir(): + ir_node = FloatImm("float32", 1.0) + + with pytest.raises(TVMError) as e: + as_script(ir_node, "test_xyz", {}) + + assert "test_xyz" in str(e.value) diff --git a/tests/python/unittest/test_tvmscript_printer_irdocsifier.py b/tests/python/unittest/test_tvmscript_printer_irdocsifier.py index 357a710584c1..d9d552ce4b9f 100644 --- a/tests/python/unittest/test_tvmscript_printer_irdocsifier.py +++ b/tests/python/unittest/test_tvmscript_printer_irdocsifier.py @@ -19,7 +19,7 @@ from tvm.runtime import ObjectPath from tvm.script.printer.doc import IdDoc from tvm.script.printer.frame import MetadataFrame, VarDefFrame -from tvm.script.printer.ir_docsifier import IRDocsifier +from tvm.script.printer.ir_docsifier import IRDocsifier, RootNodeContainer from tvm.tir import Var @@ -40,9 +40,16 @@ def printer(obj, object_path, ir_docsifier): # pylint: disable=unused-argument return printer +def _root_dispatch_function(obj, ir_docsifier): + doc = ir_docsifier.as_doc(obj, ObjectPath.root()) + doc.source_paths = [ObjectPath.root().attr("irdocsifier_test")] + return doc + + # Because the dispatch table is global, tests should only set dispatch function under # unique dispatch token. IRDocsifier.set_dispatch(Var, _get_id_doc_printer("x"), f"{__file__}") +IRDocsifier.set_root_dispatch(f"{__file__}", _root_dispatch_function) def test_set_dispatch(ir_docsifier): @@ -55,6 +62,11 @@ def test_set_dispatch(ir_docsifier): assert doc.name == "x" +def test_set_root_dispatch(ir_docsifier): + doc = ir_docsifier.as_doc(RootNodeContainer(Var("x", dtype="int8")), ObjectPath.root()) + assert ObjectPath.root().attr("irdocsifier_test") in doc.source_paths + + def test_as_doc(ir_docsifier): object_path = ObjectPath.root() doc = ir_docsifier.as_doc(Var("x", "int8"), ObjectPath.root()) From 906dad9e8a23588f76f45578ee43f29643fff5ff Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Fri, 19 Aug 2022 11:08:44 -0400 Subject: [PATCH 2/3] Rename AsScript to Script --- include/tvm/script/printer.h | 4 ++- python/tvm/script/__init__.py | 1 - python/tvm/script/printer/__init__.py | 1 + .../script/{as_script.py => printer/entry.py} | 34 ++++++++++++++++--- src/script/printer.cc | 8 ++--- .../test_tvmscript_printer_entry_point.py | 4 +-- 6 files changed, 40 insertions(+), 12 deletions(-) rename python/tvm/script/{as_script.py => printer/entry.py} (58%) diff --git a/include/tvm/script/printer.h b/include/tvm/script/printer.h index 471ffe06d92a..b0fc54108c92 100644 --- a/include/tvm/script/printer.h +++ b/include/tvm/script/printer.h @@ -24,6 +24,7 @@ namespace tvm { namespace script { +namespace printer { /*! * \brief Print IR graph as TVMScript code @@ -38,7 +39,7 @@ namespace script { * * \return the TVMScript code as string. */ -String AsScript( // +String Script( // const ObjectRef& root_node, // String ir_name, // Map ir_prefix, // @@ -48,6 +49,7 @@ String AsScript( // Optional path_to_underline = NullOpt // ); +} // namespace printer } // namespace script } // namespace tvm diff --git a/python/tvm/script/__init__.py b/python/tvm/script/__init__.py index cc43e8ecf535..555659d0c55e 100644 --- a/python/tvm/script/__init__.py +++ b/python/tvm/script/__init__.py @@ -19,4 +19,3 @@ from . import tir from .parser import ir_module, from_source -from .as_script import as_script diff --git a/python/tvm/script/printer/__init__.py b/python/tvm/script/printer/__init__.py index 84ab7b0ba836..d49614db0f21 100644 --- a/python/tvm/script/printer/__init__.py +++ b/python/tvm/script/printer/__init__.py @@ -24,3 +24,4 @@ """ from . import _ffi_api +from .entry import script diff --git a/python/tvm/script/as_script.py b/python/tvm/script/printer/entry.py similarity index 58% rename from python/tvm/script/as_script.py rename to python/tvm/script/printer/entry.py index 753e07e2e4c9..c812db8bd249 100644 --- a/python/tvm/script/as_script.py +++ b/python/tvm/script/printer/entry.py @@ -20,13 +20,13 @@ from typing import Dict, Optional -from tvm.runtime.object_path import ObjectPath +from tvm.runtime import Object, ObjectPath from . import _ffi_api -def as_script( - root_node, +def script( # pylint: disable=too-many-arguments + root_node: Object, ir_name: str, ir_prefix: Dict[str, str], indent_spaces: int = 4, @@ -34,7 +34,33 @@ def as_script( num_context_lines: int = -1, path_to_underline: Optional[ObjectPath] = None, ) -> str: - return _ffi_api.AsScript( + """ + Print IR graph as TVMScript code + + Parameters + ---------- + root_node : Object + The root node to print. + ir_name : str + The dispatch token of the target IR, e.g., "tir", "relax". + ir_prefix : Dict[str, str] + The symbol name for TVMScript IR namespaces. For example, + {"tir": "T"}. + indent_spaces : int + The number of indent spaces to use in the output + print_line_numbers: bool + Whether to print line numbers + num_context_lines : Optional[int] + Number of context lines to print around the underlined text + path_to_underline : Optional[ObjectPath] + Object path to be underlined + + Returns + ------- + script : str + The TVMScript code of the root_node + """ + return _ffi_api.Script( root_node, ir_name, ir_prefix, diff --git a/src/script/printer.cc b/src/script/printer.cc index c8fce03e800e..051b774ba6ac 100644 --- a/src/script/printer.cc +++ b/src/script/printer.cc @@ -26,10 +26,9 @@ namespace tvm { namespace script { +namespace printer { -using namespace printer; - -String AsScript( // +String Script( // const ObjectRef& root_node, // String ir_name, // Map ir_prefix, // @@ -48,7 +47,8 @@ String AsScript( // path_to_underline); } -TVM_REGISTER_GLOBAL("script.AsScript").set_body_typed(&AsScript); +TVM_REGISTER_GLOBAL("script.printer.Script").set_body_typed(&Script); +} // namespace printer } // namespace script } // namespace tvm diff --git a/tests/python/unittest/test_tvmscript_printer_entry_point.py b/tests/python/unittest/test_tvmscript_printer_entry_point.py index 3fbffb39463f..208386dbdd4a 100644 --- a/tests/python/unittest/test_tvmscript_printer_entry_point.py +++ b/tests/python/unittest/test_tvmscript_printer_entry_point.py @@ -17,7 +17,7 @@ import pytest from tvm.error import TVMError -from tvm.script import as_script +from tvm.script.printer import script from tvm.tir import FloatImm @@ -25,6 +25,6 @@ def test_as_script_unknown_ir(): ir_node = FloatImm("float32", 1.0) with pytest.raises(TVMError) as e: - as_script(ir_node, "test_xyz", {}) + script(ir_node, "test_xyz", {}) assert "test_xyz" in str(e.value) From fdcee7ca7b9652f800f7c9d2e81c20d593d23dec Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Fri, 19 Aug 2022 16:50:21 -0400 Subject: [PATCH 3/3] Fix lint --- python/tvm/script/printer/entry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/script/printer/entry.py b/python/tvm/script/printer/entry.py index c812db8bd249..c015702af09b 100644 --- a/python/tvm/script/printer/entry.py +++ b/python/tvm/script/printer/entry.py @@ -60,7 +60,7 @@ def script( # pylint: disable=too-many-arguments script : str The TVMScript code of the root_node """ - return _ffi_api.Script( + return _ffi_api.Script( # type: ignore # pylint: disable=no-member root_node, ir_name, ir_prefix,