diff --git a/llm/alignment/dpo/dpo_argument.py b/llm/alignment/dpo/dpo_argument.py index 86173be4ae2c..6d4d03603068 100644 --- a/llm/alignment/dpo/dpo_argument.py +++ b/llm/alignment/dpo/dpo_argument.py @@ -63,11 +63,10 @@ class DPODataArgument: default=False, metadata={"help": "Whether to run benchmark by autotuner. True for from_scratch."}, ) - greedy_intokens: bool = field( - default=True, - metadata={"help": "Whether apply greedy intokens."}, + greedy_zero_padding: bool = field( + default=False, + metadata={"help": "Whether to use Greedy Zero Padding data stream."}, ) - buffer_size: int = field(default=500, metadata={"help": "Buffer size for greedy_intokens strategy."}) @dataclass @@ -87,9 +86,7 @@ class DPOModelArgument: "help": "The granularity of recompute training can be selected as `full` or `full_attn` or `core_attn`." }, ) - flash_mask: bool = field( - default=False, metadata={"help": "Whether to use flash mask in flash attention."} - ) + flash_mask: bool = field(default=False, metadata={"help": "Whether to use flash mask in flash attention."}) virtual_pp_degree: int = field( default=1, metadata={"help": "virtual_pp_degree"}, diff --git a/llm/alignment/dpo/run_dpo.py b/llm/alignment/dpo/run_dpo.py index 0d542444fe55..3945375aee43 100644 --- a/llm/alignment/dpo/run_dpo.py +++ b/llm/alignment/dpo/run_dpo.py @@ -17,7 +17,6 @@ import os import sys import time -import inspect from functools import partial import paddle @@ -30,17 +29,19 @@ get_last_checkpoint, set_seed, ) -from paddlenlp.transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from paddlenlp.transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + LlamaForCausalLM, + LlamaForCausalLMPipe, +) from paddlenlp.trl import ( DPOTrainer, calculate_effective_tokens, preference_collate_fn, preprocess_preference_data, ) -from paddlenlp.transformers import ( - LlamaForCausalLM, - LlamaForCausalLMPipe, -) from paddlenlp.utils.log import logger flash_mask_support_list = [LlamaForCausalLM, LlamaForCausalLMPipe] @@ -132,9 +133,7 @@ def main(): model.set_state_dict(ref_model.state_dict()) if model_args.flash_mask and not model.config.use_flash_attention: - logger.warning( - "`flash_mask` must use with zero padding and flash attention." - ) + logger.warning("`flash_mask` must use with zero padding and flash attention.") model.config.use_flash_attention = True if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list): @@ -161,6 +160,7 @@ def main(): train_ds.map(trans_func), tokenizer=tokenizer, max_length=data_args.max_seq_len, + greedy_zero_padding=data_args.greedy_zero_padding, ) if train_ds is not None else None diff --git a/llm/run_finetune.py b/llm/run_finetune.py index bdb67a5843e6..851f74ae3224 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -391,6 +391,7 @@ def neft_post_hook(module, input, output): train_ds, tokenizer=tokenizer, max_length=data_args.max_length, + greedy_zero_padding=data_args.greedy_zero_padding, ) if train_ds is not None else None @@ -400,6 +401,7 @@ def neft_post_hook(module, input, output): ptq_ds, tokenizer=tokenizer, max_length=data_args.max_length, + greedy_zero_padding=data_args.greedy_zero_padding, ) if ptq_ds is not None else None diff --git a/llm/utils/argument.py b/llm/utils/argument.py index 7ee754b3e4a0..5c4df81ff05d 100644 --- a/llm/utils/argument.py +++ b/llm/utils/argument.py @@ -86,6 +86,12 @@ class DataArgument: dataset_name_or_path: str = field(default=None, metadata={"help": "Name or path for dataset"}) task_name: str = field(default=None, metadata={"help": "Additional name to select a more specific task."}) zero_padding: bool = field(default=False, metadata={"help": "Whether to use Zero Padding data stream"}) + greedy_zero_padding: bool = field( + default=False, + metadata={ + "help": "Whether to use Greedy Zero Padding data stream, should be used together with `zero_padding=True`." + }, + ) pad_to_multiple_of: int = field( default=None, metadata={"help": "If set will pad the sequence to a multiple of the provided value."} ) diff --git a/paddlenlp/datasets/zero_padding_dataset.py b/paddlenlp/datasets/zero_padding_dataset.py index 9b65d448d31b..2a071fc11404 100644 --- a/paddlenlp/datasets/zero_padding_dataset.py +++ b/paddlenlp/datasets/zero_padding_dataset.py @@ -17,6 +17,27 @@ from scipy.linalg import block_diag +def generate_greedy_packs(examples, max_length): + left_len = np.zeros([len(examples)]) - 1 + left_len[0] = max_length # At the beginning, only the first pack is valid. + generate_packs = [[] for i in range(len(examples))] + index, left_index = 0, 0 + + while index < len(examples): + record = examples[index] + max_left_index = left_len.argmax() + # Put the current sequence into the largest left space valid pack. + if len(record["input_ids"]) <= left_len[max_left_index]: + generate_packs[max_left_index].append(record) + left_len[max_left_index] -= len(record["input_ids"]) + index += 1 + else: + left_index += 1 + left_len[left_index] = max_length + + return generate_packs + + class ZeroPadding: required_output_keys = ["input_ids", "labels", "attention_mask"] # Only supported the following keys for ZeroPadding. Keys outside of the set will be ignored. @@ -80,38 +101,66 @@ def _pad_batch_records(cls, batch_records): class ZeroPaddingMapDataset(ZeroPadding, Dataset): - def __init__(self, data, tokenizer, max_length): + def __init__(self, data, tokenizer, max_length, greedy_zero_padding=False): self.tokenizer = tokenizer self.max_length = max_length + self.greedy_zero_padding = greedy_zero_padding self.new_data = self._create_zero_padding_data(data) def _create_zero_padding_data(self, data): - batch_records, max_len = [], 0 - cur_len_so_far = 0 - total_data = [] - for i in range(len(data)): - record = data[i] - max_len = max(max_len, len(record["input_ids"])) - to_append = (cur_len_so_far + len(record["input_ids"])) <= self.max_length - if to_append: - batch_records.append(record) - cur_len_so_far += len(record["input_ids"]) - else: - # exceed max length + if not self.greedy_zero_padding: + batch_records = [] + cur_len_so_far = 0 + for i in range(len(data)): + record = data[i] + if len(record["input_ids"]) > self.max_length: + continue + to_append = (cur_len_so_far + len(record["input_ids"])) <= self.max_length + if to_append: + batch_records.append(record) + cur_len_so_far += len(record["input_ids"]) + else: + # exceed max length + padded_list = self._pad_batch_records(batch_records) + total_data.append(padded_list) + # reset + batch_records = [] + cur_len_so_far = 0 + # append current data + batch_records.append(record) + cur_len_so_far += len(record["input_ids"]) + + # remaining data + if batch_records: padded_list = self._pad_batch_records(batch_records) total_data.append(padded_list) - # reset - batch_records, max_len = [], 0 - cur_len_so_far = 0 - # append current data - batch_records.append(record) - cur_len_so_far += len(record["input_ids"]) - - # remaining data - if batch_records: - padded_list = self._pad_batch_records(batch_records) - total_data.append(padded_list) + else: + examples = [] + buffer_size = 500 + i = 0 + for record in data: + if len(record["input_ids"]) > self.max_length: + continue + if i < buffer_size: + examples.append(record) + i += 1 + else: + # Running greedy strategy in examples. + generate_packs = generate_greedy_packs(examples, self.max_length) + for batch_records in generate_packs: + if len(batch_records) > 0: + padded_list = self._pad_batch_records(batch_records) + total_data.append(padded_list) + examples = [record] + i = 1 + if len(examples) > 0: + generate_packs = generate_greedy_packs(examples, self.max_length) + for batch_records in generate_packs: + if len(batch_records) > 0: + padded_list = self._pad_batch_records(batch_records) + total_data.append(padded_list) + return total_data def __getitem__(self, idx): @@ -122,34 +171,61 @@ def __len__(self): class ZeroPaddingIterableDataset(ZeroPadding, IterableDataset): - def __init__(self, data, tokenizer, max_length): - + def __init__(self, data, tokenizer, max_length, greedy_zero_padding=False): self.data = data self.tokenizer = tokenizer self.max_length = max_length self.zero_padding_global_step = 0 + self.greedy_zero_padding = greedy_zero_padding def __iter__(self): - batch_records, max_len = [], 0 - cur_len_so_far = 0 - for record in self.data: - max_len = max(max_len, len(record["input_ids"])) - to_append = (cur_len_so_far + len(record["input_ids"])) <= self.max_length - if to_append: - batch_records.append(record) - self.zero_padding_global_step += 1 - cur_len_so_far += len(record["input_ids"]) - else: - # exceed max length + if not self.greedy_zero_padding: + batch_records = [] + cur_len_so_far = 0 + for record in self.data: + to_append = (cur_len_so_far + len(record["input_ids"])) <= self.max_length + if to_append: + batch_records.append(record) + self.zero_padding_global_step += 1 + cur_len_so_far += len(record["input_ids"]) + else: + # exceed max length + padded_list = self._pad_batch_records(batch_records) + yield padded_list + # reset + batch_records = [] + cur_len_so_far = 0 + # append current data + batch_records.append(record) + self.zero_padding_global_step += 1 + cur_len_so_far += len(record["input_ids"]) + if batch_records: padded_list = self._pad_batch_records(batch_records) yield padded_list - # reset - batch_records, max_len = [], 0 - cur_len_so_far = 0 - # append current data - batch_records.append(record) - self.zero_padding_global_step += 1 - cur_len_so_far += len(record["input_ids"]) - if batch_records: - padded_list = self._pad_batch_records(batch_records) - yield padded_list + else: + examples = [] + buffer_size = 500 + i = 0 + for record in self.data: + if len(record["input_ids"]) > self.max_length: + continue + if i < buffer_size: + examples.append(record) + self.zero_padding_global_step += 1 + i += 1 + else: + # Running greedy strategy in examples. + generate_packs = generate_greedy_packs(examples, self.max_length) + for batch_records in generate_packs: + if len(batch_records) > 0: + padded_list = self._pad_batch_records(batch_records) + yield padded_list + examples = [record] + self.zero_padding_global_step += 1 + i = 1 + if len(examples) > 0: + generate_packs = generate_greedy_packs(examples, self.max_length) + for batch_records in generate_packs: + if len(batch_records) > 0: + padded_list = self._pad_batch_records(batch_records) + yield padded_list