Skip to content

Commit

Permalink
[BugFix][Relax] change FuseOpsByPattern strategy to pattern-match max…
Browse files Browse the repository at this point in the history
…imal subgraph (#16922)

* [BugFix][Relax] change FuseOpsByPattern strategy to pattern-match maximal subgraph

* add testcase

---------

Co-authored-by: Huibin Wang <wang.huibin@intellif.com>
  • Loading branch information
wanghuibin0 and wanghuibin0 authored May 10, 2024
1 parent fffd168 commit 2565aa3
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 2 deletions.
31 changes: 29 additions & 2 deletions src/relax/transform/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1073,7 +1073,11 @@ class PatternBasedPartitioner : ExprVisitor {
current_block_use_def_ = {};
}

void VisitVarDef(const Var& var) final { group_map_[var.get()] = arena_->make<Group>(); }
void VisitVarDef(const Var& var) final {
Group* g = arena_->make<Group>();
group_map_[var.get()] = g;
vars_in_group_[g].push_back(var);
}

void VisitBinding_(const VarBindingNode* binding) final {
bindings_.Set(binding->var, binding->value);
Expand All @@ -1097,7 +1101,13 @@ class PatternBasedPartitioner : ExprVisitor {
auto g = GetGroup(match);
if (g && g->FindRoot()->num_nodes > 1) {
// This expression has already been matched to a previous pattern.
return;
// If the prior matched subgraph is subsumed by the new matched one,
// we can safely merge them, obtaining a maximized matched subgraph enventually.
// Otherwise, merging them will result in an incorrect subgraph,
// so we keep the prior subgraph and discard the current one by directly return.
auto vars_in_prior_matched_graph = vars_in_group_[g];
if (!GraphSubsumedInMatchedValues(vars_in_prior_matched_graph, matches_opt.value()))
return;
}
}
}
Expand Down Expand Up @@ -1145,6 +1155,7 @@ class PatternBasedPartitioner : ExprVisitor {
if (group_map_[e.get()] != to) {
--group_map_[e.get()]->num_nodes;
group_map_[e.get()]->parent = to;
vars_in_group_[to].push_back(e);
++to->num_nodes;
}
}
Expand Down Expand Up @@ -1181,6 +1192,21 @@ class PatternBasedPartitioner : ExprVisitor {
current_block_use_def_, value_to_bound_var_);
}

// check if a previous matched subgraph is subsumed by the current matched result
bool GraphSubsumedInMatchedValues(const Array<Expr>& vars_in_graph,
const Map<DFPattern, Expr>& matched_result) {
std::set<Expr> matched_vars;
for (const auto& [pat, match] : matched_result) {
if ((pat->IsInstance<CallPatternNode>() || pat->IsInstance<TupleGetItemPatternNode>()))
matched_vars.insert(value_to_bound_var_[match]);
}

for (const auto var : vars_in_graph) {
if (matched_vars.find(var) == matched_vars.end()) return false;
}
return true;
}

String pat_name_;
DFPattern pat_;
Map<String, DFPattern> annotation_pat_;
Expand All @@ -1191,6 +1217,7 @@ class PatternBasedPartitioner : ExprVisitor {
Map<Expr, Var> value_to_bound_var_;
Map<Var, Array<Var>> current_block_use_def_;
GroupMap group_map_;
std::map<Group*, Array<Expr>> vars_in_group_;
};

/*!
Expand Down
26 changes: 26 additions & 0 deletions tests/python/relax/test_transform_fuse_ops_by_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,5 +1217,31 @@ def inner_func(
tvm.ir.assert_structural_equal(Expected, After)


def test_match_maximal_subgraph():
@R.function
def func(
x: R.Tensor((32, 8), dtype="int32"),
y: R.Tensor((8, 8), dtype="int32"),
bias: R.Tensor((8,), dtype="int32"),
) -> R.Tensor((32, 8), dtype="int32"):
R.func_attr({"global_symbol": "main"})
with R.dataflow():
lv0 = R.matmul(x, y, out_dtype="int32")
lv1 = R.add(lv0, bias)
lv2 = R.clip(lv1, -128, 127)
R.output(lv2)
return lv2

mod = tvm.IRModule({"main": func})

matmul = is_op("relax.matmul")(wildcard(), wildcard())
matmul_add = is_op("relax.add")(matmul, wildcard())
pattern = matmul_add | is_op("relax.clip")(matmul_add, wildcard(), wildcard())

partitioned = relax.transform.FuseOpsByPattern([("orclip", pattern)])(mod)
func_names = [name.name_hint for (name, _) in partitioned.functions.items()]
assert "fused_relax_matmul_relax_add_relax_clip" in func_names


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 2565aa3

Please sign in to comment.