Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
haohongxiang committed Sep 28, 2022
1 parent e94606e commit 17a9161
Showing 1 changed file with 24 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,14 @@ def _partial_send_op(tensor, group, use_calc_stream, ring_id, dst, nranks,
elif in_dygraph_mode():
group = paddle.distributed.collective._get_default_group(
) if group is None else group
comm_op = group.process_group.send_partial_on_calc_stream \
if use_calc_stream else group.process_group.send_partial
return comm_op(tensor, dst_rank_in_group, nranks, rank_id)
if use_calc_stream:
task = group.process_group.send_partial_on_calc_stream(
tensor, dst, nranks, rank_id)
task.wait()
return None
else:
return group.process_group.send_partial(tensor, dst, nranks,
rank_id)


def send_partial(tensor,
Expand Down Expand Up @@ -214,9 +219,14 @@ def _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, nranks,
elif in_dygraph_mode():
group = paddle.distributed.collective._get_default_group(
) if group is None else group
comm_op = group.process_group.recv_partial_on_calc_stream \
if use_calc_stream else group.process_group.recv_partial
return comm_op(tensor, src_rank_in_group, nranks, rank_id)
if use_calc_stream:
task = group.process_group.recv_partial_on_calc_stream(
tensor, src, nranks, rank_id)
task.wait()
return None
else:
return group.process_group.recv_partial(tensor, src, nranks,
rank_id)


def recv_partial(tensor,
Expand Down Expand Up @@ -254,9 +264,14 @@ def _partial_allgather_op(tensor, group, use_calc_stream, ring_id, nranks,
elif in_dygraph_mode():
group = paddle.distributed.collective._get_default_group(
) if group is None else group
comm_op = group.process_group.all_gather_partial_on_calc_stream \
if use_calc_stream else group.process_group.all_gather_partial
return comm_op(tensor, tensor, nranks, rank_id)
if use_calc_stream:
task = group.process_group.all_gather_partial_on_calc_stream(
tensor, tensor, nranks, rank_id)
task.wait()
return None
else:
return group.process_group.all_gather_partial(
tensor, tensor, nranks, rank_id)


def allgather_partial(tensor,
Expand Down

0 comments on commit 17a9161

Please sign in to comment.