diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 6286aa6cf0c4..19b14f69b003 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -524,8 +524,8 @@ def _aggregate_total_loss(self): assert self.global_rank in self.grid.pp_group losses = torch.Tensor([self.dp_group_loss, agg_loss]).to(self.device) - dist.broadcast(tensor=losses, src=self.global_rank, group=self.mpu.get_pipe_parallel_group()) - + if self.is_pipe_parallel: + dist.broadcast(tensor=losses, src=self.global_rank, group=self.mpu.get_pipe_parallel_group()) else: # Get loss from last stage src_rank = self.grid.stage_to_global(self.num_stages - 1)