Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#67 from tc20042008/xk-cinn-trivalop-fuse
Browse files Browse the repository at this point in the history
Xk cinn trivalop fuse
  • Loading branch information
tc20042008 authored Mar 14, 2024
2 parents 7bd854e + 23e8341 commit 5180b55
Show file tree
Hide file tree
Showing 7 changed files with 1,110 additions and 188 deletions.
3 changes: 2 additions & 1 deletion paddle/cinn/api/op_topo_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ struct ReductionPattern {

using Nothing = std::monostate;
std::variant<Nothing, InjectiveSourcePattern<T>, PartialShardablePattern<T>> input;
SingleReductionOpPattern<T> reduction_op_pattern;
SingleReductionOpPattern<T> reduce_op_pattern;

bool HasFusedInput() const {
return !std::holds_alternative<Nothing>(this->input);
Expand All @@ -52,6 +52,7 @@ using StmtsPattern = std::vector<StmtPattern<T>>;
// 2. PS -> Stmts
// 3. Stmts * Stmts -> Stmts
// OpTopoPattern := Error | Stmts

template <typename T>
using OpTopoPattern = std::variant<ErrorPattern<T>, StmtsPattern<T>>;

Expand Down
19 changes: 7 additions & 12 deletions paddle/cinn/frontend/group_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ struct ErrorPattern<frontend::FrontendPattern> {
template<>
struct InjectiveSourcePattern<frontend::FrontendPattern> {
std::vector<const pir::Operation*> ops;
const pir::Operation* sole_sink;
};

template<>
Expand All @@ -136,6 +137,7 @@ struct SingleReductionOpPattern<frontend::FrontendPattern> {
template<>
struct PartialShardablePattern<frontend::FrontendPattern> {
std::vector<const pir::Operation*> ops;
const pir::Operation* sole_sink;
frontend::ShardableAxesSignature shardable_axes_signature;
};

Expand All @@ -146,19 +148,12 @@ namespace cinn::frontend {
using ErrorGroupPattern = api::ErrorPattern<frontend::FrontendPattern>;
using GroupPattern = api::OpTopoPattern<frontend::FrontendPattern>;

template <typename T>
struct PatternBranches {
using LhsLessThanRhs = adt::LT<symbol::DimExpr, symbol::DimExpr>;
using LhsGreaterEqualRhs = adt::GE<symbol::DimExpr, symbol::DimExpr>;
using Condition = std::variant<LhsLessThanRhs, LhsGreaterEqualRhs>;

Condition condition;
adt::List<T> true_branch;
adt::List<T> false_branch;
struct LoopAlignableStmtsPattern {
std::vector<api::StmtPattern<frontend::FrontendPattern>> stmts;
};

// ConditionalGroupPattern = GroupPatternBranches | GroupPattern
using ConditionalGroupPattern = adt::Tree<PatternBranches, GroupPattern>;
using GroupPatternBranches = PatternBranches<ConditionalGroupPattern>;
struct ClusteringResult {
std::vector<LoopAlignableStmtsPattern> loop_alignable_list;
};

}
Loading

0 comments on commit 5180b55

Please sign in to comment.