-
Notifications
You must be signed in to change notification settings - Fork 0
/
analyze.py
89 lines (73 loc) · 2.89 KB
/
analyze.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import re
import typer
from rich.console import Console
from rich.table import Table
from rich.align import Align
from constants import Perturbation, SEPARATOR
console = Console()
def main(
model: str = typer.Option(help="Model to use for analysis"),
perturbation: Perturbation = typer.Option(help="Perturbation to analyze"),
):
if "/" in model:
model = model.split("/")[1]
path = "data/experiments/{}/cleaned-{}.txt".format(model, perturbation.value)
data = None
with open(path, "r") as file:
data = file.read()
datapoints = data.split(SEPARATOR)
if datapoints[-1] == "\n":
datapoints = datapoints[:-1]
baseline_correct_count = 0
experiment_correct_count = 0
total_rows = 0
table = Table(
title="\n\n[bold]Results for {}, {} perturbation[/bold]".format(
model, perturbation.value
),
padding=(0, 2),
show_footer=True,
)
table.add_column("ID", style="cyan bold")
table.add_column("Correct Answer", justify="center", style="magenta bold")
table.add_column("Baseline Answer", justify="center", style="magenta bold")
table.add_column("Baseline Correctness", justify="center")
table.add_column("Experiment Answer", justify="center", style="magenta bold")
table.add_column("Experiment Correctness", justify="center")
for datapoint in datapoints:
if datapoint == "\n" or datapoint == "":
continue
idd = re.findall(r"ID:\s*(\d+)", datapoint)[0]
extracted_correct_answer = re.findall(
r">>>> Extracted Correct Answer:\s*(.*?)\n", datapoint
)[0]
extracted_baseline_response = re.findall(
r">>>> Extracted Baseline Response:\s*(.*?)\n", datapoint
)[0]
extracted_experiment_response = re.findall(
r">>>> Extracted Experiment Response:\s*(.*?)\n", datapoint
)[0]
if extracted_baseline_response == extracted_correct_answer:
baseline_correct_count += 1
if extracted_experiment_response == extracted_correct_answer:
experiment_correct_count += 1
total_rows += 1
table.add_row(
idd,
extracted_correct_answer,
extracted_baseline_response,
("✅" if extracted_baseline_response == extracted_correct_answer else "❌"),
extracted_experiment_response,
(
"✅"
if extracted_experiment_response == extracted_correct_answer
else "❌"
),
)
baseline_percentage = (baseline_correct_count / total_rows) * 100
experiment_percentage = (experiment_correct_count / total_rows) * 100
table.columns[3].footer = f"{baseline_percentage:.2f}% Correct"
table.columns[5].footer = f"{experiment_percentage:.2f}% Correct"
console.print(Align.center(table))
if __name__ == "__main__":
typer.run(main)