From 9baee8503e4e4bafa571d948b8a2dfbe08117863 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Fri, 13 Sep 2024 20:07:11 +0800 Subject: [PATCH] fix utest (#68131) --- .../framework/distributed_strategy.proto | 5 +- python/paddle/distributed/fleet/__init__.py | 1 + .../dygraph_sharding_optimizer.py | 10 ++- .../fleet/meta_parallel/pipeline_hooks.py | 36 ++++++++++ .../fleet/meta_parallel/pipeline_parallel.py | 46 ++++++++++++ .../distributed/fleet/utils/log_util.py | 72 +++++++++++++++++++ .../fleet/utils/tensor_fusion_helper.py | 51 +++++++++++-- 7 files changed, 211 insertions(+), 10 deletions(-) create mode 100644 python/paddle/distributed/fleet/meta_parallel/pipeline_hooks.py diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 03a079ed6d18c..8726daf317924 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -97,8 +97,9 @@ message DygraphShardingConfig { optional bool split_param = 4 [ default = false ]; optional bool fuse_optimizer = 5 [ default = true ]; optional bool use_reduce_avg = 6 [ default = true ]; - optional bool release_gradients = 7 [ default = false ]; - optional int32 comm_buffer_size_MB = 8 [ default = 256 ]; + optional int32 comm_buffer_size_MB = 7 [ default = 256 ]; + optional bool release_gradients = 8 [ default = false ]; + optional bool free_grads_in_comm = 9 [ default = false ]; } message HybridConfig { diff --git a/python/paddle/distributed/fleet/__init__.py b/python/paddle/distributed/fleet/__init__.py index a5174be7ef68b..7c83bd89189da 100755 --- a/python/paddle/distributed/fleet/__init__.py +++ b/python/paddle/distributed/fleet/__init__.py @@ -107,6 +107,7 @@ set_log_level = log_util.set_log_level get_log_level_code = log_util.get_log_level_code get_log_level_name = log_util.get_log_level_name +check_memory_usage = log_util.check_memory_usage save_cache_table = fleet.save_cache_table collective_perf = fleet.collective_perf from .. import auto_parallel as auto # noqa: F401 diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index 2b3279f62cbad..33d4c60c2ae06 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -628,6 +628,7 @@ def __init__(self, optimizer, hcg): self.comm_overlap = sharding_config.comm_overlap comm_buffer_size_MB = sharding_config.comm_buffer_size_MB + free_grads_in_comm = sharding_config.free_grads_in_comm # Setting pipeline parallelism overlap self.pp_overlap = pp_config.sharding_comm_overlap @@ -643,7 +644,9 @@ def __init__(self, optimizer, hcg): "nccl reduce_avg requires paddle compiled with cuda and nccl>=2.10.0, please check compilation setups." ) - self._build_comm_buffers(acc_steps, comm_buffer_size_MB * 1024 * 1024) + self._build_comm_buffers( + acc_steps, comm_buffer_size_MB * 1024 * 1024, free_grads_in_comm + ) # NOTE(shenliang03): Sort the comm_buffers by dst rank, # it will improve the performance in reduce communicate. Default # g_shard_sort_reduce_root is True. @@ -714,7 +717,9 @@ def fused_allreduce(*_): return fused_allreduce - def _build_comm_buffers(self, acc_steps, group_size=256 * 1024 * 1024): + def _build_comm_buffers( + self, acc_steps, group_size=256 * 1024 * 1024, free_grads_in_comm=False + ): if self.pp_overlap: return # NOTE(lijin23): for XPU, we fuse all params to a single comm buffer to @@ -742,6 +747,7 @@ def _build_comm_buffers(self, acc_steps, group_size=256 * 1024 * 1024): act=HOOK_ACTION.REDUCE_SCATTER, release_grads=self.sd_release_grads, use_reduce_avg=self.use_reduce_avg, + free_grads_in_comm=free_grads_in_comm, ) self._comm_buffer_list.append(buffer) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_hooks.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_hooks.py new file mode 100644 index 0000000000000..59573f67f6584 --- /dev/null +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_hooks.py @@ -0,0 +1,36 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections import defaultdict +from typing import Callable + + +class BubbleHook: + def __init__(self): + self.hooks: dict[int, list[Callable]] = defaultdict(list) + + def set_bubble_times(self, bubble_times: int): + self.bubble_times = bubble_times + + def register_hook(self, bubble_id: int, hook: Callable): + self.hooks[bubble_id].append(hook) + + def run_hook(self, bubble_id: int): + if bubble_id not in self.hooks: + return + + for hook in self.hooks[bubble_id]: + hook(bubble_id) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 65440431bc8e0..5f8f512d384bb 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -51,6 +51,10 @@ assign_group_by_size, ) +from .pipeline_hooks import ( + BubbleHook, +) + __all__ = [] @@ -228,6 +232,16 @@ def register_global_pipeline_parallel_hook( pipeline_parallel_callbacks_.register_hook(location, hook) +pipeline_bubble_hooks_ = BubbleHook() + + +def register_bubble_pipeline_parallel_hook(location: int, hook: Callable): + """ + Registering bubble hooks for pipeline parallelism. + """ + pipeline_bubble_hooks_.register_hook(location, hook) + + class PipelineParallel(MetaParallelBase): def __init__(self, layers, hcg, strategy): if not isinstance(layers, PipelineLayer): @@ -383,6 +397,10 @@ def __init__(self, layers, hcg, strategy): self._compute_loss = True self.callbacks = pipeline_parallel_callbacks_ + self.bubble_hooks = pipeline_bubble_hooks_ + + self.bubble_hooks.set_bubble_times(bubble_times=self.num_stages) + logger.info( f"Pipeline Info -- num_stages: {self.num_stages}, stage_id: {self.stage_id}" ) @@ -1462,6 +1480,13 @@ def _process_bwd_buffer(step_id, tensor): steady_steps = num_steps - startup_steps + bubble_idx = -1 + for location in range(self.stage_id): + bubble_idx += 1 + self.bubble_hooks.run_hook(bubble_idx) + + rest_bubble_times = self.num_stages - 1 - self.stage_id + self.set_virtual_pipeline_rank(0) if not static_scheduler: self.input_tensors[0].append( @@ -1499,6 +1524,10 @@ def _process_bwd_buffer(step_id, tensor): output_tensor = self._forward_step_helper(micro_dataset, micro_step) self._record_stamp("F", micro_step, '"E"', forward=True) + if micro_step >= startup_steps - rest_bubble_times: + bubble_idx += 1 + self.bubble_hooks.run_hook(bubble_idx) + # determine whether recv forward tensor or not next_virtual_pp_rank = self._get_virtual_pp_rank( micro_step + 1, forward=True @@ -1851,6 +1880,14 @@ def _process_bwd_buffer(step_id, tensor): f"backward step for {real_micro_step} with virtual pp rank {virtual_pp_rank}" ) continue + + if ( + micro_step + < steady_steps + self.num_stages - 1 - self.stage_id + ): + bubble_idx += 1 + self.bubble_hooks.run_hook(bubble_idx) + # cooldown loop self._record_stamp("B", micro_step, '"B"', forward=False) input_tensor_grad = self._backward_step_helper(micro_step) @@ -1885,6 +1922,15 @@ def _process_bwd_buffer(step_id, tensor): self._sync_overlap_grads() + for _ in range(self.stage_id): + bubble_idx += 1 + self.bubble_hooks.run_hook(bubble_idx) + + if not forward_only: + assert (bubble_idx + 1) == ( + 2 * self.num_stages - 2 + ), f"All bubbles number {bubble_idx + 1} should be equal to {(2 * self.num_stages - 2)}" + if static_scheduler: self._reset_counter() return schedule diff --git a/python/paddle/distributed/fleet/utils/log_util.py b/python/paddle/distributed/fleet/utils/log_util.py index fc277a77a8326..6db05e27018a4 100644 --- a/python/paddle/distributed/fleet/utils/log_util.py +++ b/python/paddle/distributed/fleet/utils/log_util.py @@ -14,6 +14,7 @@ import logging import os +import subprocess from distutils.util import strtobool from logging.handlers import RotatingFileHandler @@ -115,3 +116,74 @@ def sync_rotate_logger(): if g_sync_rotate_logger is None: g_sync_rotate_logger = get_rotate_file_logger("INFO", __name__) return g_sync_rotate_logger + + +def check_memory_usage(msg=""): + GB = 1024.0 * 1024.0 * 1024.0 + mem_dict = {} + mem_dict['max_memory_allocated_size'] = ( + paddle.device.cuda.max_memory_allocated() / GB + ) + mem_dict['max_memory_reserved_size'] = ( + paddle.device.cuda.max_memory_reserved() / GB + ) + mem_dict['memory_allocated_size'] = ( + paddle.device.cuda.memory_allocated() / GB + ) + mem_dict['memory_reserved_size'] = paddle.device.cuda.memory_reserved() / GB + mem_msg = f"checking gpu memory usage {msg}:" + for key in mem_dict: + mem_msg += f"\n{key}: {mem_dict[key]}GB" + logger.info(mem_msg) + + if hasattr(paddle.device.cuda, 'max_pinned_memory_allocated'): + mem_dict = {} + mem_dict['max_memory_allocated_size'] = ( + paddle.device.cuda.max_pinned_memory_allocated() / GB + ) + mem_dict['max_memory_reserved_size'] = ( + paddle.device.cuda.max_pinned_memory_reserved() / GB + ) + mem_dict['memory_allocated_size'] = ( + paddle.device.cuda.pinned_memory_allocated() / GB + ) + mem_dict['memory_reserved_size'] = ( + paddle.device.cuda.pinned_memory_reserved() / GB + ) + mem_msg = f"checking pinned memory usage {msg}:" + for key in mem_dict: + mem_msg += f"\n{key}: {mem_dict[key]}GB" + logger.infor(mem_msg) + + if hasattr(paddle.device, 'cpu') and hasattr( + paddle.device.cpu, 'max_memory_allocated' + ): + mem_dict = {} + mem_dict['max_memory_allocated_size'] = ( + paddle.device.cpu.max_memory_allocated() / GB + ) + mem_dict['max_memory_reserved_size'] = ( + paddle.device.cpu.max_memory_reserved() / GB + ) + mem_dict['memory_allocated_size'] = ( + paddle.device.cpu.memory_allocated() / GB + ) + mem_dict['memory_reserved_size'] = ( + paddle.device.cpu.memory_reserved() / GB + ) + mem_msg = f"checking cpu memory usage {msg}:" + for key in mem_dict: + mem_msg += f"\n{key}: {mem_dict[key]}GB" + logger.info(mem_msg) + + # Execute the command and get the output + result = subprocess.run(["free", "-h"], capture_output=True, text=True) + lines = result.stdout.strip().split('\n') + + # Extract data + mem_data = lines[1].split() + swap_data = lines[2].split() + + # Format and print + formatted_output = f"checking CPU memory usage: {msg} Memory - Total: {mem_data[1]}, Used: {mem_data[2]}, Free: {mem_data[3]} Available:{mem_data[-1]}" + logger.info(formatted_output) diff --git a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py index 30db50885d90f..5ca91bcdb91ec 100644 --- a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py +++ b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py @@ -207,6 +207,7 @@ def __init__( param_end = min(self._index + self._padded_size, rank_end) self._param_begin = param_begin self._param_end = param_end + self._rank_begin = rank_begin self._slice_grad = None @@ -302,7 +303,7 @@ def assign_slice_grad(self, slice_param): else: assert slice_param.grad._is_shared_buffer_with(slice_grad) - def _reset_grad_buffer(self): + def _clear_grad_buffer(self): if self._slice_grad is not None: self._slice_grad._clear_dataptr() self._slice_grad = None @@ -311,6 +312,15 @@ def _reset_grad_buffer(self): self._grad_buffer._clear_dataptr() self._grad_buffer = None + def _reset_grad_buffer(self, slice_grad_buffer): + self._clear_grad_buffer() + self._grad_buffer = slice_grad_buffer + if self._param_begin < self._param_end: + self._slice_grad = self._grad_buffer._slice( + self._param_begin - self._rank_begin, + self._param_end - self._rank_begin, + ) + def build_reduce_scatter_buffer( parameters, sharding_degree, rank, use_main_grad=False, release_grad=False @@ -384,6 +394,7 @@ def __init__( scale_after_comm=True, release_grads=False, use_reduce_avg=False, + free_grads_in_comm=False, ): self._id = id self._params = params @@ -393,6 +404,16 @@ def __init__( self._fuse_param = fuse_param self._release_grads = release_grads self._use_reduce_avg = use_reduce_avg + self._free_grads_in_comm = free_grads_in_comm + + if self._free_grads_in_comm: + assert ( + acc_steps == 1 + ), f"No need to use free_grads_in_comm when acc_steps `{acc_steps}` != 1" + assert ( + act == HOOK_ACTION.REDUCE_SCATTER + ), "Currently, only support reduce_scatter" + assert release_grads, "Currently, only support release_grads" assert not ( self._fuse_param and self._release_grads @@ -490,7 +511,15 @@ def _clear_grad_storage(self): self.grad_storage = None if self._act == HOOK_ACTION.REDUCE_SCATTER: for param in self._params: - self._sharding_param_grad_view[param.name]._reset_grad_buffer() + self._sharding_param_grad_view[param.name]._clear_grad_buffer() + + def _reset_grad_storage(self, slice_grad_buffer): + self._clear_grad_storage() + for param in self._params: + self._sharding_param_grad_view[param.name]._reset_grad_buffer( + slice_grad_buffer + ) + self.grad_storage = slice_grad_buffer def _init_step_dict(self): for p in self._params: @@ -530,10 +559,13 @@ def _copy_grad_to_buffer(self, param): if self.use_main_grad: param.main_grad._clear() - param.main_grad = tmp_var - param.main_grad.name = "main_grad@" + param.name + if not self._free_grads_in_comm: + param.main_grad = tmp_var + param.main_grad.name = "main_grad@" + param.name else: - param._copy_gradient_from(tmp_var) + param.grad._clear() + if not self._free_grads_in_comm: + param._copy_gradient_from(tmp_var) # record address for the following `acc_steps - 1` steps. self._grads_to_addr[param.name] = get_grad_address( @@ -658,7 +690,11 @@ def _comm_grads(self): shard_size = self.grad_storage._numel() // self._comm_group.nranks begin = shard_size * self._comm_group.rank end = begin + shard_size - reduce_scattered = self.grad_storage._slice(begin, end) + reduce_scattered = ( + paddle.empty_like(self.grad_storage._slice(begin, end)) + if self._free_grads_in_comm + else self.grad_storage._slice(begin, end) + ) task = paddle.distributed.reduce_scatter( reduce_scattered, self.grad_storage, @@ -666,6 +702,9 @@ def _comm_grads(self): group=self._comm_group, sync_op=False, ) + if self._free_grads_in_comm: + self._reset_grad_storage(reduce_scattered) + self._task = task @imperative_base.no_grad