Skip to content

Commit

Permalink
rm _init_npu_pipeline_comm (#53150)
Browse files Browse the repository at this point in the history
  • Loading branch information
Liyulingyue authored Apr 21, 2023
1 parent 43b950f commit 07878a3
Showing 1 changed file with 0 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -737,89 +737,6 @@ def _init_pair_comm(self, pair, ring_id):
sync=False,
)

def _init_npu_pipeline_comm(self, startup_block):
assert (self.pp_degree % 2) == 0

max_ring_id = -1
my_pair = []
for pair in self.pipeline_pair:
pair_key = pair[0] * 1000 + pair[1]
ring_id = self.pp_ring_map[pair_key]
max_ring_id = max(max_ring_id, ring_id)
logger.info(f"pp pair:{pair}, ring_id: {ring_id}")

if self.pp_rank in pair:
my_pair.append(pair)

# for example: self.pp_rank=2, self.pp_degree=4
send_to_next_pair = (
self.pp_rank,
(self.pp_rank + 1) % self.pp_degree,
) # 2->3
recv_from_next_pair = (
(self.pp_rank + 1) % self.pp_degree,
self.pp_rank,
) # 3->2
recv_from_prev_pair = (
(self.pp_rank - 1 + self.pp_degree) % self.pp_degree,
self.pp_rank,
) # 1->2
send_to_prev_pair = (
self.pp_rank,
(self.pp_rank - 1 + self.pp_degree) % self.pp_degree,
) # 2->1

even = (self.pp_rank % 2) == 0

# 1. even send to next, odd recv from prev, 0->1, 2->3
pair = send_to_next_pair if even else recv_from_prev_pair
ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]]
self._init_pair_comm(pair, ring_id)
my_pair.remove(pair)
logger.info(f"pair0(even->odd): pp pair:{pair}, ring_id: {ring_id}")

# 2. even recv from next, odd send to prev, 1->0, 3->2
pair = recv_from_next_pair if even else send_to_prev_pair
ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]]
self._init_pair_comm(pair, ring_id)
my_pair.remove(pair)
logger.info(f"pair1(even<-odd): pp pair:{pair}, ring_id: {ring_id}")

# if pp_degree is 2, only need pair(0->1, 1->0)
if self.pp_degree > 2:
# 3. odd send to next, even recv from prev, 1->2, 3->0
pair = send_to_next_pair if not even else recv_from_prev_pair
ring_id = self.pp_ring_map.get(
pair[0] * 1000 + pair[1], max_ring_id + 1
) # 3->0 not in pp_ring_map
self._init_pair_comm(pair, ring_id)
if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1:
my_pair.remove(pair)
logger.info(
"pair2(odd->even): pp pair:{}, ring_id: {}".format(
pair, ring_id
)
)

# 4. odd recv from next, even send to prev, 2->1, 0->3
pair = recv_from_next_pair if not even else send_to_prev_pair
ring_id = self.pp_ring_map.get(
pair[0] * 1000 + pair[1], max_ring_id + 2
) # 0->3 not in pp_ring_map
self._init_pair_comm(pair, ring_id)
if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1:
my_pair.remove(pair)
logger.info(
"pair3(odd<-even): pp pair:{}, ring_id: {}".format(
pair, ring_id
)
)

assert len(my_pair) == 0, (
"Current pipeline does not support cross stage communication, "
"please check unexpected pair {}".format(my_pair)
)

def _init_pipeline_comm(self, startup_block):
# TODO (JZ-LIANG) to unify pp_rank_ and pp_rank
if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None) is None:
Expand All @@ -834,7 +751,6 @@ def _init_pipeline_comm(self, startup_block):
)

if core.is_compiled_with_custom_device('npu'):
self._init_npu_pipeline_comm(startup_block)
return

# GPU
Expand Down

0 comments on commit 07878a3

Please sign in to comment.