From a41d87b691162446114660743891ebe4e5def43d Mon Sep 17 00:00:00 2001 From: drownfish19 Date: Tue, 20 Aug 2024 14:23:14 +0000 Subject: [PATCH 1/3] fix zero_padding for sequence parallel --- paddlenlp/datasets/zero_padding_dataset.py | 50 ++++++++++++++++++---- 1 file changed, 41 insertions(+), 9 deletions(-) diff --git a/paddlenlp/datasets/zero_padding_dataset.py b/paddlenlp/datasets/zero_padding_dataset.py index 2a071fc11404..b8d77b21e163 100644 --- a/paddlenlp/datasets/zero_padding_dataset.py +++ b/paddlenlp/datasets/zero_padding_dataset.py @@ -53,7 +53,39 @@ class ZeroPadding: ] @classmethod - def _pad_batch_records(cls, batch_records): + def _pad_batch_records_to_max_length(cls, batch_records, max_length): + for records in batch_records: + # confirm the at least one item in the pack + if len(records) == 0: + continue + # count all records total length + total_length = sum([len(record["input_ids"]) for record in records]) + reserved_length = max_length - total_length + + # append padding to the max_length + if "attn_mask_startend_row_indices" in records[0]: + records.append( + { + "input_ids": [-100] * reserved_length, + "labels": [-100] * reserved_length, + "attn_mask_startend_row_indices": [0] * reserved_length, + } + ) + elif "attention_mask" in records[0]: + records.append( + { + "input_ids": [-100] * reserved_length, + "labels": [-100] * reserved_length, + "attention_mask": [True] * reserved_length, + } + ) + + return batch_records + + @classmethod + def _pad_batch_records(cls, batch_records, max_length): + batch_records = cls._pad_batch_records_to_max_length(batch_records, max_length) + # Only consider supported input keys input_keys = [key for key in batch_records[0].keys() if key in cls.supported_input_keys] if "attn_mask_startend_row_indices" not in input_keys and "attention_mask" not in input_keys: @@ -122,7 +154,7 @@ def _create_zero_padding_data(self, data): cur_len_so_far += len(record["input_ids"]) else: # exceed max length - padded_list = self._pad_batch_records(batch_records) + padded_list = self._pad_batch_records(batch_records, self.max_length) total_data.append(padded_list) # reset batch_records = [] @@ -133,7 +165,7 @@ def _create_zero_padding_data(self, data): # remaining data if batch_records: - padded_list = self._pad_batch_records(batch_records) + padded_list = self._pad_batch_records(batch_records, self.max_length) total_data.append(padded_list) else: examples = [] @@ -150,7 +182,7 @@ def _create_zero_padding_data(self, data): generate_packs = generate_greedy_packs(examples, self.max_length) for batch_records in generate_packs: if len(batch_records) > 0: - padded_list = self._pad_batch_records(batch_records) + padded_list = self._pad_batch_records(batch_records, self.max_length) total_data.append(padded_list) examples = [record] i = 1 @@ -158,7 +190,7 @@ def _create_zero_padding_data(self, data): generate_packs = generate_greedy_packs(examples, self.max_length) for batch_records in generate_packs: if len(batch_records) > 0: - padded_list = self._pad_batch_records(batch_records) + padded_list = self._pad_batch_records(batch_records, self.max_length) total_data.append(padded_list) return total_data @@ -190,7 +222,7 @@ def __iter__(self): cur_len_so_far += len(record["input_ids"]) else: # exceed max length - padded_list = self._pad_batch_records(batch_records) + padded_list = self._pad_batch_records(batch_records, self.max_length) yield padded_list # reset batch_records = [] @@ -200,7 +232,7 @@ def __iter__(self): self.zero_padding_global_step += 1 cur_len_so_far += len(record["input_ids"]) if batch_records: - padded_list = self._pad_batch_records(batch_records) + padded_list = self._pad_batch_records(batch_records, self.max_length) yield padded_list else: examples = [] @@ -218,7 +250,7 @@ def __iter__(self): generate_packs = generate_greedy_packs(examples, self.max_length) for batch_records in generate_packs: if len(batch_records) > 0: - padded_list = self._pad_batch_records(batch_records) + padded_list = self._pad_batch_records(batch_records, self.max_length) yield padded_list examples = [record] self.zero_padding_global_step += 1 @@ -227,5 +259,5 @@ def __iter__(self): generate_packs = generate_greedy_packs(examples, self.max_length) for batch_records in generate_packs: if len(batch_records) > 0: - padded_list = self._pad_batch_records(batch_records) + padded_list = self._pad_batch_records(batch_records, self.max_length) yield padded_list From 8fe6cb5775aaaf5d2720b401cafe7e13afed20d8 Mon Sep 17 00:00:00 2001 From: drownfish19 Date: Wed, 21 Aug 2024 03:29:10 +0000 Subject: [PATCH 2/3] fix --- paddlenlp/datasets/zero_padding_dataset.py | 49 +++++++++++----------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/paddlenlp/datasets/zero_padding_dataset.py b/paddlenlp/datasets/zero_padding_dataset.py index b8d77b21e163..4a8d51eb8ac6 100644 --- a/paddlenlp/datasets/zero_padding_dataset.py +++ b/paddlenlp/datasets/zero_padding_dataset.py @@ -53,32 +53,31 @@ class ZeroPadding: ] @classmethod - def _pad_batch_records_to_max_length(cls, batch_records, max_length): - for records in batch_records: - # confirm the at least one item in the pack - if len(records) == 0: - continue - # count all records total length - total_length = sum([len(record["input_ids"]) for record in records]) - reserved_length = max_length - total_length + def _pad_batch_records_to_max_length(cls, batch_records, max_length, pad_token=0): + # confirm the at least one item in the pack + if len(batch_records) == 0: + return batch_records + # count all records total length + total_length = sum([len(record["input_ids"]) for record in batch_records]) + reserved_length = max_length - total_length - # append padding to the max_length - if "attn_mask_startend_row_indices" in records[0]: - records.append( - { - "input_ids": [-100] * reserved_length, - "labels": [-100] * reserved_length, - "attn_mask_startend_row_indices": [0] * reserved_length, - } - ) - elif "attention_mask" in records[0]: - records.append( - { - "input_ids": [-100] * reserved_length, - "labels": [-100] * reserved_length, - "attention_mask": [True] * reserved_length, - } - ) + # append padding to the max_length + if "attn_mask_startend_row_indices" in batch_records[0]: + batch_records.append( + { + "input_ids": [pad_token] * reserved_length, + "labels": [-100] * reserved_length, + "attn_mask_startend_row_indices": [0] * reserved_length, + } + ) + elif "attention_mask" in batch_records[0]: + batch_records.append( + { + "input_ids": [pad_token] * reserved_length, + "labels": [-100] * reserved_length, + "attention_mask": np.zeros((reserved_length, reserved_length), dtype=bool), + } + ) return batch_records From a4b0264cd116d7c8d0b8012e1e784390a1bdf2d3 Mon Sep 17 00:00:00 2001 From: drownfish19 Date: Wed, 21 Aug 2024 04:19:34 +0000 Subject: [PATCH 3/3] add comments --- paddlenlp/datasets/zero_padding_dataset.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/paddlenlp/datasets/zero_padding_dataset.py b/paddlenlp/datasets/zero_padding_dataset.py index 4a8d51eb8ac6..870394aaca33 100644 --- a/paddlenlp/datasets/zero_padding_dataset.py +++ b/paddlenlp/datasets/zero_padding_dataset.py @@ -63,6 +63,8 @@ def _pad_batch_records_to_max_length(cls, batch_records, max_length, pad_token=0 # append padding to the max_length if "attn_mask_startend_row_indices" in batch_records[0]: + # attn_mask_startend_row_indices is a list of row indices `0`, + # which indicates that all tokens are masked. batch_records.append( { "input_ids": [pad_token] * reserved_length, @@ -71,6 +73,8 @@ def _pad_batch_records_to_max_length(cls, batch_records, max_length, pad_token=0 } ) elif "attention_mask" in batch_records[0]: + # attention_mask is a fullly masked attention matrix (all False) + # which indicates that all tokens are masked. batch_records.append( { "input_ids": [pad_token] * reserved_length,