-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_distill.py
473 lines (435 loc) · 24.2 KB
/
model_distill.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
import torch
import torch.nn as nn
from typing import Optional, List, Union, Tuple, Dict, Any
from transformers.modeling_outputs import CausalLMOutputWithPast, SequenceClassifierOutputWithPast, \
SequenceClassifierOutput, CausalLMOutputWithCrossAttentions
from torch.nn import CrossEntropyLoss
from transformers import AutoModelForCausalLM, AutoModel
import torch.nn.functional as F
# from peft import prepare_model_for_int8_training
# from transformers import BitsAndBytesConfig
from transformers import LogitsProcessorList, TopPLogitsWarper, PreTrainedModel
from transformers import AutoTokenizer
import os
# from transformers.models.roberta.modeling_roberta import RobertaModel
from transformers import AutoConfig
from .SmallModel import MyGPT2, add_init_prompt_weight
from collections import OrderedDict
class MyPooler(nn.Module):
def __init__(self, s_hidden_size, t_hidden_size):
super().__init__()
self.dense = nn.Linear(s_hidden_size, t_hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
pooled_output = self.dense(flat_hidden_states)
pooled_output = self.activation(pooled_output)
nested_output = pooled_output.reshape(*hidden_states.shape[:-1], -1)
return nested_output
class DistillModel(PreTrainedModel):
def __init__(self, local_llm_model,
model_args,
config=None,
num_tasks=None,
is_train_llm=False,
):
super().__init__(local_llm_model.config)
self.local_llm_model = local_llm_model
self.config = self.local_llm_model.config
self.is_prompt_tuning = model_args.is_prompt_tuning
self.virtual_demo_len = model_args.virtual_demo_len
self.d_beta = model_args.d_beta
self.d_eta = model_args.d_eta
self.processors = LogitsProcessorList()
self.processors.append(TopPLogitsWarper(0.95))
self.d_tau = model_args.d_tau
self.is_contrast = model_args.is_contrast
self.is_t_demo = model_args.is_t_demo
self.model_args = model_args
self.is_train_llm = is_train_llm
self.sm_pool_name = model_args.sm_pool_name
self.is_kd = model_args.is_kd
self.is_only_kd = model_args.is_only_kd
self.loss_lambda = model_args.loss_lambda
self.is_add_model_encode_demo = model_args.is_add_model_encode_demo
self.is_fid = model_args.is_fid
self.is_shared_prompt = model_args.is_shared_prompt
# FID relevant parameters
self.demo_batch_count = model_args.demo_batch_count
# inner distill
self.is_inner_distill = model_args.is_inner_distill
self.ts_steps = model_args.ts_steps
self.add_same_task_count = model_args.add_same_task_count
self.t_demo_batch_count = model_args.t_demo_batch_count
self.kl_topK = model_args.kl_topK
self.do_train = model_args.do_train
# fake lm_head to fool generate check
# self.lm_head = self.local_llm_model.lm_head
if is_train_llm is False:
if model_args.is_prompt_tuning:
self.num_tasks = num_tasks
self.hypernet = self._init_prompt(d_len=self.virtual_demo_len, num_tasks=self.num_tasks)
else:
config = AutoConfig.from_pretrained(model_args.s_model_name_or_path)
s_tokenizer = AutoTokenizer.from_pretrained(model_args.s_model_name_or_path)
config.mask_token_id = s_tokenizer.mask_token_id
config.demo_batch_count = self.demo_batch_count
config.virtual_demo_len = self.virtual_demo_len
config.label_virtual_demo_len = self.model_args.label_virtual_demo_len
config.expand_s_attention_mask = self.model_args.expand_s_attention_mask
config.is_expand_demo_input_mask = self.model_args.is_expand_demo_input_mask
if 'bart' in model_args.s_model_name_or_path.lower():
self.hypernet = MyFiDBart.from_pretrained(model_args.s_model_name_or_path,
config=config)
self.is_s_roberta = False
elif 'roberta' in model_args.s_model_name_or_path.lower():
config.is_s_reuse_cache = model_args.is_s_reuse_cache
config.is_ed = model_args.is_ed
if self.model_args.is_fid:
config.is_decoder = True
if model_args.is_s_reuse_cache is False and model_args.is_ed is False:
config.add_cross_attention = True
self.hypernet = MyFiDRoBERTa.from_pretrained(model_args.s_model_name_or_path, config=config,
add_pooling_layer=False)
self.is_s_roberta = True
elif "gpt2" in model_args.s_model_name_or_path.lower():
self.hypernet = MyGPT2.from_pretrained(model_args.s_model_name_or_path,
config=config,
# load_in_8bit=not self.do_train
)
elif "t5" in model_args.s_model_name_or_path.lower():
# TODO: Implementation of t5 model here.
raise NotImplementedError
else:
self.hypernet = AutoModel.from_pretrained(model_args.s_model_name_or_path,
config=config,
# load_in_8bit=not self.do_train
)
self.hypernet.adap_pooler = MyPooler(self.hypernet.config.hidden_size,
self.local_llm_model.config.hidden_size) \
if self.hypernet.config.hidden_size != self.local_llm_model.config.hidden_size else None
if self.model_args.is_init_prompt_weight:
self.hypernet.init_prompt_weight = add_init_prompt_weight(self.hypernet.config, model=self.hypernet,
virtual_demo_len=self.virtual_demo_len,
model_args=model_args)
if self.model_args.s_state_dict_path is not None:
s_state_dict = torch.load(self.model_args.s_state_dict_path, map_location='cpu')
s_state_dict = {key.replace("hypernet.", "") if "hypernet" in key else key: value for key, value in
s_state_dict.items()}
self.hypernet.load_state_dict(s_state_dict)
def _init_prompt(self, d_len, num_tasks):
prompt_length = d_len * num_tasks
hypernet = nn.Embedding(d_len * num_tasks, self.local_llm_model.config.hidden_size)
if self.model_args.virtual_demo_init == "random":
hypernet.weight.data.normal_(mean=0.0, std=self.local_llm_model.config.initializer_range)
elif self.model_args.virtual_demo_init == 'vocab':
rand_id = torch.randint(100, self.config.vocab_size, (prompt_length,)).long()
rand_emb = self.embed_encode(rand_id)
hypernet = hypernet.from_pretrained(rand_emb, freeze=False)
# reshape the weight for hypernet
weight = hypernet.weight.data
weight = weight.reshape(num_tasks, -1)
hypernet = nn.Embedding(*weight.shape).from_pretrained(weight, freeze=False)
return hypernet
def get_task_prompt_embed(self, task_index):
if self.is_shared_prompt:
task_index = torch.zeros_like(task_index)
prompt_embed = self.hypernet(task_index)
prompt_embed = prompt_embed.reshape(task_index.shape[0], self.virtual_demo_len, -1)
return prompt_embed
def query_logits_distill(self,
s_logits,
t_logits,
s_input_attention_mask,
t_input_attention_mask,
):
loss_fn = nn.KLDivLoss(reduction="batchmean")
# extract the s_logits
flat_s_logits = s_logits.reshape(-1, s_logits.shape[-1]).contiguous()
flat_t_logits = t_logits.reshape(-1, t_logits.shape[-1]).contiguous()
flat_t_attention_mask = t_input_attention_mask.reshape(-1, 1).contiguous().bool()
flat_s_attention_mask = s_input_attention_mask.reshape(-1, 1).contiguous().bool()
select_s_logits = torch.masked_select(flat_s_logits, flat_s_attention_mask).reshape(-1, flat_s_logits.shape[-1])
select_t_logits = torch.masked_select(flat_t_logits, flat_t_attention_mask).reshape(-1, flat_t_logits.shape[-1])
kl_loss = loss_fn(
F.log_softmax(select_s_logits/self.model_args.temperature, dim=1),
F.softmax(select_t_logits/self.model_args.temperature, dim=1)
)
return kl_loss
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: torch.LongTensor = None,
teacher_input_ids: torch.LongTensor = None,
teacher_attention_mask: Optional[torch.Tensor] = None,
teacher_token_type_ids=None,
demo_ids=None,
demo_attention_mask=None,
num_class=-1,
label_candidates: Optional[torch.LongTensor] = None,
label_candidates_attention_mask: Optional[torch.LongTensor] = None,
label_candidates_token_type_ids: Optional[torch.LongTensor] = None,
add_input_ids=None,
add_attention_mask=None,
add_token_type_ids=None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
clf_label=None,
task_idx=None,
teacher_demo_token_type_ids=None,
is_consecutive=None,
real_num_class_mask=None,
demo_token_type_ids=None,
past_key_values=None,
s_demo_count=None,
**kwargs,
):
is_generation = kwargs.get("is_generation", False)
if clf_label is not None:
batch_size = clf_label.shape[0]
else:
batch_size = input_ids.shape[0]
if (
self.training is False and self.is_t_demo) or self.is_train_llm or self.is_kd or self.model_args.is_meta_gradient_match or self.model_args.is_query_state_match or self.model_args.is_query_kl_loss:
if self.is_train_llm:
t_llm_output = self.local_llm_model(input_ids=teacher_input_ids,
attention_mask=teacher_attention_mask,
use_cache=True,
output_hidden_states=True,
output_attentions=output_attentions
)
else:
# in-context learning inference or teacher
# no need for gradient
with torch.no_grad():
t_llm_output = self.local_llm_model(input_ids=teacher_input_ids,
attention_mask=teacher_attention_mask,
use_cache=True,
output_hidden_states=True,
output_attentions=output_attentions
)
t_logits = t_llm_output.logits
attentions = t_llm_output.attentions
if is_generation:
return CausalLMOutputWithCrossAttentions(past_key_values=t_llm_output.past_key_values, logits=t_llm_output.logits, attentions=attentions)
if ((
self.training is False and self.is_t_demo) or self.is_train_llm) is False or self.is_kd or self.model_args.is_meta_gradient_match or self.is_inner_distill or self.model_args.is_query_state_match or self.model_args.is_query_kl_loss:
# ATTENTION: get the student logits.
distill_attention_mask = None
if self.is_prompt_tuning:
distill_embeds = self.get_task_prompt_embed(task_idx)
else:
if self.model_args.is_encoder_debug:
distill_embeds = self.local_llm_model.get_input_embeddings()(demo_ids)
distill_attention_mask = distill_attention_mask
else:
distill_embeds, distill_attention_mask = self.hypernet(input_ids=demo_ids,
attention_mask=demo_attention_mask,
sm_pool_name=self.sm_pool_name,
demo_token_type_ids=demo_token_type_ids)
if self.hypernet.adap_pooler is not None:
distill_embeds = self.hypernet.adap_pooler(distill_embeds)
# expand for evaluation
if self.training is False and is_generation is False:
assert num_class > 0
distill_embeds = distill_embeds.unsqueeze(1).repeat(1, num_class, 1, 1).reshape(
-1, *distill_embeds.shape[1:]
)
if distill_attention_mask is not None:
distill_attention_mask = distill_attention_mask.repeat(1, num_class, 1).reshape(
-1, *distill_attention_mask.shape[1:]
)
if distill_attention_mask is None:
distill_attention_mask = attention_mask.new_ones(distill_embeds.shape[0], distill_embeds.shape[1])
# TODO: Prepare the 2D attention mask to highlight the special embeddings for outputs
# pretraining loss
if self.model_args.is_seq2seq_pretrain and self.training:
input_ids = teacher_input_ids
attention_mask = teacher_attention_mask
token_type_ids = teacher_attention_mask
distill_embeds = distill_embeds.unsqueeze(1).repeat(1, self.model_args.t_demo_batch_count,
1, 1).reshape(-1, *distill_embeds.shape[1:])
distill_attention_mask = distill_attention_mask.unsqueeze(1).repeat(1,
self.model_args.t_demo_batch_count,
1).reshape(-1,
distill_attention_mask.shape[
-1])
input_ids_embeds = self.local_llm_model.get_input_embeddings()(input_ids)
# concatenate embeds
if self.model_args.is_no_s_embeds:
input_embeds_s = input_ids_embeds
attention_mask_s = attention_mask
else:
input_embeds_s = torch.cat((distill_embeds, input_ids_embeds), dim=1)
attention_mask_s = torch.cat([distill_attention_mask, attention_mask], dim=-1)
# share the past_key_values for other class labels
s_llm_output = self.local_llm_model(
inputs_embeds=input_embeds_s,
attention_mask=attention_mask_s,
use_cache=True,
output_hidden_states=True,
output_attentions=output_attentions
)
# downstream tasks' objective functions
s_logits = s_llm_output.logits
past_key_values = s_llm_output.past_key_values
attentions = s_llm_output.attentions
if is_generation:
return CausalLMOutputWithCrossAttentions(past_key_values=past_key_values, logits=s_logits)
# remove the distilled sentence hidden states
if self.model_args.is_no_s_embeds is False:
s_logits = s_logits[:, distill_embeds.shape[1]:]
if self.training:
if self.is_train_llm:
# MetaICL
train_logits = t_logits
train_gen_labels = torch.clone(teacher_input_ids)
train_token_type_ids = teacher_token_type_ids
train_attention_mask = teacher_attention_mask
else:
train_logits = s_logits
train_gen_labels = torch.clone(input_ids)
train_token_type_ids = token_type_ids
train_attention_mask = attention_mask
# Flatten the tokens
loss_logits = None
loss = None
if self.training and self.model_args.is_only_auxiliary_loss is False:
labels_mask = train_token_type_ids[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
# Shift so that tokens < n predict n
shift_logits = train_logits[..., :-1, :].contiguous()
shift_labels = train_gen_labels[..., 1:].contiguous()
shift_train_attention_mask = train_attention_mask[:, 1:].contiguous()
if is_consecutive is not None:
# only subsect of a batch requires to calculate teacher forcing generation loss.
is_consecutive = is_consecutive.bool().squeeze()
shift_logits = shift_logits[is_consecutive]
shift_labels = shift_labels[is_consecutive]
labels_mask = labels_mask[is_consecutive]
shift_labels = torch.where((labels_mask * shift_train_attention_mask) == 1,
shift_labels,
torch.ones_like(shift_labels) * -100)
shift_logits = shift_logits.view(-1, shift_logits.shape[-1])
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
loss_logits = None
elif self.training is False:
if self.is_t_demo:
shift_logits = t_logits[:, :-1].contiguous()
shift_labels = torch.where(teacher_token_type_ids == 1, teacher_input_ids,
torch.ones_like(teacher_input_ids) * -100)[:, 1:].contiguous()
here_token_type_ids = teacher_token_type_ids[:, 1:].contiguous()
else:
shift_logits = s_logits[:, :-1].contiguous()
shift_labels = torch.where(token_type_ids == 1, input_ids,
torch.ones_like(input_ids) * -100)[:, 1:].contiguous()
here_token_type_ids = token_type_ids[:, 1:].contiguous()
loss_fct = CrossEntropyLoss(reduction='none')
loss = loss_fct(
shift_logits.view(-1, shift_logits.shape[-1]),
shift_labels.view(-1)
)
loss = loss.view(batch_size * num_class, -1)
# agg loss for classification
# shape: batch_size, num_class for classification
loss_logits = -1 * (torch.sum(loss, dim=1) / torch.sum(here_token_type_ids, dim=1)).reshape(batch_size,
num_class)
# mask out the loss_logits
if real_num_class_mask is not None:
loss_logits = loss_logits.masked_fill((1-real_num_class_mask).bool(), -10000)
loss = torch.mean(loss)
# Auxilary loss
if self.training:
# knowledge distillation between teacher and student
# batch_size, seq_length, vocab_size
if self.model_args.is_query_kl_loss:
s_input_token_type_ids = attention_mask
t_input_token_type_ids = (1 - teacher_demo_token_type_ids) * teacher_attention_mask
aux_loss = self.query_logits_distill(
s_logits=s_logits,
t_logits=t_logits,
s_input_attention_mask=s_input_token_type_ids,
t_input_attention_mask=t_input_token_type_ids
)
else:
aux_loss = 0
if aux_loss is not None and type(aux_loss) is not int and torch.isnan(aux_loss):
raise FloatingPointError("aux_loss is nan")
if self.model_args.is_only_auxiliary_loss:
loss = aux_loss
else:
loss += self.loss_lambda * aux_loss
if loss is not None and torch.isnan(loss):
raise FloatingPointError("loss is nan")
return SequenceClassifierOutput(
loss=loss,
logits=loss_logits,
attentions=attentions
)
@torch.no_grad()
def save_pretrained(self, save_directory, safe_serialization=False, **kwargs):
if "state_dict" in kwargs.keys():
kwargs.pop("state_dict")
# we only save the distillation model
if self.is_train_llm:
state_dict = self.local_llm_model.state_dict()
state_dict = OrderedDict({"local_llm_model." + key: value for key, value in state_dict.items()})
self.local_llm_model.save_pretrained(save_directory, safe_serialization, state_dict=state_dict, **kwargs)
else:
state_dict = self.hypernet.state_dict()
state_dict = OrderedDict({"hypernet." + key: value for key, value in state_dict.items()})
if isinstance(self.hypernet, PreTrainedModel):
self.hypernet.save_pretrained(
save_directory, safe_serialization, state_dict=state_dict, **kwargs
)
else:
torch.save(state_dict, save_directory + "/pytorch_model.bin")
def print_trainable_parameters(self):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for _, param in self.named_parameters():
num_params = param.numel()
# if using DS Zero 3 and the weights are initialized empty
if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel
all_param += num_params
if param.requires_grad:
trainable_params += num_params
print(
f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param}"
)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
**kwargs,
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}