From 2d3332e26f040ff23d6bec5877747e87c5f6f64f Mon Sep 17 00:00:00 2001 From: shibukazu Date: Sun, 5 Feb 2023 23:01:21 +0900 Subject: [PATCH 1/2] add non-speech-token suppression --- whisper.cpp | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/whisper.cpp b/whisper.cpp index aedd343ae6f..903b7b5cd54 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -3133,6 +3133,35 @@ static void whisper_process_logits( logits[vocab.token_translate] = -INFINITY; logits[vocab.token_transcribe] = -INFINITY; + + // suppress non-speech tokens + // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 + std::vector non_speech_tokens{ + "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^", + "_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--", + "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪", + "♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯" + }; + + for (const std::string &token : non_speech_tokens) + { + std::string suppress_tokens[] = {token, " " + token}; + for (const std::string &suppress_token : suppress_tokens) + { + if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) + { + logits[vocab.token_to_id.at(suppress_token)] = -INFINITY; + } + } + } + // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word + if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) { + logits[vocab.token_to_id.at(" -")] = -INFINITY; + } + if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) { + logits[vocab.token_to_id.at(" '")] = -INFINITY; + } + // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424 { From a8f0bd4e89b7ad6d1af5daf6cb0fa9ce004091bb Mon Sep 17 00:00:00 2001 From: shibukazu Date: Mon, 6 Feb 2023 03:04:53 +0900 Subject: [PATCH 2/2] add suppress non-speech_tokens param --- whisper.cpp | 45 ++++++++++++++++++++++++++------------------- whisper.h | 1 + 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 903b7b5cd54..a1ea5a58c6e 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2936,6 +2936,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.language =*/ "en", /*.suppress_blank =*/ true, + /*.suppress_non_speech_tokens =*/true, /*.temperature =*/ 0.0f, /*.max_initial_ts =*/ 1.0f, @@ -3073,6 +3074,14 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool return res; } +static const std::vector non_speech_tokens +{ + "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^", + "_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--", + "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪", + "♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯" +}; + // process the logits for the selected decoder // - applies logit filters // - computes logprobs and probs @@ -3136,30 +3145,28 @@ static void whisper_process_logits( // suppress non-speech tokens // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 - std::vector non_speech_tokens{ - "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^", - "_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--", - "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪", - "♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯" - }; - - for (const std::string &token : non_speech_tokens) + if (params.suppress_non_speech_tokens) { - std::string suppress_tokens[] = {token, " " + token}; - for (const std::string &suppress_token : suppress_tokens) + for (const std::string &token : non_speech_tokens) { - if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) + std::string suppress_tokens[] = {token, " " + token}; + for (const std::string &suppress_token : suppress_tokens) { - logits[vocab.token_to_id.at(suppress_token)] = -INFINITY; + if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) + { + logits[vocab.token_to_id.at(suppress_token)] = -INFINITY; + } } } - } - // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word - if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) { - logits[vocab.token_to_id.at(" -")] = -INFINITY; - } - if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) { - logits[vocab.token_to_id.at(" '")] = -INFINITY; + // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word + if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) + { + logits[vocab.token_to_id.at(" -")] = -INFINITY; + } + if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) + { + logits[vocab.token_to_id.at(" '")] = -INFINITY; + } } // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly diff --git a/whisper.h b/whisper.h index 72331e6abd4..07257bd6085 100644 --- a/whisper.h +++ b/whisper.h @@ -275,6 +275,7 @@ extern "C" { // common decoding parameters: bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89 + bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478 float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97