Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand S2S range and make pad trimming controllable #235

Merged
merged 3 commits into from
Mar 17, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 43 additions & 19 deletions examples/llm/src/data/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def ul2_prefix_function(
"""
if mean_length is None:
# This is the case for "sequence to sequence"
prefix = '[S2S]'
prefix = '[S2S]' if mask_ratio < 1.0 else '[CLM]'
elif mean_length >= 12 or mask_ratio >= 0.3:
# UL2 tags this corruption rate "extreme"
prefix = '[NLG]'
Expand Down Expand Up @@ -101,13 +101,20 @@ class MixtureOfDenoisersCollator:
Default: ``None`` does not add any span corruption tasks.
sequence_mask_ratios (optional): A float or list of floats, one for each
sequence corruption denoising task to add to the task mixture. Each
sequence mask ratio must be greater than 0.0 and less than 0.5.
sequence mask ratio must be greater than 0.0 and at most 1.0.
This type of task is a special instance of span corruption, with
exactly one masked span take from the end of the sequence. The
length of the span is sampled uniformly from
[1, 2*mask_ratio*n_tokens], where n_tokens is the length of the
unmasked token sequence.
length of the span is sampled uniformly such that the average
portion of masked tokens equals sequence_mask_ratio.
Note: A value of 1.0 essentially yields causal LM examples.
Default: ``None` does not add any sequence corruption tasks.
allow_pad_trimming (bool, optional): Whether to allow the collator to
trim away sequence regions that are entirely padding (i.e. padding
for each example in the batch). If ``True``, shorter sequences may
improve throughput but at a potentially higher memory cost
owing to variable sequence lengths from batch to batch.
Default: ``False`` yields batches that are always padded to
max_seq_length.
prefix_function (callable, optional): A function that maps denoising
task parameters (e.g. mean_length=3, mask_ratio=0.15) to a prefix
that will be added to sequences when the associated "noiser" is
Expand All @@ -124,13 +131,18 @@ def __init__(
decoder_only_format: bool = False,
span_mean_lengths_and_ratios: Optional[List] = None,
sequence_mask_ratios: Optional[Union[List[float], float]] = None,
allow_pad_trimming: bool = False,
prefix_function: Optional[PREFIX_FUNCTION] = ul2_prefix_function,
):
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
self.decoder_only_format = decoder_only_format
self._sentinel_token_ids = np.array(self.tokenizer.sentinel_token_ids)

# Trimming will always be skipped on at least the first __call__
self._allow_pad_trimming = allow_pad_trimming
self._seen_first_batch = False

# Prepare the tokenizer for denoising tasks
utils.adapt_tokenizer_for_denoising(self.tokenizer)

Expand Down Expand Up @@ -165,9 +177,9 @@ def __init__(
else:
# In this case, there is one or more sequence corruption tasks
for ratio in sequence_mask_ratios:
if not (0 < ratio < 0.5):
if not (0 < ratio <= 1.0):
raise ValueError('`sequence_mask_ratios` must be a float (or list '+\
'of floats) between 0.0 and 0.5, or None. '+\
'of floats) that are each >0.0 and <=1.0, or None. '+\
f'Got {sequence_mask_ratios}.')
self.sequence_mask_ratios = sequence_mask_ratios

Expand Down Expand Up @@ -274,6 +286,12 @@ def __call__(self, examples: List[Dict[str,
decoder_only_format=self.decoder_only_format))
batch = self.tokenizer.pad(processed_examples)

# This logic prevents trimming on at least the first batch
if not (self._allow_pad_trimming and self._seen_first_batch):
self._seen_first_batch = True
return batch
self._seen_first_batch = True

# Truncate portions of the inputs that are purely padding
# (up to a multiple of 8)
multiple_of = 8
Expand Down Expand Up @@ -349,15 +367,18 @@ def build_text_denoising_dataloader(cfg: DictConfig,
cfg.mixture_of_denoisers.decoder_only_format (bool): Whether the
batches should use the format required for training a decoder-only
model (if ``True``) or an encoder-decoder model (if ``False``).
cfg.mixture_of_denoisers.span_mean_lengths_and_ratios (optiona): The
cfg.mixture_of_denoisers.span_mean_lengths_and_ratios (optional): The
parameters for any span corruption denoising tasks to include in
the task mixture.
See :class:`MixtureOfDenoisersCollator` docstring for details.
cfg.mixture_of_denoisers.sequence_mask_ratios (optiona): The
cfg.mixture_of_denoisers.sequence_mask_ratios (optional): The
parameters for any sequence denoising tasks to include in the
task mixture.
See :class:`MixtureOfDenoisersCollator` docstring for details.
cfg.mixture_of_denoisers.prefix_function (optiona): Set to ``None``
cfg.mixture_of_denoisers.allow_pad_trimming (optional): Whether to
allow the collator to trim padding when possible (if ``True``).
Defaults to ``False``.
cfg.mixture_of_denoisers.prefix_function (optional): Set to ``None``
to disable the UL2-style prefixes that will be automatically
added by default.
---
Expand All @@ -377,6 +398,8 @@ def build_text_denoising_dataloader(cfg: DictConfig,
'span_mean_lengths_and_ratios'),
sequence_mask_ratios=cfg.mixture_of_denoisers.get(
'sequence_mask_ratios'),
allow_pad_trimming=cfg.mixture_of_denoisers.get('allow_pad_trimming',
False),
prefix_function=cfg.mixture_of_denoisers.get('prefix_function',
ul2_prefix_function))

Expand Down Expand Up @@ -463,20 +486,21 @@ def noise_token_sequence(

prefix_tokens = prefix_tokens or []

if length < 1:
raise ValueError('Example cannot be empty but token length <1.')

# mean_span_length==None is a special case for "sequential" denoising
# (where a single span at the end of the sequence is masked)
if mean_span_length is None:
# This ensures that exactly 1 span will be produced and that
# trimming to max_seq_length won't cut off any <EOS> token.
# In the decoder-only case, this won't insert new tokens.
min_span_length = np.maximum(
1, length + len(prefix_tokens) - max_seq_length)
max_span_length = np.maximum(
min_span_length, np.minimum(length - 1, 2 * mask_ratio * length))
mean_span_length = float(
np.floor(
np.random.uniform(low=min_span_length, high=max_span_length)))
mask_ratio = mean_span_length / length
if mask_ratio <= 0.5:
u = np.random.uniform(low=0.0, high=mask_ratio * 2)
else:
u = np.random.uniform(low=(mask_ratio * 2) - 1, high=1.0)
mean_span_length = float(np.round(1 + u * (length - 1)))
mask_ratio = mean_span_length / length # type: ignore
use_sentinels = False
else:
use_sentinels = True
Expand Down Expand Up @@ -714,7 +738,7 @@ def _format_tokens_for_decoder_only(
n_input = len(tokens_inputs)
n_label = len(tokens_labels)
n_concat = n_input + n_label
assert n_concat <= max_seq_length
assert n_concat <= max_seq_length, f'{n_concat=}, {n_input=}, {n_label=}'

tokens_concat = torch.concat([tokens_inputs, tokens_labels], dim=0)

Expand Down