Skip to content

Commit

Permalink
Cherry pick for sharding (#47061)
Browse files Browse the repository at this point in the history
* [dygraph sharding] Overlap the reduce and the caculation for sharding stage 2. (#46495)

* [dygraph sharding stage 2] sharding broadcast overlap (#46656)

* Multi groups for broadcast of sharding stage 2 (#46894)
  • Loading branch information
FeixLiu authored Oct 18, 2022
1 parent b84edd9 commit 5b64214
Show file tree
Hide file tree
Showing 4 changed files with 408 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

import copy
import logging
import warnings

import numpy as np
from collections import OrderedDict

Expand Down Expand Up @@ -86,6 +88,11 @@ def __init__(self,
# Default information
self._optim = optim

# sharing stage 2 comm overlap flag
self._reduce_overlap = False
# record the last task used for comm overlap for sharding stage 2
self._comm_task = None

assert hasattr(self._optim, "_master_weights"
), "Must use optimizer with _master_weights attribute"

Expand All @@ -103,6 +110,17 @@ def __init__(self,
filter(lambda x: x.trainable and x.dtype == Type.fp16.value,
self._local_params))) > 0

self._broadcast_overlap = False
self._forward_pre_hook_remove_helper = []
try:
# The fp32 params such as layer_norm_0.w_0 will be at the end of param_list.
# Have to sort the params to make sure all params are in the forward using order.
self._broadcast_order_params = sorted(
self.local_params,
key=lambda x: int(x.name.split('.')[0].split('_')[-1]))
except ValueError:
self._broadcast_order_params = None

self._group = new_group(
_get_global_group().ranks) if group is None else group

Expand Down Expand Up @@ -157,6 +175,60 @@ def _sync_params_and_buffers(self):
group=self._group,
sync_op=True)

def _update_task(self, task):
if self._reduce_overlap:
assert task is not None
# Only track of the last reduce task.
# Since all tasks are on the same stream, only need to wait the last one.
# After waiting for the last reduce task, all reduce tasks before have already finished.
self._comm_task = task

def _set_reduce_overlap(self, reduce_overlap):
# Enable gradients' reduces overlap with backward calculation.
self._reduce_overlap = reduce_overlap

def _set_broadcast_overlap(self,
broadcast_overlap,
layers=None,
num_groups=None):
# Enable post optimizer broadcasts overlap with the forward calculation of next batch.
self._broadcast_overlap = broadcast_overlap
if self._broadcast_overlap:
assert layers is not None, \
"To enable broadcast overlap forward, please pass the module to the function."
self._layers = layers
warnings.warn(
"Setting overlap broadcast means the `paddle.device.cuda.synchronize()` "
"must be called manually before calling `paddle.save()` and before and inference."
)
if self._broadcast_order_params is None:
# Params' names should be like column_linear_32.w_0 patter to get the best performance.
warnings.warn(
"The param name passed to the optimizer doesn't follow .+_[0-9]+\..+ patter, "
"overlap broadcast may harm the performance.")
self._broadcast_order_params = self._local_params

if num_groups is None or num_groups > len(self._broadcast_order_params):
warnings.warn(
"The num_groups for broadcast is larger than the number of params to be broadcast. "
"It will set to default value: 1 (use the default sharding group)."
)
num_groups = 1

assert isinstance(
num_groups,
int) and num_groups > 0, "num_groups should be a positive integer"

self._number_of_broadcast_groups = num_groups
self._broadcast_groups = [
None for _ in range(self._number_of_broadcast_groups)
]
self._broadcast_groups[0] = self._group

ranks = self._group.ranks
for i in range(1, self._number_of_broadcast_groups):
self._broadcast_groups[i] = new_group(ranks)

def _generate_master_params(self, trainable_params):
if self.offload:
for param in trainable_params:
Expand Down Expand Up @@ -364,6 +436,13 @@ def step(self):
"""
A wrapper for Optimizer's step function to finish the update operation of the optimizer.
"""
# This method won't be called directly by opt.step()!
# The _redefine_opt_step() in class GroupShardedStage2 will wrap this function.
if self._broadcast_overlap:
# Clear the pre forward hook in the optimizer step.
for hook_remove in self._forward_pre_hook_remove_helper:
hook_remove.remove()
self._forward_pre_hook_remove_helper = []

if self.offload:
params_list = [self.offload_params.buffer]
Expand Down Expand Up @@ -408,9 +487,52 @@ def _broadcast_params(self):
"""Broadcast the parameters of the current rank to each rank"""

# Exchange all the shards with the other ranks
for dtype_per_rank in self.param_storages.values():
for dst_rank, internal_storage in dtype_per_rank.items():
broadcast(tensor=internal_storage.buffer,
src=self._group.ranks[dst_rank],
group=self._group,
sync_op=True)
if self._broadcast_overlap:
self._broadcast_params_overlap_forward()
else:
for dtype_per_rank in self.param_storages.values():
for dst_rank, internal_storage in dtype_per_rank.items():
broadcast(tensor=internal_storage.buffer,
src=self._group.ranks[dst_rank],
group=self._group,
sync_op=True)

def _forward_pre_hook_function(self, tasks):
# Since the layers will call pre hook by `forward_pre_hook(self, inputs)`,
# the helper functions needs the x and y to take those params.
def __impl__(x, y):
for task in tasks:
# Wait for broadcast task before using the result of the broadcast.
task.wait()

return __impl__

@paddle.autograd.no_grad()
def _broadcast_params_overlap_forward(self):
# Exchange all the shards with the other ranks,
# but overlap the broadcast with next batch's calculation.
group_idx = 0

param2task = {}
for x in self._broadcast_order_params:
if x.trainable:
group = self._broadcast_groups[group_idx]
group_idx = (group_idx + 1) % self._number_of_broadcast_groups
task = broadcast(tensor=x,
src=group.ranks[self._param2rank[x.name]],
group=group,
sync_op=False)
assert x.name not in param2task
param2task[x.name] = task

for layer in self._layers.sublayers():
if len(layer.sublayers()) == 0:
# Register forward pre hood for leaf layers. This will get the best performance.
tasks = []
for param in layer.parameters():
if param.trainable:
if param.name in param2task:
tasks.append(param2task[param.name])
self._forward_pre_hook_remove_helper.append(
layer.register_forward_pre_hook(
self._forward_pre_hook_function(tasks)))
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def __init__(
for optim in self._sharding_optimizers:
self._all_params.extend(list(optim.local_params))

# sharing stage 2 comm overlap flag
self._reduce_overlap = False

self._trainable_params = []
self._grad_reduced = []
self._trainable_param2rank = {}
Expand Down Expand Up @@ -306,6 +309,18 @@ def _clear_counters(self):
for grad_storage in self._grad_storage_list:
grad_storage.reset_checked_in()

def _set_reduce_overlap(self, reduce_overlap):
# Hacky way to not add an extra parameter to the `group_sharded_parallel` funct.
# User should use this like:
# model, optimizer, scaler = group_sharded_parallel(...)
# model._set_reduce_overlap(True)
self._reduce_overlap = reduce_overlap
if self._reduce_overlap:
assert len(
self._sharding_optimizers
) == 1, "Only support comm overlap strategy for single optimizer"
self._sharding_optimizers[0]._set_reduce_overlap(reduce_overlap)

def _get_reduce_fn(self, index, param, dst_rank):
"""
There are two ways to reduce gradient.
Expand Down Expand Up @@ -337,11 +352,12 @@ def cleanup():
del tmp_grad
param.clear_gradient(False)

# Synchronize the reduce parameter gradient
collective.reduce(tensor=param.grad,
dst=self._group.ranks[dst_rank],
group=self._group)
# TODO (Baibaifan) Asynchronous the reduce parameter gradient
# Synchronize the reduce parameter gradient asynchronize
self._sharding_optimizers[0]._update_task(
collective.reduce(tensor=param.grad,
dst=self._group.ranks[dst_rank],
group=self._group,
sync_op=not self._reduce_overlap))

# Clear the task flow and trigger callback to clear the redundant gradient
# self._clear_task_flow()
Expand Down Expand Up @@ -385,12 +401,13 @@ def cleanup():

# Reduce the bucket
grad_storage.sent = True
# Synchronize the reduce parameter gradient
collective.reduce(
tensor=grad_storage.buffer,
dst=self._group.ranks[grad_storage.destination],
group=self._group)
# TODO (Baibaifan) Asynchronous the reduce parameter gradient
# Synchronize the reduce parameter gradient asynchronize
self._sharding_optimizers[0]._update_task(
collective.reduce(
tensor=grad_storage.buffer,
dst=self._group.ranks[grad_storage.destination],
group=self._group,
sync_op=not self._reduce_overlap))

cleanup()

Expand Down Expand Up @@ -528,6 +545,10 @@ def _redefine_opt_step(self):
opt_step = opt.step

def _opt_step(self):
if self._reduce_overlap:
# Wait for the last reduce task. This wait must before grad scale function.
assert self._comm_task is not None
self._comm_task.wait()
grad_func()
opt_step()

Expand Down
Loading

0 comments on commit 5b64214

Please sign in to comment.