Skip to content

Commit

Permalink
[PIR] Reconstruct the Verify system (PaddlePaddle#58052)
Browse files Browse the repository at this point in the history
* refine verify of if op

* fix

* fix

* fix

* refine

* fix

* fix

* fix

* fix
  • Loading branch information
zhangbo9674 authored Oct 13, 2023
1 parent 71ee1cd commit 5348600
Show file tree
Hide file tree
Showing 32 changed files with 205 additions and 94 deletions.
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ std::vector<pir::Operation *> GroupOp::ops() {
inner_block->end());
}

void GroupOp::Verify() {}
void GroupOp::VerifySig() {}

void GroupOp::Print(pir::IrPrinter &printer) {
auto &os = printer.os;
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class GroupOp : public pir::Op<GroupOp> {
pir::Block *block();
std::vector<pir::Operation *> ops();

void Verify();
void VerifySig();
void Print(pir::IrPrinter &printer); // NOLINT
};

Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace dialect {

const char* JitKernelOp::attributes_name[attributes_num] = {kAttrName};

void JitKernelOp::Verify() {
void JitKernelOp::VerifySig() {
VLOG(4) << "Verifying inputs, outputs and attributes for: JitKernelOp.";

auto& attributes = this->attributes();
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class JitKernelOp : public ::pir::Op<JitKernelOp> {

hlir::framework::Instruction* instruction();

void Verify();
void VerifySig();
};

} // namespace dialect
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/ir_adaptor/translator/program_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,8 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation(
true);
}
VLOG(4) << "[general op][conditional_block] IfOp false block translate end.";

operation->Verify();
VLOG(4) << "[general op][conditional_block] IfOp translate end.";
return operation;
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ const char* PhiKernelOp::attributes_name[attributes_num] = { // NOLINT
"kernel_name",
"kernel_key"};

void PhiKernelOp::Verify() {
void PhiKernelOp::VerifySig() {
VLOG(4) << "Verifying inputs, outputs and attributes for: PhiKernelOp.";

auto& attributes = this->attributes();
Expand Down Expand Up @@ -64,7 +64,7 @@ const char* LegacyKernelOp::attributes_name[attributes_num] = { // NOLINT
"kernel_name",
"kernel_key"};

void LegacyKernelOp::Verify() {
void LegacyKernelOp::VerifySig() {
VLOG(4) << "Verifying inputs, outputs and attributes for: LegacyKernelOp.";

auto& attributes = this->attributes();
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pir/dialect/kernel/ir/kernel_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class PhiKernelOp : public pir::Op<PhiKernelOp> {
std::string op_name();
std::string kernel_name();
phi::KernelKey kernel_key();
void Verify();
void VerifySig();
};

class LegacyKernelOp : public pir::Op<LegacyKernelOp> {
Expand All @@ -41,7 +41,7 @@ class LegacyKernelOp : public pir::Op<LegacyKernelOp> {
std::string op_name();
std::string kernel_name();
phi::KernelKey kernel_key();
void Verify();
void VerifySig();
};

} // namespace dialect
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{
{build_mutable_attr_is_input}
{build_attr_num_over_1}
{build_mutable_attr_is_input_attr_num_over_1}
void Verify();
void VerifySig();
{get_inputs_and_outputs}
{exclusive_interface}
}};
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pir/dialect/op_generator/op_verify_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# verify
OP_VERIFY_TEMPLATE = """
void {op_name}::Verify() {{
void {op_name}::VerifySig() {{
VLOG(4) << "Start Verifying inputs, outputs and attributes for: {op_name}.";
VLOG(4) << "Verifying inputs:";
{{
Expand All @@ -36,7 +36,7 @@
"""

GRAD_OP_VERIFY_TEMPLATE = """
void {op_name}::Verify() {{}}
void {op_name}::VerifySig() {{}}
"""

INPUT_TYPE_CHECK_TEMPLATE = """
Expand Down
70 changes: 69 additions & 1 deletion paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ paddle::dialect::IfOp, paddle::dialect::WhileOp

#include "paddle/phi/core/enforce.h"
#include "paddle/pir/core/builder.h"
#include "paddle/pir/core/builtin_type.h"
#include "paddle/pir/core/ir_printer.h"
#include "paddle/pir/core/operation_utils.h"
#include "paddle/pir/dialect/control_flow/ir/cf_ops.h"
Expand Down Expand Up @@ -109,7 +110,74 @@ void IfOp::Print(pir::IrPrinter &printer) {
}
os << "\n }";
}
void IfOp::Verify() {}
void IfOp::VerifySig() {
VLOG(4) << "Start Verifying inputs, outputs and attributes for: IfOp.";
auto input_size = num_operands();
PADDLE_ENFORCE_EQ(
input_size,
1u,
phi::errors::PreconditionNotMet(
"The size %d of inputs must be equal to 1.", input_size));

if ((*this)->operand_source(0).type().isa<pir::DenseTensorType>()) {
PADDLE_ENFORCE(
(*this)
->operand_source(0)
.type()
.dyn_cast<pir::DenseTensorType>()
.dtype()
.isa<pir::BoolType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 1th input, it should be a "
"bool DenseTensorType."));
}

PADDLE_ENFORCE_EQ((*this)->num_regions(),
2u,
phi::errors::PreconditionNotMet(
"The size %d of regions must be equal to 2.",
(*this)->num_regions()));
}

void IfOp::VerifyRegion() {
VLOG(4) << "Start Verifying sub regions for: IfOp.";
PADDLE_ENFORCE_EQ(
(*this)->region(0).size(),
1u,
phi::errors::PreconditionNotMet("The size %d of true_region must be 1.",
(*this)->region(0).size()));

if ((*this)->num_results() != 0) {
PADDLE_ENFORCE_EQ(
(*this)->region(0).size(),
(*this)->region(1).size(),
phi::errors::PreconditionNotMet("The size %d of true_region must be "
"equal to the size %d of false_region.",
(*this)->region(0).size(),
(*this)->region(1).size()));

auto *true_last_op = (*this)->region(0).front()->back();
auto *false_last_op = (*this)->region(1).front()->back();
PADDLE_ENFORCE_EQ(true_last_op->isa<pir::YieldOp>(),
true,
phi::errors::PreconditionNotMet(
"The last of true block must be YieldOp"));
PADDLE_ENFORCE_EQ(true_last_op->num_operands(),
(*this)->num_results(),
phi::errors::PreconditionNotMet(
"The size of last of true block op's input must be "
"equal to IfOp's outputs num."));
PADDLE_ENFORCE_EQ(false_last_op->isa<pir::YieldOp>(),
true,
phi::errors::PreconditionNotMet(
"The last of false block must be YieldOp"));
PADDLE_ENFORCE_EQ(false_last_op->num_operands(),
(*this)->num_results(),
phi::errors::PreconditionNotMet(
"The size of last of false block op's input must be "
"equal to IfOp's outputs num."));
}
}

void WhileOp::Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ class IfOp : public pir::Op<IfOp> {
pir::Block *true_block();
pir::Block *false_block();
void Print(pir::IrPrinter &printer); // NOLINT
void Verify();
void VerifySig();
void VerifyRegion();
};

class WhileOp : public pir::Op<WhileOp> {
Expand All @@ -57,7 +58,8 @@ class WhileOp : public pir::Op<WhileOp> {
pir::Block *cond_block();
pir::Block *body_block();
void Print(pir::IrPrinter &printer); // NOLINT
void Verify() {}
void VerifySig() {}
void VerifyRegion() {}
};

} // namespace dialect
Expand Down
12 changes: 6 additions & 6 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ OpInfoTuple AddNOp::GetOpInfo() {
return std::make_tuple(inputs, attributes, outputs, run_time_info, "add_n");
}

void AddNOp::Verify() {
void AddNOp::VerifySig() {
VLOG(4) << "Start Verifying inputs, outputs and attributes for: AddNOp.";
VLOG(4) << "Verifying inputs:";
{
Expand Down Expand Up @@ -222,7 +222,7 @@ void AddN_Op::Build(pir::Builder &builder,
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
}

void AddN_Op::Verify() {
void AddN_Op::VerifySig() {
VLOG(4) << "Start Verifying inputs, outputs and attributes for: AddN_Op.";
VLOG(4) << "Verifying inputs:";
{
Expand Down Expand Up @@ -345,7 +345,7 @@ void AddNWithKernelOp::Build(pir::Builder &builder,
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
}

void AddNWithKernelOp::Verify() {
void AddNWithKernelOp::VerifySig() {
VLOG(4) << "Start Verifying inputs, outputs and attributes for: "
"AddNWithKernelOp.";
VLOG(4) << "Verifying inputs:";
Expand Down Expand Up @@ -561,7 +561,7 @@ void FusedGemmEpilogueOp::Build(pir::Builder &builder,
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
}

void FusedGemmEpilogueOp::Verify() {
void FusedGemmEpilogueOp::VerifySig() {
VLOG(4) << "Start Verifying inputs, outputs and attributes for: "
"FusedGemmEpilogueOp.";
VLOG(4) << "Verifying inputs:";
Expand Down Expand Up @@ -833,7 +833,7 @@ void FusedGemmEpilogueGradOp::Build(pir::Builder &builder,
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
}

void FusedGemmEpilogueGradOp::Verify() {}
void FusedGemmEpilogueGradOp::VerifySig() {}

void FusedGemmEpilogueGradOp::InferMeta(phi::InferMetaContext *infer_meta) {
auto fn = PD_INFER_META(phi::FusedGemmEpilogueGradInferMeta);
Expand Down Expand Up @@ -983,7 +983,7 @@ void SplitGradOp::Build(pir::Builder &builder,
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
}

void SplitGradOp::Verify() {
void SplitGradOp::VerifySig() {
VLOG(4) << "Start Verifying inputs, outputs and attributes for: SplitGradOp.";
VLOG(4) << "Verifying inputs:";
{
Expand Down
12 changes: 6 additions & 6 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class AddNOp : public pir::Op<AddNOp,
pir::OperationArgument &argument, // NOLINT
pir::Value inputs);

void Verify();
void VerifySig();
pir::Value inputs() { return operand_source(0); }
pir::OpResult out() { return result(0); }
static void InferMeta(phi::InferMetaContext *infer_meta);
Expand All @@ -69,7 +69,7 @@ class AddN_Op : public pir::Op<AddN_Op,
pir::OperationArgument &argument, // NOLINT
pir::Value inputs_);

void Verify();
void VerifySig();
pir::Value inputs() { return operand_source(0); }
pir::OpResult out() { return result(0); }

Expand All @@ -89,7 +89,7 @@ class AddNWithKernelOp : public pir::Op<AddNWithKernelOp,
pir::OperationArgument &argument, // NOLINT
pir::Value inputs_);

void Verify();
void VerifySig();
pir::Value inputs() { return operand_source(0); }
pir::OpResult out() { return result(0); }

Expand All @@ -113,7 +113,7 @@ class FusedGemmEpilogueOp
pir::Value y_,
pir::Value bias_,
pir::AttributeMap attributes);
void Verify();
void VerifySig();
pir::Value x() { return operand_source(0); }
pir::Value y() { return operand_source(1); }
pir::Value bias() { return operand_source(2); }
Expand Down Expand Up @@ -141,7 +141,7 @@ class FusedGemmEpilogueGradOp
pir::Value reserve_space_,
pir::Value out_grad_,
pir::AttributeMap attributes);
void Verify();
void VerifySig();
pir::Value x() { return operand_source(0); }
pir::Value y() { return operand_source(1); }
pir::Value reserve_space() { return operand_source(2); }
Expand Down Expand Up @@ -169,7 +169,7 @@ class SplitGradOp : public pir::Op<SplitGradOp, OpYamlInfoInterface> {
pir::Value out_grad_,
pir::Value axis_);

void Verify();
void VerifySig();
pir::Value out_grad() { return operand_source(0); }
pir::Value axis() { return operand_source(1); }
pir::OpResult x_grad() { return result(0); }
Expand Down
16 changes: 8 additions & 8 deletions paddle/pir/core/builtin_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ void ModuleOp::Destroy() {
}
}

void ModuleOp::Verify() const {
void ModuleOp::VerifySig() const {
VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp.";
// Verify inputs:
IR_ENFORCE(num_operands() == 0u, "The size of inputs must be equal to 0.");
Expand Down Expand Up @@ -118,7 +118,7 @@ void GetParameterOp::PassStopGradients(OperationArgument &argument) {
pir::ArrayAttribute::get(pir::IrContext::Instance(), outs_stop_gradient));
}

void GetParameterOp::Verify() const {
void GetParameterOp::VerifySig() const {
VLOG(4) << "Verifying inputs, outputs and attributes for: GetParameterOp.";
// Verify inputs:
IR_ENFORCE(num_operands() == 0u, "The size of inputs must be equal to 0.");
Expand All @@ -144,7 +144,7 @@ void SetParameterOp::Build(Builder &builder, // NOLINT
argument.AddAttribute(attributes_name[0],
pir::StrAttribute::get(builder.ir_context(), name));
}
void SetParameterOp::Verify() const {
void SetParameterOp::VerifySig() const {
VLOG(4) << "Verifying inputs, outputs and attributes for: SetParameterOp.";
// Verify inputs:
IR_ENFORCE(num_operands() == 1, "The size of outputs must be equal to 1.");
Expand All @@ -170,7 +170,7 @@ void ShadowOutputOp::Build(Builder &builder, // NOLINT
argument.AddAttribute(attributes_name[0],
pir::StrAttribute::get(builder.ir_context(), name));
}
void ShadowOutputOp::Verify() const {
void ShadowOutputOp::VerifySig() const {
VLOG(4) << "Verifying inputs, outputs and attributes for: ShadowOutputOp.";
// Verify inputs:
IR_ENFORCE(num_operands() == 1, "The size of outputs must be equal to 1.");
Expand Down Expand Up @@ -198,7 +198,7 @@ void CombineOp::Build(Builder &builder,
PassStopGradientsDefaultly(argument);
}

void CombineOp::Verify() const {
void CombineOp::VerifySig() const {
// outputs.size() == 1
IR_ENFORCE(num_results() == 1u, "The size of outputs must be equal to 1.");

Expand Down Expand Up @@ -260,7 +260,7 @@ void SliceOp::PassStopGradients(OperationArgument &argument, int index) {
pir::ArrayAttribute::get(pir::IrContext::Instance(), outs_stop_gradient));
}

void SliceOp::Verify() const {
void SliceOp::VerifySig() const {
// inputs.size() == 1
auto input_size = num_operands();
IR_ENFORCE(
Expand Down Expand Up @@ -364,7 +364,7 @@ void SplitOp::PassStopGradients(OperationArgument &argument) {
pir::ArrayAttribute::get(pir::IrContext::Instance(), outs_stop_gradient));
}

void SplitOp::Verify() const {
void SplitOp::VerifySig() const {
// inputs.size() == 1
IR_ENFORCE(num_operands() == 1u, "The size of inputs must be equal to 1.");

Expand Down Expand Up @@ -393,7 +393,7 @@ void ConstantOp::Build(Builder &builder,
argument.output_types.push_back(output_type);
}

void ConstantOp::Verify() const {
void ConstantOp::VerifySig() const {
IR_ENFORCE(num_operands() == 0, "The size of inputs must be equal to 0.");
IR_ENFORCE(num_results() == 1, "The size of outputs must be equal to 1.");
IR_ENFORCE(attributes().count("value") > 0, "must has value attribute");
Expand Down
Loading

0 comments on commit 5348600

Please sign in to comment.