-
Notifications
You must be signed in to change notification settings - Fork 6
/
verify.py
110 lines (100 loc) · 3.02 KB
/
verify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""Verify the correctness of model predictions. """
import os
import random
import evaluate
from typing import Dict, List
from processor import CodeProcessor
from prompt import add_indent, get_entry_point
code_eval_metric = evaluate.load("code_eval")
# os.environ["HF_ALLOW_CODE_EVAL"] = "0"
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
def get_valid_solutions(
predictions: List[str],
deduplicate: bool = False,
) -> List[str]:
processor = CodeProcessor()
solutions = []
for pred in predictions:
valid_sol = processor.code_extract(pred)
solutions.append(valid_sol)
if deduplicate:
solutions = list(set(solutions))
return solutions
def wrap_check(
sample: Dict,
solution_list: List[str],
k: List[int],
num_workers: int = 1,
max_num_tests: int = 1,
verbose: bool = False,
exclude_suffix: bool = False,
function_name: str = "id",
):
if exclude_suffix:
wrapped_solution_list = [
f"{sample['prompt']}{solution}".replace('\t', ' '*4)
for solution in solution_list
]
else:
wrapped_solution_list = [
f"{sample['prompt']}{solution}{sample['suffix']}".replace('\t', ' '*4)
for solution in solution_list
]
max_num_tests = min(len(sample["test"]), max_num_tests)
test_case = random.sample(sample["test"], max_num_tests)
entry_point = get_entry_point(sample, function_name)
check_function = '\n'.join([
sample['test_start'],
''.join(test_case),
'',
f"check({entry_point})",
])
scores, outputs = code_eval_metric.compute(
predictions=[wrapped_solution_list],
references=[check_function],
k=k,
num_workers=num_workers,
)
if verbose:
print(f"[predic] {wrapped_solution_list}")
print(f"[fcheck] {check_function}")
print(f"[scores] {scores}")
print(f"[output] {outputs[0]}")
return scores, outputs[0]
def wrap_check_test(
prompt: str, suffix: str,
solution_list: List[str],
test_start: str,
test_case: str,
entry_point: str,
k: List[int] = [1],
num_workers: int = 1,
add_indent_test: bool = True,
verbose: bool = False,
):
wrapped_solution_list = [
f"{prompt}{solution}{suffix}".replace('\t', ' '*4)
for solution in solution_list
]
if add_indent_test == True:
test_case = add_indent(test_case)
check_function = '\n'.join([
test_start,
test_case,
'',
f"check({entry_point})",
])
if verbose:
print(f"[solution list] \n{wrapped_solution_list}")
print(f"[check function] \n{check_function}")
scores, outputs = code_eval_metric.compute(
predictions=[wrapped_solution_list],
references=[check_function],
k=k,
num_workers=num_workers,
)
if verbose:
print(f"[scores] {scores}")
print(f"[output] {outputs[0]}")
print('-'*25)
return scores, outputs[0]