-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Embedding] Add embedding training #9508
Merged
Merged
Changes from 18 commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
c141720
add Qwen2SentenceEmbedding
DrownFish19 6e9efb2
update modeling
DrownFish19 89d23e6
Merge remote-tracking branch 'paddlenlp/develop' into dev_20241121_ad…
DrownFish19 4d974e5
add embedding trainer
DesmonDay f408f2e
embedding
DrownFish19 a3da81b
fix
DrownFish19 18405ed
Merge remote-tracking branch 'paddlenlp-daisiming/add_embedding_train…
DrownFish19 f8e877b
Merge remote-tracking branch 'paddlenlp/develop' into dev_20241121_ad…
DrownFish19 e6394ad
update
DrownFish19 ba2c286
support cross device
DesmonDay a71783b
Merge branch 'dev_20241121_add_qwen2_embedding' of https://github.com…
DesmonDay 759d832
update trainer
DesmonDay b92df93
add loss
DesmonDay d3d5a7f
delete unused code
DesmonDay 88ba45e
delete unused code
DesmonDay b5c08aa
optimize code
DesmonDay d815fce
update
DesmonDay 0a618b0
update
DesmonDay 344d4f0
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
DesmonDay 39d324d
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
DesmonDay File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
{ | ||
"model_name_or_path": "Qwen/Qwen2-0.5B", | ||
"dataset_name_or_path": "./dureader_data", | ||
"output_dir": "./checkpoints/sft_ckpts", | ||
"per_device_train_batch_size": 1, | ||
"gradient_accumulation_steps": 4, | ||
"per_device_eval_batch_size": 1, | ||
"eval_accumulation_steps": 1, | ||
"max_steps": 2000, | ||
"learning_rate": 3e-5, | ||
"warmup_steps": 30, | ||
"logging_steps": 1, | ||
"evaluation_strategy": "no", | ||
"save_strategy": "epoch", | ||
"max_query_len": 512, | ||
"max_passage_len": 512, | ||
"group_size": 4, | ||
"bf16": true, | ||
"fp16_opt_level": "O2", | ||
"do_train": true, | ||
"do_eval": false, | ||
"disable_tqdm": true, | ||
"load_best_model_at_end": false, | ||
"eval_with_do_generation": false, | ||
"metric_for_best_model": "accuracy", | ||
"recompute": true, | ||
"save_total_limit": 1, | ||
"tensor_parallel_degree": 1, | ||
"pipeline_parallel_degree": 1, | ||
"sharding": "stage1", | ||
"zero_padding": false, | ||
"unified_checkpoint": true, | ||
"use_flash_attention": true, | ||
"amp_custom_black_list": "elementwise_div", | ||
"release_grads": true | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,288 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# import inspect | ||
import os | ||
import sys | ||
|
||
import paddle | ||
from utils.argument import EmbeddingArgument | ||
|
||
from paddlenlp.data import DataCollatorForEmbedding | ||
from paddlenlp.datasets import EmbeddingIterableDataset, load_dataset | ||
from paddlenlp.trainer import PdArgumentParser, get_last_checkpoint, set_seed | ||
from paddlenlp.trainer.trainer_callback import TrainerState | ||
from paddlenlp.transformers import ( | ||
AutoConfig, | ||
AutoTokenizer, | ||
Qwen2Config, | ||
Qwen2SentenceEmbedding, | ||
) | ||
from paddlenlp.transformers.configuration_utils import LlmMetaConfig | ||
from paddlenlp.transformers.refined_recompute import update_refined_recompute | ||
from paddlenlp.trl import DataConfig, EmbeddingTrainer, ModelConfig, SFTConfig | ||
from paddlenlp.trl.llm_utils import compute_metrics, init_chat_template | ||
from paddlenlp.utils.log import logger | ||
|
||
# Fine-tune Environment Variables to support sharding stage1 overlap optimization. | ||
os.environ["USE_CASUAL_MASK"] = "False" | ||
|
||
|
||
def main(): | ||
parser = PdArgumentParser((ModelConfig, DataConfig, SFTConfig, EmbeddingArgument)) | ||
if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"): | ||
model_args, data_args, training_args, embedding_args = parser.parse_json_file_and_cmd_lines() | ||
else: | ||
model_args, data_args, training_args, embedding_args = parser.parse_args_into_dataclasses() | ||
|
||
training_args.print_config(model_args, "Model") | ||
training_args.print_config(data_args, "Data") | ||
|
||
# Setup GPU & distributed training | ||
paddle.set_device(training_args.device) | ||
set_seed(seed=training_args.seed) | ||
logger.warning( | ||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, " | ||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}" | ||
) | ||
|
||
if training_args.pipeline_parallel_degree > 1: | ||
raise NotImplementedError("Cannot support pipeline parallel for Embedding training now.") | ||
|
||
# Detecting last checkpoint. | ||
last_checkpoint = None | ||
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: | ||
last_checkpoint = get_last_checkpoint(training_args.output_dir) | ||
if last_checkpoint is not None and training_args.resume_from_checkpoint is None: | ||
logger.info( | ||
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " | ||
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch." | ||
) | ||
|
||
# Load model | ||
if training_args.fp16_opt_level == "O2": | ||
if training_args.fp16: | ||
dtype = "float16" | ||
elif training_args.bf16: | ||
dtype = "bfloat16" | ||
else: | ||
raise ValueError("Please specific dtype: --fp16 or --bf16") | ||
else: | ||
dtype = "float32" | ||
|
||
model_config = AutoConfig.from_pretrained( | ||
model_args.model_name_or_path, | ||
dtype=dtype, | ||
from_aistudio=model_args.from_aistudio, | ||
) | ||
assert isinstance(model_config, Qwen2Config), "Now only qwen2 supported" | ||
|
||
LlmMetaConfig.set_llm_config(model_config, training_args) | ||
model_config.refined_recompute = update_refined_recompute(training_args.refined_recompute) | ||
model_config.use_fast_layer_norm = model_args.use_fast_layer_norm | ||
|
||
# Config for model using dropout, such as GPT. | ||
if hasattr(model_config, "hidden_dropout_prob"): | ||
model_config.hidden_dropout_prob = model_args.hidden_dropout_prob | ||
if hasattr(model_config, "attention_probs_dropout_prob"): | ||
model_config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob | ||
if hasattr(model_config, "ignore_index"): | ||
model_config.ignore_index = -100 | ||
|
||
if model_args.fuse_attention_qkv is not None: | ||
model_config.fuse_attention_qkv = model_args.fuse_attention_qkv | ||
if model_args.fuse_attention_ffn is not None: | ||
model_config.fuse_attention_ffn = model_args.fuse_attention_ffn | ||
|
||
model_config.seq_length = data_args.max_length | ||
model_config.embedding_negatives_cross_device = embedding_args.embedding_negatives_cross_device | ||
logger.info(f"Final model config: {model_config}") | ||
|
||
model_class = Qwen2SentenceEmbedding | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 后面改掉吧 可以搞一个 Auto的class |
||
|
||
if model_args.continue_training and not training_args.autotuner_benchmark: | ||
model = model_class.from_pretrained( | ||
model_args.model_name_or_path, | ||
config=model_config, | ||
from_aistudio=model_args.from_aistudio, | ||
) | ||
else: | ||
model = model_class.from_config(model_config, dtype=dtype) | ||
|
||
if model_args.flash_mask and (not data_args.zero_padding or not model.config.use_flash_attention): | ||
logger.warning("`flash_mask` must use with zero padding and flash attention.") | ||
data_args.zero_padding = True | ||
model.config.use_flash_attention = True | ||
|
||
# Load tokenizer & dataset | ||
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, from_aistudio=model_args.from_aistudio) | ||
|
||
# init chat_template for tokenizer | ||
init_chat_template(tokenizer, model_args.model_name_or_path, data_args.chat_template) | ||
|
||
# if using chat_template, data_args.eval_with_do_generation must be false | ||
if tokenizer.chat_template is not None: | ||
data_args.eval_with_do_generation = False | ||
|
||
if training_args.do_eval: | ||
logger.warning("Warning: 'do_eval' is set to True, but will be set to False for Embedding training currently.") | ||
training_args.do_eval = False | ||
training_args.evaluation_strategy = "no" | ||
|
||
if data_args.dataset_name_or_path is None: | ||
raise ValueError(f"Please specific dataset name or path (got {data_args.dataset_name_or_path})") | ||
elif os.path.exists(os.path.join(data_args.dataset_name_or_path, "train.json")) or os.path.exists( | ||
os.path.join(data_args.dataset_name_or_path, "dev.json") | ||
): | ||
if training_args.do_train: | ||
train_ds = load_dataset( | ||
"json", | ||
data_files=os.path.join(data_args.dataset_name_or_path, "train.json"), | ||
lazy=data_args.lazy, | ||
)[0] | ||
else: | ||
train_ds = None | ||
if training_args.do_eval: | ||
dev_ds = load_dataset( | ||
"json", | ||
data_files=os.path.join(data_args.dataset_name_or_path, "dev.json"), | ||
lazy=data_args.lazy, | ||
)[0] | ||
else: | ||
dev_ds = None | ||
|
||
elif os.path.exists(os.path.join(data_args.dataset_name_or_path, "train")) or os.path.exists( | ||
os.path.join(data_args.dataset_name_or_path, "dev") | ||
): | ||
import glob | ||
|
||
if training_args.do_train: | ||
train_ds = load_dataset( | ||
"json", | ||
data_files=glob.glob(os.path.join(data_args.dataset_name_or_path, "train", "*.json")), | ||
lazy=data_args.lazy, | ||
)[0] | ||
else: | ||
train_ds = None | ||
if training_args.do_eval: | ||
dev_ds = load_dataset( | ||
"json", | ||
data_files=glob.glob(os.path.join(data_args.dataset_name_or_path, "dev", "*.json")), | ||
lazy=data_args.lazy, | ||
)[0] | ||
else: | ||
dev_ds = None | ||
|
||
else: | ||
if training_args.do_train: | ||
train_ds = load_dataset(data_args.dataset_name_or_path, splits=["train"])[0] | ||
else: | ||
train_ds = None | ||
if training_args.do_eval: | ||
dev_ds = load_dataset(data_args.dataset_name_or_path, splits=["dev"])[0] | ||
else: | ||
dev_ds = None | ||
|
||
# TODO(ZHUI & sijunhe): Temporary implementation. Generalize this logic and move to Trainer later. | ||
if training_args.resume_from_checkpoint is not None and data_args.lazy: | ||
logger.info( | ||
f"Loading from '{training_args.resume_from_checkpoint}' with `lazy=True`, manually skipping dataset and setting `ignore_data_skip` to True." | ||
) | ||
training_args.ignore_data_skip = True | ||
state = TrainerState.load_from_json(os.path.join(training_args.resume_from_checkpoint, "trainer_state.json")) | ||
if state.trial_params is not None and "zero_padding_global_step" in state.trial_params: | ||
consumed_samples = state.trial_params["zero_padding_global_step"] | ||
else: | ||
consumed_samples = ( | ||
state.global_step | ||
* training_args.per_device_train_batch_size | ||
* training_args.gradient_accumulation_steps | ||
* training_args.dataset_world_size | ||
) | ||
logger.info( | ||
f"Skipping the first {consumed_samples} samples to warmup the dataset from checkpoint '{training_args.resume_from_checkpoint}'." | ||
) | ||
train_ds = train_ds.skip(consumed_samples) | ||
|
||
if train_ds is not None: | ||
train_ds = EmbeddingIterableDataset( | ||
train_ds, | ||
tokenizer, | ||
max_query_len=embedding_args.max_query_len, | ||
max_passage_len=embedding_args.max_passage_len, | ||
group_size=embedding_args.group_size, | ||
query_template=embedding_args.query_template, | ||
passage_template=embedding_args.passage_template, | ||
) | ||
|
||
if dev_ds is not None: | ||
dev_ds = EmbeddingIterableDataset( | ||
dev_ds, | ||
tokenizer, | ||
max_query_len=embedding_args.max_query_len, | ||
max_passage_len=embedding_args.max_passage_len, | ||
group_size=embedding_args.group_size, | ||
query_template=embedding_args.query_template, | ||
passage_template=embedding_args.passage_template, | ||
) | ||
|
||
# Create trainer | ||
if data_args.pad_to_max_length: | ||
padding = "max_length" | ||
else: | ||
padding = True | ||
|
||
data_collator_fn = DataCollatorForEmbedding( | ||
tokenizer=tokenizer, | ||
max_query_len=embedding_args.max_query_len, | ||
padding=padding, | ||
max_passage_len=embedding_args.max_passage_len, | ||
return_tensors="np", | ||
return_attention_mask=not model_args.flash_mask, | ||
pad_to_multiple_of=data_args.pad_to_multiple_of, | ||
) | ||
trainer = EmbeddingTrainer( | ||
model=model, | ||
model_args=embedding_args, | ||
args=training_args, | ||
train_dataset=train_ds, | ||
eval_dataset=dev_ds, | ||
tokenizer=tokenizer, | ||
compute_metrics=compute_metrics, | ||
data_collator=data_collator_fn, | ||
) | ||
trainable_parameters = [p for p in model.parameters() if not p.stop_gradient] | ||
trainer.set_optimizer_grouped_parameters(trainable_parameters) | ||
Comment on lines
+264
to
+265
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个 @lugimzzz 之前是为啥加来着? |
||
|
||
# Train | ||
if training_args.do_train: | ||
checkpoint = None | ||
if training_args.resume_from_checkpoint is not None: | ||
checkpoint = training_args.resume_from_checkpoint | ||
elif last_checkpoint is not None: | ||
checkpoint = last_checkpoint | ||
train_result = trainer.train(resume_from_checkpoint=checkpoint) | ||
trainer.save_model(merge_tensor_parallel=training_args.tensor_parallel_degree > 1) | ||
trainer.log_metrics("train", train_result.metrics) | ||
trainer.save_metrics("train", train_result.metrics) | ||
trainer.save_state() | ||
|
||
# Evaluation dev set | ||
if training_args.do_eval: | ||
logger.info("*** Evaluate result after train ***") | ||
eval_result = trainer.evaluate(dev_ds) | ||
trainer.log_metrics("eval", eval_result) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用方法,数据集格式,加一个readme吧