-
Notifications
You must be signed in to change notification settings - Fork 4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Score CTC prefix beams with KenLM #805
Conversation
9f500ed
to
abe318e
Compare
2f701d3
to
895be68
Compare
FWIW, all the tests passed: https://tools.taskcluster.net/groups/Vxu6YJ_vR-6GFI-vi0Z0nA I need to figure out a solution for users in general, probably downloading the appropriate native_client.tar.xz automatically from the bin/run-* scripts and extracting the library so that training can be done without having to set up a TensorFlow build environment. |
895be68
to
e1ca3bc
Compare
@kdavis-mozilla @lissyx I've split the changes into logical chunks where possible. I'll work on the scripts I mentioned above before merging the PR, but the commits that are here are ready to be reviewed. |
native_client/BUILD
Outdated
name = "ctc_decoder_with_kenlm", | ||
srcs = ["beam_search.cc", | ||
"alphabet.h", | ||
"trie_node.h"] + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: the "] +" should be on the next line
native_client/BUILD
Outdated
"trie_node.h", | ||
"alphabet.h", | ||
] + glob(["kenlm/lm/*.cc", "kenlm/util/*.cc", "kenlm/util/double-conversion/*.cc", | ||
"kenlm/lm/*.hh", "kenlm/util/*.hh", "kenlm/util/double-conversion/*.h"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: alignment
native_client/BUILD
Outdated
"kenlm/lm/*.hh", "kenlm/util/*.hh", "kenlm/util/double-conversion/*.h"], | ||
exclude = ["kenlm/*/*test.cc", "kenlm/*/*main.cc"]), | ||
includes = ["kenlm"], | ||
copts = ['-std=c++11'], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use " instead of '
native_client/beam_search.cc
Outdated
limitations under the License. | ||
==============================================================================*/ | ||
|
||
// This test illustrates how to make use of the CTCBeamSearchDecoder using a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think those are not needed anymore :)
c->set_output(out_idx++, c->Matrix(batch_size, top_paths)); | ||
return tf::Status::OK(); | ||
}) | ||
.Doc(R"doc( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the doc needs to be updated to include the new parameters
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't change this line so GitHub is still showing the comment but it's been fixed.
native_client/beam_search.cc
Outdated
OP_REQUIRES_OK(ctx, ctx->GetAttr("top_paths", &top_paths)); | ||
decode_helper_.SetTopPaths(top_paths); | ||
|
||
// const tf::Tensor* model_tensor; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is that leftover from previous testing ?
native_client/generate_trie.cpp
Outdated
} | ||
|
||
int main(void) { | ||
return generate_trie("/Users/remorais/Development/DeepSpeech/data/alphabet.txt", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to be leftover as well :/
native_client/generate_trie.cpp
Outdated
ifs.open(vocab_path, std::ifstream::in); | ||
|
||
if (!ifs.is_open()) { | ||
std::cout << "unable to open vocabulary" << std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
errors should go to stderr
.tc.training.yml
Outdated
apt-get -qq -y install make build-essential libssl-dev zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev wget curl llvm libncurses5-dev xz-utils tk-dev && | ||
{{ SYSTEM_ADD_USER }} && | ||
echo -e "#!/bin/bash\nset -xe\nexport PATH=/home/build-user/bin:$PATH && env && id && wget https://github.com/git-lfs/git-lfs/releases/download/v2.2.1/git-lfs-linux-amd64-2.2.1.tar.gz -O - | tar -C /tmp -zxf - && PREFIX=/home/build-user/ /tmp/git-lfs-2.2.1/install.sh && mkdir ~/DeepSpeech/ && git clone --quiet {{ GITHUB_HEAD_REPO_URL }} ~/DeepSpeech/ds/ && cd ~/DeepSpeech/ds && git checkout --quiet {{ GITHUB_HEAD_SHA }}" > /tmp/clone.sh && chmod +x /tmp/clone.sh && | ||
{{ SYSTEM_DO_CLONE }} && | ||
sudo -H -u build-user TENSORFLOW_WHEEL=${TENSORFLOW_WHEEL} /bin/bash /home/build-user/DeepSpeech/ds/tc-train-tests.sh 2.7.13 | ||
sudo -H -u build-user TENSORFLOW_WHEEL=${TENSORFLOW_WHEEL} {{ TASK_ENV_VARS }} /bin/bash /home/build-user/DeepSpeech/ds/tc-train-tests.sh 2.7.13 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need {{ TASK_ENV_VARS }}
, just DEEPSPEECH_ARTIFACTS_ROOT=${DEEPSPEECH_ARTIFACTS_ROOT}
DeepSpeech.py
Outdated
@@ -484,7 +538,7 @@ def calculate_mean_edit_distance_and_loss(model_feeder, tower, dropout): | |||
avg_loss = tf.reduce_mean(total_loss) | |||
|
|||
# Beam search decode the batch | |||
decoded, _ = tf.nn.ctc_beam_search_decoder(logits, batch_seq_len, merge_repeated=False) | |||
decoded, _ = decode_with_lm(logits, batch_seq_len, merge_repeated=False, beam_width=1024) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to be able to keep the ability to use the vanilla tensorflow decoder ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think anyone will want to use it, and I'd rather avoid forking the execution/maintenance paths and have some part of the code break because it's not being tested.
native_client/generate_trie.cpp
Outdated
} | ||
|
||
std::ofstream ofs; | ||
ofs.open(trie_path); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't we check that as well? if we cannot open the output file
native_client/generate_trie.cpp
Outdated
ofs.open(trie_path); | ||
|
||
std::string word; | ||
while (ifs >> word) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, this is going to read line by line the worlds.txt file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Word by word, it splits on whitespace.
native_client/trie_node.h
Outdated
} | ||
} | ||
|
||
int GetFrequency() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is a bit misleading, a frequency is not exactly a number of occurences in my mind, or am I missing something ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, it's more of a count. I'll rename it (or delete it, since there's no users).
native_client/generate_trie.cpp
Outdated
for_each(word.begin(), word.end(), [](char& a) { a = tolower(a); }); | ||
lm::WordIndex vocab = GetWordIndex(model, word); | ||
float unigram_score = ScoreWord(model, vocab); | ||
root.Insert(word.c_str(), [&a](char c) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: For readability, I guess the translator should be completely on a new line:
root.Insert(
word.c_str(),
[&a](char c) {
},
vocab,
unigram_score);
native_client/trie_node.h
Outdated
TrieNode *child = children[vocabIndex]; | ||
if (child == nullptr) | ||
child = children[vocabIndex] = new TrieNode(vocab_size); | ||
child->Insert(word + 1, translator, lm_word, unigram_score); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we playing with pointers here with the + 1
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes.
native_client/trie_node.h
Outdated
if (wordCharacter != '\0') { | ||
int vocabIndex = translator(wordCharacter); | ||
TrieNode *child = children[vocabIndex]; | ||
if (child == nullptr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: i'm not a big fan of brace-less conditions, it's often root for nasty bugs in the future :)
} | ||
|
||
static void ReadFromStream(std::istream& is, TrieNode* &obj, int vocab_size) { | ||
int prefixCount; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
having a local variable of a static class named the same way a member of the class is confusing, imho
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But it is eventually the local variable prefixCount.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I renamed member variables to have a _ suffix.
native_client/generate_trie.cpp
Outdated
Model::State in_state = model.NullContextState(); | ||
Model::State out; | ||
lm::FullScoreReturn full_score_return; | ||
full_score_return = model.FullScore(in_state, vocab, out); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is there this out
that we don't use ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Asking because according to that comment f1e859e#diff-7a0b3b152c06fc03b20e63b72c26583aR11 (which I don't completely get the point, so I might just be misunderstanding something there), it seems like out_state
has some use.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keeping state is a performance optimization when scoring a sentence of various words, in this function we want to score the word independently of any context so that's why we throw it away.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, maybe worth a comment then ?
native_client/beam_search.cc
Outdated
}; | ||
|
||
// CTC beam search | ||
class CTCBeamSearchDecoderOp : public tf::OpKernel { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That class name is already used by Tensorflow vanilla CTC Beam Search decoder. Can we rename to something matching the op? Like CTCBeamSearchDecoderWithLM
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically this is ::CTCBeamSearchDecoderOp
rather than ::tensorflow::CTCBeamSearchDecoderOp
, but you're right, I'll rename it :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool PR!
There's lots of little stuff which I won't go in to here.
However, I think the biggest single thing is that the trie needs to support unicode. Currently it just supports ascii, see my comments on the code. So it blocks the internationalization work you already did with the alphabet.txt file.
DeepSpeech.py
Outdated
top_paths=1, merge_repeated=True): | ||
"""Performs beam search decoding on the logits given in input. | ||
|
||
**Note** The `ctc_greedy_decoder` is a special case of the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we keep this comment from the original ctc_beam_search_decoder
? I think it's not needed.
DeepSpeech.py
Outdated
@@ -484,7 +538,7 @@ def calculate_mean_edit_distance_and_loss(model_feeder, tower, dropout): | |||
avg_loss = tf.reduce_mean(total_loss) | |||
|
|||
# Beam search decode the batch | |||
decoded, _ = tf.nn.ctc_beam_search_decoder(logits, batch_seq_len, merge_repeated=False) | |||
decoded, _ = decode_with_lm(logits, batch_seq_len, merge_repeated=False, beam_width=1024) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we pull the beam width out as a command line parameter that defaults to 1024?
native_client/BUILD
Outdated
exclude = ["kenlm/*/*test.cc", "kenlm/*/*main.cc"]), | ||
includes = ["kenlm"], | ||
copts = ['-std=c++11'], | ||
linkopts = ['-lm'], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto previous nit on ' vs "
native_client/alphabet.h
Outdated
@@ -54,6 +54,11 @@ class Alphabet { | |||
return size_; | |||
} | |||
|
|||
bool IsSpace(unsigned int label) const { | |||
const std::string& str = StringFromLabel(label); | |||
return str.size() == 1 && str[0] == ' '; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -0,0 +1,165 @@ | |||
GNU LESSER GENERAL PUBLIC LICENSE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
native_client/beam_search.cc
Outdated
|
||
// TODO replace with OOV unigram prob? | ||
// If we have no valid prefix we assume a very low log probability | ||
float min_unigram_score = -10.0f; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The choice of OOV unigram prob seems reasonable as it's the min unigram score one can expect walking down the trie assuming the language model was pruned and all words under some min probability epsilon were replaced with unknown.
For KenLM the unknown word has WordIndex model->GetVocabulary().NotFound(). So I'd guess one should be able to set this value more systematically.
// TODO try two options | ||
// 1) unigram score added up to language model scare | ||
// 2) langugage model score of (preceding_words + unigram_word) | ||
to_state->score = min_unigram_score + to_state->language_model_score; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question, as the current character is not a space do
from_state->language_model_score
and
to_state->language_model_score
differ?
As far as I can see they do not differ as there's been no call to ScoreIncompleteWord()
and UpdateWithLMScore
between evaluating from_state->language_model_score
and to_state->language_model_score
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They do not. Until there's a space we accumulate min unigram scores and then when a word boundary is reached we score it with the LM and update language_model_score.
native_client/beam_search.cc
Outdated
float ScoreIncompleteWord(const Model::State& model_state, | ||
const std::string& word, | ||
Model::State& out) const { | ||
lm::FullScoreReturn full_score_return; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we're doing C++ we could write these lines as
lm::WordIndex vocab = model->GetVocabulary().Index(word);
lm::FullScoreReturn full_score_return = model->FullScore(model_state, vocab, out);
return full_score_return.prob;
native_client/beam_search.cc
Outdated
OP_REQUIRES_OK(ctx, ctx->GetAttr("top_paths", &top_paths)); | ||
decode_helper_.SetTopPaths(top_paths); | ||
|
||
// const tf::Tensor* model_tensor; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you remove this commented out code?
native_client/beam_search.cc
Outdated
} | ||
|
||
void Compute(tf::OpKernelContext *ctx) override { | ||
const tf::Tensor *inputs; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why such C based C++? Introducing all the variables at the start of the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code is copied from CTCBeamSearchDecoder
in tensorflow, mostly unmodified. All I change is which beam scorer to use. I can make the code more C++-y, but then if we ever need to uplift changes from there to here it's harder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, understood.
tc-tests-utils.sh
Outdated
@@ -47,7 +47,7 @@ assert_correct_ldc93s1() | |||
assert_correct_inference "$1" "she had your dark suit in greasy wash water all year" | |||
} | |||
|
|||
download_material() | |||
download_native_client_files() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am stealing part of that because I need to perform something similar in #810
@reuben I guess you've seen Bug 1396777. Do you want to fold those changes in to this PR or do a separate PR? |
@lissyx the TC build is failing because the version of GCC/libstdc++ on the builders does not support Alternatively, if that's too complicated, I can use Boost.Locale, but that's a new dependency I'd rather avoid. |
@kdavis-mozilla I've addressed all review comments. Could you take a second look? In particular, the trie_node.h, beam_search.cc and DeepSpeech.py changes. Thanks! |
@reuben You might be able to get
There is nothing newer available for RPi3, switching toolchain for ARM cross compilation might be very tricky. |
@lissyx that's what I was afraid of. I remember having trouble building TensorFlow with GCC>4.9, so I think the easier solution is to go with Boost.Locale. |
@reuben On the review. Should I wait until the switch to Boost.Locale? |
@kdavis-mozilla I'll try to upgrade to GCC 5.0 first. But yeah, I'll let you know once it's ready for review. |
13e08c6
to
7af3b4d
Compare
There's no way to pass the appropriate include path to Bazel from the external world, so I vendored the necessary headers instead in native_client/boost_locale. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few little nits, fix them if you want to
tf.app.flags.DEFINE_string ('lm_binary_path', 'data/lm/lm.binary', 'path to the language model binary file created with KenLM') | ||
tf.app.flags.DEFINE_string ('lm_trie_path', 'data/lm/trie', 'path to the language model trie file created with native_client/generate_trie') | ||
tf.app.flags.DEFINE_integer ('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions') | ||
tf.app.flags.DEFINE_float ('lm_weight', 2.15, 'the alpha hyperparameter of the CTC decoder. Language Model weight.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't remember and can't find where I got the beta' terminology from. I'll reword the explanation.
tf.app.flags.DEFINE_string ('lm_trie_path', 'data/lm/trie', 'path to the language model trie file created with native_client/generate_trie') | ||
tf.app.flags.DEFINE_integer ('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions') | ||
tf.app.flags.DEFINE_float ('lm_weight', 2.15, 'the alpha hyperparameter of the CTC decoder. Language Model weight.') | ||
tf.app.flags.DEFINE_float ('word_count_weight', -0.10, 'the beta hyperparameter of the CTC decoder. Word insertion weight (penalty).') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DeepSpeech.py
Outdated
tf.app.flags.DEFINE_integer ('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions') | ||
tf.app.flags.DEFINE_float ('lm_weight', 2.15, 'the alpha hyperparameter of the CTC decoder. Language Model weight.') | ||
tf.app.flags.DEFINE_float ('word_count_weight', -0.10, 'the beta hyperparameter of the CTC decoder. Word insertion weight (penalty).') | ||
tf.app.flags.DEFINE_float ('valid_word_count_weight', 1.10, 'the beta\' hyperparameter of the CTC decoder. Valid word insertion weight.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
native_client/beam_search.cc
Outdated
} | ||
|
||
private: | ||
Model *model; | ||
Alphabet *alphabet; | ||
Model model; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do some members have a finale '_' and some don't?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oversight! Will fix it.
native_client/beam_search.cc
Outdated
Model *model; | ||
Alphabet *alphabet; | ||
Model model; | ||
Alphabet alphabet; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do some members have a finale '_' and some don't?
native_client/beam_search.cc
Outdated
Model *model; | ||
Alphabet *alphabet; | ||
Model model; | ||
Alphabet alphabet; | ||
TrieNode *trieRoot; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do some members have a finale '_' and some don't?
c798f31
to
1bfb028
Compare
1cc5402
to
4ccab71
Compare
This thread has been automatically locked since there has not been any recent activity after it was closed. Please open a new issue for related bugs. |
Opening a PR to test the TaskCluster setup.