From 289fa44b61eb923fa0eb4ba334ece82250d97cd2 Mon Sep 17 00:00:00 2001 From: yanxing Date: Thu, 4 Jul 2024 17:42:15 +0800 Subject: [PATCH] [bugfix] bugfix of repetition_penalty compute. --- src/llm.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/llm.cpp b/src/llm.cpp index 8e63f675..ae319220 100644 --- a/src/llm.cpp +++ b/src/llm.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -170,17 +171,18 @@ VARP Llm::forward(const std::vector& input_ids) { } int Llm::sample(VARP logits, const std::vector& pre_ids) { - auto scores = logits->writeMap(); + std::unordered_set ids_set(pre_ids.begin(), pre_ids.end()); + auto scores = (float*)(logits->readMap()); auto size = logits->getInfo()->size; - float max_score = 0; - int token_id = 0; // repetition penalty const float repetition_penalty = 1.1; - for (auto id : pre_ids) { + for (auto id : ids_set) { float score = scores[id]; scores[id] = score < 0 ? score * repetition_penalty : score / repetition_penalty; } // argmax + float max_score = 0; + int token_id = 0; for (int i = 0; i < size; i++) { float score = scores[i]; if (score > max_score) {