Skip to content

Commit

Permalink
[IR] Refine PhiKernelOp attributes name and delete some unused code2 (#…
Browse files Browse the repository at this point in the history
…54944)

* refine code

* add some interface for phi kernel op

* fix compile bug

* delete unused code

* support code

* fix bug

* refine code

* delete unused code

* fix compile bug

* fix compile bug

* delete unused code

* add elementwise add op

* fix compile bug

* refine code

* fix compile bug

* add ut for attribute member function

* delete unused code

* refine ut
  • Loading branch information
zhangbo9674 authored Jun 29, 2023
1 parent b8e4d74 commit a7419ff
Show file tree
Hide file tree
Showing 22 changed files with 192 additions and 109 deletions.
8 changes: 4 additions & 4 deletions paddle/fluid/ir/dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ set(PD_DIALECT_SOURCE_DIR "${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect")
set(PD_DIALECT_BINARY_DIR "${PADDLE_BINARY_DIR}/paddle/fluid/ir/dialect")

# Generate pd_dialect files defining op using op_gen_file
set(op_gen_file ${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/op_gen.py)
set(op_gen_file
${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/op_generator/op_gen.py)
set(op_compat_yaml_file ${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml)
set(op_forward_yaml_file1
${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/ops.parsed.yaml
Expand All @@ -17,10 +18,9 @@ set(op_backward_yaml_file2
${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_backward_ops.parsed.yaml
)
set(op_yaml_file3 ${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/pd_op.yaml)
set(op_yaml_file4
${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/pd_legacy_op.yaml)

set(op_yaml_files
${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2},${op_yaml_file3},${op_yaml_file4}
${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2},${op_yaml_file3}
)
set(op_namespace paddle,dialect)
set(dialect_name pd)
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/ir/dialect/kernel_attribute_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once

#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/attribute_base.h"
#include "paddle/ir/core/utils.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_factory.h"
Expand Down
22 changes: 5 additions & 17 deletions paddle/fluid/ir/dialect/kernel_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,26 +44,14 @@ void PhiKernelOp::Verify() {
"Type of attribute: kernel_key is not right."));
}

const std::string PhiKernelOp::op_name() {
return operation()
->attributes()
.at("op_name")
.dyn_cast<ir::StrAttribute>()
.data();
std::string PhiKernelOp::op_name() {
return attributes().at("op_name").dyn_cast<ir::StrAttribute>().data();
}
const std::string PhiKernelOp::kernel_name() {
return operation()
->attributes()
.at("kernel_name")
.dyn_cast<ir::StrAttribute>()
.data();
std::string PhiKernelOp::kernel_name() {
return attributes().at("kernel_name").dyn_cast<ir::StrAttribute>().data();
}
phi::KernelKey PhiKernelOp::kernel_key() {
return operation()
->attributes()
.at("kernel_key")
.dyn_cast<KernelAttribute>()
.data();
return attributes().at("kernel_key").dyn_cast<KernelAttribute>().data();
}

} // namespace dialect
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/ir/dialect/kernel_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class PhiKernelOp : public ir::Op<PhiKernelOp> {
static const char *name() { return "phi.kernel"; }
static constexpr uint32_t attributes_num = 3;
static const char *attributes_name[attributes_num];
const std::string op_name();
const std::string kernel_name();
std::string op_name();
std::string kernel_name();
phi::KernelKey kernel_key();
void Verify();
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import os

import yaml
from op_interface_gen import gen_exclusive_interface_str, gen_op_infer_meta_str
from op_member_func_gen import gen_op_get_inputs_outputs_str
from op_verify_gen import gen_verify_func_str

# =====================================
Expand All @@ -29,7 +31,7 @@
#undef GET_OP_LIST
{op_declare}
#else
// This file is generated by "paddle/fluid/ir/dialect/op_gen.py"
// This file is generated by "paddle/fluid/ir/dialect/op_generator/op_gen.py"
#include <vector>
Expand Down Expand Up @@ -78,17 +80,12 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{
"static const char *attributes_name[{attribute_num}];"
)

OP_GET_INPUT_TEMPLATE = """ ir::Value {input_name}() {{ return operand({input_index}); }}
"""
OP_GET_OUTPUT_TEMPLATE = """ ir::OpResult {output_name}() {{ return result({output_index}); }}
"""

# =====================================
# String Template for cc file code gen
# =====================================
CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_gen.py"
CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_generator/op_gen.py"
#include "{h_file}"
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/ir/core/builtin_attribute.h"
Expand Down Expand Up @@ -142,12 +139,6 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{
{build_outputs}
}}
"""
OP_INFER_SHAPE_TEMPLATE = """
void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{
auto fn = PD_INFER_META(phi::{infer_meta_func});
fn(infer_meta);
}}
"""

DEFINE_OP_TYPE_ID = """
IR_DEFINE_EXPLICIT_TYPE_ID({op_name})
Expand Down Expand Up @@ -1217,12 +1208,10 @@ def OpGenerator(
op_interfaces = ["OpYamlInfoInterface"]
op_traits = []

exclusive_interface_str = ""
if op_info.infer_meta_func:
op_interfaces += ["InferMetaInterface"]
exclusive_interface_str += (
" static void InferMeta( phi::InferMetaContext *infer_meta );"
)

exclusive_interface_str = gen_exclusive_interface_str(op_info)

# If op has inplace info, we will generate inplace op and non-inplace op.
for op_name in op_info.op_phi_name:
Expand All @@ -1242,22 +1231,11 @@ def OpGenerator(
# =================================== #
# gen get input/output methods str #
# =================================== #
op_get_inputs_outputs_str = ""
for idx in range(len(op_input_name_list)):
op_get_inputs_outputs_str += OP_GET_INPUT_TEMPLATE.format(
input_name=op_input_name_list[idx],
input_index=idx,
)
for idx in range(len(op_mutable_attribute_name_list)):
op_get_inputs_outputs_str += OP_GET_INPUT_TEMPLATE.format(
input_name=op_mutable_attribute_name_list[idx],
input_index=idx + len(op_input_name_list),
)
for idx in range(len(op_output_name_list)):
op_get_inputs_outputs_str += OP_GET_OUTPUT_TEMPLATE.format(
output_name=op_output_name_list[idx],
output_index=idx,
)
op_get_inputs_outputs_str = gen_op_get_inputs_outputs_str(
op_input_name_list,
op_mutable_attribute_name_list,
op_output_name_list,
)

# =================================== #
# gen Build methods str #
Expand Down Expand Up @@ -1472,12 +1450,7 @@ def OpGenerator(
op_output_optional_list,
)

op_infer_meta_str = ""
if op_info.infer_meta_func:
op_infer_meta_str = OP_INFER_SHAPE_TEMPLATE.format(
op_name=op_class_name,
infer_meta_func=op_info.infer_meta_func,
)
op_infer_meta_str = gen_op_infer_meta_str(op_info, op_class_name)

ops_name_list.append(op_class_name)
ops_declare_list.append(op_declare_str)
Expand Down
41 changes: 41 additions & 0 deletions paddle/fluid/ir/dialect/op_generator/op_interface_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# generator interfaces

OP_INFER_SHAPE_TEMPLATE = """
void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{
auto fn = PD_INFER_META(phi::{infer_meta_func});
fn(infer_meta);
}}
"""


def gen_op_infer_meta_str(op_info, op_class_name):
op_infer_meta_str = ""
if op_info.infer_meta_func:
op_infer_meta_str = OP_INFER_SHAPE_TEMPLATE.format(
op_name=op_class_name,
infer_meta_func=op_info.infer_meta_func,
)
return op_infer_meta_str


def gen_exclusive_interface_str(op_info):
exclusive_interface_str = ""
if op_info.infer_meta_func:
exclusive_interface_str += (
" static void InferMeta( phi::InferMetaContext *infer_meta );"
)
return exclusive_interface_str
55 changes: 55 additions & 0 deletions paddle/fluid/ir/dialect/op_generator/op_member_func_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# generator op member function

OP_GET_INPUT_TEMPLATE = """ ir::Value {input_name}() {{ return operand({input_index}); }}
"""
OP_GET_OUTPUT_TEMPLATE = """ ir::OpResult {output_name}() {{ return result({output_index}); }}
"""
OP_GET_ATTRIBUTE_TEMPLATE = """ ir::Attribute attribute(const std::string &name) {{
PADDLE_ENFORCE(attributes().count(name) > 0,
phi::errors::PreconditionNotMet("Attribute is not exist."));
return attributes().at(name);
}}
template <typename T>
T attribute(const std::string &name) {{
PADDLE_ENFORCE(attributes().count(name) > 0 && attributes().at(name).isa<T>(),
phi::errors::PreconditionNotMet("Attribute is not right."));
return attributes().at(name).dyn_cast<T>();
}}
"""


def gen_op_get_inputs_outputs_str(
op_input_name_list, op_mutable_attribute_name_list, op_output_name_list
):
op_get_inputs_outputs_str = ""
for idx in range(len(op_input_name_list)):
op_get_inputs_outputs_str += OP_GET_INPUT_TEMPLATE.format(
input_name=op_input_name_list[idx],
input_index=idx,
)
for idx in range(len(op_mutable_attribute_name_list)):
op_get_inputs_outputs_str += OP_GET_INPUT_TEMPLATE.format(
input_name=op_mutable_attribute_name_list[idx],
input_index=idx + len(op_input_name_list),
)
for idx in range(len(op_output_name_list)):
op_get_inputs_outputs_str += OP_GET_OUTPUT_TEMPLATE.format(
output_name=op_output_name_list[idx],
output_index=idx,
)
op_get_inputs_outputs_str += OP_GET_ATTRIBUTE_TEMPLATE
return op_get_inputs_outputs_str
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/ir/dialect/pd_attribute_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once

#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/attribute_base.h"
#include "paddle/ir/core/utils.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h"
Expand Down
32 changes: 0 additions & 32 deletions paddle/fluid/ir/dialect/pd_legacy_op.yaml

This file was deleted.

1 change: 1 addition & 0 deletions paddle/fluid/ir/dialect/pd_type_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <type_traits>

#include "paddle/ir/core/type.h"
#include "paddle/ir/core/type_base.h"
#include "paddle/ir/core/utils.h"
#include "paddle/phi/core/tensor_meta.h"

Expand Down
12 changes: 12 additions & 0 deletions paddle/ir/core/attribute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,20 @@
// limitations under the License.

#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/attribute_base.h"
#include "paddle/ir/core/dialect.h"

namespace ir {
IrContext *Attribute::ir_context() const { return dialect().ir_context(); }

TypeId Attribute::type_id() { return storage_->abstract_attribute().type_id(); }

const AbstractAttribute &Attribute::abstract_attribute() {
return storage_->abstract_attribute();
}

const Dialect &Attribute::dialect() const {
return storage_->abstract_attribute().dialect();
}

} // namespace ir
17 changes: 9 additions & 8 deletions paddle/ir/core/attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@

#pragma once

#include "paddle/ir/core/attribute_base.h"
#include "paddle/ir/core/cast_utils.h"
#include "paddle/ir/core/type_id.h"

namespace ir {
class AttributeStorage;
class AbstractAttribute;
class IrContext;
class Dialect;

///
/// \brief Unified interface of the Attribute class. Derivation of all Attribute
/// classes only derives interfaces, not members.
Expand Down Expand Up @@ -46,17 +51,13 @@ class IR_API Attribute {
///
/// \brief Some Attribute attribute acquisition interfaces.
///
TypeId type_id() { return storage_->abstract_attribute().type_id(); }
TypeId type_id();

const AbstractAttribute &abstract_attribute() {
return storage_->abstract_attribute();
}
const AbstractAttribute &abstract_attribute();

const Storage *storage() const { return storage_; }

const Dialect &dialect() const {
return storage_->abstract_attribute().dialect();
}
const Dialect &dialect() const;

IrContext *ir_context() const;

Expand Down
Loading

0 comments on commit a7419ff

Please sign in to comment.