diff --git a/paddle/fluid/pybind/dist_api.cc b/paddle/fluid/pybind/dist_api.cc index 13abf232ba248a..d1b000da60c5be 100644 --- a/paddle/fluid/pybind/dist_api.cc +++ b/paddle/fluid/pybind/dist_api.cc @@ -17,6 +17,7 @@ #include "paddle/fluid/pir/dialect/distributed/ir/dist_api.h" #include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_tools.h" #include "paddle/fluid/pir/dialect/distributed/transforms/dist_to_dense_pass.h" #include "paddle/fluid/pir/dialect/distributed/transforms/mix_to_dist_pass.h" #include "paddle/fluid/pybind/dist_api.h" @@ -122,6 +123,7 @@ OperationDistAttribute CreateOperationDistAttribute( void BindDistUtils(pybind11::module *m) { m->def("create_tensor_dist_attribute", CreateTensorDistAttribute); m->def("create_op_dist_attribute", CreateOperationDistAttribute); + m->def("cvt_to_dist_type", &dialect::CvtToPirDistType); } void BindDistPassAPI(pybind11::module *module) { diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index a4d5d0f3c90469..85ce4abcda94d0 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -33,6 +33,7 @@ #include "paddle/fluid/ir_adaptor/translator/utils.h" #include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h" #include "paddle/fluid/pir/dialect/distributed/ir/dist_dialect.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_tools.h" #include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" @@ -815,45 +816,43 @@ py::str Value2String(Value self) { return print_stream.str(); } -phi::DataType GetValueDtype(Value value) { - if (!value.type()) { +phi::DataType GetTensorDtype(Type type) { + if (!type) { PADDLE_THROW(phi::errors::InvalidArgument("The type of value is nullptr.")); } - if (value.type().isa()) { - return paddle::dialect::TransToPhiDataType( - value.type().dyn_cast().dtype()); - } else if (value.type().isa()) { - return paddle::dialect::TransToPhiDataType( - value.type().dyn_cast().dtype()); - } else if (value.type().isa()) { - return paddle::dialect::TransToPhiDataType( - value.type().dyn_cast().dtype()); - } else if (value.type().isa()) { - return paddle::dialect::TransToPhiDataType( - value.type().dyn_cast().dtype()); + if (auto dense_tensor_type = type.dyn_cast()) { + return dialect::TransToPhiDataType(dense_tensor_type.dtype()); + } else if (auto select_rows = type.dyn_cast()) { + return dialect::TransToPhiDataType(select_rows.dtype()); + } else if (auto dense_array = type.dyn_cast()) { + return dialect::TransToPhiDataType(dense_array.dtype()); } else { PADDLE_THROW(phi::errors::InvalidArgument( "Currently, we can only get phi::DataType from DenseTensorType and " - "SelectedRowsType, DistDenseTensorType.")); + "SelectedRowsType, DenseTensorArrayType.")); } } +phi::DataType GetValueDtype(Value value) { + return GetTensorDtype(value.type()); +} -const phi::DDim &GetValueDims(Value value) { - if (!value.type()) { - PADDLE_THROW(phi::errors::InvalidArgument("The type of value is nullptr.")); +const phi::DDim &GetTensorDims(Type type) { + if (!type) { + PADDLE_THROW(common::errors::InvalidArgument( + "The type used to get dims is nullptr.")); } - if (value.type().isa()) { - return value.type().dyn_cast().dims(); - } else if (value.type().isa()) { - return value.type().dyn_cast().dims(); - } else if (value.type().isa()) { - return value.type().dyn_cast().global_ddim(); + if (auto dense_type = type.dyn_cast()) { + return dense_type.dims(); + } else if (auto select_rows_type = type.dyn_cast()) { + return select_rows_type.dims(); } else { - PADDLE_THROW(phi::errors::InvalidArgument( - "Currently, we can only get shape for dense and distdense" - "tensor.")); + PADDLE_THROW(common::errors::InvalidArgument( + "Currently, we can only get shape for dense and selsect rows type.")); } } +const phi::DDim &GetValueDims(Value value) { + return GetTensorDims(value.type()); +} pir::Value apply(Value self, py::object func) { py::gil_scoped_acquire gil; @@ -1100,13 +1099,10 @@ void BindValue(py::module *m) { } return self.type().dyn_cast().tensor_dist_attr(); }) + // The function will calculate the new local shape based on the global + // shape and the dist_attr argument. .def("update_dist_attr", [](Value &self, TensorDistAttribute dist_attr) { - if (auto dist_type = self.type().dyn_cast()) { - self.set_type(dist_type.CopyWithNewDistAttr(dist_attr)); - } else { - PADDLE_THROW(common::errors::InvalidArgument( - "update_dist_attr is only for dist type tensor.")); - } + self.set_type(dialect::CvtToPirDistType(self.type(), dist_attr)); }); } @@ -1137,11 +1133,26 @@ bool GetValueBoolAttr(Value value, const std::string &attr_name) { void BindType(py::module *m) { py::class_ ir_type(*m, "Type"); - ir_type.def("__eq__", &Type::operator==).def("__str__", [](Type &self) { - std::ostringstream print_stream; - print_stream << self; - return print_stream.str(); - }); + ir_type.def("__eq__", &Type::operator==) + .def_property( + "shape", + [](Type self) { return phi::vectorize(GetTensorDims(self)); }, + [](Type self, const std::vector &shape) { + PADDLE_THROW(phi::errors::InvalidArgument( + "can't set shape when building static graph")); + }) + .def_property( + "dtype", + [](Type self) { return GetTensorDtype(self); }, + [](Type self, phi::DataType dtype) { + PADDLE_THROW(phi::errors::InvalidArgument( + "can't set dtype when building static graph")); + }) + .def("__str__", [](Type &self) { + std::ostringstream print_stream; + print_stream << self; + return print_stream.str(); + }); m->def("create_shaped_type", [](Type &type, const std::vector &shape) -> Type { diff --git a/python/paddle/distributed/auto_parallel/static/engine.py b/python/paddle/distributed/auto_parallel/static/engine.py index 4fe6aa06326055..e3ae2bf090fda0 100644 --- a/python/paddle/distributed/auto_parallel/static/engine.py +++ b/python/paddle/distributed/auto_parallel/static/engine.py @@ -672,12 +672,12 @@ def _parallel_pir(self, mode): # TODO(JZ-LIANG) Step 3.1: Partition Pass # insert reshard op if operand tensor's placements if different from what the cumsumer op need. # Partition the computation graph into different pipeline stage if need. - dist_program = apply_partition_pass(dist_program) + apply_partition_pass(dist_program) # TODO(hitywt) Step 3.2: Reshard Pass # resolute the reshard op into special collective operation. # collect the communicator created during resolution. - dist_program = apply_reshard_pass(dist_program) + apply_reshard_pass(dist_program) # Part 4: Optimization Pass # NOTE Only those Optimization Pass that related to Parallelism (need dist attr) should be placed here and all the Pass should be Optional. diff --git a/python/paddle/distributed/auto_parallel/static/pir_pass.py b/python/paddle/distributed/auto_parallel/static/pir_pass.py index 895156bbebf784..5f4d34aad77186 100644 --- a/python/paddle/distributed/auto_parallel/static/pir_pass.py +++ b/python/paddle/distributed/auto_parallel/static/pir_pass.py @@ -61,99 +61,90 @@ def reshard_combine_value(op, operand, attr): def apply_partition_pass(program): - with paddle.static.program_guard(program): - for op in program.global_block().ops: - if op.name() in partition_skip_op_list: - continue - assert len(op.operands()) == len( - op.dist_attr.operands() - ), f"The number of operands and the number of op_dist_attr's operands are not equal in op: {op}" - - for operand, attr in zip(op.operands(), op.dist_attr.operands()): - prev_var = operand.source() - if prev_var.is_combine(): - operand.set_source(reshard_combine_value(op, operand, attr)) - else: - operand.set_source(reshard_single_value(op, operand, attr)) - prev_op = prev_var.get_defining_op() - if ( - prev_op - and prev_op.num_results() == 1 - and prev_var.use_empty() - ): - prev_op.erase() - - for var, attr in zip(op.results(), op.dist_attr.results()): - if ( - var.initialized() - and var.is_dist() - and var.dist_attr() != attr - ): - paddle.pir.set_insertion_point_after(op) - old_dist_attr = var.dist_attr() - var.update_dist_attr(attr.as_tensor_dist_attr()) - # insert reshard - reshard_var = paddle._C_ops.reshard_v2(var, old_dist_attr) - var.replace_all_uses_with(reshard_var) - reshard_var.get_defining_op().operand(0).set_source(var) - - # pruning op and value not belong to cur rank - cur_rank = paddle.distributed.get_rank() - for op in program.global_block().ops[::-1]: - if cur_rank not in op.dist_attr.process_mesh.process_ids: - program.global_block().remove_op(op) + for op in program.global_block().ops: + if op.name() in partition_skip_op_list: + continue + assert len(op.operands()) == len( + op.dist_attr.operands() + ), f"The number of operands and the number of op_dist_attr's operands are not equal in op: {op}" + + for operand, attr in zip(op.operands(), op.dist_attr.operands()): + prev_var = operand.source() + if prev_var.is_combine(): + operand.set_source(reshard_combine_value(op, operand, attr)) else: - # set the operand as null when it is not belong to cur rank - if ( - op.name() == 'dist_op.reshard' - and cur_rank - not in op.operand(0) - .source() - .dist_attr() - .process_mesh.process_ids - ): - op.operand(0).set_source(None) - - # merge pd.data ops for - lr_ops = [] - for op in program.global_block().ops[::-1]: + operand.set_source(reshard_single_value(op, operand, attr)) + prev_op = prev_var.get_defining_op() + if prev_op and prev_op.num_results() == 1 and prev_var.use_empty(): + prev_op.erase() + + for var, attr in zip(op.results(), op.dist_attr.results()): + if var.initialized() and var.is_dist() and var.dist_attr() != attr: + paddle.pir.set_insertion_point_after(op) + old_dist_attr = var.dist_attr() + var.update_dist_attr(attr.as_tensor_dist_attr()) + # insert reshard + reshard_var = paddle._C_ops.reshard_v2(var, old_dist_attr) + var.replace_all_uses_with(reshard_var) + reshard_var.get_defining_op().operand(0).set_source(var) + + # pruning op and value not belong to cur rank + cur_rank = paddle.distributed.get_rank() + for op in program.global_block().ops[::-1]: + if cur_rank not in op.dist_attr.process_mesh.process_ids: + op.erase() + else: + # set the operand as null when it is not belong to cur rank if ( - op.name() == 'pd_op.data' - and "learning_rate" in op.attrs()["name"] + op.name() == 'dist_op.reshard' + and cur_rank + not in op.operand(0) + .source() + .dist_attr() + .process_mesh.process_ids ): - lr_ops.append(op) - - if len(lr_ops) > 1: - lr_value = lr_ops[0].result(0) - for op in lr_ops[1:]: - lr = op.result(0) - lr.replace_all_uses_with(lr_value) - program.global_block().remove_op(op) - return program + op.operand(0).set_source(None) + + # merge pd.data ops for + lr_ops = [] + for op in program.global_block().ops[::-1]: + if op.name() == 'pd_op.data' and "learning_rate" in op.attrs()["name"]: + lr_ops.append(op) + + if len(lr_ops) > 1: + lr_value = lr_ops[0].result(0) + for op in lr_ops[1:]: + lr = op.result(0) + lr.replace_all_uses_with(lr_value) + op.erase() def apply_reshard_pass(program): - new_program = program.clone() - with paddle.base.program_guard(new_program): - for op in new_program.global_block().ops: - if op.name() == 'dist_op.reshard': - var = op.operand_source(0) - op_dist_attr = op.dist_attr - src_dist_attr = op_dist_attr.operand(0).as_tensor_dist_attr() - dst_dist_attr = op_dist_attr.result(0).as_tensor_dist_attr() - assert ( - not var.initialized() or var.dist_attr() == src_dist_attr - ), f"The dist_attr of reshard op's input and operand should be equal, but got {var.dist_attr()} and {src_dist_attr}" - - reshard_func = choose_reshard_func(src_dist_attr, dst_dist_attr) - assert ( - reshard_func is not None - ), f'There is no reshard function that matches src_dist_attr: {src_dist_attr} and dst_dist_attr: {dst_dist_attr}' - reshard_func.reshard( - new_program, op, src_dist_attr, dst_dist_attr - ) - - return new_program + for op in program.global_block().ops: + if op.name() == 'dist_op.reshard': + var = op.operand_source(0) + op_dist_attr = op.dist_attr + src_dist_attr = op_dist_attr.operand(0).as_tensor_dist_attr() + dst_dist_attr = op_dist_attr.result(0).as_tensor_dist_attr() + assert ( + not var.initialized() or var.dist_attr() == src_dist_attr + ), f"The dist_attr of reshard op's input and operand should be equal, but got {var.dist_attr()} and {src_dist_attr}" + + reshard_func = choose_reshard_func(src_dist_attr, dst_dist_attr) + assert ( + reshard_func is not None + ), f'There is no reshard function that matches src_dist_attr: {src_dist_attr} and dst_dist_attr: {dst_dist_attr}' + paddle.pir.set_insertion_point_after(op) + out_value = reshard_func.reshard( + src_dist_attr, + dst_dist_attr, + op.operand_source(0), + op.result(0).type(), + ) + if out_value is not None: + op.result(0).replace_all_uses_with(out_value) + if op.result(0).use_empty(): + op.erase() # In sequence_parallel, we need to transpose hidden_states @@ -183,5 +174,5 @@ def eliminate_transpose_by_reshape(program): transpose_var = op.result(0) reshape_var = paddle._C_ops.reshape(var, transpose_var.shape) transpose_var.replace_all_uses_with(reshape_var) - program.global_block().remove_op(op) + op.erase() return program diff --git a/python/paddle/distributed/auto_parallel/static/reshard_funcs/base_reshard_func.py b/python/paddle/distributed/auto_parallel/static/reshard_funcs/base_reshard_func.py index b34fe137958305..c01c77ad7f0d7b 100644 --- a/python/paddle/distributed/auto_parallel/static/reshard_funcs/base_reshard_func.py +++ b/python/paddle/distributed/auto_parallel/static/reshard_funcs/base_reshard_func.py @@ -20,7 +20,7 @@ class ReshardFunction: def is_suitable(self, dist_tensor, dist_attr): raise NotImplementedError - def reshard(self, program, op, src_tensor, dst_dist_attr): + def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type): raise NotImplementedError diff --git a/python/paddle/distributed/auto_parallel/static/reshard_funcs/p_to_r_reshard_func.py b/python/paddle/distributed/auto_parallel/static/reshard_funcs/p_to_r_reshard_func.py index 32e632c6da9173..6c0c9445449938 100644 --- a/python/paddle/distributed/auto_parallel/static/reshard_funcs/p_to_r_reshard_func.py +++ b/python/paddle/distributed/auto_parallel/static/reshard_funcs/p_to_r_reshard_func.py @@ -39,9 +39,7 @@ def is_suitable(self, src_dist_attr, dst_dist_attr): return False return True - def reshard( - self, program, op, src_dist_attr, dst_dist_attr, reshard_op=True - ): + def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type): src_mesh = src_dist_attr.process_mesh src_reduce_type = src_dist_attr.partial_status[0] reduce_mean = False @@ -49,28 +47,19 @@ def reshard( src_reduce_type = ReduceOp.SUM reduce_mean = True - op_value = op.result(0) - op_type = op_value.type() - if reshard_op: - paddle.pir.set_insertion_point(op) - op_value = op.operand_source(0) - else: - paddle.pir.set_insertion_point_after(op) group = new_process_group(src_mesh.process_ids) reduced_value = paddle._pir_ops.c_allreduce_sum_( - op_value, group.id, True, False + src_value, group.id, True, False ) # set dist type and dist attr - reduced_value.set_type(op_type) + reduced_value.set_type(dst_type) reduced_value.get_defining_op().dist_attr = ( paddle.base.libpaddle.pir.create_op_dist_attribute( src_mesh, [src_dist_attr], [dst_dist_attr] ) ) - if reshard_op: - op.result(0).replace_all_uses_with(reduced_value) - program.global_block().remove_op(op) + return reduced_value class PToRReshardFunctionCrossMesh(ReshardFunction): @@ -96,26 +85,30 @@ def is_suitable(self, src_dist_attr, dst_dist_attr): return True - def reshard(self, program, op, src_dist_attr, dst_dist_attr): + def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type): same_status_func = SameStatusReshardFunction() tmp_dist_attr = paddle.base.libpaddle.pir.create_tensor_dist_attribute( dst_dist_attr.process_mesh, src_dist_attr.dims_mapping, src_dist_attr.partial_status, ) - pre_op, out_dist_attr = same_status_func.reshard( - program, op, src_dist_attr, tmp_dist_attr + tmp_dst_type = paddle.base.libpaddle.pir.cvt_to_dist_type( + src_value.type(), tmp_dist_attr + ) + out_value = same_status_func.reshard( + src_dist_attr, tmp_dist_attr, src_value, tmp_dst_type ) - if pre_op is None: - return None, out_dist_attr + if out_value is None: + return None curr_global_rank = paddle.distributed.get_rank() if curr_global_rank in dst_dist_attr.process_mesh.process_ids: p_to_r_func = PToRReshardFunction() assert p_to_r_func.is_suitable( - out_dist_attr, dst_dist_attr - ), f"Invoke the p to r reshard function is not valid from {pre_op.dist_attr()} to {dst_dist_attr}" - p_to_r_func.reshard( - program, pre_op, out_dist_attr, dst_dist_attr, False + tmp_dist_attr, dst_dist_attr + ), f"Invoke the p to r reshard function is not valid from {tmp_dist_attr} to {dst_dist_attr}" + return p_to_r_func.reshard( + tmp_dist_attr, dst_dist_attr, out_value, dst_type ) + return None diff --git a/python/paddle/distributed/auto_parallel/static/reshard_funcs/r_to_s_reshard_func.py b/python/paddle/distributed/auto_parallel/static/reshard_funcs/r_to_s_reshard_func.py index 87d85cb54c6168..922df440c5a21c 100644 --- a/python/paddle/distributed/auto_parallel/static/reshard_funcs/r_to_s_reshard_func.py +++ b/python/paddle/distributed/auto_parallel/static/reshard_funcs/r_to_s_reshard_func.py @@ -15,6 +15,7 @@ import paddle from .base_reshard_func import ReshardFunction, is_replicated, is_shard +from .same_status_reshard_func import SameStatusReshardFunction class RToSReshardFunction(ReshardFunction): @@ -36,9 +37,7 @@ def is_suitable(self, src_dist_attr, dst_dist_attr): return False return True - def reshard( - self, program, op, src_dist_attr, dst_dist_attr, remove_op=True - ): + def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type): split_axis = -1 mesh_axis = -1 for idx, v in enumerate(dst_dist_attr.dims_mapping): @@ -49,7 +48,7 @@ def reshard( mesh = src_dist_attr.process_mesh curr_global_rank = paddle.distributed.get_rank() if curr_global_rank in mesh.process_ids: - total_nums = op.operand_source(0).shape[split_axis] + total_nums = src_value.shape[split_axis] num_of_pieces = mesh.shape[mesh_axis] piece_len = (total_nums + num_of_pieces - 1) // num_of_pieces rank_relative = mesh.process_ids.index(curr_global_rank) @@ -58,9 +57,66 @@ def reshard( if curr_global_rank == mesh.process_ids[-1]: end = total_nums - paddle.pir.set_insertion_point(op) - out_value = paddle.slice( - op.operand_source(0), [split_axis], [start], [end] + out_value = paddle.slice(src_value, [split_axis], [start], [end]) + + out_value.set_type(src_value.type()) + out_value.update_dist_attr(dst_dist_attr) + out_value.get_defining_op().dist_attr = ( + paddle.base.libpaddle.pir.create_op_dist_attribute( + mesh, [src_dist_attr], [dst_dist_attr] + ) + ) + return out_value + return None + + +class RToSReshardFunctionCrossMesh(ReshardFunction): + def is_suitable(self, src_dist_attr, dst_dist_attr): + if not is_replicated(src_dist_attr): + return False + + if not is_shard(dst_dist_attr): + return False + + in_mesh = src_dist_attr.process_mesh + out_mesh = dst_dist_attr.process_mesh + + if ( + in_mesh.ndim != 1 + or out_mesh.ndim != 1 + or in_mesh.shape != out_mesh.shape + ): + return False + + if in_mesh == out_mesh: + return False + + return True + + def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type): + same_status_func = SameStatusReshardFunction() + tmp_dist_attr = paddle.base.libpaddle.pir.create_tensor_dist_attribute( + dst_dist_attr.process_mesh, + src_dist_attr.dims_mapping, + src_dist_attr.partial_status, + ) + tmp_dst_type = paddle.base.libpaddle.pir.cvt_to_dist_type( + src_value.type(), tmp_dist_attr + ) + out_value = same_status_func.reshard( + src_dist_attr, tmp_dist_attr, src_value, tmp_dst_type + ) + + if out_value is None: + return None + + curr_global_rank = paddle.distributed.get_rank() + if curr_global_rank in dst_dist_attr.process_mesh.process_ids: + r_to_s_func = RToSReshardFunction() + assert r_to_s_func.is_suitable( + tmp_dist_attr, dst_dist_attr + ), f"Invoke the r to s reshard function is not valid from {tmp_dist_attr} to {dst_dist_attr}" + return r_to_s_func.reshard( + tmp_dist_attr, dst_dist_attr, out_value, dst_type ) - op.result(0).replace_all_uses_with(out_value) - op.get_parent_block().remove_op(op) + return None diff --git a/python/paddle/distributed/auto_parallel/static/reshard_funcs/reshard_func_register.py b/python/paddle/distributed/auto_parallel/static/reshard_funcs/reshard_func_register.py index 5afe0f995e4f01..136ebfdcbbf104 100644 --- a/python/paddle/distributed/auto_parallel/static/reshard_funcs/reshard_func_register.py +++ b/python/paddle/distributed/auto_parallel/static/reshard_funcs/reshard_func_register.py @@ -17,14 +17,18 @@ PToRReshardFunction, PToRReshardFunctionCrossMesh, ) -from .r_to_s_reshard_func import RToSReshardFunction +from .r_to_s_reshard_func import ( + RToSReshardFunction, + RToSReshardFunctionCrossMesh, +) from .same_status_reshard_func import SameStatusReshardFunction def register_reshard_funcs(): register_reshard_func(PToRReshardFunction()) - register_reshard_func(RToSReshardFunction()) register_reshard_func(PToRReshardFunctionCrossMesh()) + register_reshard_func(RToSReshardFunction()) + register_reshard_func(RToSReshardFunctionCrossMesh()) register_reshard_func(SameStatusReshardFunction()) diff --git a/python/paddle/distributed/auto_parallel/static/reshard_funcs/same_status_reshard_func.py b/python/paddle/distributed/auto_parallel/static/reshard_funcs/same_status_reshard_func.py index af910104de6277..1e79afbaf0ab06 100644 --- a/python/paddle/distributed/auto_parallel/static/reshard_funcs/same_status_reshard_func.py +++ b/python/paddle/distributed/auto_parallel/static/reshard_funcs/same_status_reshard_func.py @@ -34,7 +34,7 @@ def is_suitable(self, src_dist_attr, dst_dist_attr): return False return True - def reshard(self, program, op, src_dist_attr, dst_dist_attr): + def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type): src_mesh = src_dist_attr.process_mesh dst_mesh = dst_dist_attr.process_mesh @@ -45,14 +45,13 @@ def reshard(self, program, op, src_dist_attr, dst_dist_attr): cur_global_rank = paddle.distributed.get_rank() comm_group = new_process_group(all_process_ids) - paddle.pir.set_insertion_point(op) is_send = True for src, dst in zip(src_mesh.process_ids, dst_mesh.process_ids): if src == cur_global_rank: dst_local_rank = all_process_ids.index(dst) paddle._pir_ops.send_v2( - op.operand_source(0), + src_value, comm_group.id, dst_local_rank, True, @@ -72,11 +71,11 @@ def reshard(self, program, op, src_dist_attr, dst_dist_attr): elif dst == cur_global_rank: src_local_rank = all_process_ids.index(src) assert ( - -1 not in op.result(0).shape + -1 not in dst_type.shape ), "dynamic shape is not supported by pir-auto parallel yet." recv_value = paddle._pir_ops.recv_v2( - op.result(0).shape, - op.result(0).dtype, + dst_type.shape, + dst_type.dtype, src_local_rank, comm_group.id, True, @@ -88,14 +87,11 @@ def reshard(self, program, op, src_dist_attr, dst_dist_attr): dst_mesh, [], [dst_dist_attr] ) ) - recv_value.set_type(op.result(0).type()) - op.result(0).replace_all_uses_with(recv_value) + recv_value.update_dist_attr(dst_dist_attr) is_send = False break - program.global_block().remove_op(op) - if is_send: - return None, None + return None else: - return new_op, dst_dist_attr + return recv_value diff --git a/test/auto_parallel/pir_reshard_r_to_s_cross_mesh.py b/test/auto_parallel/pir_reshard_r_to_s_cross_mesh.py new file mode 100644 index 00000000000000..875850ccd035ac --- /dev/null +++ b/test/auto_parallel/pir_reshard_r_to_s_cross_mesh.py @@ -0,0 +1,92 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import paddle +import paddle.distributed as dist +from paddle.distributed.auto_parallel.static.pir_pass import ( + apply_reshard_pass, +) + + +class TestReshardRToSCrossMesh: + def __init__(self): + self._shape = eval(os.getenv("shape")) + self._dtype = os.getenv("dtype") + self._seeds = eval(os.getenv("seeds")) + self._shard = eval(os.getenv("shard")) + self._backend = os.getenv("backend") + self._in_mesh = dist.ProcessMesh([0, 2], dim_names=["x"]) + self._out_mesh = dist.ProcessMesh([1, 3], dim_names=["x"]) + + def run_test_case(self): + paddle.enable_static() + + BATCH_SIZE = 2 + SEQ_LEN = 4 + HIDDEN_SIZE = 8 + + with paddle.pir_utils.IrGuard(): + main_program = paddle.base.Program() + with paddle.base.program_guard(main_program): + input = paddle.static.data( + name='input', shape=[BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE] + ) + input_tensor = dist.shard_tensor( + input, self._in_mesh, [dist.Replicate()] + ) + out = paddle._C_ops.reshard( + input_tensor, self._out_mesh, [dist.Shard(self._shard)] + ) + target_type = out.type() + + old_ops = [op.name() for op in main_program.global_block().ops] + assert 'dist_op.reshard' in old_ops + + apply_reshard_pass(main_program) + # np.testing.assert_equal(dist_program.num_ops(), 6) + new_ops = [op.name() for op in main_program.global_block().ops] + assert 'dist_op.reshard' not in new_ops + if dist.get_rank() in self._in_mesh.process_ids: + assert 'pd_op.send_v2' in new_ops + else: + assert 'pd_op.recv_v2' in new_ops + assert 'pd_op.slice' in new_ops + for op in main_program.global_block().ops: + if op.name() == 'pd_op.send_v2': + assert op.dist_attr.process_mesh == self._in_mesh + assert op.operand_source(0).dist_attr() == op.dist_attr.operand( + 0 + ) + + operand_dist_attr = op.operand_source(0).dist_attr() + assert operand_dist_attr.process_mesh == self._in_mesh + assert operand_dist_attr.dims_mapping == [-1, -1, -1] + assert operand_dist_attr.partial_status == {} + elif op.name() == 'pd_op.recv_v2': + assert op.dist_attr.process_mesh == self._out_mesh + assert op.result(0).dist_attr() == op.dist_attr.result(0) + result_dist_attr = op.result(0).dist_attr() + assert result_dist_attr.process_mesh == self._out_mesh + assert result_dist_attr.dims_mapping == [-1, -1, -1] + assert result_dist_attr.partial_status == {} + elif op.name() == 'pd_op.slice': + assert op.dist_attr.process_mesh == self._out_mesh + assert op.result(0).dist_attr() == op.dist_attr.result(0) + assert op.result(0).type() == target_type + + +if __name__ == '__main__': + TestReshardRToSCrossMesh().run_test_case() diff --git a/test/auto_parallel/reshard_p_to_r.py b/test/auto_parallel/reshard_p_to_r.py index d680dad2f38cc5..0a71492ba11d0d 100644 --- a/test/auto_parallel/reshard_p_to_r.py +++ b/test/auto_parallel/reshard_p_to_r.py @@ -88,9 +88,9 @@ def run_pir_static_test_case(self): reshard_tensor = paddle._C_ops.reshard( input_tensor, self._mesh, [dist.Replicate()] ) - dist_program = apply_reshard_pass(main_program) - np.testing.assert_equal(dist_program.num_ops(), 4) - ops = dist_program.global_block().ops + apply_reshard_pass(main_program) + np.testing.assert_equal(main_program.num_ops(), 4) + ops = main_program.global_block().ops np.testing.assert_equal( [op.name() for op in ops], [ diff --git a/test/auto_parallel/reshard_p_to_r_cross_mesh.py b/test/auto_parallel/reshard_p_to_r_cross_mesh.py index 1b3b4b07447058..42a34a478a7ffb 100644 --- a/test/auto_parallel/reshard_p_to_r_cross_mesh.py +++ b/test/auto_parallel/reshard_p_to_r_cross_mesh.py @@ -85,11 +85,12 @@ def run_pir_static_test_case(self): reshard_tensor = paddle._pir_ops.reshard( input_tensor, self._out_mesh, [dist.Replicate()] ) - dist_program = apply_reshard_pass(main_program) - ops = [op.name() for op in dist_program.global_block().ops] + apply_reshard_pass(main_program) + + ops = [op.name() for op in main_program.global_block().ops] if paddle.distributed.get_rank() == 0: - np.testing.assert_equal(dist_program.num_ops(), 4) + np.testing.assert_equal(main_program.num_ops(), 4) std_ops = [ 'builtin.parameter', 'pd_op.data', @@ -97,7 +98,7 @@ def run_pir_static_test_case(self): 'pd_op.send_v2', ] else: - np.testing.assert_equal(dist_program.num_ops(), 5) + np.testing.assert_equal(main_program.num_ops(), 5) std_ops = [ 'builtin.parameter', 'pd_op.data', @@ -109,7 +110,7 @@ def run_pir_static_test_case(self): ops, std_ops, ) - for op in dist_program.global_block().ops: + for op in main_program.global_block().ops: if op.name() == 'pd_op.send_v2': assert op.dist_attr.num_operands() == 1 assert op.dist_attr.num_results() == 0 diff --git a/test/auto_parallel/reshard_r_to_s.py b/test/auto_parallel/reshard_r_to_s.py index 8a37255654ead5..ea36ef7a344245 100644 --- a/test/auto_parallel/reshard_r_to_s.py +++ b/test/auto_parallel/reshard_r_to_s.py @@ -93,7 +93,8 @@ def run_pir_test_case(self): paddle._C_ops.reshard( input_tensor, self._mesh, [dist.Shard(self._shard)] ) - dist_program = apply_reshard_pass(main_program) + dist_program = main_program.clone() + apply_reshard_pass(dist_program) np.testing.assert_equal(dist_program.num_ops(), 6) old_ops = [op.name() for op in main_program.global_block().ops] new_ops = [op.name() for op in dist_program.global_block().ops] diff --git a/test/auto_parallel/test_reshard_r_to_s.py b/test/auto_parallel/test_reshard_r_to_s.py index b951508f8c1c95..ada6f934f4aae9 100644 --- a/test/auto_parallel/test_reshard_r_to_s.py +++ b/test/auto_parallel/test_reshard_r_to_s.py @@ -52,5 +52,30 @@ def test_reshard_r_to_s_cross_mesh(self): ) +class TestReshardRToSCrossMesh(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp(num_of_devices=2, timeout=120) + self._default_envs = { + "dtype": "float32", + "seeds": "2023", + } + self._changeable_envs = { + "shape": ["(10, 20)"], + "shard": ["0", "1"], + "backend": ["cpu"], + } + + def test_reshard_r_to_s_cross_mesh(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + # self._log_dir.name = "./log" + for envs in envs_list: + self.run_test_case( + "pir_reshard_r_to_s_cross_mesh.py", + user_defined_envs=envs, + ) + + if __name__ == "__main__": unittest.main()