Skip to content

Commit

Permalink
Add TIR var and type printing
Browse files Browse the repository at this point in the history
Co-authored-by: Greg Bonik <gbonik@octoml.ai>
  • Loading branch information
yelite and gbonik committed Sep 17, 2022
1 parent 22755f7 commit 7551fe0
Show file tree
Hide file tree
Showing 5 changed files with 404 additions and 0 deletions.
77 changes: 77 additions & 0 deletions src/script/printer/tir/tir.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include "./tir.h"

#include <tvm/script/printer/doc.h>
#include <tvm/script/printer/ir_docsifier.h>
#include <tvm/script/printer/traced_object.h>
#include <tvm/script/printer/traced_object_functor.h>
#include <tvm/tir/op.h>

namespace tvm {
namespace script {
namespace printer {

TIRTopLevelFrame::TIRTopLevelFrame() : TIRFrame(make_object<TIRTopLevelFrameNode>()) {}

TIRGeneralFrame::TIRGeneralFrame() : TIRFrame(make_object<TIRGeneralFrameNode>()) {}

ExprDoc GetTypeAnnotationDocForVar(const TracedObject<tir::Var>& var, const IRDocsifier& p) {
auto type_annotation = var.GetAttr(&tir::VarNode::type_annotation);
if (type_annotation.Get().defined()) {
return p->AsExprDoc(type_annotation);
} else {
auto dtype = var.GetAttr(&tir::VarNode::dtype);
Type raw_type = GetTypeFromRuntimeDataType(dtype.Get());
return p->AsExprDoc(MakeTraced(raw_type, dtype.GetPath()));
}
}

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<RootNodeContainer>("tir", [](TracedObject<RootNodeContainer> obj, IRDocsifier p) {
const ObjectRef& root_node = obj.Get()->root_node;

TIRTopLevelFrame top_level_frame;
auto frame_ctx = p->WithFrame(top_level_frame);

// Because we are printing a single element, concise scoping should be allowed by default
top_level_frame->allow_concise_scoping = true;

Doc root_doc = p->AsDoc<Doc>(MakeTraced(root_node));

Array<StmtDoc> doc_to_print = top_level_frame->free_var_definitions;

if (const auto* stmt_doc_node = root_doc.as<StmtDocNode>()) {
doc_to_print.push_back(GetRef<StmtDoc>(stmt_doc_node));
} else if (const auto* expr_doc_node = root_doc.as<ExprDocNode>()) {
doc_to_print.push_back(ExprStmtDoc(GetRef<ExprDoc>(expr_doc_node)));
} else if (const auto* stmt_block_node = root_doc.as<StmtBlockDocNode>()) {
doc_to_print = runtime::Concat(doc_to_print, stmt_block_node->stmts);
} else if (const auto* slice_doc_node = root_doc.as<SliceDocNode>()) {
doc_to_print.push_back(ExprStmtDoc(IdDoc("_")[{GetRef<SliceDoc>(slice_doc_node)}]));
} else {
ICHECK(false) << "Cannot print " << root_doc->GetTypeKey() << " as top level doc for TIR.";
}

return StmtBlockDoc(doc_to_print);
});
} // namespace printer
} // namespace script
} // namespace tvm
89 changes: 89 additions & 0 deletions src/script/printer/tir/tir.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_SCRIPT_PRINTER_TIR_TIR_H_
#define TVM_SCRIPT_PRINTER_TIR_TIR_H_

#include <tvm/script/printer/ir_docsifier.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>

namespace tvm {
namespace script {
namespace printer {

class TIRFrameNode : public FrameNode {
public:
mutable bool allow_concise_scoping{false};

void VisitAttrs(AttrVisitor* v) {
FrameNode::VisitAttrs(v);
v->Visit("allow_concise_scoping", &allow_concise_scoping);
}

static constexpr const char* _type_key = "script.printer.TIRFrame";
TVM_DECLARE_BASE_OBJECT_INFO(TIRFrameNode, FrameNode);
};

class TIRFrame : public Frame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRFrame, Frame, TIRFrameNode);
};

class TIRTopLevelFrameNode : public TIRFrameNode {
public:
Array<StmtDoc> free_var_definitions;

void VisitAttrs(AttrVisitor* v) {
TIRFrameNode::VisitAttrs(v);
v->Visit("free_var_definitions", &free_var_definitions);
}

static constexpr const char* _type_key = "script.printer.TIRTopLevelFrame";
TVM_DECLARE_BASE_OBJECT_INFO(TIRTopLevelFrameNode, FrameNode);
};

class TIRTopLevelFrame : public TIRFrame {
public:
TIRTopLevelFrame();
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRTopLevelFrame, TIRFrame,
TIRTopLevelFrameNode);
};

class TIRGeneralFrameNode : public TIRFrameNode {
public:
static constexpr const char* _type_key = "script.printer.TIRGeneralFrame";
TVM_DECLARE_BASE_OBJECT_INFO(TIRGeneralFrameNode, FrameNode);
};

class TIRGeneralFrame : public TIRFrame {
public:
TIRGeneralFrame();
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRGeneralFrame, TIRFrame, TIRGeneralFrameNode);
};

inline IdDoc TIR(const IRDocsifier& p) { return IdDoc(p->ir_prefix.Get("tir").value_or("T")); }

ExprDoc GetTypeAnnotationDocForVar(const TracedObject<tir::Var>& var, const IRDocsifier& p);

} // namespace printer
} // namespace script
} // namespace tvm

#endif // TVM_SCRIPT_PRINTER_TIR_TIR_H_
69 changes: 69 additions & 0 deletions src/script/printer/tir/type.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <dlpack/dlpack.h>
#include <tvm/ir/type.h>
#include <tvm/node/functor.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/logging.h>
#include <tvm/script/printer/ir_docsifier.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/var.h>

#include "../utils.h"
#include "./tir.h"

namespace tvm {
namespace script {
namespace printer {

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<PrimType>("tir", [](TracedObject<PrimType> ty, IRDocsifier p) -> Doc {
TracedBasicValue<DataType> dtype = ty.GetAttr(&PrimTypeNode::dtype);
String ty_str = runtime::DLDataType2String(dtype.Get());
return TIR(p)->Attr(MakeTraced(ty_str, ty.GetPath()));
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<PointerType>("tir", [](TracedObject<PointerType> ty, IRDocsifier p) -> Doc {
TracedObject<Type> element_type = ty.GetAttr(&PointerTypeNode::element_type);
TracedObject<String> storage_scope = ty.GetAttr(&PointerTypeNode::storage_scope);

ExprDoc element_type_doc = p->AsDoc<ExprDoc>(element_type);
if (storage_scope.Get().empty()) {
return TIR(p)->Attr("Ptr")->Call({element_type_doc});
} else {
return TIR(p)->Attr("Ptr")->Call({element_type_doc, LiteralDoc::Str(storage_scope)});
}
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<TupleType>("tir", [](TracedObject<TupleType> ty, IRDocsifier p) -> Doc {
auto fields = ty.GetAttr(&TupleTypeNode::fields);

if (fields.empty()) {
return LiteralDoc::None(fields.GetPath());
}
return TIR(p)->Attr("Tuple")->Call(AsExprDocArray(fields, p));
});

} // namespace printer
} // namespace script
} // namespace tvm
77 changes: 77 additions & 0 deletions src/script/printer/tir/var.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/node/functor.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/logging.h>
#include <tvm/script/printer/ir_docsifier.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/var.h>

#include "./tir.h"

namespace tvm {
namespace script {
namespace printer {

TracedObject<String> GetVarNameHint(const TracedObject<tir::Var>& var) {
TracedObject<String> name_hint = var.GetAttr(&tir::VarNode::name_hint);
if (name_hint.Get().empty()) {
return MakeTraced(String("v"), var.GetPath());
} else {
return name_hint;
}
}

IdDoc CreateFreeVariableDefinition(TracedObject<tir::Var> var, IRDocsifier p) {
TracedObject<String> name_hint = GetVarNameHint(var);
// TODO(yelite): When implementing the PrimFunc printing, the logic here
// needs to change, putting variable def into PrimFuncFrame if it exists.
TIRTopLevelFrame top_level_frame = p->GetFrame<TIRTopLevelFrame>().value();
IdDoc doc = p->vars->Define(var.Get(), name_hint, top_level_frame);
StmtDoc def_doc = AssignDoc(doc, NullOpt, GetTypeAnnotationDocForVar(var, p));
top_level_frame->free_var_definitions.push_back(def_doc);
return doc;
}

ExprDoc PrintVariable(TracedObject<tir::Var> var, IRDocsifier p) {
Optional<ExprDoc> doc = p->vars->GetVarDoc(var);
if (doc.defined()) {
return doc.value();
} else {
return CreateFreeVariableDefinition(var, p);
}
}

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch<tir::Var>(PrintVariable);
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::SizeVar>([](TracedObject<tir::SizeVar> var, IRDocsifier p) {
return PrintVariable(MakeTraced(var.Get(), var.GetPath()), p);
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::IterVar>([](TracedObject<tir::IterVar> v, IRDocsifier p) -> Doc {
LOG(FATAL) << "Cannot print IterVar directly. Please use the helper functions in tir.h for "
"specific usage of IterVar.";
throw;
});

} // namespace printer
} // namespace script
} // namespace tvm
Loading

0 comments on commit 7551fe0

Please sign in to comment.