From 826809a291054b6281f01e47db2b5b4b0e187695 Mon Sep 17 00:00:00 2001 From: jiahy0825 Date: Wed, 6 Mar 2024 08:12:59 +0000 Subject: [PATCH 1/3] redefine OpTopoPattern --- .../{ops_topo_pattern.h => op_topo_pattern.h} | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) rename paddle/cinn/api/{ops_topo_pattern.h => op_topo_pattern.h} (59%) diff --git a/paddle/cinn/api/ops_topo_pattern.h b/paddle/cinn/api/op_topo_pattern.h similarity index 59% rename from paddle/cinn/api/ops_topo_pattern.h rename to paddle/cinn/api/op_topo_pattern.h index 88d4084ec10c5..fe2ac78d36e16 100644 --- a/paddle/cinn/api/ops_topo_pattern.h +++ b/paddle/cinn/api/op_topo_pattern.h @@ -16,18 +16,22 @@ struct ReductionPattern {}; template struct PartialShardablePattern {}; -// SR := [R | PS] template -using ShardableReductionsPattern = std::vector, PartialShardablePattern>>; +using ShardableReductionPattern = std::vector, PartialShardablePattern>>; // fuse rules: // 1. IS * PS -> PS // 2. PS * PS -> PS -// 3. R * PS -> RS -// 4. RS * (PS | R) -> RS +// 3. PS * R -> R +// 4. IS * R -> R -// OpsTopoPattern := IS | SR +// lifting rules: +// 1. R -> SR +// 2. PS -> SR +// 3. SR * SR -> SR + +// OpTopoPattern := IS | SR template -using OpsTopoPattern = std::variant, ShardableReductionsPattern>; +using OpTopoPattern = std::variant, ShardableReductionPattern>; } From 918095c037a3c24533da8fb542e9df64e0015d58 Mon Sep 17 00:00:00 2001 From: jiahy0825 Date: Wed, 6 Mar 2024 08:14:22 +0000 Subject: [PATCH 2/3] fix typo --- paddle/cinn/api/op_topo_pattern.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/cinn/api/op_topo_pattern.h b/paddle/cinn/api/op_topo_pattern.h index fe2ac78d36e16..47c7f2b225fec 100644 --- a/paddle/cinn/api/op_topo_pattern.h +++ b/paddle/cinn/api/op_topo_pattern.h @@ -17,7 +17,7 @@ template struct PartialShardablePattern {}; template -using ShardableReductionPattern = std::vector, PartialShardablePattern>>; +using ShardableReductionsPattern = std::vector, PartialShardablePattern>>; // fuse rules: // 1. IS * PS -> PS @@ -32,6 +32,6 @@ using ShardableReductionPattern = std::vector, // OpTopoPattern := IS | SR template -using OpTopoPattern = std::variant, ShardableReductionPattern>; +using OpTopoPattern = std::variant, ShardableReductionsPattern>; } From 7731441dcba3fc38e863ecbd1b03ead6a22e8fc0 Mon Sep 17 00:00:00 2001 From: jiahy0825 Date: Wed, 6 Mar 2024 08:15:45 +0000 Subject: [PATCH 3/3] add comments for SR --- paddle/cinn/api/op_topo_pattern.h | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/cinn/api/op_topo_pattern.h b/paddle/cinn/api/op_topo_pattern.h index 47c7f2b225fec..8febb35a20e6e 100644 --- a/paddle/cinn/api/op_topo_pattern.h +++ b/paddle/cinn/api/op_topo_pattern.h @@ -16,6 +16,7 @@ struct ReductionPattern {}; template struct PartialShardablePattern {}; +// SR := [R | PS] template using ShardableReductionsPattern = std::vector, PartialShardablePattern>>;