Skip to content

Commit

Permalink
[DRR] change namespace pir::drr:: to paddle::drr:: (PaddlePaddle#60432)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanlehome authored and Wanglongzhi2001 committed Jan 7, 2024
1 parent 0391f03 commit 5af4bc7
Show file tree
Hide file tree
Showing 35 changed files with 597 additions and 529 deletions.
40 changes: 20 additions & 20 deletions paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ namespace cinn {
namespace dialect {
namespace ir {

class SumOpPattern : public pir::drr::DrrPatternBase<SumOpPattern> {
class SumOpPattern : public paddle::drr::DrrPatternBase<SumOpPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// Source Pattern
pir::drr::SourcePattern pattern = ctx->SourcePattern();
paddle::drr::SourcePattern pattern = ctx->SourcePattern();
const auto &full_int_array =
pattern.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", pattern.Attr("axis_info")},
Expand All @@ -48,7 +48,7 @@ class SumOpPattern : public pir::drr::DrrPatternBase<SumOpPattern> {
pattern.Tensor("ret") = sum(pattern.Tensor("arg0"), full_int_array());

// Result patterns
pir::drr::ResultPattern res = pattern.ResultPattern();
paddle::drr::ResultPattern res = pattern.ResultPattern();
const auto &cinn_reduce_sum =
res.Op(cinn::dialect::ReduceSumOp::name(),
{{"dim", pattern.Attr("axis_info")},
Expand All @@ -57,11 +57,11 @@ class SumOpPattern : public pir::drr::DrrPatternBase<SumOpPattern> {
}
};

class MaxOpPattern : public pir::drr::DrrPatternBase<MaxOpPattern> {
class MaxOpPattern : public paddle::drr::DrrPatternBase<MaxOpPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// Source Pattern
pir::drr::SourcePattern pattern = ctx->SourcePattern();
paddle::drr::SourcePattern pattern = ctx->SourcePattern();
const auto &full_int_array =
pattern.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", pattern.Attr("axis_info")},
Expand All @@ -73,7 +73,7 @@ class MaxOpPattern : public pir::drr::DrrPatternBase<MaxOpPattern> {
pattern.Tensor("ret") = pd_max(pattern.Tensor("arg0"), full_int_array());

// Result patterns
pir::drr::ResultPattern res = pattern.ResultPattern();
paddle::drr::ResultPattern res = pattern.ResultPattern();
const auto &cinn_reduce_max =
res.Op(cinn::dialect::ReduceMaxOp::name(),
{{"dim", pattern.Attr("axis_info")},
Expand All @@ -82,11 +82,11 @@ class MaxOpPattern : public pir::drr::DrrPatternBase<MaxOpPattern> {
}
};

class MinOpPattern : public pir::drr::DrrPatternBase<MinOpPattern> {
class MinOpPattern : public paddle::drr::DrrPatternBase<MinOpPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// Source Pattern
pir::drr::SourcePattern pattern = ctx->SourcePattern();
paddle::drr::SourcePattern pattern = ctx->SourcePattern();
const auto &full_int_array =
pattern.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", pattern.Attr("axis_info")},
Expand All @@ -98,7 +98,7 @@ class MinOpPattern : public pir::drr::DrrPatternBase<MinOpPattern> {
pattern.Tensor("ret") = pd_max(pattern.Tensor("arg0"), full_int_array());

// Result patterns
pir::drr::ResultPattern res = pattern.ResultPattern();
paddle::drr::ResultPattern res = pattern.ResultPattern();
const auto &cinn_reduce_max =
res.Op(cinn::dialect::ReduceMinOp::name(),
{{"dim", pattern.Attr("axis_info")},
Expand All @@ -107,11 +107,11 @@ class MinOpPattern : public pir::drr::DrrPatternBase<MinOpPattern> {
}
};

class ProdOpPattern : public pir::drr::DrrPatternBase<ProdOpPattern> {
class ProdOpPattern : public paddle::drr::DrrPatternBase<ProdOpPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// Source Pattern
pir::drr::SourcePattern pattern = ctx->SourcePattern();
paddle::drr::SourcePattern pattern = ctx->SourcePattern();
const auto &full_int_array =
pattern.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", pattern.Attr("axis_info")},
Expand All @@ -123,7 +123,7 @@ class ProdOpPattern : public pir::drr::DrrPatternBase<ProdOpPattern> {
pattern.Tensor("ret") = pd_max(pattern.Tensor("arg0"), full_int_array());

// Result patterns
pir::drr::ResultPattern res = pattern.ResultPattern();
paddle::drr::ResultPattern res = pattern.ResultPattern();
const auto &cinn_reduce_max =
res.Op(cinn::dialect::ReduceProdOp::name(),
{{"dim", pattern.Attr("axis_info")},
Expand Down Expand Up @@ -552,11 +552,11 @@ class SplitWithNumOpPattern
}
};

class UniformOpPattern : public pir::drr::DrrPatternBase<UniformOpPattern> {
class UniformOpPattern : public paddle::drr::DrrPatternBase<UniformOpPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// Source Pattern
pir::drr::SourcePattern pattern = ctx->SourcePattern();
paddle::drr::SourcePattern pattern = ctx->SourcePattern();
const auto &full_int_array =
pattern.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", pattern.Attr("axis_info")},
Expand Down Expand Up @@ -585,7 +585,7 @@ class UniformOpPattern : public pir::drr::DrrPatternBase<UniformOpPattern> {
// int64_t[] shape, float min, float max, int seed, DataType dtype, int
// diag_num, int diag_step, float diag_val)
// Result patterns
pir::drr::ResultPattern res = pattern.ResultPattern();
paddle::drr::ResultPattern res = pattern.ResultPattern();
const auto &cinn_uniform =
res.Op(cinn::dialect::UniformRandomOp::name(),
{{"shape", pattern.Attr("axis_info")},
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,22 @@
{op_header}
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
namespace pir {{
namespace paddle {{
namespace drr {{
void OperationFactory::Register{dialect}GeneratedOpCreator() {{
{body}
}}
}} // namespace drr
}} // namespace pir
}} // namespace paddle
"""

NORMAL_FUNCTION_TEMPLATE = """
RegisterOperationCreator(
"{op_name}",
[](const std::vector<Value>& inputs,
[](const std::vector<pir::Value>& inputs,
const pir::AttributeMap& attrs,
pir::PatternRewriter& rewriter) {{
return rewriter.Build<{namespace}::{op_class_name}>(
Expand All @@ -53,7 +53,7 @@
MUTABLE_ATTR_FUNCTION_TEMPLATE = """
RegisterOperationCreator(
"{op_name}",
[](const std::vector<Value>& inputs,
[](const std::vector<pir::Value>& inputs,
const pir::AttributeMap& attrs,
pir::PatternRewriter& rewriter) {{
// mutable_attr is tensor
Expand Down
24 changes: 12 additions & 12 deletions paddle/fluid/pir/drr/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ Taking PASS to eliminate redundant CastOp as an example, the code example develo
~~~ c++
// 1. Inherit specialized template class from DrPatternBase
class RemoveRedundentCastPattern
: public pir::drr::DrrPatternBase<RemoveRedundentCastPattern> {
: public paddle::drr::DrrPatternBase<RemoveRedundentCastPattern> {
// 2. Overload operator()
void operator()(pir::drr::DrrPatternContext *ctx) const override {
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// 3. Define a SourcePattern containing two consecutive CastOps using Op, Tensor, and Attribute
auto pat = ctx->SourcePattern();

Expand Down Expand Up @@ -55,7 +55,7 @@ Developers only need to define `SourcePattern`, `Constrains` and `ResultPattern`
<tr>
<td rowspan="1">DrrPatternBase</td>
<td> <pre> virtual void operator()(
pir::drr::DrrPatternContext* ctx) const </pre></td>
paddle::drr::DrrPatternContext* ctx) const </pre></td>
<td> Implement the entry function of DRR PASS </td>
<td> ctx: Context parameters required to create Patten</td>
</tr>
Expand Down Expand Up @@ -165,11 +165,11 @@ Attribute Attr(const AttrComputeFunc& attr_compute_func) const</pre></td>
## 3 Example
Example 1: Matmul + Add -> FusedGemmEpilogue
~~~ c++
class FusedLinearPattern : public pir::drr::DrrPatternBase<FusedLinearPattern> {
class FusedLinearPattern : public paddle::drr::DrrPatternBase<FusedLinearPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// Define SourcePattern
pir::drr::SourcePattern pat = ctx->SourcePattern();
paddle::drr::SourcePattern pat = ctx->SourcePattern();
const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(),
{{"transpose_x", pat.Attr("trans_x")},
{"transpose_y", pat.Attr("trans_y")}});
Expand All @@ -179,10 +179,10 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase<FusedLinearPattern> {
pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias"));
// Define ResultPattern
pir::drr::ResultPattern res = pat.ResultPattern();
paddle::drr::ResultPattern res = pat.ResultPattern();
// Define Constrain
const auto &act_attr =
res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any {
res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any {
return "none";
});
const auto &fused_gemm_epilogue = res.Op(paddle::dialect::FusedGemmEpilogueOp::name(),
Expand All @@ -199,11 +199,11 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase<FusedLinearPattern> {
Example 2: Full + Expand -> Full
~~~ c++
class FoldExpandToConstantPattern
: public pir::drr::DrrPatternBase<FoldExpandToConstantPattern> {
: public paddle::drr::DrrPatternBase<FoldExpandToConstantPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// Define SourcePattern
pir::drr::SourcePattern pat = ctx->SourcePattern();
paddle::drr::SourcePattern pat = ctx->SourcePattern();
const auto &full1 = pat.Op(paddle::dialect::FullOp::name(),
{{"shape", pat.Attr("shape_1")},
{"value", pat.Attr("value_1")},
Expand All @@ -218,7 +218,7 @@ class FoldExpandToConstantPattern
pat.Tensor("ret") = expand(full1(), full_int_array1());

// Define ResultPattern
pir::drr::ResultPattern res = pat.ResultPattern();
paddle::drr::ResultPattern res = pat.ResultPattern();
const auto &full2 = res.Op(paddle::dialect::FullOp::name(),
{{"shape", pat.Attr("expand_shape_value")},
{"value", pat.Attr("value_1")},
Expand Down
24 changes: 12 additions & 12 deletions paddle/fluid/pir/drr/README_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ DRR ( Declarative Rewrite Rule ) 是来处理这种 DAG-to-DAG 类型的一套 P
~~~ c++
// 1. 继承 DrrPatternBase 的特化模板类
class RemoveRedundentCastPattern
: public pir::drr::DrrPatternBase<RemoveRedundentCastPattern> {
: public paddle::drr::DrrPatternBase<RemoveRedundentCastPattern> {
// 2. 重载 operator()
void operator()(pir::drr::DrrPatternContext *ctx) const override {
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// 3. 使用 Op、Tensor 和 Attribute 定义一个包含两个连续 CastOp 的 SourcePattern
auto pat = ctx->SourcePattern();

Expand Down Expand Up @@ -56,7 +56,7 @@ DRR PASS 包含以下三个部分:
<tr>
<td rowspan="1">DrrPatternBase</td>
<td> <pre> virtual void operator()(
pir::drr::DrrPatternContext* ctx) const </pre></td>
paddle::drr::DrrPatternContext* ctx) const </pre></td>
<td> 实现 DRR PASS 的入口函数 </td>
<td> ctx: 创建 Patten 所需要的 Context 参数</td>
</tr>
Expand Down Expand Up @@ -168,11 +168,11 @@ Attribute Attr(const AttrComputeFunc& attr_compute_func) const</pre></td>
## 3 使用示例
Example 1: Matmul + Add -> FusedGemmEpilogue
~~~ c++
class FusedLinearPattern : public pir::drr::DrrPatternBase<FusedLinearPattern> {
class FusedLinearPattern : public paddle::drr::DrrPatternBase<FusedLinearPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// 定义 Source Pattern
pir::drr::SourcePattern pat = ctx->SourcePattern();
paddle::drr::SourcePattern pat = ctx->SourcePattern();
const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(),
{{"transpose_x", pat.Attr("trans_x")},
{"transpose_y", pat.Attr("trans_y")}});
Expand All @@ -182,10 +182,10 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase<FusedLinearPattern> {
pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias"));
// 定义 Result Pattern
pir::drr::ResultPattern res = pat.ResultPattern();
paddle::drr::ResultPattern res = pat.ResultPattern();
// 定义 Constrain
const auto &act_attr =
res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any {
res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any {
return "none";
});
const auto &fused_gemm_epilogue = res.Op(paddle::dialect::FusedGemmEpilogueOp::name(),
Expand All @@ -202,11 +202,11 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase<FusedLinearPattern> {
Example 2: Full + Expand -> Full
~~~ c++
class FoldExpandToConstantPattern
: public pir::drr::DrrPatternBase<FoldExpandToConstantPattern> {
: public paddle::drr::DrrPatternBase<FoldExpandToConstantPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// 定义 Source Pattern
pir::drr::SourcePattern pat = ctx->SourcePattern();
paddle::drr::SourcePattern pat = ctx->SourcePattern();
const auto &full1 = pat.Op(paddle::dialect::FullOp::name(),
{{"shape", pat.Attr("shape_1")},
{"value", pat.Attr("value_1")},
Expand All @@ -221,7 +221,7 @@ class FoldExpandToConstantPattern
pat.Tensor("ret") = expand(full1(), full_int_array1());

// 定义 Result Pattern Constrains: 本 Pass 无额外约束规则
pir::drr::ResultPattern res = pat.ResultPattern();
paddle::drr::ResultPattern res = pat.ResultPattern();
const auto &full2 = res.Op(paddle::dialect::FullOp::name(),
{{"shape", pat.Attr("expand_shape_value")},
{"value", pat.Attr("value_1")},
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/pir/drr/api/drr_pattern_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include "paddle/fluid/pir/drr/api/drr_pattern_context.h"
#include "paddle/fluid/pir/drr/drr_rewrite_pattern.h"

namespace pir {
namespace paddle {
namespace drr {

template <typename DrrPattern>
Expand All @@ -26,7 +26,7 @@ class DrrPatternBase {
virtual ~DrrPatternBase() = default;

// Define the Drr Pattern.
virtual void operator()(pir::drr::DrrPatternContext* ctx) const = 0;
virtual void operator()(paddle::drr::DrrPatternContext* ctx) const = 0;

std::unique_ptr<DrrRewritePattern> Build(
pir::IrContext* ir_context, pir::PatternBenefit benefit = 1) const {
Expand All @@ -39,4 +39,4 @@ class DrrPatternBase {
};

} // namespace drr
} // namespace pir
} // namespace paddle
5 changes: 3 additions & 2 deletions paddle/fluid/pir/drr/api/drr_pattern_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include "paddle/fluid/pir/drr/pattern_graph.h"
#include "paddle/phi/core/enforce.h"

namespace pir {
namespace paddle {
namespace drr {

DrrPatternContext::DrrPatternContext() {
Expand All @@ -28,6 +28,7 @@ DrrPatternContext::DrrPatternContext() {
drr::SourcePattern DrrPatternContext::SourcePattern() {
return drr::SourcePattern(this);
}

const Op& DrrPatternContext::SourceOpPattern(
const std::string& op_type,
const std::unordered_map<std::string, Attribute>& attributes) {
Expand Down Expand Up @@ -167,4 +168,4 @@ void Tensor::operator=(const Tensor& other) const { // NOLINT
}

} // namespace drr
} // namespace pir
} // namespace paddle
4 changes: 2 additions & 2 deletions paddle/fluid/pir/drr/api/drr_pattern_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

#include "paddle/fluid/pir/drr/api/match_context.h"

namespace pir {
namespace paddle {
namespace drr {

class Op;
Expand Down Expand Up @@ -334,4 +334,4 @@ class SourcePattern {
};

} // namespace drr
} // namespace pir
} // namespace paddle
4 changes: 2 additions & 2 deletions paddle/fluid/pir/drr/api/match_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include "paddle/fluid/pir/drr/ir_operation.h"
#include "paddle/fluid/pir/drr/match_context_impl.h"

namespace pir {
namespace paddle {
namespace drr {

MatchContext::MatchContext(std::shared_ptr<const MatchContextImpl> impl)
Expand All @@ -46,4 +46,4 @@ template std::vector<int64_t> MatchContext::Attr<std::vector<int64_t>>(
const std::string&) const;

} // namespace drr
} // namespace pir
} // namespace paddle
4 changes: 2 additions & 2 deletions paddle/fluid/pir/drr/api/match_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include "paddle/fluid/pir/drr/api/tensor_interface.h"
#include "paddle/fluid/pir/drr/ir_operation.h"

namespace pir {
namespace paddle {
namespace drr {

class TensorInterface;
Expand All @@ -40,4 +40,4 @@ class MatchContext final {
};

} // namespace drr
} // namespace pir
} // namespace paddle
Loading

0 comments on commit 5af4bc7

Please sign in to comment.