Skip to content

Commit

Permalink
fix utest (#68131)
Browse files Browse the repository at this point in the history
  • Loading branch information
ForFishes authored Sep 13, 2024
1 parent a40ea5e commit 9baee85
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 10 deletions.
5 changes: 3 additions & 2 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,9 @@ message DygraphShardingConfig {
optional bool split_param = 4 [ default = false ];
optional bool fuse_optimizer = 5 [ default = true ];
optional bool use_reduce_avg = 6 [ default = true ];
optional bool release_gradients = 7 [ default = false ];
optional int32 comm_buffer_size_MB = 8 [ default = 256 ];
optional int32 comm_buffer_size_MB = 7 [ default = 256 ];
optional bool release_gradients = 8 [ default = false ];
optional bool free_grads_in_comm = 9 [ default = false ];
}

message HybridConfig {
Expand Down
1 change: 1 addition & 0 deletions python/paddle/distributed/fleet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
set_log_level = log_util.set_log_level
get_log_level_code = log_util.get_log_level_code
get_log_level_name = log_util.get_log_level_name
check_memory_usage = log_util.check_memory_usage
save_cache_table = fleet.save_cache_table
collective_perf = fleet.collective_perf
from .. import auto_parallel as auto # noqa: F401
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,7 @@ def __init__(self, optimizer, hcg):
self.comm_overlap = sharding_config.comm_overlap

comm_buffer_size_MB = sharding_config.comm_buffer_size_MB
free_grads_in_comm = sharding_config.free_grads_in_comm

# Setting pipeline parallelism overlap
self.pp_overlap = pp_config.sharding_comm_overlap
Expand All @@ -643,7 +644,9 @@ def __init__(self, optimizer, hcg):
"nccl reduce_avg requires paddle compiled with cuda and nccl>=2.10.0, please check compilation setups."
)

self._build_comm_buffers(acc_steps, comm_buffer_size_MB * 1024 * 1024)
self._build_comm_buffers(
acc_steps, comm_buffer_size_MB * 1024 * 1024, free_grads_in_comm
)
# NOTE(shenliang03): Sort the comm_buffers by dst rank,
# it will improve the performance in reduce communicate. Default
# g_shard_sort_reduce_root is True.
Expand Down Expand Up @@ -714,7 +717,9 @@ def fused_allreduce(*_):

return fused_allreduce

def _build_comm_buffers(self, acc_steps, group_size=256 * 1024 * 1024):
def _build_comm_buffers(
self, acc_steps, group_size=256 * 1024 * 1024, free_grads_in_comm=False
):
if self.pp_overlap:
return
# NOTE(lijin23): for XPU, we fuse all params to a single comm buffer to
Expand Down Expand Up @@ -742,6 +747,7 @@ def _build_comm_buffers(self, acc_steps, group_size=256 * 1024 * 1024):
act=HOOK_ACTION.REDUCE_SCATTER,
release_grads=self.sd_release_grads,
use_reduce_avg=self.use_reduce_avg,
free_grads_in_comm=free_grads_in_comm,
)
self._comm_buffer_list.append(buffer)

Expand Down
36 changes: 36 additions & 0 deletions python/paddle/distributed/fleet/meta_parallel/pipeline_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections import defaultdict
from typing import Callable


class BubbleHook:
def __init__(self):
self.hooks: dict[int, list[Callable]] = defaultdict(list)

def set_bubble_times(self, bubble_times: int):
self.bubble_times = bubble_times

def register_hook(self, bubble_id: int, hook: Callable):
self.hooks[bubble_id].append(hook)

def run_hook(self, bubble_id: int):
if bubble_id not in self.hooks:
return

for hook in self.hooks[bubble_id]:
hook(bubble_id)
46 changes: 46 additions & 0 deletions python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@
assign_group_by_size,
)

from .pipeline_hooks import (
BubbleHook,
)

__all__ = []


Expand Down Expand Up @@ -228,6 +232,16 @@ def register_global_pipeline_parallel_hook(
pipeline_parallel_callbacks_.register_hook(location, hook)


pipeline_bubble_hooks_ = BubbleHook()


def register_bubble_pipeline_parallel_hook(location: int, hook: Callable):
"""
Registering bubble hooks for pipeline parallelism.
"""
pipeline_bubble_hooks_.register_hook(location, hook)


class PipelineParallel(MetaParallelBase):
def __init__(self, layers, hcg, strategy):
if not isinstance(layers, PipelineLayer):
Expand Down Expand Up @@ -383,6 +397,10 @@ def __init__(self, layers, hcg, strategy):
self._compute_loss = True
self.callbacks = pipeline_parallel_callbacks_

self.bubble_hooks = pipeline_bubble_hooks_

self.bubble_hooks.set_bubble_times(bubble_times=self.num_stages)

logger.info(
f"Pipeline Info -- num_stages: {self.num_stages}, stage_id: {self.stage_id}"
)
Expand Down Expand Up @@ -1462,6 +1480,13 @@ def _process_bwd_buffer(step_id, tensor):

steady_steps = num_steps - startup_steps

bubble_idx = -1
for location in range(self.stage_id):
bubble_idx += 1
self.bubble_hooks.run_hook(bubble_idx)

rest_bubble_times = self.num_stages - 1 - self.stage_id

self.set_virtual_pipeline_rank(0)
if not static_scheduler:
self.input_tensors[0].append(
Expand Down Expand Up @@ -1499,6 +1524,10 @@ def _process_bwd_buffer(step_id, tensor):
output_tensor = self._forward_step_helper(micro_dataset, micro_step)
self._record_stamp("F", micro_step, '"E"', forward=True)

if micro_step >= startup_steps - rest_bubble_times:
bubble_idx += 1
self.bubble_hooks.run_hook(bubble_idx)

# determine whether recv forward tensor or not
next_virtual_pp_rank = self._get_virtual_pp_rank(
micro_step + 1, forward=True
Expand Down Expand Up @@ -1851,6 +1880,14 @@ def _process_bwd_buffer(step_id, tensor):
f"backward step for {real_micro_step} with virtual pp rank {virtual_pp_rank}"
)
continue

if (
micro_step
< steady_steps + self.num_stages - 1 - self.stage_id
):
bubble_idx += 1
self.bubble_hooks.run_hook(bubble_idx)

# cooldown loop
self._record_stamp("B", micro_step, '"B"', forward=False)
input_tensor_grad = self._backward_step_helper(micro_step)
Expand Down Expand Up @@ -1885,6 +1922,15 @@ def _process_bwd_buffer(step_id, tensor):

self._sync_overlap_grads()

for _ in range(self.stage_id):
bubble_idx += 1
self.bubble_hooks.run_hook(bubble_idx)

if not forward_only:
assert (bubble_idx + 1) == (
2 * self.num_stages - 2
), f"All bubbles number {bubble_idx + 1} should be equal to {(2 * self.num_stages - 2)}"

if static_scheduler:
self._reset_counter()
return schedule
Expand Down
72 changes: 72 additions & 0 deletions python/paddle/distributed/fleet/utils/log_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import logging
import os
import subprocess
from distutils.util import strtobool
from logging.handlers import RotatingFileHandler

Expand Down Expand Up @@ -115,3 +116,74 @@ def sync_rotate_logger():
if g_sync_rotate_logger is None:
g_sync_rotate_logger = get_rotate_file_logger("INFO", __name__)
return g_sync_rotate_logger


def check_memory_usage(msg=""):
GB = 1024.0 * 1024.0 * 1024.0
mem_dict = {}
mem_dict['max_memory_allocated_size'] = (
paddle.device.cuda.max_memory_allocated() / GB
)
mem_dict['max_memory_reserved_size'] = (
paddle.device.cuda.max_memory_reserved() / GB
)
mem_dict['memory_allocated_size'] = (
paddle.device.cuda.memory_allocated() / GB
)
mem_dict['memory_reserved_size'] = paddle.device.cuda.memory_reserved() / GB
mem_msg = f"checking gpu memory usage {msg}:"
for key in mem_dict:
mem_msg += f"\n{key}: {mem_dict[key]}GB"
logger.info(mem_msg)

if hasattr(paddle.device.cuda, 'max_pinned_memory_allocated'):
mem_dict = {}
mem_dict['max_memory_allocated_size'] = (
paddle.device.cuda.max_pinned_memory_allocated() / GB
)
mem_dict['max_memory_reserved_size'] = (
paddle.device.cuda.max_pinned_memory_reserved() / GB
)
mem_dict['memory_allocated_size'] = (
paddle.device.cuda.pinned_memory_allocated() / GB
)
mem_dict['memory_reserved_size'] = (
paddle.device.cuda.pinned_memory_reserved() / GB
)
mem_msg = f"checking pinned memory usage {msg}:"
for key in mem_dict:
mem_msg += f"\n{key}: {mem_dict[key]}GB"
logger.infor(mem_msg)

if hasattr(paddle.device, 'cpu') and hasattr(
paddle.device.cpu, 'max_memory_allocated'
):
mem_dict = {}
mem_dict['max_memory_allocated_size'] = (
paddle.device.cpu.max_memory_allocated() / GB
)
mem_dict['max_memory_reserved_size'] = (
paddle.device.cpu.max_memory_reserved() / GB
)
mem_dict['memory_allocated_size'] = (
paddle.device.cpu.memory_allocated() / GB
)
mem_dict['memory_reserved_size'] = (
paddle.device.cpu.memory_reserved() / GB
)
mem_msg = f"checking cpu memory usage {msg}:"
for key in mem_dict:
mem_msg += f"\n{key}: {mem_dict[key]}GB"
logger.info(mem_msg)

# Execute the command and get the output
result = subprocess.run(["free", "-h"], capture_output=True, text=True)
lines = result.stdout.strip().split('\n')

# Extract data
mem_data = lines[1].split()
swap_data = lines[2].split()

# Format and print
formatted_output = f"checking CPU memory usage: {msg} Memory - Total: {mem_data[1]}, Used: {mem_data[2]}, Free: {mem_data[3]} Available:{mem_data[-1]}"
logger.info(formatted_output)
51 changes: 45 additions & 6 deletions python/paddle/distributed/fleet/utils/tensor_fusion_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def __init__(
param_end = min(self._index + self._padded_size, rank_end)
self._param_begin = param_begin
self._param_end = param_end
self._rank_begin = rank_begin

self._slice_grad = None

Expand Down Expand Up @@ -302,7 +303,7 @@ def assign_slice_grad(self, slice_param):
else:
assert slice_param.grad._is_shared_buffer_with(slice_grad)

def _reset_grad_buffer(self):
def _clear_grad_buffer(self):
if self._slice_grad is not None:
self._slice_grad._clear_dataptr()
self._slice_grad = None
Expand All @@ -311,6 +312,15 @@ def _reset_grad_buffer(self):
self._grad_buffer._clear_dataptr()
self._grad_buffer = None

def _reset_grad_buffer(self, slice_grad_buffer):
self._clear_grad_buffer()
self._grad_buffer = slice_grad_buffer
if self._param_begin < self._param_end:
self._slice_grad = self._grad_buffer._slice(
self._param_begin - self._rank_begin,
self._param_end - self._rank_begin,
)


def build_reduce_scatter_buffer(
parameters, sharding_degree, rank, use_main_grad=False, release_grad=False
Expand Down Expand Up @@ -384,6 +394,7 @@ def __init__(
scale_after_comm=True,
release_grads=False,
use_reduce_avg=False,
free_grads_in_comm=False,
):
self._id = id
self._params = params
Expand All @@ -393,6 +404,16 @@ def __init__(
self._fuse_param = fuse_param
self._release_grads = release_grads
self._use_reduce_avg = use_reduce_avg
self._free_grads_in_comm = free_grads_in_comm

if self._free_grads_in_comm:
assert (
acc_steps == 1
), f"No need to use free_grads_in_comm when acc_steps `{acc_steps}` != 1"
assert (
act == HOOK_ACTION.REDUCE_SCATTER
), "Currently, only support reduce_scatter"
assert release_grads, "Currently, only support release_grads"

assert not (
self._fuse_param and self._release_grads
Expand Down Expand Up @@ -490,7 +511,15 @@ def _clear_grad_storage(self):
self.grad_storage = None
if self._act == HOOK_ACTION.REDUCE_SCATTER:
for param in self._params:
self._sharding_param_grad_view[param.name]._reset_grad_buffer()
self._sharding_param_grad_view[param.name]._clear_grad_buffer()

def _reset_grad_storage(self, slice_grad_buffer):
self._clear_grad_storage()
for param in self._params:
self._sharding_param_grad_view[param.name]._reset_grad_buffer(
slice_grad_buffer
)
self.grad_storage = slice_grad_buffer

def _init_step_dict(self):
for p in self._params:
Expand Down Expand Up @@ -530,10 +559,13 @@ def _copy_grad_to_buffer(self, param):

if self.use_main_grad:
param.main_grad._clear()
param.main_grad = tmp_var
param.main_grad.name = "main_grad@" + param.name
if not self._free_grads_in_comm:
param.main_grad = tmp_var
param.main_grad.name = "main_grad@" + param.name
else:
param._copy_gradient_from(tmp_var)
param.grad._clear()
if not self._free_grads_in_comm:
param._copy_gradient_from(tmp_var)

# record address for the following `acc_steps - 1` steps.
self._grads_to_addr[param.name] = get_grad_address(
Expand Down Expand Up @@ -658,14 +690,21 @@ def _comm_grads(self):
shard_size = self.grad_storage._numel() // self._comm_group.nranks
begin = shard_size * self._comm_group.rank
end = begin + shard_size
reduce_scattered = self.grad_storage._slice(begin, end)
reduce_scattered = (
paddle.empty_like(self.grad_storage._slice(begin, end))
if self._free_grads_in_comm
else self.grad_storage._slice(begin, end)
)
task = paddle.distributed.reduce_scatter(
reduce_scattered,
self.grad_storage,
op=reduce_op,
group=self._comm_group,
sync_op=False,
)
if self._free_grads_in_comm:
self._reset_grad_storage(reduce_scattered)

self._task = task

@imperative_base.no_grad
Expand Down

0 comments on commit 9baee85

Please sign in to comment.