Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[TVMScript] Add source_paths to Doc (apache#12324)
Browse files Browse the repository at this point in the history
This PR:

- Add the source_paths attribute to Doc base class.
- Add the corresponding Python binding for it.

This PR is depended by multiple tasks, including the diagnostic output in DocPrinter, VarTable and IRDocisifer.

Tracking issue: apache#11912

Co-authored-by: Greg Bonik <gbonik@octoml.ai>
  • Loading branch information
2 people authored and xinetzone committed Nov 25, 2022
1 parent f1975a0 commit 661e68e
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 3 deletions.
11 changes: 10 additions & 1 deletion include/tvm/script/printer/doc.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,16 @@ namespace printer {
*/
class DocNode : public Object {
public:
void VisitAttrs(AttrVisitor* v) {}
/*!
* \brief The list of object paths of the source IR node.
*
* This is used to trace back to the IR node position where
* this Doc is generated, in order to position the diagnostic
* message.
*/
mutable Array<ObjectPath> source_paths;

void VisitAttrs(AttrVisitor* v) { v->Visit("source_paths", &source_paths); }

static constexpr const char* _type_key = "script.printer.Doc";
TVM_DECLARE_BASE_OBJECT_INFO(DocNode, Object);
Expand Down
27 changes: 25 additions & 2 deletions python/tvm/script/printer/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Dict, List, Optional, Sequence, Tuple, Union

from tvm._ffi import register_object
from tvm.runtime import Object
from tvm.runtime import Object, ObjectPath
from tvm.tir import FloatImm, IntImm

from . import _ffi_api
Expand All @@ -29,8 +29,23 @@
class Doc(Object):
"""Base class of all Docs"""

@property
def source_paths(self) -> Sequence[ObjectPath]:
"""
The list of object paths of the source IR node.
This is used to trace back to the IR node position where
this Doc is generated, in order to position the diagnostic
message.
"""
return self.__getattr__("source_paths") # pylint: disable=unnecessary-dunder-call

@source_paths.setter
def source_paths(self, value):
return _ffi_api.DocSetSourcePaths(self, value) # type: ignore # pylint: disable=no-member

class ExprDoc(Object):

class ExprDoc(Doc):
"""Base class of all expression Docs"""

def attr(self, name: str) -> "AttrAccessDoc":
Expand Down Expand Up @@ -104,6 +119,14 @@ class StmtDoc(Doc):

@property
def comment(self) -> Optional[str]:
"""
The comment of this doc.
The actual position of the comment depends on the type of Doc
and also the DocPrinter implementation. It could be on the same
line as the statement, or the line above, or inside the statement
if it spans over multiple lines.
"""
# It has to call the dunder method to avoid infinite recursion
return self.__getattr__("comment") # pylint: disable=unnecessary-dunder-call

Expand Down
4 changes: 4 additions & 0 deletions src/script/printer/doc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,10 @@ ClassDoc::ClassDoc(IdDoc name, Array<ExprDoc> decorators, Array<StmtDoc> body) {
}

TVM_REGISTER_NODE_TYPE(DocNode);
TVM_REGISTER_GLOBAL("script.printer.DocSetSourcePaths")
.set_body_typed([](Doc doc, Array<ObjectPath> source_paths) {
doc->source_paths = source_paths;
});

TVM_REGISTER_NODE_TYPE(ExprDocNode);
TVM_REGISTER_GLOBAL("script.printer.ExprDocAttr").set_body_method<ExprDoc>(&ExprDocNode::Attr);
Expand Down
19 changes: 19 additions & 0 deletions tests/python/unittest/test_tvmscript_printer_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pytest

import tvm
from tvm.runtime import ObjectPath
from tvm.script.printer.doc import (
AssertDoc,
AssignDoc,
Expand Down Expand Up @@ -510,8 +511,26 @@ def test_stmt_doc_comment():

comment = "test comment"
doc.comment = comment
# Make sure the previous statement doesn't set attribute
# as if it's an ordinary Python object.
assert "comment" not in doc.__dict__
assert doc.comment == comment


def test_doc_source_paths():
doc = IdDoc("x")
assert len(doc.source_paths) == 0

source_paths = [ObjectPath.root(), ObjectPath.root().attr("x")]

doc.source_paths = source_paths
# This should triggers the __getattr__ and gets a tvm.ir.container.Array
assert not isinstance(doc.source_paths, list)
assert list(doc.source_paths) == source_paths

doc.source_paths = []
assert len(doc.source_paths) == 0


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 661e68e

Please sign in to comment.