diff --git a/composer/core/precision.py b/composer/core/precision.py index ea08a10c56..bb91fc64d1 100644 --- a/composer/core/precision.py +++ b/composer/core/precision.py @@ -40,6 +40,7 @@ class Precision(StringEnum): def get_precision_context( precision: Union[str, Precision], precision_config: Optional[dict[str, Any]] = None, + fp8_autocast_enabled: bool = True, ) -> Generator[None, None, None]: """Returns a context manager to automatically cast to a specific precision. @@ -47,6 +48,7 @@ def get_precision_context( precision (str | Precision): Precision for the context precision_config (Optional[dict[str, Any]]): Config for FP8 scaling strategy. See parameters for `DelayedScaling `_. + fp8_autocast_enabled (bool): Whether to enable FP8 autocast. Defaults to True. """ precision = Precision(precision) if precision == Precision.FP32: @@ -86,7 +88,7 @@ def get_precision_context( 'amax_compute_algo': 'max', } fp8_recipe = DelayedScaling(**precision_config) - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + with te.fp8_autocast(enabled=fp8_autocast_enabled, fp8_recipe=fp8_recipe): # The te.onnx_export flag ensures that we save all fp8 buffers # as tensors instead of bytes. This is necessary for proper # saving and resumption of checkpoints. diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index b410e8aa96..b497ef669b 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -460,10 +460,15 @@ def _get_ddp_sync_strategy(ddp_sync_strategy: Optional[Union[str, DDPSyncStrateg return ddp_sync_strategy -def _get_precision_context(precision: Precision, precision_config: Optional[dict[str, Any]], deepspeed_enabled: bool): +def _get_precision_context( + precision: Precision, + precision_config: Optional[dict[str, Any]], + deepspeed_enabled: bool, + fp8_autocast_enabled: bool = True, +): if deepspeed_enabled: return contextlib.nullcontext() - return get_precision_context(precision, precision_config) + return get_precision_context(precision, precision_config, fp8_autocast_enabled) def _generate_run_name() -> str: @@ -2672,10 +2677,15 @@ def _train_loop(self) -> None: def _eval_train_metrics(self, device_batch): assert self._train_data_spec is not None, 'The train data spec should be set on __init__ or fit()' assert self.state.train_metrics is not None, 'The train metrics should be set on __init__ or fit()' - + # We disable FP8 autocast in eval metrics and default to the activation dtype for the forward pass + # This is because FP8 in TE requires all eval data sizes to be divisible by 16 which does not hold for all evaluation datasets. + # See https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html for more info. + # Note: the activation dtype is BF16 if FSDP Mixed Precision PURE is enabled and FP32 if FSDP Mixed Precision FULL is enabled. + # See https://github.com/NVIDIA/TransformerEngine/blob/8e039fdcd98fc56582d81e373880c1509c2b8f73/transformer_engine/pytorch/module/linear.py#L250-L252 and \ + # https://github.com/NVIDIA/TransformerEngine/blob/8e039fdcd98fc56582d81e373880c1509c2b8f73/transformer_engine/pytorch/module/base.py#L495-L513 for more info. with torch.no_grad(),\ model_eval_mode(self.state.model),\ - _get_precision_context(self.state.precision, self.state.precision_config, self.state.deepspeed_enabled): + _get_precision_context(self.state.precision, self.state.precision_config, self.state.deepspeed_enabled, fp8_autocast_enabled=False): eval_outputs = self._original_model.eval_forward(device_batch, self.state.outputs) for metric in self.state.train_metrics.values(): self._original_model.update_metric( @@ -3470,11 +3480,17 @@ def _eval_loop( )[0] self.engine.run_event(Event.EVAL_BEFORE_FORWARD) - + # We disable FP8 autocast in eval mode and default to the activation dtype for the forward pass + # This is because FP8 in TE requires all eval data sizes to be divisible by 16 which does not hold for all evaluation datasets. + # See https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html for more info. + # Note: the activation dtype is BF16 if FSDP Mixed Precision PURE is enabled and FP32 if FSDP Mixed Precision FULL is enabled. + # See https://github.com/NVIDIA/TransformerEngine/blob/8e039fdcd98fc56582d81e373880c1509c2b8f73/transformer_engine/pytorch/module/linear.py#L250-L252 and \ + # https://github.com/NVIDIA/TransformerEngine/blob/8e039fdcd98fc56582d81e373880c1509c2b8f73/transformer_engine/pytorch/module/base.py#L495-L513 for more info. with _get_precision_context( self.state.precision, self.state.precision_config, self.state.deepspeed_enabled, + fp8_autocast_enabled=False, ): self.state.outputs = self._original_model.eval_forward(self.state.batch) diff --git a/tests/fixtures/fixtures.py b/tests/fixtures/fixtures.py index f587079073..c4dd3fa65f 100644 --- a/tests/fixtures/fixtures.py +++ b/tests/fixtures/fixtures.py @@ -14,7 +14,7 @@ from composer.core import State from composer.devices import DeviceCPU, DeviceGPU from composer.loggers import Logger -from composer.utils import dist +from composer.utils import dist, retry from tests.common import RandomClassificationDataset, SimpleModel from tests.conftest import _get_option @@ -310,6 +310,7 @@ def _session_tiny_t5_config(): # type: ignore return tiny_t5_config_helper() +@retry(num_attempts=3) def tiny_t5_tokenizer_helper(): transformers = pytest.importorskip('transformers') diff --git a/tests/trainer/test_trainer_eval.py b/tests/trainer/test_trainer_eval.py index b548efde81..9a2d8d6ab4 100644 --- a/tests/trainer/test_trainer_eval.py +++ b/tests/trainer/test_trainer_eval.py @@ -92,6 +92,36 @@ def test_eval_with_nondivisible_dataset(world_size: int, size: int, batch_size: assert count.item() == size +from unittest.mock import patch + + +@pytest.mark.gpu +def test_amp_fp8_eval_casts_to_bf16(): + # Check that we can import FP8 with TE. If not, skip this test. + try: + import transformer_engine # pyright: ignore + except ImportError: + pytest.skip('Precision amp_fp8 requires transformer-engine to be installed',) + + # Mocking the transformer_engine.pytorch.fp8_autocast and running model eval. + with patch('transformer_engine.pytorch.fp8_autocast') as mock_fp8_autocast: + # Construct the trainer + trainer = Trainer(model=SimpleModel(), device='gpu', precision='amp_fp8') + # Evaluate the model + dataset = RandomClassificationDataset() + trainer.eval(eval_dataloader=DataLoader( + dataset=dataset, + batch_size=10, + sampler=dist.get_sampler(dataset), + ),) + + # Check that te.fp8_autocast was called with enabled=False. + # This ensures that we disable the FP8 context on eval. + actual_call = mock_fp8_autocast.call_args_list[0] + actual_call_args = actual_call._get_call_arguments()[1] + assert actual_call_args['enabled'] is False + + def test_eval_call_with_trainer_evaluators(): trainer_dataset = RandomClassificationDataset() trainer_evaluator = Evaluator(