From 7d23fe8c05adf2c3ea545aab32a8a92b9edc685b Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Thu, 5 Dec 2024 21:59:33 +0800 Subject: [PATCH] add token_penalty for speculate decoding --- .../transformers/generation_utils.py | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/paddlenlp/experimental/transformers/generation_utils.py b/paddlenlp/experimental/transformers/generation_utils.py index 32ab14bbcc70..94749d323234 100644 --- a/paddlenlp/experimental/transformers/generation_utils.py +++ b/paddlenlp/experimental/transformers/generation_utils.py @@ -796,8 +796,27 @@ def _post_process_( temperature, model_kwargs, ): + step_idx = model_kwargs["step_idx"] logits = paddle.cast(outputs, paddle.float32) - # TODO(Wanglongzhi2001): token_penalty + + from paddlenlp_ops import speculate_get_token_penalty_multi_scores + + speculate_get_token_penalty_multi_scores( + model_kwargs["pre_ids"], + logits, + penalty_score, + frequency_score, + presence_score, + temperature, + model_kwargs["bad_tokens"], + step_idx, + model_kwargs["min_dec_len"], + eos_token_id, + model_kwargs["seq_lens_this_time"], + model_kwargs["output_padding_offset"], + model_kwargs["output_cum_offsets"], + self.max_seq_len, + ) # sample probs = F.softmax(logits)