Skip to content

Commit

Permalink
[Relax] Stabilize relax pass mutation order (#16883)
Browse files Browse the repository at this point in the history
The current implementation of the relax pass is not stable, to be more
specific, the order of the mutation is not stable. This PR aims to
stabilize the mutation order of the relax pass, and further stabilize
the output of the relax pass.

Also fixes a minor doc typo in NN frontend
  • Loading branch information
Hzfengsy authored Apr 15, 2024
1 parent a64d1f1 commit f267691
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 20 deletions.
3 changes: 2 additions & 1 deletion include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ class IRModuleNode : public Object {
TVM_DLL GlobalVar GetGlobalVar(const String& str) const;

/*!
* \brief Collect all global vars defined in this module.
* \brief Collect all global vars defined in this module, ordered by
* the global variable name.
* \returns An array of global vars
*/
TVM_DLL Array<GlobalVar> GetGlobalVars() const;
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relax/frontend/nn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,10 +475,10 @@ def export_tvm(
-------
irmodule : tvm.ir.IRModule
The converted tvm IR representation of the model.
params : Dict[str, tvm.nd.array]
A dictionary of parameters corresponding to the weights of
the model.
params : List[Tuple[str, Parameter]]
A list of Parameters corresponding to the weights of the model.
ext_mods : List[nn.ExternModule]
A list of ExternModules that are used in the model.
"""
# pylint: disable=import-outside-toplevel
from . import spec as _spec
Expand Down
4 changes: 4 additions & 0 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <tvm/node/structural_equal.h>
#include <tvm/runtime/registry.h>

#include <algorithm>
#include <fstream>
#include <sstream>
#include <unordered_set>
Expand Down Expand Up @@ -183,6 +184,9 @@ tvm::Array<GlobalVar> IRModuleNode::GetGlobalVars() const {
for (const auto& pair : global_var_map_) {
global_vars.push_back(pair.second);
}
std::sort(global_vars.begin(), global_vars.end(), [](const GlobalVar& lhs, const GlobalVar& rhs) {
return lhs->name_hint < rhs->name_hint;
});
return tvm::Array<GlobalVar>(global_vars);
}

Expand Down
3 changes: 2 additions & 1 deletion src/relax/transform/alter_op_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ class AlterOpImplMutator : public ExprMutator {
op_buffer_axis_separators__(axis_separators_) {}

IRModule Run() {
for (const auto& [gv, func] : mod_->functions) {
for (const auto& gv : mod_->GetGlobalVars()) {
const auto& func = mod_->Lookup(gv);
if (func->IsInstance<relax::FunctionNode>()) {
relax::Function update_func = Downcast<Function>(VisitExpr(func));
builder_->UpdateFunction(gv, update_func);
Expand Down
3 changes: 2 additions & 1 deletion src/relax/transform/dead_code_elimination.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ IRModule DeadCodeElimination(const IRModule& arg_mod, Array<runtime::String> ent
for (const auto& name : entry_function_names) {
entry_functions.insert(mod->GetGlobalVar(name));
}
for (const auto& [gv, func] : mod->functions) {
for (const auto& gv : mod->GetGlobalVars()) {
const auto& func = mod->Lookup(gv);
if (func.as<ExternFuncNode>() || func->GetLinkageType() == LinkageType::kExternal) {
entry_functions.insert(gv);
}
Expand Down
22 changes: 12 additions & 10 deletions src/relax/transform/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,8 @@ class OperatorFusor : public ExprMutator {
* \return The new IRModule after transformation
*/
IRModule Transform() {
for (const auto& [gv, func] : mod_->functions) {
for (const auto& gv : mod_->GetGlobalVars()) {
const auto& func = mod_->Lookup(gv);
// Only visit Relax function without attr kPrimitive.
if (func->IsInstance<relax::FunctionNode>() && !func->HasNonzeroAttr(attr::kPrimitive)) {
auto updated_func = Downcast<Function>(VisitExpr(func));
Expand Down Expand Up @@ -1196,9 +1197,9 @@ class CompositeFunctionAnnotator : public ExprMutator {

IRModule Run() {
auto mod = builder_->GetContextIRModule();
auto all_functions = mod->functions;
for (const auto& entry : all_functions) {
if (const auto* func = entry.second.as<FunctionNode>()) {
for (const auto& gv : mod->GetGlobalVars()) {
const auto& base_func = mod->Lookup(gv);
if (const auto* func = base_func.as<FunctionNode>()) {
if (func->GetAttr<String>(attr::kComposite).defined() ||
func->GetAttr<String>(attr::kCodegen).defined()) {
continue;
Expand All @@ -1208,7 +1209,7 @@ class CompositeFunctionAnnotator : public ExprMutator {
if (!new_body.same_as(func->body)) {
auto new_func = Function(func->params, new_body, func->ret_struct_info, func->is_pure,
func->attrs, func->span);
builder_->UpdateFunction(entry.first, new_func);
builder_->UpdateFunction(gv, new_func);
}
}
}
Expand Down Expand Up @@ -1272,11 +1273,12 @@ IRModule FuseOpsByPattern(const tvm::Array<transform::FusionPattern>& patterns,
support::Arena arena;
for (const auto& pattern : patterns) {
OperatorFusor::GroupMap group_map;
for (const auto& entry : mod->functions) {
if (entry.second->IsInstance<tir::PrimFuncNode>()) {
for (const auto& gv : mod->GetGlobalVars()) {
const auto& base_func = mod->Lookup(gv);
if (base_func->IsInstance<tir::PrimFuncNode>()) {
continue;
}
const FunctionNode* function = entry.second.as<FunctionNode>();
const FunctionNode* function = base_func.as<FunctionNode>();
if (function->GetAttr<Integer>(attr::kPrimitive).defined() ||
function->GetAttr<String>(attr::kComposite).defined() ||
function->GetAttr<String>(attr::kCodegen).defined()) {
Expand All @@ -1285,8 +1287,8 @@ IRModule FuseOpsByPattern(const tvm::Array<transform::FusionPattern>& patterns,

auto map = PatternBasedPartitioner::Run(pattern->name, pattern->pattern,
pattern->annotation_patterns,
pattern->check.value_or(nullptr), entry.second,
&arena, pattern->attrs_getter.value_or(nullptr));
pattern->check.value_or(nullptr), base_func, &arena,
pattern->attrs_getter.value_or(nullptr));
for (const auto& [key, value] : map) {
CHECK(!group_map.count(key))
<< "ValueError: "
Expand Down
3 changes: 2 additions & 1 deletion src/relax/transform/fuse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,8 @@ class TIRFuseMutator : public ExprMutator {
static IRModule Transform(IRModule mod) {
// Collect all primitive relax functions
Map<GlobalVar, Function> primitive_relax;
for (const auto& [gvar, base_func] : mod->functions) {
for (const auto& gvar : mod->GetGlobalVars()) {
const auto& base_func = mod->Lookup(gvar);
// Only fuse primitive relax functions
if (base_func->HasNonzeroAttr(attr::kPrimitive)) {
if (auto func = base_func.as<relax::Function>()) {
Expand Down
8 changes: 5 additions & 3 deletions src/relax/transform/legalize_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,18 @@ class LegalizeMutator : public ExprMutator {
}

IRModule Transform() {
for (const auto& [gv, func] : mod_->functions) {
for (const auto& gv : mod_->GetGlobalVars()) {
const auto& func = mod_->Lookup(gv);
if (func->IsInstance<FunctionNode>()) {
auto updated_func = Downcast<Function>(this->VisitExpr(func));
builder_->UpdateFunction(gv, Downcast<BaseFunc>(updated_func));
}
}
// Fill the "kTarget" attribute of PrimFunc
for (const auto& [gv, func] : builder_->GetContextIRModule()->functions) {
const auto& mod = builder_->GetContextIRModule();
for (const auto& gv : mod->GetGlobalVars()) {
const tir::PrimFuncNode* prim_func;
if (tmap_.count(gv) && (prim_func = func.as<tir::PrimFuncNode>())) {
if (tmap_.count(gv) && (prim_func = mod->Lookup(gv).as<tir::PrimFuncNode>())) {
auto f = WithAttr(GetRef<tir::PrimFunc>(prim_func), tvm::attr::kTarget, tmap_[gv]);
builder_->UpdateFunction(gv, f);
}
Expand Down

0 comments on commit f267691

Please sign in to comment.