Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ICL race conditions #1978

Merged
merged 13 commits into from
Feb 21, 2023
55 changes: 35 additions & 20 deletions composer/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,16 +97,16 @@ def __init__(
prompt_string: str,
example_delimiter: str,
continuation_delimiter: str,
destination_path: str = 'icl_lm_task.jsonl',
destination_path: str,
):
try:
from datasets import load_dataset # pyright: ignore [reportGeneralTypeIssues]
except ImportError as e:
raise MissingConditionalImportError(extra_deps_group='nlp',
conda_package='datasets',
conda_channel='conda-forge') from e

get_file(dataset_uri, destination_path, overwrite=True)
with dist.local_rank_zero_download_and_wait(destination_path):
get_file(dataset_uri, destination_path, overwrite=True)
dataset = load_dataset('json', data_files=destination_path, split='train', streaming=False)
self.samples = list(
dataset.map(lambda examples: {
Expand Down Expand Up @@ -238,7 +238,7 @@ def __init__(
prompt_string: str,
example_delimiter: str,
continuation_delimiter: str,
destination_path: str = 'icl_mc_task.jsonl',
destination_path: str,
):
try:
from datasets import load_dataset # pyright: ignore [reportGeneralTypeIssues]
Expand All @@ -247,7 +247,8 @@ def __init__(
conda_package='datasets',
conda_channel='conda-forge') from e

get_file(dataset_uri, destination_path, overwrite=True)
with dist.local_rank_zero_download_and_wait(destination_path):
get_file(dataset_uri, destination_path, overwrite=True)
dataset = load_dataset('json', data_files=destination_path, split='train', streaming=False)
self.samples = list(
dataset.map(lambda examples: {
Expand Down Expand Up @@ -367,16 +368,17 @@ def split_batch(self, batch: Any, microbatch_size: int):


def get_icl_task_dataloader(
icl_task_type: str,
dataset_uri: str,
tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast],
batch_size: int,
max_seq_len: int,
pad_tok_id: int,
num_fewshot: int,
prompt_string: str, # e.g. 'translate english to french:'
example_delimiter: str, # e.g. '\n'
continuation_delimiter: str, # e.g. ''
icl_task_type: str,
dataset_uri: str,
tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast],
batch_size: int,
max_seq_len: int,
pad_tok_id: int,
num_fewshot: int,
prompt_string: str, # e.g. 'translate english to french:'
example_delimiter: str, # e.g. '\n'
continuation_delimiter: str, # e.g. ''
destination_path: str,
) -> DataSpec:
"""This constructs a dataloader capable of evaluating LLMs on in-context learning language modeling tasks, for example LAMBADA. An example usage is below:

Expand Down Expand Up @@ -422,14 +424,27 @@ def get_icl_task_dataloader(
"""

if icl_task_type == 'multiple_choice':
dataset = InContextLearningMultipleChoiceTaskDataset(dataset_uri, tokenizer, max_seq_len, pad_tok_id,
num_fewshot, prompt_string, example_delimiter,
continuation_delimiter)
dataset = InContextLearningMultipleChoiceTaskDataset(dataset_uri,
tokenizer,
max_seq_len,
pad_tok_id,
num_fewshot,
prompt_string,
example_delimiter,
continuation_delimiter,
destination_path=destination_path)
batch_size = max(dataset.num_choices, batch_size)
effective_batchsize = batch_size // dataset.num_choices
elif icl_task_type == 'language_modeling':
dataset = InContextLearningLMTaskDataset(dataset_uri, tokenizer, max_seq_len, pad_tok_id, num_fewshot,
prompt_string, example_delimiter, continuation_delimiter)
dataset = InContextLearningLMTaskDataset(dataset_uri,
tokenizer,
max_seq_len,
pad_tok_id,
num_fewshot,
prompt_string,
example_delimiter,
continuation_delimiter,
destination_path=destination_path)
effective_batchsize = batch_size
else:
raise Exception(f'Unrecognized ICL task type: {icl_task_type}')
Expand Down
21 changes: 21 additions & 0 deletions composer/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import datetime
import logging
import os
import time
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, TypeVar, Union, cast

Expand Down Expand Up @@ -454,6 +455,26 @@ def get_sampler(dataset: torch.utils.data.Dataset, *, drop_last: bool = False, s
)


@contextmanager
def local_rank_zero_download_and_wait(expected_file_path: str):
"""Context manager to wait for a file to exist on all ranks except local rank zero.

It is expected that the file will be created by local rank zero. This function is useful
as an alternative to ``run_local_rank_zero_first`` when downloading a file, because it does
not require dist to be initialized. It only requires that the ``LOCAL_RANK`` environment variable
is set. If dist is initialized, you probably want to use ``run_local_rank_zero_first`` instead.
dakinggg marked this conversation as resolved.
Show resolved Hide resolved

Args:
expected_file_path (str): The file to wait for existence of
"""
local_rank = get_local_rank()
if local_rank != 0:
while not os.path.exists(expected_file_path):
time.sleep(0.1)

yield


@contextmanager
def run_local_rank_zero_first():
"""Context manager to hold all non-zero ranks until rank zero completes.
Expand Down
32 changes: 20 additions & 12 deletions tests/datasets/test_in_context_learning_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_batch_padding_logic(tiny_gpt2_tokenizer):


@pytest.mark.parametrize('dataset_uri', ['lambada_small.jsonl'])
def test_lm_task_dataloader(dataset_uri, tiny_gpt2_tokenizer):
def test_lm_task_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path):
local_data = os.path.join(os.path.dirname(__file__), 'local_data')

tokenizer = tiny_gpt2_tokenizer
Expand All @@ -55,7 +55,8 @@ def test_lm_task_dataloader(dataset_uri, tiny_gpt2_tokenizer):
num_fewshot=0,
prompt_string='',
example_delimiter='\n',
continuation_delimiter='')
continuation_delimiter='',
destination_path=str(tmp_path / 'icl.jsonl'))

assert isinstance(dl.dataloader, DataLoader) # pyright
batch = next(dl.dataloader._get_iterator())
Expand All @@ -75,7 +76,7 @@ def test_lm_task_dataloader(dataset_uri, tiny_gpt2_tokenizer):

@pytest.mark.parametrize('dataset_uri', ['lambada_small.jsonl'])
@pytest.mark.parametrize('num_fewshot', [0, 1])
def test_lm_task_dataloader_opt_tokenizer(dataset_uri, num_fewshot):
def test_lm_task_dataloader_opt_tokenizer(dataset_uri, num_fewshot, tmp_path):
local_data = os.path.join(os.path.dirname(__file__), 'local_data')

tokenizer = AutoTokenizer.from_pretrained('facebook/opt-125m', use_fast=False)
Expand All @@ -91,7 +92,8 @@ def test_lm_task_dataloader_opt_tokenizer(dataset_uri, num_fewshot):
num_fewshot=num_fewshot,
prompt_string='',
example_delimiter='\n',
continuation_delimiter='')
continuation_delimiter='',
destination_path=str(tmp_path / 'icl.jsonl'))

assert isinstance(dl.dataloader, DataLoader) # pyright
batch = next(dl.dataloader._get_iterator())
Expand All @@ -113,7 +115,7 @@ def test_lm_task_dataloader_opt_tokenizer(dataset_uri, num_fewshot):

@pytest.mark.parametrize('dataset_uri', ['piqa_small.jsonl'])
@pytest.mark.parametrize('num_fewshot', [0, 1])
def test_mc_task_dataloader_opt_tokenizer(dataset_uri, num_fewshot):
def test_mc_task_dataloader_opt_tokenizer(dataset_uri, num_fewshot, tmp_path):
local_data = os.path.join(os.path.dirname(__file__), 'local_data')

tokenizer = AutoTokenizer.from_pretrained('facebook/opt-125m', use_fast=False)
Expand All @@ -130,7 +132,8 @@ def test_mc_task_dataloader_opt_tokenizer(dataset_uri, num_fewshot):
num_fewshot=num_fewshot,
prompt_string='',
example_delimiter='\n',
continuation_delimiter=': ')
continuation_delimiter=': ',
destination_path=str(tmp_path / 'icl.jsonl'))

assert isinstance(dl.dataloader, DataLoader) # pyright
batch = next(dl.dataloader._get_iterator())
Expand Down Expand Up @@ -158,7 +161,7 @@ def test_mc_task_dataloader_opt_tokenizer(dataset_uri, num_fewshot):


@pytest.mark.parametrize('dataset_uri', ['piqa_small.jsonl'])
def test_mc_task_dataloader(dataset_uri, tiny_gpt2_tokenizer):
def test_mc_task_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path):
local_data = os.path.join(os.path.dirname(__file__), 'local_data')

tokenizer = tiny_gpt2_tokenizer
Expand All @@ -174,7 +177,8 @@ def test_mc_task_dataloader(dataset_uri, tiny_gpt2_tokenizer):
num_fewshot=1,
prompt_string='',
example_delimiter='\n',
continuation_delimiter=': ')
continuation_delimiter=': ',
destination_path=str(tmp_path / 'icl.jsonl'))

assert isinstance(dl.dataloader, DataLoader) # pyright
batch = next(dl.dataloader._get_iterator())
Expand Down Expand Up @@ -202,7 +206,7 @@ def test_mc_task_dataloader(dataset_uri, tiny_gpt2_tokenizer):
@pytest.mark.parametrize('dataset_uri', ['lambada_small.jsonl'])
@pytest.mark.parametrize('num_fewshot', [0, 5])
@device('gpu')
def test_lm_task_evaluation(device, dataset_uri, num_fewshot, tiny_gpt2_tokenizer):
def test_lm_task_evaluation(device, dataset_uri, num_fewshot, tiny_gpt2_tokenizer, tmp_path):
pytest.importorskip('datasets')
in_memory_logger = InMemoryLogger() # track the logged metrics in the in_memory_logger
local_data = os.path.join(os.path.dirname(__file__), 'local_data')
Expand All @@ -217,7 +221,9 @@ def test_lm_task_evaluation(device, dataset_uri, num_fewshot, tiny_gpt2_tokenize
num_fewshot=num_fewshot,
prompt_string='',
example_delimiter='\n',
continuation_delimiter='')
continuation_delimiter='',
destination_path=str(tmp_path / 'icl.jsonl'))

evaluator = Evaluator(label='lambada', dataloader=dl, metric_names=['InContextLearningLMAccuracy'])
model = create_gpt2(use_pretrained=False, pretrained_model_name='EleutherAI/gpt-neo-125M')
model.add_eval_metrics(evaluator)
Expand All @@ -230,7 +236,7 @@ def test_lm_task_evaluation(device, dataset_uri, num_fewshot, tiny_gpt2_tokenize
@pytest.mark.parametrize('dataset_uri', ['piqa_small.jsonl', 'hellaswag_small.jsonl'])
@device('gpu')
@pytest.mark.parametrize('num_fewshot', [0, 5])
def test_mc_task_evaluation(device, num_fewshot, dataset_uri, tiny_gpt2_tokenizer):
def test_mc_task_evaluation(device, num_fewshot, dataset_uri, tiny_gpt2_tokenizer, tmp_path):
pytest.importorskip('datasets')
in_memory_logger = InMemoryLogger() # track the logged metrics in the in_memory_logger
local_data = os.path.join(os.path.dirname(__file__), 'local_data')
Expand All @@ -245,7 +251,9 @@ def test_mc_task_evaluation(device, num_fewshot, dataset_uri, tiny_gpt2_tokenize
num_fewshot=num_fewshot,
prompt_string='',
example_delimiter='\n',
continuation_delimiter=': ')
continuation_delimiter=': ',
destination_path=str(tmp_path / 'icl.jsonl'))

evaluator = Evaluator(label='lambada', dataloader=dl, metric_names=['InContextLearningMultipleChoiceAccuracy'])
model = create_gpt2(use_pretrained=False, pretrained_model_name='EleutherAI/gpt-neo-125M')
model.add_eval_metrics(evaluator)
Expand Down