diff --git a/composer/algorithms/gradient_clipping/gradient_clipping.py b/composer/algorithms/gradient_clipping/gradient_clipping.py index accb88daf6..973d12914a 100644 --- a/composer/algorithms/gradient_clipping/gradient_clipping.py +++ b/composer/algorithms/gradient_clipping/gradient_clipping.py @@ -14,7 +14,6 @@ from composer.core import Algorithm, Event, State from composer.loggers import Logger from composer.models import ComposerModel -from composer.utils import using_torch_2 log = logging.getLogger(__name__) @@ -44,38 +43,24 @@ def apply_gradient_clipping(model: Union[ComposerModel, torch.nn.Module], clippi raise RuntimeError('To use FSDP with Composer, you must use torch>=1.13.0.') from torch.distributed.fsdp import FullyShardedDataParallel - is_torch_2_0 = using_torch_2() - for module in model.modules(): - if isinstance(module, FullyShardedDataParallel): - # We can only call grad clip on the parent instance, so we iterate through all - # modules and try grad clipping and FSDP will throw an exception if we - # clip any gradients that aren't a parent module - try: - if clipping_type == 'norm': - module.clip_grad_norm_(max_norm=clipping_threshold) - elif clipping_type == 'value': - module.clip_grad_norm_(max_norm=clipping_threshold, norm_type=float('inf')) - else: - raise ValueError(f"clipping type must be 'norm' or 'value' with FSDP not {clipping_type}") - except (AssertionError, RuntimeError) as e: - if (('clip_grad_norm should only be called on the root (parent) instance' == str(e) and - not is_torch_2_0) or - ('`clip_grad_norm_()` should only be called on the root FSDP instance' == str(e) and - is_torch_2_0)): - continue - else: - raise - return - parameters = model.parameters() - if clipping_type == 'adaptive': - _apply_agc(parameters, clipping_threshold=clipping_threshold) - elif clipping_type == 'norm': - torch.nn.utils.clip_grad_norm_(parameters, max_norm=clipping_threshold) - elif clipping_type == 'value': - torch.nn.utils.clip_grad_value_(parameters, clip_value=clipping_threshold) + if isinstance(module, FullyShardedDataParallel) and module.check_is_root(): + if clipping_type == 'norm': + module.clip_grad_norm_(max_norm=clipping_threshold) + elif clipping_type == 'value': + module.clip_grad_norm_(max_norm=clipping_threshold, norm_type=float('inf')) + else: + raise ValueError(f"clipping type must be 'norm' or 'value' with FSDP not {clipping_type}") else: - raise ValueError(f"clipping_type must be 'adaptive', 'norm', or 'value' not {clipping_type} ") + parameters = model.parameters() + if clipping_type == 'adaptive': + _apply_agc(parameters, clipping_threshold=clipping_threshold) + elif clipping_type == 'norm': + torch.nn.utils.clip_grad_norm_(parameters, max_norm=clipping_threshold) + elif clipping_type == 'value': + torch.nn.utils.clip_grad_value_(parameters, clip_value=clipping_threshold) + else: + raise ValueError(f"clipping_type must be 'adaptive', 'norm', or 'value' not {clipping_type} ") def _apply_agc(