Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change FP8 Eval to default to activation dtype #3454

Merged
merged 87 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from 86 commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
76dacf5
commit change
Jun 17, 2024
8e62837
commit change
Jun 19, 2024
4c92aa1
commit change
Jul 2, 2024
a0ccee1
commit change
Jul 2, 2024
1032f21
commit change
Jul 2, 2024
8693638
commit change
Jul 3, 2024
5cf97dc
Bump coverage[toml] from 7.5.3 to 7.5.4 (#3422)
dependabot[bot] Jun 24, 2024
e8d9d67
Update psutil requirement from <6,>=5.8.0 to >=5.8.0,<7 (#3424)
dependabot[bot] Jun 24, 2024
c07e949
Add support for variable length dataloaders in DDP (#3416)
JAEarly Jun 24, 2024
87798e0
Hsdp + MoE CI tests (#3378)
KuuCi Jun 24, 2024
17bc140
bumping mlflow to 2.14.1 (#3425)
JackZ-db Jun 25, 2024
66620df
Skip HSDP + TP pytests that require torch 2.3 or above (#3426)
KuuCi Jun 25, 2024
b0c85b8
remove codeql (#3429)
mvpatel2000 Jun 26, 2024
15e6f62
Remove save overwrite (#3431)
mvpatel2000 Jun 27, 2024
17c6c68
LeDocs (#3430)
snarayan21 Jun 28, 2024
cfb66ff
Lower the system metrics logging frequency to reduce MLflow server's …
chenmoneygithub Jun 28, 2024
db4ff97
Update paramiko requirement from <3,>=2.11.0 to >=3.4.0,<4 (#3439)
dependabot[bot] Jul 1, 2024
8801127
bump versions (#3433)
mvpatel2000 Jul 1, 2024
64a3182
fix eval after all (#3445)
mvpatel2000 Jul 1, 2024
385a129
skip log (#3446)
mvpatel2000 Jul 1, 2024
0f06eb1
Remove MosaicMLLambdaEvalClient (#3432)
aspfohl Jul 1, 2024
76d39d5
Relax hf hub pin (#3435)
dakinggg Jul 1, 2024
a2fdfe9
Pytest skip 2 (#3448)
KuuCi Jul 2, 2024
47ffb06
bump version (#3450)
XiaohanZhangCMU Jul 2, 2024
3cfffcb
commit change
Jul 5, 2024
c1463dd
commit change
Jul 5, 2024
8e34d73
Merge branch 'dev' into chuck/fix_pytorch_patch
j316chuck Jul 5, 2024
bfc9dae
commit change
Jul 6, 2024
bfbad84
commit change
Jul 6, 2024
c3350a1
commit change
Jul 6, 2024
ebe5fa1
commit change
Jul 9, 2024
d3378af
Merge branch 'dev' into chuck/fix_pytorch_patch
mvpatel2000 Jul 9, 2024
6d16995
commit change
Jul 9, 2024
d84b047
commit change
Jul 9, 2024
6bfd1d5
commit change
Jul 9, 2024
fe7ef4c
commit change
Jul 9, 2024
db2bee8
ok
Jul 9, 2024
288c5a3
commit change
Jul 9, 2024
a744648
commit change
Jul 9, 2024
6b00538
commit change
Jul 9, 2024
876b882
commit change
Jul 9, 2024
a9ecbb8
commit change
Jul 9, 2024
e1f2c7f
commit change
Jul 9, 2024
10ec772
commit change
Jul 10, 2024
efd048f
commit change
Jul 10, 2024
b64f614
commit change
Jul 10, 2024
5aa85e1
commit change
Jul 10, 2024
6598b7e
commit change
Jul 10, 2024
d8e5276
commit change
Jul 10, 2024
c00aa37
commit change
Jul 10, 2024
79a70a5
commit change
Jul 10, 2024
72b8f80
commit change
Jul 10, 2024
06064c7
commit change
Jul 10, 2024
a0e8400
commit change
Jul 10, 2024
bd5b727
commit change
Jul 10, 2024
79fad52
commit change
Jul 10, 2024
ff8cc01
commit change
Jul 11, 2024
4528f4d
commit change
Jul 11, 2024
32aa6c8
commit change
Jul 11, 2024
4eef8c9
commit change
Jul 11, 2024
88887a0
commit change
Jul 11, 2024
c5b8c0d
commit change
Jul 11, 2024
9d2a0e4
commit change
Jul 11, 2024
0b21acd
commit change
Jul 11, 2024
51ad2c9
commit change
Jul 11, 2024
fcf22b0
commit change
Jul 11, 2024
f9cbb55
commit change
Jul 11, 2024
2f60a7e
commit change
Jul 11, 2024
6502008
commit change
Jul 11, 2024
a827f87
commit change
Jul 11, 2024
0dd800e
commit change
Jul 11, 2024
32dc7e1
commit change
Jul 11, 2024
433e16d
commit change
Jul 11, 2024
9a88e38
commit change
Jul 11, 2024
fad6f88
commit change
Jul 11, 2024
584b225
commit change
Jul 11, 2024
d2a503c
commit change
Jul 11, 2024
ba0d23b
commit change
Jul 11, 2024
38f3b7a
commit change
Jul 11, 2024
94cfda8
commit change
Jul 11, 2024
7e57c7c
commit change
Jul 11, 2024
9ca5c5a
commit change
Jul 11, 2024
22a941a
commit change
Jul 11, 2024
de4d313
commit change
Jul 11, 2024
3c585a5
commit change
Jul 11, 2024
228051e
Merge branch 'dev' into chuck/fix_pytorch_patch
j316chuck Jul 11, 2024
5315d79
commit change
Jul 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion composer/core/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ class Precision(StringEnum):
def get_precision_context(
precision: Union[str, Precision],
precision_config: Optional[dict[str, Any]] = None,
fp8_autocast_enabled: bool = True,
) -> Generator[None, None, None]:
"""Returns a context manager to automatically cast to a specific precision.

Args:
precision (str | Precision): Precision for the context
precision_config (Optional[dict[str, Any]]): Config for FP8 scaling strategy. See parameters for
`DelayedScaling <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html?highlight=delayedscaling#transformer_engine.common.recipe.DelayedScaling>`_.
fp8_autocast_enabled (bool): Whether to enable FP8 autocast. Defaults to True.
"""
precision = Precision(precision)
if precision == Precision.FP32:
Expand Down Expand Up @@ -86,7 +88,7 @@ def get_precision_context(
'amax_compute_algo': 'max',
}
fp8_recipe = DelayedScaling(**precision_config)
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
with te.fp8_autocast(enabled=fp8_autocast_enabled, fp8_recipe=fp8_recipe):
# The te.onnx_export flag ensures that we save all fp8 buffers
# as tensors instead of bytes. This is necessary for proper
# saving and resumption of checkpoints.
Expand Down
26 changes: 21 additions & 5 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,10 +460,15 @@ def _get_ddp_sync_strategy(ddp_sync_strategy: Optional[Union[str, DDPSyncStrateg
return ddp_sync_strategy


def _get_precision_context(precision: Precision, precision_config: Optional[dict[str, Any]], deepspeed_enabled: bool):
def _get_precision_context(
precision: Precision,
precision_config: Optional[dict[str, Any]],
deepspeed_enabled: bool,
fp8_autocast_enabled: bool = True,
):
if deepspeed_enabled:
return contextlib.nullcontext()
return get_precision_context(precision, precision_config)
return get_precision_context(precision, precision_config, fp8_autocast_enabled)


def _generate_run_name() -> str:
Expand Down Expand Up @@ -2672,10 +2677,15 @@ def _train_loop(self) -> None:
def _eval_train_metrics(self, device_batch):
assert self._train_data_spec is not None, 'The train data spec should be set on __init__ or fit()'
assert self.state.train_metrics is not None, 'The train metrics should be set on __init__ or fit()'

# We disable FP8 autocast in eval metrics and default to the activation dtype for the forward pass
# This is because FP8 in TE requires all eval data sizes to be divisible by 16 which does not hold for all evaluation datasets.
# See https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html for more info.
# Note: the activation dtype is BF16 if FSDP Mixed Precision PURE is enabled and FP32 if FSDP Mixed Precision FULL is enabled.
# See https://github.com/NVIDIA/TransformerEngine/blob/8e039fdcd98fc56582d81e373880c1509c2b8f73/transformer_engine/pytorch/module/linear.py#L250-L252 and \
# https://github.com/NVIDIA/TransformerEngine/blob/8e039fdcd98fc56582d81e373880c1509c2b8f73/transformer_engine/pytorch/module/base.py#L495-L513 for more info.
with torch.no_grad(),\
model_eval_mode(self.state.model),\
_get_precision_context(self.state.precision, self.state.precision_config, self.state.deepspeed_enabled):
_get_precision_context(self.state.precision, self.state.precision_config, self.state.deepspeed_enabled, fp8_autocast_enabled=False):
eval_outputs = self._original_model.eval_forward(device_batch, self.state.outputs)
for metric in self.state.train_metrics.values():
self._original_model.update_metric(
Expand Down Expand Up @@ -3470,11 +3480,17 @@ def _eval_loop(
)[0]

self.engine.run_event(Event.EVAL_BEFORE_FORWARD)

# We disable FP8 autocast in eval mode and default to the activation dtype for the forward pass
# This is because FP8 in TE requires all eval data sizes to be divisible by 16 which does not hold for all evaluation datasets.
# See https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html for more info.
# Note: the activation dtype is BF16 if FSDP Mixed Precision PURE is enabled and FP32 if FSDP Mixed Precision FULL is enabled.
# See https://github.com/NVIDIA/TransformerEngine/blob/8e039fdcd98fc56582d81e373880c1509c2b8f73/transformer_engine/pytorch/module/linear.py#L250-L252 and \
# https://github.com/NVIDIA/TransformerEngine/blob/8e039fdcd98fc56582d81e373880c1509c2b8f73/transformer_engine/pytorch/module/base.py#L495-L513 for more info.
with _get_precision_context(
self.state.precision,
self.state.precision_config,
self.state.deepspeed_enabled,
fp8_autocast_enabled=False,
):
self.state.outputs = self._original_model.eval_forward(self.state.batch)

Expand Down
3 changes: 2 additions & 1 deletion tests/fixtures/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from composer.core import State
from composer.devices import DeviceCPU, DeviceGPU
from composer.loggers import Logger
from composer.utils import dist
from composer.utils import dist, retry
from tests.common import RandomClassificationDataset, SimpleModel
from tests.conftest import _get_option

Expand Down Expand Up @@ -310,6 +310,7 @@ def _session_tiny_t5_config(): # type: ignore
return tiny_t5_config_helper()


@retry(num_attempts=3)
def tiny_t5_tokenizer_helper():
transformers = pytest.importorskip('transformers')

Expand Down
30 changes: 30 additions & 0 deletions tests/trainer/test_trainer_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,36 @@ def test_eval_with_nondivisible_dataset(world_size: int, size: int, batch_size:
assert count.item() == size


from unittest.mock import patch


@pytest.mark.gpu
def test_amp_fp8_eval_casts_to_bf16():
# Check that we can import FP8 with TE. If not, skip this test.
try:
import transformer_engine # pyright: ignore
except ImportError:
pytest.skip('Precision amp_fp8 requires transformer-engine to be installed',)

# Mocking the transformer_engine.pytorch.fp8_autocast and running model eval.
with patch('transformer_engine.pytorch.fp8_autocast') as mock_fp8_autocast:
# Construct the trainer
trainer = Trainer(model=SimpleModel(), device='gpu', precision='amp_fp8')
# Evaluate the model
dataset = RandomClassificationDataset()
trainer.eval(eval_dataloader=DataLoader(
dataset=dataset,
batch_size=10,
sampler=dist.get_sampler(dataset),
),)

# Check that te.fp8_autocast was called with enabled=False.
# This ensures that we disable the FP8 context on eval.
actual_call = mock_fp8_autocast.call_args_list[0]
actual_call_args = actual_call._get_call_arguments()[1]
assert actual_call_args['enabled'] is False


def test_eval_call_with_trainer_evaluators():
trainer_dataset = RandomClassificationDataset()
trainer_evaluator = Evaluator(
Expand Down
Loading