diff --git a/src/script/printer/tir/tir.cc b/src/script/printer/tir/tir.cc new file mode 100644 index 000000000000..38bd94a72bb5 --- /dev/null +++ b/src/script/printer/tir/tir.cc @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "./tir.h" + +#include +#include +#include +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +TIRTopLevelFrame::TIRTopLevelFrame() : TIRFrame(make_object()) {} + +TIRGeneralFrame::TIRGeneralFrame() : TIRFrame(make_object()) {} + +ExprDoc GetTypeAnnotationDocForVar(const TracedObject& var, const IRDocsifier& p) { + auto type_annotation = var.GetAttr(&tir::VarNode::type_annotation); + if (type_annotation.Get().defined()) { + return p->AsExprDoc(type_annotation); + } else { + auto dtype = var.GetAttr(&tir::VarNode::dtype); + Type raw_type = GetTypeFromRuntimeDataType(dtype.Get()); + return p->AsExprDoc(MakeTraced(raw_type, dtype.GetPath())); + } +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("tir", [](TracedObject obj, IRDocsifier p) { + const ObjectRef& root_node = obj.Get()->root_node; + + TIRTopLevelFrame top_level_frame; + auto frame_ctx = p->WithFrame(top_level_frame); + + // Because we are printing a single element, concise scoping should be allowed by default + top_level_frame->allow_concise_scoping = true; + + Doc root_doc = p->AsDoc(MakeTraced(root_node)); + + Array doc_to_print = top_level_frame->free_var_definitions; + + if (const auto* stmt_doc_node = root_doc.as()) { + doc_to_print.push_back(GetRef(stmt_doc_node)); + } else if (const auto* expr_doc_node = root_doc.as()) { + doc_to_print.push_back(ExprStmtDoc(GetRef(expr_doc_node))); + } else if (const auto* stmt_block_node = root_doc.as()) { + doc_to_print = runtime::Concat(doc_to_print, stmt_block_node->stmts); + } else if (const auto* slice_doc_node = root_doc.as()) { + doc_to_print.push_back(ExprStmtDoc(IdDoc("_")[{GetRef(slice_doc_node)}])); + } else { + ICHECK(false) << "Cannot print " << root_doc->GetTypeKey() << " as top level doc for TIR."; + } + + return StmtBlockDoc(doc_to_print); + }); +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/tir/tir.h b/src/script/printer/tir/tir.h new file mode 100644 index 000000000000..bb5973ee4f3b --- /dev/null +++ b/src/script/printer/tir/tir.h @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_TIR_TIR_H_ +#define TVM_SCRIPT_PRINTER_TIR_TIR_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +class TIRFrameNode : public FrameNode { + public: + mutable bool allow_concise_scoping{false}; + + void VisitAttrs(AttrVisitor* v) { + FrameNode::VisitAttrs(v); + v->Visit("allow_concise_scoping", &allow_concise_scoping); + } + + static constexpr const char* _type_key = "script.printer.TIRFrame"; + TVM_DECLARE_BASE_OBJECT_INFO(TIRFrameNode, FrameNode); +}; + +class TIRFrame : public Frame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRFrame, Frame, TIRFrameNode); +}; + +class TIRTopLevelFrameNode : public TIRFrameNode { + public: + Array free_var_definitions; + + void VisitAttrs(AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("free_var_definitions", &free_var_definitions); + } + + static constexpr const char* _type_key = "script.printer.TIRTopLevelFrame"; + TVM_DECLARE_BASE_OBJECT_INFO(TIRTopLevelFrameNode, FrameNode); +}; + +class TIRTopLevelFrame : public TIRFrame { + public: + TIRTopLevelFrame(); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRTopLevelFrame, TIRFrame, + TIRTopLevelFrameNode); +}; + +class TIRGeneralFrameNode : public TIRFrameNode { + public: + static constexpr const char* _type_key = "script.printer.TIRGeneralFrame"; + TVM_DECLARE_BASE_OBJECT_INFO(TIRGeneralFrameNode, FrameNode); +}; + +class TIRGeneralFrame : public TIRFrame { + public: + TIRGeneralFrame(); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRGeneralFrame, TIRFrame, TIRGeneralFrameNode); +}; + +inline IdDoc TIR(const IRDocsifier& p) { return IdDoc(p->ir_prefix.Get("tir").value_or("T")); } + +ExprDoc GetTypeAnnotationDocForVar(const TracedObject& var, const IRDocsifier& p); + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_TIR_TIR_H_ diff --git a/src/script/printer/tir/type.cc b/src/script/printer/tir/type.cc new file mode 100644 index 000000000000..09aa96be7847 --- /dev/null +++ b/src/script/printer/tir/type.cc @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../utils.h" +#include "./tir.h" + +namespace tvm { +namespace script { +namespace printer { + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("tir", [](TracedObject ty, IRDocsifier p) -> Doc { + TracedBasicValue dtype = ty.GetAttr(&PrimTypeNode::dtype); + String ty_str = runtime::DLDataType2String(dtype.Get()); + return TIR(p)->Attr(MakeTraced(ty_str, ty.GetPath())); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("tir", [](TracedObject ty, IRDocsifier p) -> Doc { + TracedObject element_type = ty.GetAttr(&PointerTypeNode::element_type); + TracedObject storage_scope = ty.GetAttr(&PointerTypeNode::storage_scope); + + ExprDoc element_type_doc = p->AsDoc(element_type); + if (storage_scope.Get().empty()) { + return TIR(p)->Attr("Ptr")->Call({element_type_doc}); + } else { + return TIR(p)->Attr("Ptr")->Call({element_type_doc, LiteralDoc::Str(storage_scope)}); + } + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("tir", [](TracedObject ty, IRDocsifier p) -> Doc { + auto fields = ty.GetAttr(&TupleTypeNode::fields); + + if (fields.empty()) { + return LiteralDoc::None(fields.GetPath()); + } + return TIR(p)->Attr("Tuple")->Call(AsExprDocArray(fields, p)); + }); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/tir/var.cc b/src/script/printer/tir/var.cc new file mode 100644 index 000000000000..e6c200e1fe8e --- /dev/null +++ b/src/script/printer/tir/var.cc @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include +#include +#include + +#include "./tir.h" + +namespace tvm { +namespace script { +namespace printer { + +TracedObject GetVarNameHint(const TracedObject& var) { + TracedObject name_hint = var.GetAttr(&tir::VarNode::name_hint); + if (name_hint.Get().empty()) { + return MakeTraced(String("v"), var.GetPath()); + } else { + return name_hint; + } +} + +IdDoc CreateFreeVariableDefinition(TracedObject var, IRDocsifier p) { + TracedObject name_hint = GetVarNameHint(var); + // TODO(yelite): When implementing the PrimFunc printing, the logic here + // needs to change, putting variable def into PrimFuncFrame if it exists. + TIRTopLevelFrame top_level_frame = p->GetFrame().value(); + IdDoc doc = p->vars->Define(var.Get(), name_hint, top_level_frame); + StmtDoc def_doc = AssignDoc(doc, NullOpt, GetTypeAnnotationDocForVar(var, p)); + top_level_frame->free_var_definitions.push_back(def_doc); + return doc; +} + +ExprDoc PrintVariable(TracedObject var, IRDocsifier p) { + Optional doc = p->vars->GetVarDoc(var); + if (doc.defined()) { + return doc.value(); + } else { + return CreateFreeVariableDefinition(var, p); + } +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch(PrintVariable); +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject var, IRDocsifier p) { + return PrintVariable(MakeTraced(var.Get(), var.GetPath()), p); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject v, IRDocsifier p) -> Doc { + LOG(FATAL) << "Cannot print IterVar directly. Please use the helper functions in tir.h for " + "specific usage of IterVar."; + throw; + }); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py new file mode 100644 index 000000000000..936a2d74b48e --- /dev/null +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from tvm.ir import PointerType, PrimType, TupleType +from tvm.script.printer import script +from tvm.tir import SizeVar, Var + + +def format_script(s: str) -> str: + """ + Remove leading and trailing blank lines, and make the minimum idention 0 + """ + s = s.strip("\n") + + non_empty_lines = [line for line in s.splitlines() if line and not line.isspace()] + if not non_empty_lines: + # no actual content + return "\n" + + line_indents = [len(line) - len(line.lstrip(" ")) for line in non_empty_lines] + spaces_to_remove = min(line_indents) + + cleaned_lines = "\n".join(line[spaces_to_remove:] for line in s.splitlines()) + if not cleaned_lines.endswith("\n"): + cleaned_lines += "\n" + return cleaned_lines + + +@pytest.mark.parametrize( + "ty, expected", + [ + ( + PrimType("int8"), + """ + T.int8 + """, + ), + ( + PrimType("float32"), + """ + T.float32 + """, + ), + ( + PointerType(PrimType("int32")), + """ + T.Ptr(T.int32) + """, + ), + ( + PointerType(PrimType("int32"), "global"), + """ + T.Ptr(T.int32, "global") + """, + ), + ( + TupleType([]), + """ + None + """, + ), + ], +) +def test_type(ty, expected): + assert format_script(expected) == script(ty, "tir", {"tir": "T"}) + + +@pytest.mark.parametrize("var_type", [Var, SizeVar]) +def test_var(var_type): + var = var_type("x", "int8") + + assert script(var, "tir", {"tir": "T"}) == format_script( + """ + x: T.int8 + x + """ + )