Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#2 from Fridge003/cinn_tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
feifei-111 authored Mar 20, 2024
2 parents 384dafc + d86e15e commit bcbf191
Show file tree
Hide file tree
Showing 10 changed files with 196 additions and 74 deletions.
6 changes: 2 additions & 4 deletions paddle/cinn/frontend/cluster_ops/cluster_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

#include "paddle/cinn/frontend/cluster_ops/cluster_policy.h"

namespace cinn::frontend {
namespace cinn::frontend::cluster_ops {

class LoopAlignableClusteringPolicy final : public ClusteringPolicy {
public:
Expand Down Expand Up @@ -233,6 +233,4 @@ std::shared_ptr<ClusteringPolicy> MakeLoopAlignableClusteringPolicy(
const pir::ShapeConstraintIRAnalysis* shape_analysis) {
return std::make_shared<LoopAlignableClusteringPolicy>(shape_analysis);
}


} // namespace cinn::frontend
} // namespace cinn::frontend::cluster_ops
4 changes: 2 additions & 2 deletions paddle/cinn/frontend/cluster_ops/cluster_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include "paddle/cinn/frontend/cluster_ops/shardable_axes_provider.h"
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"

namespace cinn::frontend {
namespace cinn::frontend::cluster_ops {

class ClusteringPolicy {
public:
Expand All @@ -44,4 +44,4 @@ class ClusteringPolicy {

std::shared_ptr<ClusteringPolicy> MakeLoopAlignableClusteringPolicy(
const pir::ShapeConstraintIRAnalysis* shape_analysis);
} // namespace cinn::frontend
} // namespace cinn::frontend::cluster_ops
5 changes: 4 additions & 1 deletion paddle/cinn/frontend/cluster_ops/clustering_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/cinn/frontend/cluster_ops/cluster_engine.h"

namespace cinn::frontend::cluster_ops {
class ClusteringEngine {
public:
ClusteringEngine(const std::vector<const pir::Operation*>& ops,
Expand Down Expand Up @@ -495,4 +496,6 @@ class ClusteringEngine {
const std::shared_ptr<ClusteringPolicy> clustering_policy_;
ShardableAxesInferer shardable_axes_inferer_;
const OpTopo op_topo_;
};
};

} // namespace cinn::frontend::cluster_ops
4 changes: 2 additions & 2 deletions paddle/cinn/frontend/cluster_ops/clustering_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include "paddle/cinn/frontend/cluster_ops/shardable_axes_provider.h"


namespace cinn::frontend {
namespace cinn::frontend::cluster_ops {

class ClusteringEngine {
public:
Expand Down Expand Up @@ -117,4 +117,4 @@ class ClusteringEngine {
const OpTopo op_topo_;
};

} // namespace cinn::frontend
} // namespace cinn::frontend::cluster_ops
14 changes: 13 additions & 1 deletion paddle/cinn/frontend/cluster_ops/common_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,17 @@ std::function<size_t(const pir::Operation*)> MakeTopoOrderFinderOfOp(
return iter->second;
};
}

std::function<bool(const pir::Operation*)> MakePredicatorIsInThisFusionOp(
const std::vector<const pir::Operation*>& ops) {
std::set<const pir::Operation*> set;
for (const pir::Operation* op : ops) {
if (!op->isa<::pir::YieldOp>()) {
set.insert(op);
}
}
return [set = std::move(set)](const pir::Operation* op) {
return set.count(op) > 0;
};
}

} // namespace cinn::frontend::cluster_ops
5 changes: 4 additions & 1 deletion paddle/cinn/frontend/cluster_ops/common_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,7 @@ struct OpTopo {
}
};

}
std::function<bool(const pir::Operation*)> MakePredicatorIsInThisFusionOp(
const std::vector<const pir::Operation*>& ops);

} // namespace cinn::frontend::cluster_ops
66 changes: 3 additions & 63 deletions paddle/cinn/frontend/cluster_ops/fusion_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,70 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.

std::function<bool(const pir::Operation*)> MakePredicatorIsInThisFusionOp(
const std::vector<const pir::Operation*>& ops) {
std::set<const pir::Operation*> set;
for (const pir::Operation* op : ops) {
if (!op->isa<::pir::YieldOp>()) {
set.insert(op);
}
}
return [set = std::move(set)](const pir::Operation* op) {
return set.count(op) > 0;
};
}

std::function<bool(const pir::Operation*)> MakePredicatorIsInjectiveSource(
const OpTopo& op_topo) {
const auto& IsSource = [&](const pir::Operation* op) {
std::size_t num_inputs = 0;
op_topo.VisitInputOp(op,
[&](const pir::Operation* input) { ++num_inputs; });
return num_inputs == 0;
};

const auto starts = [&] {
std::list<const pir::Operation*> starts;
for (const auto* op : *op_topo.ops) {
if (IsSource(op)) {
starts.push_back(op);
} else {
// do nothing.
}
}
return starts;
}();

std::unordered_map<const pir::Operation*, bool> op_2_is_injective_source;

auto IsInputsAllInjectiveSource = [&](const pir::Operation* op) {
bool is_inputs_all_injective_source = true;
op_topo.VisitInputOp(op, [&](const pir::Operation* input) {
is_inputs_all_injective_source = (is_inputs_all_injective_source &&
op_2_is_injective_source.at(input));
});
return is_inputs_all_injective_source;
};
const auto VisitInput = [&](const pir::Operation* op,
const OpVisitor& DoEach) {
op_topo.VisitInputOp(op, DoEach);
};
const auto VisitOutput = [&](const pir::Operation* op,
const OpVisitor& DoEach) {
op_topo.VisitOutputOp(op, DoEach);
};
common::TopoWalker<const pir::Operation*> walker{VisitInput, VisitOutput};
walker(starts.begin(), starts.end(), [&](const pir::Operation* op) {
op_2_is_injective_source[op] =
(IsGeneralInjective(op) && IsInputsAllInjectiveSource(op));
});
return [map = std::move(op_2_is_injective_source)](const pir::Operation* op) {
const auto& iter = map.find(op);
CHECK(iter != map.end());
return iter->second;
};
}
#include "paddle/cinn/frontend/cluster_ops/fusion_helper.h"

namespace cinn::frontend::cluster_ops {
class StmtFusionHelper {
public:
StmtFusionHelper(const std::vector<const pir::Operation*>& ops,
Expand Down Expand Up @@ -524,3 +463,4 @@ class StmtFusionHelper {
std::function<bool(const pir::Operation*)> IsInjectiveSource;
std::function<size_t(const pir::Operation*)> GetOrderValue4Op;
};
} // namespace cinn::frontend::cluster_ops
111 changes: 111 additions & 0 deletions paddle/cinn/frontend/cluster_ops/fusion_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,114 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/cinn/frontend/cluster_ops/common_utils.h"
#include "paddle/cinn/frontend/cluster_ops/group_pattern.h"
#include "paddle/cinn/frontend/cluster_ops/pattern_utils.h"
#include "paddle/cinn/frontend/cluster_ops/shardable_axes_provider.h"

namespace cinn::frontend::cluster_ops {

class StmtFusionHelper {
public:
StmtFusionHelper(const std::vector<const pir::Operation*>& ops,
const ShardableAxesInferer& shardable_axes_inferer);

GroupPattern FuseToGroupPattern();

private:
std::vector<StmtPattern> ConvertToStmtsPattern();
void SortStmtPatterns(std::vector<StmtPattern>* stmt_patterns);

std::optional<ErrorGroupPattern> Fuse_IS_x_IS_2_IS(
std::vector<StmtPattern>* stmt_patterns);

std::optional<ErrorGroupPattern> Fuse_PS_x_PS_2_PS(
std::vector<StmtPattern>* stmt_patterns);

struct FusePolicy_IS_x_PS_2_PS {
static bool FuseCondition(const StmtPattern& upstream,
const StmtPattern& downstream);
static std::variant<StmtPattern, ErrorGroupPattern> MergePattern(
const StmtPattern& upstream, const StmtPattern& downstream);
static std::variant<StmtPattern, ErrorGroupPattern> MergePatternImpl(
const IS& upstream, const PS& downstream);
static ShardableAxesSignature MergeShardableAxesSignature(
const IS& upstream, const PS& downstream);
};

std::optional<ErrorGroupPattern> Fuse_IS_x_PS_2_PS(
std::vector<StmtPattern>* stmt_patterns);
struct FusePolicy_IS_x_R_2_R {
static bool FuseCondition(const StmtPattern& upstream,
const StmtPattern& downstream);
static std::variant<StmtPattern, ErrorGroupPattern> MergePattern(
const StmtPattern& upstream, const StmtPattern& downstream);
static std::variant<StmtPattern, ErrorGroupPattern> MergePatternImpl(
const IS& upstream, const R& downstream);
};

std::optional<ErrorGroupPattern> Fuse_IS_x_R_2_R(
std::vector<StmtPattern>* stmt_patterns);

struct FusePolicy_PS_x_R_2_R {
static bool FuseCondition(const StmtPattern& upstream,
const StmtPattern& downstream);
static std::variant<StmtPattern, ErrorGroupPattern> MergePattern(
const StmtPattern& upstream, const StmtPattern& downstream);
static std::variant<StmtPattern, ErrorGroupPattern> MergePatternImpl(
const PS& upstream, const R& downstream);
};

std::optional<ErrorGroupPattern> Fuse_PS_x_R_2_R(
std::vector<StmtPattern>* stmt_patterns);
StmtPattern ConvertToStmtPattern(const pir::Operation* op);

IS ConvertToIS(const pir::Operation* op);

R ConvertReductionOpToReductionPattern(const pir::Operation* op);

PS ConvertOpToPS(const pir::Operation* op);
using StmtPtr4OpT =
std::function<std::optional<StmtPattern*>(const pir::Operation*)>;
static StmtPtr4OpT MakeStmtFinderFromOp(std::vector<StmtPattern>* stmts);

template <typename IsChozenPatternT, typename ConstructPatternT>
std::optional<ErrorGroupPattern> MultiFuse(
const IsChozenPatternT& IsChozenPattern,
const ConstructPatternT& ConstructPattern,
std::vector<StmtPattern>* stmts);

struct StmtIterPair {
std::list<StmtPattern*>::iterator upstream_iter;
std::list<StmtPattern*>::iterator downstream_iter;
};

bool IsConnected(const StmtPtr4OpT& StmtFinder,
const StmtPattern* upstream,
const StmtPattern* downstream);

template <typename FuseTargetConditionT>
std::optional<StmtIterPair> FindConnetedPattenPairWithCondition(
const StmtPtr4OpT& StmtFinder,
std::list<StmtPattern*>* stmt_ptrs,
const FuseTargetConditionT& FuseTargetCondition);

template <typename FusionPolicy>
std::optional<ErrorGroupPattern> FuseFilteredStmtPatterns(
std::vector<StmtPattern>* stmt_patterns)

ShardableAxesSignature GetShardableAxesSignature(const OpTopo& op_topo)

private:
std::vector<const pir::Operation*> ops_;
ShardableAxesInferer shardable_axes_inferer_;
OpTopo op_topo_;
std::function<bool(const pir::Operation*)> IsInThisOpList;
std::function<bool(const pir::Operation*)> IsInjectiveSource;
std::function<size_t(const pir::Operation*)> GetOrderValue4Op;
};

} // namespace cinn::frontend::cluster_ops
52 changes: 52 additions & 0 deletions paddle/cinn/frontend/cluster_ops/pattern_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,4 +180,56 @@ auto VisitCachedOutput = [stmt2outputs](const auto* stmt,
};
return common::TopoWalker<const StmtPattern*>(VisitCachedInput,
VisitCachedOutput);

}

std::function<bool(const pir::Operation*)> MakePredicatorIsInjectiveSource(
const OpTopo& op_topo) {
const auto& IsSource = [&](const pir::Operation* op) {
std::size_t num_inputs = 0;
op_topo.VisitInputOp(op,
[&](const pir::Operation* input) { ++num_inputs; });
return num_inputs == 0;
};

const auto starts = [&] {
std::list<const pir::Operation*> starts;
for (const auto* op : *op_topo.ops) {
if (IsSource(op)) {
starts.push_back(op);
} else {
// do nothing.
}
}
return starts;
}();

std::unordered_map<const pir::Operation*, bool> op_2_is_injective_source;

auto IsInputsAllInjectiveSource = [&](const pir::Operation* op) {
bool is_inputs_all_injective_source = true;
op_topo.VisitInputOp(op, [&](const pir::Operation* input) {
is_inputs_all_injective_source = (is_inputs_all_injective_source &&
op_2_is_injective_source.at(input));
});
return is_inputs_all_injective_source;
};
const auto VisitInput = [&](const pir::Operation* op,
const OpVisitor& DoEach) {
op_topo.VisitInputOp(op, DoEach);
};
const auto VisitOutput = [&](const pir::Operation* op,
const OpVisitor& DoEach) {
op_topo.VisitOutputOp(op, DoEach);
};
common::TopoWalker<const pir::Operation*> walker{VisitInput, VisitOutput};
walker(starts.begin(), starts.end(), [&](const pir::Operation* op) {
op_2_is_injective_source[op] =
(IsGeneralInjective(op) && IsInputsAllInjectiveSource(op));
});
return [map = std::move(op_2_is_injective_source)](const pir::Operation* op) {
const auto& iter = map.find(op);
CHECK(iter != map.end());
return iter->second;
};
}
3 changes: 3 additions & 0 deletions paddle/cinn/frontend/cluster_ops/pattern_utils.h
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
common::TopoWalker<const StmtPattern*> MakeTopoWalker(
const OpTopo& op_topo, const std::vector<StmtPattern>& stmt_patterns);

std::function<bool(const pir::Operation*)> MakePredicatorIsInjectiveSource(
const OpTopo& op_topo);

0 comments on commit bcbf191

Please sign in to comment.