Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[TVMScript] Add more helper functions to the printer infra (apache#12829
Browse files Browse the repository at this point in the history
)

This PR is split from apache#12492, to make the necessary updates to the printer infra for future PRs of TIR printer.

Tracking issue: apache#11912

Co-authored-by: Greg Bonik <gbonik@octoml.ai>
  • Loading branch information
2 people authored and xinetzone committed Nov 25, 2022
1 parent 62cee24 commit f9de7aa
Show file tree
Hide file tree
Showing 9 changed files with 228 additions and 62 deletions.
64 changes: 64 additions & 0 deletions include/tvm/script/printer/doc.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <tvm/ir/expr.h>
#include <tvm/node/node.h>
#include <tvm/runtime/data_type.h>
#include <tvm/script/printer/traced_object.h>

namespace tvm {
namespace script {
Expand Down Expand Up @@ -87,6 +88,15 @@ class ExprDocNode : public DocNode {
*/
ExprDoc Attr(String attr) const;

/*!
* \brief Create a doc representing attribute access on the current ExprDoc
* \param attr The attribute to access.
*
* The ObjectPath of attr will be pushed to the source_path of the returned
* doc.
*/
ExprDoc Attr(TracedObject<String> attr) const;

/*!
* \brief Create a doc representing index access on the current ExprDoc
* \param indices The indices to access.
Expand Down Expand Up @@ -242,37 +252,91 @@ class LiteralDocNode : public ExprDocNode {
class LiteralDoc : public ExprDoc {
protected:
explicit LiteralDoc(ObjectRef value);
LiteralDoc(ObjectRef value, ObjectPath object_path);

public:
/*!
* \brief Create a LiteralDoc to represent None/null/empty value.
*/
static LiteralDoc None() { return LiteralDoc(ObjectRef(nullptr)); }

/*!
* \brief Create a LiteralDoc to represent None/null/empty value.
* \param object_path The source path of the returned Doc.
*/
static LiteralDoc None(ObjectPath object_path) {
return LiteralDoc(ObjectRef(nullptr), object_path);
}

/*!
* \brief Create a LiteralDoc to represent integer.
* \param v The integer value.
*/
static LiteralDoc Int(int v) { return LiteralDoc(IntImm(DataType::Int(64), v)); }

/*!
* \brief Create a LiteralDoc to represent integer.
* \param v The integer value.
*
* The ObjectPath of v will be pushed to the source_path of the returned doc.
*/
static LiteralDoc Int(const TracedObject<IntImm>& v) { return LiteralDoc(v.Get(), v.GetPath()); }

/*!
* \brief Create a LiteralDoc to represent integer.
* \param v The integer value.
*
* The ObjectPath of v will be pushed to the source_path of the returned doc.
*/
static LiteralDoc Int(const TracedBasicValue<int>& v) {
return LiteralDoc(IntImm(DataType::Int(64), v.Get()), v.GetPath());
}
/*!
* \brief Create a LiteralDoc to represent boolean.
* \param v The boolean value.
*/
static LiteralDoc Boolean(bool v) { return LiteralDoc(IntImm(DataType::Bool(), v)); }

/*!
* \brief Create a LiteralDoc to represent boolean.
* \param v The boolean value.
*
* The ObjectPath of v will be pushed to the source_path of the returned doc.
*/
static LiteralDoc Boolean(const TracedBasicValue<bool>& v) {
return LiteralDoc(IntImm(DataType::Bool(), v.Get()), v.GetPath());
}

/*!
* \brief Create a LiteralDoc to represent float.
* \param v The float value.
*/
static LiteralDoc Float(double v) { return LiteralDoc(FloatImm(DataType::Float(64), v)); }

/*!
* \brief Create a LiteralDoc to represent float.
* \param v The float value.
*
* The ObjectPath of v will be pushed to the source_path of the returned doc.
*/
static LiteralDoc Float(const TracedObject<FloatImm>& v) {
return LiteralDoc(v.Get(), v.GetPath());
}

/*!
* \brief Create a LiteralDoc to represent string.
* \param v The string value.
*/
static LiteralDoc Str(const String& v) { return LiteralDoc(v); }

/*!
* \brief Create a LiteralDoc to represent string.
* \param v The string value.
*
* The ObjectPath of v will be pushed to the source_path of the returned doc.
*/
static LiteralDoc Str(const TracedObject<String>& v) { return LiteralDoc(v.Get(), v.GetPath()); }

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LiteralDoc, ExprDoc, LiteralDocNode);
};

Expand Down
37 changes: 4 additions & 33 deletions include/tvm/script/printer/traced_object_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,35 +34,6 @@ namespace tvm {
namespace script {
namespace printer {

namespace {

namespace detail {
/*!
* \brief Helper template class to extract the type of first argument of a function
* \tparam FType The function type.
*/
template <typename FType>
struct FirstArgTypeGetter;

template <typename R, typename ArgOne, typename... OtherArgs>
struct FirstArgTypeGetter<R(ArgOne, OtherArgs...)> {
using T = ArgOne;
};

/*!
* \brief Template alias for the type of first argument of a function
* \tparam FType The function type.
*
* The name of public functions are in snake case to be consistent with
* tvm/node/functor.h
*/
template <typename FType>
using FirstArgType = typename detail::FirstArgTypeGetter<
typename tvm::runtime::detail::function_signature<FType>::FType>::T;
} // namespace detail

} // namespace

/*
* This type alias and the following free functions are created to reduce the binary bloat
* from template and also hide implementation details from this header
Expand Down Expand Up @@ -156,8 +127,7 @@ class TracedObjectFunctor {
*
* The diaptch function should have signature `R(TracedObject<TObjectRef>, Args...)`.
*/
template <typename TCallable,
typename TObjectRef = typename detail::FirstArgType<TCallable>::ObjectRefType,
template <typename TObjectRef, typename TCallable,
typename = std::enable_if_t<IsDispatchFunction<TObjectRef, TCallable>::value>>
TSelf& set_dispatch(String token, TCallable f) {
return set_dispatch(
Expand All @@ -177,9 +147,10 @@ class TracedObjectFunctor {
*
* Default dispatch function has an empty string as dispatch token.
*/
template <typename TCallable>
template <typename TObjectRef, typename TCallable,
typename = std::enable_if_t<IsDispatchFunction<TObjectRef, TCallable>::value>>
TSelf& set_dispatch(TCallable&& f) {
return set_dispatch(kDefaultDispatchToken, std::forward<TCallable>(f));
return set_dispatch<TObjectRef>(kDefaultDispatchToken, std::forward<TCallable>(f));
}

/*!
Expand Down
11 changes: 11 additions & 0 deletions include/tvm/script/printer/var_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,17 @@ class VarTableNode : public Object {
*/
Optional<ExprDoc> GetVarDoc(const ObjectRef& obj, const ObjectPath& object_path) const;

/*!
* \brief Get the doc for variable.
* \param obj The traced variable object.
*
* \return The doc for variable, if it exists in the table. Otherwise it returns NullOpt.
*/
template <typename TObjectRef>
Optional<ExprDoc> GetVarDoc(const TracedObject<TObjectRef> obj) const {
return GetVarDoc(obj.Get(), obj.GetPath());
}

/*!
* \brief Check if a variable exists in the table.
* \param obj The variable object.
Expand Down
30 changes: 24 additions & 6 deletions src/script/printer/doc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ namespace printer {

ExprDoc ExprDocNode::Attr(String attr) const { return AttrAccessDoc(GetRef<ExprDoc>(this), attr); }

ExprDoc ExprDocNode::Attr(TracedObject<String> attr) const {
auto doc = AttrAccessDoc(GetRef<ExprDoc>(this), attr.Get());
doc->source_paths.push_back(attr.GetPath());
return doc;
}

ExprDoc ExprDocNode::operator[](Array<Doc> indices) const {
return IndexDoc(GetRef<ExprDoc>(this), indices);
}
Expand Down Expand Up @@ -54,6 +60,13 @@ LiteralDoc::LiteralDoc(ObjectRef value) {
this->data_ = std::move(n);
}

LiteralDoc::LiteralDoc(ObjectRef value, ObjectPath object_path) {
ObjectPtr<LiteralDocNode> n = make_object<LiteralDocNode>();
n->value = value;
n->source_paths.push_back(object_path);
this->data_ = std::move(n);
}

IdDoc::IdDoc(String name) {
ObjectPtr<IdDocNode> n = make_object<IdDocNode>();
n->name = name;
Expand Down Expand Up @@ -225,7 +238,8 @@ TVM_REGISTER_GLOBAL("script.printer.DocSetSourcePaths")
});

TVM_REGISTER_NODE_TYPE(ExprDocNode);
TVM_REGISTER_GLOBAL("script.printer.ExprDocAttr").set_body_method<ExprDoc>(&ExprDocNode::Attr);
TVM_REGISTER_GLOBAL("script.printer.ExprDocAttr")
.set_body_method<ExprDoc, ExprDocNode, ExprDoc, String>(&ExprDocNode::Attr);
TVM_REGISTER_GLOBAL("script.printer.ExprDocIndex")
.set_body_method<ExprDoc>(&ExprDocNode::operator[]);
TVM_REGISTER_GLOBAL("script.printer.ExprDocCall")
Expand All @@ -242,11 +256,15 @@ TVM_REGISTER_GLOBAL("script.printer.StmtBlockDoc").set_body_typed([](Array<StmtD
});

TVM_REGISTER_NODE_TYPE(LiteralDocNode);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed(LiteralDoc::None);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt").set_body_typed(LiteralDoc::Int);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean").set_body_typed(LiteralDoc::Boolean);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat").set_body_typed(LiteralDoc::Float);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr").set_body_typed(LiteralDoc::Str);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed<LiteralDoc()>(LiteralDoc::None);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt")
.set_body_typed<LiteralDoc(int)>(LiteralDoc::Int);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean")
.set_body_typed<LiteralDoc(bool)>(LiteralDoc::Boolean);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat")
.set_body_typed<LiteralDoc(double)>(LiteralDoc::Float);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr")
.set_body_typed<LiteralDoc(const String&)>(LiteralDoc::Str);

TVM_REGISTER_NODE_TYPE(IdDocNode);
TVM_REGISTER_GLOBAL("script.printer.IdDoc").set_body_typed([](String name) { return IdDoc(name); });
Expand Down
2 changes: 1 addition & 1 deletion src/script/printer/ir_docsifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ RootNodeContainer::RootNodeContainer(ObjectRef root_node) {
// });
// \endcode
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch([](TracedObject<RootNodeContainer> obj, IRDocsifier p) -> Doc {
.set_dispatch<RootNodeContainer>([](TracedObject<RootNodeContainer> obj, IRDocsifier p) -> Doc {
String top_dispatch_token = p->dispatch_tokens.back();
ICHECK_NE(top_dispatch_token, "");
ICHECK(false) << "Printing IR " << top_dispatch_token << " is not implemented.";
Expand Down
93 changes: 93 additions & 0 deletions src/script/printer/utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_SCRIPT_PRINTER_UTILS_H_
#define TVM_SCRIPT_PRINTER_UTILS_H_

#include <tvm/script/printer/doc.h>
#include <tvm/script/printer/ir_docsifier.h>

#include <utility>

namespace tvm {
namespace script {
namespace printer {

template <typename DocType, typename NodeType>
Array<DocType> AsDocArray(const TracedArray<NodeType>& refs, const IRDocsifier& ir_docsifier) {
Array<DocType> result;
for (auto ref : refs) {
result.push_back(ir_docsifier->AsExprDoc(ref));
}
return result;
}

template <typename DocType, typename NodeType>
Array<DocType> AsDocArray(std::initializer_list<NodeType>&& refs, const IRDocsifier& ir_docsifier) {
Array<DocType> result;
for (auto& ref : refs) {
result.push_back(ir_docsifier->AsExprDoc(ref));
}
return result;
}

template <typename RefType>
Array<ExprDoc> AsExprDocArray(const TracedArray<RefType>& refs, const IRDocsifier& ir_docsifier) {
return AsDocArray<ExprDoc>(refs, ir_docsifier);
}

template <typename RefType>
Array<ExprDoc> AsExprDocArray(std::initializer_list<RefType>&& refs,
const IRDocsifier& ir_docsifier) {
return AsDocArray<ExprDoc>(std::move(refs), ir_docsifier);
}

inline DictDoc AsDictDoc(const TracedMap<String, ObjectRef>& dict,
const IRDocsifier& ir_docsifier) {
Array<ExprDoc> keys;
Array<ExprDoc> values;

for (auto p : dict) {
keys.push_back(LiteralDoc::Str(p.first));
values.push_back(ir_docsifier->AsExprDoc(p.second));
}

auto doc = DictDoc(keys, values);
doc->source_paths.push_back(dict.GetPath());
return doc;
}

template <typename T>
inline ListDoc AsListDoc(const TracedArray<T>& arr, const IRDocsifier& ir_docsifier) {
auto ret = ListDoc(AsExprDocArray(arr, ir_docsifier));
ret->source_paths.push_back(arr.GetPath());
return ret;
}

template <typename T>
inline TupleDoc AsTupleDoc(const TracedArray<T>& arr, const IRDocsifier& ir_docsifier) {
auto ret = TupleDoc(AsExprDocArray(arr, ir_docsifier));
ret->source_paths.push_back(arr.GetPath());
return ret;
}

} // namespace printer
} // namespace script
} // namespace tvm

#endif // TVM_SCRIPT_PRINTER_UTILS_H_
3 changes: 2 additions & 1 deletion src/script/printer/var_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ TVM_REGISTER_GLOBAL("script.printer.VarTableDefineByDoc")
obj, [f = std::move(factory)]() { return f(); }, frame);
});
TVM_REGISTER_GLOBAL("script.printer.VarTableGetVarDoc")
.set_body_method<VarTable>(&VarTableNode::GetVarDoc);
.set_body_method<VarTable, VarTableNode, Optional<ExprDoc>, const ObjectRef&,
const ObjectPath&>(&VarTableNode::GetVarDoc);
TVM_REGISTER_GLOBAL("script.printer.VarTableIsVarDefined")
.set_body_method<VarTable>(&VarTableNode::IsVarDefined);

Expand Down
13 changes: 9 additions & 4 deletions tests/cpp/tvmscript_printer_irdocsifier_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,19 @@ class TestObject : public ObjectRef {
TVM_REGISTER_NODE_TYPE(TestObjectNode);

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch([](TracedObject<TestObject> obj, IRDocsifier p) { return IdDoc("x"); });
.set_dispatch<TestObject>([](TracedObject<TestObject> obj, IRDocsifier p) {
return IdDoc("x");
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch("tir", [](TracedObject<TestObject> obj, IRDocsifier p) { return IdDoc("tir"); });
.set_dispatch<TestObject>("tir", [](TracedObject<TestObject> obj, IRDocsifier p) {
return IdDoc("tir");
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch("relax",
[](TracedObject<TestObject> obj, IRDocsifier p) { return IdDoc("relax"); });
.set_dispatch<TestObject>("relax", [](TracedObject<TestObject> obj, IRDocsifier p) {
return IdDoc("relax");
});

TEST(PrinterIRDocsifierTest, AsDoc) {
IRDocsifier p(Map<String, String>{});
Expand Down
Loading

0 comments on commit f9de7aa

Please sign in to comment.