Skip to content

Commit

Permalink
Optionally use flash-attn's CE loss for metrics (#3394)
Browse files Browse the repository at this point in the history
* yo

* slam

* cuda

* cuda checks

* test

* fix_test

* gloo

* gloo

* lint

* lint

---------

Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>
  • Loading branch information
3 people authored Jun 17, 2024
1 parent 6023fe5 commit 2cf9262
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pr-cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
markers: not daily and not remote and not gpu and not doctest
pytest_command: coverage run -m pytest
- name: cpu-3.11-2.3
container: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04
container: mosaicml/pytorch:2.3.1_cpu-python3.11-ubuntu20.04
markers: not daily and not remote and not gpu and not doctest
pytest_command: coverage run -m pytest
- name: cpu-doctest
Expand Down
3 changes: 3 additions & 0 deletions composer/devices/device_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch.backends.cudnn
import torch.cuda
import torch.cuda.amp
import torch.distributed as torch_dist
import torch.utils.data

from composer.devices.device import Device
Expand Down Expand Up @@ -42,6 +43,8 @@ def __init__(
):
if not torch.cuda.is_available():
raise ValueError('DeviceGPU cannot be created as torch.cuda is not available.')
if torch_dist.is_gloo_available():
DeviceGPU.dist_backend = 'cuda:nccl,cpu:gloo'
if device_id is None:
device_id = dist.get_local_rank()
self._device = torch.device(f'cuda:{device_id}')
Expand Down
22 changes: 20 additions & 2 deletions composer/metrics/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,21 @@ def __init__(self, dist_sync_on_step: bool = False, ignore_index: int = -100):
super().__init__(dist_sync_on_step=dist_sync_on_step)

self.ignore_index = ignore_index
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='sum')
self.flash_loss_fn = None
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss
log.debug(
'Found `flash_attn` installation. Using CrossEntropyLoss from `flash_attn`' +
'to compute LanguageCrossEntropy metric for CUDA tensors, which will be faster.',
)
self.flash_loss_fn = FusedCrossEntropyLoss(ignore_index=ignore_index, reduction='sum')
except ImportError:
if torch.cuda.is_available():
log.debug(
'Package `flash_attn` not installed. Using torch.nn.CrossEntropyLoss ' +
'to compute LanguageCrossEntropy metric for CUDA tensors, which will be slower.',
)
self.torch_loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='sum')
self.add_state('sum_loss', default=torch.tensor(0.), dist_reduce_fx='sum')
self.add_state('total_items', default=torch.tensor(0), dist_reduce_fx='sum')

Expand All @@ -104,7 +118,11 @@ def update(self, output: Union[Mapping, Tensor], target: Tensor) -> None:

target = target.view(-1)
logits = logits.view(target.shape[0], -1)
losses = self.loss_fn(logits, target)
# Use Flash attn's CE loss function, if available, if inputs are both CUDA tensors.
if self.flash_loss_fn is not None and target.is_cuda and logits.is_cuda:
losses = self.flash_loss_fn(logits, target)
else:
losses = self.torch_loss_fn(logits, target)

total_items = (target != self.ignore_index).sum()
self.total_items += total_items #type: ignore (third-party)
Expand Down
6 changes: 5 additions & 1 deletion tests/checkpoint/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import pytest
import torch
import torch.distributed as torch_dist
from packaging import version
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.optim import adam
Expand Down Expand Up @@ -530,7 +531,10 @@ def test_get_metadata_sharded_model(model_type: str, tensor_type: str, world_siz
assert 'model_name' in metadata_sd

assert 'dist_backend' in metadata_sd
assert metadata_sd['dist_backend'] == 'nccl'
if torch_dist.is_gloo_available():
assert metadata_sd['dist_backend'] == 'cuda:nccl,cpu:gloo'
else:
assert metadata_sd['dist_backend'] == 'nccl'


@pytest.mark.filterwarnings('ignore:SWA has')
Expand Down
89 changes: 89 additions & 0 deletions tests/metrics/test_nlp_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
LanguagePerplexity,
MaskedAccuracy,
)
from tests.common import device


@pytest.mark.parametrize('ignore_index', [-100])
Expand Down Expand Up @@ -50,12 +51,100 @@ def test_masked_accuracy(ignore_index, num_classes):
assert abs(final_acc - (1.0 / num_classes)) < 0.02


@device('cpu', 'gpu')
@pytest.mark.parametrize('ignore_index', [-100])
@pytest.mark.parametrize('batch_size', [1e2, 1e3])
@pytest.mark.parametrize('sequence_length', [128])
@pytest.mark.parametrize('num_classes', [2, 10])
@pytest.mark.parametrize('minibatch_size', [56, 256, 768])
@pytest.mark.parametrize('tensor_device', ['cpu', 'gpu'])
def test_cross_entropy(
device: str,
batch_size: float,
ignore_index: Optional[int],
sequence_length: int,
num_classes: int,
minibatch_size: int,
tensor_device: str,
):
"""Sanity check to make sure that batched CrossEntropyLoss matches the expected performance.
Generates a predicted distribution from a normal distribution, and a ground truth from a normal distribution.
Verifies Cross Entropy Loss against the baseline performance.
Args:
device (str): the device to run the test on
batch_size (int): how many samples are in each batch
ignore_index (Optional[int]): if present, the class index to ignore in accuracy calculations.
sequence_length (int): the length of the generated sequence
num_classes (int): the number of classes in the classification task
minibatch_size (int): the minibatch size to simulate for model predictions
tensor_device (str): which device the input tensors to the metric are on
"""

if device == 'cpu' and tensor_device == 'gpu':
pytest.skip('Skipping test that would try to use GPU tensors when only CPU is available.')

batch_size = int(batch_size)
generated_preds = torch.randn((batch_size, sequence_length, num_classes))
generated_true = torch.randint(low=0, high=num_classes, size=(batch_size, sequence_length))

assert ignore_index is not None
torchmetrics_xent = LanguageCrossEntropy(dist_sync_on_step=False, ignore_index=ignore_index)
ce_with_keys_metric = LanguageCrossEntropy(dist_sync_on_step=False, ignore_index=ignore_index)

if tensor_device == 'cpu':
torchmetrics_xent = torchmetrics_xent.to('cpu')
ce_with_keys_metric = ce_with_keys_metric.to('cpu')
elif tensor_device == 'gpu':
torchmetrics_xent = torchmetrics_xent.to('cuda')
ce_with_keys_metric = ce_with_keys_metric.to('cuda')

if device == 'gpu':
assert torchmetrics_xent.flash_loss_fn is not None

labels_mask = torch.rand((batch_size, sequence_length))
labels_mask[labels_mask > 0.8] = 1
labels_mask[labels_mask <= 0.8] = 0
labels_mask = labels_mask.bool()
generated_true[labels_mask] = ignore_index

num_batches = math.ceil(batch_size / minibatch_size)
for batch_idx in range(num_batches):
begin_idx = (batch_idx * minibatch_size)
end_idx = ((batch_idx + 1) * minibatch_size)
preds_subset = generated_preds[begin_idx:end_idx]
true_subset = generated_true[begin_idx:end_idx]

if tensor_device == 'cpu':
preds_subset = preds_subset.cpu()
true_subset = true_subset.cpu()
elif tensor_device == 'gpu':
preds_subset = preds_subset.cuda()
true_subset = true_subset.cuda()

torchmetrics_xent.update(preds_subset, true_subset)
ce_with_keys_metric.update(
{
'logits': preds_subset.view(-1, num_classes),
'loss': cross_entropy(preds_subset.view(-1, num_classes), true_subset.view(-1)),
},
true_subset.view(-1),
)

torchmetrics_loss = torchmetrics_xent.compute()
ce_with_keys_loss = ce_with_keys_metric.compute()
correct_loss = cross_entropy(generated_preds.view(-1, num_classes), generated_true.view(-1))
assert torchmetrics_loss == ce_with_keys_loss
assert torch.isclose(correct_loss, torchmetrics_loss)


@pytest.mark.parametrize('ignore_index', [-100])
@pytest.mark.parametrize('batch_size', [1e2, 1e3])
@pytest.mark.parametrize('sequence_length', [128])
@pytest.mark.parametrize('num_classes', [2, 10])
@pytest.mark.parametrize('minibatch_size', [56, 256, 768])
def test_torch_cpu_cross_entropy(
batch_size: float,
ignore_index: Optional[int],
sequence_length: int,
Expand Down

0 comments on commit 2cf9262

Please sign in to comment.