Skip to content

Commit

Permalink
improve consistency of zero_grad
Browse files Browse the repository at this point in the history
  • Loading branch information
tohtana committed Sep 18, 2024
1 parent 2a56f53 commit 84ca923
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 53 deletions.
10 changes: 9 additions & 1 deletion deepspeed/runtime/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

import os
import torch
from typing import Callable, Iterable

from deepspeed.utils import logger
from deepspeed.utils.tensor_fragment import map_to_flat_opt_states
from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank
from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank, zero_grad_params


class DeepSpeedOptimizer(object):
Expand Down Expand Up @@ -61,3 +62,10 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec
if key == 'params':
continue
param_group[key] = value

def _do_zero_grad(self,
params: Iterable[torch.nn.Parameter],
set_to_none_fn: Callable[[torch.Tensor], None],
set_to_none: bool = True,
force: bool = False) -> None:
zero_grad_params(params, set_to_none_fn, self.is_gradient_accumulation_boundary, set_to_none, force)
37 changes: 21 additions & 16 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
from deepspeed.utils.debug import debug_extract_module_and_param_names, debug_clear_module_and_param_names
from deepspeed.monitor.monitor import MonitorMaster
from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop
from deepspeed.runtime.utils import clip_grad_norm_
from deepspeed.runtime.utils import clip_grad_norm_, zero_grad_params
from deepspeed.runtime.eigenvalue import Eigenvalue
from deepspeed.runtime.data_pipeline.constants import DATA_SAMPLING, \
DATA_ROUTING, DATA_SAMPLING_ENABLED, CURRICULUM_LEARNING, \
Expand Down Expand Up @@ -2097,12 +2097,27 @@ def set_gradient_accumulation_boundary(self, is_boundary):
self._is_gradient_accumulation_boundary = is_boundary
self.optimizer.is_gradient_accumulation_boundary = is_boundary

def zero_grad(self):
def zero_grad(self, set_to_none: bool = True, force: bool = False) -> None:
"""
Zero parameter grads.
"""
for param_name, param in self.module.named_parameters():
param.grad = None
# zero grad in basic optimizer could be unreliable and may not exhibit
# the behavior that we want
if self.bfloat16_enabled():
# TODO: Temporary until bf16_optimizer and zero_optimizer are integrated
if self.zero_optimization() and hasattr(self.optimizer, "zero_grad"):
self.optimizer.zero_grad(set_to_none, force)
else:
pass
elif self.zero_optimization() or self.fp16_enabled() or self.amp_enabled():
self.optimizer.zero_grad(set_to_none, force)
else:

def set_to_none_fn(param):
param.grad = None

zero_grad_params(self.module.parameters(), set_to_none_fn, self.is_gradient_accumulation_boundary(),
set_to_none, force)

def clip_fp32_gradients(self):
clip_grad_norm_(parameters=self.module.parameters(), max_norm=self.gradient_clipping(), mpu=self.mpu)
Expand Down Expand Up @@ -2132,18 +2147,8 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}):
self.eigenvalue_enabled(),
block_eigenvalue,
)
# zero grad in basic optimizer could be unreliable and may not exhibit
# the behavior that we want
if self.bfloat16_enabled():
# TODO: Temporary until bf16_optimizer and zero_optimizer are integrated
if self.zero_optimization() and hasattr(self.optimizer, "zero_grad"):
self.optimizer.zero_grad()
else:
pass
elif self.zero_optimization() or self.fp16_enabled() or self.amp_enabled():
self.optimizer.zero_grad()
else:
self.zero_grad()

self.zero_grad(force=True)

report_progress = self.global_rank == 0 if self.global_rank else True

Expand Down
31 changes: 30 additions & 1 deletion deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Helper functions and classes from multiple sources.
"""

from collections.abc import Iterable
from collections.abc import Iterable, Callable
from deepspeed.moe.utils import is_moe_param
import os
import psutil
Expand Down Expand Up @@ -1065,3 +1065,32 @@ def to_tensor(v):
total_norm = -1

return total_norm


warn_zero_grad_shown = False


def warn_zero_grad() -> None:
global warn_zero_grad_shown
if not warn_zero_grad_shown:
msg = "zero_grad() was called but gradients are not cleared because " \
"the current iteration is not a gradient accumulation boundary. " \
"If you want to clear gradients, please set force=True."
logger.info(msg)
warn_zero_grad_shown = True
return


def zero_grad_params(params: Iterable[torch.nn.Parameter], set_to_none_fn: Callable[[torch.Tensor], None],
is_gradient_accumulation_boundary: bool, set_to_none: bool, force: bool) -> None:
if not is_gradient_accumulation_boundary and not force:
warn_zero_grad()
return

for param in params:
if set_to_none:
set_to_none_fn(param)
else:
if param.grad is not None:
param.grad.detach_()
param.grad.zero_()
41 changes: 19 additions & 22 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def unwrap_model_for_generation(model):
return


INITIAL_MICRO_STEP_ID = -1
INITIAL_MICRO_STEP_ID = 0


class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
Expand Down Expand Up @@ -293,7 +293,8 @@ def __init__(
self.gradient_predivide_factor = gradient_predivide_factor
self.postscale_gradients = postscale_gradients
self.gradient_accumulation_steps = gradient_accumulation_steps
self.micro_step_id = 0
self.micro_step_id = INITIAL_MICRO_STEP_ID
self.force_overwrite_grads = False
self.reduce_bucket_size = int(reduce_bucket_size)

if self.all2all_process_group is not None:
Expand Down Expand Up @@ -1463,7 +1464,7 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L
# move or accumulate gradient partition to target buffer
grad_buffer = self.__param_id_to_grad_partition[param.ds_id].narrow(0, 0, grad_partition.numel())
buffers.append(grad_buffer)
if self.micro_step_id == 0: # don't accumulate
if self.micro_step_id == 0 or self.force_overwrite_grads: # don't accumulate
grad_buffer.copy_(grad_partition, non_blocking=True)
# ensure grad buffer is a CUDA buffer to speed up the next few
# operations and so it can be used asynchronously
Expand Down Expand Up @@ -1504,6 +1505,8 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L
param.grad.record_stream(get_accelerator().current_stream())
param.grad = None

self.force_overwrite_grads = False

if self.offload_optimizer and self.swap_optimizer:
for i in offload_fp32_gradients.keys():
self.optimizer_swapper.swap_out_gradients(parameter=self.fp32_partitioned_groups_flat[i],
Expand Down Expand Up @@ -1719,24 +1722,18 @@ def get_partition_info(self, tensor_list, partition_size, partition_id):
return params_in_partition, params_not_in_partition, first_offset

@instrument_w_nvtx
def zero_grad(self, set_to_none=True):
"""
Zero FP16 parameter grads.
"""
self.micro_step_id = 0

# FP32 grad should never exist.
# For speed, set model fp16 grad to None by default
for group in self.fp16_groups:
for p in group:
if set_to_none:
if p.grad is not None and get_accelerator().on_accelerator(p.grad):
p.grad.record_stream(get_accelerator().current_stream())
p.grad = None
else:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
def zero_grad(self, set_to_none=True, force=False):

def set_grad_to_none(p):
if p.grad is not None and get_accelerator().on_accelerator(p.grad):
p.grad.record_stream(get_accelerator().current_stream())
p.grad = None

params = [p for group in self.fp16_groups for p in group]
self._do_zero_grad(params, set_grad_to_none, set_to_none, force)

# Flag to indicate that the reduced gradients should be copied to the buffer, not accumulated
self.force_overwrite_grads = True

def _model_parallel_all_reduce(self, tensor, op):
""" Perform all reduce within model parallel group, if any.
Expand Down Expand Up @@ -1856,7 +1853,7 @@ def reset_cpu_buffers(self):
self.norm_for_param_grads = {}

def _pre_step(self):
self.micro_step_id = 0
self.micro_step_id = INITIAL_MICRO_STEP_ID

print_rank_0(f"Inside Step function")
see_memory_usage(f"In step before checking overflow", force=False)
Expand Down
21 changes: 8 additions & 13 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1633,22 +1633,17 @@ def get_partition_info(self, tensor_list, partition_size, partition_id):

return params_in_partition, params_not_in_partition, first_offset

def zero_grad(self, set_to_none=True):
def zero_grad(self, set_to_none=True, force=False):
"""
Zero FP16 parameter grads.
"""
# FP32 grad should never exist.
# For speed, set model fp16 grad to None by default
# zero all pointers to grad tensors
for group in self.bit16_groups:
for p in group:
if set_to_none:
p.grad = None # epilogue and in step
p.grad_accum = None
else:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()

def set_grad_to_none(p):
p.grad = None # epilogue and in step
p.grad_accum = None

params = [p for group in self.bit16_groups for p in group]
self._do_zero_grad(params, set_grad_to_none, set_to_none, force)

def _model_parallel_all_reduce(self, tensor, op):
""" Perform all reduce within model parallel group, if any.
Expand Down

0 comments on commit 84ca923

Please sign in to comment.