Skip to content

Commit

Permalink
【pir】add fused_linear_param_grad_add_pass (#58401)
Browse files Browse the repository at this point in the history
* fused_linear_param_grad_add_pass

* modify

* skip cpu ci

* add muti_presion modify
  • Loading branch information
xiaoguoguo626807 authored Oct 26, 2023
1 parent e99885f commit 6f7cd08
Show file tree
Hide file tree
Showing 9 changed files with 724 additions and 38 deletions.
20 changes: 10 additions & 10 deletions paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

#include "paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.h"

#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/api/drr_pattern_base.h"
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_registry.h"
Expand All @@ -26,21 +26,21 @@ class FusedDropoutAddPattern
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
pir::drr::SourcePattern pat = ctx->SourcePattern();
const auto &dropout = pat.Op("pd_op.dropout",
const auto &dropout = pat.Op(paddle::dialect::DropoutOp::name(),
{{"p", pat.Attr("p")},
{"is_test", pat.Attr("is_test")},
{"mode", pat.Attr("mod")},
{"seed", pat.Attr("seed")},
{"fix_seed", pat.Attr("fix_seed")}});
const auto &add = pat.Op("pd_op.add");
const auto &add = pat.Op(paddle::dialect::AddOp::name());

dropout({&pat.Tensor("x"), &pat.Tensor("seed_tensor")},
{&pat.Tensor("dropout_out"), &pat.Tensor("mask")});
pat.Tensor("add_out") = add(pat.Tensor("dropout_out"), pat.Tensor("y"));

pir::drr::ResultPattern res = pat.ResultPattern();
const auto &fused_dropout_add =
res.Op("pd_op.fused_dropout_add",
res.Op(paddle::dialect::FusedDropoutAddOp::name(),
{{{"p", pat.Attr("p")},
{"is_test", pat.Attr("is_test")},
{"mode", pat.Attr("mod")},
Expand All @@ -57,16 +57,16 @@ class FusedDropoutGradAddGradPattern
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
pir::drr::SourcePattern pat = ctx->SourcePattern();
const auto &dropout = pat.Op("pd_op.dropout",
const auto &dropout = pat.Op(paddle::dialect::DropoutOp::name(),
{{"p", pat.Attr("p")},
{"is_test", pat.Attr("is_test")},
{"mode", pat.Attr("mod")},
{"seed", pat.Attr("seed")},
{"fix_seed", pat.Attr("fix_seed")}});
const auto &add = pat.Op("pd_op.add");
const auto &add = pat.Op(paddle::dialect::AddOp::name());

const auto &add_grad = pat.Op("pd_op.add_grad");
const auto &dropout_grad = pat.Op("pd_op.dropout_grad",
const auto &add_grad = pat.Op(paddle::dialect::AddGradOp::name());
const auto &dropout_grad = pat.Op(paddle::dialect::DropoutGradOp::name(),
{{"p", pat.Attr("p")},
{"is_test", pat.Attr("is_test")},
{"mode", pat.Attr("mod")}});
Expand All @@ -83,15 +83,15 @@ class FusedDropoutGradAddGradPattern

pir::drr::ResultPattern res = pat.ResultPattern();
const auto &fused_dropout_add =
res.Op("pd_op.fused_dropout_add",
res.Op(paddle::dialect::FusedDropoutAddOp::name(),
{{{"p", pat.Attr("p")},
{"is_test", pat.Attr("is_test")},
{"mode", pat.Attr("mod")},
{"seed", pat.Attr("seed")},
{"fix_seed", pat.Attr("fix_seed")}}});

const auto &fused_dropout_add_grad =
res.Op("pd_op.fused_dropout_add_grad",
res.Op(paddle::dialect::FusedDropoutAddGradOp::name(),
{{{"p", pat.Attr("p")},
{"is_test", pat.Attr("is_test")},
{"mode", pat.Attr("mod")},
Expand Down
59 changes: 31 additions & 28 deletions paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

#include "paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.h"

#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/api/drr_pattern_base.h"
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_registry.h"
Expand All @@ -25,10 +25,10 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase<FusedLinearPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
pir::drr::SourcePattern pat = ctx->SourcePattern();
const auto &matmul = pat.Op("pd_op.matmul",
const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(),
{{"transpose_x", pat.Attr("trans_x")},
{"transpose_y", pat.Attr("trans_y")}});
const auto &add = pat.Op("pd_op.add");
const auto &add = pat.Op(paddle::dialect::AddOp::name());

pat.Tensor("tmp") = matmul(pat.Tensor("x"), pat.Tensor("w"));
pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias"));
Expand All @@ -43,10 +43,11 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase<FusedLinearPattern> {
res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any {
return "none";
});
const auto &fused_gemm_epilogue = res.Op("pd_op.fused_gemm_epilogue",
{{{"trans_x", pat.Attr("trans_x")},
{"trans_y", pat.Attr("trans_y")},
{"activation", act_attr}}});
const auto &fused_gemm_epilogue =
res.Op(paddle::dialect::FusedGemmEpilogueOp::name(),
{{{"trans_x", pat.Attr("trans_x")},
{"trans_y", pat.Attr("trans_y")},
{"activation", act_attr}}});
fused_gemm_epilogue(
{&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")},
{&res.Tensor("out")});
Expand All @@ -58,14 +59,14 @@ class FusedLinearGradPattern
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
pir::drr::SourcePattern pat = ctx->SourcePattern();
const auto &matmul = pat.Op("pd_op.matmul",
const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(),
{{"transpose_x", pat.Attr("trans_x")},
{"transpose_y", pat.Attr("trans_y")}});
const auto &matmul_grad = pat.Op("pd_op.matmul_grad",
const auto &matmul_grad = pat.Op(paddle::dialect::MatmulGradOp::name(),
{{"transpose_x", pat.Attr("trans_x")},
{"transpose_y", pat.Attr("trans_y")}});
const auto &add = pat.Op("pd_op.add");
const auto &add_grad = pat.Op("pd_op.add_grad");
const auto &add = pat.Op(paddle::dialect::AddOp::name());
const auto &add_grad = pat.Op(paddle::dialect::AddGradOp::name());

pat.Tensor("tmp") = matmul(pat.Tensor("x"), pat.Tensor("w"));
pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias"));
Expand All @@ -84,12 +85,13 @@ class FusedLinearGradPattern
res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any {
return "none";
});
const auto &fused_gemm_epilogue = res.Op("pd_op.fused_gemm_epilogue",
{{{"trans_x", pat.Attr("trans_x")},
{"trans_y", pat.Attr("trans_y")},
{"activation", act_attr}}});
const auto &fused_gemm_epilogue =
res.Op(paddle::dialect::FusedGemmEpilogueOp::name(),
{{{"trans_x", pat.Attr("trans_x")},
{"trans_y", pat.Attr("trans_y")},
{"activation", act_attr}}});
const auto &fused_gemm_epilogue_grad =
res.Op("pd_op.fused_gemm_epilogue_grad",
res.Op(paddle::dialect::FusedGemmEpilogueGradOp::name(),
{{{"trans_x", pat.Attr("trans_x")},
{"trans_y", pat.Attr("trans_y")},
{"activation_grad", act_attr}}});
Expand All @@ -112,19 +114,20 @@ class FusedLinearGeluGradPattern
void operator()(pir::drr::DrrPatternContext *ctx) const override {
pir::drr::SourcePattern pat = ctx->SourcePattern();
const auto &fused_gemm_epilogue =
pat.Op("pd_op.fused_gemm_epilogue",
pat.Op(paddle::dialect::FusedGemmEpilogueOp::name(),
{{{"trans_x", pat.Attr("trans_x1")},
{"trans_y", pat.Attr("trans_y1")},
{"activation", pat.Attr("act1")}}});
const auto &fused_gemm_epilogue_grad1 =
pat.Op("pd_op.fused_gemm_epilogue_grad",
pat.Op(paddle::dialect::FusedGemmEpilogueGradOp::name(),
{{{"trans_x", pat.Attr("trans_x2")},
{"trans_y", pat.Attr("trans_y2")},
{"activation_grad", pat.Attr("act2")}}});
fused_gemm_epilogue(
{&pat.Tensor("x"), &pat.Tensor("w"), &pat.Tensor("bias")},
{&pat.Tensor("fuse_out"), &pat.Tensor("reserve_space")});
pat.Tensor("out") = pat.Op("pd_op.gelu")(pat.Tensor("fuse_out"));
pat.Tensor("out") =
pat.Op(paddle::dialect::GeluOp::name())(pat.Tensor("fuse_out"));

fused_gemm_epilogue_grad1({&pat.Tensor("x1"),
&pat.Tensor("w1"),
Expand All @@ -133,8 +136,8 @@ class FusedLinearGeluGradPattern
{&pat.Tensor("x1_grad"),
&pat.Tensor("w1_grad"),
&pat.Tensor("bias1_grad")});
pat.Tensor("gelu_dx") = pat.Op("pd_op.gelu_grad")(pat.Tensor("fuse_out"),
pat.Tensor("x1_grad"));
pat.Tensor("gelu_dx") = pat.Op(paddle::dialect::GeluGradOp::name())(
pat.Tensor("fuse_out"), pat.Tensor("x1_grad"));

pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) {
return match_ctx.Attr<std::string>("act1") == "none" &&
Expand All @@ -147,7 +150,7 @@ class FusedLinearGeluGradPattern
return "gelu";
});
const auto &fused_gemm_epilogue_new =
res.Op("pd_op.fused_gemm_epilogue",
res.Op(paddle::dialect::FusedGemmEpilogueOp::name(),
{{{"trans_x", pat.Attr("trans_x1")},
{"trans_y", pat.Attr("trans_y1")},
{"activation", act_attr}}});
Expand All @@ -156,7 +159,7 @@ class FusedLinearGeluGradPattern
return "gelu_grad";
});
const auto &fused_gemm_epilogue_grad_new =
res.Op("pd_op.fused_gemm_epilogue_grad",
res.Op(paddle::dialect::FusedGemmEpilogueGradOp::name(),
{{{"trans_x", pat.Attr("trans_x2")},
{"trans_y", pat.Attr("trans_y2")},
{"activation_grad", act_grad_attr}}});
Expand All @@ -179,17 +182,17 @@ class FusedLinearReluGradPattern
void operator()(pir::drr::DrrPatternContext *ctx) const override {
pir::drr::SourcePattern pat = ctx->SourcePattern();
const auto &fused_gemm_epilogue =
pat.Op("pd_op.fused_gemm_epilogue",
pat.Op(paddle::dialect::FusedGemmEpilogueOp::name(),
{{{"trans_x", pat.Attr("trans_x1")},
{"trans_y", pat.Attr("trans_y1")},
{"activation", pat.Attr("act1")}}});
const auto &fused_gemm_epilogue_grad =
pat.Op("pd_op.fused_gemm_epilogue_grad",
pat.Op(paddle::dialect::FusedGemmEpilogueGradOp::name(),
{{{"trans_x", pat.Attr("trans_x2")},
{"trans_y", pat.Attr("trans_y2")},
{"activation_grad", pat.Attr("act2")}}});
const auto &fused_gemm_epilogue_grad1 =
pat.Op("pd_op.fused_gemm_epilogue_grad",
pat.Op(paddle::dialect::FusedGemmEpilogueGradOp::name(),
{{{"trans_x", pat.Attr("trans_x3")},
{"trans_y", pat.Attr("trans_y3")},
{"activation_grad", pat.Attr("act3")}}});
Expand Down Expand Up @@ -226,7 +229,7 @@ class FusedLinearReluGradPattern
return "relu";
});
const auto &fused_gemm_epilogue_new =
res.Op("pd_op.fused_gemm_epilogue",
res.Op(paddle::dialect::FusedGemmEpilogueOp::name(),
{{{"trans_x", pat.Attr("trans_x1")},
{"trans_y", pat.Attr("trans_y1")},
{"activation", act_attr}}});
Expand All @@ -235,7 +238,7 @@ class FusedLinearReluGradPattern
return "relu_grad";
});
const auto &fused_gemm_epilogue_grad1_new =
res.Op("pd_op.fused_gemm_epilogue_grad",
res.Op(paddle::dialect::FusedGemmEpilogueGradOp::name(),
{{{"trans_x", pat.Attr("trans_x2")},
{"trans_y", pat.Attr("trans_y2")},
{"activation_grad", act_grad_attr}}});
Expand Down
Loading

0 comments on commit 6f7cd08

Please sign in to comment.