Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ZeroPadding] padding to max_length for sequence parallel #8973

Merged
merged 3 commits into from
Aug 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 44 additions & 9 deletions paddlenlp/datasets/zero_padding_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,42 @@ class ZeroPadding:
]

@classmethod
def _pad_batch_records(cls, batch_records):
def _pad_batch_records_to_max_length(cls, batch_records, max_length, pad_token=0):
# confirm the at least one item in the pack
if len(batch_records) == 0:
return batch_records
# count all records total length
total_length = sum([len(record["input_ids"]) for record in batch_records])
reserved_length = max_length - total_length

# append padding to the max_length
if "attn_mask_startend_row_indices" in batch_records[0]:
# attn_mask_startend_row_indices is a list of row indices `0`,
# which indicates that all tokens are masked.
batch_records.append(
{
"input_ids": [pad_token] * reserved_length,
"labels": [-100] * reserved_length,
"attn_mask_startend_row_indices": [0] * reserved_length,
}
)
elif "attention_mask" in batch_records[0]:
# attention_mask is a fullly masked attention matrix (all False)
# which indicates that all tokens are masked.
batch_records.append(
{
"input_ids": [pad_token] * reserved_length,
"labels": [-100] * reserved_length,
"attention_mask": np.zeros((reserved_length, reserved_length), dtype=bool),
}
)

return batch_records

@classmethod
def _pad_batch_records(cls, batch_records, max_length):
batch_records = cls._pad_batch_records_to_max_length(batch_records, max_length)

# Only consider supported input keys
input_keys = [key for key in batch_records[0].keys() if key in cls.supported_input_keys]
if "attn_mask_startend_row_indices" not in input_keys and "attention_mask" not in input_keys:
Expand Down Expand Up @@ -122,7 +157,7 @@ def _create_zero_padding_data(self, data):
cur_len_so_far += len(record["input_ids"])
else:
# exceed max length
padded_list = self._pad_batch_records(batch_records)
padded_list = self._pad_batch_records(batch_records, self.max_length)
total_data.append(padded_list)
# reset
batch_records = []
Expand All @@ -133,7 +168,7 @@ def _create_zero_padding_data(self, data):

# remaining data
if batch_records:
padded_list = self._pad_batch_records(batch_records)
padded_list = self._pad_batch_records(batch_records, self.max_length)
total_data.append(padded_list)
else:
examples = []
Expand All @@ -150,15 +185,15 @@ def _create_zero_padding_data(self, data):
generate_packs = generate_greedy_packs(examples, self.max_length)
for batch_records in generate_packs:
if len(batch_records) > 0:
padded_list = self._pad_batch_records(batch_records)
padded_list = self._pad_batch_records(batch_records, self.max_length)
total_data.append(padded_list)
examples = [record]
i = 1
if len(examples) > 0:
generate_packs = generate_greedy_packs(examples, self.max_length)
for batch_records in generate_packs:
if len(batch_records) > 0:
padded_list = self._pad_batch_records(batch_records)
padded_list = self._pad_batch_records(batch_records, self.max_length)
total_data.append(padded_list)

return total_data
Expand Down Expand Up @@ -190,7 +225,7 @@ def __iter__(self):
cur_len_so_far += len(record["input_ids"])
else:
# exceed max length
padded_list = self._pad_batch_records(batch_records)
padded_list = self._pad_batch_records(batch_records, self.max_length)
yield padded_list
# reset
batch_records = []
Expand All @@ -200,7 +235,7 @@ def __iter__(self):
self.zero_padding_global_step += 1
cur_len_so_far += len(record["input_ids"])
if batch_records:
padded_list = self._pad_batch_records(batch_records)
padded_list = self._pad_batch_records(batch_records, self.max_length)
yield padded_list
else:
examples = []
Expand All @@ -218,7 +253,7 @@ def __iter__(self):
generate_packs = generate_greedy_packs(examples, self.max_length)
for batch_records in generate_packs:
if len(batch_records) > 0:
padded_list = self._pad_batch_records(batch_records)
padded_list = self._pad_batch_records(batch_records, self.max_length)
yield padded_list
examples = [record]
self.zero_padding_global_step += 1
Expand All @@ -227,5 +262,5 @@ def __iter__(self):
generate_packs = generate_greedy_packs(examples, self.max_length)
for batch_records in generate_packs:
if len(batch_records) > 0:
padded_list = self._pad_batch_records(batch_records)
padded_list = self._pad_batch_records(batch_records, self.max_length)
yield padded_list
Loading