Skip to content

Commit

Permalink
change FuseOpsByPattern policy to match max subgraph
Browse files Browse the repository at this point in the history
body: change FuseOpsByPattern policy to match max subgraph
  • Loading branch information
wanghuibin0 committed May 9, 2024
1 parent 28d32b5 commit 9d46d67
Showing 1 changed file with 29 additions and 2 deletions.
31 changes: 29 additions & 2 deletions src/relax/transform/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1073,7 +1073,11 @@ class PatternBasedPartitioner : ExprVisitor {
current_block_use_def_ = {};
}

void VisitVarDef(const Var& var) final { group_map_[var.get()] = arena_->make<Group>(); }
void VisitVarDef(const Var& var) final {
Group* g = arena_->make<Group>();
group_map_[var.get()] = g;
vars_in_group_[g].push_back(var);
}

void VisitBinding_(const VarBindingNode* binding) final {
bindings_.Set(binding->var, binding->value);
Expand All @@ -1097,7 +1101,13 @@ class PatternBasedPartitioner : ExprVisitor {
auto g = GetGroup(match);
if (g && g->FindRoot()->num_nodes > 1) {
// This expression has already been matched to a previous pattern.
return;
// If the prior matched subgraph is subsumed by the new matched one,
// we can safely merge them, obtaining a maximized matched subgraph enventually.
// Otherwise, merging them will result in an incorrect subgraph,
// so we keep the prior subgraph and discard the current one by directly return.
auto vars_in_prior_matched_graph = vars_in_group_[g];
if (!GraphSubsumedInMatchedValues(vars_in_prior_matched_graph, matches_opt.value()))
return;
}
}
}
Expand Down Expand Up @@ -1145,6 +1155,7 @@ class PatternBasedPartitioner : ExprVisitor {
if (group_map_[e.get()] != to) {
--group_map_[e.get()]->num_nodes;
group_map_[e.get()]->parent = to;
vars_in_group_[to].push_back(e);
++to->num_nodes;
}
}
Expand Down Expand Up @@ -1181,6 +1192,21 @@ class PatternBasedPartitioner : ExprVisitor {
current_block_use_def_, value_to_bound_var_);
}

// check if a previous matched subgraph is subsumed by the current matched result
bool GraphSubsumedInMatchedValues(const Array<Expr>& vars_in_graph,
const Map<DFPattern, Expr>& matched_result) {
std::set<Expr> matched_vars;
for (const auto& [pat, match] : matched_result) {
if ((pat->IsInstance<CallPatternNode>() || pat->IsInstance<TupleGetItemPatternNode>()))
matched_vars.insert(value_to_bound_var_[match]);
}

for (const auto var : vars_in_graph) {
if (matched_vars.find(var) == matched_vars.end()) return false;
}
return true;
}

String pat_name_;
DFPattern pat_;
Map<String, DFPattern> annotation_pat_;
Expand All @@ -1191,6 +1217,7 @@ class PatternBasedPartitioner : ExprVisitor {
Map<Expr, Var> value_to_bound_var_;
Map<Var, Array<Var>> current_block_use_def_;
GroupMap group_map_;
std::map<Group*, Array<Expr>> vars_in_group_;
};

/*!
Expand Down

0 comments on commit 9d46d67

Please sign in to comment.