-
Notifications
You must be signed in to change notification settings - Fork 12
/
finetune.py
381 lines (335 loc) · 13.6 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
#!/bin/python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import os
import sys
from typing import List
import datetime
import fire
import torch
import transformers
from datasets import load_dataset
import wandb
from transformers import (
TrainerCallback,
TrainingArguments,
TrainerState,
TrainerControl,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers import LlamaForCausalLM, LlamaTokenizer
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_int8_training,
set_peft_model_state_dict,
)
from chatbots.conversation import Conversation
from src.utils import fc_prefix, fc_suffix
class SavePeftModelCallback(TrainerCallback):
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
checkpoint_folder = os.path.join(
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
)
kwargs["model"].save_pretrained(checkpoint_folder)
pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
torch.save({}, pytorch_model_path)
return control
class LoadBestPeftModelCallback(TrainerCallback):
def on_train_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
print(
f"Loading best peft model from {state.best_model_checkpoint} (score: {state.best_metric})."
)
best_model_path = os.path.join(state.best_model_checkpoint, "adapter_model.bin")
adapters_weights = torch.load(best_model_path)
model = kwargs["model"]
set_peft_model_state_dict(model, adapters_weights)
return control
def train(
# model/data params
base_model: str = "",
data_path: str = "",
output_dir: str = "",
# training hyperparams
batch_size: int = 128,
micro_batch_size: int = 8,
num_epochs: int = 1,
learning_rate: float = 3e-4,
cutoff_len: int = 4096,
val_set_size: int = 0,
eval_steps: int = 200,
save_steps: int = 1000,
lr_scheduler: str = "cosine",
warmup_steps: int = 100,
# lora hyperparams
lora_r: int = 16,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
# from peft docs: ["q_proj", "k_proj", "v_proj", "o_proj", "fc_in", "fc_out", "wte", "gate_proj", "down_proj", "up_proj"]
lora_target_modules: List[str] = ["gate_proj", "down_proj", "up_proj"],
# llm hyperparams
train_on_response: bool = False, # if False, masks out inputs in loss
add_eos_token: bool = False,
group_by_length: bool = False, # faster, but produces an odd training loss curve
# wandb params
wandb_project: str = "",
wandb_run_name: str = "",
wandb_watch: str = "", # options: false | gradients | all
wandb_log_model: str = "", # options: false | true
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
):
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
print(
f"base_model: {base_model}\n"
f"data_path: {data_path}\n"
f"output_dir: {output_dir}\n"
f"batch_size: {batch_size}\n"
f"micro_batch_size: {micro_batch_size}\n"
f"num_epochs: {num_epochs}\n"
f"learning_rate: {learning_rate}\n"
f"cutoff_len: {cutoff_len}\n"
f"val_set_size: {val_set_size}\n"
f"eval_steps: {eval_steps}\n"
f"save_steps: {save_steps}\n"
f"lr_scheduler: {lr_scheduler}\n"
f"warmup_steps: {warmup_steps}\n"
f"lora_r: {lora_r}\n"
f"lora_alpha: {lora_alpha}\n"
f"lora_dropout: {lora_dropout}\n"
f"lora_target_modules: {lora_target_modules}\n"
f"train_on_response: {train_on_response}\n"
f"add_eos_token: {add_eos_token}\n"
f"group_by_length: {group_by_length}\n"
f"wandb_project: {wandb_project}\n"
f"wandb_run_name: {wandb_run_name}\n"
f"wandb_watch: {wandb_watch}\n"
f"wandb_log_model: {wandb_log_model}\n"
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
)
assert (
base_model
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
gradient_accumulation_steps = batch_size // micro_batch_size
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
print("world size: ", world_size)
if ddp:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
gradient_accumulation_steps = gradient_accumulation_steps // world_size
print("gradient_accumulation_steps: ", gradient_accumulation_steps)
# Check if parameter passed or if set within environ
use_wandb = len(wandb_project) > 0 or (
"WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
)
# Only overwrite environ if wandb param passed
if len(wandb_project) > 0:
os.environ["WANDB_PROJECT"] = wandb_project
if len(wandb_watch) > 0:
os.environ["WANDB_WATCH"] = wandb_watch
if len(wandb_log_model) > 0:
os.environ["WANDB_LOG_MODEL"] = wandb_log_model
model = LlamaForCausalLM.from_pretrained(
base_model, load_in_8bit=True, torch_dtype=torch.float16, device_map=device_map
)
tokenizer = LlamaTokenizer.from_pretrained(base_model)
bos = tokenizer.bos_token_id
eos = tokenizer.eos_token_id
pad = tokenizer.pad_token_id
print(
"pre-trained model's BOS EOS and PAD token id:",
bos,
eos,
pad,
" => It should be 1 2 None",
)
tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
tokenizer.padding_side = "right"
def tokenize(prompt, add_eos_token=True):
result = tokenizer(
prompt,
truncation=True,
max_length=cutoff_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != tokenizer.eos_token_id
and len(result["input_ids"]) < cutoff_len
and add_eos_token
):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
def generate_and_tokenize_conversation(data_point):
# data in the conversation fashion
system_prompt = data_point["system"]
conversation = data_point["conversation"]
labels = []
tokenized_ids = []
target_len = 0
# system message part
tokenized_system_prompt = tokenize(system_prompt, add_eos_token=add_eos_token)
system_prompt_len = len(tokenized_system_prompt["input_ids"])
if add_eos_token:
system_prompt_len -= 1
tokenized_ids += tokenized_system_prompt["input_ids"]
labels += [-100] * system_prompt_len
# conversation
for message in conversation:
###### user turn ######
if message["role"] == "user":
user_input = message["content"]
tokenized_user_input = tokenize(user_input, add_eos_token=add_eos_token)
user_input_len = len(tokenized_user_input["input_ids"])
if add_eos_token:
user_input_len -= 1
tokenized_ids += tokenized_user_input["input_ids"]
labels += [-100] * user_input_len
###### assistant turn ######
elif message["role"] == "assistant":
assistant_output = message["content"]
### select target output ###
if fc_prefix in assistant_output and fc_suffix in assistant_output:
function_call = assistant_output.split(fc_suffix)[0] + fc_suffix
response = assistant_output.split(fc_suffix)[1]
else:
function_call = ""
response = assistant_output
### part 1: function call ###
if function_call:
tokenized_function_call = tokenize(
function_call, add_eos_token=add_eos_token
)
function_call_output_len = len(tokenized_function_call["input_ids"])
if add_eos_token:
function_call_output_len -= 1
tokenized_ids += tokenized_function_call["input_ids"]
labels += tokenized_function_call["input_ids"]
target_len += function_call_output_len
### part 2: response ###
if response:
tokenized_response = tokenize(response, add_eos_token=add_eos_token)
response_output_len = len(tokenized_response["input_ids"])
if add_eos_token:
response_output_len -= 1
tokenized_ids += tokenized_response["input_ids"]
if train_on_response:
labels += tokenized_response["input_ids"]
target_len += response_output_len
else:
labels += [-100] * response_output_len
assert len(tokenized_ids) == len(labels)
return {
"input_ids": torch.LongTensor(tokenized_ids),
"labels": torch.LongTensor(labels),
}
model = prepare_model_for_int8_training(model)
config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
if data_path.endswith(".json") or data_path.endswith(".jsonl"):
data = load_dataset("json", data_files=data_path)
else:
data = load_dataset(data_path)
if resume_from_checkpoint:
# Check the available weights and load them
checkpoint_name = os.path.join(
resume_from_checkpoint, "pytorch_model.bin"
) # Full checkpoint
if not os.path.exists(checkpoint_name):
checkpoint_name = os.path.join(
resume_from_checkpoint, "adapter_model.bin"
) # only LoRA model - LoRA config above has to fit
resume_from_checkpoint = False # So the trainer won't try loading its state
# The two files above have a different name depending on how they were saved, but are actually the same.
if os.path.exists(checkpoint_name):
print(f"Restarting from {checkpoint_name}")
adapters_weights = torch.load(checkpoint_name)
set_peft_model_state_dict(model, adapters_weights)
else:
print(f"Checkpoint {checkpoint_name} not found")
model.print_trainable_parameters()
if val_set_size > 0:
train_val = data["train"].train_test_split(
test_size=val_set_size, shuffle=True, seed=42
)
train_data = (
train_val["train"].shuffle().map(generate_and_tokenize_conversation)
)
val_data = train_val["test"].shuffle().map(generate_and_tokenize_conversation)
else:
train_data = data["train"].shuffle().map(generate_and_tokenize_conversation)
val_data = None
if not ddp and torch.cuda.device_count() > 1:
# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
model.is_parallelizable = True
model.model_parallel = True
trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=micro_batch_size,
per_device_eval_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
warmup_steps=warmup_steps,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
# dataloader_num_workers=16,
fp16=True,
logging_steps=1,
optim="adamw_torch",
evaluation_strategy="steps" if val_set_size > 0 else "no",
save_strategy="steps",
eval_steps=eval_steps if val_set_size > 0 else None,
save_steps=save_steps,
lr_scheduler_type=lr_scheduler,
output_dir=output_dir,
save_total_limit=2,
load_best_model_at_end=True if val_set_size > 0 else False,
ddp_find_unused_parameters=False if ddp else None,
group_by_length=group_by_length,
report_to="wandb" if use_wandb else None,
run_name=wandb_run_name if use_wandb else None,
),
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
),
# callbacks=[SavePeftModelCallback, LoadBestPeftModelCallback], # ONLY USE LoadBestPeftModelCallback if val_set_size > 0
)
model.config.use_cache = False
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
model.save_pretrained(output_dir)
# model.base_model.save_pretrained(output_dir)
pytorch_model_path = os.path.join(output_dir, "pytorch_model.bin")
torch.save({}, pytorch_model_path)
tokenizer.save_pretrained(output_dir)
if __name__ == "__main__":
torch.cuda.empty_cache()
fire.Fire(train)