Skip to content

Commit

Permalink
Backward reranking public (#667)
Browse files Browse the repository at this point in the history
Summary:
Implementation of noisy channel model reranking for release with paper
Pull Request resolved: fairinternal/fairseq-py#667

Reviewed By: michaelauli

Differential Revision: D15901665

Pulled By: nng555

fbshipit-source-id: 2de2c518be8e5828ffad72db3e741b0940623373
  • Loading branch information
nng555 authored and facebook-github-bot committed Aug 15, 2019
1 parent ac66df4 commit 49177c9
Show file tree
Hide file tree
Showing 13 changed files with 1,629 additions and 3 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,6 @@ fairseq/modules/*_layer/*_backward.cu

# data
data-bin/

# reranking
examples/reranking/rerank_data
6 changes: 4 additions & 2 deletions eval_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,9 @@ def main(parsed_args):
hypos = scorer.generate(models, sample)
gen_timer.stop(sample['ntokens'])

for hypos_i in hypos:
for i, hypos_i in enumerate(hypos):
hypo = hypos_i[0]
sample_id = sample['id'][i]

tokens = hypo['tokens']
tgt_len = tokens.numel()
Expand Down Expand Up @@ -199,7 +200,8 @@ def main(parsed_args):
is_bpe = False
w = ''
if args.output_word_probs:
print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))
print(str(int(sample_id)) + " " +
('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob)))

wps_meter.update(sample['ntokens'])
t.log({'wps': round(wps_meter.avg)})
Expand Down
10 changes: 10 additions & 0 deletions examples/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

__version__ = '0.7.2'

import examples.noisychannel # noqa
72 changes: 72 additions & 0 deletions examples/noisychannel/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Simple and Effective Noisy Channel Modeling for Neural Machine Translation (Yee et al., 2019)
This page contains pointers to pre-trained models as well as instructions on how to run the reranking scripts.

## Citation:
```bibtex
@inproceedings{yee2018simple,
title = {Simple and Effective Noisy Channel Modeling for Neural Machine Translation},
author = {Kyra Yee and Yann Dauphin and Michael Auli},
booktitle = {Conference on Empirical Methods in Natural Language Processing},
year = {2019},
}
```

## Pre-trained Models:

Model | Description | Download
---|---|---
`transformer.noisychannel.de-en` | De->En Forward Model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/forward_de2en.tar.bz2)
`transformer.noisychannel.en-de` | En->De Channel Model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/backward_en2de.tar.bz2)
`transformer_lm.noisychannel.en` | En Language model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/reranking_en_lm.tar.bz2)

Test Data: [newstest_wmt17](https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/wmt17test.tar.bz2)

## Example usage

```
mkdir rerank_example
curl https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/forward_de2en.tar.bz2 | tar xvjf - -C rerank_example
curl https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/backward_en2de.tar.bz2 | tar xvjf - -C rerank_example
curl https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/reranking_en_lm.tar.bz2 | tar xvjf - -C rerank_example
curl https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/wmt17test.tar.bz2 | tar xvjf - -C rerank_example
beam=50
num_trials=1000
fw_name=fw_model_ex
bw_name=bw_model_ex
lm_name=lm_ex
data_dir=rerank_example/hyphen-splitting-mixed-case-wmt17test-wmt14bpe
data_dir_name=wmt17
lm=rerank_example/lm/checkpoint_best.pt
lm_bpe_code=rerank_example/lm/bpe32k.code
lm_dict=rerank_example/lm/dict.txt
batch_size=32
bw=rerank_example/backward_en2de.pt
fw=rerank_example/forward_de2en.pt
# reranking with P(T|S) P(S|T) and P(T)
python examples/noisychannel/rerank_tune.py $data_dir --tune-param lenpen weight1 weight3 \
--lower-bound 0 0 0 --upper-bound 3 3 3 --data-dir-name $data_dir_name \
--num-trials $num_trials --source-lang de --target-lang en --gen-model $fw \
-n $beam --batch-size $batch_size --score-model2 $fw --score-model1 $bw \
--backwards1 --weight2 1 \
-lm $lm --lm-dict $lm_dict --lm-name en_newscrawl --lm-bpe-code $lm_bpe_code \
--model2-name $fw_name --model1-name $bw_name --gen-model-name $fw_name
# reranking with P(T|S) and P(T)
python examples/noisychannel/rerank_tune.py $data_dir --tune-param lenpen weight3 \
--lower-bound 0 0 --upper-bound 3 3 --data-dir-name $data_dir_name \
--num-trials $num_trials --source-lang de --target-lang en --gen-model $fw \
-n $beam --batch-size $batch_size --score-model1 $fw \
-lm $lm --lm-dict $lm_dict --lm-name en_newscrawl --lm-bpe-code $lm_bpe_code \
--model1-name $fw_name --gen-model-name $fw_name
# to run with a preconfigured set of hyperparameters for the lenpen and model weights, using rerank.py instead.
python examples/noisychannel/rerank.py $data_dir \
--lenpen 0.269 --weight1 1 --weight2 0.929 --weight3 0.831 \
--data-dir-name $data_dir_name --source-lang de --target-lang en --gen-model $fw \
-n $beam --batch-size $batch_size --score-model2 $fw --score-model1 $bw --backwards1 \
-lm $lm --lm-dict $lm_dict --lm-name en_newscrawl --lm-bpe-code $lm_bpe_code \
--model2-name $fw_name --model1-name $bw_name --gen-model-name $fw_name
```

8 changes: 8 additions & 0 deletions examples/noisychannel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

from .rerank_options import *
Loading

0 comments on commit 49177c9

Please sign in to comment.