Skip to content

Commit

Permalink
Fix some problems.
Browse files Browse the repository at this point in the history
  • Loading branch information
GhostScreaming committed Aug 14, 2023
1 parent 3114673 commit 2c08b3d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
5 changes: 4 additions & 1 deletion python/paddle/distributed/fleet/layers/mpu/mp_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,10 @@ class InnerOverlapLinear(paddle.autograd.PyLayer):
@staticmethod
def forward(ctx, x, weight, bias):
ctx.save_for_backward(x, weight, bias)
if not _get_mp_env_flag("Flags_skip_mp_c_identity"):
if (
_get_mp_env_flag("Flags_mp_aysnc_allreduce")
and _get_mp_env_flag("Flags_skip_mp_c_identity")
) is False:
x = paddle._legacy_C_ops.c_identity(
x,
'use_calc_stream',
Expand Down
12 changes: 7 additions & 5 deletions python/paddle/distributed/fleet/layers/mpu/mp_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
def _get_mp_env_flag(flag):
# Model parallel environment flag.
# Flags_mp_aysnc_allreduce: support all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear
# Flags_fused_linear_param_grad_add: support fused_linear_param_grad_add in ColumnParallelLinear
# Flags_skip_mp_c_identity: support skip c_identity in ColumnParallelLinear and RowParallelLinear
# Flags_fused_linear_param_grad_add: support fused_linear_param_grad_add in ColumnParallelLinear. Only works when Flags_mp_aysnc_allreduce is True.
# Flags_skip_mp_c_identity: support skip c_identity in ColumnParallelLinear and RowParallelLinear. Only works when Flags_mp_aysnc_allreduce is True.
assert flag in [
"Flags_mp_aysnc_allreduce",
"Flags_fused_linear_param_grad_add",
Expand Down Expand Up @@ -61,8 +61,8 @@ class c_identity_eager(PyLayer):
@staticmethod
def forward(ctx, tensor):
if _get_mp_env_flag(
"Flags_skip_mp_c_identity"
) and _get_mp_env_flag("Flags_mp_aysnc_allreduce"):
"Flags_mp_aysnc_allreduce"
) and _get_mp_env_flag("Flags_skip_mp_c_identity"):
return tensor
else:
return _legacy_C_ops.c_identity(
Expand Down Expand Up @@ -276,7 +276,9 @@ def forward(

@staticmethod
def backward(ctx, dy):
if _get_mp_env_flag("Flags_skip_mp_c_identity"):
if _get_mp_env_flag(
"Flags_mp_aysnc_allreduce"
) and _get_mp_env_flag("Flags_skip_mp_c_identity"):
return dy
else:
return _legacy_C_ops.c_identity(
Expand Down

0 comments on commit 2c08b3d

Please sign in to comment.