Skip to content

Commit

Permalink
[bugfix] bugfix of repetition_penalty compute.
Browse files Browse the repository at this point in the history
  • Loading branch information
wangzhaode committed Jul 4, 2024
1 parent cd82d79 commit 289fa44
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/llm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <iostream>
#include <fstream>
#include <sstream>
#include <unordered_set>
#include <regex>

#include <MNN/expr/ExecutorScope.hpp>
Expand Down Expand Up @@ -170,17 +171,18 @@ VARP Llm::forward(const std::vector<int>& input_ids) {
}

int Llm::sample(VARP logits, const std::vector<int>& pre_ids) {
auto scores = logits->writeMap<float>();
std::unordered_set<int> ids_set(pre_ids.begin(), pre_ids.end());
auto scores = (float*)(logits->readMap<float>());
auto size = logits->getInfo()->size;
float max_score = 0;
int token_id = 0;
// repetition penalty
const float repetition_penalty = 1.1;
for (auto id : pre_ids) {
for (auto id : ids_set) {
float score = scores[id];
scores[id] = score < 0 ? score * repetition_penalty : score / repetition_penalty;
}
// argmax
float max_score = 0;
int token_id = 0;
for (int i = 0; i < size; i++) {
float score = scores[i];
if (score > max_score) {
Expand Down

0 comments on commit 289fa44

Please sign in to comment.