Skip to content

Commit

Permalink
add token_penalty for speculate decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
Wanglongzhi2001 committed Dec 5, 2024
1 parent cecc2b8 commit 7d23fe8
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion paddlenlp/experimental/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,8 +796,27 @@ def _post_process_(
temperature,
model_kwargs,
):
step_idx = model_kwargs["step_idx"]
logits = paddle.cast(outputs, paddle.float32)
# TODO(Wanglongzhi2001): token_penalty

from paddlenlp_ops import speculate_get_token_penalty_multi_scores

speculate_get_token_penalty_multi_scores(
model_kwargs["pre_ids"],
logits,
penalty_score,
frequency_score,
presence_score,
temperature,
model_kwargs["bad_tokens"],
step_idx,
model_kwargs["min_dec_len"],
eos_token_id,
model_kwargs["seq_lens_this_time"],
model_kwargs["output_padding_offset"],
model_kwargs["output_cum_offsets"],
self.max_seq_len,
)

# sample
probs = F.softmax(logits)
Expand Down

0 comments on commit 7d23fe8

Please sign in to comment.