Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity] Fix FuseOpsByPattern when a subgraph can be matched by multiple residual patterns #15308

Merged
merged 1 commit into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/relax/transform/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,18 @@ class PatternBasedPartitioner : ExprVisitor {
if (check_ != nullptr && !check_(CreatePatternCheckContext(call, matches_opt.value()))) {
return;
}

for (const auto& [pat, match] : matches_opt.value()) {
if ((pat->IsInstance<CallPatternNode>() && match != GetRef<Call>(call)) ||
pat->IsInstance<TupleGetItemPatternNode>()) {
auto g = GetGroup(match);
if (g && g->FindRoot()->num_nodes > 1) {
// This expression has already been matched to a previous pattern.
return;
}
}
}

// If a match is found, put all matching expressions into the same group.
// OperatorFusor also requires that the bound variable be in the same group as the RHS value.
// Since is_op(...) based pattern only matches against call nodes on the right hand side,
Expand Down Expand Up @@ -1108,6 +1120,13 @@ class PatternBasedPartitioner : ExprVisitor {
return group_map_[bound_var.get()]->FindRoot();
}

Group* GetGroup(const Expr& exp) {
if (value_to_bound_var_.count(exp) && group_map_.count(value_to_bound_var_[exp].get())) {
return group_map_[value_to_bound_var_[exp].get()];
}
return nullptr;
}

PatternCheckContext CreatePatternCheckContext(const CallNode* call,
const Map<DFPattern, Expr>& matched_result) {
Map<String, Expr> annotated_expr;
Expand Down
27 changes: 27 additions & 0 deletions tests/python/relax/test_transform_fuse_ops_by_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
wildcard,
)
from tvm.relax.transform import PatternCheckContext
from tvm.relax.backend.contrib.cutlass import partition_for_cutlass
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
Expand Down Expand Up @@ -894,5 +895,31 @@ def main(
check(mod, [("x.clip", pat_clip)], Expected2)


def test_matmul_add3():
@I.ir_module
class Module:
@R.function
def main(
x: R.Tensor((32, 8), dtype="float16"),
y: R.Tensor((8, 8), dtype="float16"),
x2: R.Tensor((32, 8), dtype="float16"),
y2: R.Tensor((8, 8), dtype="float16"),
bias: R.Tensor((8,), dtype="float16"),
residual: R.Tensor((32, 8), dtype="float16"),
) -> R.Tensor((32, 8), dtype="float16"):
with R.dataflow():
lv_: R.Tensor((32, 8), dtype="float16") = R.matmul(x2, y2, out_dtype="float16")
lv: R.Tensor((32, 8), dtype="float16") = R.matmul(x, y, out_dtype="float16")
lv1: R.Tensor((32, 8), dtype="float16") = R.add(lv, bias)
lv2: R.Tensor((32, 8), dtype="float16") = R.add(lv1, lv_)
out: R.Tensor((32, 8), dtype="float16") = R.add(lv2, residual)
R.output(out)
return out

mod = partition_for_cutlass(Module)
func_names = [name.name_hint for (name, _) in mod.functions.items()]
assert "fused_relax_matmul_relax_add_relax_add_cutlass" in func_names


if __name__ == "__main__":
pytest.main([__file__])