-
Notifications
You must be signed in to change notification settings - Fork 78
/
train_query.py
307 lines (262 loc) · 11.8 KB
/
train_query.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
import json
import argparse
import torch
import os
import random
import numpy as np
import requests
import logging
import math
import copy
import string
import faiss
from time import time
from tqdm import tqdm
from densephrases.utils.squad_utils import get_question_dataloader
from densephrases.utils.single_utils import load_encoder
from densephrases.utils.open_utils import load_phrase_index, get_query2vec, load_qa_pairs
from densephrases.utils.eval_utils import drqa_exact_match_score, drqa_regex_match_score, \
drqa_metric_max_over_ground_truths
from eval_phrase_retrieval import evaluate
from densephrases import Options
from transformers import (
AdamW,
get_linear_schedule_with_warmup,
)
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)
def train_query_encoder(args, mips=None):
# Freeze one for MIPS
device = 'cuda' if args.cuda else 'cpu'
logger.info("Loading pretrained encoder: this one is for MIPS (fixed)")
pretrained_encoder, tokenizer, _ = load_encoder(device, args)
# Train a copy of it
logger.info("Copying target encoder")
target_encoder = copy.deepcopy(pretrained_encoder)
# MIPS
if mips is None:
mips = load_phrase_index(args)
# Optimizer setting
def is_train_param(name):
if name.endswith(".embeddings.word_embeddings.weight"):
logger.info(f'freezing {name}')
return False
return True
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [{
"params": [
p for n, p in target_encoder.named_parameters() \
if not any(nd in n for nd in no_decay) and is_train_param(n)
],
"weight_decay": 0.01,
}, {
"params": [
p for n, p in target_encoder.named_parameters() \
if any(nd in n for nd in no_decay) and is_train_param(n)
],
"weight_decay": 0.0
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
step_per_epoch = math.ceil(len(load_qa_pairs(args.train_path, args)[1]) / args.per_gpu_train_batch_size)
t_total = int(step_per_epoch // args.gradient_accumulation_steps * args.num_train_epochs)
logger.info(f"Train for {t_total} iterations")
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
)
eval_steps = math.ceil(len(load_qa_pairs(args.dev_path, args)[1]) / args.eval_batch_size)
logger.info(f"Test takes {eval_steps} iterations")
# Train arguments
args.per_gpu_train_batch_size = int(args.per_gpu_train_batch_size / args.gradient_accumulation_steps)
best_acc = -1000.0
for ep_idx in range(int(args.num_train_epochs)):
# Training
total_loss = 0.0
total_accs = []
total_accs_k = []
# Load training dataset
q_ids, questions, answers, titles = load_qa_pairs(args.train_path, args, shuffle=True)
pbar = tqdm(get_top_phrases(
mips, q_ids, questions, answers, titles, pretrained_encoder, tokenizer,
args.per_gpu_train_batch_size, args)
)
for step_idx, (q_ids, questions, answers, titles, outs) in enumerate(pbar):
train_dataloader, _, _ = get_question_dataloader(
questions, tokenizer, args.max_query_length, batch_size=args.per_gpu_train_batch_size
)
svs, evs, tgts, p_tgts = annotate_phrase_vecs(mips, q_ids, questions, answers, titles, outs, args)
target_encoder.train()
svs_t = torch.Tensor(svs).to(device)
evs_t = torch.Tensor(evs).to(device)
tgts_t = [torch.Tensor([tgt_ for tgt_ in tgt if tgt_ is not None]).to(device) for tgt in tgts]
p_tgts_t = [torch.Tensor([tgt_ for tgt_ in tgt if tgt_ is not None]).to(device) for tgt in p_tgts]
# Train query encoder
assert len(train_dataloader) == 1
for batch in train_dataloader:
batch = tuple(t.to(device) for t in batch)
loss, accs = target_encoder.train_query(
input_ids_=batch[0], attention_mask_=batch[1], token_type_ids_=batch[2],
start_vecs=svs_t,
end_vecs=evs_t,
targets=tgts_t,
p_targets=p_tgts_t,
)
# Optimize, get acc and report
if loss is not None:
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
if args.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
total_loss += loss.mean().item()
if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(target_encoder.parameters(), args.max_grad_norm)
optimizer.step()
scheduler.step() # Update learning rate schedule
target_encoder.zero_grad()
pbar.set_description(
f"Ep {ep_idx+1} Tr loss: {loss.mean().item():.2f}, acc: {sum(accs)/len(accs):.3f}"
)
if accs is not None:
total_accs += accs
total_accs_k += [len(tgt) > 0 for tgt in tgts_t]
else:
total_accs += [0.0]*len(tgts_t)
total_accs_k += [0.0]*len(tgts_t)
step_idx += 1
logger.info(
f"Avg train loss ({step_idx} iterations): {total_loss/step_idx:.2f} | train " +
f"acc@1: {sum(total_accs)/len(total_accs):.3f} | acc@{args.top_k}: {sum(total_accs_k)/len(total_accs_k):.3f}"
)
# Evaluation
new_args = copy.deepcopy(args)
new_args.top_k = 10
new_args.save_pred = False
new_args.test_path = args.dev_path
dev_em, dev_f1, dev_emk, dev_f1k = evaluate(new_args, mips, target_encoder, tokenizer)
logger.info(f"Develoment set acc@1: {dev_em:.3f}, f1@1: {dev_f1:.3f}")
# Save best model
if dev_em > best_acc:
best_acc = dev_em
save_path = args.output_dir
if not os.path.exists(save_path):
os.makedirs(save_path)
target_encoder.save_pretrained(save_path)
logger.info(f"Saved best model with acc {best_acc:.3f} into {save_path}")
if (ep_idx + 1) % 1 == 0:
logger.info('Updating pretrained encoder')
pretrained_encoder = copy.deepcopy(target_encoder)
print()
logger.info(f"Best model has acc {best_acc:.3f} saved as {save_path}")
def get_top_phrases(mips, q_ids, questions, answers, titles, query_encoder, tokenizer, batch_size, args):
# Search
step = batch_size
phrase_idxs = []
search_fn = mips.search
query2vec = get_query2vec(
query_encoder=query_encoder, tokenizer=tokenizer, args=args, batch_size=batch_size
)
for q_idx in tqdm(range(0, len(questions), step)):
outs = query2vec(questions[q_idx:q_idx+step])
start = np.concatenate([out[0] for out in outs], 0)
end = np.concatenate([out[1] for out in outs], 0)
query_vec = np.concatenate([start, end], 1)
outs = search_fn(
query_vec,
q_texts=questions[q_idx:q_idx+step], nprobe=args.nprobe,
top_k=args.top_k, return_idxs=True,
max_answer_length=args.max_answer_length, aggregate=args.aggregate, agg_strat=args.agg_strat,
)
yield (
q_ids[q_idx:q_idx+step], questions[q_idx:q_idx+step], answers[q_idx:q_idx+step],
titles[q_idx:q_idx+step], outs
)
def annotate_phrase_vecs(mips, q_ids, questions, answers, titles, phrase_groups, args):
assert mips is not None
batch_size = len(answers)
# Phrase groups are in size of [batch, top_k, values]
# phrase_groups = [[(
# out_['doc_idx'], out_['start_idx'], out_['end_idx'], out_['answer'],
# out_['start_vec'], out_['end_vec'], out_['context'], out_['title'])
# for out_ in out] for out in outs
# ]
dummy_group = {
'doc_idx': -1,
'start_idx': 0, 'end_idx': 0,
'answer': '',
'start_vec': np.zeros(768),
'end_vec': np.zeros(768),
'context': '', 'title': ['']
}
# Pad phrase groups (two separate top-k coming from start/end, so pad with top_k*2)
for b_idx, phrase_idx in enumerate(phrase_groups):
while len(phrase_groups[b_idx]) < args.top_k*2:
phrase_groups[b_idx].append(dummy_group)
assert len(phrase_groups[b_idx]) == args.top_k*2
# Flatten phrase groups
flat_phrase_groups = [phrase for phrase_group in phrase_groups for phrase in phrase_group]
doc_idxs = [int(phrase_group['doc_idx']) for phrase_group in flat_phrase_groups]
start_vecs = [phrase_group['start_vec'] for phrase_group in flat_phrase_groups]
end_vecs = [phrase_group['end_vec'] for phrase_group in flat_phrase_groups]
# stack vectors
start_vecs = np.stack(start_vecs)
end_vecs = np.stack(end_vecs)
zero_mask = np.array([[1] if doc_idx >= 0 else [0] for doc_idx in doc_idxs])
start_vecs = start_vecs * zero_mask
end_vecs = end_vecs * zero_mask
# Reshape
start_vecs = np.reshape(start_vecs, (batch_size, args.top_k*2, -1))
end_vecs = np.reshape(end_vecs, (batch_size, args.top_k*2, -1))
# Dummy targets
targets = [[None for phrase in phrase_group] for phrase_group in phrase_groups]
p_targets = [[None for phrase in phrase_group] for phrase_group in phrase_groups]
# TODO: implement dynamic label_strategy based on the task name (label_strat = dynamic)
# Annotate for L_phrase
if 'phrase' in args.label_strat.split(','):
match_fns = [
drqa_regex_match_score if args.regex or ('trec' in q_id.lower()) else drqa_exact_match_score for q_id in q_ids
]
targets = [
[drqa_metric_max_over_ground_truths(match_fn, phrase['answer'], answer_set) for phrase in phrase_group]
for phrase_group, answer_set, match_fn in zip(phrase_groups, answers, match_fns)
]
targets = [[ii if val else None for ii, val in enumerate(target)] for target in targets]
# Annotate for L_doc
if 'doc' in args.label_strat.split(','):
p_targets = [
[any(phrase['title'][0].lower() == tit.lower() for tit in title) for phrase in phrase_group]
for phrase_group, title in zip(phrase_groups, titles)
]
p_targets = [[ii if val else None for ii, val in enumerate(target)] for target in p_targets]
return start_vecs, end_vecs, targets, p_targets
if __name__ == '__main__':
# See options in densephrases.options
options = Options()
options.add_model_options()
options.add_index_options()
options.add_retrieval_options()
options.add_data_options()
options.add_qsft_options()
args = options.parse()
# Seed for reproducibility
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
if args.run_mode == 'train_query':
# Train
mips = load_phrase_index(args)
train_query_encoder(args, mips)
# Eval
args.load_dir = args.output_dir
logger.info(f"Evaluating {args.load_dir}")
args.top_k = 10
evaluate(args, mips)
else:
raise NotImplementedError