Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update support for HumanEval #2550

Merged
merged 11 commits into from
Sep 25, 2023
19 changes: 17 additions & 2 deletions composer/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,7 @@ def __init__(
code_prelimiter: str,
fewshot_random_seed: int,
generations_per_sample: int,
pass_at_k: int = 1,
top_p: Optional[float] = 0.95,
top_k: Optional[int] = 40,
):
Expand All @@ -918,7 +919,15 @@ def __init__(
'test_outputs': examples['test_outputs'],
'language': examples['language'],
}))

if generations_per_sample < pass_at_k:
raise ValueError(
f'Invalid for generations_per_sample ({generations_per_sample}) to be less than pass_at_k ({pass_at_k}) for code evaluation.'
)
mcarbin marked this conversation as resolved.
Show resolved Hide resolved

self.pass_at_k = pass_at_k
self.generations_per_sample = generations_per_sample

self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
self.pad_tok_id = pad_tok_id
Expand Down Expand Up @@ -1040,10 +1049,11 @@ def collate_fn(self, data):
'test_inputs': test_inputs, # list of test inputs
'test_outputs': test_outputs, # list of test outputs
'languages': languages, # list of languages
'pass_at_k': self.pass_at_k,
'generation_length': self.max_seq_len - self.max_prompt_length,
'generation_kwargs': {
'pad_token_id': self.pad_tok_id,
'num_beams': self.generations_per_sample, # change strategy to beam search
'num_beams': 1, # single beam
mcarbin marked this conversation as resolved.
Show resolved Hide resolved
'num_return_sequences': self.generations_per_sample, # how many gens per prompt
'do_sample': True,
'top_p': self.top_p,
Expand All @@ -1062,7 +1072,7 @@ def split_batch(self, batch: Any, microbatch_size: int):
# Don't split kwargs that don't change
# Normally split torch tensors
# List split lists of strings
no_split = ['mode', 'generation_length', 'generation_kwargs']
no_split = ['mode', 'generation_length', 'pass_at_k', 'generation_kwargs']
normal_split = ['input_ids', 'attention_mask']
list_split = [
'labels', 'tests', 'canonical_solutions', 'entry_points', 'test_inputs', 'test_outputs', 'prompts',
Expand Down Expand Up @@ -1101,6 +1111,7 @@ def build_icl_dataloader(
destination_path: str,
question_prelimiter: str = '', # e.g. 'Question: '
fewshot_random_seed: int = 1234,
pass_at_k: int = 1,
generations_per_sample: int = 1,
) -> DataSpec:
if icl_task_type == 'multiple_choice':
Expand Down Expand Up @@ -1165,6 +1176,7 @@ def build_icl_dataloader(
destination_path=destination_path,
code_prelimiter=question_prelimiter,
fewshot_random_seed=fewshot_random_seed,
pass_at_k=pass_at_k,
generations_per_sample=generations_per_sample)
effective_batchsize = batch_size
else:
Expand Down Expand Up @@ -1248,6 +1260,7 @@ def get_icl_task_dataloader(
destination_path: str = '',
question_prelimiter: str = '', # e.g. 'Question: '
fewshot_random_seed: int = 1234,
pass_at_k: int = 1,
generations_per_sample: int = 1,
has_categories: bool = False) -> Union[DataSpec, Dict[str, DataSpec]]:
"""This constructs a dataloader (or dataloaders if has_categories is True) capable of evaluating LLMs on in-context learning language modeling tasks, for example LAMBADA. An example usage is below:
Expand Down Expand Up @@ -1316,6 +1329,7 @@ def get_icl_task_dataloader(
partition_uri + '_tmp',
question_prelimiter,
fewshot_random_seed,
pass_at_k,
generations_per_sample,
)
return result_dls
Expand All @@ -1334,5 +1348,6 @@ def get_icl_task_dataloader(
destination_path,
question_prelimiter,
fewshot_random_seed,
pass_at_k,
generations_per_sample,
)
33 changes: 28 additions & 5 deletions composer/metrics/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import warnings
from typing import Any, Dict, List, Mapping, Union

import numpy as np
import torch
from torch import Tensor
from torch.nn import functional as F
Expand Down Expand Up @@ -542,6 +543,18 @@ def get_client(self) -> EvalClient:
'CODE_EVAL_DEVICE to LOCAL or LAMBDA.')
return client

def estimator(self, n: int, c: int, k: int) -> float:
"""Computes the pass@k metric.

Given the number of generated samples, n, the number of correct samples, c, and the k of interest,
this function calculates pass@k as 1 - comb(n - c, k) / comb(n, k) as per the definition of
pass@k in the HumanEval paper (https://arxiv.org/abs/2107.03374) and it's associated implementation:
https://github.com/openai/human-eval.
"""
if n - c < k:
return 1.0
return 1.0 - float(np.prod(1.0 - k / np.arange(n - c + 1, n + 1)))

def update(self, batch: Dict[str, Any], outputs: List[str], labels: List[str]):
"""Updates the pass@k accuracy of code generation.

Expand Down Expand Up @@ -569,8 +582,11 @@ def update(self, batch: Dict[str, Any], outputs: List[str], labels: List[str]):
del labels # never used
client = self.get_client()

num_beams = batch['generation_kwargs']['num_beams']
processed_outputs = [outputs[i * num_beams:(i + 1) * num_beams] for i in range(len(batch['prompts']))]
pass_at_k = batch['pass_at_k']
num_generations = batch['generation_kwargs']['num_return_sequences']
processed_outputs = [
outputs[i * num_generations:(i + 1) * num_generations] for i in range(len(batch['prompts']))
]
payloads = []
for sample_outputs, sample_prompt, test_inputs, test_outputs, entry_point, language in zip(
processed_outputs, batch['prompts'], batch['test_inputs'], batch['test_outputs'], batch['entry_points'],
Expand All @@ -595,9 +611,16 @@ def update(self, batch: Dict[str, Any], outputs: List[str], labels: List[str]):
payloads.append(prompt_payload)

results = client.invoke(payloads)
passes = sum(
[any(all(generation_payload) for generation_payload in prompt_payload) for prompt_payload in results])
self.correct += torch.tensor(float(passes))
for prompt in results:
num_correct = 0
for generation in prompt:
correct = all(generation)
if correct:
num_correct += 1

pass_at_k_rate = self.estimator(num_generations, num_correct, pass_at_k)
self.correct += torch.tensor(pass_at_k_rate)

client.close() # pyright: ignore [reportOptionalMemberAccess]

def compute(self):
Expand Down
30 changes: 30 additions & 0 deletions tests/datasets/test_in_context_learning_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,9 @@ def test_code_eval_split_batch(dataset_uri, tmp_path):
assert len(split2[k]) == 2
assert all(isinstance(val, v) for val in split1[k] + split2[k])

assert isinstance(split1['pass_at_k'], int)
assert isinstance(split2['pass_at_k'], int)

assert isinstance(split1['generation_length'], int)
assert isinstance(split2['generation_length'], int)

Expand Down Expand Up @@ -806,6 +809,33 @@ def test_code_eval_test_cases(dataset_uri, tmp_path):
assert result == eval(test_output)


@pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl'])
def test_code_eval_pass_at_k_validity(dataset_uri, tmp_path):
pytest.importorskip('datasets')

local_data = os.path.join(os.path.dirname(__file__), 'local_data')

tokenizer = AutoTokenizer.from_pretrained('huggyllama/llama-7b')
dataset_uri = f'{local_data}/{dataset_uri}'
batch_size = 9
seqlen = 2048

with pytest.raises(ValueError, match=r'.* pass_at_k .*'):
get_icl_task_dataloader('code_evaluation',
dataset_uri,
tokenizer,
batch_size,
max_seq_len=seqlen,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=0,
prompt_string='',
example_delimiter='\n',
question_prelimiter='Code start: \n',
destination_path=str(tmp_path / f'icl_.jsonl'),
pass_at_k=10,
generations_per_sample=1)


@pytest.mark.parametrize('dataset_uri', ['human_eval_small.jsonl'])
@pytest.mark.parametrize('num_fewshot', [0, 1, 2, 3])
@pytest.mark.parametrize('prompt_string', ['Please code:\n', ''])
Expand Down
12 changes: 10 additions & 2 deletions tests/metrics/test_nlp_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,13 @@ def test_in_context_learning_code_eval_accuracy(monkeypatch):
languages = ['python', 'python', 'python']
monkeypatch.setenv('CODE_EVAL_DEVICE', 'LOCAL')
batch = {
# This tests deterministic beam search rather than sampling
'generation_kwargs': {
'num_beams': 2
'num_beams': 1,
'num_return_sequences': 2
},
'prompts': prompts,
'pass_at_k': 1,
'entry_points': entry_points,
'test_inputs': test_inputs,
'test_outputs': test_outputs,
Expand All @@ -264,7 +267,12 @@ def test_in_context_learning_code_eval_accuracy(monkeypatch):
metric = InContextLearningCodeEvalAccuracy()
metric.update(batch, outputs, labels)

assert metric.compute() == (2 / 3)
# pass@1 values
# program 1: 0
# program 2: 1
# program 3: .5
# mean: 0.5
assert metric.compute() == 0.5


def test_in_context_learning_mc_accuracy(tiny_gpt2_tokenizer):
Expand Down
Loading