Skip to content

Commit

Permalink
[CINN / PIR] Cinn trivalop fuse (#62088)
Browse files Browse the repository at this point in the history
* implement FuseFilteredStmtPatterns

* update

* split trivial op into a single file.

* fix compiler complaints

* rename StmtIter to StmtPtr

* declare group_pattern.InferShardableAxes

* refine signature of group_pattern.InferShardableAxes

* move group_pattern.InferShardableAxes to group_pattern_util.InferShardableAxes

* implement group_pattern_util.InferShardableAxes

* add group_pattern_util.InferShardableAxesFromSink

* ReversedInferShardableAxes support sinks

* update op lower

* support multiple sinks in group_pattern_util.InferShardableAxes

* update

* fix link error

* update

* remove FusionOp to OpList

* update

* update

* update

* update

* declare group_pattern_util.h

* fix compiler complains

* declare group_pattern_util.ClusteringHelper

* refine signature of group_pattern_util.ClusterIntoGroupPatternsFromOpList

* update op lowr

* add todo

* minor refine by group_pattern_util.OpSet

* update

* update

* update (#57)

* update

* update

* Cinn trivalop fuse (#58)

* fix

* refactor StmtFusionHelper by OpTopo

* Complete: CreateReduceExpr function.

* update

* recursive done.

* update

* Cinn trivalop fuse (#59)

* clean all the TODO.

* update

* fix cluster

* remove unused OpTopo.downstream_disconnected_ops

* Cinn trivalop fuse (#60)

* fix compile rror

* update

* Cinn trivalop fuse (#61)

* add R + T skeleon

* add search utils.

* update

* Cinn trivalop fuse (#62)

* push

* update

* fix

* fix transformer

* fix

* Implement iterator vars fetching in ReduceOp

* small fix

* add GetOuterIterVars API

* fix

* fix compile complain

* modify GetOutputIters of TrivialOp

* remove dumplicate code in visit

* implement ClusterIntoGroupPatternsFromOpList

* Fix most error in trivial_op.cc.

* CreateReduceExpr is OK!

* fix

* add CheckIterEq

* implement group_pattern_util.ClusteringEngine and groupp_pattern_util.ClusteringPolicy

* SinkTrivialTransform OK!

* update

* fix init_tensor name problem.

* update

* fix compiler complains

* refactor ShardableAxesSignature by group_pattern.SoleOutputShardableAxes

* split trivial_op.cc

* update

* implement group_pattern_util.MakeShardableAxesSignature4ReduceOp

* update

* implement group_pattern_util.MakeEmptyShardableAxesSignature

* add helper class group_pattern_util.ShardableAxesProvider

* implement group_pattern_util.MakeShardableAxesSignature4BroadcastOp

* update

* update

* fix softmax error.!

* fix

* update

* merge

* fix

* Implement new OpMergeWithOp and add a relevant flag

* update

* update

* fix reduce_load error. add splitReduceTransform

* fix conflict

* update

* update

* update

* disable horizontal fusion

* fix

* Add some VLOG

* Fix group cluster bug (#71)

* fix

* fix dyshape

* fix

* init split cluster files

* update

* update

* update

* spliting

* update

* spliting

* spliting

* pattern utils

* update

* update

* clean cmake

* update

* update

* update

* fix clustering_engine

* fix fusion_helper

* update

* fix

* update

* update

* update

* update

* fix

* fix some erros

* update

* update

* fix split with num problem

* update

* fix

* fix static issues

* fix

* init split cluster files (#72)

* update

* update

* update

* update

* update

* update

* update

* update

* update

* split shardable axes provider (#73)

* update

* update

* fix broadcast (#75)

* update

* update

* fix

* fix code format

* fix code format

* remove unittest

* update

* update (#77)

* update

* update

* update

---------

Co-authored-by: tc20042008 <156998525+tc20042008@users.noreply.github.com>
Co-authored-by: feifei-111 <2364819892@qq.com>
Co-authored-by: jiahy0825 <jiahongyu@baidu.com>
Co-authored-by: zhangbaizhou <zhangbaizhou@baidu.com>
Co-authored-by: Baizhou Zhang <eddiezhang@pku.edu.cn>
  • Loading branch information
6 people authored Mar 26, 2024
1 parent c3f5747 commit fec0b3d
Show file tree
Hide file tree
Showing 46 changed files with 3,198 additions and 109 deletions.
77 changes: 77 additions & 0 deletions paddle/cinn/api/op_topo_pattern.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <list>
#include <variant>
#include <vector>

namespace cinn::api {

template <typename T>
struct ErrorPattern {};

// ElementWise/Broadcast/Injective Ops without reduction ancestors.
template <typename T>
struct InjectiveSourcePattern {};

// Reduce op
template <typename T>
struct SingleReductionOpPattern {};

// ElementWise/Broadcast ops which have shardable dimentions and reduction
// ancestors.
template <typename T>
struct PartialShardablePattern {};

// Reduce base pattern
template <typename T>
struct ReductionPattern {
using Nothing = std::monostate;
std::variant<Nothing, InjectiveSourcePattern<T>, PartialShardablePattern<T>>
input;
SingleReductionOpPattern<T> reduce_op_pattern;

bool HasFusedInput() const {
return !std::holds_alternative<Nothing>(this->input);
}
};

// Stmt := IS | R | PS
// ops in StmtPattern will be lowered into a inlined cuda code.
template <typename T>
using StmtPattern = std::variant<InjectiveSourcePattern<T>,
ReductionPattern<T>,
PartialShardablePattern<T>>;

// Stmts := [Stmt]
template <typename T>
using StmtPatternVec = std::vector<StmtPattern<T>>;
// fuse rules:
// 1. IS * IS -> IS
// 2. PS * PS -> PS
// 3. IS * PS -> PS
// 4. IS * R -> R
// 5. PS * R -> R
// lifting rules:
// 1. R -> Stmts
// 2. PS -> Stmts
// 3. Stmts * Stmts -> Stmts
// OpTopoPattern := Error | Stmts

template <typename T>
using OpTopoPattern = std::variant<ErrorPattern<T>, StmtPatternVec<T>>;

} // namespace cinn::api
23 changes: 1 addition & 22 deletions paddle/cinn/ast_gen_ius/ast_gen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,6 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
const std::vector<ir::Var>& reduce_axis = tensor->reduce_axis;
VLOG(4) << "ast gen: tensor init_body is " << init_body;
for (int i = 0; i < shape.size(); ++i) {
bool is_keep_dim = axis[i]->is_keepdim;
if (FLAGS_group_schedule_tiling_first && is_keep_dim) {
// if tiling first, we need to replace the reduce axis with 0, but don't
// deal with the non-reduce axis
optim::ReplaceVarWithExpr(&init_body, axis[i], Expr(0));
continue;
}
if (!FLAGS_group_schedule_tiling_first &&
FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) {
optim::ReplaceVarWithExpr(&init_body, axis[i], Expr(0));
Expand Down Expand Up @@ -144,13 +137,6 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
// for same axis so we re-create objects
std::vector<Var> reduce_axis_vars = cinn::common::GenDefaultAxis(axis_len);
for (int i = 0; i < shape.size(); ++i) {
bool is_keep_dim = axis[i]->is_keepdim;
if (FLAGS_group_schedule_tiling_first && is_keep_dim) {
// if tiling first, we need to replace the reduce axis with 0, but don't
// deal with the non-reduce axis
optim::ReplaceVarWithExpr(&reduce_body, axis[i], Expr(0));
continue;
}
if (!FLAGS_group_schedule_tiling_first &&
FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) {
optim::ReplaceVarWithExpr(&reduce_body, axis[i], Expr(0));
Expand Down Expand Up @@ -185,10 +171,7 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
std::vector<ir::Var> non_reduce_axis_vars = [&]() {
std::vector<ir::Var> res;
for (int i = 0; i < shape.size(); ++i) {
bool is_keep_dim = axis[i]->is_keepdim;
if (!is_keep_dim) {
res.push_back(axis[i]);
}
res.push_back(axis[i]);
}
return res;
}();
Expand Down Expand Up @@ -240,10 +223,6 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
// Put the two parts together
ir::Expr body = ir::Block::Make({init_body, reduce_body});
for (int i = static_cast<int>(axis_len) - 1; i >= 0; --i) {
bool is_keep_dim = axis[i]->is_keepdim;
if (FLAGS_group_schedule_tiling_first && is_keep_dim) {
continue;
}
if ((!FLAGS_group_schedule_tiling_first || !FLAGS_cinn_bucket_compile) &&
shape[i] == Expr(1)) {
continue;
Expand Down
1 change: 1 addition & 0 deletions paddle/cinn/backends/codegen_cuda_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ detail::CollectBucketStrategyHostFunctionVisitor::GenDeviceKernelName(

void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
ir::Expr func, ir::Expr predicate) {
VLOG(4) << "Process Lowered Func" << func;
ir::_LoweredFunc_ *func_node = func.as_lowered_func();
CHECK(func_node);
if (!func_node->cuda_axis_info.valid()) {
Expand Down
1 change: 1 addition & 0 deletions paddle/cinn/frontend/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ add_subdirectory(paddle)
add_subdirectory(decomposer)
add_subdirectory(op_mappers)
add_subdirectory(pass)
add_subdirectory(group_cluster)

cinn_cc_test(test_op_mapper_registry SRCS op_mapper_registry_test.cc DEPS
cinncore)
6 changes: 6 additions & 0 deletions paddle/cinn/frontend/group_cluster/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
gather_srcs(group_cluster_src SRCS common_utils.cc pattern_node.cc
pattern_graph.cc)

add_subdirectory(cluster_policy)

cc_library(group_cluster SRCS ${group_cluster_src})
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
gather_srcs(group_cluster_src SRCS general_topo_policy.cc policy_manager.cc)

add_subdirectory(shardable_axes_policy)
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/frontend/group_cluster/cluster_policy/general_topo_policy.h"

namespace cinn::frontend::group_cluster::policy {

bool GeneralTopoPolicy::CanFuse(const PatternNodePtr upstream,
const PatternNodePtr downstream) {
// TODO(wuzhanfei) topo policy (if lead to loop)
return false;
}

} // namespace cinn::frontend::group_cluster::policy
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "paddle/cinn/frontend/group_cluster/cluster_policy/policy_manager.h"

namespace cinn::frontend::group_cluster::policy {

class GeneralTopoPolicy final : virtual public Policy {
public:
bool CanFuse(const PatternNodePtr upstream, const PatternNodePtr downstream);
};

} // namespace cinn::frontend::group_cluster::policy
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/frontend/group_cluster/cluster_policy/policy_manager.h"
#include "paddle/common/enforce.h"

namespace cinn::frontend::group_cluster::policy {

bool PolicyManager::CanFuse(const PatternNodePtr upstream,
const PatternNodePtr downstream) {
for (const auto& policy : policies_) {
if (!policy->CanFuse(upstream, downstream)) return false;
}
return true;
}

} // namespace cinn::frontend::group_cluster::policy
39 changes: 39 additions & 0 deletions paddle/cinn/frontend/group_cluster/cluster_policy/policy_manager.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/cinn/frontend/group_cluster/pattern_node.h"

namespace cinn::frontend::group_cluster::policy {

class Policy {
public:
virtual bool CanFuse(const PatternNodePtr upstream,
const PatternNodePtr downstream) = 0;
};

using PolicyPtr = std::shared_ptr<Policy>;

class PolicyManager {
public:
explicit PolicyManager(const std::vector<PolicyPtr>& policies)
: policies_(policies) {}
bool CanFuse(const PatternNodePtr upstream, const PatternNodePtr downstream);

private:
std::vector<PolicyPtr> policies_;
};

} // namespace cinn::frontend::group_cluster::policy
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
gather_srcs(group_cluster_src SRCS shardable_axes_base.cc
shardable_axes_policy.cc)
Loading

0 comments on commit fec0b3d

Please sign in to comment.