Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRR] change namespace pir::drr:: to paddle::drr:: #60432

Merged
merged 1 commit into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ namespace cinn {
namespace dialect {
namespace ir {

class SumOpPattern : public pir::drr::DrrPatternBase<SumOpPattern> {
class SumOpPattern : public paddle::drr::DrrPatternBase<SumOpPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// Source Pattern
pir::drr::SourcePattern pattern = ctx->SourcePattern();
paddle::drr::SourcePattern pattern = ctx->SourcePattern();
const auto &full_int_array =
pattern.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", pattern.Attr("axis_info")},
Expand All @@ -48,7 +48,7 @@ class SumOpPattern : public pir::drr::DrrPatternBase<SumOpPattern> {
pattern.Tensor("ret") = sum(pattern.Tensor("arg0"), full_int_array());

// Result patterns
pir::drr::ResultPattern res = pattern.ResultPattern();
paddle::drr::ResultPattern res = pattern.ResultPattern();
const auto &cinn_reduce_sum =
res.Op(cinn::dialect::ReduceSumOp::name(),
{{"dim", pattern.Attr("axis_info")},
Expand All @@ -57,11 +57,11 @@ class SumOpPattern : public pir::drr::DrrPatternBase<SumOpPattern> {
}
};

class MaxOpPattern : public pir::drr::DrrPatternBase<MaxOpPattern> {
class MaxOpPattern : public paddle::drr::DrrPatternBase<MaxOpPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// Source Pattern
pir::drr::SourcePattern pattern = ctx->SourcePattern();
paddle::drr::SourcePattern pattern = ctx->SourcePattern();
const auto &full_int_array =
pattern.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", pattern.Attr("axis_info")},
Expand All @@ -73,7 +73,7 @@ class MaxOpPattern : public pir::drr::DrrPatternBase<MaxOpPattern> {
pattern.Tensor("ret") = pd_max(pattern.Tensor("arg0"), full_int_array());

// Result patterns
pir::drr::ResultPattern res = pattern.ResultPattern();
paddle::drr::ResultPattern res = pattern.ResultPattern();
const auto &cinn_reduce_max =
res.Op(cinn::dialect::ReduceMaxOp::name(),
{{"dim", pattern.Attr("axis_info")},
Expand All @@ -82,11 +82,11 @@ class MaxOpPattern : public pir::drr::DrrPatternBase<MaxOpPattern> {
}
};

class MinOpPattern : public pir::drr::DrrPatternBase<MinOpPattern> {
class MinOpPattern : public paddle::drr::DrrPatternBase<MinOpPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// Source Pattern
pir::drr::SourcePattern pattern = ctx->SourcePattern();
paddle::drr::SourcePattern pattern = ctx->SourcePattern();
const auto &full_int_array =
pattern.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", pattern.Attr("axis_info")},
Expand All @@ -98,7 +98,7 @@ class MinOpPattern : public pir::drr::DrrPatternBase<MinOpPattern> {
pattern.Tensor("ret") = pd_max(pattern.Tensor("arg0"), full_int_array());

// Result patterns
pir::drr::ResultPattern res = pattern.ResultPattern();
paddle::drr::ResultPattern res = pattern.ResultPattern();
const auto &cinn_reduce_max =
res.Op(cinn::dialect::ReduceMinOp::name(),
{{"dim", pattern.Attr("axis_info")},
Expand All @@ -107,11 +107,11 @@ class MinOpPattern : public pir::drr::DrrPatternBase<MinOpPattern> {
}
};

class ProdOpPattern : public pir::drr::DrrPatternBase<ProdOpPattern> {
class ProdOpPattern : public paddle::drr::DrrPatternBase<ProdOpPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// Source Pattern
pir::drr::SourcePattern pattern = ctx->SourcePattern();
paddle::drr::SourcePattern pattern = ctx->SourcePattern();
const auto &full_int_array =
pattern.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", pattern.Attr("axis_info")},
Expand All @@ -123,7 +123,7 @@ class ProdOpPattern : public pir::drr::DrrPatternBase<ProdOpPattern> {
pattern.Tensor("ret") = pd_max(pattern.Tensor("arg0"), full_int_array());

// Result patterns
pir::drr::ResultPattern res = pattern.ResultPattern();
paddle::drr::ResultPattern res = pattern.ResultPattern();
const auto &cinn_reduce_max =
res.Op(cinn::dialect::ReduceProdOp::name(),
{{"dim", pattern.Attr("axis_info")},
Expand Down Expand Up @@ -552,11 +552,11 @@ class SplitWithNumOpPattern
}
};

class UniformOpPattern : public pir::drr::DrrPatternBase<UniformOpPattern> {
class UniformOpPattern : public paddle::drr::DrrPatternBase<UniformOpPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// Source Pattern
pir::drr::SourcePattern pattern = ctx->SourcePattern();
paddle::drr::SourcePattern pattern = ctx->SourcePattern();
const auto &full_int_array =
pattern.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", pattern.Attr("axis_info")},
Expand Down Expand Up @@ -585,7 +585,7 @@ class UniformOpPattern : public pir::drr::DrrPatternBase<UniformOpPattern> {
// int64_t[] shape, float min, float max, int seed, DataType dtype, int
// diag_num, int diag_step, float diag_val)
// Result patterns
pir::drr::ResultPattern res = pattern.ResultPattern();
paddle::drr::ResultPattern res = pattern.ResultPattern();
const auto &cinn_uniform =
res.Op(cinn::dialect::UniformRandomOp::name(),
{{"shape", pattern.Attr("axis_info")},
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,22 @@
{op_header}
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
namespace pir {{
namespace paddle {{
namespace drr {{
void OperationFactory::Register{dialect}GeneratedOpCreator() {{
{body}
}}
}} // namespace drr
}} // namespace pir
}} // namespace paddle
"""

NORMAL_FUNCTION_TEMPLATE = """
RegisterOperationCreator(
"{op_name}",
[](const std::vector<Value>& inputs,
[](const std::vector<pir::Value>& inputs,
const pir::AttributeMap& attrs,
pir::PatternRewriter& rewriter) {{
return rewriter.Build<{namespace}::{op_class_name}>(
Expand All @@ -53,7 +53,7 @@
MUTABLE_ATTR_FUNCTION_TEMPLATE = """
RegisterOperationCreator(
"{op_name}",
[](const std::vector<Value>& inputs,
[](const std::vector<pir::Value>& inputs,
const pir::AttributeMap& attrs,
pir::PatternRewriter& rewriter) {{
// mutable_attr is tensor
Expand Down
24 changes: 12 additions & 12 deletions paddle/fluid/pir/drr/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ Taking PASS to eliminate redundant CastOp as an example, the code example develo
~~~ c++
// 1. Inherit specialized template class from DrPatternBase
class RemoveRedundentCastPattern
: public pir::drr::DrrPatternBase<RemoveRedundentCastPattern> {
: public paddle::drr::DrrPatternBase<RemoveRedundentCastPattern> {
// 2. Overload operator()
void operator()(pir::drr::DrrPatternContext *ctx) const override {
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// 3. Define a SourcePattern containing two consecutive CastOps using Op, Tensor, and Attribute
auto pat = ctx->SourcePattern();

Expand Down Expand Up @@ -55,7 +55,7 @@ Developers only need to define `SourcePattern`, `Constrains` and `ResultPattern`
<tr>
<td rowspan="1">DrrPatternBase</td>
<td> <pre> virtual void operator()(
pir::drr::DrrPatternContext* ctx) const </pre></td>
paddle::drr::DrrPatternContext* ctx) const </pre></td>
<td> Implement the entry function of DRR PASS </td>
<td> ctx: Context parameters required to create Patten</td>
</tr>
Expand Down Expand Up @@ -165,11 +165,11 @@ Attribute Attr(const AttrComputeFunc& attr_compute_func) const</pre></td>
## 3 Example
Example 1: Matmul + Add -> FusedGemmEpilogue
~~~ c++
class FusedLinearPattern : public pir::drr::DrrPatternBase<FusedLinearPattern> {
class FusedLinearPattern : public paddle::drr::DrrPatternBase<FusedLinearPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// Define SourcePattern
pir::drr::SourcePattern pat = ctx->SourcePattern();
paddle::drr::SourcePattern pat = ctx->SourcePattern();
const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(),
{{"transpose_x", pat.Attr("trans_x")},
{"transpose_y", pat.Attr("trans_y")}});
Expand All @@ -179,10 +179,10 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase<FusedLinearPattern> {
pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias"));
// Define ResultPattern
pir::drr::ResultPattern res = pat.ResultPattern();
paddle::drr::ResultPattern res = pat.ResultPattern();
// Define Constrain
const auto &act_attr =
res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any {
res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any {
return "none";
});
const auto &fused_gemm_epilogue = res.Op(paddle::dialect::FusedGemmEpilogueOp::name(),
Expand All @@ -199,11 +199,11 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase<FusedLinearPattern> {
Example 2: Full + Expand -> Full
~~~ c++
class FoldExpandToConstantPattern
: public pir::drr::DrrPatternBase<FoldExpandToConstantPattern> {
: public paddle::drr::DrrPatternBase<FoldExpandToConstantPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// Define SourcePattern
pir::drr::SourcePattern pat = ctx->SourcePattern();
paddle::drr::SourcePattern pat = ctx->SourcePattern();
const auto &full1 = pat.Op(paddle::dialect::FullOp::name(),
{{"shape", pat.Attr("shape_1")},
{"value", pat.Attr("value_1")},
Expand All @@ -218,7 +218,7 @@ class FoldExpandToConstantPattern
pat.Tensor("ret") = expand(full1(), full_int_array1());

// Define ResultPattern
pir::drr::ResultPattern res = pat.ResultPattern();
paddle::drr::ResultPattern res = pat.ResultPattern();
const auto &full2 = res.Op(paddle::dialect::FullOp::name(),
{{"shape", pat.Attr("expand_shape_value")},
{"value", pat.Attr("value_1")},
Expand Down
24 changes: 12 additions & 12 deletions paddle/fluid/pir/drr/README_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ DRR ( Declarative Rewrite Rule ) 是来处理这种 DAG-to-DAG 类型的一套 P
~~~ c++
// 1. 继承 DrrPatternBase 的特化模板类
class RemoveRedundentCastPattern
: public pir::drr::DrrPatternBase<RemoveRedundentCastPattern> {
: public paddle::drr::DrrPatternBase<RemoveRedundentCastPattern> {
// 2. 重载 operator()
void operator()(pir::drr::DrrPatternContext *ctx) const override {
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// 3. 使用 Op、Tensor 和 Attribute 定义一个包含两个连续 CastOp 的 SourcePattern
auto pat = ctx->SourcePattern();

Expand Down Expand Up @@ -56,7 +56,7 @@ DRR PASS 包含以下三个部分:
<tr>
<td rowspan="1">DrrPatternBase</td>
<td> <pre> virtual void operator()(
pir::drr::DrrPatternContext* ctx) const </pre></td>
paddle::drr::DrrPatternContext* ctx) const </pre></td>
<td> 实现 DRR PASS 的入口函数 </td>
<td> ctx: 创建 Patten 所需要的 Context 参数</td>
</tr>
Expand Down Expand Up @@ -168,11 +168,11 @@ Attribute Attr(const AttrComputeFunc& attr_compute_func) const</pre></td>
## 3 使用示例
Example 1: Matmul + Add -> FusedGemmEpilogue
~~~ c++
class FusedLinearPattern : public pir::drr::DrrPatternBase<FusedLinearPattern> {
class FusedLinearPattern : public paddle::drr::DrrPatternBase<FusedLinearPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// 定义 Source Pattern
pir::drr::SourcePattern pat = ctx->SourcePattern();
paddle::drr::SourcePattern pat = ctx->SourcePattern();
const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(),
{{"transpose_x", pat.Attr("trans_x")},
{"transpose_y", pat.Attr("trans_y")}});
Expand All @@ -182,10 +182,10 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase<FusedLinearPattern> {
pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias"));
// 定义 Result Pattern
pir::drr::ResultPattern res = pat.ResultPattern();
paddle::drr::ResultPattern res = pat.ResultPattern();
// 定义 Constrain
const auto &act_attr =
res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any {
res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any {
return "none";
});
const auto &fused_gemm_epilogue = res.Op(paddle::dialect::FusedGemmEpilogueOp::name(),
Expand All @@ -202,11 +202,11 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase<FusedLinearPattern> {
Example 2: Full + Expand -> Full
~~~ c++
class FoldExpandToConstantPattern
: public pir::drr::DrrPatternBase<FoldExpandToConstantPattern> {
: public paddle::drr::DrrPatternBase<FoldExpandToConstantPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// 定义 Source Pattern
pir::drr::SourcePattern pat = ctx->SourcePattern();
paddle::drr::SourcePattern pat = ctx->SourcePattern();
const auto &full1 = pat.Op(paddle::dialect::FullOp::name(),
{{"shape", pat.Attr("shape_1")},
{"value", pat.Attr("value_1")},
Expand All @@ -221,7 +221,7 @@ class FoldExpandToConstantPattern
pat.Tensor("ret") = expand(full1(), full_int_array1());

// 定义 Result Pattern Constrains: 本 Pass 无额外约束规则
pir::drr::ResultPattern res = pat.ResultPattern();
paddle::drr::ResultPattern res = pat.ResultPattern();
const auto &full2 = res.Op(paddle::dialect::FullOp::name(),
{{"shape", pat.Attr("expand_shape_value")},
{"value", pat.Attr("value_1")},
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/pir/drr/api/drr_pattern_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include "paddle/fluid/pir/drr/api/drr_pattern_context.h"
#include "paddle/fluid/pir/drr/drr_rewrite_pattern.h"

namespace pir {
namespace paddle {
namespace drr {

template <typename DrrPattern>
Expand All @@ -26,7 +26,7 @@ class DrrPatternBase {
virtual ~DrrPatternBase() = default;

// Define the Drr Pattern.
virtual void operator()(pir::drr::DrrPatternContext* ctx) const = 0;
virtual void operator()(paddle::drr::DrrPatternContext* ctx) const = 0;

std::unique_ptr<DrrRewritePattern> Build(
pir::IrContext* ir_context, pir::PatternBenefit benefit = 1) const {
Expand All @@ -39,4 +39,4 @@ class DrrPatternBase {
};

} // namespace drr
} // namespace pir
} // namespace paddle
5 changes: 3 additions & 2 deletions paddle/fluid/pir/drr/api/drr_pattern_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include "paddle/fluid/pir/drr/pattern_graph.h"
#include "paddle/phi/core/enforce.h"

namespace pir {
namespace paddle {
namespace drr {

DrrPatternContext::DrrPatternContext() {
Expand All @@ -28,6 +28,7 @@ DrrPatternContext::DrrPatternContext() {
drr::SourcePattern DrrPatternContext::SourcePattern() {
return drr::SourcePattern(this);
}

const Op& DrrPatternContext::SourceOpPattern(
const std::string& op_type,
const std::unordered_map<std::string, Attribute>& attributes) {
Expand Down Expand Up @@ -167,4 +168,4 @@ void Tensor::operator=(const Tensor& other) const { // NOLINT
}

} // namespace drr
} // namespace pir
} // namespace paddle
4 changes: 2 additions & 2 deletions paddle/fluid/pir/drr/api/drr_pattern_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

#include "paddle/fluid/pir/drr/api/match_context.h"

namespace pir {
namespace paddle {
namespace drr {

class Op;
Expand Down Expand Up @@ -334,4 +334,4 @@ class SourcePattern {
};

} // namespace drr
} // namespace pir
} // namespace paddle
4 changes: 2 additions & 2 deletions paddle/fluid/pir/drr/api/match_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include "paddle/fluid/pir/drr/ir_operation.h"
#include "paddle/fluid/pir/drr/match_context_impl.h"

namespace pir {
namespace paddle {
namespace drr {

MatchContext::MatchContext(std::shared_ptr<const MatchContextImpl> impl)
Expand All @@ -46,4 +46,4 @@ template std::vector<int64_t> MatchContext::Attr<std::vector<int64_t>>(
const std::string&) const;

} // namespace drr
} // namespace pir
} // namespace paddle
4 changes: 2 additions & 2 deletions paddle/fluid/pir/drr/api/match_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include "paddle/fluid/pir/drr/api/tensor_interface.h"
#include "paddle/fluid/pir/drr/ir_operation.h"

namespace pir {
namespace paddle {
namespace drr {

class TensorInterface;
Expand All @@ -40,4 +40,4 @@ class MatchContext final {
};

} // namespace drr
} // namespace pir
} // namespace paddle
Loading