Skip to content

Commit

Permalink
Use cross entropy from apex for improved memory efficiency (#1122)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: fairinternal/fairseq-py#1122

Reviewed By: ngoyal2707

Differential Revision: D20745717

Pulled By: myleott

fbshipit-source-id: 877a1185f17952461ef204d8ad7f05b8d37b1fd9
  • Loading branch information
myleott authored and facebook-github-bot committed Mar 31, 2020
1 parent 4d2efae commit 5065077
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 8 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,12 @@ and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more example
* [PyTorch](http://pytorch.org/) version >= 1.4.0
* Python version >= 3.6
* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
* **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library with the `--cuda_ext` and `--deprecated_fused_adam` options
* **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library:
```bash
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--deprecated_fused_adam" --global-option="--xentropy" --global-option="--fast_multihead_attn" ./
```

To install fairseq:
```bash
Expand Down
10 changes: 3 additions & 7 deletions fairseq/criterions/masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
import torch.nn.functional as F

from fairseq import metrics, utils
from fairseq import metrics, modules, utils
from fairseq.criterions import FairseqCriterion, register_criterion


Expand Down Expand Up @@ -47,12 +47,8 @@ def forward(self, model, sample, reduce=True):
targets = model.get_targets(sample, [logits])
targets = targets[masked_tokens]

loss = F.nll_loss(
F.log_softmax(
logits.view(-1, logits.size(-1)),
dim=-1,
dtype=torch.float32,
),
loss = modules.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
reduction='sum',
ignore_index=self.padding_idx,
Expand Down
2 changes: 2 additions & 0 deletions fairseq/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .beamable_mm import BeamableMM
from .character_token_embedder import CharacterTokenEmbedder
from .conv_tbc import ConvTBC
from .cross_entropy import cross_entropy
from .downsampled_multihead_attention import DownsampledMultiHeadAttention
from .dynamic_convolution import DynamicConv, DynamicConv1dTBC
from .dynamic_crf_layer import DynamicCRF
Expand Down Expand Up @@ -36,6 +37,7 @@
'BeamableMM',
'CharacterTokenEmbedder',
'ConvTBC',
'cross_entropy',
'DownsampledMultiHeadAttention',
'DynamicConv1dTBC',
'DynamicConv',
Expand Down
50 changes: 50 additions & 0 deletions fairseq/modules/cross_entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging

import torch
import torch.nn.functional as F


logger = logging.getLogger(__name__)


def _cross_entropy_pytorch(logits, target, ignore_index=None, reduction='mean'):
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
return F.nll_loss(
lprobs, target, ignore_index=ignore_index, reduction=reduction,
)


try:
from apex.contrib import xentropy

logger.info('using fused cross entropy')

def cross_entropy(logits, target, ignore_index=-100, reduction='mean'):
if logits.device == torch.device('cpu'):
return _cross_entropy_pytorch(logits, target, ignore_index, reduction)
else:
half_to_float = (logits.dtype == torch.half)
losses = xentropy.SoftmaxCrossEntropyLoss.apply(
logits, target, 0.0, ignore_index, half_to_float,
)
if reduction == 'sum':
return losses.sum()
elif reduction == 'mean':
if ignore_index >= 0:
return losses.sum() / target.ne(ignore_index).sum()
else:
return losses.mean()
elif reduction == 'none':
return losses
else:
raise NotImplementedError

except ImportError:

def cross_entropy(logits, target, ignore_index=-100, reduction='mean'):
return _cross_entropy_pytorch(logits, target, ignore_index, reduction)

0 comments on commit 5065077

Please sign in to comment.