diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_tools.cc b/paddle/fluid/pir/dialect/distributed/ir/dist_tools.cc new file mode 100644 index 0000000000000..16eb061d55c4f --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_tools.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/dialect/distributed/ir/dist_tools.h" +#include "paddle/common/enforce.h" + +namespace paddle { +namespace dialect { + +bool HasDistInput(const std::vector& inputs) { + for (auto value : inputs) { + if (value.type().isa()) { + return true; + } + } + return false; +} + +bool AllInputAreDist(const std::vector& inputs) { + for (auto value : inputs) { + if (!value.type().isa()) { + return false; + } + } + return true; +} + +phi::distributed::DistMetaTensor CvtToDistMetaTensor(DistDenseTensorType type) { + auto pir_attr = type.tensor_dist_attr(); + phi::distributed::TensorDistAttr phi_attr; + phi_attr.set_process_mesh(pir_attr.process_mesh_attr().process_mesh()); + phi_attr.set_dims_mapping(pir_attr.dims_mapping()); + phi_attr.set_partial_status(pir_attr.partial_status()); + return phi::distributed::DistMetaTensor(type.global_ddim(), phi_attr); +} + +TensorDistAttribute CvtToPirDistAttr( + const phi::distributed::ArgDistAttr& dist_attr) { + auto& attr = PADDLE_GET_CONST(phi::distributed::TensorDistAttr, dist_attr); + return TensorDistAttribute::get(pir::IrContext::Instance(), + attr.process_mesh(), + attr.dims_mapping(), + attr.partial_status()); +} + +} // namespace dialect +} // namespace paddle diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_tools.h b/paddle/fluid/pir/dialect/distributed/ir/dist_tools.h new file mode 100644 index 0000000000000..aa6cfe9343b9d --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_tools.h @@ -0,0 +1,31 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "paddle/common/ddim.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" +#include "paddle/pir/include/core/value.h" + +namespace paddle { +namespace dialect { + +bool HasDistInput(const std::vector& inputs); +bool AllInputAreDist(const std::vector& inputs); +phi::distributed::DistMetaTensor CvtToDistMetaTensor(DistDenseTensorType type); +TensorDistAttribute CvtToPirDistAttr( + const phi::distributed::ArgDistAttr& dist_attr); + +} // namespace dialect +} // namespace paddle diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_type.cc b/paddle/fluid/pir/dialect/distributed/ir/dist_type.cc index 5044fb5b0b5c2..3f0e896801287 100644 --- a/paddle/fluid/pir/dialect/distributed/ir/dist_type.cc +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_type.cc @@ -34,9 +34,29 @@ DistDenseTensorType DistDenseTensorType::get( pir::IrContext* ctx, pir::DenseTensorType dense_tensor_type, TensorDistAttribute tensor_dist_attr, - const common::DDim& global_ddim) { - return Base::get(ctx, dense_tensor_type, tensor_dist_attr, global_ddim); + const common::DDim& local_ddim) { + return Base::get(ctx, dense_tensor_type, tensor_dist_attr, local_ddim); } + +common::DDim InferLocalDDim(const common::DDim& global_ddim, + TensorDistAttribute dist_attr) { + auto& mesh_dim = dist_attr.process_mesh_attr().shape(); + auto& dim_mapping = dist_attr.dims_mapping(); + PADDLE_ENFORCE_EQ( + global_ddim.size(), + dim_mapping.size(), + ::common::errors::PreconditionNotMet( + "The global ddim size must equal to dim_mapping's size!")); + common::DDim local_ddim(global_ddim); + for (size_t i = 0; i < dim_mapping.size(); ++i) { + if (dim_mapping[i] != -1) { + auto dim_size = mesh_dim.at(dim_mapping[i]); + local_ddim[i] = (global_ddim[i] + dim_size - 1) / dim_size; + } + } + return local_ddim; +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_type.h b/paddle/fluid/pir/dialect/distributed/ir/dist_type.h index 7b35c52c7ea58..c8964a516af76 100644 --- a/paddle/fluid/pir/dialect/distributed/ir/dist_type.h +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_type.h @@ -23,6 +23,8 @@ namespace dialect { class DistDenseTensorTypeStorage; +common::DDim InferLocalDDim(const common::DDim& global_ddim, + TensorDistAttribute dist_attr); class DistDenseTensorType : public pir::Type::TypeBase 0: - op_interfaces_str = "," + ",".join(op_interfaces) - if len(func_list) == 1: op_class_name = to_pascal_case(op_name) + "Op" op_dialect_name = dialect_name + "." + op_name @@ -1413,6 +1392,28 @@ def AutoCodeGen( kernel_func_name ] + op_info.class_name = op_class_name + op_info.kernel_input_type_list = op_input_type_list + op_info.kernel_output_type_list = op_output_type_list + + ( + all_interface_list, + exclusive_declare_list, + exclusive_impl_list, + ) = gen_op_all_func(args, op_info, op_info_items) + all_interface_list += op_interfaces + + all_interface_str = "" + if len(all_interface_list) > 0: + all_interface_str = "," + ",".join(all_interface_list) + + all_declare_str = ( + exclusive_interface_str + + '\n' + + '\n'.join(exclusive_declare_list) + ) + ops_defined_list += exclusive_impl_list + # =================================== # # gen Build methods str # # =================================== # @@ -1432,13 +1433,16 @@ def AutoCodeGen( ) parse_kernel_key_str = "" - if "paddle::dialect::ParseKernelKeyInterface" in op_interfaces: + if ( + "paddle::dialect::ParseKernelKeyInterface" + in all_interface_list + ): parse_kernel_key_str = parse_kernel_key_template infer_symbolic_shape_str = "" if ( "paddle::dialect::InferSymbolicShapeInterface" - in op_interfaces + in all_interface_list ): infer_symbolic_shape_str = infer_symbolic_shape_template @@ -1568,7 +1572,7 @@ def AutoCodeGen( TEST_API=TEST_API, op_name=op_class_name, dialect_op_name=op_dialect_name, - interfaces=op_interfaces_str, + interfaces=all_interface_str, traits=op_traits_str, attribute_declare=op_0_attribute_declare_str, attribute_num=0, @@ -1576,7 +1580,7 @@ def AutoCodeGen( build_mutable_attr_is_input=build_mutable_attr_is_input, build_attr_num_over_1=build_attr_num_over_1, build_mutable_attr_is_input_attr_num_over_1=build_mutable_attr_is_input_attr_num_over_1, - exclusive_interface=exclusive_interface_str, + exclusive_interface=all_declare_str, get_kernel_type_for_var_declare=get_kernel_type_for_var_declare_str, parse_kernel_key_declare=parse_kernel_key_str, infer_symbolic_shape_declare=infer_symbolic_shape_str, @@ -1587,7 +1591,7 @@ def AutoCodeGen( TEST_API=TEST_API, op_name=op_class_name, dialect_op_name=op_dialect_name, - interfaces=op_interfaces_str, + interfaces=all_interface_str, traits=op_traits_str, attribute_declare=op_n_attribute_declare_str.format( attribute_num=len( @@ -1599,7 +1603,7 @@ def AutoCodeGen( build_mutable_attr_is_input=build_mutable_attr_is_input, build_attr_num_over_1=build_attr_num_over_1, build_mutable_attr_is_input_attr_num_over_1=build_mutable_attr_is_input_attr_num_over_1, - exclusive_interface=exclusive_interface_str, + exclusive_interface=all_declare_str, get_kernel_type_for_var_declare=get_kernel_type_for_var_declare_str, parse_kernel_key_declare=parse_kernel_key_str, infer_symbolic_shape_declare=infer_symbolic_shape_str, @@ -1848,7 +1852,10 @@ def AutoCodeGen( # generate op ParseKernelKeyInterface function str parse_kernel_key_define_str = '' - if "paddle::dialect::ParseKernelKeyInterface" in op_interfaces: + if ( + "paddle::dialect::ParseKernelKeyInterface" + in all_interface_list + ): parse_kernel_key_define_str = gen_parse_kernel_key_str( op_class_name ) @@ -1857,7 +1864,7 @@ def AutoCodeGen( infer_symbolic_shape_define_str = '' if ( "paddle::dialect::InferSymbolicShapeInterface" - in op_interfaces + in all_interface_list ): infer_symbolic_shape_define_str = ( gen_infer_symbolic_shape_str(op_class_name) @@ -1867,7 +1874,7 @@ def AutoCodeGen( infer_symbolic_shape_define_str = '' if ( "paddle::dialect::InferSymbolicShapeInterface" - in op_interfaces + in all_interface_list ): infer_symbolic_shape_define_str = ( gen_infer_symbolic_shape_str(op_class_name) @@ -1885,52 +1892,6 @@ def AutoCodeGen( ) ) - op_infer_meta_str = gen_op_infer_meta_str( - op_info, op_class_name, op_info_items - ) - - op_infer_meta_from_type_str = "" - if op_infer_meta_map is not None: - muta_attr_is_input = ( - True - if len(op_mutable_attribute_name_list) > 0 - else False - ) - op_infer_meta_from_type_str = gen_infermeta_func_str( - op_class_name, - op_input_name_list, - op_input_type_list, - op_input_optional_list, - op_mutable_attribute_name_list, - op_mutable_attribute_type_list, - op_output_name_list, - op_output_type_list, - op_output_size_list, - op_output_optional_list, - op_infer_meta_map, - op_inplace_map, - op_attribute_name_list, - op_attribute_type_list, - op_attribute_build_arg_type_list, - op_non_mutable_attribute_name_list, - op_non_mutable_attribute_type_list, - op_non_mutable_attribute_build_arg_type_list, - muta_attr_is_input, - attr_args_is_map=True, - ) - - if (op_invoke_map is not None) and ( - op_invoke_map['func'] in op_info_items - ): - op_invoke_class_name = ( - to_pascal_case(op_invoke_map['func']) + "Op" - ) - op_infer_meta_from_type_str = ( - gen_infermeta_by_invoke_func_str( - op_class_name, op_invoke_class_name - ) - ) - # =================================== # # gen Vjp func str # # =================================== # @@ -1971,8 +1932,6 @@ def AutoCodeGen( ) ops_defined_list.append(op_verify_str) - ops_defined_list.append(op_infer_meta_str) - ops_defined_list.append(op_infer_meta_from_type_str) ops_defined_list.append(op_get_kernel_type_for_var_str) ops_defined_list.append(parse_kernel_key_define_str) ops_defined_list.append(infer_symbolic_shape_define_str) diff --git a/paddle/fluid/pir/dialect/op_generator/op_infer_spmd_func_gen.py b/paddle/fluid/pir/dialect/op_generator/op_infer_spmd_func_gen.py index b14453f44236c..e8ab19ccf8863 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_infer_spmd_func_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_infer_spmd_func_gen.py @@ -20,7 +20,7 @@ def gen_op_infer_spmd_func(args, op_info, op_info_items): - if not args.with_distributed or not op_info.spmd_rule_func: + if not args.with_distributed or op_info.spmd_rule_func is None: return [], None, None input_types_map = { 'paddle::dialect::DenseTensorType': 'const phi::distributed::DistMetaTensor&', @@ -36,9 +36,18 @@ def gen_op_infer_spmd_func(args, op_info, op_info_items): attr_name_list = op_info.attribute_name_list attr_type_list = op_info.attribute_gen_arg_type_list + attr_name_type_dict = {} for attr_idx in range(len(attr_type_list)): attr_name_type_dict[attr_name_list[attr_idx]] = attr_type_list[attr_idx] + scalar_list = [ + "Scalar(int64_t)", + "Scalar(int)", + "Scalar(float)", + "Scalar(double)", + ] + if op_info.op_yaml_item['attrs'][attr_idx]['typename'] in scalar_list: + attr_name_type_dict[attr_name_list[attr_idx]] = "const phi::Scalar&" spmd_params = input_name_list + attr_name_list if op_info.kernel_map is not None: @@ -60,9 +69,12 @@ def gen_op_infer_spmd_func(args, op_info, op_info_items): args_list_with_type.append(param_type + " " + param) args_list.append(param) + spmd_rule_func = op_info.spmd_rule_func + if spmd_rule_func is None: + spmd_rule_func = "VariadicReplicatedInferSpmdDynamic" declare_str = OP_INFER_SPMD_TEMPLATE.format( infer_spmd_args=', '.join(args_list_with_type), - func=op_info.infer_meta_map["spmd_rule"], + func=spmd_rule_func, args=', '.join(args_list), ) return [], declare_str, None diff --git a/paddle/fluid/pir/dialect/op_generator/op_infermeta_gen.py b/paddle/fluid/pir/dialect/op_generator/op_infermeta_gen.py index bebc10bf756c3..491ba61e49f20 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_infermeta_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_infermeta_gen.py @@ -16,9 +16,22 @@ _INFERMETA_NEED_META_CONFIG, _PREPARE_DATA_WITH_VECTOR_INT64_MTTABLE_ATTRIBUTE, ) +from utils import to_pascal_case -OP_INFERMETA_TEMPLATE = """ -std::vector {op_name}::InferMeta(const std::vector& input_values, const pir::AttributeMap& attributes) {{ +OP_INFERMETA_DECL_STRING = ( + " static void InferMeta( phi::InferMetaContext *infer_meta );\n" + " static std::vector InferMeta( const std::vector& input_values, pir::AttributeMap& attributes );" +) + +OP_INFERMETA_IMPL_TEMPLATE_1 = """ +void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{ + auto fn = PD_INFER_META(phi::{infer_meta_func}); + fn(infer_meta); +}} +""" + +OP_INFERMETA_IMPL_TEMPLATE_2 = """ +std::vector {op_name}::InferMeta(const std::vector& input_values, pir::AttributeMap& attributes) {{ {infermeta_inputs} {get_attributes_str} {infermeta_outputs} @@ -26,6 +39,12 @@ }} """ +OP_INFERMETA_IMPL_TEMPLATE_2_BY_INVOKE = """ +std::vector {op_name}::InferMeta(const std::vector& input_values, pir::AttributeMap& attributes) {{ + return {invoke_class}::InferMeta(input_values, attributes); +}} +""" + CREATE_INPUT_VALUE_TEMPLATE = """ pir::Value {input_name}_ = input_values[{index}]; (void){input_name}_;""" @@ -34,12 +53,6 @@ "Num of inputs is expected to be {op_input_name_list_size} but got %d.", input_values.size()); """ -OP_INFERMETA_BY_INVOKE_TEMPLATE = """ -std::vector {op_name}::InferMeta(const std::vector& input_values, const pir::AttributeMap& attributes) {{ - return {invoke_class}::InferMeta(input_values, attributes); -}} -""" - GET_INPUT_TYPE_TEMPLATE = """ {type} {name}; if ({name}_.type().isa<{type}>()) {{ @@ -51,6 +64,7 @@ def get_infermeta_inputs_str( + op_info, inuse_infer_meta_args, op_input_name_list, op_input_type_list, @@ -58,7 +72,7 @@ def get_infermeta_inputs_str( op_mutable_attribute_name_list, mutable_attr_is_input, ): - op_input_name_list_size = len(op_input_name_list) + op_input_name_list_size = len(op_info.input_name_list) if mutable_attr_is_input: op_input_name_list_size += len(op_mutable_attribute_name_list) @@ -66,11 +80,11 @@ def get_infermeta_inputs_str( op_input_name_list_size=str(op_input_name_list_size), ) - for i in range(len(op_input_name_list)): - if op_input_name_list[i] not in inuse_infer_meta_args: + for i in range(len(op_info.input_name_list)): + if op_info.input_name_list[i] not in inuse_infer_meta_args: continue infermeta_inputs_str += CREATE_INPUT_VALUE_TEMPLATE.format( - input_name=op_input_name_list[i], index=str(i) + input_name=op_info.input_name_list[i], index=str(i) ) if mutable_attr_is_input: @@ -119,7 +133,8 @@ def get_infermeta_inputs_str( def GenBuildOutputsPart2( - op_class_name, + args, + op_info, inuse_infer_meta_args, op_input_name_list, op_input_type_list, @@ -285,7 +300,7 @@ def GenBuildOutputsPart2( # int_array if attr_dtype[0] == "paddle::dialect::IntArrayAttribute": if ( - op_class_name + op_info.class_name in _PREPARE_DATA_WITH_VECTOR_INT64_MTTABLE_ATTRIBUTE ): build_output_str += CREATE_VECTOR_INT_MUTABLE_ATTRIBUTE_WITH_UNKNOWN_DATA_TEMPLATE.format( @@ -415,28 +430,21 @@ def GenBuildOutputsPart2( build_output_str += "\n std::vector argument_outputs;" CREATE_OUTPUT_DENSE_TENSOR_TEMPLATE = """ - pir::Type {name}_dense_tensor_type = {type}::get(pir::IrContext::Instance(), paddle::dialect::TransToIrDataType(dense_{name}.dtype()), dense_{name}.dims(), dense_{name}.layout(), dense_{name}.lod(), dense_{name}.offset()); - argument_outputs.push_back({name}_dense_tensor_type); + pir::Type {name}_type = CvtTo{type}(dense_{name}); """ - CREATE_OUTPUT_INPLACE_OPTIONAL_DENSE_TENSOR_TEMPLATE = """ + pir::Type {name}_type; if ({input_name}_.impl() != nullptr) {{ - pir::Type {output_name}_dense_tensor_type = {type}::get(pir::IrContext::Instance(), paddle::dialect::TransToIrDataType(dense_{output_name}.dtype()), dense_{output_name}.dims(), dense_{output_name}.layout(), dense_{output_name}.lod(), dense_{output_name}.offset()); - argument_outputs.push_back({output_name}_dense_tensor_type); - }} else {{ - pir::Type {output_name}_type; - argument_outputs.push_back({output_name}_type); + {name}_type = CvtTo{type}(dense_{name}); }} - """ CREATE_OUTPUT_VEC_DENSE_TENSOR_TEMPLATE = """ std::vector {name}_types; for (size_t i=0; i < static_cast({output_size}); i++) {{ - {name}_types.push_back(paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), paddle::dialect::TransToIrDataType(vec_dense_{name}[i].dtype()), vec_dense_{name}[i].dims(), vec_dense_{name}[i].layout(), vec_dense_{name}[i].lod(), vec_dense_{name}[i].offset())); + {name}_types.push_back(CvtToDenseTensorType(vec_dense_{name}[i])); }} - pir::Type {name}_vector_type = pir::VectorType::get(pir::IrContext::Instance(), {name}_types); - argument_outputs.push_back({name}_vector_type); + pir::Type {name}_type = pir::VectorType::get(pir::IrContext::Instance(), {name}_types); """ for idx in range(len(op_output_name_list)): # is a vector @@ -457,27 +465,30 @@ def GenBuildOutputsPart2( build_output_str += ( CREATE_OUTPUT_INPLACE_OPTIONAL_DENSE_TENSOR_TEMPLATE.format( input_name=op_inplace_map[output_name], - output_name=output_name, - type=op_output_type_list[idx], + name=output_name, + type=op_output_type_list[idx][17:], ) ) else: build_output_str += CREATE_OUTPUT_DENSE_TENSOR_TEMPLATE.format( - type=op_output_type_list[idx], name=output_name + type=op_output_type_list[idx][17:], name=output_name ) + build_output_str += GenDistBranch(args, op_info) + + PUSH_BACK_OUTPUT_TYPE_TEMPLATE = """ + argument_outputs.push_back({name}); +""" + for idx in range(len(op_output_name_list)): + build_output_str += PUSH_BACK_OUTPUT_TYPE_TEMPLATE.format( + name=op_output_name_list[idx] + "_type", + ) return build_output_str def GetAttributes( - op_class_name, + op_info, mutable_attr_is_input, inuse_infer_meta_args, - op_attribute_name_list, - op_attribute_type_list, - op_attribute_build_arg_type_list, - op_non_mutable_attribute_name_list, - op_non_mutable_attribute_type_list, - op_non_mutable_attribute_build_arg_type_list, attr_args_is_map, ): GET_ATTRIBUTES_FROM_MAP_TEMPLATE = """ @@ -521,13 +532,13 @@ def GetAttributes( attr_types = [] attr_build_arg_types = [] if not mutable_attr_is_input: - attr_names = op_attribute_name_list - attr_types = op_attribute_type_list - attr_build_arg_types = op_attribute_build_arg_type_list + attr_names = op_info.attribute_name_list + attr_types = op_info.attribute_type_list + attr_build_arg_types = op_info.attribute_build_arg_type_list else: - attr_names = op_non_mutable_attribute_name_list - attr_types = op_non_mutable_attribute_type_list - attr_build_arg_types = op_non_mutable_attribute_build_arg_type_list + attr_names = op_info.non_mutable_attribute_name_list + attr_types = op_info.non_mutable_attribute_type_list + attr_build_arg_types = op_info.non_mutable_attribute_build_arg_type_list if attr_args_is_map: for idx in range(len(attr_names)): if attr_names[idx] not in inuse_infer_meta_args: @@ -545,7 +556,7 @@ def GetAttributes( data_name = "AsString" get_attributes_str += ( GET_ARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE.format( - op_name=op_class_name, + op_name=op_info.class_name, attr_type=attr_type, attribute_name=attr_names[idx], inner_type=inner_type, @@ -555,7 +566,7 @@ def GetAttributes( elif "paddle::dialect::IntArrayAttribute" in attr_types[idx]: get_attributes_str += ( GET_INTARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE.format( - op_name=op_class_name, + op_name=op_info.class_name, attr_type=attr_type, attribute_name=attr_names[idx], ) @@ -563,7 +574,7 @@ def GetAttributes( elif "paddle::dialect::ScalarAttribute" in attr_types[idx]: get_attributes_str += ( GET_SCALAR_ATTRIBUTE_FROM_MAP_TEMPLATE.format( - op_name=op_class_name, + op_name=op_info.class_name, attr_type=attr_type, attribute_name=attr_names[idx], ) @@ -571,7 +582,7 @@ def GetAttributes( elif "pir::StrAttribute" in attr_types[idx]: get_attributes_str += ( GET_STR_ATTRIBUTES_FROM_MAP_TEMPLATE.format( - op_name=op_class_name, + op_name=op_info.class_name, attr_type=attr_type, attribute_name=attr_names[idx], attr_ir_type=attr_types[idx], @@ -579,7 +590,7 @@ def GetAttributes( ) else: get_attributes_str += GET_ATTRIBUTES_FROM_MAP_TEMPLATE.format( - op_name=op_class_name, + op_name=op_info.class_name, attr_type=attr_type, attribute_name=attr_names[idx], attr_ir_type=attr_types[idx], @@ -587,81 +598,153 @@ def GetAttributes( return get_attributes_str -def gen_infermeta_func_str( - op_class_name, - op_input_name_list, - op_input_type_list, - op_input_optional_list, - op_mutable_attribute_name_list, - op_mutable_attribute_type_list, - op_output_name_list, - op_output_type_list, - op_output_size_list, - op_output_optional_list, - op_infer_meta_map, - op_inplace_map, - op_attribute_name_list, - op_attribute_type_list, - op_attribute_build_arg_type_list, - op_non_mutable_attribute_name_list, - op_non_mutable_attribute_type_list, - op_non_mutable_attribute_build_arg_type_list, - mutable_attr_is_input=False, - attr_args_is_map=True, -): +def GenDistBranch(args, op_info): + if not args.with_distributed or op_info.spmd_rule_func is None: + return "" + TEMPLATE = """ + // Auto Parallel condition + if(!input_values.empty() && AllInputAreDist(input_values)) {{ + ProcessMeshAttribute op_mesh = input_values[0].type().dyn_cast().process_mesh_attr(); + std::vector operand_dist_attrs, result_dist_attrs;""" + dist_branch_str = TEMPLATE.format() + infer_spmd_args_list = [] + # Prepare inputs_meta_tensor & attributes for infer spmd + for name in op_info.spmd_params: + # is input + if name in op_info.input_name_list: + input_index = op_info.input_name_list.index(name) + # is a vector + if 'pir::VectorType' in op_info.input_type_list[input_index]: + TEMPLATE = """ + std::vector vec_dist_meta_{name}; + for(auto& sub_ir_tensor: {name}.data()) {{ + vec_dist_meta_{name}.push_back(CvtToDistMetaTensor(sub_ir_tensor.dyn_cast())); + }}""" + dist_branch_str += TEMPLATE.format(name=name) + infer_spmd_args_list.append("vec_dist_meta_" + name) + # is a Tensor + else: + if op_info.input_optional_list[input_index] == 'true': + TEMPLATE = """ + phi::distributed::DistMetaTensor dist_meta_{name}; + if({name}_) {{ + dist_meta_{name} = CvtToDistMetaTensor({name}_.type().dyn_cast()); + }}""" + dist_branch_str += TEMPLATE.format(name=name) + else: + TEMPLATE = """ + auto dist_meta_{name} = CvtToDistMetaTensor({name}_.type().dyn_cast());""" + dist_branch_str += TEMPLATE.format(name=name) + infer_spmd_args_list.append("dist_meta_" + name) + else: + attr_index = op_info.attribute_name_list.index(name) + param_type = op_info.attribute_gen_arg_type_list[attr_index] + infer_spmd_args_list.append(name) + if param_type == "phi::IntArray": + if name in op_info.mutable_attribute_name_list: + attr_index = op_info.mutable_attribute_name_list.index(name) + attr_type = op_info.mutable_attribute_type_list[attr_index] + if attr_type[0] == "paddle::dialect::IntArrayAttribute": + infer_spmd_args_list[-1] = name + ".GetData()" + TEMPLATE = """ + auto spmd_info = InferSpmd({args}); + for(auto& arg_dist : spmd_info.first) {{ + operand_dist_attrs.push_back(CvtToPirDistAttr(arg_dist)); + }} +""" + dist_branch_str += TEMPLATE.format(args=', '.join(infer_spmd_args_list)) + for idx, output_name in enumerate(op_info.output_name_list): + # is a vector + if 'pir::VectorType' in op_info.output_type_list[idx]: + # Todo: support vector case + dist_branch_str += "" + # is a Tensor + else: + TEMPLATE = """ + auto dist_attr_{name} = CvtToPirDistAttr(spmd_info.second[{idx}]); + result_dist_attrs.push_back(dist_attr_{name}); + argument_outputs.push_back(DistDenseTensorType::get(pir::IrContext::Instance(), {name}_type.dyn_cast(), dist_attr_{name})); +""" + dist_branch_str += TEMPLATE.format(idx=idx, name=output_name) + TEMPLATE = """ + attributes[kAttrOpDistAttrs] = OperationDistAttribute::get( + pir::IrContext::Instance(), + op_mesh, + operand_dist_attrs, + result_dist_attrs + ); + return argument_outputs; + }} +""" + dist_branch_str += TEMPLATE.format() + return dist_branch_str + + +def gen_infermeta_func_str(args, op_info): + attr_args_is_map = True + mutable_attr_is_input = ( + True if len(op_info.mutable_attribute_name_list) > 0 else False + ) inuse_infer_meta_args = [] - for idx in range(len(op_infer_meta_map['param'])): - inuse_infer_meta_args.append(op_infer_meta_map['param'][idx]) + for idx in range(len(op_info.infer_meta_map['param'])): + inuse_infer_meta_args.append(op_info.infer_meta_map['param'][idx]) # Prepare outputs_meta_tensor for infer meta - for idx in range(len(op_output_name_list)): - if op_output_name_list[idx].endswith('_grad'): - inuse_infer_meta_args.append(f"{op_output_name_list[idx][0:-5]}") - if op_output_name_list[idx].endswith('_grad_'): - inuse_infer_meta_args.append(f"{op_output_name_list[idx][0:-6]}") - inuse_infer_meta_args.append(f"{op_output_name_list[idx]}") + for idx in range(len(op_info.output_name_list)): + if op_info.output_name_list[idx].endswith('_grad'): + inuse_infer_meta_args.append( + f"{op_info.output_name_list[idx][0:-5]}" + ) + if op_info.output_name_list[idx].endswith('_grad_'): + inuse_infer_meta_args.append( + f"{op_info.output_name_list[idx][0:-6]}" + ) + inuse_infer_meta_args.append(f"{op_info.output_name_list[idx]}") + + spmd_params = [] + if args.with_distributed and op_info.spmd_rule_func is not None: + spmd_params = op_info.input_name_list + op_info.attribute_name_list + if op_info.kernel_map is not None: + spmd_params = op_info.kernel_map['param'] + op_info.spmd_params = spmd_params infermeta_inputs_str = get_infermeta_inputs_str( - inuse_infer_meta_args, - op_input_name_list, - op_input_type_list, - op_input_optional_list, - op_mutable_attribute_name_list, + op_info, + inuse_infer_meta_args + spmd_params, + op_info.input_name_list, + op_info.kernel_input_type_list, + op_info.input_optional_list, + op_info.mutable_attribute_name_list, mutable_attr_is_input, ) get_attributes_str = GetAttributes( - op_class_name, + op_info, mutable_attr_is_input, - inuse_infer_meta_args, - op_attribute_name_list, - op_attribute_type_list, - op_attribute_build_arg_type_list, - op_non_mutable_attribute_name_list, - op_non_mutable_attribute_type_list, - op_non_mutable_attribute_build_arg_type_list, + inuse_infer_meta_args + spmd_params, attr_args_is_map, ) infermeta_outputs_str = GenBuildOutputsPart2( - op_class_name, - inuse_infer_meta_args, - op_input_name_list, - op_input_type_list, - op_input_optional_list, - op_mutable_attribute_name_list, - op_mutable_attribute_type_list, - op_output_name_list, - op_output_type_list, - op_output_size_list, - op_output_optional_list, - op_infer_meta_map, - op_inplace_map, + args, + op_info, + inuse_infer_meta_args + spmd_params, + op_info.input_name_list, + op_info.kernel_input_type_list, + op_info.input_optional_list, + op_info.mutable_attribute_name_list, + op_info.mutable_attribute_type_list, + op_info.output_name_list, + op_info.kernel_output_type_list, + op_info.output_size_list, + op_info.output_optional_list, + op_info.infer_meta_map, + op_info.inplace_map, mutable_attr_is_input, ) - infermeta_func = OP_INFERMETA_TEMPLATE.format( - op_name=op_class_name, + infermeta_func = OP_INFERMETA_IMPL_TEMPLATE_2.format( + op_name=op_info.class_name, infermeta_inputs=infermeta_inputs_str, get_attributes_str=get_attributes_str, infermeta_outputs=infermeta_outputs_str, @@ -670,17 +753,45 @@ def gen_infermeta_func_str( return infermeta_func -def gen_infermeta_by_invoke_func_str(op_class_name, invoke_class_name): - return OP_INFERMETA_BY_INVOKE_TEMPLATE.format( - op_name=op_class_name, invoke_class=invoke_class_name +def gen_infermeta_impl_str(args, op_info): + return ( + OP_INFERMETA_IMPL_TEMPLATE_1.format( + op_name=op_info.class_name, + infer_meta_func=op_info.infer_meta_func, + ) + + "\n" + + gen_infermeta_func_str(args, op_info) + ) + + +def gen_infermeta_by_invoke_impl_str(op_info, op_info_items): + invoke_class_name = to_pascal_case(op_info.invoke_map['func']) + "Op" + return ( + OP_INFERMETA_IMPL_TEMPLATE_1.format( + op_name=op_info.class_name, + infer_meta_func=op_info_items[ + op_info.invoke_map['func'] + ].infer_meta_func, + ) + + "\n" + + OP_INFERMETA_IMPL_TEMPLATE_2_BY_INVOKE.format( + op_name=op_info.class_name, invoke_class=invoke_class_name + ) ) def gen_op_infermeta_func(args, op_info, op_info_items): interface = [] + declare_str = "" + impl_str = "" if op_info.infer_meta_func: interface = ["paddle::dialect::InferMetaInterface"] + declare_str = OP_INFERMETA_DECL_STRING + impl_str = gen_infermeta_impl_str(args, op_info) elif op_info.invoke_map and op_info.invoke_map['func'] in op_info_items: if op_info_items[op_info.invoke_map['func']].infer_meta_func: interface = ["paddle::dialect::InferMetaInterface"] - return interface, None, None + declare_str = OP_INFERMETA_DECL_STRING + impl_str = gen_infermeta_by_invoke_impl_str(op_info, op_info_items) + + return interface, declare_str, impl_str diff --git a/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py b/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py index 0a0cae38ec2e5..ce9990350e486 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py @@ -15,12 +15,6 @@ # generator interfaces from vjp_interface_black_list import vjp_interface_black_list -OP_INFER_SHAPE_TEMPLATE = """ -void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{ - auto fn = PD_INFER_META(phi::{infer_meta_func}); - fn(infer_meta); -}} -""" CHECK_INPUT_TEMPLATE = """ PADDLE_ENFORCE_EQ( inputs_.size(), @@ -272,37 +266,8 @@ def gen_op_vjp_str( return str -def gen_op_infer_meta_str(op_info, op_class_name, op_info_items): - op_infer_meta_str = "" - if op_info.infer_meta_func: - op_infer_meta_str = OP_INFER_SHAPE_TEMPLATE.format( - op_name=op_class_name, - infer_meta_func=op_info.infer_meta_func, - ) - elif op_info.invoke_map and op_info.invoke_map['func'] in op_info_items: - if op_info_items[op_info.invoke_map['func']].infer_meta_func: - op_infer_meta_str = OP_INFER_SHAPE_TEMPLATE.format( - op_name=op_class_name, - infer_meta_func=op_info_items[ - op_info.invoke_map['func'] - ].infer_meta_func, - ) - return op_infer_meta_str - - def gen_exclusive_interface_str(op_info, op_info_items): exclusive_interface_str = "" - if op_info.infer_meta_func: - exclusive_interface_str += ( - " static void InferMeta( phi::InferMetaContext *infer_meta );\n" - " static std::vector InferMeta( const std::vector& input_values, const pir::AttributeMap& attributes );" - ) - elif op_info.invoke_map and op_info.invoke_map['func'] in op_info_items: - if op_info_items[op_info.invoke_map['func']].infer_meta_func: - exclusive_interface_str += ( - " static void InferMeta( phi::InferMetaContext *infer_meta );\n" - " static std::vector InferMeta( const std::vector& input_values, const pir::AttributeMap& attributes );" - ) if op_info.op_phi_name[0] not in vjp_interface_black_list: exclusive_interface_str += "\n static std::vector> Vjp(pir::Operation* op, const std::vector>& inputs_, const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients);" return exclusive_interface_str diff --git a/paddle/fluid/pir/dialect/op_generator/utils.py b/paddle/fluid/pir/dialect/op_generator/utils.py new file mode 100644 index 0000000000000..79a1f99fca058 --- /dev/null +++ b/paddle/fluid/pir/dialect/op_generator/utils.py @@ -0,0 +1,21 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def to_pascal_case(s): + words = s.split("_") + if s[-1] == "_": + return "".join([word.capitalize() for word in words]) + "_" + else: + return "".join([word.capitalize() for word in words]) + "" diff --git a/paddle/fluid/pir/dialect/operator/interface/infermeta.h b/paddle/fluid/pir/dialect/operator/interface/infermeta.h index bd6d1f7d42013..6a33729ba6899 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infermeta.h +++ b/paddle/fluid/pir/dialect/operator/interface/infermeta.h @@ -25,13 +25,12 @@ class InferMetaInterface : public pir::OpInterfaceBase { struct Concept { explicit Concept(void (*infer_meta)(phi::InferMetaContext *), std::vector (*infer_meta_by_value)( - const std::vector &, - const pir::AttributeMap &)) + const std::vector &, pir::AttributeMap &)) : infer_meta_(infer_meta), infer_meta_by_value_(infer_meta_by_value) {} void (*infer_meta_)(phi::InferMetaContext *); std::vector (*infer_meta_by_value_)( - const std::vector &, const pir::AttributeMap &); + const std::vector &, pir::AttributeMap &); // NOLINT }; template @@ -41,7 +40,7 @@ class InferMetaInterface : public pir::OpInterfaceBase { } static inline std::vector InferMetaByValue( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap &attributes) { // NOLINT return ConcreteOp::InferMeta(input_values, attributes); } Model() : Concept(InferMeta, InferMetaByValue) {} @@ -56,7 +55,7 @@ class InferMetaInterface : public pir::OpInterfaceBase { } std::vector InferMeta(const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap &attributes) { // NOLINT return impl_->infer_meta_by_value_(input_values, attributes); } diff --git a/paddle/fluid/pir/dialect/operator/ir/ir_selected_rows.h b/paddle/fluid/pir/dialect/operator/ir/ir_selected_rows.h index 37000c86b5b65..856ddb2f7542c 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ir_selected_rows.h +++ b/paddle/fluid/pir/dialect/operator/ir/ir_selected_rows.h @@ -14,6 +14,8 @@ #pragma once +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/phi/core/allocator.h" #include "paddle/phi/core/tensor_base.h" #include "paddle/phi/core/tensor_meta.h" @@ -87,5 +89,14 @@ class IrSelectedRows size_t offset_{0}; }; +inline SelectedRowsType CvtToSelectedRowsType(const IrSelectedRows& ir_tensor) { + return SelectedRowsType::get(pir::IrContext::Instance(), + TransToIrDataType(ir_tensor.dtype()), + ir_tensor.dims(), + ir_tensor.layout(), + ir_tensor.lod(), + ir_tensor.offset()); +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/ir/ir_tensor.h b/paddle/fluid/pir/dialect/operator/ir/ir_tensor.h index 21d8a9fdd7ae5..45847d3080387 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ir_tensor.h +++ b/paddle/fluid/pir/dialect/operator/ir/ir_tensor.h @@ -14,9 +14,11 @@ #pragma once +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/phi/core/allocator.h" #include "paddle/phi/core/tensor_base.h" #include "paddle/phi/core/tensor_meta.h" +#include "paddle/pir/include/core/builtin_type.h" namespace paddle { namespace dialect { @@ -86,5 +88,14 @@ class IrTensor : public phi::TensorBase, size_t offset_{0}; }; +inline pir::DenseTensorType CvtToDenseTensorType(const IrTensor& ir_tensor) { + return pir::DenseTensorType::get(pir::IrContext::Instance(), + TransToIrDataType(ir_tensor.dtype()), + ir_tensor.dims(), + ir_tensor.layout(), + ir_tensor.lod(), + ir_tensor.offset()); +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.cc index 6ee537d1ee1a7..588cd210a4523 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.cc @@ -243,7 +243,7 @@ void ExpandOp::InferMeta(phi::InferMetaContext* infer_meta) { std::vector ExpandOp::InferMeta( const std::vector& input_values, - const pir::AttributeMap& attributes) { + pir::AttributeMap& attributes) { // NOLINT IR_ENFORCE(input_values.size() == 2, "Num of inputs is expected to be 2 but got %d.", input_values.size()); diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.h index 3c8050480ade9..54d564f9a77e2 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.h @@ -84,7 +84,7 @@ class ExpandOp : public pir::Op InferMeta( const std::vector& input_values, - const pir::AttributeMap& attributes); + pir::AttributeMap& attributes); // NOLINT }; } // namespace dialect diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index c673ece8fdf46..43d22fce3561d 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -146,7 +146,7 @@ void AddNOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector AddNOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap &attributes) { // NOLINT VLOG(4) << "Start infermeta AddNOp"; IR_ENFORCE(input_values.size() == 1, "Num of inputs is expected to be 1 but got %d.", @@ -292,7 +292,7 @@ void AddN_Op::InferMeta(phi::InferMetaContext *infer_meta) { std::vector AddN_Op::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap &attributes) { // NOLINT VLOG(4) << "Start infermeta AddN_Op"; IR_ENFORCE(input_values.size() == 1, "Num of inputs is expected to be 1 but got %d.", @@ -441,7 +441,7 @@ void AddNArrayOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector AddNArrayOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap &attributes) { // NOLINT VLOG(4) << "Start infermeta AddNArrayOp"; IR_ENFORCE(input_values.size() == 1, "Num of inputs is expected to be 1 but got %d.", @@ -662,7 +662,7 @@ void FusedGemmEpilogueOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector FusedGemmEpilogueOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap &attributes) { // NOLINT VLOG(4) << "Start infermeta FusedGemmEpilogueOp"; IR_ENFORCE(input_values.size() == 3, "Num of inputs is expected to be 3 but got %d.", @@ -907,7 +907,7 @@ void FusedGemmEpilogueGradOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector FusedGemmEpilogueGradOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap &attributes) { // NOLINT IR_ENFORCE(input_values.size() == 4, "Num of inputs is expected to be 4 but got %d.", input_values.size()); @@ -1204,7 +1204,7 @@ void SplitGradOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector SplitGradOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap &attributes) { // NOLINT VLOG(4) << "Start infermeta SplitGradOp"; IR_ENFORCE(input_values.size() == 2, @@ -1343,7 +1343,7 @@ void CreateArrayOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector CreateArrayOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap &attributes) { // NOLINT VLOG(4) << "Start infermeta CreateArrayOp"; PADDLE_ENFORCE( @@ -1461,7 +1461,7 @@ void CreateArrayLikeOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector CreateArrayLikeOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap &attributes) { // NOLINT VLOG(4) << "Start infermeta CreateArrayLikeOp"; IR_ENFORCE(input_values.size() == 1, "Num of inputs is expected to be 1 but got %d.", @@ -1582,7 +1582,7 @@ void ArrayLengthOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector ArrayLengthOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap &attributes) { // NOLINT VLOG(4) << "Start infermeta ArrayLengthOp"; IR_ENFORCE(input_values.size() == 1, "Num of inputs is expected to be 1 but got %d.", @@ -1738,7 +1738,7 @@ void ArrayReadOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector ArrayReadOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap &attributes) { // NOLINT VLOG(4) << "Start infermeta ArrayLengthOp"; IR_ENFORCE(input_values.size() == 2, "Num of inputs is expected to be 2 but got %d.", @@ -1907,7 +1907,7 @@ void ArrayWrite_Op::InferMeta(phi::InferMetaContext *infer_meta) { std::vector ArrayWrite_Op::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap &attributes) { // NOLINT VLOG(4) << "Start infermeta ArrayWrite_Op"; IR_ENFORCE(input_values.size() == 3, "Num of inputs is expected to be 3 but got %d.", @@ -2099,7 +2099,7 @@ void ArrayToTensorOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector ArrayToTensorOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap &attributes) { // NOLINT VLOG(4) << "Start infermeta ArrayToTensorOp"; IR_ENFORCE(input_values.size() == 1, "Num of inputs is expected to be 1 but got %d.", @@ -2288,7 +2288,7 @@ void TensorToArrayOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector TensorToArrayOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap &attributes) { // NOLINT VLOG(4) << "Start infermeta TensorToArrayOp"; IR_ENFORCE(input_values.size() == 2, "Num of inputs is expected to be 2 but got %d.", @@ -2501,7 +2501,7 @@ void SliceArrayOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector SliceArrayOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap &attributes) { // NOLINT VLOG(4) << "Start infermeta SliceArrayOp"; IR_ENFORCE(input_values.size() == 3, "Num of inputs is expected to be 3 but got %d.", @@ -2652,7 +2652,7 @@ void SliceArrayDenseOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector SliceArrayDenseOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap &attributes) { // NOLINT VLOG(4) << "Start infermeta SliceArrayDenseOp"; IR_ENFORCE(input_values.size() == 2, "Num of inputs is expected to be 2 but got %d.", @@ -2791,7 +2791,7 @@ phi::DataType AssignArrayOp::GetKernelTypeForVar( std::vector AssignArrayOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap &attributes) { // NOLINT VLOG(4) << "Start infermeta AssignArrayOp"; IR_ENFORCE(input_values.size() == 1, "Num of inputs is expected to be 1 but got %d.", @@ -2892,7 +2892,7 @@ void AssignArray_Op::InferMeta(phi::InferMetaContext *infer_meta) { std::vector AssignArray_Op::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap &attributes) { // NOLINT VLOG(4) << "Start infermeta AssignArray_Op"; IR_ENFORCE(input_values.size() == 1, "Num of inputs is expected to be 1 but got %d.", @@ -3158,7 +3158,7 @@ void ExpandOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector ExpandOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap &attributes) { // NOLINT VLOG(4) << "Start infermeta ExpandOp"; IR_ENFORCE(input_values.size() == 2, "Num of inputs is expected to be 2 but got %d.", @@ -3381,7 +3381,7 @@ void IncrementOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector IncrementOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap &attributes) { // NOLINT VLOG(4) << "Start infermeta IncrementOp"; IR_ENFORCE(input_values.size() == 1, "Num of inputs is expected to be 1 but got %d.", @@ -3562,7 +3562,7 @@ void Increment_Op::InferMeta(phi::InferMetaContext *infer_meta) { std::vector Increment_Op::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap &attributes) { // NOLINT VLOG(4) << "Start infermeta Increment_Op"; IR_ENFORCE(input_values.size() == 1, "Num of inputs is expected to be 1 but got %d.", @@ -3722,7 +3722,7 @@ void AssignOut_Op::InferMeta(phi::InferMetaContext *infer_meta) { std::vector AssignOut_Op::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap &attributes) { // NOLINT IR_ENFORCE(input_values.size() == 2, "Num of inputs is expected to be 2 but got %d.", input_values.size()); @@ -3801,7 +3801,7 @@ void ShapeBroadcastOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector ShapeBroadcastOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap &attributes) { // NOLINT VLOG(4) << "Start infermeta ShapeBroadcastOp"; IR_ENFORCE(input_values.size() == 2, "Num of inputs is expected to be 2 but got %d.", @@ -4011,7 +4011,7 @@ void MemcpyD2hMultiIoOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector MemcpyD2hMultiIoOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap &attributes) { // NOLINT IR_ENFORCE(input_values.size() == 1, "Num of inputs is expected to be 1 but got %d.", input_values.size()); @@ -4158,7 +4158,7 @@ void ArrayPopOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector ArrayPopOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap &attributes) { // NOLINT VLOG(4) << "Start infermeta ArrayPopOp"; IR_ENFORCE(input_values.size() == 1, "Num of inputs is expected to be 1 but got %d.", diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h index 36feddf569dad..9e76b9255bfcf 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -55,7 +55,7 @@ class AddNOp : public pir::Op InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap &attributes); // NOLINT static std::vector> Vjp( pir::Operation *op, @@ -87,7 +87,7 @@ class AddN_Op : public pir::Op InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap &attributes); // NOLINT }; class AddNArrayOp : public pir::Op InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap &attributes); // NOLINT }; class FusedGemmEpilogueOp @@ -140,7 +140,7 @@ class FusedGemmEpilogueOp static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap &attributes); // NOLINT }; class FusedGemmEpilogueGradOp @@ -173,7 +173,7 @@ class FusedGemmEpilogueGradOp static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap &attributes); // NOLINT }; class SplitGradOp : public pir::Op { @@ -199,7 +199,7 @@ class SplitGradOp : public pir::Op { static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap &attributes); // NOLINT }; class CreateArrayOp @@ -218,7 +218,7 @@ class CreateArrayOp static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap &attributes); // NOLINT }; class CreateArrayLikeOp : public pir::Op InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap &attributes); // NOLINT }; class ArrayLengthOp @@ -260,7 +260,7 @@ class ArrayLengthOp static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap &attributes); // NOLINT }; class ArrayReadOp : public pir::Op InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap &attributes); // NOLINT static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, @@ -321,7 +321,7 @@ class ArrayWrite_Op : public pir::Op InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap &attributes); // NOLINT static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, @@ -352,7 +352,7 @@ class ArrayToTensorOp : public pir::Op InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap &attributes); // NOLINT static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, @@ -382,7 +382,7 @@ class TensorToArrayOp static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap &attributes); // NOLINT }; class SliceArrayOp @@ -416,7 +416,7 @@ class SliceArrayOp static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap &attributes); // NOLINT }; class SliceArrayDenseOp @@ -448,7 +448,7 @@ class SliceArrayDenseOp static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap &attributes); // NOLINT }; class AssignArrayOp @@ -479,7 +479,7 @@ class AssignArrayOp static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap &attributes); // NOLINT }; class AssignArray_Op @@ -507,7 +507,7 @@ class AssignArray_Op static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap &attributes); // NOLINT }; class ExpandOp : public pir::Op InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap &attributes); // NOLINT static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, @@ -597,7 +597,7 @@ class IncrementOp static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap &attributes); // NOLINT static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, @@ -644,7 +644,7 @@ class Increment_Op static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap &attributes); // NOLINT static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, @@ -686,7 +686,7 @@ class AssignOut_Op static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap &attributes); // NOLINT static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, @@ -729,7 +729,7 @@ class MemcpyD2hMultiIoOp static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap &attributes); // NOLINT }; class IR_API ShapeBroadcastOp @@ -755,7 +755,7 @@ class IR_API ShapeBroadcastOp static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap &attributes); // NOLINT bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis); }; @@ -790,7 +790,7 @@ class ArrayPopOp : public pir::Op InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap &attributes); // NOLINT }; } // namespace dialect diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index 94e0c8599ff88..6caaeb81b0fe1 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -92,7 +92,7 @@ pir::Type ConvertOpTypeToKernelType(pir::IrContext* ctx, static const std::vector InferMetaByValue( pir::Operation* op, const std::vector& input_values, - const pir::AttributeMap& attribute_map) { + pir::AttributeMap& attribute_map) { // NOLINT pir::OpInfo op_info = pir::IrContext::Instance()->GetRegisteredOpInfo(op->name()); auto infer_meta_interface = diff --git a/test/auto_parallel/pir/CMakeLists.txt b/test/auto_parallel/pir/CMakeLists.txt index 65e827d046313..ad278c460f59b 100644 --- a/test/auto_parallel/pir/CMakeLists.txt +++ b/test/auto_parallel/pir/CMakeLists.txt @@ -2,4 +2,6 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_to_static_pir_program MODULES test_to_static_pir_program) set_tests_properties(test_to_static_pir_program PROPERTIES ENVIRONMENT "FLAGS_enable_pir_api=1") + py_test_modules(test_ir_dist_attr MODULES test_ir_dist_attr ENVS + FLAGS_enable_pir_api=1) endif() diff --git a/test/ir/pir/test_ir_dist_attr.py b/test/auto_parallel/pir/test_ir_dist_attr.py similarity index 65% rename from test/ir/pir/test_ir_dist_attr.py rename to test/auto_parallel/pir/test_ir_dist_attr.py index a4107199308bf..b0abbca3e87bb 100644 --- a/test/ir/pir/test_ir_dist_attr.py +++ b/test/auto_parallel/pir/test_ir_dist_attr.py @@ -74,7 +74,7 @@ def test_build_replicated_program(self): dist_input = dtensor_from_local(input, mesh, [dist.Replicate()]) dist_w0 = dtensor_from_local(w0, mesh, [dist.Replicate()]) - # dist_out = paddle.matmul(dist_input, dist_w0) + dist_out = paddle.matmul(dist_input, dist_w0) self.assertTrue(dist_input.is_dist_dense_tensor_type()) self.assertTrue(dist_w0.is_dist_dense_tensor_type()) @@ -101,13 +101,17 @@ def test_build_replicated_program(self): self.assertTrue(len(dist_w0.partial_dims) == 0) # matmul out - # self.assertTrue(dist_out.shape == [BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE]) - # self.assertTrue(dist_out._local_shape == [BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE]) - # self.assertTrue(dist_out.dims_mapping == [-1, -1]) - # self.assertTrue(isinstance(dist_out.process_mesh, paddle.base.libpaddle.ProcessMesh)) - # self.assertTrue(dist_out.process_mesh.shape == [2]) - # self.assertTrue(dist_out.process_mesh.process_ids == [0, 1]) - # self.assertTrue(len(dist_out.partial_dims) == 0) + self.assertTrue(dist_out.shape == [BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE]) + self.assertTrue( + dist_out._local_shape == [BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE] + ) + self.assertTrue(dist_out.dims_mapping == [-1, -1, -1]) + self.assertTrue( + isinstance(dist_out.process_mesh, paddle.base.libpaddle.ProcessMesh) + ) + self.assertTrue(dist_out.process_mesh.shape == [2]) + self.assertTrue(dist_out.process_mesh.process_ids == [0, 1]) + self.assertTrue(len(dist_out.partial_dims) == 0) def test_build_col_parallel_program(self): with paddle.pir_utils.IrGuard(): @@ -128,6 +132,7 @@ def test_build_col_parallel_program(self): dist_input = dtensor_from_local(input, mesh, [dist.Replicate()]) dist_w0 = dtensor_from_local(w0, mesh, [dist.Shard(1)]) + dist_out = paddle.matmul(dist_input, dist_w0) self.assertTrue(dist_input.is_dist_dense_tensor_type()) self.assertTrue(dist_w0.is_dist_dense_tensor_type()) @@ -141,13 +146,18 @@ def test_build_col_parallel_program(self): self.assertTrue(dist_input.dims_mapping == [-1, -1, -1]) self.assertTrue(dist_w0.dims_mapping == [-1, 0]) # matmul out - # self.assertTrue(dist_out.shape == [BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE]) - # self.assertTrue(dist_out._local_shape == [BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE // MP_SIZE]) - # self.assertTrue(dist_out.dims_mapping == [-1, -1, 0]) - # self.assertTrue(isinstance(dist_out.process_mesh, paddle.base.libpaddle.ProcessMesh)) - # self.assertTrue(dist_out.process_mesh.shape == [2]) - # self.assertTrue(dist_out.process_mesh.process_ids == [0, 1]) - # self.assertTrue(len(dist_out.partial_dims) == 0) + self.assertTrue(dist_out.shape == [BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE]) + self.assertTrue( + dist_out._local_shape + == [BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE // MP_SIZE] + ) + self.assertTrue(dist_out.dims_mapping == [-1, -1, 0]) + self.assertTrue( + isinstance(dist_out.process_mesh, paddle.base.libpaddle.ProcessMesh) + ) + self.assertTrue(dist_out.process_mesh.shape == [2]) + self.assertTrue(dist_out.process_mesh.process_ids == [0, 1]) + self.assertTrue(len(dist_out.partial_dims) == 0) def test_build_row_parallel_program(self): with paddle.pir_utils.IrGuard(): @@ -169,6 +179,7 @@ def test_build_row_parallel_program(self): dist_input = dtensor_from_local(input, mesh, [dist.Shard(2)]) dist_w0 = dtensor_from_local(w0, mesh, [dist.Shard(0)]) + dist_out = paddle.matmul(dist_input, dist_w0) self.assertTrue(dist_input.is_dist_dense_tensor_type()) self.assertTrue(dist_w0.is_dist_dense_tensor_type()) @@ -185,58 +196,63 @@ def test_build_row_parallel_program(self): self.assertTrue(dist_input.dims_mapping == [-1, -1, 0]) self.assertTrue(dist_w0.dims_mapping == [0, -1]) # matmul out - # self.assertTrue(dist_out.shape == [BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE]) - # self.assertTrue(dist_out._local_shape == [BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE]) - # self.assertTrue(dist_out.dims_mapping == [-1, -1, -1]) - # self.assertTrue(isinstance(dist_out.process_mesh, paddle.base.libpaddle.ProcessMesh)) - # self.assertTrue(dist_out.process_mesh.shape == [2]) - # self.assertTrue(dist_out.process_mesh.process_ids == [0, 1]) - # self.assertTrue(len(dist_out.partial_dims) == set(0)) - - # def test_build_with_shard_tensor(self): - # with paddle.pir_utils.IrGuard(): - # main_program = paddle.base.Program() - # with paddle.base.program_guard(main_program): - # mesh = dist.ProcessMesh([0, 1], dim_names=['mp']) - # input = paddle.static.data( - # name='input', - # shape=[BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE], - # ) - # w0 = paddle.pir.core.create_parameter( - # dtype="float32", - # shape=[HIDDEN_SIZE, HIDDEN_SIZE], - # name="w0", - # initializer=paddle.nn.initializer.Uniform(), - # ) - # w1 = paddle.pir.core.create_parameter( - # dtype="float32", - # shape=[HIDDEN_SIZE, HIDDEN_SIZE], - # name="w0", - # initializer=paddle.nn.initializer.Uniform(), - # ) - # self.assertTrue(input.is_dense_tensor_type()) - # self.assertTrue(w0.is_dense_tensor_type()) - - # dist_input = dist.shard_tensor(input, mesh, [dist.Replicate()]) - # dist_w0 = dist.shard_tensor(w0, mesh, [dist.Shard(0)]) - # dist_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(1)]) - # self.assertTrue(dist_input.is_dist_dense_tensor_type()) - # self.assertTrue(dist_w0.is_dist_dense_tensor_type()) - - # # check global shape - # self.assertTrue(dist_input.shape == [BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE]) - # self.assertTrue(dist_w0.shape == [HIDDEN_SIZE, HIDDEN_SIZE]) - # self.assertTrue(dist_w1.shape == [HIDDEN_SIZE, HIDDEN_SIZE]) - # # check local shape - # self.assertTrue( - # dist_input._local_shape == dist_input.shape - # ) # replicated, local = global - # self.assertTrue( - # dist_w0._local_shape == [HIDDEN_SIZE // MP_SIZE, HIDDEN_SIZE] - # ) # sharded, local != global, sharded by mesh size - # self.assertTrue( - # dist_w1._local_shape == [HIDDEN_SIZE, HIDDEN_SIZE // MP_SIZE] - # ) # sharded, local != global, sharded by mesh size + self.assertTrue(dist_out.shape == [BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE]) + self.assertTrue( + dist_out._local_shape == [BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE] + ) + self.assertTrue(dist_out.dims_mapping == [-1, -1, -1]) + self.assertTrue( + isinstance(dist_out.process_mesh, paddle.base.libpaddle.ProcessMesh) + ) + self.assertTrue(dist_out.process_mesh.shape == [2]) + self.assertTrue(dist_out.process_mesh.process_ids == [0, 1]) + self.assertTrue(dist_out.partial_dims == {0}) + + def test_build_with_shard_tensor(self): + with paddle.pir_utils.IrGuard(): + main_program = paddle.base.Program() + with paddle.base.program_guard(main_program): + mesh = dist.ProcessMesh([0, 1], dim_names=['mp']) + input = paddle.static.data( + name='input', + shape=[BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE], + ) + w0 = paddle.pir.core.create_parameter( + dtype="float32", + shape=[HIDDEN_SIZE, HIDDEN_SIZE], + name="w0", + initializer=paddle.nn.initializer.Uniform(), + ) + w1 = paddle.pir.core.create_parameter( + dtype="float32", + shape=[HIDDEN_SIZE, HIDDEN_SIZE], + name="w0", + initializer=paddle.nn.initializer.Uniform(), + ) + self.assertTrue(input.is_dense_tensor_type()) + self.assertTrue(w0.is_dense_tensor_type()) + + dist_input = dist.shard_tensor(input, mesh, [dist.Replicate()]) + dist_w0 = dist.shard_tensor(w0, mesh, [dist.Shard(0)]) + dist_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(1)]) + self.assertTrue(dist_input.is_dist_dense_tensor_type()) + self.assertTrue(dist_w0.is_dist_dense_tensor_type()) + + # check global shape + self.assertTrue(dist_input.shape == [BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE]) + self.assertTrue(dist_w0.shape == [HIDDEN_SIZE, HIDDEN_SIZE]) + self.assertTrue(dist_w1.shape == [HIDDEN_SIZE, HIDDEN_SIZE]) + # check local shape + self.assertTrue( + dist_input._local_shape == dist_input.shape + ) # replicated, local = global + self.assertTrue( + dist_w0._local_shape == [HIDDEN_SIZE // MP_SIZE, HIDDEN_SIZE] + ) # sharded, local != global, sharded by mesh size + self.assertTrue( + dist_w1._local_shape == [HIDDEN_SIZE, HIDDEN_SIZE // MP_SIZE] + ) # sharded, local != global, sharded by mesh size + # TODO check Dtype, layout same as densetensor # TODO check dims_mapping & mesh as user annotated