-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add dummy task for translation benchmarking (#1212)
Summary: Pull Request resolved: fairinternal/fairseq-py#1212 Test Plan: python train.py \ -a transformer \ --clip-norm 0.4 --optimizer adam --lr 0.001 \ --dropout 0.0 \ --decoder-layers 7 \ --encoder-layers 7 \ --encoder-ffn-embed-dim 2048 \ --decoder-ffn-embed-dim 2048 \ --encoder-embed-dim 1024 \ --decoder-embed-dim 1024 \ --max-tokens 8192 \ --criterion cross_entropy --max-update 50 \ --attention-dropout 0.0 \ --adam-betas '(0.9, 0.98)' \ --disable-validation --no-save \ --task dummy_mt # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Reviewed By: myleott Differential Revision: D22484873 Pulled By: msbaines fbshipit-source-id: bc61165ab91290d0b6aa2077c968ab537bce8a6a
- Loading branch information
1 parent
ffecb4e
commit a541b19
Showing
2 changed files
with
121 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,4 +8,5 @@ | |
dummy_lm, | ||
dummy_masked_lm, | ||
dummy_model, | ||
dummy_mt, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import logging | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from fairseq.data import Dictionary, FairseqDataset | ||
from fairseq.tasks import FairseqTask, register_task | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@register_task('dummy_mt') | ||
class DummyMTTask(FairseqTask): | ||
|
||
@staticmethod | ||
def add_args(parser): | ||
"""Add task-specific arguments to the parser.""" | ||
parser.add_argument('--dict-size', default=49996, type=int) | ||
parser.add_argument('--dataset-size', default=100000, type=int) | ||
parser.add_argument('--tokens-per-sample', default=512, type=int, | ||
help='max number of total tokens over all segments ' | ||
'per sample for BERT dataset') | ||
|
||
def __init__(self, args, dictionary): | ||
super().__init__(args) | ||
self.dictionary = dictionary | ||
self.seed = args.seed | ||
|
||
dictionary.pad_to_multiple_(8) # often faster if divisible by 8 | ||
|
||
seq = torch.arange(args.tokens_per_sample + 1) + dictionary.pad() + 1 | ||
|
||
self.dummy_src = seq[:-1] | ||
self.dummy_tgt = seq[1:] | ||
|
||
@classmethod | ||
def setup_task(cls, args, **kwargs): | ||
"""Setup the task. """ | ||
dictionary = Dictionary() | ||
for i in range(args.dict_size): | ||
dictionary.add_symbol('word{}'.format(i)) | ||
logger.info('dictionary: {} types'.format(len(dictionary))) | ||
return cls(args, dictionary) | ||
|
||
def load_dataset(self, split, epoch=1, combine=False, **kwargs): | ||
"""Load a given dataset split. | ||
Args: | ||
split (str): name of the split (e.g., train, valid, test) | ||
""" | ||
if self.args.max_sentences is not None: | ||
bsz = self.args.max_sentences | ||
else: | ||
bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample) | ||
tgt = torch.stack([self.dummy_tgt for _ in range(bsz)]) | ||
self.datasets[split] = DummyDataset( | ||
{ | ||
'id': 1, | ||
'net_input': { | ||
'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]), | ||
'src_lengths': torch.full( | ||
(bsz, ), self.args.tokens_per_sample, dtype=torch.long | ||
), | ||
'prev_output_tokens': tgt.clone(), | ||
}, | ||
'target': tgt, | ||
'nsentences': bsz, | ||
'ntokens': bsz * self.args.tokens_per_sample, | ||
}, | ||
num_items=self.args.dataset_size, | ||
item_size=self.args.tokens_per_sample, | ||
) | ||
|
||
@property | ||
def source_dictionary(self): | ||
return self.dictionary | ||
|
||
@property | ||
def target_dictionary(self): | ||
return self.dictionary | ||
|
||
|
||
class DummyDataset(FairseqDataset): | ||
|
||
def __init__(self, batch, num_items, item_size): | ||
super().__init__() | ||
self.batch = batch | ||
self.num_items = num_items | ||
self.item_size = item_size | ||
|
||
def __getitem__(self, index): | ||
return index | ||
|
||
def __len__(self): | ||
return self.num_items | ||
|
||
def collater(self, samples): | ||
return self.batch | ||
|
||
@property | ||
def sizes(self): | ||
return np.array([self.item_size] * self.num_items) | ||
|
||
def num_tokens(self, index): | ||
return self.item_size | ||
|
||
def size(self, index): | ||
return self.item_size | ||
|
||
def ordered_indices(self): | ||
return np.arange(self.num_items) | ||
|
||
@property | ||
def supports_prefetch(self): | ||
return False |