From 5c4fe71e9854cb02ac03160ffd7e6d935cb9d82a Mon Sep 17 00:00:00 2001 From: caimingzhu Date: Thu, 1 Dec 2022 17:01:00 +0800 Subject: [PATCH 1/7] add NoRepeatNGramLogitsProcessor --- paddlenlp/transformers/generation_utils.py | 66 ++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/paddlenlp/transformers/generation_utils.py b/paddlenlp/transformers/generation_utils.py index 39c4277ff821..c950163d6b09 100644 --- a/paddlenlp/transformers/generation_utils.py +++ b/paddlenlp/transformers/generation_utils.py @@ -331,6 +331,7 @@ def get_logits_processor(self, num_beam_groups=1, diversity_rate=0.0, repetition_penalty=None, + no_repeat_ngram_size=None, logits_processors=None): processors = LogitsProcessorList() @@ -346,6 +347,9 @@ def get_logits_processor(self, if repetition_penalty is not None and repetition_penalty != 1.0: processors.append( RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) + if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0: + processors.append( + NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)) if forced_bos_token_id is not None: processors.append( ForcedBOSTokenLogitsProcessor(forced_bos_token_id)) @@ -584,6 +588,7 @@ def generate(self, top_k=0, top_p=1.0, repetition_penalty=1.0, + no_repeat_ngram_size=None, num_beams=1, num_beam_groups=1, length_penalty=0.0, @@ -786,6 +791,8 @@ def generate(self, self, 'forced_eos_token_id', None) decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else getattr( self, 'decoder_start_token_id', None) + no_repeat_ngram_size = no_repeat_ngram_size if no_repeat_ngram_size is not None else getattr( + self, 'no_repeat_ngram_size', None) if getattr(self, '_faster_entry', None) is not False and use_faster: args = locals() @@ -869,6 +876,7 @@ def generate(self, num_beam_groups=num_beam_groups, diversity_rate=diversity_rate, repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, logits_processors=model_kwargs["logits_processors"] if "logits_processors" in model_kwargs and isinstance( model_kwargs["logits_processors"], LogitsProcessorList) else @@ -1426,6 +1434,64 @@ def __call__(self, input_ids, logits): return outputs +def _get_ngrams(ngram_size, prev_input_ids, num_hypos): + generated_ngrams = [{} for _ in range(num_hypos)] + for idx in range(num_hypos): + gen_tokens = prev_input_ids[idx].tolist() + generated_ngram = generated_ngrams[idx] + for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]): + prev_ngram_tuple = tuple(ngram[:-1]) + generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] + return generated_ngrams + + +def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len): + # Before decoding the next token, prevent decoding of ngrams that have already appeared + start_idx = cur_len + 1 - ngram_size + ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist()) + return banned_ngrams.get(ngram_idx, []) + + +def _calc_banned_ngram_tokens(ngram_size, prev_input_ids, num_hypos, cur_len): + """Copied from fairseq for no_repeat_ngram in beam_search""" + if cur_len + 1 < ngram_size: + # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet + return [[] for _ in range(num_hypos)] + + generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos) + + banned_tokens = [ + _get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len) + for hypo_idx in range(num_hypos) + ] + return banned_tokens + + +class NoRepeatNGramLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] that enforces no repetition of n-grams. See + [Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345). + Args: + ngram_size (`int`): + All ngrams of size `ngram_size` can only occur once. + """ + + def __init__(self, ngram_size): + if not isinstance(ngram_size, int) or ngram_size <= 0: + raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") + self.ngram_size = ngram_size + + def __call__(self, input_ids, scores): + num_batch_hypotheses = scores.shape[0] + cur_len = input_ids.shape[-1] + banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len) + + for i, banned_tokens in enumerate(banned_batch_tokens): + scores[i, banned_tokens] = -float("inf") + + return scores + + class HammingDiversityLogitsProcessor(LogitsProcessor): """ This `LogitsProcessor` enforces diverse beam search. Note that this logits From 1037f98c5103fb7025592abad570b526008fbda8 Mon Sep 17 00:00:00 2001 From: christineaa <49200582+christineaa@users.noreply.github.com> Date: Thu, 8 Dec 2022 17:07:36 +0800 Subject: [PATCH 2/7] Update generation_utils.py solve invalid syntax --- paddlenlp/transformers/generation_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/paddlenlp/transformers/generation_utils.py b/paddlenlp/transformers/generation_utils.py index 530fb904136d..af3f0b7d98f0 100644 --- a/paddlenlp/transformers/generation_utils.py +++ b/paddlenlp/transformers/generation_utils.py @@ -734,6 +734,7 @@ def generate( no_repeat_ngram_size if no_repeat_ngram_size is not None else getattr(self, 'no_repeat_ngram_size', None) + ) if getattr(self, "_faster_entry", None) is not False and use_faster: args = locals() From 7e0e5f1a020dbb32dbf639634e7f59dee78b801d Mon Sep 17 00:00:00 2001 From: christineaa <49200582+christineaa@users.noreply.github.com> Date: Thu, 8 Dec 2022 17:08:20 +0800 Subject: [PATCH 3/7] Update generation_utils.py --- paddlenlp/transformers/generation_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/paddlenlp/transformers/generation_utils.py b/paddlenlp/transformers/generation_utils.py index af3f0b7d98f0..7983ecd0d551 100644 --- a/paddlenlp/transformers/generation_utils.py +++ b/paddlenlp/transformers/generation_utils.py @@ -708,6 +708,7 @@ def generate( print(response) # ['是的', '嗯嗯'] """ + assert decode_strategy in [ "greedy_search", "sampling", From 5e9005d20737c9b4caef7f208253ac56d165c40a Mon Sep 17 00:00:00 2001 From: christineaa <49200582+christineaa@users.noreply.github.com> Date: Thu, 8 Dec 2022 17:08:53 +0800 Subject: [PATCH 4/7] Update generation_utils.py --- paddlenlp/transformers/generation_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlenlp/transformers/generation_utils.py b/paddlenlp/transformers/generation_utils.py index 7983ecd0d551..91610958cf6c 100644 --- a/paddlenlp/transformers/generation_utils.py +++ b/paddlenlp/transformers/generation_utils.py @@ -708,7 +708,7 @@ def generate( print(response) # ['是的', '嗯嗯'] """ - + assert decode_strategy in [ "greedy_search", "sampling", From d2a58d4d119cc2e26d11d6068d7f20accc6fe901 Mon Sep 17 00:00:00 2001 From: christineaa <49200582+christineaa@users.noreply.github.com> Date: Thu, 8 Dec 2022 17:40:15 +0800 Subject: [PATCH 5/7] Update generation_utils.py add no_repeat_ngram_size=None --- paddlenlp/transformers/generation_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddlenlp/transformers/generation_utils.py b/paddlenlp/transformers/generation_utils.py index 91610958cf6c..c371b8e24517 100644 --- a/paddlenlp/transformers/generation_utils.py +++ b/paddlenlp/transformers/generation_utils.py @@ -306,6 +306,7 @@ def get_logits_processor( num_beam_groups=1, diversity_rate=0.0, repetition_penalty=None, + no_repeat_ngram_size=None, logits_processors=None, ): processors = LogitsProcessorList() @@ -533,6 +534,7 @@ def generate( decoder_start_token_id=None, forced_bos_token_id=None, forced_eos_token_id=None, + no_repeat_ngram_size=None, num_return_sequences=1, diversity_rate=0.0, use_cache=True, From c715d46c1c179ed9e1983faaf589c5088514a391 Mon Sep 17 00:00:00 2001 From: christineaa <49200582+christineaa@users.noreply.github.com> Date: Thu, 8 Dec 2022 17:42:02 +0800 Subject: [PATCH 6/7] Update generation_utils.py fix code style --- paddlenlp/transformers/generation_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlenlp/transformers/generation_utils.py b/paddlenlp/transformers/generation_utils.py index c371b8e24517..1115251b536e 100644 --- a/paddlenlp/transformers/generation_utils.py +++ b/paddlenlp/transformers/generation_utils.py @@ -506,7 +506,7 @@ def _build_faster(self, kwargs): if kwargs["num_beam_groups"] != 1: # not support for group_beam_search yet in the faster version raise AttributeError("'num_beam_groups != 1' is not supported yet in the faster version") - if paddle.get_default_dtype() == "float16" and kwargs["use_fp16_decoding"] == False: + if paddle.get_default_dtype() == "float16" and kwargs["use_fp16_decoding"] is False: logger.info( "Since the default dtype is float16, float16 would be used " "though 'use_fp16_decoding=False'." ) From 4a1fb32ca52206a256cbf37e45ebb4e256708943 Mon Sep 17 00:00:00 2001 From: caimingzhu Date: Thu, 8 Dec 2022 19:15:17 +0800 Subject: [PATCH 7/7] fix code style --- paddlenlp/transformers/generation_utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/paddlenlp/transformers/generation_utils.py b/paddlenlp/transformers/generation_utils.py index 1115251b536e..7bb8590d4104 100644 --- a/paddlenlp/transformers/generation_utils.py +++ b/paddlenlp/transformers/generation_utils.py @@ -13,15 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List import inspect from abc import ABC +from typing import List import paddle import paddle.nn as nn import paddle.nn.functional as F from paddle.common_ops_import import convert_dtype from paddle.fluid.layers.utils import map_structure + from paddlenlp.utils.log import logger __all__ = ["GenerationMixin"] @@ -734,9 +735,7 @@ def generate( else getattr(self, "decoder_start_token_id", None) ) no_repeat_ngram_size = ( - no_repeat_ngram_size - if no_repeat_ngram_size is not None - else getattr(self, 'no_repeat_ngram_size', None) + no_repeat_ngram_size if no_repeat_ngram_size is not None else getattr(self, "no_repeat_ngram_size", None) ) if getattr(self, "_faster_entry", None) is not False and use_faster: