Skip to content

Commit

Permalink
Fix some problems in mp_async_all_reduce and skip_c_identity.
Browse files Browse the repository at this point in the history
  • Loading branch information
GhostScreaming committed Aug 14, 2023
1 parent cba7b23 commit 3114673
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion python/paddle/distributed/fleet/layers/mpu/mp_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def backward(ctx, dy):
dy.reshape([-1, dy.shape[-1]]),
transpose_x=True,
)
if bias is not None:
if bias is None:
task.wait()
return dx, dw
else:
Expand Down
6 changes: 4 additions & 2 deletions python/paddle/distributed/fleet/layers/mpu/mp_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _get_mp_env_flag(flag):
"Flags_mp_aysnc_allreduce",
"Flags_fused_linear_param_grad_add",
"Flags_skip_mp_c_identity",
], "Only support set Flags_mp_aysnc_allreduce (support all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear), Flags_fused_linear_param_grad_add (support fused_linear_param_grad_add in ColumnParallelLinear) and Flags_skip_mp_c_identity (support skip c_identity in ColumnParallelLinear and RowParallelLinear)"
], "Only support set Flags_mp_aysnc_allreduce (support all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear), Flags_fused_linear_param_grad_add (support fused_linear_param_grad_add in ColumnParallelLinear) and Flags_skip_mp_c_identity (support skip c_identity in ColumnParallelLinear with Flags_mp_aysnc_allreduce=True, and skip c_identity in RowParallelLinear)"
return str(os.getenv(flag)).lower() in ["true", "1"]


Expand All @@ -60,7 +60,9 @@ def _c_identity(tensor, group=None):
class c_identity_eager(PyLayer):
@staticmethod
def forward(ctx, tensor):
if _get_mp_env_flag("Flags_skip_mp_c_identity"):
if _get_mp_env_flag(
"Flags_skip_mp_c_identity"
) and _get_mp_env_flag("Flags_mp_aysnc_allreduce"):
return tensor
else:
return _legacy_C_ops.c_identity(
Expand Down

0 comments on commit 3114673

Please sign in to comment.