-
Notifications
You must be signed in to change notification settings - Fork 3
/
main_ll_eval.py
150 lines (126 loc) · 5.8 KB
/
main_ll_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# adapted from https://github.com/CeeZh/LLoVi
import os
from pathlib import Path
from util import *
from eval import *
from dataset import get_dataset
from prompts import PromptFactory
from tqdm import tqdm
from pprint import pprint
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from util import loglikelihood_classifier
def launch():
args = parse_args()
pprint(args)
# output
makedir(args.output_base_path)
output_path = os.path.join(args.output_base_path, args.output_filename)
# resume
processed = {}
if not args.start_from_scratch and os.path.exists(output_path):
processed = load_json(output_path)
if 'data' in processed:
processed = processed['data']
# get input
quids_to_exclude = set(list(processed.keys()))
dataset = get_dataset(args, quids_to_exclude=quids_to_exclude, num_examples_to_run=args.num_examples_to_run)
# configure prompt
prompter = PromptFactory().get(args.prompt_type)
# 4-bit quantization mixtral to manage gpu memory
if 'mixtral' in args.model.lower():
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
else:
quantization_config = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# get LLM
model = AutoModelForCausalLM.from_pretrained(args.model,
device_map="auto",
quantization_config=quantization_config,
)
tokenizer = AutoTokenizer.from_pretrained(args.model)
tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
tokenizer.padding_side = "left"
# answer
pbar = tqdm(total=len(dataset))
correct, total = 0, 0
for i, item in enumerate(dataset):
item['narration'] = item['narration'].replace('\n\n', ' ').replace('\n', ' ')
del item['raw_naration']
# batch {question: answer_option} for each answer
prompt_head = prompter.fill_each(**item, prompt_idx=0)
prompt_tails = [prompter.fill_each(**item, prompt_idx=1, answer_id=op_i,
answer=item[f'option{op_i}']) for op_i in ['A','B','C','D','E']]
prompts = [prompt_head + pt_i for pt_i in prompt_tails]
a_prompt_tokens = tokenizer(prompt_tails)
a_prompt_lengths = [len(x) - 1 for x in a_prompt_tokens['input_ids']] # to filter only answer tokens
qa_prompt_tokens = tokenizer(prompts, return_tensors="pt", padding=True)
# for debugging
if i==0:
if 'quid' in item:
print(item['quid'])
else:
print(item['uid'])
print(prompt_head, prompt_tails)
#feeding all options as a batch
labels = qa_prompt_tokens['input_ids'].clone()
for idx, length in enumerate(a_prompt_lengths):
labels[idx, :-length] = -100
model_inputs = {x: y.to(device) for x,y in qa_prompt_tokens.items()}
with torch.no_grad():
model_outputs = model(**model_inputs, labels=labels).logits.detach()
# manage gpu memory by using bs=1 (feeding each option separately)
'''max_tk = 4096
labels = qa_prompt_tokens['input_ids'].clone()[:, -max_tk:]
#print(qa_prompt_tokens['input_ids'].shape, labels.shape)
for idx, length in enumerate(a_prompt_lengths):
labels[idx, :-length] = -100
model_outputs = []
for ch_i in range(len(prompts)):
model_inputs = {x: y[ch_i:ch_i+1, -max_tk:].to(device) for x,y in qa_prompt_tokens.items()}
with torch.no_grad():
model_outputs.append(model(**model_inputs, labels=labels[ch_i:ch_i+1]).logits.detach())
model_outputs = torch.cat(model_outputs, dim=0)
del qa_prompt_tokens, model_inputs'''
# log-likelihood classifier
loss = loglikelihood_classifier(model_outputs, labels)
# select highest-probable option (argmin of CE loss)
pred = loss.argmin().item()
correct += (item['truth'] == pred)
total += 1
ukey_name = 'quid' if 'quid' in item else 'uid'
ukey = item[ukey_name]
processed[ukey] = item
processed[ukey]['prompt_template'] = prompter.get_template_str()
processed[ukey]['pred'] = pred
if i % args.save_every == 0:
save_json(processed, output_path)
pbar.update(1)
torch.cuda.empty_cache()
if total != 0:
print(f"'acc': {correct/total}, 'num_corrects': {correct}, 'num_total': {total}, 'num_valids': {total}")
# incorporate with backup prediction
if len(args.backup_pred_path) > 0:
backup = load_json(args.backup_pred_path)
if 'data' in backup:
backup = backup['data']
for uid in processed:
if processed[uid]['pred'] == -1:
processed[uid]['pred'] = backup[uid]['pred']
# if eval
if not args.disable_eval:
if args.task == 'qa':
if args.dataset == 'egoschema':
processed = eval_qa_egoschema(processed)
elif args.dataset in ['nextqa', 'intentqa', 'nextgqa']:
processed = eval_qa_nextqa(args.anno_path, processed)
elif args.task == 'gqa':
if args.dataset == 'nextgqa':
pred_qa_path = args.nextgqa_pred_qa_path if len(args.nextgqa_pred_qa_path) > 0 else None
processed = eval_gqa(args.nextgqa_gt_ground_path, processed, pred_qa_path=pred_qa_path)
elif args.task == 'sum':
processed, sum_data = eval_sum(processed)
save_json(sum_data, f'{Path(output_path).parent / Path(output_path).stem}_data.json')
save_json(processed, output_path)
if __name__ == '__main__':
launch()