Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] DataCollatorForTextInfilling #12370

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions src/transformers/data/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import random
import warnings
from dataclasses import dataclass
Expand Down Expand Up @@ -511,6 +512,108 @@ def mask_tokens(self, inputs: torch.Tensor, mask_labels: torch.Tensor) -> Tuple[
return inputs, labels


@dataclass
class DataCollatorForTextInfilling:
tokenizer: PreTrainedTokenizerBase
mlm_probability: float = 0.15
poisson_lambda: float = 3.0
pad_to_multiple_of: Optional[int] = None

def __post_init__(self):
if self.tokenizer.mask_token is None:
raise ValueError

def __call__(self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
) -> Dict[str, torch.Tensor]:
# Handle dict or lists with proper padding and conversion to tensor.
if isinstance(examples[0], (dict, BatchEncoding)):
batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
else:
batch = {"input_ids": _collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)}

# If special token mask has been preprocessed, pop it from the dict.
special_tokens_mask = batch.pop("special_tokens_mask", None)

batch["input_ids"], batch["labels"] = self.mask_tokens(
batch["input_ids"], special_tokens_mask=special_tokens_mask
)

return batch

def mask_tokens(self,
inputs: torch.Tensor,
special_tokens_mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
labels = inputs.clone()

if special_tokens_mask is None:
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
else:
special_tokens_mask = special_tokens_mask.bool()

# determine how many tokens we need to mask in total
is_token = ~(inputs == self.tokenizer.pad_token_id) & ~special_tokens_mask
num_to_mask = int(math.ceil(is_token.float().sum() * self.mlm_probability))

if num_to_mask == 0:
return inputs, labels

# generate a sufficient number of span lengths
poisson_distribution = torch.distributions.Poisson(rate=self.poisson_lambda)
lengths = poisson_distribution.sample(sample_shape=(num_to_mask,))
while torch.cumsum(lengths, 0)[-1] < num_to_mask:
lengths = torch.cat([lengths, poisson_distribution.sample(sample_shape=(num_to_mask,))])

# remove all spans of length 0
# Note that BART inserts additional mask tokens where length == 0,
# which we do not implement for now as it adds additional complexity
lengths = lengths[lengths > 0]

# trim to about num_to_mask tokens
idx = torch.argmin(torch.abs(torch.cumsum(lengths, 0) - num_to_mask)) + 1
lengths = lengths[:idx + 1]

# select span start indices
token_indices = is_token.nonzero(as_tuple=False)
span_starts = torch.randperm(token_indices.shape[0])[:lengths.shape[0]]

# prepare mask
masked_indices = token_indices[span_starts]
mask = torch.full_like(inputs, fill_value=False)

# mask span start indices
for mi in masked_indices:
mask[tuple(mi)] = True
lengths -= 1

# fill up spans
max_index = inputs.shape[1] - 1
remaining = (lengths > 0) & (masked_indices[:, 1] < max_index)
while torch.any(remaining):
masked_indices[remaining, 1] += 1
for mi in masked_indices:
mask[tuple(mi)] = True
lengths -= 1
remaining = (lengths > 0) & (masked_indices[:, 1] < max_index)

# place the mask tokens
mask[special_tokens_mask] = False
inputs[mask.bool()] = self.tokenizer.mask_token_id
labels[~mask.bool()] = -100

# remove mask tokens that are not starts of spans
to_remove = mask.bool() & mask.bool().roll(1, 1)
new_inputs = torch.full_like(inputs, fill_value=self.tokenizer.pad_token_id)
for i, example in enumerate(torch.split(inputs, split_size_or_sections=1, dim=0)):
new_example = example[0][~to_remove[i]]
new_inputs[i, 0:new_example.shape[0]] = new_example

return new_inputs, labels


@dataclass
class DataCollatorForSOP(DataCollatorForLanguageModeling):
"""
Expand Down