diff --git a/paddlenlp/datasets/zero_padding_dataset.py b/paddlenlp/datasets/zero_padding_dataset.py index 2a071fc11404..870394aaca33 100644 --- a/paddlenlp/datasets/zero_padding_dataset.py +++ b/paddlenlp/datasets/zero_padding_dataset.py @@ -53,7 +53,42 @@ class ZeroPadding: ] @classmethod - def _pad_batch_records(cls, batch_records): + def _pad_batch_records_to_max_length(cls, batch_records, max_length, pad_token=0): + # confirm the at least one item in the pack + if len(batch_records) == 0: + return batch_records + # count all records total length + total_length = sum([len(record["input_ids"]) for record in batch_records]) + reserved_length = max_length - total_length + + # append padding to the max_length + if "attn_mask_startend_row_indices" in batch_records[0]: + # attn_mask_startend_row_indices is a list of row indices `0`, + # which indicates that all tokens are masked. + batch_records.append( + { + "input_ids": [pad_token] * reserved_length, + "labels": [-100] * reserved_length, + "attn_mask_startend_row_indices": [0] * reserved_length, + } + ) + elif "attention_mask" in batch_records[0]: + # attention_mask is a fullly masked attention matrix (all False) + # which indicates that all tokens are masked. + batch_records.append( + { + "input_ids": [pad_token] * reserved_length, + "labels": [-100] * reserved_length, + "attention_mask": np.zeros((reserved_length, reserved_length), dtype=bool), + } + ) + + return batch_records + + @classmethod + def _pad_batch_records(cls, batch_records, max_length): + batch_records = cls._pad_batch_records_to_max_length(batch_records, max_length) + # Only consider supported input keys input_keys = [key for key in batch_records[0].keys() if key in cls.supported_input_keys] if "attn_mask_startend_row_indices" not in input_keys and "attention_mask" not in input_keys: @@ -122,7 +157,7 @@ def _create_zero_padding_data(self, data): cur_len_so_far += len(record["input_ids"]) else: # exceed max length - padded_list = self._pad_batch_records(batch_records) + padded_list = self._pad_batch_records(batch_records, self.max_length) total_data.append(padded_list) # reset batch_records = [] @@ -133,7 +168,7 @@ def _create_zero_padding_data(self, data): # remaining data if batch_records: - padded_list = self._pad_batch_records(batch_records) + padded_list = self._pad_batch_records(batch_records, self.max_length) total_data.append(padded_list) else: examples = [] @@ -150,7 +185,7 @@ def _create_zero_padding_data(self, data): generate_packs = generate_greedy_packs(examples, self.max_length) for batch_records in generate_packs: if len(batch_records) > 0: - padded_list = self._pad_batch_records(batch_records) + padded_list = self._pad_batch_records(batch_records, self.max_length) total_data.append(padded_list) examples = [record] i = 1 @@ -158,7 +193,7 @@ def _create_zero_padding_data(self, data): generate_packs = generate_greedy_packs(examples, self.max_length) for batch_records in generate_packs: if len(batch_records) > 0: - padded_list = self._pad_batch_records(batch_records) + padded_list = self._pad_batch_records(batch_records, self.max_length) total_data.append(padded_list) return total_data @@ -190,7 +225,7 @@ def __iter__(self): cur_len_so_far += len(record["input_ids"]) else: # exceed max length - padded_list = self._pad_batch_records(batch_records) + padded_list = self._pad_batch_records(batch_records, self.max_length) yield padded_list # reset batch_records = [] @@ -200,7 +235,7 @@ def __iter__(self): self.zero_padding_global_step += 1 cur_len_so_far += len(record["input_ids"]) if batch_records: - padded_list = self._pad_batch_records(batch_records) + padded_list = self._pad_batch_records(batch_records, self.max_length) yield padded_list else: examples = [] @@ -218,7 +253,7 @@ def __iter__(self): generate_packs = generate_greedy_packs(examples, self.max_length) for batch_records in generate_packs: if len(batch_records) > 0: - padded_list = self._pad_batch_records(batch_records) + padded_list = self._pad_batch_records(batch_records, self.max_length) yield padded_list examples = [record] self.zero_padding_global_step += 1 @@ -227,5 +262,5 @@ def __iter__(self): generate_packs = generate_greedy_packs(examples, self.max_length) for batch_records in generate_packs: if len(batch_records) > 0: - padded_list = self._pad_batch_records(batch_records) + padded_list = self._pad_batch_records(batch_records, self.max_length) yield padded_list