Skip to content

Commit

Permalink
Add greedy_zero_padding (PaddlePaddle#8933)
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay authored and Mangodadada committed Sep 10, 2024
1 parent 77cb825 commit 3247760
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 63 deletions.
11 changes: 4 additions & 7 deletions llm/alignment/dpo/dpo_argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,10 @@ class DPODataArgument:
default=False,
metadata={"help": "Whether to run benchmark by autotuner. True for from_scratch."},
)
greedy_intokens: bool = field(
default=True,
metadata={"help": "Whether apply greedy intokens."},
greedy_zero_padding: bool = field(
default=False,
metadata={"help": "Whether to use Greedy Zero Padding data stream."},
)
buffer_size: int = field(default=500, metadata={"help": "Buffer size for greedy_intokens strategy."})


@dataclass
Expand All @@ -87,9 +86,7 @@ class DPOModelArgument:
"help": "The granularity of recompute training can be selected as `full` or `full_attn` or `core_attn`."
},
)
flash_mask: bool = field(
default=False, metadata={"help": "Whether to use flash mask in flash attention."}
)
flash_mask: bool = field(default=False, metadata={"help": "Whether to use flash mask in flash attention."})
virtual_pp_degree: int = field(
default=1,
metadata={"help": "virtual_pp_degree"},
Expand Down
18 changes: 9 additions & 9 deletions llm/alignment/dpo/run_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import os
import sys
import time
import inspect
from functools import partial

import paddle
Expand All @@ -30,17 +29,19 @@
get_last_checkpoint,
set_seed,
)
from paddlenlp.transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from paddlenlp.transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
LlamaForCausalLM,
LlamaForCausalLMPipe,
)
from paddlenlp.trl import (
DPOTrainer,
calculate_effective_tokens,
preference_collate_fn,
preprocess_preference_data,
)
from paddlenlp.transformers import (
LlamaForCausalLM,
LlamaForCausalLMPipe,
)
from paddlenlp.utils.log import logger

flash_mask_support_list = [LlamaForCausalLM, LlamaForCausalLMPipe]
Expand Down Expand Up @@ -132,9 +133,7 @@ def main():
model.set_state_dict(ref_model.state_dict())

if model_args.flash_mask and not model.config.use_flash_attention:
logger.warning(
"`flash_mask` must use with zero padding and flash attention."
)
logger.warning("`flash_mask` must use with zero padding and flash attention.")
model.config.use_flash_attention = True

if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list):
Expand All @@ -161,6 +160,7 @@ def main():
train_ds.map(trans_func),
tokenizer=tokenizer,
max_length=data_args.max_seq_len,
greedy_zero_padding=data_args.greedy_zero_padding,
)
if train_ds is not None
else None
Expand Down
2 changes: 2 additions & 0 deletions llm/run_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def neft_post_hook(module, input, output):
train_ds,
tokenizer=tokenizer,
max_length=data_args.max_length,
greedy_zero_padding=data_args.greedy_zero_padding,
)
if train_ds is not None
else None
Expand All @@ -400,6 +401,7 @@ def neft_post_hook(module, input, output):
ptq_ds,
tokenizer=tokenizer,
max_length=data_args.max_length,
greedy_zero_padding=data_args.greedy_zero_padding,
)
if ptq_ds is not None
else None
Expand Down
6 changes: 6 additions & 0 deletions llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ class DataArgument:
dataset_name_or_path: str = field(default=None, metadata={"help": "Name or path for dataset"})
task_name: str = field(default=None, metadata={"help": "Additional name to select a more specific task."})
zero_padding: bool = field(default=False, metadata={"help": "Whether to use Zero Padding data stream"})
greedy_zero_padding: bool = field(
default=False,
metadata={
"help": "Whether to use Greedy Zero Padding data stream, should be used together with `zero_padding=True`."
},
)
pad_to_multiple_of: int = field(
default=None, metadata={"help": "If set will pad the sequence to a multiple of the provided value."}
)
Expand Down
170 changes: 123 additions & 47 deletions paddlenlp/datasets/zero_padding_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,27 @@
from scipy.linalg import block_diag


def generate_greedy_packs(examples, max_length):
left_len = np.zeros([len(examples)]) - 1
left_len[0] = max_length # At the beginning, only the first pack is valid.
generate_packs = [[] for i in range(len(examples))]
index, left_index = 0, 0

while index < len(examples):
record = examples[index]
max_left_index = left_len.argmax()
# Put the current sequence into the largest left space valid pack.
if len(record["input_ids"]) <= left_len[max_left_index]:
generate_packs[max_left_index].append(record)
left_len[max_left_index] -= len(record["input_ids"])
index += 1
else:
left_index += 1
left_len[left_index] = max_length

return generate_packs


class ZeroPadding:
required_output_keys = ["input_ids", "labels", "attention_mask"]
# Only supported the following keys for ZeroPadding. Keys outside of the set will be ignored.
Expand Down Expand Up @@ -80,38 +101,66 @@ def _pad_batch_records(cls, batch_records):


class ZeroPaddingMapDataset(ZeroPadding, Dataset):
def __init__(self, data, tokenizer, max_length):
def __init__(self, data, tokenizer, max_length, greedy_zero_padding=False):
self.tokenizer = tokenizer
self.max_length = max_length
self.greedy_zero_padding = greedy_zero_padding
self.new_data = self._create_zero_padding_data(data)

def _create_zero_padding_data(self, data):
batch_records, max_len = [], 0
cur_len_so_far = 0

total_data = []
for i in range(len(data)):
record = data[i]
max_len = max(max_len, len(record["input_ids"]))
to_append = (cur_len_so_far + len(record["input_ids"])) <= self.max_length
if to_append:
batch_records.append(record)
cur_len_so_far += len(record["input_ids"])
else:
# exceed max length
if not self.greedy_zero_padding:
batch_records = []
cur_len_so_far = 0
for i in range(len(data)):
record = data[i]
if len(record["input_ids"]) > self.max_length:
continue
to_append = (cur_len_so_far + len(record["input_ids"])) <= self.max_length
if to_append:
batch_records.append(record)
cur_len_so_far += len(record["input_ids"])
else:
# exceed max length
padded_list = self._pad_batch_records(batch_records)
total_data.append(padded_list)
# reset
batch_records = []
cur_len_so_far = 0
# append current data
batch_records.append(record)
cur_len_so_far += len(record["input_ids"])

# remaining data
if batch_records:
padded_list = self._pad_batch_records(batch_records)
total_data.append(padded_list)
# reset
batch_records, max_len = [], 0
cur_len_so_far = 0
# append current data
batch_records.append(record)
cur_len_so_far += len(record["input_ids"])

# remaining data
if batch_records:
padded_list = self._pad_batch_records(batch_records)
total_data.append(padded_list)
else:
examples = []
buffer_size = 500
i = 0
for record in data:
if len(record["input_ids"]) > self.max_length:
continue
if i < buffer_size:
examples.append(record)
i += 1
else:
# Running greedy strategy in examples.
generate_packs = generate_greedy_packs(examples, self.max_length)
for batch_records in generate_packs:
if len(batch_records) > 0:
padded_list = self._pad_batch_records(batch_records)
total_data.append(padded_list)
examples = [record]
i = 1
if len(examples) > 0:
generate_packs = generate_greedy_packs(examples, self.max_length)
for batch_records in generate_packs:
if len(batch_records) > 0:
padded_list = self._pad_batch_records(batch_records)
total_data.append(padded_list)

return total_data

def __getitem__(self, idx):
Expand All @@ -122,34 +171,61 @@ def __len__(self):


class ZeroPaddingIterableDataset(ZeroPadding, IterableDataset):
def __init__(self, data, tokenizer, max_length):

def __init__(self, data, tokenizer, max_length, greedy_zero_padding=False):
self.data = data
self.tokenizer = tokenizer
self.max_length = max_length
self.zero_padding_global_step = 0
self.greedy_zero_padding = greedy_zero_padding

def __iter__(self):
batch_records, max_len = [], 0
cur_len_so_far = 0
for record in self.data:
max_len = max(max_len, len(record["input_ids"]))
to_append = (cur_len_so_far + len(record["input_ids"])) <= self.max_length
if to_append:
batch_records.append(record)
self.zero_padding_global_step += 1
cur_len_so_far += len(record["input_ids"])
else:
# exceed max length
if not self.greedy_zero_padding:
batch_records = []
cur_len_so_far = 0
for record in self.data:
to_append = (cur_len_so_far + len(record["input_ids"])) <= self.max_length
if to_append:
batch_records.append(record)
self.zero_padding_global_step += 1
cur_len_so_far += len(record["input_ids"])
else:
# exceed max length
padded_list = self._pad_batch_records(batch_records)
yield padded_list
# reset
batch_records = []
cur_len_so_far = 0
# append current data
batch_records.append(record)
self.zero_padding_global_step += 1
cur_len_so_far += len(record["input_ids"])
if batch_records:
padded_list = self._pad_batch_records(batch_records)
yield padded_list
# reset
batch_records, max_len = [], 0
cur_len_so_far = 0
# append current data
batch_records.append(record)
self.zero_padding_global_step += 1
cur_len_so_far += len(record["input_ids"])
if batch_records:
padded_list = self._pad_batch_records(batch_records)
yield padded_list
else:
examples = []
buffer_size = 500
i = 0
for record in self.data:
if len(record["input_ids"]) > self.max_length:
continue
if i < buffer_size:
examples.append(record)
self.zero_padding_global_step += 1
i += 1
else:
# Running greedy strategy in examples.
generate_packs = generate_greedy_packs(examples, self.max_length)
for batch_records in generate_packs:
if len(batch_records) > 0:
padded_list = self._pad_batch_records(batch_records)
yield padded_list
examples = [record]
self.zero_padding_global_step += 1
i = 1
if len(examples) > 0:
generate_packs = generate_greedy_packs(examples, self.max_length)
for batch_records in generate_packs:
if len(batch_records) > 0:
padded_list = self._pad_batch_records(batch_records)
yield padded_list

0 comments on commit 3247760

Please sign in to comment.