Skip to content

Commit

Permalink
Add TIR var and type printing
Browse files Browse the repository at this point in the history
Co-authored-by: Greg Bonik <gbonik@octoml.ai>
  • Loading branch information
yelite and gbonik committed Aug 18, 2022
1 parent 37e125e commit 63d8623
Show file tree
Hide file tree
Showing 14 changed files with 632 additions and 62 deletions.
64 changes: 64 additions & 0 deletions include/tvm/script/printer/doc.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <tvm/ir/expr.h>
#include <tvm/node/node.h>
#include <tvm/runtime/data_type.h>
#include <tvm/script/printer/traced_object.h>

namespace tvm {
namespace script {
Expand Down Expand Up @@ -87,6 +88,15 @@ class ExprDocNode : public DocNode {
*/
ExprDoc Attr(String attr) const;

/*!
* \brief Create a doc representing attribute access on the current ExprDoc
* \param attr The attribute to access.
*
* The ObjectPath of attr will be pushed to the source_path of the returned
* doc.
*/
ExprDoc Attr(TracedObject<String> attr) const;

/*!
* \brief Create a doc representing index access on the current ExprDoc
* \param indices The indices to access.
Expand Down Expand Up @@ -242,37 +252,91 @@ class LiteralDocNode : public ExprDocNode {
class LiteralDoc : public ExprDoc {
protected:
explicit LiteralDoc(ObjectRef value);
LiteralDoc(ObjectRef value, ObjectPath object_path);

public:
/*!
* \brief Create a LiteralDoc to represent None/null/empty value.
*/
static LiteralDoc None() { return LiteralDoc(ObjectRef(nullptr)); }

/*!
* \brief Create a LiteralDoc to represent None/null/empty value.
* \param object_path The source path of the returned Doc.
*/
static LiteralDoc None(ObjectPath object_path) {
return LiteralDoc(ObjectRef(nullptr), object_path);
}

/*!
* \brief Create a LiteralDoc to represent integer.
* \param v The integer value.
*/
static LiteralDoc Int(int v) { return LiteralDoc(IntImm(DataType::Int(64), v)); }

/*!
* \brief Create a LiteralDoc to represent integer.
* \param v The integer value.
*
* The ObjectPath of v will be pushed to the source_path of the returned doc.
*/
static LiteralDoc Int(const TracedObject<IntImm>& v) { return LiteralDoc(v.Get(), v.GetPath()); }

/*!
* \brief Create a LiteralDoc to represent integer.
* \param v The integer value.
*
* The ObjectPath of v will be pushed to the source_path of the returned doc.
*/
static LiteralDoc Int(const TracedBasicValue<int>& v) {
return LiteralDoc(IntImm(DataType::Int(64), v.Get()), v.GetPath());
}
/*!
* \brief Create a LiteralDoc to represent boolean.
* \param v The boolean value.
*/
static LiteralDoc Boolean(bool v) { return LiteralDoc(IntImm(DataType::Bool(), v)); }

/*!
* \brief Create a LiteralDoc to represent boolean.
* \param v The boolean value.
*
* The ObjectPath of v will be pushed to the source_path of the returned doc.
*/
static LiteralDoc Boolean(const TracedBasicValue<bool>& v) {
return LiteralDoc(IntImm(DataType::Bool(), v.Get()), v.GetPath());
}

/*!
* \brief Create a LiteralDoc to represent float.
* \param v The float value.
*/
static LiteralDoc Float(double v) { return LiteralDoc(FloatImm(DataType::Float(64), v)); }

/*!
* \brief Create a LiteralDoc to represent float.
* \param v The float value.
*
* The ObjectPath of v will be pushed to the source_path of the returned doc.
*/
static LiteralDoc Float(const TracedObject<FloatImm>& v) {
return LiteralDoc(v.Get(), v.GetPath());
}

/*!
* \brief Create a LiteralDoc to represent string.
* \param v The string value.
*/
static LiteralDoc Str(const String& v) { return LiteralDoc(v); }

/*!
* \brief Create a LiteralDoc to represent string.
* \param v The string value.
*
* The ObjectPath of v will be pushed to the source_path of the returned doc.
*/
static LiteralDoc Str(const TracedObject<String>& v) { return LiteralDoc(v.Get(), v.GetPath()); }

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LiteralDoc, ExprDoc, LiteralDocNode);
};

Expand Down
37 changes: 4 additions & 33 deletions include/tvm/script/printer/traced_object_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,35 +34,6 @@ namespace tvm {
namespace script {
namespace printer {

namespace {

namespace detail {
/*!
* \brief Helper template class to extract the type of first argument of a function
* \tparam FType The function type.
*/
template <typename FType>
struct FirstArgTypeGetter;

template <typename R, typename ArgOne, typename... OtherArgs>
struct FirstArgTypeGetter<R(ArgOne, OtherArgs...)> {
using T = ArgOne;
};

/*!
* \brief Template alias for the type of first argument of a function
* \tparam FType The function type.
*
* The name of public functions are in snake case to be consistent with
* tvm/node/functor.h
*/
template <typename FType>
using FirstArgType = typename detail::FirstArgTypeGetter<
typename tvm::runtime::detail::function_signature<FType>::FType>::T;
} // namespace detail

} // namespace

/*
* This type alias and the following free functions are created to reduce the binary bloat
* from template and also hide implementation details from this header
Expand Down Expand Up @@ -156,8 +127,7 @@ class TracedObjectFunctor {
*
* The diaptch function should have signature `R(TracedObject<TObjectRef>, Args...)`.
*/
template <typename TCallable,
typename TObjectRef = typename detail::FirstArgType<TCallable>::ObjectRefType,
template <typename TObjectRef, typename TCallable,
typename = std::enable_if_t<IsDispatchFunction<TObjectRef, TCallable>::value>>
TSelf& set_dispatch(String token, TCallable f) {
return set_dispatch(
Expand All @@ -177,9 +147,10 @@ class TracedObjectFunctor {
*
* Default dispatch function has an empty string as dispatch token.
*/
template <typename TCallable>
template <typename TObjectRef, typename TCallable,
typename = std::enable_if_t<IsDispatchFunction<TObjectRef, TCallable>::value>>
TSelf& set_dispatch(TCallable&& f) {
return set_dispatch(kDefaultDispatchToken, std::forward<TCallable>(f));
return set_dispatch<TObjectRef>(kDefaultDispatchToken, std::forward<TCallable>(f));
}

/*!
Expand Down
11 changes: 11 additions & 0 deletions include/tvm/script/printer/var_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,17 @@ class VarTableNode : public Object {
*/
Optional<ExprDoc> GetVarDoc(const ObjectRef& obj, const ObjectPath& object_path) const;

/*!
* \brief Get the doc for variable.
* \param obj The traced variable object.
*
* \return The doc for variable, if it exists in the table. Otherwise it returns NullOpt.
*/
template <typename TObjectRef>
Optional<ExprDoc> GetVarDoc(const TracedObject<TObjectRef> obj) const {
return GetVarDoc(obj.Get(), obj.GetPath());
}

/*!
* \brief Check if a variable exists in the table.
* \param obj The variable object.
Expand Down
30 changes: 24 additions & 6 deletions src/script/printer/doc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ namespace printer {

ExprDoc ExprDocNode::Attr(String attr) const { return AttrAccessDoc(GetRef<ExprDoc>(this), attr); }

ExprDoc ExprDocNode::Attr(TracedObject<String> attr) const {
auto doc = AttrAccessDoc(GetRef<ExprDoc>(this), attr.Get());
doc->source_paths.push_back(attr.GetPath());
return doc;
}

ExprDoc ExprDocNode::operator[](Array<Doc> indices) const {
return IndexDoc(GetRef<ExprDoc>(this), indices);
}
Expand Down Expand Up @@ -54,6 +60,13 @@ LiteralDoc::LiteralDoc(ObjectRef value) {
this->data_ = std::move(n);
}

LiteralDoc::LiteralDoc(ObjectRef value, ObjectPath object_path) {
ObjectPtr<LiteralDocNode> n = make_object<LiteralDocNode>();
n->value = value;
n->source_paths.push_back(object_path);
this->data_ = std::move(n);
}

IdDoc::IdDoc(String name) {
ObjectPtr<IdDocNode> n = make_object<IdDocNode>();
n->name = name;
Expand Down Expand Up @@ -225,7 +238,8 @@ TVM_REGISTER_GLOBAL("script.printer.DocSetSourcePaths")
});

TVM_REGISTER_NODE_TYPE(ExprDocNode);
TVM_REGISTER_GLOBAL("script.printer.ExprDocAttr").set_body_method<ExprDoc>(&ExprDocNode::Attr);
TVM_REGISTER_GLOBAL("script.printer.ExprDocAttr")
.set_body_method<ExprDoc, ExprDocNode, ExprDoc, String>(&ExprDocNode::Attr);
TVM_REGISTER_GLOBAL("script.printer.ExprDocIndex")
.set_body_method<ExprDoc>(&ExprDocNode::operator[]);
TVM_REGISTER_GLOBAL("script.printer.ExprDocCall")
Expand All @@ -242,11 +256,15 @@ TVM_REGISTER_GLOBAL("script.printer.StmtBlockDoc").set_body_typed([](Array<StmtD
});

TVM_REGISTER_NODE_TYPE(LiteralDocNode);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed(LiteralDoc::None);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt").set_body_typed(LiteralDoc::Int);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean").set_body_typed(LiteralDoc::Boolean);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat").set_body_typed(LiteralDoc::Float);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr").set_body_typed(LiteralDoc::Str);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed<LiteralDoc()>(LiteralDoc::None);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt")
.set_body_typed<LiteralDoc(int)>(LiteralDoc::Int);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean")
.set_body_typed<LiteralDoc(bool)>(LiteralDoc::Boolean);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat")
.set_body_typed<LiteralDoc(double)>(LiteralDoc::Float);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr")
.set_body_typed<LiteralDoc(const String&)>(LiteralDoc::Str);

TVM_REGISTER_NODE_TYPE(IdDocNode);
TVM_REGISTER_GLOBAL("script.printer.IdDoc").set_body_typed([](String name) { return IdDoc(name); });
Expand Down
2 changes: 1 addition & 1 deletion src/script/printer/ir_docsifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ RootNodeContainer::RootNodeContainer(ObjectRef root_node) {
// });
// \endcode
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch([](TracedObject<RootNodeContainer> obj, IRDocsifier p) -> Doc {
.set_dispatch<RootNodeContainer>([](TracedObject<RootNodeContainer> obj, IRDocsifier p) -> Doc {
String top_dispatch_token = p->dispatch_tokens.back();
ICHECK_NE(top_dispatch_token, "");
ICHECK(false) << "Printing IR " << top_dispatch_token << " is not implemented.";
Expand Down
77 changes: 77 additions & 0 deletions src/script/printer/tir/tir.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include "./tir.h"

#include <tvm/script/printer/doc.h>
#include <tvm/script/printer/ir_docsifier.h>
#include <tvm/script/printer/traced_object.h>
#include <tvm/script/printer/traced_object_functor.h>
#include <tvm/tir/op.h>

namespace tvm {
namespace script {
namespace printer {

TIRTopLevelFrame::TIRTopLevelFrame() : TIRFrame(make_object<TIRTopLevelFrameNode>()) {}

TIRGeneralFrame::TIRGeneralFrame() : TIRFrame(make_object<TIRGeneralFrameNode>()) {}

ExprDoc GetTypeAnnotationDocForVar(const TracedObject<tir::Var>& var, const IRDocsifier& p) {
auto type_annotation = var.GetAttr(&tir::VarNode::type_annotation);
if (type_annotation.Get().defined()) {
return p->AsExprDoc(type_annotation);
} else {
auto dtype = var.GetAttr(&tir::VarNode::dtype);
Type raw_type = GetTypeFromRuntimeDataType(dtype.Get());
return p->AsExprDoc(MakeTraced(raw_type, dtype.GetPath()));
}
}

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<RootNodeContainer>("tir", [](TracedObject<RootNodeContainer> obj, IRDocsifier p) {
const ObjectRef& root_node = obj.Get()->root_node;

TIRTopLevelFrame top_level_frame;
auto frame_ctx = p->WithFrame(top_level_frame);

// Because we are printing a single element, concise scoping should be allowed by default
top_level_frame->allow_concise_scoping = true;

Doc root_doc = p->AsDoc<Doc>(MakeTraced(root_node));

Array<StmtDoc> doc_to_print = top_level_frame->free_var_definitions;

if (const auto* stmt_doc_node = root_doc.as<StmtDocNode>()) {
doc_to_print.push_back(GetRef<StmtDoc>(stmt_doc_node));
} else if (const auto* expr_doc_node = root_doc.as<ExprDocNode>()) {
doc_to_print.push_back(ExprStmtDoc(GetRef<ExprDoc>(expr_doc_node)));
} else if (const auto* stmt_block_node = root_doc.as<StmtBlockDocNode>()) {
doc_to_print = runtime::Concat(doc_to_print, stmt_block_node->stmts);
} else if (const auto* slice_doc_node = root_doc.as<SliceDocNode>()) {
doc_to_print.push_back(ExprStmtDoc(IdDoc("_")[{GetRef<SliceDoc>(slice_doc_node)}]));
} else {
ICHECK(false) << "Cannot print " << root_doc->GetTypeKey() << " as top level doc for TIR.";
}

return StmtBlockDoc(doc_to_print);
});
} // namespace printer
} // namespace script
} // namespace tvm
Loading

0 comments on commit 63d8623

Please sign in to comment.