diff --git a/llm/config/qwen/emb_argument.json b/llm/config/qwen/emb_argument.json new file mode 100644 index 000000000000..d8c6aeeb7f6e --- /dev/null +++ b/llm/config/qwen/emb_argument.json @@ -0,0 +1,36 @@ +{ + "model_name_or_path": "Qwen/Qwen2-0.5B", + "dataset_name_or_path": "./dureader_data", + "output_dir": "./checkpoints/sft_ckpts", + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 4, + "per_device_eval_batch_size": 1, + "eval_accumulation_steps": 1, + "max_steps": 2000, + "learning_rate": 3e-5, + "warmup_steps": 30, + "logging_steps": 1, + "evaluation_strategy": "no", + "save_strategy": "epoch", + "max_query_len": 512, + "max_passage_len": 512, + "group_size": 4, + "bf16": true, + "fp16_opt_level": "O2", + "do_train": true, + "do_eval": false, + "disable_tqdm": true, + "load_best_model_at_end": false, + "eval_with_do_generation": false, + "metric_for_best_model": "accuracy", + "recompute": true, + "save_total_limit": 1, + "tensor_parallel_degree": 1, + "pipeline_parallel_degree": 1, + "sharding": "stage1", + "zero_padding": false, + "unified_checkpoint": true, + "use_flash_attention": true, + "amp_custom_black_list": "elementwise_div", + "release_grads": true +} diff --git a/llm/run_embedding.py b/llm/run_embedding.py new file mode 100644 index 000000000000..e598f24839cf --- /dev/null +++ b/llm/run_embedding.py @@ -0,0 +1,288 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# import inspect +import os +import sys + +import paddle +from utils.argument import EmbeddingArgument + +from paddlenlp.data import DataCollatorForEmbedding +from paddlenlp.datasets import EmbeddingIterableDataset, load_dataset +from paddlenlp.trainer import PdArgumentParser, get_last_checkpoint, set_seed +from paddlenlp.trainer.trainer_callback import TrainerState +from paddlenlp.transformers import ( + AutoConfig, + AutoTokenizer, + Qwen2Config, + Qwen2SentenceEmbedding, +) +from paddlenlp.transformers.configuration_utils import LlmMetaConfig +from paddlenlp.transformers.refined_recompute import update_refined_recompute +from paddlenlp.trl import DataConfig, EmbeddingTrainer, ModelConfig, SFTConfig +from paddlenlp.trl.llm_utils import compute_metrics, init_chat_template +from paddlenlp.utils.log import logger + +# Fine-tune Environment Variables to support sharding stage1 overlap optimization. +os.environ["USE_CASUAL_MASK"] = "False" + + +def main(): + parser = PdArgumentParser((ModelConfig, DataConfig, SFTConfig, EmbeddingArgument)) + if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"): + model_args, data_args, training_args, embedding_args = parser.parse_json_file_and_cmd_lines() + else: + model_args, data_args, training_args, embedding_args = parser.parse_args_into_dataclasses() + + training_args.print_config(model_args, "Model") + training_args.print_config(data_args, "Data") + + # Setup GPU & distributed training + paddle.set_device(training_args.device) + set_seed(seed=training_args.seed) + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, " + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}" + ) + + if training_args.pipeline_parallel_degree > 1: + raise NotImplementedError("Cannot support pipeline parallel for Embedding training now.") + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Load model + if training_args.fp16_opt_level == "O2": + if training_args.fp16: + dtype = "float16" + elif training_args.bf16: + dtype = "bfloat16" + else: + raise ValueError("Please specific dtype: --fp16 or --bf16") + else: + dtype = "float32" + + model_config = AutoConfig.from_pretrained( + model_args.model_name_or_path, + dtype=dtype, + from_aistudio=model_args.from_aistudio, + ) + assert isinstance(model_config, Qwen2Config), "Now only qwen2 supported" + + LlmMetaConfig.set_llm_config(model_config, training_args) + model_config.refined_recompute = update_refined_recompute(training_args.refined_recompute) + model_config.use_fast_layer_norm = model_args.use_fast_layer_norm + + # Config for model using dropout, such as GPT. + if hasattr(model_config, "hidden_dropout_prob"): + model_config.hidden_dropout_prob = model_args.hidden_dropout_prob + if hasattr(model_config, "attention_probs_dropout_prob"): + model_config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob + if hasattr(model_config, "ignore_index"): + model_config.ignore_index = -100 + + if model_args.fuse_attention_qkv is not None: + model_config.fuse_attention_qkv = model_args.fuse_attention_qkv + if model_args.fuse_attention_ffn is not None: + model_config.fuse_attention_ffn = model_args.fuse_attention_ffn + + model_config.seq_length = data_args.max_length + model_config.embedding_negatives_cross_device = embedding_args.embedding_negatives_cross_device + logger.info(f"Final model config: {model_config}") + + model_class = Qwen2SentenceEmbedding + + if model_args.continue_training and not training_args.autotuner_benchmark: + model = model_class.from_pretrained( + model_args.model_name_or_path, + config=model_config, + from_aistudio=model_args.from_aistudio, + ) + else: + model = model_class.from_config(model_config, dtype=dtype) + + if model_args.flash_mask and (not data_args.zero_padding or not model.config.use_flash_attention): + logger.warning("`flash_mask` must use with zero padding and flash attention.") + data_args.zero_padding = True + model.config.use_flash_attention = True + + # Load tokenizer & dataset + tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, from_aistudio=model_args.from_aistudio) + + # init chat_template for tokenizer + init_chat_template(tokenizer, model_args.model_name_or_path, data_args.chat_template) + + # if using chat_template, data_args.eval_with_do_generation must be false + if tokenizer.chat_template is not None: + data_args.eval_with_do_generation = False + + if training_args.do_eval: + logger.warning("Warning: 'do_eval' is set to True, but will be set to False for Embedding training currently.") + training_args.do_eval = False + training_args.evaluation_strategy = "no" + + if data_args.dataset_name_or_path is None: + raise ValueError(f"Please specific dataset name or path (got {data_args.dataset_name_or_path})") + elif os.path.exists(os.path.join(data_args.dataset_name_or_path, "train.json")) or os.path.exists( + os.path.join(data_args.dataset_name_or_path, "dev.json") + ): + if training_args.do_train: + train_ds = load_dataset( + "json", + data_files=os.path.join(data_args.dataset_name_or_path, "train.json"), + lazy=data_args.lazy, + )[0] + else: + train_ds = None + if training_args.do_eval: + dev_ds = load_dataset( + "json", + data_files=os.path.join(data_args.dataset_name_or_path, "dev.json"), + lazy=data_args.lazy, + )[0] + else: + dev_ds = None + + elif os.path.exists(os.path.join(data_args.dataset_name_or_path, "train")) or os.path.exists( + os.path.join(data_args.dataset_name_or_path, "dev") + ): + import glob + + if training_args.do_train: + train_ds = load_dataset( + "json", + data_files=glob.glob(os.path.join(data_args.dataset_name_or_path, "train", "*.json")), + lazy=data_args.lazy, + )[0] + else: + train_ds = None + if training_args.do_eval: + dev_ds = load_dataset( + "json", + data_files=glob.glob(os.path.join(data_args.dataset_name_or_path, "dev", "*.json")), + lazy=data_args.lazy, + )[0] + else: + dev_ds = None + + else: + if training_args.do_train: + train_ds = load_dataset(data_args.dataset_name_or_path, splits=["train"])[0] + else: + train_ds = None + if training_args.do_eval: + dev_ds = load_dataset(data_args.dataset_name_or_path, splits=["dev"])[0] + else: + dev_ds = None + + # TODO(ZHUI & sijunhe): Temporary implementation. Generalize this logic and move to Trainer later. + if training_args.resume_from_checkpoint is not None and data_args.lazy: + logger.info( + f"Loading from '{training_args.resume_from_checkpoint}' with `lazy=True`, manually skipping dataset and setting `ignore_data_skip` to True." + ) + training_args.ignore_data_skip = True + state = TrainerState.load_from_json(os.path.join(training_args.resume_from_checkpoint, "trainer_state.json")) + if state.trial_params is not None and "zero_padding_global_step" in state.trial_params: + consumed_samples = state.trial_params["zero_padding_global_step"] + else: + consumed_samples = ( + state.global_step + * training_args.per_device_train_batch_size + * training_args.gradient_accumulation_steps + * training_args.dataset_world_size + ) + logger.info( + f"Skipping the first {consumed_samples} samples to warmup the dataset from checkpoint '{training_args.resume_from_checkpoint}'." + ) + train_ds = train_ds.skip(consumed_samples) + + if train_ds is not None: + train_ds = EmbeddingIterableDataset( + train_ds, + tokenizer, + max_query_len=embedding_args.max_query_len, + max_passage_len=embedding_args.max_passage_len, + group_size=embedding_args.group_size, + query_template=embedding_args.query_template, + passage_template=embedding_args.passage_template, + ) + + if dev_ds is not None: + dev_ds = EmbeddingIterableDataset( + dev_ds, + tokenizer, + max_query_len=embedding_args.max_query_len, + max_passage_len=embedding_args.max_passage_len, + group_size=embedding_args.group_size, + query_template=embedding_args.query_template, + passage_template=embedding_args.passage_template, + ) + + # Create trainer + if data_args.pad_to_max_length: + padding = "max_length" + else: + padding = True + + data_collator_fn = DataCollatorForEmbedding( + tokenizer=tokenizer, + max_query_len=embedding_args.max_query_len, + padding=padding, + max_passage_len=embedding_args.max_passage_len, + return_tensors="np", + return_attention_mask=not model_args.flash_mask, + pad_to_multiple_of=data_args.pad_to_multiple_of, + ) + trainer = EmbeddingTrainer( + model=model, + model_args=embedding_args, + args=training_args, + train_dataset=train_ds, + eval_dataset=dev_ds, + tokenizer=tokenizer, + compute_metrics=compute_metrics, + data_collator=data_collator_fn, + ) + trainable_parameters = [p for p in model.parameters() if not p.stop_gradient] + trainer.set_optimizer_grouped_parameters(trainable_parameters) + + # Train + if training_args.do_train: + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_model(merge_tensor_parallel=training_args.tensor_parallel_degree > 1) + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + + # Evaluation dev set + if training_args.do_eval: + logger.info("*** Evaluate result after train ***") + eval_result = trainer.evaluate(dev_ds) + trainer.log_metrics("eval", eval_result) + + +if __name__ == "__main__": + main() diff --git a/llm/utils/argument.py b/llm/utils/argument.py index 812293f1ab8f..99df142e826e 100644 --- a/llm/utils/argument.py +++ b/llm/utils/argument.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass, field +from typing import List, Optional @dataclass @@ -36,3 +37,54 @@ class GenerateArgument: top_p: float = field( default=1.0, metadata={"help": "The cumulative probability for top-p-filtering in the sampling strategy."} ) + + +@dataclass +class EmbeddingArgument: + max_query_len: int = field( + default=1, + metadata={ + "help": "The number of highest probability tokens to keep for top-k-filtering in the sampling strategy" + }, + ) + max_passage_len: int = field( + default=1.0, metadata={"help": "The cumulative probability for top-p-filtering in the sampling strategy."} + ) + group_size: int = field( + default=8, + metadata={ + "help": ( + "Number of total positive and negative samples associated with " "each query for embedding training." + ) + }, + ) + query_template: str = field( + default="Query: {text}\nUse one word to summarize the query's relevant information. The word is: \"", + metadata={ + "help": ( + "Query template. Ensure the template includes the placeholder " + "'{text}' to insert the actual query text." + ) + }, + ) + passage_template: str = field( + default="Text: {text}\nUse one word to summarize the text's content. The word is: \"", + metadata={ + "help": ( + "Passage template. Ensure the template includes the placeholder " + "'{text}' to insert the actual passage text." + ) + }, + ) + embedding_temperature: float = field( + default=0.02, + metadata={"help": "The temperature used in embedding learning."}, + ) + embedding_negatives_cross_device: bool = field( + default=True, + metadata={"help": "Whether to share the negatives across all GPUs."}, + ) + embedding_matryoshka_dims: Optional[List[int]] = field( + default=None, + metadata={"help": "The dims for matryoshka training."}, + ) diff --git a/paddlenlp/data/data_collator.py b/paddlenlp/data/data_collator.py index 78d3b3517ca0..d06953b4ee7a 100644 --- a/paddlenlp/data/data_collator.py +++ b/paddlenlp/data/data_collator.py @@ -39,6 +39,7 @@ "DataCollatorForSeq2Seq", "DataCollatorForLanguageModeling", "DataCollatorForWholeWordMask", + "DataCollatorForEmbedding", ] InputDataClass = NewType("InputDataClass", Any) @@ -417,6 +418,129 @@ def __call__(self, features, return_tensors=None): return batch +@dataclass +class DataCollatorForEmbedding: + tokenizer: PretrainedTokenizerBase + model: Optional[Any] = None + padding: Union[bool, str, PaddingStrategy] = True + pad_to_multiple_of: Optional[int] = None + label_pad_token_id: int = -100 + return_tensors: str = "pd" + return_attention_mask: Optional[bool] = None + max_label_length: Optional[int] = None + + max_query_len: int = 512 + max_passage_len: int = 512 + + def __call__(self, batch, return_tensors=None) -> Any: + """Convert batch data into tensor.""" + input_keys = ["input_ids", "position_ids"] + + attn_key = "attention_mask" + input_keys.append(attn_key) + + # Initialize query and passage lists + queries = {key: [] for key in input_keys} + passages = {key: [] for key in input_keys} + + batch_query_embedding_indices = [] + batch_passage_embedding_indices = [] + + global_passage_idx = 0 + + # Process each batch sequence + for idx, batch_sequence in enumerate(batch): + query_data = [pair.query for pair in batch_sequence] + padded_query_token_ids, padded_query_position_ids, query_token_ids = self.process_data( + query_data, self.tokenizer.pad_token_id, self.max_query_len + ) + + queries["input_ids"].append(padded_query_token_ids) + queries["position_ids"].append(padded_query_position_ids) + batch_query_embedding_indices.append([idx, len(query_token_ids[0]) - 1]) + + queries[attn_key].append(self.gen_self_attn_mask(query_token_ids, self.max_query_len)) + + for pair in batch_sequence: + for passage in pair.passages: + passage_data = [passage] + padded_passage_token_ids, padded_passage_position_ids, passage_token_ids = self.process_data( + passage_data, self.tokenizer.pad_token_id, self.max_passage_len + ) + + passages["input_ids"].append(padded_passage_token_ids) + passages["position_ids"].append(padded_passage_position_ids) + batch_passage_embedding_indices.append([global_passage_idx, len(passage_token_ids[0]) - 1]) + + passages[attn_key].append(self.gen_self_attn_mask(passage_token_ids, self.max_passage_len)) + global_passage_idx += 1 + + for data in (queries, passages): + for k, v in data.items(): + data[k] = paddle.to_tensor(np.concatenate(v)) + + queries["embedding_indices"] = paddle.to_tensor(np.array(batch_query_embedding_indices, dtype="int32")) + passages["embedding_indices"] = paddle.to_tensor(np.array(batch_passage_embedding_indices, dtype="int32")) + + return { + "query": queries, + "passages": passages, + } + + def process_data(self, data, pad_idx, max_len): + """padding token_ids & position_ids.""" + token_ids = [sum((item.token_ids for item in data), [])] + position_ids = [sum((item.position_ids for item in data), [])] + padded_token_ids = self.pad_batch_data(token_ids, pad_id=pad_idx, max_seq_len=max_len) + padded_position_ids = self.pad_batch_data(position_ids, pad_id=0, max_seq_len=max_len) + return padded_token_ids, padded_position_ids, token_ids + + @staticmethod + def pad_batch_data(insts, pad_id=0, max_seq_len=None, return_seq_len=False, pad_style="right"): + """Pad sequences to the max sequence length in batch.""" + max_len = max_seq_len if max_seq_len is not None else max(map(len, insts)) + if pad_style == "left": + inst_data = np.array([[pad_id] * (max_len - len(inst)) + list(inst) for inst in insts]) + else: + inst_data = np.array([list(inst) + [pad_id] * (max_len - len(inst)) for inst in insts]) + + if return_seq_len: + seq_len = np.array([len(inst) for inst in insts]) + return inst_data.astype("int64").reshape([-1, max_len]), seq_len + else: + return inst_data.astype("int64").reshape([-1, max_len]) + + @staticmethod + def gen_self_attn_mask(batch_token_ids: List[List[int]], max_seq_len: int): + """Generate self attention mask for multiple sub-sequence.""" + input_mask_data = np.zeros((1, 1, max_seq_len, max_seq_len), dtype="float32") + offset = 0 + for index, token_ids in enumerate(batch_token_ids): + cur_len = len(token_ids) + b = np.tril(np.ones([cur_len, cur_len]), 0) + input_mask_data[0, 0, offset : offset + cur_len, offset : offset + cur_len] = b + offset += cur_len + return input_mask_data + + @staticmethod + def gen_attn_mask_start_row_indices(batch_token_ids: List[List[int]], max_seq_len: int, sliding_window: int): + """Generate attn_mask_start_row_indices for flash attention.""" + offset = 0 + attn_mask_start_row_indices = [] + for token_ids in batch_token_ids: + cur_len = len(token_ids) + if sliding_window > 0: + for i in range(cur_len): + attn_mask_start_row_indices.append(offset + min(cur_len, i + sliding_window)) + else: + attn_mask_start_row_indices.extend([offset + cur_len] * cur_len) + offset += cur_len + if offset < max_seq_len: + attn_mask_start_row_indices.extend(list(range(offset + 1, max_seq_len + 1))) + + return np.array(attn_mask_start_row_indices, dtype=np.int32)[None, None] + + def _paddle_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None): """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary.""" import paddle diff --git a/paddlenlp/datasets/__init__.py b/paddlenlp/datasets/__init__.py index fda1d65868cf..49fb25fcf319 100644 --- a/paddlenlp/datasets/__init__.py +++ b/paddlenlp/datasets/__init__.py @@ -25,6 +25,7 @@ from .drcd import * from .drcd_cn import * from .dureader_robust import * +from .embedding_dataset import * from .glue import * from .imdb import * from .lcqmc import * diff --git a/paddlenlp/datasets/embedding_dataset.py b/paddlenlp/datasets/embedding_dataset.py new file mode 100644 index 000000000000..da34b9164e48 --- /dev/null +++ b/paddlenlp/datasets/embedding_dataset.py @@ -0,0 +1,252 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Embedding dataset.""" + +import random +from dataclasses import dataclass +from typing import List + +from paddle.io import Dataset, IterableDataset + +from ..utils.log import logger + + +@dataclass +class Example: + """Dataset example.""" + + query: str + pos_passage: List[str] + neg_passage: List[str] = None + + +@dataclass +class Sequence: + """Sequence.""" + + token_ids: List[int] + position_ids: List[int] + + +@dataclass +class Pair: + """Pair.""" + + query: Sequence + passages: List[Sequence] + + +class EmbeddingDatasetMixin: + """EmbeddingDatasetMixin.""" + + def convert_example(tokenizer, example): + """Convert raw json format example to Example.""" + + assert all( + (key in example for key in ["query", "pos_passage", "neg_passage"]) + ), "query, pos_passage, neg_passage are needed" + + if not isinstance(example["query"], str): + raise ValueError("query must be a string.") + if isinstance(example["pos_passage"], str): + example["pos_passage"] = [example["pos_passage"]] + if isinstance(example["neg_passage"], str): + example["neg_passage"] = [example["neg_passage"]] + + if len(example["neg_passage"]) > 0: + for item in [example["query"]] + example["pos_passage"] + example["neg_passage"]: + if not isinstance(item, str): + raise ValueError("The item in pos_passage / neg_passage must be a string.") + if len(item.strip()) == 0: + raise ValueError("Example with empty string in query / pos_passage / neg_passage field.") + + query = example["query"] + pos_passage = example["pos_passage"] + neg_passage = example["neg_passage"] + return Example(query=query, pos_passage=pos_passage, neg_passage=neg_passage) + + def tokenize_template(cls, tokenizer, template: str): + """Tokenize a given template using the provided tokenizer.""" + assert template.count("{text}") == 1, "Template must contain exactly one {text} placeholder" + + template_prefix, template_suffix = template.split("{text}") + + prefix_tokens = tokenizer(template_prefix, add_special_tokens=False).input_ids + suffix_tokens = tokenizer(template_suffix, add_special_tokens=False).input_ids + return prefix_tokens, suffix_tokens + + def _process_truncation(self, tokens, text_type): + """ + Process tokens by converting them into a complete token sequence with prefix and suffix, + and generate corresponding position ids. + """ + if text_type not in ["query", "passage"]: + raise ValueError("text_type must be either 'query' or 'passage'") + + prefix_key = f"{text_type}_template_prefix" + suffix_key = f"{text_type}_template_suffix" + max_len_key = f"max_{text_type}_len" + + # If the template does not contain a suffix token, add the EOS token to the end + if getattr(self, suffix_key) == []: + setattr(self, suffix_key, [self.tokenizer.eos_token_id]) + + # Calculate the available length + max_len = getattr(self, max_len_key) + prefix_tokens = getattr(self, prefix_key) + suffix_tokens = getattr(self, suffix_key) + available_len = int(max_len - len(prefix_tokens) - len(suffix_tokens)) + + # Convert tokens to ids and truncate + token_ids_converted = self.tokenizer.convert_tokens_to_ids(tokens) + truncated_token_ids = token_ids_converted[:available_len] + + # Combine prefix, truncated tokens, and suffix + token_ids = prefix_tokens + truncated_token_ids + suffix_tokens + pos_ids = list(range(len(token_ids))) + return token_ids, pos_ids + + def _postprocess_sequence(self, example: Example): + """Post process sequence: tokenization & truncation.""" + query = example.query + pos_passage = random.choice(example.pos_passage) + neg_passage = example.neg_passage + if len(neg_passage) > 0: + if len(neg_passage) < self.group_size - 1: + # Calculate how many full sets are needed to ensure each element appears at least once + full_sets_needed = (self.group_size - 1) // len(neg_passage) + remainder = (self.group_size - 1) % len(neg_passage) + + # Initialize the list and add complete sets + selected_neg_passage = neg_passage * full_sets_needed + + # Ensure the remainder part is filled; randomly select from neg_passage + selected_neg_passage += random.sample(neg_passage, remainder) + + # Shuffle the result to ensure randomness + random.shuffle(selected_neg_passage) + else: + selected_neg_passage = random.sample(neg_passage, self.group_size - 1) + else: + selected_neg_passage = [] + # Process query tokens + query_tokens = self.tokenizer.tokenize(query) + query_token_ids, query_pos_ids = self._process_truncation(query_tokens, "query") + + query = Sequence( + token_ids=query_token_ids, + position_ids=query_pos_ids, + ) + + # Process passage tokens + passages = [] + for passage in [pos_passage] + selected_neg_passage: + passage_tokens = self.tokenizer.tokenize(passage) + passage_token_ids, passage_pos_ids = self._process_truncation(passage_tokens, "passage") + passages.append( + Sequence( + token_ids=passage_token_ids, + position_ids=passage_pos_ids, + ) + ) + return Pair(query=query, passages=passages) + + +class EmbeddingDataset(EmbeddingDatasetMixin, Dataset): + def __init__( + self, + dataset, + tokenizer, + max_query_len: int = 64, + max_passage_len: int = 256, + group_size: int = 2, + query_template: str = "{text}", + passage_template: str = "{text}", + ): + super().__init__() + self.example_dataset = dataset + self.tokenizer = tokenizer + self.max_query_len = max_query_len + self.max_passage_len = max_passage_len + self.group_size = group_size + self.query_template = query_template + self.passage_template = passage_template + self.query_template_prefix, self.query_template_suffix = self.tokenize_template( + self.tokenizer, self.query_template + ) + self.passage_template_prefix, self.passage_template_suffix = self.tokenize_template( + self.tokenizer, self.passage_template + ) + + for index, data in enumerate(self.example_dataset): + self.example_dataset[index] = self.convert_example(data) + + def __getitem__(self, index): + return self._postprocess_sequence(self.example_dataset[index]) + + def __len__(self): + raise len(self.example_dataset) + + +class EmbeddingIterableDataset(EmbeddingDatasetMixin, IterableDataset): + """Create sequences from Example Dataset. + + This is a stateful dataset. + """ + + def __init__( + self, + dataset, + tokenizer, + max_query_len: int = 64, + max_passage_len: int = 256, + group_size: int = 2, + query_template: str = "{text}", + passage_template: str = "{text}", + ): + super().__init__() + self.example_dataset = dataset + self.tokenizer = tokenizer + self.max_query_len = max_query_len + self.max_passage_len = max_passage_len + self.group_size = group_size + self.query_template = query_template + self.passage_template = passage_template + self.query_template_prefix, self.query_template_suffix = self.tokenize_template( + self.tokenizer, self.query_template + ) + self.passage_template_prefix, self.passage_template_suffix = self.tokenize_template( + self.tokenizer, self.passage_template + ) + + self.epoch_index = 0 + + def __iter__(self): + while True: + logger.info(f"Start to load dataset on epoch={self.epoch_index}") + yield from self.iter_one_epoch() + + def iter_one_epoch(self): + """Iterates through one epoch of the dataset.""" + + num_sequences = 0 + for index, example in enumerate(self.example_dataset): + example = self.convert_example(example) + sequence = self._postprocess_sequence(example) + if sequence is None: + continue + num_sequences += 1 + yield [sequence] + + self.epoch_index += 1 diff --git a/paddlenlp/transformers/qwen2/modeling.py b/paddlenlp/transformers/qwen2/modeling.py index 35b3ea91f2b5..195f40e02188 100644 --- a/paddlenlp/transformers/qwen2/modeling.py +++ b/paddlenlp/transformers/qwen2/modeling.py @@ -23,15 +23,17 @@ import math import warnings from functools import partial -from typing import List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import paddle +import paddle.distributed as dist import paddle.distributed.fleet.meta_parallel as mpu import paddle.nn.functional as F from paddle import Tensor, nn from paddle.distributed import fleet from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker +from paddlenlp.transformers.contrastive_loss import SimpleContrastiveLoss from paddlenlp.transformers.refined_recompute import ( RRColumnParallelLinear, RRColumnSequenceParallelLinear, @@ -45,6 +47,7 @@ from .. import linear_utils from ..activations import ACT2FN from ..conversion_utils import StateDictNameMapping, init_name_mappings +from ..embedding_utils import dist_gather_tensor_with_gradient from ..linear_utils import Linear from ..llama import fusion_ops from ..model_outputs import ( @@ -84,6 +87,7 @@ "Qwen2PretrainingCriterion", "Qwen2ForSequenceClassification", "Qwen2ForTokenClassification", + "Qwen2SentenceEmbedding", ] @@ -1662,3 +1666,80 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +class Qwen2SentenceEmbedding(Qwen2PretrainedModel): + def __init__( + self, + config: Qwen2Config, + embedding_temperature: float = 0.02, + ): + """Qwen2SentenceEmbedding + For getting larger batch_size, we use tensor parallel to get larger batch_size. + + Args: + config (Qwen2Config): _description_ + model (Qwen2Model): _description_ + embedding_temperature (float, optional): _description_. Defaults to 0.02. + """ + super(Qwen2SentenceEmbedding, self).__init__(config) + self.config = config + self.qwen2 = Qwen2Model(config) + self.in_batch_negative_loss = SimpleContrastiveLoss(embedding_temperature) + self.world_size = dist.get_world_size() + self.process_rank = dist.get_rank() + self.embedding_negatives_cross_device = config.embedding_negatives_cross_device + if self.world_size <= 1: + self.embedding_negatives_cross_device = False + + def forward( + self, + query: Optional[Dict[str, paddle.Tensor]] = None, + passages: Optional[Dict[str, paddle.Tensor]] = None, + return_encode=False, + ): + """forward""" + q_reps = self.encode(**query) + p_reps = self.encode(**passages) + + q_reps = nn.functional.normalize(q_reps, axis=-1) + p_reps = nn.functional.normalize(p_reps, axis=-1) + + if return_encode: + return q_reps, p_reps + + if self.embedding_negatives_cross_device: + q_reps = dist_gather_tensor_with_gradient(q_reps) + p_reps = dist_gather_tensor_with_gradient(p_reps) + + loss = self.in_batch_negative_loss(q_reps, p_reps) + return loss + + def encode( + self, + input_ids, + position_ids=None, + embedding_indices=None, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=False, + **kwargs, + ): + """encode""" + input_type = type(input_ids) + outputs = self.qwen2( + input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs, + ) + if isinstance(outputs, input_type): + hidden_states = outputs + else: + hidden_states = outputs[0] + last_hidden_states = hidden_states.gather_nd(embedding_indices) + return last_hidden_states