Skip to content

Commit

Permalink
Add --eval-bleu for translation
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: fairinternal/fairseq-py#989

Reviewed By: MultiPath

Differential Revision: D19411162

Pulled By: myleott

fbshipit-source-id: 74842f0174f58e39a13fb90f3cc1170c63bc89be
  • Loading branch information
MultiPath authored and facebook-github-bot committed Jan 17, 2020
1 parent 122fc1d commit 60fbf64
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 26 deletions.
8 changes: 7 additions & 1 deletion examples/translation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,13 @@ CUDA_VISIBLE_DEVICES=0 fairseq-train \
--lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
--dropout 0.3 --weight-decay 0.0001 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--max-tokens 4096
--max-tokens 4096 \
--eval-bleu \
--eval-bleu-args '{"beam": 5, "max_len_a": 1.2, "max_len_b": 10}' \
--eval-bleu-detok moses \
--eval-bleu-remove-bpe \
--eval-bleu-print-samples \
--best-checkpoint-metric bleu --maximize-best-checkpoint-metric
```

Finally we can evaluate our trained model:
Expand Down
6 changes: 5 additions & 1 deletion fairseq/meters.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,11 @@ def get_smoothed_value(self, key: str) -> float:

def get_smoothed_values(self) -> Dict[str, float]:
"""Get all smoothed values."""
return OrderedDict([(key, self.get_smoothed_value(key)) for key in self.keys()])
return OrderedDict([
(key, self.get_smoothed_value(key))
for key in self.keys()
if not key.startswith("_")
])

def reset(self):
"""Reset Meter instances."""
Expand Down
127 changes: 125 additions & 2 deletions fairseq/tasks/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,30 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from argparse import Namespace
import json
import itertools
import logging
import os

from fairseq import options, utils
import numpy as np

from fairseq import metrics, options, utils
from fairseq.data import (
AppendTokenDataset,
ConcatDataset,
data_utils,
encoders,
indexed_dataset,
LanguagePairDataset,
PrependTokenDataset,
StripTokenDataset,
TruncateDataset,
)

from . import FairseqTask, register_task
from fairseq.tasks import FairseqTask, register_task

EVAL_BLEU_ORDER = 4


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -155,6 +162,26 @@ def add_args(parser):
help='amount to upsample primary dataset')
parser.add_argument('--truncate-source', action='store_true', default=False,
help='truncate source to max-source-positions')

# options for reporting BLEU during validation
parser.add_argument('--eval-bleu', action='store_true',
help='evaluation with BLEU scores')
parser.add_argument('--eval-bleu-detok', type=str, default="space",
help='detokenizer before computing BLEU (e.g., "moses"); '
'required if using --eval-bleu; use "space" to '
'disable detokenization; see fairseq.data.encoders '
'for other options')
parser.add_argument('--eval-bleu-detok-args', type=str, metavar='JSON',
help='args for building the tokenizer, if needed')
parser.add_argument('--eval-tokenized-bleu', action='store_true', default=False,
help='if setting, we compute tokenized BLEU instead of sacrebleu')
parser.add_argument('--eval-bleu-remove-bpe', nargs='?', const='@@ ', default=None,
help='remove BPE before computing BLEU')
parser.add_argument('--eval-bleu-args', type=str, metavar='JSON',
help='generation args for BLUE scoring, '
'e.g., \'{"beam": 4, "lenpen": 0.6}\'')
parser.add_argument('--eval-bleu-print-samples', action='store_true',
help='print sample generations during validation')
# fmt: on

def __init__(self, args, src_dict, tgt_dict):
Expand Down Expand Up @@ -219,6 +246,75 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs):
def build_dataset_for_inference(self, src_tokens, src_lengths):
return LanguagePairDataset(src_tokens, src_lengths, self.source_dictionary)

def build_model(self, args):
if getattr(args, 'eval_bleu', False):
assert getattr(args, 'eval_bleu_detok', None) is not None, (
'--eval-bleu-detok is required if using --eval-bleu; '
'try --eval-bleu-detok=moses (or --eval-bleu-detok=space '
'to disable detokenization, e.g., when using sentencepiece)'
)
detok_args = json.loads(getattr(args, 'eval_bleu_detok_args', '{}') or '{}')
self.tokenizer = encoders.build_tokenizer(Namespace(
tokenizer=getattr(args, 'eval_bleu_detok', None),
**detok_args
))

gen_args = json.loads(getattr(args, 'eval_bleu_args', '{}') or '{}')
self.sequence_generator = self.build_generator(Namespace(**gen_args))
return super().build_model(args)

def valid_step(self, sample, model, criterion):
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
if self.args.eval_bleu:
bleu = self._inference_with_bleu(self.sequence_generator, sample, model)
logging_output['_bleu_sys_len'] = bleu.sys_len
logging_output['_bleu_ref_len'] = bleu.ref_len
# we split counts into separate entries so that they can be
# summed efficiently across workers using fast-stat-sync
assert len(bleu.counts) == EVAL_BLEU_ORDER
for i in range(EVAL_BLEU_ORDER):
logging_output['_bleu_counts_' + str(i)] = bleu.counts[i]
logging_output['_bleu_totals_' + str(i)] = bleu.totals[i]
return loss, sample_size, logging_output

def reduce_metrics(self, logging_outputs, criterion):
super().reduce_metrics(logging_outputs, criterion)
if self.args.eval_bleu:

def sum_logs(key):
return sum(log.get(key, 0) for log in logging_outputs)

counts, totals = [], []
for i in range(EVAL_BLEU_ORDER):
counts.append(sum_logs('_bleu_counts_' + str(i)))
totals.append(sum_logs('_bleu_totals_' + str(i)))

if max(totals) > 0:
# log counts as numpy arrays -- log_scalar will sum them correctly
metrics.log_scalar('_bleu_counts', np.array(counts))
metrics.log_scalar('_bleu_totals', np.array(totals))
metrics.log_scalar('_bleu_sys_len', sum_logs('_bleu_sys_len'))
metrics.log_scalar('_bleu_ref_len', sum_logs('_bleu_ref_len'))

def compute_bleu(meters):
import inspect
import sacrebleu
fn_sig = inspect.getfullargspec(sacrebleu.compute_bleu)[0]
if 'smooth_method' in fn_sig:
smooth = {'smooth_method': 'exp'}
else:
smooth = {'smooth': 'exp'}
bleu = sacrebleu.compute_bleu(
correct=meters['_bleu_counts'].sum,
total=meters['_bleu_totals'].sum,
sys_len=meters['_bleu_sys_len'].sum,
ref_len=meters['_bleu_ref_len'].sum,
**smooth
)
return round(bleu.score, 2)

metrics.log_derived('bleu', compute_bleu)

def max_positions(self):
"""Return the max sentence length allowed by the task."""
return (self.args.max_source_positions, self.args.max_target_positions)
Expand All @@ -232,3 +328,30 @@ def source_dictionary(self):
def target_dictionary(self):
"""Return the target :class:`~fairseq.data.Dictionary`."""
return self.tgt_dict

def _inference_with_bleu(self, generator, sample, model):
import sacrebleu

def decode(toks, escape_unk=False):
s = self.tgt_dict.string(
toks.int().cpu(),
self.args.eval_bleu_remove_bpe,
escape_unk=escape_unk,
)
if self.tokenizer:
s = self.tokenizer.decode(s)
return s

gen_out = self.inference_step(generator, [model], sample, None)
hyps, refs = [], []
for i in range(len(gen_out)):
hyps.append(decode(gen_out[i][0]['tokens']))
refs.append(decode(
utils.strip_pad(sample['target'][i], self.tgt_dict.pad()),
escape_unk=True, # don't count <unk> as matches to the hypo
))
if self.args.eval_bleu_print_samples:
logger.info('example hypothesis: ' + hyps[0])
logger.info('example reference: ' + refs[0])
tokenize = sacrebleu.DEFAULT_TOKENIZER if not self.args.eval_tokenized_bleu else 'none'
return sacrebleu.corpus_bleu(hyps, [refs], tokenize=tokenize)
43 changes: 21 additions & 22 deletions fairseq_cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,39 +163,38 @@ def train(args, trainer, task, epoch_itr):

valid_subsets = args.valid_subset.split(',')
max_update = args.max_update or math.inf
for samples in progress:
with metrics.aggregate('train_inner'):
with metrics.aggregate() as agg:
for samples in progress:
log_output = trainer.train_step(samples)
num_updates = trainer.get_num_updates()
if log_output is None:
continue

# log mid-epoch stats
stats = get_training_stats('train_inner')
stats = get_training_stats(agg.get_smoothed_values())
progress.log(stats, tag='train', step=num_updates)

if (
not args.disable_validation
and args.save_interval_updates > 0
and num_updates % args.save_interval_updates == 0
and num_updates > 0
):
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
if (
not args.disable_validation
and args.save_interval_updates > 0
and num_updates % args.save_interval_updates == 0
and num_updates > 0
):
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

if num_updates >= max_update:
break
if num_updates >= max_update:
break

# log end-of-epoch stats
stats = get_training_stats('train')
stats = get_training_stats(agg.get_smoothed_values())
progress.print(stats, tag='train', step=num_updates)

# reset epoch-level meters
metrics.reset_meters('train')


def get_training_stats(stats_key):
stats = metrics.get_smoothed_values(stats_key)
def get_training_stats(stats):
if 'nll_loss' in stats and 'ppl' not in stats:
stats['ppl'] = utils.get_perplexity(stats['nll_loss'])
stats['wall'] = round(metrics.get_meter('default', 'wall').elapsed_time, 0)
Expand Down Expand Up @@ -233,22 +232,22 @@ def validate(args, trainer, task, epoch_itr, subsets):
no_progress_bar='simple'
)

# reset validation loss meters
# reset validation meters
metrics.reset_meters('valid')

for sample in progress:
trainer.valid_step(sample)
with metrics.aggregate() as agg:
for sample in progress:
trainer.valid_step(sample)

# log validation stats
stats = get_valid_stats(args, trainer)
stats = get_valid_stats(args, trainer, agg.get_smoothed_values())
progress.print(stats, tag=subset, step=trainer.get_num_updates())

valid_losses.append(stats[args.best_checkpoint_metric])
return valid_losses


def get_valid_stats(args, trainer):
stats = metrics.get_smoothed_values('valid')
def get_valid_stats(args, trainer, stats):
if 'nll_loss' in stats and 'ppl' not in stats:
stats['ppl'] = utils.get_perplexity(stats['nll_loss'])
stats['num_updates'] = trainer.get_num_updates()
Expand Down
13 changes: 13 additions & 0 deletions tests/test_binaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,19 @@ def test_generation(self):
])
generate_main(data_dir, ['--prefix-size', '2'])

def test_eval_bleu(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_eval_bleu') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'fconv_iwslt_de_en', [
'--eval-bleu',
'--eval-bleu-print-samples',
'--eval-bleu-remove-bpe',
'--eval-bleu-detok', 'space',
'--eval-bleu-args', '{"beam": 4, "min_len": 10}',
])

def test_lstm(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_lstm') as data_dir:
Expand Down

0 comments on commit 60fbf64

Please sign in to comment.