Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pir] Add conv fuse pass #58252

Closed
wants to merge 14 commits into from
2 changes: 1 addition & 1 deletion paddle/fluid/pir/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ file(GLOB FUSION_PASS_SRCS "fusion/*.cc")
cc_library(
fusion_passes
SRCS ${FUSION_PASS_SRCS}
DEPS drr)
DEPS drr transform_general_functions)

cc_library(
transform_general_functions
Expand Down
192 changes: 192 additions & 0 deletions paddle/fluid/pir/transforms/fusion/conv2d_fuse_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/pir/core/builder.h"
#include "paddle/pir/core/builtin_attribute.h"
#include "paddle/pir/core/builtin_dialect.h"
#include "paddle/pir/core/ir_context.h"
#include "paddle/pir/core/op_info.h"
#include "paddle/pir/core/parameter.h"
#include "paddle/pir/core/program.h"
#include "paddle/pir/core/value.h"
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_manager.h"
#include "paddle/pir/pass/pass_registry.h"
#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h"
#include "paddle/pir/pattern_rewrite/pattern_applicator.h"
#include "paddle/pir/pattern_rewrite/pattern_match.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"

#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"

#include "paddle/fluid/pir/transforms/fusion/conv2d_fuse_pass.h"
#include "paddle/fluid/pir/transforms/transform_general_functions.h"

#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/phi/core/ddim.h"

namespace {

class Conv2dBnFusePattern
Xinyu302 marked this conversation as resolved.
Show resolved Hide resolved
: public pir::OpRewritePattern<paddle::dialect::BatchNormOp> {
public:
using pir::OpRewritePattern<paddle::dialect::BatchNormOp>::OpRewritePattern;
bool MatchAndRewrite(
paddle::dialect::BatchNormOp op,
pir::PatternRewriter &rewriter) const override { // NOLINT
// The next op should be batch_norm.
paddle::dialect::Conv2dOp conv2d_op =
pir::GetDefiningOpForInput(op, 0)
->dyn_cast<paddle::dialect::Conv2dOp>();
if (!conv2d_op) return false;

pir::OpResult conv2d_out = conv2d_op.out();
if (!conv2d_out.HasOneUse()) return false;

pir::Value conv2d_filter = conv2d_op.filter();

// pir::GetParameterOp filter_parameter_op =
// conv2d_filter.GetDefiningOp()->dyn_cast<pir::GetParameterOp>();
// if (!filter_parameter_op) return false;

pir::OpResult conv2d_filter_result =
conv2d_filter.dyn_cast<pir::OpResult>();
IR_ENFORCE(conv2d_filter_result);

pir::Value bn_input = op.x();
IR_ENFORCE(bn_input == conv2d_out);

pir::Value bn_mean = op.mean();
pir::Value bn_variance = op.variance();
pir::Value bn_scale = op.scale();
pir::Value bn_bias = op.bias();

// --- deal with filter ---
rewriter.SetInsertionPoint(op);
phi::DDim bn_variance_shape =
bn_variance.type().dyn_cast<paddle::dialect::DenseTensorType>().dims();
float epsilon = op.attribute<pir::FloatAttribute>("epsilon").data();
paddle::dialect::FullOp full_op = rewriter.Build<paddle::dialect::FullOp>(
phi::vectorize(bn_variance_shape), epsilon);
paddle::dialect::AddOp add_op = rewriter.Build<paddle::dialect::AddOp>(
bn_variance.dyn_cast<pir::OpResult>(), full_op.out());
paddle::dialect::SqrtOp sqrt_op =
rewriter.Build<paddle::dialect::SqrtOp>(add_op.out());
paddle::dialect::DivideOp div_op =
rewriter.Build<paddle::dialect::DivideOp>(
bn_scale.dyn_cast<pir::OpResult>(), sqrt_op.out());
// reshape scale
phi::DDim conv2d_filter_shape = pir::GetShapeFromValue(conv2d_filter);
phi::DDim bn_scale_shape =
bn_scale.type().dyn_cast<paddle::dialect::DenseTensorType>().dims();
std::vector<int64_t> bn_scale_new_shape(conv2d_filter_shape.size(), 1);
bn_scale_new_shape[0] = bn_scale_shape[0];
paddle::dialect::ReshapeOp reshape_scale_op =
rewriter.Build<paddle::dialect::ReshapeOp>(div_op.out(),
bn_scale_new_shape);
// new filter --> mul_op.out()
paddle::dialect::MultiplyOp mul_op =
rewriter.Build<paddle::dialect::MultiplyOp>(conv2d_filter_result,
reshape_scale_op.out());

auto conv2d_attributes = conv2d_op->attributes();
auto new_conv2d_op = rewriter.Build<paddle::dialect::Conv2dOp>(
conv2d_op.input().dyn_cast<pir::OpResult>(),
mul_op.out(),
conv2d_attributes);

// --- deal with bias ---
paddle::dialect::MultiplyOp mul_bias_op =
rewriter.Build<paddle::dialect::MultiplyOp>(
bn_mean.dyn_cast<pir::OpResult>(), div_op.out());
// new bias --> sub_op.out()
paddle::dialect::SubtractOp sub_op =
rewriter.Build<paddle::dialect::SubtractOp>(
bn_bias.dyn_cast<pir::OpResult>(), mul_bias_op.out());
// reshape new bias
phi::DDim new_conv2d_out_shape =
pir::GetShapeFromValue(new_conv2d_op.out());
std::vector<int64_t> new_bias_new_shape(new_conv2d_out_shape.size(), 1);
std::string data_format =
new_conv2d_op.attribute<pir::StrAttribute>("data_format").AsString();
IR_ENFORCE(data_format == "NCHW", "Only support NCHW now.");
new_bias_new_shape[1] = new_conv2d_out_shape[1];
paddle::dialect::ReshapeOp reshape_bias_op =
rewriter.Build<paddle::dialect::ReshapeOp>(sub_op.out(),
new_bias_new_shape);
paddle::dialect::AddOp add_bias_op = rewriter.Build<paddle::dialect::AddOp>(
new_conv2d_op.out(), reshape_bias_op.out());

rewriter.ReplaceAllUsesWith(op.out(), add_bias_op.out());

rewriter.EraseOp(op);
rewriter.EraseOp(conv2d_op);
return true;
}
};

class Conv2dFusePass : public pir::Pass {
public:
Conv2dFusePass() : pir::Pass("Conv2dFusePass", 2) {}

bool Initialize(pir::IrContext *context) override {
pir::RewritePatternSet ps(context);
auto conv_bn_pattern = std::make_unique<Conv2dBnFusePattern>(
context,
1,
std::vector<std::string>{paddle::dialect::FullOp::name(),
paddle::dialect::AddOp::name(),
paddle::dialect::SqrtOp::name(),
paddle::dialect::DivideOp::name(),
paddle::dialect::ReshapeOp::name(),
paddle::dialect::MultiplyOp::name(),
paddle::dialect::SubtractOp::name(),
paddle::dialect::Conv2dOp::name()});
VLOG(4) << "Conv2dBnFusePattern will generate the following operations: ";
for (auto op_info : conv_bn_pattern->generated_ops()) {
VLOG(4) << "--- " << op_info.name();
}
ps.Add(std::move(conv_bn_pattern));
patterns_ = pir::FrozenRewritePatternSet(std::move(ps));
return true;
}

void Run(pir::Operation *op) override {
pir::GreedyRewriteConfig cfg;
cfg.use_top_down_traversal = true;
cfg.max_iterations = 10;
pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg);
}

bool CanApplyOn(pir::Operation *op) const override {
return op->isa<::pir::ModuleOp>() && op->num_regions() > 0;
}

private:
pir::FrozenRewritePatternSet patterns_;
};

} // namespace

namespace pir {

std::unique_ptr<Pass> CreateConv2dFusePass() {
return std::make_unique<Conv2dFusePass>();
}

} // namespace pir

REGISTER_IR_PASS(conv2d_fuse_pass, Conv2dFusePass);
26 changes: 26 additions & 0 deletions paddle/fluid/pir/transforms/fusion/conv2d_fuse_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include "paddle/pir/core/dll_decl.h"

namespace pir {

class Pass;

IR_API std::unique_ptr<Pass> CreateConv2dFusePass();

} // namespace pir
4 changes: 2 additions & 2 deletions test/cpp/pir/pattern_rewrite/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
set(PATTERN_REWRITE_TEST_DEPS
pd_constant_folding_pass transform_general_functions gtest pd_op_dialect
pir)
pd_constant_folding_pass fusion_passes transform_general_functions gtest
pd_op_dialect pir)

if(WITH_DISTRIBUTE)
set(PATTERN_REWRITE_TEST_DEPS ${PATTERN_REWRITE_TEST_DEPS} fleet_executor
Expand Down
Loading