Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ZeroPadding] add greedy_zero_padding #8933

Merged
merged 1 commit into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions llm/alignment/dpo/dpo_argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,10 @@ class DPODataArgument:
default=False,
metadata={"help": "Whether to run benchmark by autotuner. True for from_scratch."},
)
greedy_intokens: bool = field(
default=True,
metadata={"help": "Whether apply greedy intokens."},
greedy_zero_padding: bool = field(
default=False,
metadata={"help": "Whether to use Greedy Zero Padding data stream."},
)
buffer_size: int = field(default=500, metadata={"help": "Buffer size for greedy_intokens strategy."})


@dataclass
Expand All @@ -87,9 +86,7 @@ class DPOModelArgument:
"help": "The granularity of recompute training can be selected as `full` or `full_attn` or `core_attn`."
},
)
flash_mask: bool = field(
default=False, metadata={"help": "Whether to use flash mask in flash attention."}
)
flash_mask: bool = field(default=False, metadata={"help": "Whether to use flash mask in flash attention."})
virtual_pp_degree: int = field(
default=1,
metadata={"help": "virtual_pp_degree"},
Expand Down
18 changes: 9 additions & 9 deletions llm/alignment/dpo/run_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import os
import sys
import time
import inspect
from functools import partial

import paddle
Expand All @@ -30,17 +29,19 @@
get_last_checkpoint,
set_seed,
)
from paddlenlp.transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from paddlenlp.transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
LlamaForCausalLM,
LlamaForCausalLMPipe,
)
from paddlenlp.trl import (
DPOTrainer,
calculate_effective_tokens,
preference_collate_fn,
preprocess_preference_data,
)
from paddlenlp.transformers import (
LlamaForCausalLM,
LlamaForCausalLMPipe,
)
from paddlenlp.utils.log import logger

flash_mask_support_list = [LlamaForCausalLM, LlamaForCausalLMPipe]
Expand Down Expand Up @@ -132,9 +133,7 @@ def main():
model.set_state_dict(ref_model.state_dict())

if model_args.flash_mask and not model.config.use_flash_attention:
logger.warning(
"`flash_mask` must use with zero padding and flash attention."
)
logger.warning("`flash_mask` must use with zero padding and flash attention.")
model.config.use_flash_attention = True

if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list):
Expand All @@ -161,6 +160,7 @@ def main():
train_ds.map(trans_func),
tokenizer=tokenizer,
max_length=data_args.max_seq_len,
greedy_zero_padding=data_args.greedy_zero_padding,
)
if train_ds is not None
else None
Expand Down
2 changes: 2 additions & 0 deletions llm/run_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def neft_post_hook(module, input, output):
train_ds,
tokenizer=tokenizer,
max_length=data_args.max_length,
greedy_zero_padding=data_args.greedy_zero_padding,
)
if train_ds is not None
else None
Expand All @@ -400,6 +401,7 @@ def neft_post_hook(module, input, output):
ptq_ds,
tokenizer=tokenizer,
max_length=data_args.max_length,
greedy_zero_padding=data_args.greedy_zero_padding,
)
if ptq_ds is not None
else None
Expand Down
6 changes: 6 additions & 0 deletions llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ class DataArgument:
dataset_name_or_path: str = field(default=None, metadata={"help": "Name or path for dataset"})
task_name: str = field(default=None, metadata={"help": "Additional name to select a more specific task."})
zero_padding: bool = field(default=False, metadata={"help": "Whether to use Zero Padding data stream"})
greedy_zero_padding: bool = field(
default=False,
metadata={
"help": "Whether to use Greedy Zero Padding data stream, should be used together with `zero_padding=True`."
},
)
pad_to_multiple_of: int = field(
default=None, metadata={"help": "If set will pad the sequence to a multiple of the provided value."}
)
Expand Down
170 changes: 123 additions & 47 deletions paddlenlp/datasets/zero_padding_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,27 @@
from scipy.linalg import block_diag


def generate_greedy_packs(examples, max_length):
left_len = np.zeros([len(examples)]) - 1
left_len[0] = max_length # At the beginning, only the first pack is valid.
generate_packs = [[] for i in range(len(examples))]
index, left_index = 0, 0

Check warning on line 24 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L21-L24

Added lines #L21 - L24 were not covered by tests

while index < len(examples):
record = examples[index]
max_left_index = left_len.argmax()

Check warning on line 28 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L26-L28

Added lines #L26 - L28 were not covered by tests
# Put the current sequence into the largest left space valid pack.
if len(record["input_ids"]) <= left_len[max_left_index]:
generate_packs[max_left_index].append(record)
left_len[max_left_index] -= len(record["input_ids"])
index += 1

Check warning on line 33 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L30-L33

Added lines #L30 - L33 were not covered by tests
else:
left_index += 1
left_len[left_index] = max_length

Check warning on line 36 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L35-L36

Added lines #L35 - L36 were not covered by tests

return generate_packs

Check warning on line 38 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L38

Added line #L38 was not covered by tests


class ZeroPadding:
required_output_keys = ["input_ids", "labels", "attention_mask"]
# Only supported the following keys for ZeroPadding. Keys outside of the set will be ignored.
Expand Down Expand Up @@ -80,38 +101,66 @@


class ZeroPaddingMapDataset(ZeroPadding, Dataset):
def __init__(self, data, tokenizer, max_length):
def __init__(self, data, tokenizer, max_length, greedy_zero_padding=False):
self.tokenizer = tokenizer
self.max_length = max_length
self.greedy_zero_padding = greedy_zero_padding

Check warning on line 107 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L107

Added line #L107 was not covered by tests
self.new_data = self._create_zero_padding_data(data)

def _create_zero_padding_data(self, data):
batch_records, max_len = [], 0
cur_len_so_far = 0

total_data = []
for i in range(len(data)):
record = data[i]
max_len = max(max_len, len(record["input_ids"]))
to_append = (cur_len_so_far + len(record["input_ids"])) <= self.max_length
if to_append:
batch_records.append(record)
cur_len_so_far += len(record["input_ids"])
else:
# exceed max length
if not self.greedy_zero_padding:
batch_records = []
cur_len_so_far = 0
for i in range(len(data)):
record = data[i]
if len(record["input_ids"]) > self.max_length:
continue
to_append = (cur_len_so_far + len(record["input_ids"])) <= self.max_length
if to_append:
batch_records.append(record)
cur_len_so_far += len(record["input_ids"])

Check warning on line 122 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L112-L122

Added lines #L112 - L122 were not covered by tests
else:
# exceed max length
padded_list = self._pad_batch_records(batch_records)
total_data.append(padded_list)

Check warning on line 126 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L125-L126

Added lines #L125 - L126 were not covered by tests
# reset
batch_records = []
cur_len_so_far = 0

Check warning on line 129 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L128-L129

Added lines #L128 - L129 were not covered by tests
# append current data
batch_records.append(record)
cur_len_so_far += len(record["input_ids"])

Check warning on line 132 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L131-L132

Added lines #L131 - L132 were not covered by tests

# remaining data
if batch_records:

Check warning on line 135 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L135

Added line #L135 was not covered by tests
padded_list = self._pad_batch_records(batch_records)
total_data.append(padded_list)
# reset
batch_records, max_len = [], 0
cur_len_so_far = 0
# append current data
batch_records.append(record)
cur_len_so_far += len(record["input_ids"])

# remaining data
if batch_records:
padded_list = self._pad_batch_records(batch_records)
total_data.append(padded_list)
else:
examples = []
buffer_size = 500
i = 0
for record in data:
if len(record["input_ids"]) > self.max_length:
continue
if i < buffer_size:
examples.append(record)
i += 1

Check warning on line 147 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L139-L147

Added lines #L139 - L147 were not covered by tests
else:
# Running greedy strategy in examples.
generate_packs = generate_greedy_packs(examples, self.max_length)
for batch_records in generate_packs:
if len(batch_records) > 0:
padded_list = self._pad_batch_records(batch_records)
total_data.append(padded_list)
examples = [record]
i = 1
if len(examples) > 0:
generate_packs = generate_greedy_packs(examples, self.max_length)
for batch_records in generate_packs:
if len(batch_records) > 0:
padded_list = self._pad_batch_records(batch_records)
total_data.append(padded_list)

Check warning on line 162 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L150-L162

Added lines #L150 - L162 were not covered by tests

return total_data

def __getitem__(self, idx):
Expand All @@ -122,34 +171,61 @@


class ZeroPaddingIterableDataset(ZeroPadding, IterableDataset):
def __init__(self, data, tokenizer, max_length):

def __init__(self, data, tokenizer, max_length, greedy_zero_padding=False):
self.data = data
self.tokenizer = tokenizer
self.max_length = max_length
self.zero_padding_global_step = 0
self.greedy_zero_padding = greedy_zero_padding

Check warning on line 179 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L179

Added line #L179 was not covered by tests

def __iter__(self):
batch_records, max_len = [], 0
cur_len_so_far = 0
for record in self.data:
max_len = max(max_len, len(record["input_ids"]))
to_append = (cur_len_so_far + len(record["input_ids"])) <= self.max_length
if to_append:
batch_records.append(record)
self.zero_padding_global_step += 1
cur_len_so_far += len(record["input_ids"])
else:
# exceed max length
if not self.greedy_zero_padding:
batch_records = []
cur_len_so_far = 0
for record in self.data:
to_append = (cur_len_so_far + len(record["input_ids"])) <= self.max_length
if to_append:
batch_records.append(record)
self.zero_padding_global_step += 1
cur_len_so_far += len(record["input_ids"])

Check warning on line 190 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L182-L190

Added lines #L182 - L190 were not covered by tests
else:
# exceed max length
padded_list = self._pad_batch_records(batch_records)
yield padded_list

Check warning on line 194 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L193-L194

Added lines #L193 - L194 were not covered by tests
# reset
batch_records = []
cur_len_so_far = 0

Check warning on line 197 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L196-L197

Added lines #L196 - L197 were not covered by tests
# append current data
batch_records.append(record)
self.zero_padding_global_step += 1
cur_len_so_far += len(record["input_ids"])
if batch_records:

Check warning on line 202 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L199-L202

Added lines #L199 - L202 were not covered by tests
padded_list = self._pad_batch_records(batch_records)
yield padded_list
# reset
batch_records, max_len = [], 0
cur_len_so_far = 0
# append current data
batch_records.append(record)
self.zero_padding_global_step += 1
cur_len_so_far += len(record["input_ids"])
if batch_records:
padded_list = self._pad_batch_records(batch_records)
yield padded_list
else:
examples = []
buffer_size = 500
i = 0
for record in self.data:
if len(record["input_ids"]) > self.max_length:
continue
if i < buffer_size:
examples.append(record)
self.zero_padding_global_step += 1
i += 1

Check warning on line 215 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L206-L215

Added lines #L206 - L215 were not covered by tests
else:
# Running greedy strategy in examples.
generate_packs = generate_greedy_packs(examples, self.max_length)
for batch_records in generate_packs:
if len(batch_records) > 0:
padded_list = self._pad_batch_records(batch_records)
yield padded_list
examples = [record]
self.zero_padding_global_step += 1
i = 1
if len(examples) > 0:
generate_packs = generate_greedy_packs(examples, self.max_length)
for batch_records in generate_packs:
if len(batch_records) > 0:
padded_list = self._pad_batch_records(batch_records)
yield padded_list

Check warning on line 231 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L218-L231

Added lines #L218 - L231 were not covered by tests
Loading