Skip to content

Commit

Permalink
Add dummy task for translation benchmarking (#1212)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: fairinternal/fairseq-py#1212

Test Plan:
python train.py \
    -a transformer \
    --clip-norm 0.4 --optimizer adam --lr 0.001 \
    --dropout 0.0 \
    --decoder-layers 7 \
    --encoder-layers 7 \
    --encoder-ffn-embed-dim 2048 \
    --decoder-ffn-embed-dim 2048 \
    --encoder-embed-dim 1024 \
    --decoder-embed-dim 1024 \
    --max-tokens 8192 \
    --criterion cross_entropy --max-update 50 \
    --attention-dropout 0.0 \
    --adam-betas '(0.9, 0.98)' \
    --disable-validation --no-save \
    --task dummy_mt

# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Fixes # (issue).

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Reviewed By: myleott

Differential Revision: D22484873

Pulled By: msbaines

fbshipit-source-id: bc61165ab91290d0b6aa2077c968ab537bce8a6a
  • Loading branch information
Mandeep Baines authored and facebook-github-bot committed Jul 15, 2020
1 parent ffecb4e commit a541b19
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 0 deletions.
1 change: 1 addition & 0 deletions fairseq/benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
dummy_lm,
dummy_masked_lm,
dummy_model,
dummy_mt,
)
120 changes: 120 additions & 0 deletions fairseq/benchmark/dummy_mt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging

import numpy as np
import torch

from fairseq.data import Dictionary, FairseqDataset
from fairseq.tasks import FairseqTask, register_task


logger = logging.getLogger(__name__)


@register_task('dummy_mt')
class DummyMTTask(FairseqTask):

@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('--dict-size', default=49996, type=int)
parser.add_argument('--dataset-size', default=100000, type=int)
parser.add_argument('--tokens-per-sample', default=512, type=int,
help='max number of total tokens over all segments '
'per sample for BERT dataset')

def __init__(self, args, dictionary):
super().__init__(args)
self.dictionary = dictionary
self.seed = args.seed

dictionary.pad_to_multiple_(8) # often faster if divisible by 8

seq = torch.arange(args.tokens_per_sample + 1) + dictionary.pad() + 1

self.dummy_src = seq[:-1]
self.dummy_tgt = seq[1:]

@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task. """
dictionary = Dictionary()
for i in range(args.dict_size):
dictionary.add_symbol('word{}'.format(i))
logger.info('dictionary: {} types'.format(len(dictionary)))
return cls(args, dictionary)

def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
if self.args.max_sentences is not None:
bsz = self.args.max_sentences
else:
bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample)
tgt = torch.stack([self.dummy_tgt for _ in range(bsz)])
self.datasets[split] = DummyDataset(
{
'id': 1,
'net_input': {
'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]),
'src_lengths': torch.full(
(bsz, ), self.args.tokens_per_sample, dtype=torch.long
),
'prev_output_tokens': tgt.clone(),
},
'target': tgt,
'nsentences': bsz,
'ntokens': bsz * self.args.tokens_per_sample,
},
num_items=self.args.dataset_size,
item_size=self.args.tokens_per_sample,
)

@property
def source_dictionary(self):
return self.dictionary

@property
def target_dictionary(self):
return self.dictionary


class DummyDataset(FairseqDataset):

def __init__(self, batch, num_items, item_size):
super().__init__()
self.batch = batch
self.num_items = num_items
self.item_size = item_size

def __getitem__(self, index):
return index

def __len__(self):
return self.num_items

def collater(self, samples):
return self.batch

@property
def sizes(self):
return np.array([self.item_size] * self.num_items)

def num_tokens(self, index):
return self.item_size

def size(self, index):
return self.item_size

def ordered_indices(self):
return np.arange(self.num_items)

@property
def supports_prefetch(self):
return False

0 comments on commit a541b19

Please sign in to comment.