Skip to content

Commit

Permalink
pir infermeta func support inferspmd.
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang committed Mar 14, 2024
1 parent b92fd61 commit 86719a2
Show file tree
Hide file tree
Showing 19 changed files with 577 additions and 352 deletions.
58 changes: 58 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_tools.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pir/dialect/distributed/ir/dist_tools.h"
#include "paddle/common/enforce.h"

namespace paddle {
namespace dialect {

bool HasDistInput(const std::vector<pir::Value>& inputs) {
for (auto value : inputs) {
if (value.type().isa<DistDenseTensorType>()) {
return true;
}
}
return false;
}

bool AllInputAreDist(const std::vector<pir::Value>& inputs) {
for (auto value : inputs) {
if (!value.type().isa<DistDenseTensorType>()) {
return false;
}
}
return true;
}

phi::distributed::DistMetaTensor CvtToDistMetaTensor(DistDenseTensorType type) {
auto pir_attr = type.tensor_dist_attr();
phi::distributed::TensorDistAttr phi_attr;
phi_attr.set_process_mesh(pir_attr.process_mesh_attr().process_mesh());
phi_attr.set_dims_mapping(pir_attr.dims_mapping());
phi_attr.set_partial_status(pir_attr.partial_status());
return phi::distributed::DistMetaTensor(type.global_ddim(), phi_attr);
}

TensorDistAttribute CvtToPirDistAttr(
const phi::distributed::ArgDistAttr& dist_attr) {
auto& attr = PADDLE_GET_CONST(phi::distributed::TensorDistAttr, dist_attr);
return TensorDistAttribute::get(pir::IrContext::Instance(),
attr.process_mesh(),
attr.dims_mapping(),
attr.partial_status());
}

} // namespace dialect
} // namespace paddle
31 changes: 31 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_tools.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/common/ddim.h"
#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
#include "paddle/phi/core/distributed/type_defs.h"
#include "paddle/pir/include/core/value.h"

namespace paddle {
namespace dialect {

bool HasDistInput(const std::vector<pir::Value>& inputs);
bool AllInputAreDist(const std::vector<pir::Value>& inputs);
phi::distributed::DistMetaTensor CvtToDistMetaTensor(DistDenseTensorType type);
TensorDistAttribute CvtToPirDistAttr(
const phi::distributed::ArgDistAttr& dist_attr);

} // namespace dialect
} // namespace paddle
24 changes: 22 additions & 2 deletions paddle/fluid/pir/dialect/distributed/ir/dist_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,29 @@ DistDenseTensorType DistDenseTensorType::get(
pir::IrContext* ctx,
pir::DenseTensorType dense_tensor_type,
TensorDistAttribute tensor_dist_attr,
const common::DDim& global_ddim) {
return Base::get(ctx, dense_tensor_type, tensor_dist_attr, global_ddim);
const common::DDim& local_ddim) {
return Base::get(ctx, dense_tensor_type, tensor_dist_attr, local_ddim);
}

common::DDim InferLocalDDim(const common::DDim& global_ddim,
TensorDistAttribute dist_attr) {
auto& mesh_dim = dist_attr.process_mesh_attr().shape();
auto& dim_mapping = dist_attr.dims_mapping();
PADDLE_ENFORCE_EQ(
global_ddim.size(),
dim_mapping.size(),
::common::errors::PreconditionNotMet(
"The global ddim size must equal to dim_mapping's size!"));
common::DDim local_ddim(global_ddim);
for (size_t i = 0; i < dim_mapping.size(); ++i) {
if (dim_mapping[i] != -1) {
auto dim_size = mesh_dim.at(dim_mapping[i]);
local_ddim[i] = (global_ddim[i] + dim_size - 1) / dim_size;
}
}
return local_ddim;
}

} // namespace dialect
} // namespace paddle

Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ namespace dialect {

class DistDenseTensorTypeStorage;

common::DDim InferLocalDDim(const common::DDim& global_ddim,
TensorDistAttribute dist_attr);
class DistDenseTensorType
: public pir::Type::TypeBase<DistDenseTensorType,
pir::Type,
Expand Down Expand Up @@ -57,6 +59,13 @@ class DistDenseTensorType
pir::DenseTensorType dense_tensor_type,
TensorDistAttribute tensor_dist_attr,
const common::DDim& local_ddim);
static DistDenseTensorType get(pir::IrContext* ctx,
pir::DenseTensorType dense_tensor_type,
TensorDistAttribute tensor_dist_attr) {
auto local_ddim =
InferLocalDDim(dense_tensor_type.dims(), tensor_dist_attr);
return get(ctx, dense_tensor_type, tensor_dist_attr, local_ddim);
}
};

} // namespace dialect
Expand Down
121 changes: 40 additions & 81 deletions paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,15 @@
from infer_symbolic_shape_gen import gen_infer_symbolic_shape_str
from op_all_func_gen import gen_op_all_func
from op_build_gen import gen_build_func_str, gen_build_func_str_by_invoke
from op_infermeta_gen import (
gen_infermeta_by_invoke_func_str,
gen_infermeta_func_str,
)
from op_interface_gen import (
gen_exclusive_interface_str,
gen_op_infer_meta_str,
gen_op_vjp_str,
)
from op_kerneltype_gen import gen_kernel_type_for_var_str
from op_verify_gen import gen_verify_func_str
from ops_onednn_extra_parser import parse_data_format_tensors, parse_extra_args
from parse_kernel_key_gen import gen_parse_kernel_key_str
from utils import to_pascal_case
from vjp_interface_black_list import vjp_interface_black_list

# import from paddle/fluid/primitive/code_gen/gen.py
Expand Down Expand Up @@ -110,6 +106,8 @@
#include "paddle/phi/core/infermeta_utils.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/infermeta/spmd_rules/rules.h"
#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h"
#include "paddle/fluid/pir/dialect/distributed/ir/dist_tools.h"
#endif
{only_pd_op_header_files}
Expand Down Expand Up @@ -1082,14 +1080,6 @@ def get_phi_dtype_name(self, name):
return name


def to_pascal_case(s):
words = s.split("_")
if s[-1] == "_":
return "".join([word.capitalize() for word in words]) + "_"
else:
return "".join([word.capitalize() for word in words]) + ""


def get_input_grad_semantic(op_info, op_info_items):
input_grad_semantics = []
num_inputs = len(op_info.input_name_list)
Expand Down Expand Up @@ -1305,13 +1295,6 @@ def AutoCodeGen(
op_info, op_info_items
)

interface_list, declare_list, impl_list = gen_op_all_func(
args, op_info, op_info_items
)
op_interfaces += interface_list
exclusive_interface_str += '\n' + '\n'.join(declare_list)
ops_defined_list += impl_list

if dialect_name == "pd_op" or dialect_name == "onednn_op":
op_interfaces += ["paddle::dialect::GetKernelTypeForVarInterface"]

Expand Down Expand Up @@ -1384,10 +1367,6 @@ def AutoCodeGen(
# =================================== #
# gen interface list str #
# =================================== #
op_interfaces_str = ""
if len(op_interfaces) > 0:
op_interfaces_str = "," + ",".join(op_interfaces)

if len(func_list) == 1:
op_class_name = to_pascal_case(op_name) + "Op"
op_dialect_name = dialect_name + "." + op_name
Expand All @@ -1413,6 +1392,28 @@ def AutoCodeGen(
kernel_func_name
]

op_info.class_name = op_class_name
op_info.kernel_input_type_list = op_input_type_list
op_info.kernel_output_type_list = op_output_type_list

(
all_interface_list,
exclusive_declare_list,
exclusive_impl_list,
) = gen_op_all_func(args, op_info, op_info_items)
all_interface_list += op_interfaces

all_interface_str = ""
if len(all_interface_list) > 0:
all_interface_str = "," + ",".join(all_interface_list)

all_declare_str = (
exclusive_interface_str
+ '\n'
+ '\n'.join(exclusive_declare_list)
)
ops_defined_list += exclusive_impl_list

# =================================== #
# gen Build methods str #
# =================================== #
Expand All @@ -1432,13 +1433,16 @@ def AutoCodeGen(
)

parse_kernel_key_str = ""
if "paddle::dialect::ParseKernelKeyInterface" in op_interfaces:
if (
"paddle::dialect::ParseKernelKeyInterface"
in all_interface_list
):
parse_kernel_key_str = parse_kernel_key_template

infer_symbolic_shape_str = ""
if (
"paddle::dialect::InferSymbolicShapeInterface"
in op_interfaces
in all_interface_list
):
infer_symbolic_shape_str = infer_symbolic_shape_template

Expand Down Expand Up @@ -1568,15 +1572,15 @@ def AutoCodeGen(
TEST_API=TEST_API,
op_name=op_class_name,
dialect_op_name=op_dialect_name,
interfaces=op_interfaces_str,
interfaces=all_interface_str,
traits=op_traits_str,
attribute_declare=op_0_attribute_declare_str,
attribute_num=0,
build_args=build_args_with_muta_attr_not_input_for_declare,
build_mutable_attr_is_input=build_mutable_attr_is_input,
build_attr_num_over_1=build_attr_num_over_1,
build_mutable_attr_is_input_attr_num_over_1=build_mutable_attr_is_input_attr_num_over_1,
exclusive_interface=exclusive_interface_str,
exclusive_interface=all_declare_str,
get_kernel_type_for_var_declare=get_kernel_type_for_var_declare_str,
parse_kernel_key_declare=parse_kernel_key_str,
infer_symbolic_shape_declare=infer_symbolic_shape_str,
Expand All @@ -1587,7 +1591,7 @@ def AutoCodeGen(
TEST_API=TEST_API,
op_name=op_class_name,
dialect_op_name=op_dialect_name,
interfaces=op_interfaces_str,
interfaces=all_interface_str,
traits=op_traits_str,
attribute_declare=op_n_attribute_declare_str.format(
attribute_num=len(
Expand All @@ -1599,7 +1603,7 @@ def AutoCodeGen(
build_mutable_attr_is_input=build_mutable_attr_is_input,
build_attr_num_over_1=build_attr_num_over_1,
build_mutable_attr_is_input_attr_num_over_1=build_mutable_attr_is_input_attr_num_over_1,
exclusive_interface=exclusive_interface_str,
exclusive_interface=all_declare_str,
get_kernel_type_for_var_declare=get_kernel_type_for_var_declare_str,
parse_kernel_key_declare=parse_kernel_key_str,
infer_symbolic_shape_declare=infer_symbolic_shape_str,
Expand Down Expand Up @@ -1848,7 +1852,10 @@ def AutoCodeGen(

# generate op ParseKernelKeyInterface function str
parse_kernel_key_define_str = ''
if "paddle::dialect::ParseKernelKeyInterface" in op_interfaces:
if (
"paddle::dialect::ParseKernelKeyInterface"
in all_interface_list
):
parse_kernel_key_define_str = gen_parse_kernel_key_str(
op_class_name
)
Expand All @@ -1857,7 +1864,7 @@ def AutoCodeGen(
infer_symbolic_shape_define_str = ''
if (
"paddle::dialect::InferSymbolicShapeInterface"
in op_interfaces
in all_interface_list
):
infer_symbolic_shape_define_str = (
gen_infer_symbolic_shape_str(op_class_name)
Expand All @@ -1867,7 +1874,7 @@ def AutoCodeGen(
infer_symbolic_shape_define_str = ''
if (
"paddle::dialect::InferSymbolicShapeInterface"
in op_interfaces
in all_interface_list
):
infer_symbolic_shape_define_str = (
gen_infer_symbolic_shape_str(op_class_name)
Expand All @@ -1885,52 +1892,6 @@ def AutoCodeGen(
)
)

op_infer_meta_str = gen_op_infer_meta_str(
op_info, op_class_name, op_info_items
)

op_infer_meta_from_type_str = ""
if op_infer_meta_map is not None:
muta_attr_is_input = (
True
if len(op_mutable_attribute_name_list) > 0
else False
)
op_infer_meta_from_type_str = gen_infermeta_func_str(
op_class_name,
op_input_name_list,
op_input_type_list,
op_input_optional_list,
op_mutable_attribute_name_list,
op_mutable_attribute_type_list,
op_output_name_list,
op_output_type_list,
op_output_size_list,
op_output_optional_list,
op_infer_meta_map,
op_inplace_map,
op_attribute_name_list,
op_attribute_type_list,
op_attribute_build_arg_type_list,
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_type_list,
op_non_mutable_attribute_build_arg_type_list,
muta_attr_is_input,
attr_args_is_map=True,
)

if (op_invoke_map is not None) and (
op_invoke_map['func'] in op_info_items
):
op_invoke_class_name = (
to_pascal_case(op_invoke_map['func']) + "Op"
)
op_infer_meta_from_type_str = (
gen_infermeta_by_invoke_func_str(
op_class_name, op_invoke_class_name
)
)

# =================================== #
# gen Vjp func str #
# =================================== #
Expand Down Expand Up @@ -1971,8 +1932,6 @@ def AutoCodeGen(
)

ops_defined_list.append(op_verify_str)
ops_defined_list.append(op_infer_meta_str)
ops_defined_list.append(op_infer_meta_from_type_str)
ops_defined_list.append(op_get_kernel_type_for_var_str)
ops_defined_list.append(parse_kernel_key_define_str)
ops_defined_list.append(infer_symbolic_shape_define_str)
Expand Down
Loading

0 comments on commit 86719a2

Please sign in to comment.