Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added constrained decoding (#1536) #2402

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9665f5a
Added constrained decoding (Post & Vilar, 2018)
mjpost Jul 31, 2020
8eda87f
removed nonexistent device
mjpost Jul 31, 2020
72e16ed
Consolidated on LanguagePairDataset
mjpost Jul 31, 2020
82ca082
removed stray mention of ConstrainedDataset
mjpost Jul 31, 2020
e2e2d67
named argument, cleanup
mjpost Jul 31, 2020
4a26f66
build_dataset_for_inference stubs
mjpost Jul 31, 2020
87d48f7
type() -> isinstance()
mjpost Jul 31, 2020
d6c7a41
removed torchscript-incompatible isinstance()
mjpost Aug 3, 2020
7843790
moved to explicit parameter in search
mjpost Aug 3, 2020
68d63a2
fiddling since local tests pass
mjpost Aug 3, 2020
c4a1f74
converted to packed Tensor representation
mjpost Aug 4, 2020
b93c64b
fixed old test case error (not mine)
mjpost Aug 4, 2020
1cf6271
call signature
mjpost Aug 4, 2020
5cd6c3b
constraints arg to inference_step
mjpost Aug 4, 2020
19a47ab
removed comments
mjpost Aug 4, 2020
451d6b2
Improved packed constraint structure and input handling
mjpost Aug 5, 2020
2e0a773
bugfix, unpack returns tensors
mjpost Aug 5, 2020
d4100e9
incorporated test cases
mjpost Aug 5, 2020
e84d18a
minor cleanup
mjpost Aug 5, 2020
740e9f9
"id" is a reserved word
mjpost Aug 5, 2020
2d2b8dd
bugfix handling 0 constraints; added packing test case to catch it
mjpost Aug 6, 2020
e64feaf
clarified comments, fleshed out base class
mjpost Aug 7, 2020
ff44d99
renamed file; documentation
mjpost Aug 10, 2020
661ddc5
Merge branch 'master' into constraints
mjpost Aug 17, 2020
8a8fc19
added example
mjpost Aug 17, 2020
643af66
cleanup
mjpost Aug 17, 2020
ba8cc13
Added headers; switched function documentation style; cleanup
mjpost Aug 19, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ We provide reference implementations of various sequence modeling papers:
- [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
- [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
- [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/transformer_lm/README.md)
- [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md)
mjpost marked this conversation as resolved.
Show resolved Hide resolved
- [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
- [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
- [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
Expand Down
124 changes: 124 additions & 0 deletions examples/constrained_decoding/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# (Vectorized) Lexically constrained decoding with dynamic beam allocation

This page provides instructions for how to use lexically constrained decoding in Fairseq.
Fairseq implements the code described in the following papers:

* [Fast Lexically Constrained Decoding With Dynamic Beam Allocation](https://www.aclweb.org/anthology/N18-1119/)
* [Improved Lexically Constrained Decoding for Translation and Monolingual Rewriting](https://www.aclweb.org/anthology/N19-1090/)

## Quick start

Constrained search is enabled by adding the command-line argument `--constraints` to `fairseq-interactive`.
Constraints are appended to each line of input, separated by tabs. Each constraint (one or more tokens)
is a separate field.

The following command, using [Fairseq's WMT19 German--English model](https://github.com/pytorch/fairseq/blob/master/examples/wmt19/README.md),
translates the sentence *Die maschinelle Übersetzung ist schwer zu kontrollieren.* with the constraints
"hard" and "to influence".

echo -e "Die maschinelle Übersetzung ist schwer zu kontrollieren.\thard\ttoinfluence" \
| normalize.py | tok.py \
| fairseq-interactive /path/to/model \
--path /path/to/model/model1.pt \
--bpe fastbpe \
--bpe-codes /path/to/model/bpecodes \
--constraints \
-s de -t en \
--beam 10

(tok.py and normalize.py can be found in the same directory as this README; they are just shortcuts around Fairseq's WMT19 preprocessing).
This will generate the following output:

[snip]
S-0 Die masch@@ in@@ elle Über@@ setzung ist schwer zu kontrollieren .
W-0 1.844 seconds
C-0 hard
C-0 influence
H-0 -1.5333266258239746 Mach@@ ine trans@@ lation is hard to influence .
D-0 -1.5333266258239746 Machine translation is hard to influence .
P-0 -0.5434 -0.1423 -0.1930 -0.1415 -0.2346 -1.8031 -0.1701 -11.7727 -0.1815 -0.1511

By default, constraints are generated in the order supplied, with any number (zero or more) of tokens generated
between constraints. If you wish for the decoder to order the constraints, then use `--constraints unordered`.
Note that you may want to use a larger beam.

## Implementation details

The heart of the implementation is in `fairseq/search.py`, which adds a `LexicallyConstrainedBeamSearch` instance.
This instance of beam search tracks the progress of each hypothesis in the beam through the set of constraints
provided for each input sentence. It does this using one of two classes, both found in `fairseq/token_generation_contstraints.py`:

* OrderedConstraintState: assumes the C input constraints will be generated in the provided order
* UnorderedConstraintState: tries to apply C (phrasal) constraints in all C! orders

## Differences from Sockeye

There are a number of [differences from Sockeye's implementation](https://awslabs.github.io/sockeye/inference.html#lexical-constraints).

* Generating constraints in the order supplied (the default option here) is not available in Sockeye.
* Due to an improved beam allocation method, there is no need to prune the beam.
* Again due to better allocation, beam sizes as low as 10 or even 5 are often sufficient.
* [The extensions described in Hu et al.](https://github.com/edwardjhu/sockeye/tree/trie_constraints) (NAACL 2019) were never merged
into the main branch.
* Sockeye 2, released in July 2020, no longer supports constrained decoding.

## Citation

The paper first describing lexical constraints for seq2seq decoding is:

```bibtex
@inproceedings{hokamp-liu-2017-lexically,
title = "Lexically Constrained Decoding for Sequence Generation Using Grid Beam Search",
author = "Hokamp, Chris and
Liu, Qun",
booktitle = "Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
month = jul,
year = "2017",
address = "Vancouver, Canada",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/P17-1141",
doi = "10.18653/v1/P17-1141",
pages = "1535--1546",
}
```

The fairseq implementation uses the extensions described in

```bibtex
@inproceedings{post-vilar-2018-fast,
title = "Fast Lexically Constrained Decoding with Dynamic Beam Allocation for Neural Machine Translation",
author = "Post, Matt and
Vilar, David",
booktitle = "Proceedings of the 2018 Conference of the North {A}merican Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long Papers)",
month = jun,
year = "2018",
address = "New Orleans, Louisiana",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/N18-1119",
doi = "10.18653/v1/N18-1119",
pages = "1314--1324",
}
```

and

```bibtex
@inproceedings{hu-etal-2019-improved,
title = "Improved Lexically Constrained Decoding for Translation and Monolingual Rewriting",
author = "Hu, J. Edward and
Khayrallah, Huda and
Culkin, Ryan and
Xia, Patrick and
Chen, Tongfei and
Post, Matt and
Van Durme, Benjamin",
booktitle = "Proceedings of the 2019 Conference of the North {A}merican Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers)",
month = jun,
year = "2019",
address = "Minneapolis, Minnesota",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/N19-1090",
doi = "10.18653/v1/N19-1090",
pages = "839--850",
}
```
21 changes: 21 additions & 0 deletions examples/constrained_decoding/normalize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/usr/bin/env python3
mjpost marked this conversation as resolved.
Show resolved Hide resolved

import sys

from sacremoses.normalize import MosesPunctNormalizer


def main(args):
normalizer = MosesPunctNormalizer(lang=args.lang, penn=args.penn)
for line in sys.stdin:
print(normalizer.normalize(line.rstrip()), flush=True)


if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--lang', '-l', default='en')
parser.add_argument('--penn', '-p', action='store_true')
args = parser.parse_args()

main(args)
26 changes: 26 additions & 0 deletions examples/constrained_decoding/tok.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/usr/bin/env python3

import sys
import sacremoses


def main(args):
"""Tokenizes, preserving tabs"""
mt = sacremoses.MosesTokenizer(lang=args.lang)
def tok(s):
return mt.tokenize(s, return_str=True)

for line in sys.stdin:
parts = list(map(tok, line.split("\t")))
print(*parts, sep="\t", flush=True)


if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--lang', '-l', default='en')
parser.add_argument('--penn', '-p', action='store_true')
parser.add_argument('--fields', '-f', help="fields to tokenize")
args = parser.parse_args()

main(args)
3 changes: 2 additions & 1 deletion examples/translation_moe/src/translation_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,14 @@ def valid_step(self, sample, model, criterion):
loss, sample_size, logging_output = self._get_loss(sample, model, criterion)
return loss, sample_size, logging_output

def inference_step(self, generator, models, sample, prefix_tokens=None, expert=None):
def inference_step(self, generator, models, sample, prefix_tokens=None, expert=None, constraints=None):
expert = expert or self.args.gen_expert
with torch.no_grad():
return generator.generate(
models,
sample,
prefix_tokens=prefix_tokens,
constraints=constraints,
bos_token=self.expert_index(expert),
)

Expand Down
1 change: 1 addition & 0 deletions fairseq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import fairseq.optim.lr_scheduler # noqa
import fairseq.pdb # noqa
import fairseq.tasks # noqa
import fairseq.token_generation_constraints # noqa

import fairseq.benchmark # noqa
import fairseq.model_parallel # noqa
2 changes: 1 addition & 1 deletion fairseq/data/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ class Dictionary(object):
def __init__(
self,
*, # begin keyword-only arguments
bos="<s>",
pad="<pad>",
eos="</s>",
unk="<unk>",
bos="<s>",
extra_special_symbols=None,
):
self.unk_word, self.pad_word, self.eos_word = unk, pad, eos
Expand Down
16 changes: 16 additions & 0 deletions fairseq/data/language_pair_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,16 @@ def compute_alignment_weights(alignments):
batch['alignments'] = alignments
batch['align_weights'] = align_weights

if samples[0].get("constraints", None) is not None:
# Collate the packed constraints across the samples, padding to
# the length of the longest sample.
lens = [sample.get("constraints").size(0) for sample in samples]
max_len = max(lens)
constraints = torch.zeros((len(samples), max(lens))).long()
for i, sample in enumerate(samples):
constraints[i, 0:lens[i]] = samples[i].get("constraints")
batch["constraints"] = constraints

return batch


Expand Down Expand Up @@ -161,6 +171,8 @@ class LanguagePairDataset(FairseqDataset):
target if it's absent (default: False).
align_dataset (torch.utils.data.Dataset, optional): dataset
containing alignments.
constraints (Tensor, optional): 2d tensor with a concatenated, zero-
delimited list of constraints for each sentence.
append_bos (bool, optional): if set, appends bos to the beginning of
source/target sentence.
num_buckets (int, optional): if set to a value greater than 0, then
Expand All @@ -180,6 +192,7 @@ def __init__(
shuffle=True, input_feeding=True,
remove_eos_from_source=False, append_eos_to_target=False,
align_dataset=None,
constraints=None,
append_bos=False, eos=None,
num_buckets=0,
src_lang_id=None,
Expand All @@ -206,6 +219,7 @@ def __init__(
self.align_dataset = align_dataset
if self.align_dataset is not None:
assert self.tgt_sizes is not None, "Both source and target needed when alignments are provided"
self.constraints = constraints
self.append_bos = append_bos
self.eos = (eos if eos is not None else src_dict.eos())
self.src_lang_id = src_lang_id
Expand Down Expand Up @@ -279,6 +293,8 @@ def __getitem__(self, index):
}
if self.align_dataset is not None:
example['alignment'] = self.align_dataset[index]
if self.constraints is not None:
example["constraints"] = self.constraints[index]
return example

def __len__(self):
Expand Down
4 changes: 3 additions & 1 deletion fairseq/iterative_refinement_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ def generate_batched_itr(


@torch.no_grad()
def generate(self, models, sample, prefix_tokens=None):
def generate(self, models, sample, prefix_tokens=None, constraints=None):
if constraints is not None:
raise NotImplementedError("Constrained decoding with the IterativeRefinementGenerator is not supported")

# TODO: iterative refinement generator does not support ensemble for now.
if not self.retain_dropout:
Expand Down
2 changes: 2 additions & 0 deletions fairseq/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,8 @@ def add_generation_args(parser):
help='sample from top K likely next words instead of all words')
group.add_argument('--sampling-topp', default=-1.0, type=float, metavar='PS',
help='sample from the smallest set whose cumulative probability mass exceeds p for next words')
group.add_argument('--constraints', const="ordered", nargs="?", choices=["ordered", "unordered"],
help='enables lexically constrained decoding')
group.add_argument('--temperature', default=1., type=float, metavar='N',
help='temperature for generation')
group.add_argument('--diverse-beam-groups', default=-1, type=int, metavar='N',
Expand Down
Loading