diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_tools.cc b/paddle/fluid/pir/dialect/distributed/ir/dist_tools.cc index 11f418dba1b62..f9ab0609bdcf3 100644 --- a/paddle/fluid/pir/dialect/distributed/ir/dist_tools.cc +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_tools.cc @@ -377,7 +377,9 @@ pir::Attribute CreateReplicatedDistAttr(pir::Type prim_type, } return nullptr; } -pir::Type CvtToPirDistType(pir::Type global_type, pir::Attribute dist_attr) { +pir::Type CvtToPirDistType(pir::Type global_type, + pir::Attribute dist_attr, + const std::vector& local_ddim) { if (!global_type) return nullptr; auto ctx = pir::IrContext::Instance(); if (auto dense_tensor_type = global_type.dyn_cast()) { @@ -389,7 +391,14 @@ pir::Type CvtToPirDistType(pir::Type global_type, pir::Attribute dist_attr) { "Only allowed convert a densor tensor type to dist dense tensor type " "with non-empty TensorDistAttr")); } - return DistDenseTensorType::get(ctx, dense_tensor_type, tensor_dist_attr); + if (!local_ddim.empty()) { + return DistDenseTensorType::get(ctx, + dense_tensor_type, + tensor_dist_attr, + common::make_ddim(local_ddim)); + } else { + return DistDenseTensorType::get(ctx, dense_tensor_type, tensor_dist_attr); + } } else if (auto vec_type = global_type.dyn_cast()) { auto array_attr = dist_attr.dyn_cast(); if (!array_attr) { @@ -406,7 +415,8 @@ pir::Type CvtToPirDistType(pir::Type global_type, pir::Attribute dist_attr) { "The vector type size must equal to array attribute size.")); std::vector dist_vec_type; for (size_t idx = 0; idx < vec_type.size(); ++idx) { - dist_vec_type.push_back(CvtToPirDistType(vec_type[idx], array_attr[idx])); + dist_vec_type.push_back( + CvtToPirDistType(vec_type[idx], array_attr[idx], local_ddim)); } return pir::VectorType::get(ctx, dist_vec_type); } else { diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_tools.h b/paddle/fluid/pir/dialect/distributed/ir/dist_tools.h index 5c45afb6f0e90..d9dfaddd5ad13 100644 --- a/paddle/fluid/pir/dialect/distributed/ir/dist_tools.h +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_tools.h @@ -46,7 +46,10 @@ pir::Attribute CvtToPirAttr(const phi::distributed::ArgDistAttr& dist_attr); pir::Attribute CreateReplicatedDistAttr(pir::Type prim_type, ProcessMeshAttribute mesh); -pir::Type CvtToPirDistType(pir::Type global_type, pir::Attribute dist_attr); +pir::Type CvtToPirDistType( + pir::Type global_type, + pir::Attribute dist_attr, + const std::vector& local_ddim = std::vector()); /// /// When the following conditions are met: diff --git a/paddle/fluid/pybind/dist_api.cc b/paddle/fluid/pybind/dist_api.cc index c12eeb49b883c..593ee0883efd2 100644 --- a/paddle/fluid/pybind/dist_api.cc +++ b/paddle/fluid/pybind/dist_api.cc @@ -42,6 +42,7 @@ using paddle::dialect::DistTypeInterface; using paddle::dialect::OperationDistAttribute; using paddle::dialect::ProcessMeshAttribute; using paddle::dialect::TensorDistAttribute; +using pir::ArrayAttribute; namespace paddle::pybind { @@ -127,11 +128,21 @@ OperationDistAttribute CreateOperationDistAttribute( pir::IrContext::Instance(), mesh, operands, results); } +ArrayAttribute CreateArrayAttribute( + const std::vector &elements) { + return ArrayAttribute::get(pir::IrContext::Instance(), elements); +} + void BindDistUtils(pybind11::module *m) { m->def("create_tensor_dist_attribute", CreateTensorDistAttribute); m->def("create_op_dist_attribute", CreateOperationDistAttribute); + m->def("create_array_attribute", CreateArrayAttribute); m->def("get_sub_meshes", phi::distributed::GetSubMeshes); - m->def("cvt_to_dist_type", &dialect::CvtToPirDistType); + m->def("cvt_to_dist_type", + &dialect::CvtToPirDistType, + py::arg("global_type"), + py::arg("dist_attr"), + py::arg("local_ddim") = std::vector()); } void BindDistPassAPI(pybind11::module *module) { diff --git a/python/paddle/distributed/auto_parallel/static/reshard_funcs/s_to_r_reshard_func.py b/python/paddle/distributed/auto_parallel/static/reshard_funcs/s_to_r_reshard_func.py index 916d9748aef1a..36d163f7e07d7 100644 --- a/python/paddle/distributed/auto_parallel/static/reshard_funcs/s_to_r_reshard_func.py +++ b/python/paddle/distributed/auto_parallel/static/reshard_funcs/s_to_r_reshard_func.py @@ -49,8 +49,9 @@ def infer_allgather_dist_type(self, in_value, split_axis): # may be shard and it will call this 1-D s_to_r function on each # axis. In this case, we should recompute the local and global shape. out_local_shape = list(in_value.shape) - out_local_shape[split_axis] = ( - in_value.shape[split_axis] // mesh.shape[split_mesh_dim] + out_local_shape[split_axis] = int( + (in_value.shape[split_axis] + mesh.shape[split_mesh_dim] - 1) + / mesh.shape[split_mesh_dim] ) out_global_shape = list(out_local_shape) out_global_shape[0] *= mesh.shape[split_mesh_dim] @@ -96,10 +97,8 @@ def get_split_axis_with_dims_mapping(dims_mapping): for k, v in split_axis_map.items(): split_axis = k break - - num_of_padding = ( - src_value.shape[split_axis] % src_dist_attr.process_mesh.size - ) + num_of_process = src_dist_attr.process_mesh.size + num_of_padding = src_value.shape[split_axis] % num_of_process is_balanced_split = num_of_padding == 0 if is_balanced_split: @@ -113,8 +112,102 @@ def get_split_axis_with_dims_mapping(dims_mapping): ) return new_value else: - # TODO(ywt01) support unbalanced split - raise NotImplementedError("unbalanced split is not implemented") + # find the last one + need_padding = ( + paddle.distributed.get_rank() + == src_dist_attr.process_mesh.process_ids[-1] + ) + + # get padding_num + avg_size_on_split_axis = int( + (src_value.shape[split_axis] + num_of_process - 1) + / num_of_process + ) + padding_num = ( + avg_size_on_split_axis * num_of_process + - src_value.shape[split_axis] + ) + if need_padding: + # set right _local_shape + local_shape_at_split_axis = src_value.shape[ + split_axis + ] - avg_size_on_split_axis * (num_of_process - 1) + local_shape = src_value._local_shape + local_shape[split_axis] = local_shape_at_split_axis + tmp_src_type = paddle.base.libpaddle.pir.cvt_to_dist_type( + src_value.type(), src_dist_attr, list(local_shape) + ) + src_value.set_type(tmp_src_type) + padding_shape = src_value._local_shape + padding_shape[split_axis] = padding_num + padding_tensor = paddle.full( + padding_shape, + 0.0, + src_value.dtype, + ) + tmp_src_type1 = paddle.base.libpaddle.pir.cvt_to_dist_type( + padding_tensor.type(), dst_dist_attr + ) + padding_tensor.set_type(tmp_src_type1) + padding_tensor.get_defining_op().dist_attr = ( + paddle.base.libpaddle.pir.create_op_dist_attribute( + dst_dist_attr.process_mesh, [], [dst_dist_attr] + ) + ) + + concat_value = paddle._C_ops.concat( + [src_value, padding_tensor], split_axis + ) + # set concat dist_attr + axis_dist_attr = ( + paddle.base.libpaddle.pir.create_tensor_dist_attribute( + src_dist_attr.process_mesh, [-1], {} + ) + ) + concat_value.get_defining_op().dist_attr = ( + paddle.base.libpaddle.pir.create_op_dist_attribute( + src_dist_attr.process_mesh, + [ + paddle.base.libpaddle.pir.create_array_attribute( + [src_dist_attr, dst_dist_attr] + ), + axis_dist_attr, + ], + [src_dist_attr], + ) + ) + # set concat_value type + concat_global_shape = list(src_value.shape) + concat_global_shape[split_axis] = ( + avg_size_on_split_axis * num_of_process + ) + concat_type = paddle.pir.create_shaped_type( + src_value.type(), concat_global_shape + ) + concat_type = paddle.base.libpaddle.pir.cvt_to_dist_type( + concat_type, src_dist_attr + ) + concat_value.set_type(concat_type) + + new_value = self.reshard_s_to_r_with_padding( + concat_value, + split_axis, + src_dist_attr, + dst_dist_attr, + dst_type, + padding_num, + ) + return new_value + else: + new_value = self.reshard_s_to_r_with_padding( + src_value, + split_axis, + src_dist_attr, + dst_dist_attr, + dst_type, + padding_num, + ) + return new_value def reshard_s_to_r_with_padding( self, @@ -168,14 +261,29 @@ def reshard_s_to_r_with_padding( ) pd_splite_op.result(0).set_type(vec_type) - concat_value = paddle._C_ops.concat(split_values, split_axis) - # fold builtin.split op and builtin.combine op - concat_op = concat_value.get_defining_op() - builtin_combine_op = concat_op.operand_source(0).get_defining_op() - concat_op.operand(0).set_source(pd_splite_op.result(0)) - builtin_combine_op.erase() - builtin_split_op.erase() - return concat_value + if padding_num != 0: + tmp_split_values = paddle._C_ops.split( + split_values[-1], + [ + split_values[-1].shape[split_axis] - padding_num, + padding_num, + ], + split_axis, + ) + split_values[-1] = tmp_split_values[0] + concat_value = paddle._C_ops.concat(split_values, split_axis) + return concat_value + else: + concat_value = paddle._C_ops.concat(split_values, split_axis) + # fold builtin.split op and builtin.combine op + concat_op = concat_value.get_defining_op() + builtin_combine_op = concat_op.operand_source( + 0 + ).get_defining_op() + concat_op.operand(0).set_source(pd_splite_op.result(0)) + builtin_combine_op.erase() + builtin_split_op.erase() + return concat_value return allgather_value diff --git a/test/auto_parallel/pir/pir_reshard_s_to_r.py b/test/auto_parallel/pir/pir_reshard_s_to_r.py index 5eae888415322..324597e4e72cf 100644 --- a/test/auto_parallel/pir/pir_reshard_s_to_r.py +++ b/test/auto_parallel/pir/pir_reshard_s_to_r.py @@ -149,7 +149,8 @@ def run_pir_test_case(self): assert operand_1_dist_attr.dims_mapping == [-1, -1] assert operand_2_dist_attr.dims_mapping == [-1] - assert operand_dist_attr.partial_status == {} + assert operand_1_dist_attr.partial_status == {} + assert operand_2_dist_attr.partial_status == {} result_dist_attrs = op.dist_attr.result(0).as_array_attr() assert len(result_dist_attrs) == 2 @@ -209,6 +210,263 @@ def run_pir_test_case(self): assert op_value.dist_attr().dims_mapping == [-1, -1] assert op_value.dist_attr().partial_status == {} + def run_pir_unbalanced_split_test_case(self): + paddle.enable_static() + if self._backend == "cpu": + paddle.set_device("cpu") + place = paddle.CPUPlace() + elif self._backend == "gpu": + place = paddle.CUDAPlace(dist.get_rank()) + + BATCH_SIZE = 2 + SEQ_LEN = 4 + HIDDEN_SIZE = 9 + MP_SIZE = 2 + + with paddle.pir_utils.IrGuard(): + main_program = paddle.base.Program() + with paddle.base.program_guard(main_program): + mesh = dist.ProcessMesh([0, 1], dim_names=['mp']) + input = paddle.static.data( + name='input', shape=[BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE] + ) + w1 = paddle.pir.core.create_parameter( + dtype="float32", + shape=[HIDDEN_SIZE, HIDDEN_SIZE], + name="w1", + initializer=paddle.nn.initializer.Uniform(), + ) + + input_tensor = dist.shard_tensor( + w1, self._mesh, [dist.Shard(self._shard)] + ) + + reshard_tensor = paddle._C_ops.reshard( + input_tensor, self._mesh, [dist.Replicate()] + ) + apply_reshard_pass(main_program) + # last one will pad + need_padding = dist.get_rank() == self._mesh.process_ids[-1] + ops = [op.name() for op in main_program.global_block().ops] + if need_padding: + np.testing.assert_equal(main_program.num_ops(), 18) + std_ops = [ + 'builtin.parameter', + 'pd_op.data', + 'dist_op.shard_tensor', + 'pd_op.full', + 'pd_op.full', + 'builtin.combine', + 'pd_op.concat', + 'pd_op.all_gather', + 'pd_op.full', + 'pd_op.split_with_num', + 'builtin.split', + 'pd_op.full_int_array', + 'pd_op.full', + 'pd_op.split', + 'builtin.split', + 'pd_op.full', + 'builtin.combine', + 'pd_op.concat', + ] + np.testing.assert_equal( + ops, + std_ops, + ) + else: + np.testing.assert_equal(main_program.num_ops(), 14) + std_ops = [ + 'builtin.parameter', + 'pd_op.data', + 'dist_op.shard_tensor', + 'pd_op.all_gather', + 'pd_op.full', + 'pd_op.split_with_num', + 'builtin.split', + 'pd_op.full_int_array', + 'pd_op.full', + 'pd_op.split', + 'builtin.split', + 'pd_op.full', + 'builtin.combine', + 'pd_op.concat', + ] + + np.testing.assert_equal( + ops, + std_ops, + ) + + first_concat = True + for op in main_program.global_block().ops: + if op.name() == 'pd_op.all_gather': + # check op dist_attr + assert op.dist_attr.num_operands() == 1 + assert op.dist_attr.num_results() == 1 + + operand_dist_attr = op.dist_attr.operand( + 0 + ).as_tensor_dist_attr() + result_dist_attr = op.dist_attr.result(0).as_tensor_dist_attr() + + assert op.dist_attr.process_mesh == self._mesh + assert operand_dist_attr.process_mesh == self._mesh + if self._shard == 0: + assert operand_dist_attr.dims_mapping == [0, -1] + elif self._shard == 1: + assert operand_dist_attr.dims_mapping == [-1, 0] + assert operand_dist_attr.partial_status == {} + + assert result_dist_attr.process_mesh == self._mesh + assert result_dist_attr.dims_mapping == [-1, -1] + assert result_dist_attr.partial_status == {} + + # check op_value dist_attr + assert op.num_results() == 1 + op_value = op.result(0) + assert op_value.is_dense_tensor_type() + assert op_value.is_dist_dense_tensor_type() + assert op_value.is_dist_dense_tensor_type() + assert op_value.dist_attr().process_mesh == self._mesh + assert op_value.dist_attr().dims_mapping == [-1, -1] + assert op_value.dist_attr().partial_status == {} + elif op.name() == 'pd_op.split_with_num': + # check op dist_attr + assert op.dist_attr.num_operands() == 2 + assert op.dist_attr.num_results() == 1 + + operand_1_dist_attr = op.dist_attr.operand( + 0 + ).as_tensor_dist_attr() + operand_2_dist_attr = op.dist_attr.operand( + 1 + ).as_tensor_dist_attr() + + assert op.dist_attr.process_mesh == self._mesh + assert operand_1_dist_attr.process_mesh == self._mesh + assert operand_2_dist_attr.process_mesh == self._mesh + + assert operand_1_dist_attr.dims_mapping == [-1, -1] + assert operand_2_dist_attr.dims_mapping == [-1] + + assert operand_1_dist_attr.partial_status == {} + assert operand_2_dist_attr.partial_status == {} + + result_dist_attrs = op.dist_attr.result(0).as_array_attr() + assert len(result_dist_attrs) == 2 + result_dist_attr_1 = result_dist_attrs[0].as_tensor_dist_attr() + result_dist_attr_2 = result_dist_attrs[1].as_tensor_dist_attr() + assert result_dist_attr_1.process_mesh == self._mesh + assert result_dist_attr_1.dims_mapping == [-1, -1] + assert result_dist_attr_1.partial_status == {} + + assert result_dist_attr_2.process_mesh == self._mesh + assert result_dist_attr_2.dims_mapping == [-1, -1] + assert result_dist_attr_2.partial_status == {} + + # check op_value dist_attr + assert op.num_results() == 1 + op_value = op.result(0) + assert op_value.is_combine() + values = op_value.first_use().owner().results() + for value in values: + assert value.dist_attr().process_mesh == self._mesh + assert value.dist_attr().dims_mapping == [-1, -1] + assert value.dist_attr().partial_status == {} + elif op.name() == 'pd_op.concat': + if need_padding and first_concat: + first_concat = False + # check op dist_attr + assert op.dist_attr.num_operands() == 2 + assert op.dist_attr.num_results() == 1 + + operand_1_dist_attrs = op.dist_attr.operand( + 0 + ).as_array_attr() + assert len(operand_1_dist_attrs) == 2 + + operand_1_dist_attr_1 = operand_1_dist_attrs[ + 0 + ].as_tensor_dist_attr() + operand_1_dist_attr_2 = operand_1_dist_attrs[ + 1 + ].as_tensor_dist_attr() + assert operand_1_dist_attr_1.process_mesh == self._mesh + if self._shard == 0: + assert operand_1_dist_attr_1.dims_mapping == [0, -1] + elif self._shard == 1: + assert operand_1_dist_attr_1.dims_mapping == [-1, 0] + assert operand_1_dist_attr_1.partial_status == {} + + assert operand_1_dist_attr_2.process_mesh == self._mesh + assert operand_1_dist_attr_2.dims_mapping == [-1, -1] + assert operand_1_dist_attr_2.partial_status == {} + + result_dist_attr = op.dist_attr.result( + 0 + ).as_tensor_dist_attr() + assert result_dist_attr.process_mesh == self._mesh + if self._shard == 0: + assert result_dist_attr.dims_mapping == [0, -1] + elif self._shard == 1: + assert result_dist_attr.dims_mapping == [-1, 0] + assert result_dist_attr.partial_status == {} + + # check op_value dist_attr + assert op.num_results() == 1 + op_value = op.result(0) + assert op_value.is_dense_tensor_type() + assert op_value.is_dist_dense_tensor_type() + assert op_value.is_dist_dense_tensor_type() + assert op_value.dist_attr().process_mesh == self._mesh + if self._shard == 0: + assert op_value.dist_attr().dims_mapping == [0, -1] + elif self._shard == 1: + assert op_value.dist_attr().dims_mapping == [-1, 0] + assert op_value.dist_attr().partial_status == {} + else: + # check op dist_attr + assert op.dist_attr.num_operands() == 2 + assert op.dist_attr.num_results() == 1 + + operand_1_dist_attrs = op.dist_attr.operand( + 0 + ).as_array_attr() + assert len(operand_1_dist_attrs) == 2 + + operand_1_dist_attr_1 = operand_1_dist_attrs[ + 0 + ].as_tensor_dist_attr() + operand_1_dist_attr_2 = operand_1_dist_attrs[ + 1 + ].as_tensor_dist_attr() + assert operand_1_dist_attr_1.process_mesh == self._mesh + assert operand_1_dist_attr_1.dims_mapping == [-1, -1] + assert operand_1_dist_attr_1.partial_status == {} + + assert operand_1_dist_attr_2.process_mesh == self._mesh + assert operand_1_dist_attr_2.dims_mapping == [-1, -1] + assert operand_1_dist_attr_2.partial_status == {} + + result_dist_attr = op.dist_attr.result( + 0 + ).as_tensor_dist_attr() + assert result_dist_attr.process_mesh == self._mesh + assert result_dist_attr.dims_mapping == [-1, -1] + assert result_dist_attr.partial_status == {} + + # check op_value dist_attr + assert op.num_results() == 1 + op_value = op.result(0) + assert op_value.is_dense_tensor_type() + assert op_value.is_dist_dense_tensor_type() + assert op_value.is_dist_dense_tensor_type() + assert op_value.dist_attr().process_mesh == self._mesh + assert op_value.dist_attr().dims_mapping == [-1, -1] + assert op_value.dist_attr().partial_status == {} + if __name__ == '__main__': TestReshardSToR().run_pir_test_case() + TestReshardSToR().run_pir_unbalanced_split_test_case()