From 13e76d76e5caf2a993ecb7d0422e53f348eafc35 Mon Sep 17 00:00:00 2001 From: Lans1ot Date: Wed, 31 Jul 2024 14:12:09 +0800 Subject: [PATCH 01/10] =?UTF-8?q?=E8=A1=A5=E5=85=A8=E5=8F=82=E6=95=B0?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B=E4=BF=A1=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../communication/reduce_scatter.py | 41 +++++++++++++++++-- .../distributed/communication/scatter.py | 23 +++++++++-- .../paddle/distributed/communication/send.py | 15 ++++++- .../communication/stream/all_gather.py | 41 ++++++++++++++----- 4 files changed, 100 insertions(+), 20 deletions(-) diff --git a/python/paddle/distributed/communication/reduce_scatter.py b/python/paddle/distributed/communication/reduce_scatter.py index 8513d79f8c7fa..1b5c33ad41559 100644 --- a/python/paddle/distributed/communication/reduce_scatter.py +++ b/python/paddle/distributed/communication/reduce_scatter.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + import paddle from paddle.distributed.communication import stream from paddle.distributed.communication.reduce import ReduceOp @@ -19,10 +23,21 @@ _reduce_scatter_base as _reduce_scatter_base_stream, ) +if TYPE_CHECKING: + from paddle import Tensor + + _ReduceOpType = int + from paddle.base.core import task + from paddle.distributed.communication.group import Group + def reduce_scatter( - tensor, tensor_list, op=ReduceOp.SUM, group=None, sync_op=True -): + tensor: Tensor, + tensor_list: list[Tensor], + op: _ReduceOpType = ReduceOp.SUM, + group: Group = None, + sync_op: bool = True, +) -> task: """ Reduces, then scatters a list of tensors to all processes in a group @@ -62,6 +77,16 @@ def reduce_scatter( >>> # [8, 10] (2 GPUs, out for rank 1) """ + if op not in [ + ReduceOp.AVG, + ReduceOp.MAX, + ReduceOp.MIN, + ReduceOp.PROD, + ReduceOp.SUM, + ]: + raise RuntimeError( + "Invalid ``op`` function. Expected ``op`` to be of type ``ReduceOp.SUM``, ``ReduceOp.Max``, ``ReduceOp.MIN``, ``ReduceOp.PROD`` or ``ReduceOp.AVG``." + ) # AVG is only supported when nccl >= 2.10 if op == ReduceOp.AVG and paddle.base.core.nccl_version() < 21000: group = ( @@ -89,8 +114,12 @@ def reduce_scatter( def _reduce_scatter_base( - output, input, op=ReduceOp.SUM, group=None, sync_op=True -): + output: Tensor, + input: Tensor, + op: _ReduceOpType = ReduceOp.SUM, + group: Group = None, + sync_op: bool = True, +) -> task | None: """ Reduces, then scatters a flattened tensor to all processes in a group. @@ -126,6 +155,10 @@ def _reduce_scatter_base( >>> # [5, 7] (2 GPUs, out for rank 1) """ + if op not in [ReduceOp.MAX, ReduceOp.MIN, ReduceOp.PROD, ReduceOp.SUM]: + raise RuntimeError( + "Invalid ``op`` function. Expected ``op`` to be of type ``ReduceOp.SUM``, ``ReduceOp.Max``, ``ReduceOp.MIN`` or ``ReduceOp.PROD``." + ) return _reduce_scatter_base_stream( output, input, diff --git a/python/paddle/distributed/communication/scatter.py b/python/paddle/distributed/communication/scatter.py index ba62c30aecb83..6858c436dea26 100644 --- a/python/paddle/distributed/communication/scatter.py +++ b/python/paddle/distributed/communication/scatter.py @@ -12,6 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from paddle import Tensor + from paddle.distributed.communication.group import Group + import numpy as np import paddle @@ -25,7 +33,13 @@ ) -def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True): +def scatter( + tensor: Tensor, + tensor_list: list[Tensor] | tuple[Tensor] = None, + src: int = 0, + group: Group = None, + sync_op: bool = True, +) -> None: """ Scatter a tensor to all participators. As shown below, one process is started with a GPU and the source of the scatter @@ -72,8 +86,11 @@ def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True): def scatter_object_list( - out_object_list, in_object_list=None, src=0, group=None -): + out_object_list: list[Any], + in_object_list: list[Any] = None, + src: int = 0, + group: Group = None, +) -> None: """ Scatter picklable objects from the source to all others. Similiar to scatter(), but python object can be passed in. diff --git a/python/paddle/distributed/communication/send.py b/python/paddle/distributed/communication/send.py index c1c3e19204a4e..65f77235bc535 100644 --- a/python/paddle/distributed/communication/send.py +++ b/python/paddle/distributed/communication/send.py @@ -12,10 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + from paddle.distributed.communication import stream +if TYPE_CHECKING: + from paddle import Tensor + from paddle.base.core import task + from paddle.distributed.communication.group import Group + -def send(tensor, dst=0, group=None, sync_op=True): +def send( + tensor: Tensor, dst: int = 0, group: Group = None, sync_op: bool = True +) -> task: """ Send a tensor to the receiver. @@ -51,7 +62,7 @@ def send(tensor, dst=0, group=None, sync_op=True): ) -def isend(tensor, dst, group=None): +def isend(tensor: Tensor, dst: int, group: Group = None) -> task: """ Send tensor asynchronously diff --git a/python/paddle/distributed/communication/stream/all_gather.py b/python/paddle/distributed/communication/stream/all_gather.py index f2cbf5522f47e..f14dbf12967be 100644 --- a/python/paddle/distributed/communication/stream/all_gather.py +++ b/python/paddle/distributed/communication/stream/all_gather.py @@ -12,16 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + import paddle import paddle.distributed as dist from paddle import framework from paddle.base import data_feeder from paddle.distributed.communication.group import _get_global_group +if TYPE_CHECKING: + from paddle import Tensor + from paddle.base.core import task + from paddle.distributed.communication.group import Group + def _all_gather_into_tensor_in_dygraph( - out_tensor, in_tensor, group, sync_op, use_calc_stream -): + out_tensor: Tensor, + in_tensor: Tensor, + group: Group, + sync_op: bool, + use_calc_stream: bool, +) -> task: group = _get_global_group() if group is None else group if use_calc_stream: @@ -40,8 +53,12 @@ def _all_gather_into_tensor_in_dygraph( def _all_gather_in_dygraph( - tensor_list, tensor, group, sync_op, use_calc_stream -): + tensor_list: list[Tensor], + tensor: Tensor, + group: Group, + sync_op: bool, + use_calc_stream: bool, +) -> task: group = _get_global_group() if group is None else group if len(tensor_list) == 0: @@ -59,7 +76,9 @@ def _all_gather_in_dygraph( return task -def _all_gather_in_static_mode(tensor_list, tensor, group, sync_op): +def _all_gather_in_static_mode( + tensor_list: list[Tensor], tensor: Tensor, group: Group, sync_op: bool +) -> task: op_type = 'all_gather' helper = framework.LayerHelper(op_type, **locals()) out = helper.create_variable_for_type_inference(dtype=tensor.dtype) @@ -121,12 +140,12 @@ def _all_gather_in_static_mode(tensor_list, tensor, group, sync_op): def all_gather( - tensor_or_tensor_list, - tensor, - group=None, - sync_op=True, - use_calc_stream=False, -): + tensor_or_tensor_list: Tensor | list[Tensor], + tensor: Tensor, + group: Group = None, + sync_op: bool = True, + use_calc_stream: bool = False, +) -> task: """ Gather tensors across devices to a correctly-sized tensor or a tensor list. From feb3ba9ad1b7be8dae9d599cc19fb8c9fe801f8e Mon Sep 17 00:00:00 2001 From: Lans1ot Date: Wed, 31 Jul 2024 15:25:53 +0800 Subject: [PATCH 02/10] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=8F=82=E6=95=B0?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B=E4=BF=A1=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../paddle/distributed/communication/reduce_scatter.py | 10 ++++------ python/paddle/distributed/communication/scatter.py | 6 +++--- python/paddle/distributed/communication/send.py | 7 +++++-- .../distributed/communication/stream/all_gather.py | 2 +- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/python/paddle/distributed/communication/reduce_scatter.py b/python/paddle/distributed/communication/reduce_scatter.py index 1b5c33ad41559..1785fc8191346 100644 --- a/python/paddle/distributed/communication/reduce_scatter.py +++ b/python/paddle/distributed/communication/reduce_scatter.py @@ -25,8 +25,6 @@ if TYPE_CHECKING: from paddle import Tensor - - _ReduceOpType = int from paddle.base.core import task from paddle.distributed.communication.group import Group @@ -34,8 +32,8 @@ def reduce_scatter( tensor: Tensor, tensor_list: list[Tensor], - op: _ReduceOpType = ReduceOp.SUM, - group: Group = None, + op: ReduceOp = ReduceOp.SUM, + group: Group | None = None, sync_op: bool = True, ) -> task: """ @@ -116,8 +114,8 @@ def reduce_scatter( def _reduce_scatter_base( output: Tensor, input: Tensor, - op: _ReduceOpType = ReduceOp.SUM, - group: Group = None, + op: ReduceOp = ReduceOp.SUM, + group: Group | None = None, sync_op: bool = True, ) -> task | None: """ diff --git a/python/paddle/distributed/communication/scatter.py b/python/paddle/distributed/communication/scatter.py index 6858c436dea26..cd33b890b89bc 100644 --- a/python/paddle/distributed/communication/scatter.py +++ b/python/paddle/distributed/communication/scatter.py @@ -35,9 +35,9 @@ def scatter( tensor: Tensor, - tensor_list: list[Tensor] | tuple[Tensor] = None, + tensor_list: list[Tensor] | tuple[Tensor] | None = None, src: int = 0, - group: Group = None, + group: Group | None = None, sync_op: bool = True, ) -> None: """ @@ -87,7 +87,7 @@ def scatter( def scatter_object_list( out_object_list: list[Any], - in_object_list: list[Any] = None, + in_object_list: list[Any] | None = None, src: int = 0, group: Group = None, ) -> None: diff --git a/python/paddle/distributed/communication/send.py b/python/paddle/distributed/communication/send.py index 65f77235bc535..fe6140fec598a 100644 --- a/python/paddle/distributed/communication/send.py +++ b/python/paddle/distributed/communication/send.py @@ -25,7 +25,10 @@ def send( - tensor: Tensor, dst: int = 0, group: Group = None, sync_op: bool = True + tensor: Tensor, + dst: int = 0, + group: Group | None = None, + sync_op: bool = True, ) -> task: """ Send a tensor to the receiver. @@ -62,7 +65,7 @@ def send( ) -def isend(tensor: Tensor, dst: int, group: Group = None) -> task: +def isend(tensor: Tensor, dst: int, group: Group | None = None) -> task: """ Send tensor asynchronously diff --git a/python/paddle/distributed/communication/stream/all_gather.py b/python/paddle/distributed/communication/stream/all_gather.py index f14dbf12967be..1b7f7ce10bd49 100644 --- a/python/paddle/distributed/communication/stream/all_gather.py +++ b/python/paddle/distributed/communication/stream/all_gather.py @@ -142,7 +142,7 @@ def _all_gather_in_static_mode( def all_gather( tensor_or_tensor_list: Tensor | list[Tensor], tensor: Tensor, - group: Group = None, + group: Group | None = None, sync_op: bool = True, use_calc_stream: bool = False, ) -> task: From f7dac7a075d94ce301ffa53958a7c5e88f0aa37a Mon Sep 17 00:00:00 2001 From: Lans1ot Date: Wed, 31 Jul 2024 15:32:54 +0800 Subject: [PATCH 03/10] =?UTF-8?q?=E8=BF=BD=E5=8A=A0=E9=81=97=E6=BC=8F?= =?UTF-8?q?=E7=9A=84=E7=B1=BB=E5=9E=8B=E4=BF=A1=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/paddle/distributed/communication/scatter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/communication/scatter.py b/python/paddle/distributed/communication/scatter.py index cd33b890b89bc..01964bffa9e13 100644 --- a/python/paddle/distributed/communication/scatter.py +++ b/python/paddle/distributed/communication/scatter.py @@ -89,7 +89,7 @@ def scatter_object_list( out_object_list: list[Any], in_object_list: list[Any] | None = None, src: int = 0, - group: Group = None, + group: Group | None = None, ) -> None: """ From 0e4d12c36bb2df6b25f370bffe38560049827d8f Mon Sep 17 00:00:00 2001 From: Lans1ot <47025645+Lans1ot@users.noreply.github.com> Date: Thu, 1 Aug 2024 11:10:52 +0800 Subject: [PATCH 04/10] Update python/paddle/distributed/communication/scatter.py Co-authored-by: megemini --- python/paddle/distributed/communication/scatter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/communication/scatter.py b/python/paddle/distributed/communication/scatter.py index 01964bffa9e13..c8fbb911ac668 100644 --- a/python/paddle/distributed/communication/scatter.py +++ b/python/paddle/distributed/communication/scatter.py @@ -86,8 +86,8 @@ def scatter( def scatter_object_list( - out_object_list: list[Any], - in_object_list: list[Any] | None = None, + out_object_list: list[object], + in_object_list: list[object] | None = None, src: int = 0, group: Group | None = None, ) -> None: From 8c7e446dfaeab380da53eeb96531683353641616 Mon Sep 17 00:00:00 2001 From: Lans1ot Date: Thu, 1 Aug 2024 11:12:42 +0800 Subject: [PATCH 05/10] =?UTF-8?q?=E6=8C=89=E8=A6=81=E6=B1=82=E8=BF=9B?= =?UTF-8?q?=E8=A1=8C=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/paddle/distributed/communication/reduce_scatter.py | 5 +++-- python/paddle/distributed/communication/scatter.py | 8 ++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/python/paddle/distributed/communication/reduce_scatter.py b/python/paddle/distributed/communication/reduce_scatter.py index 1785fc8191346..472dc6179382b 100644 --- a/python/paddle/distributed/communication/reduce_scatter.py +++ b/python/paddle/distributed/communication/reduce_scatter.py @@ -27,12 +27,13 @@ from paddle import Tensor from paddle.base.core import task from paddle.distributed.communication.group import Group + from paddle.distributed.communication.reduce import _ReduceOp def reduce_scatter( tensor: Tensor, tensor_list: list[Tensor], - op: ReduceOp = ReduceOp.SUM, + op: _ReduceOp = ReduceOp.SUM, group: Group | None = None, sync_op: bool = True, ) -> task: @@ -114,7 +115,7 @@ def reduce_scatter( def _reduce_scatter_base( output: Tensor, input: Tensor, - op: ReduceOp = ReduceOp.SUM, + op: _ReduceOp = ReduceOp.SUM, group: Group | None = None, sync_op: bool = True, ) -> task | None: diff --git a/python/paddle/distributed/communication/scatter.py b/python/paddle/distributed/communication/scatter.py index 01964bffa9e13..48d1413b8bfaf 100644 --- a/python/paddle/distributed/communication/scatter.py +++ b/python/paddle/distributed/communication/scatter.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Sequence if TYPE_CHECKING: from paddle import Tensor @@ -35,7 +35,7 @@ def scatter( tensor: Tensor, - tensor_list: list[Tensor] | tuple[Tensor] | None = None, + tensor_list: Sequence[Tensor] | None = None, src: int = 0, group: Group | None = None, sync_op: bool = True, @@ -86,8 +86,8 @@ def scatter( def scatter_object_list( - out_object_list: list[Any], - in_object_list: list[Any] | None = None, + out_object_list: list[object], + in_object_list: list[object] | None = None, src: int = 0, group: Group | None = None, ) -> None: From 4c0c198f880d9c79b9043052888179bdec4a906e Mon Sep 17 00:00:00 2001 From: Lans1ot <47025645+Lans1ot@users.noreply.github.com> Date: Fri, 2 Aug 2024 16:57:17 +0800 Subject: [PATCH 06/10] Update python/paddle/distributed/communication/scatter.py Co-authored-by: megemini --- python/paddle/distributed/communication/scatter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/communication/scatter.py b/python/paddle/distributed/communication/scatter.py index 48d1413b8bfaf..7d74920727c74 100644 --- a/python/paddle/distributed/communication/scatter.py +++ b/python/paddle/distributed/communication/scatter.py @@ -86,8 +86,8 @@ def scatter( def scatter_object_list( - out_object_list: list[object], - in_object_list: list[object] | None = None, + out_object_list: list[Any], + in_object_list: list[Any] | None = None, src: int = 0, group: Group | None = None, ) -> None: From f488c7d72cb1699ea3015fae56bacfb7d1173932 Mon Sep 17 00:00:00 2001 From: Lans1ot Date: Mon, 5 Aug 2024 10:03:10 +0800 Subject: [PATCH 07/10] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=AF=BC=E5=85=A5?= =?UTF-8?q?=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/paddle/distributed/communication/scatter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/communication/scatter.py b/python/paddle/distributed/communication/scatter.py index 7d74920727c74..134fe615395cf 100644 --- a/python/paddle/distributed/communication/scatter.py +++ b/python/paddle/distributed/communication/scatter.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Any, Sequence if TYPE_CHECKING: from paddle import Tensor From 3d4abf37d27812d5d2afab56f77135d6b2372078 Mon Sep 17 00:00:00 2001 From: Lans1ot Date: Mon, 5 Aug 2024 14:18:55 +0800 Subject: [PATCH 08/10] =?UTF-8?q?=E6=A0=B9=E6=8D=AE=E5=BB=BA=E8=AE=AE?= =?UTF-8?q?=E8=BF=9B=E8=A1=8C=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/paddle/distributed/communication/scatter.py | 3 ++- python/paddle/distributed/communication/send.py | 4 ++-- python/paddle/distributed/communication/stream/all_gather.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/python/paddle/distributed/communication/scatter.py b/python/paddle/distributed/communication/scatter.py index 134fe615395cf..c5fd58f61d197 100644 --- a/python/paddle/distributed/communication/scatter.py +++ b/python/paddle/distributed/communication/scatter.py @@ -18,6 +18,7 @@ if TYPE_CHECKING: from paddle import Tensor + from paddle.base.core import task from paddle.distributed.communication.group import Group import numpy as np @@ -39,7 +40,7 @@ def scatter( src: int = 0, group: Group | None = None, sync_op: bool = True, -) -> None: +) -> task | None: """ Scatter a tensor to all participators. As shown below, one process is started with a GPU and the source of the scatter diff --git a/python/paddle/distributed/communication/send.py b/python/paddle/distributed/communication/send.py index fe6140fec598a..da238dfd20ebf 100644 --- a/python/paddle/distributed/communication/send.py +++ b/python/paddle/distributed/communication/send.py @@ -29,7 +29,7 @@ def send( dst: int = 0, group: Group | None = None, sync_op: bool = True, -) -> task: +) -> task | None: """ Send a tensor to the receiver. @@ -65,7 +65,7 @@ def send( ) -def isend(tensor: Tensor, dst: int, group: Group | None = None) -> task: +def isend(tensor: Tensor, dst: int, group: Group | None = None) -> task | None: """ Send tensor asynchronously diff --git a/python/paddle/distributed/communication/stream/all_gather.py b/python/paddle/distributed/communication/stream/all_gather.py index 1b7f7ce10bd49..fa50c410defb7 100644 --- a/python/paddle/distributed/communication/stream/all_gather.py +++ b/python/paddle/distributed/communication/stream/all_gather.py @@ -78,7 +78,7 @@ def _all_gather_in_dygraph( def _all_gather_in_static_mode( tensor_list: list[Tensor], tensor: Tensor, group: Group, sync_op: bool -) -> task: +) -> None: op_type = 'all_gather' helper = framework.LayerHelper(op_type, **locals()) out = helper.create_variable_for_type_inference(dtype=tensor.dtype) From 9f1f05f59ab1e4b318d1ddcf44db6628d61e162c Mon Sep 17 00:00:00 2001 From: Lans1ot Date: Mon, 5 Aug 2024 16:19:25 +0800 Subject: [PATCH 09/10] =?UTF-8?q?=E8=BF=BD=E5=8A=A0=E7=B1=BB=E5=9E=8B?= =?UTF-8?q?=E4=BF=A1=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/paddle/distributed/communication/stream/all_gather.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/communication/stream/all_gather.py b/python/paddle/distributed/communication/stream/all_gather.py index fa50c410defb7..a6e66e4471a58 100644 --- a/python/paddle/distributed/communication/stream/all_gather.py +++ b/python/paddle/distributed/communication/stream/all_gather.py @@ -145,7 +145,7 @@ def all_gather( group: Group | None = None, sync_op: bool = True, use_calc_stream: bool = False, -) -> task: +) -> task | None: """ Gather tensors across devices to a correctly-sized tensor or a tensor list. From 383202140672d4af0e80463629693ddf2b5bfbfe Mon Sep 17 00:00:00 2001 From: SigureMo Date: Tue, 6 Aug 2024 02:17:59 +0800 Subject: [PATCH 10/10] ignore [union-attr] --- python/paddle/distributed/communication/send.py | 2 +- python/paddle/distributed/communication/stream/all_gather.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/communication/send.py b/python/paddle/distributed/communication/send.py index da238dfd20ebf..6dded800afa97 100644 --- a/python/paddle/distributed/communication/send.py +++ b/python/paddle/distributed/communication/send.py @@ -95,7 +95,7 @@ def isend(tensor: Tensor, dst: int, group: Group | None = None) -> task | None: >>> else: ... data = paddle.to_tensor([1, 2, 3]) ... task = dist.irecv(data, src=0) - >>> task.wait() + >>> task.wait() # type: ignore[union-attr] >>> print(data) >>> # [7, 8, 9] (2 GPUs) diff --git a/python/paddle/distributed/communication/stream/all_gather.py b/python/paddle/distributed/communication/stream/all_gather.py index a6e66e4471a58..a64f0b00b5f83 100644 --- a/python/paddle/distributed/communication/stream/all_gather.py +++ b/python/paddle/distributed/communication/stream/all_gather.py @@ -181,7 +181,7 @@ def all_gather( >>> else: ... data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]]) >>> task = dist.stream.all_gather(tensor_list, data, sync_op=False) - >>> task.wait() + >>> task.wait() # type: ignore[union-attr] >>> print(tensor_list) [[[4, 5, 6], [4, 5, 6]], [[1, 2, 3], [1, 2, 3]]] (2 GPUs) """