diff --git a/composer/datasets/in_context_learning_evaluation.py b/composer/datasets/in_context_learning_evaluation.py index 0f4cd525cc..424c4e3f71 100644 --- a/composer/datasets/in_context_learning_evaluation.py +++ b/composer/datasets/in_context_learning_evaluation.py @@ -153,10 +153,11 @@ def prep_examples(self, num_fewshot: int, prompt_string: str, example_delimiter: cont = f'{continuation_delimiter}{cont}' - encoded_example['context'] = self.tokenizer(ctxt) - encoded_example['continuation'] = self.tokenizer(cont) encoded_example['preamble'] = self.tokenizer( - preamble) # if the preamble is empty then these will be 0-length lists + preamble + ) # if the preamble is empty then these will be 0-length lists, unless the tokenizer adds special tokens to empty strings (e.g. OPT tokenizer) + encoded_example['context'] = self.tokenizer(ctxt, add_special_tokens=False) + encoded_example['continuation'] = self.tokenizer(cont, add_special_tokens=False) examples.append(encoded_example) @@ -298,13 +299,13 @@ def prep_examples(self, num_fewshot: int, prompt_string: str, example_delimiter: 'choices'], self.samples[sample_idx]['gold'], if len(preamble) > 0: query = f'{example_delimiter}{query}' - choices = [f'{continuation_delimiter}{choice}' for choice in choices] - encoded_example['query'] = self.tokenizer(query) - encoded_example['choices'] = [self.tokenizer(choice) for choice in choices] encoded_example['preamble'] = self.tokenizer( - preamble) # if the preamble is empty then these will be 0-length lists + preamble + ) # if the preamble is empty then these will be 0-length lists, unless the tokenizer adds special tokens to empty strings (e.g. OPT tokenizer) encoded_example['gold_idx'] = gold_idx + encoded_example['query'] = self.tokenizer(query, add_special_tokens=False) + encoded_example['choices'] = [self.tokenizer(choice, add_special_tokens=False) for choice in choices] examples.append(encoded_example) diff --git a/tests/datasets/test_in_context_learning_datasets.py b/tests/datasets/test_in_context_learning_datasets.py index 1af596e559..f7df77d01f 100644 --- a/tests/datasets/test_in_context_learning_datasets.py +++ b/tests/datasets/test_in_context_learning_datasets.py @@ -5,6 +5,7 @@ import pytest from torch.utils.data import DataLoader +from transformers import AutoTokenizer from composer.core import Evaluator from composer.datasets.in_context_learning_evaluation import (_get_fewshot_sample_idxs, _make_padded_input, @@ -51,7 +52,43 @@ def test_lm_task_dataloader(dataset_uri, tiny_gpt2_tokenizer): batch_size, max_seq_len=seqlen, pad_tok_id=tokenizer.eos_token_id, - num_fewshot=1, + num_fewshot=0, + prompt_string='', + example_delimiter='\n', + continuation_delimiter='') + + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + + assert 'input_ids' in batch + assert tuple(batch['input_ids'].shape) == (batch_size, seqlen) + assert 'attention_mask' in batch + assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) + assert 'continuation_indices' in batch + assert isinstance(batch['continuation_indices'], list) and len(batch['continuation_indices']) == batch_size + assert 'mode' in batch + assert batch['mode'] == 'icl_task' + min_idx = min(batch['continuation_indices'][0]).item() + max_idx = max(batch['continuation_indices'][0]).item() + assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + 1]) == ' glen' + + +@pytest.mark.parametrize('dataset_uri', ['lambada_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0, 1]) +def test_lm_task_dataloader_opt_tokenizer(dataset_uri, num_fewshot): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = AutoTokenizer.from_pretrained('facebook/opt-125m', use_fast=False) + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 2 + seqlen = 2048 + dl = get_icl_task_dataloader('language_modeling', + dataset_uri, + tokenizer, + batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, prompt_string='', example_delimiter='\n', continuation_delimiter='') @@ -70,6 +107,54 @@ def test_lm_task_dataloader(dataset_uri, tiny_gpt2_tokenizer): min_idx = min(batch['continuation_indices'][0]).item() max_idx = max(batch['continuation_indices'][0]).item() assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + 1]) == ' glen' + assert tokenizer.decode(batch['input_ids'][0][0:min_idx]).startswith('') + assert tokenizer.decode(batch['input_ids'][0][0:min_idx]).count('') == 1 + + +@pytest.mark.parametrize('dataset_uri', ['piqa_small.jsonl']) +@pytest.mark.parametrize('num_fewshot', [0, 1]) +def test_mc_task_dataloader_opt_tokenizer(dataset_uri, num_fewshot): + local_data = os.path.join(os.path.dirname(__file__), 'local_data') + + tokenizer = AutoTokenizer.from_pretrained('facebook/opt-125m', use_fast=False) + + dataset_uri = f'{local_data}/{dataset_uri}' + batch_size = 2 + seqlen = 2048 + dl = get_icl_task_dataloader('multiple_choice', + dataset_uri, + tokenizer, + batch_size, + max_seq_len=seqlen, + pad_tok_id=tokenizer.eos_token_id, + num_fewshot=num_fewshot, + prompt_string='', + example_delimiter='\n', + continuation_delimiter=': ') + + assert isinstance(dl.dataloader, DataLoader) # pyright + batch = next(dl.dataloader._get_iterator()) + + choices_per_question = 2 + assert 'input_ids' in batch + assert tuple(batch['input_ids'].shape) == (batch_size, seqlen) + assert 'attention_mask' in batch + assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen) + assert 'continuation_indices' in batch + assert isinstance(batch['continuation_indices'], list) and len(batch['continuation_indices']) == batch_size + assert 'mode' in batch + assert batch['mode'] == 'icl_task' + assert 'gold_indices' in batch + assert isinstance(batch['gold_indices'], list) and len(batch['gold_indices']) == batch_size // choices_per_question + assert 'choice_groupings' in batch + assert isinstance(batch['choice_groupings'], list) and len( + batch['choice_groupings']) == batch_size // choices_per_question + + min_idx = min(batch['continuation_indices'][0]).item() + max_idx = max(batch['continuation_indices'][0]).item() + assert tokenizer.decode(batch['input_ids'][0][min_idx:max_idx + 1]) == ': Pour it onto a plate' + assert tokenizer.decode(batch['input_ids'][0][0:min_idx]).startswith('') + assert tokenizer.decode(batch['input_ids'][0][0:min_idx]).count('') == 1 @pytest.mark.parametrize('dataset_uri', ['piqa_small.jsonl'])