Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pir infermeta func support inferspmd. #62659

Merged
merged 1 commit into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_tools.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pir/dialect/distributed/ir/dist_tools.h"
#include "paddle/common/enforce.h"

namespace paddle {
namespace dialect {

bool HasDistInput(const std::vector<pir::Value>& inputs) {
for (auto value : inputs) {
if (value.type().isa<DistDenseTensorType>()) {
return true;
}
}
return false;
}

bool AllInputAreDist(const std::vector<pir::Value>& inputs) {
for (auto value : inputs) {
if (!value.type().isa<DistDenseTensorType>()) {
return false;
}
}
return true;
}

phi::distributed::DistMetaTensor CvtToDistMetaTensor(DistDenseTensorType type) {
auto pir_attr = type.tensor_dist_attr();
phi::distributed::TensorDistAttr phi_attr;
phi_attr.set_process_mesh(pir_attr.process_mesh_attr().process_mesh());
phi_attr.set_dims_mapping(pir_attr.dims_mapping());
phi_attr.set_partial_status(pir_attr.partial_status());
return phi::distributed::DistMetaTensor(type.global_ddim(), phi_attr);
}

TensorDistAttribute CvtToPirDistAttr(
const phi::distributed::ArgDistAttr& dist_attr) {
auto& attr = PADDLE_GET_CONST(phi::distributed::TensorDistAttr, dist_attr);
return TensorDistAttribute::get(pir::IrContext::Instance(),
attr.process_mesh(),
attr.dims_mapping(),
attr.partial_status());
}

} // namespace dialect
} // namespace paddle
31 changes: 31 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_tools.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/common/ddim.h"
#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
#include "paddle/phi/core/distributed/type_defs.h"
#include "paddle/pir/include/core/value.h"

namespace paddle {
namespace dialect {

bool HasDistInput(const std::vector<pir::Value>& inputs);
bool AllInputAreDist(const std::vector<pir::Value>& inputs);
phi::distributed::DistMetaTensor CvtToDistMetaTensor(DistDenseTensorType type);
TensorDistAttribute CvtToPirDistAttr(
const phi::distributed::ArgDistAttr& dist_attr);

} // namespace dialect
} // namespace paddle
24 changes: 22 additions & 2 deletions paddle/fluid/pir/dialect/distributed/ir/dist_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,29 @@ DistDenseTensorType DistDenseTensorType::get(
pir::IrContext* ctx,
pir::DenseTensorType dense_tensor_type,
TensorDistAttribute tensor_dist_attr,
const common::DDim& global_ddim) {
return Base::get(ctx, dense_tensor_type, tensor_dist_attr, global_ddim);
const common::DDim& local_ddim) {
return Base::get(ctx, dense_tensor_type, tensor_dist_attr, local_ddim);
}

common::DDim InferLocalDDim(const common::DDim& global_ddim,
TensorDistAttribute dist_attr) {
auto& mesh_dim = dist_attr.process_mesh_attr().shape();
auto& dim_mapping = dist_attr.dims_mapping();
PADDLE_ENFORCE_EQ(
global_ddim.size(),
dim_mapping.size(),
::common::errors::PreconditionNotMet(
"The global ddim size must equal to dim_mapping's size!"));
common::DDim local_ddim(global_ddim);
for (size_t i = 0; i < dim_mapping.size(); ++i) {
if (dim_mapping[i] != -1) {
auto dim_size = mesh_dim.at(dim_mapping[i]);
local_ddim[i] = (global_ddim[i] + dim_size - 1) / dim_size;
}
}
return local_ddim;
}

} // namespace dialect
} // namespace paddle

Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ namespace dialect {

class DistDenseTensorTypeStorage;

common::DDim InferLocalDDim(const common::DDim& global_ddim,
TensorDistAttribute dist_attr);
class DistDenseTensorType
: public pir::Type::TypeBase<DistDenseTensorType,
pir::Type,
Expand Down Expand Up @@ -57,6 +59,13 @@ class DistDenseTensorType
pir::DenseTensorType dense_tensor_type,
TensorDistAttribute tensor_dist_attr,
const common::DDim& local_ddim);
static DistDenseTensorType get(pir::IrContext* ctx,
pir::DenseTensorType dense_tensor_type,
TensorDistAttribute tensor_dist_attr) {
auto local_ddim =
InferLocalDDim(dense_tensor_type.dims(), tensor_dist_attr);
return get(ctx, dense_tensor_type, tensor_dist_attr, local_ddim);
}
};

} // namespace dialect
Expand Down
121 changes: 40 additions & 81 deletions paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,15 @@
from infer_symbolic_shape_gen import gen_infer_symbolic_shape_str
from op_all_func_gen import gen_op_all_func
from op_build_gen import gen_build_func_str, gen_build_func_str_by_invoke
from op_infermeta_gen import (
gen_infermeta_by_invoke_func_str,
gen_infermeta_func_str,
)
from op_interface_gen import (
gen_exclusive_interface_str,
gen_op_infer_meta_str,
gen_op_vjp_str,
)
from op_kerneltype_gen import gen_kernel_type_for_var_str
from op_verify_gen import gen_verify_func_str
from ops_onednn_extra_parser import parse_data_format_tensors, parse_extra_args
from parse_kernel_key_gen import gen_parse_kernel_key_str
from utils import to_pascal_case
from vjp_interface_black_list import vjp_interface_black_list

# import from paddle/fluid/primitive/code_gen/gen.py
Expand Down Expand Up @@ -110,6 +106,8 @@
#include "paddle/phi/core/infermeta_utils.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/infermeta/spmd_rules/rules.h"
#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h"
#include "paddle/fluid/pir/dialect/distributed/ir/dist_tools.h"
#endif
{only_pd_op_header_files}

Expand Down Expand Up @@ -1082,14 +1080,6 @@ def get_phi_dtype_name(self, name):
return name


def to_pascal_case(s):
words = s.split("_")
if s[-1] == "_":
return "".join([word.capitalize() for word in words]) + "_"
else:
return "".join([word.capitalize() for word in words]) + ""


def get_input_grad_semantic(op_info, op_info_items):
input_grad_semantics = []
num_inputs = len(op_info.input_name_list)
Expand Down Expand Up @@ -1305,13 +1295,6 @@ def AutoCodeGen(
op_info, op_info_items
)

interface_list, declare_list, impl_list = gen_op_all_func(
args, op_info, op_info_items
)
op_interfaces += interface_list
exclusive_interface_str += '\n' + '\n'.join(declare_list)
ops_defined_list += impl_list

if dialect_name == "pd_op" or dialect_name == "onednn_op":
op_interfaces += ["paddle::dialect::GetKernelTypeForVarInterface"]

Expand Down Expand Up @@ -1384,10 +1367,6 @@ def AutoCodeGen(
# =================================== #
# gen interface list str #
# =================================== #
op_interfaces_str = ""
if len(op_interfaces) > 0:
op_interfaces_str = "," + ",".join(op_interfaces)

if len(func_list) == 1:
op_class_name = to_pascal_case(op_name) + "Op"
op_dialect_name = dialect_name + "." + op_name
Expand All @@ -1413,6 +1392,28 @@ def AutoCodeGen(
kernel_func_name
]

op_info.class_name = op_class_name
op_info.kernel_input_type_list = op_input_type_list
op_info.kernel_output_type_list = op_output_type_list

(
all_interface_list,
exclusive_declare_list,
exclusive_impl_list,
) = gen_op_all_func(args, op_info, op_info_items)
all_interface_list += op_interfaces

all_interface_str = ""
if len(all_interface_list) > 0:
all_interface_str = "," + ",".join(all_interface_list)

all_declare_str = (
exclusive_interface_str
+ '\n'
+ '\n'.join(exclusive_declare_list)
)
ops_defined_list += exclusive_impl_list

# =================================== #
# gen Build methods str #
# =================================== #
Expand All @@ -1432,13 +1433,16 @@ def AutoCodeGen(
)

parse_kernel_key_str = ""
if "paddle::dialect::ParseKernelKeyInterface" in op_interfaces:
if (
"paddle::dialect::ParseKernelKeyInterface"
in all_interface_list
):
parse_kernel_key_str = parse_kernel_key_template

infer_symbolic_shape_str = ""
if (
"paddle::dialect::InferSymbolicShapeInterface"
in op_interfaces
in all_interface_list
):
infer_symbolic_shape_str = infer_symbolic_shape_template

Expand Down Expand Up @@ -1568,15 +1572,15 @@ def AutoCodeGen(
TEST_API=TEST_API,
op_name=op_class_name,
dialect_op_name=op_dialect_name,
interfaces=op_interfaces_str,
interfaces=all_interface_str,
traits=op_traits_str,
attribute_declare=op_0_attribute_declare_str,
attribute_num=0,
build_args=build_args_with_muta_attr_not_input_for_declare,
build_mutable_attr_is_input=build_mutable_attr_is_input,
build_attr_num_over_1=build_attr_num_over_1,
build_mutable_attr_is_input_attr_num_over_1=build_mutable_attr_is_input_attr_num_over_1,
exclusive_interface=exclusive_interface_str,
exclusive_interface=all_declare_str,
get_kernel_type_for_var_declare=get_kernel_type_for_var_declare_str,
parse_kernel_key_declare=parse_kernel_key_str,
infer_symbolic_shape_declare=infer_symbolic_shape_str,
Expand All @@ -1587,7 +1591,7 @@ def AutoCodeGen(
TEST_API=TEST_API,
op_name=op_class_name,
dialect_op_name=op_dialect_name,
interfaces=op_interfaces_str,
interfaces=all_interface_str,
traits=op_traits_str,
attribute_declare=op_n_attribute_declare_str.format(
attribute_num=len(
Expand All @@ -1599,7 +1603,7 @@ def AutoCodeGen(
build_mutable_attr_is_input=build_mutable_attr_is_input,
build_attr_num_over_1=build_attr_num_over_1,
build_mutable_attr_is_input_attr_num_over_1=build_mutable_attr_is_input_attr_num_over_1,
exclusive_interface=exclusive_interface_str,
exclusive_interface=all_declare_str,
get_kernel_type_for_var_declare=get_kernel_type_for_var_declare_str,
parse_kernel_key_declare=parse_kernel_key_str,
infer_symbolic_shape_declare=infer_symbolic_shape_str,
Expand Down Expand Up @@ -1848,7 +1852,10 @@ def AutoCodeGen(

# generate op ParseKernelKeyInterface function str
parse_kernel_key_define_str = ''
if "paddle::dialect::ParseKernelKeyInterface" in op_interfaces:
if (
"paddle::dialect::ParseKernelKeyInterface"
in all_interface_list
):
parse_kernel_key_define_str = gen_parse_kernel_key_str(
op_class_name
)
Expand All @@ -1857,7 +1864,7 @@ def AutoCodeGen(
infer_symbolic_shape_define_str = ''
if (
"paddle::dialect::InferSymbolicShapeInterface"
in op_interfaces
in all_interface_list
):
infer_symbolic_shape_define_str = (
gen_infer_symbolic_shape_str(op_class_name)
Expand All @@ -1867,7 +1874,7 @@ def AutoCodeGen(
infer_symbolic_shape_define_str = ''
if (
"paddle::dialect::InferSymbolicShapeInterface"
in op_interfaces
in all_interface_list
):
infer_symbolic_shape_define_str = (
gen_infer_symbolic_shape_str(op_class_name)
Expand All @@ -1885,52 +1892,6 @@ def AutoCodeGen(
)
)

op_infer_meta_str = gen_op_infer_meta_str(
op_info, op_class_name, op_info_items
)

op_infer_meta_from_type_str = ""
if op_infer_meta_map is not None:
muta_attr_is_input = (
True
if len(op_mutable_attribute_name_list) > 0
else False
)
op_infer_meta_from_type_str = gen_infermeta_func_str(
op_class_name,
op_input_name_list,
op_input_type_list,
op_input_optional_list,
op_mutable_attribute_name_list,
op_mutable_attribute_type_list,
op_output_name_list,
op_output_type_list,
op_output_size_list,
op_output_optional_list,
op_infer_meta_map,
op_inplace_map,
op_attribute_name_list,
op_attribute_type_list,
op_attribute_build_arg_type_list,
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_type_list,
op_non_mutable_attribute_build_arg_type_list,
muta_attr_is_input,
attr_args_is_map=True,
)

if (op_invoke_map is not None) and (
op_invoke_map['func'] in op_info_items
):
op_invoke_class_name = (
to_pascal_case(op_invoke_map['func']) + "Op"
)
op_infer_meta_from_type_str = (
gen_infermeta_by_invoke_func_str(
op_class_name, op_invoke_class_name
)
)

# =================================== #
# gen Vjp func str #
# =================================== #
Expand Down Expand Up @@ -1971,8 +1932,6 @@ def AutoCodeGen(
)

ops_defined_list.append(op_verify_str)
ops_defined_list.append(op_infer_meta_str)
ops_defined_list.append(op_infer_meta_from_type_str)
ops_defined_list.append(op_get_kernel_type_for_var_str)
ops_defined_list.append(parse_kernel_key_define_str)
ops_defined_list.append(infer_symbolic_shape_define_str)
Expand Down
Loading