Skip to content

Commit

Permalink
add scale
Browse files Browse the repository at this point in the history
  • Loading branch information
sneaxiy committed Oct 28, 2022
1 parent c5bb845 commit fac7af3
Showing 1 changed file with 29 additions and 11 deletions.
40 changes: 29 additions & 11 deletions python/paddle/distributed/fleet/utils/hybrid_parallel_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
__all__ = []


def _apply_collective_grads(parameters, comm_group, bucket_size):
def _apply_collective_grads(parameters, comm_group, bucket_size, scale=None):
grad_var_set = set()
grad_vars = []
sparse_grad_vars = []
Expand All @@ -48,21 +48,33 @@ def _apply_collective_grads(parameters, comm_group, bucket_size):
if comm_group is None
else comm_group.nranks
)

if scale is None:
scale = nranks
else:
scale = 1.0 / scale

if scale == 1.0:
scale = None

for coalesced_grad, _, _ in coalesced_grads_and_vars:
# need to div nranks
div_factor = paddle.to_tensor(nranks, dtype=coalesced_grad.dtype)
paddle.fluid.framework._dygraph_tracer().trace_op(
type="elementwise_div",
inputs={'X': coalesced_grad, 'Y': div_factor},
outputs={'Out': coalesced_grad},
attrs={'axis': -1},
)
if scale is not None:
div_factor = paddle.to_tensor(scale, dtype=coalesced_grad.dtype)
paddle.fluid.framework._dygraph_tracer().trace_op(
type="elementwise_div",
inputs={'X': coalesced_grad, 'Y': div_factor},
outputs={'Out': coalesced_grad},
attrs={'axis': -1},
)
paddle.distributed.all_reduce(coalesced_grad, group=comm_group)

_split_tensors(coalesced_grads_and_vars)


def _apply_collective_grads_eager(parameters, comm_group, bucket_size):
def _apply_collective_grads_eager(
parameters, comm_group, bucket_size, scale=None
):
grad_var_set = set()
grad_vars = []

Expand All @@ -83,9 +95,15 @@ def _apply_collective_grads_eager(parameters, comm_group, bucket_size):
if comm_group is None
else comm_group.nranks
)
if scale is None:
scale = 1.0 / nranks
if scale == 1.0:
scale = None

for coalesced_grad, _, _ in coalesced_grads_and_vars:
# need to div nranks
coalesced_grad.scale_(1.0 / nranks)
if scale is not None:
coalesced_grad.scale_(scale)
paddle.distributed.all_reduce(coalesced_grad, group=comm_group)

_split_tensors(coalesced_grads_and_vars)
Expand Down Expand Up @@ -173,7 +191,7 @@ def broadcast_dp_parameters(model, hcg):


def fused_allreduce_gradients_with_group(
parameter_list, group, bucket_size=128 * 1024 * 1024
parameter_list, group, bucket_size=128 * 1024 * 1024, scale=None
):
apply_func = (
_apply_collective_grads_eager
Expand Down

0 comments on commit fac7af3

Please sign in to comment.