Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add NoRepeatNGramLogitsProcessor #3977

Merged
merged 8 commits into from
Dec 8, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 69 additions & 2 deletions paddlenlp/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
import inspect
from abc import ABC
from typing import List

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.common_ops_import import convert_dtype
from paddle.fluid.layers.utils import map_structure

from paddlenlp.utils.log import logger

__all__ = ["GenerationMixin"]
Expand Down Expand Up @@ -306,6 +307,7 @@ def get_logits_processor(
num_beam_groups=1,
diversity_rate=0.0,
repetition_penalty=None,
no_repeat_ngram_size=None,
logits_processors=None,
):
processors = LogitsProcessorList()
Expand All @@ -320,6 +322,8 @@ def get_logits_processor(
)
if repetition_penalty is not None and repetition_penalty != 1.0:
processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
if forced_bos_token_id is not None:
processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
if forced_eos_token_id is not None:
Expand Down Expand Up @@ -503,7 +507,7 @@ def _build_faster(self, kwargs):
if kwargs["num_beam_groups"] != 1:
# not support for group_beam_search yet in the faster version
raise AttributeError("'num_beam_groups != 1' is not supported yet in the faster version")
if paddle.get_default_dtype() == "float16" and kwargs["use_fp16_decoding"] == False:
if paddle.get_default_dtype() == "float16" and kwargs["use_fp16_decoding"] is False:
logger.info(
"Since the default dtype is float16, float16 would be used " "though 'use_fp16_decoding=False'."
)
Expand Down Expand Up @@ -531,6 +535,7 @@ def generate(
decoder_start_token_id=None,
forced_bos_token_id=None,
forced_eos_token_id=None,
no_repeat_ngram_size=None,
num_return_sequences=1,
diversity_rate=0.0,
use_cache=True,
Expand Down Expand Up @@ -729,6 +734,9 @@ def generate(
if decoder_start_token_id is not None
else getattr(self, "decoder_start_token_id", None)
)
no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else getattr(self, "no_repeat_ngram_size", None)
)

if getattr(self, "_faster_entry", None) is not False and use_faster:
args = locals()
Expand Down Expand Up @@ -804,6 +812,7 @@ def generate(
num_beam_groups=num_beam_groups,
diversity_rate=diversity_rate,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
logits_processors=model_kwargs["logits_processors"]
if "logits_processors" in model_kwargs
and isinstance(model_kwargs["logits_processors"], LogitsProcessorList)
Expand Down Expand Up @@ -1337,6 +1346,64 @@ def __call__(self, input_ids, logits):
return outputs


def _get_ngrams(ngram_size, prev_input_ids, num_hypos):
generated_ngrams = [{} for _ in range(num_hypos)]
for idx in range(num_hypos):
gen_tokens = prev_input_ids[idx].tolist()
generated_ngram = generated_ngrams[idx]
for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
prev_ngram_tuple = tuple(ngram[:-1])
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
return generated_ngrams


def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
# Before decoding the next token, prevent decoding of ngrams that have already appeared
start_idx = cur_len + 1 - ngram_size
ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
return banned_ngrams.get(ngram_idx, [])


def _calc_banned_ngram_tokens(ngram_size, prev_input_ids, num_hypos, cur_len):
"""Copied from fairseq for no_repeat_ngram in beam_search"""
if cur_len + 1 < ngram_size:
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
return [[] for _ in range(num_hypos)]

generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)

banned_tokens = [
_get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len)
for hypo_idx in range(num_hypos)
]
return banned_tokens


class NoRepeatNGramLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that enforces no repetition of n-grams. See
[Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).
Args:
ngram_size (`int`):
All ngrams of size `ngram_size` can only occur once.
"""

def __init__(self, ngram_size):
if not isinstance(ngram_size, int) or ngram_size <= 0:
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
self.ngram_size = ngram_size

def __call__(self, input_ids, scores):
num_batch_hypotheses = scores.shape[0]
cur_len = input_ids.shape[-1]
banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len)

for i, banned_tokens in enumerate(banned_batch_tokens):
scores[i, banned_tokens] = -float("inf")

return scores


class HammingDiversityLogitsProcessor(LogitsProcessor):
"""
This `LogitsProcessor` enforces diverse beam search. Note that this logits
Expand Down