-
Notifications
You must be signed in to change notification settings - Fork 4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Score CTC prefix beams with KenLM #805
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
af71da0
Import KenLM
reuben fc91e3d
Write a CTC beam search decoder TF op that scores beams with our LM
reuben 2cccd33
Remove current re-scoring of decoder output and switch to custom op
reuben 86b0ed6
Make sure automation works with the new decoder
reuben 1f3d26d
Address review comments
reuben b42c83c
Package LICENSE and README.mozilla with native_client.tar.xz
reuben 7d06bd9
Import Boost.Locale files needed for utf_to_utf conversion
reuben 194de74
Switch from <codecvt> to Boost.Locale for charset transformation
reuben d6a2f58
Cleanup deepspeech_utils library definition
reuben 1bfb028
Address final review comments
reuben 4ccab71
Expose util/tc.py functionality as externally runnable and document it
reuben File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
|
||
*.binary filter=lfs diff=lfs merge=lfs -crlf | ||
data/lm/trie filter=lfs diff=lfs merge=lfs -crlf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,7 +24,6 @@ | |
from util.feeding import DataSet, ModelFeeder | ||
from util.gpu import get_available_gpus | ||
from util.shared_lib import check_cupti | ||
from util.spell import correction | ||
from util.text import sparse_tensor_value_to_texts, wer, Alphabet | ||
from xdg import BaseDirectory as xdg | ||
import numpy as np | ||
|
@@ -140,7 +139,16 @@ | |
tf.app.flags.DEFINE_float ('estop_mean_thresh', 0.5, 'mean threshold for loss to determine the condition if early stopping is required') | ||
tf.app.flags.DEFINE_float ('estop_std_thresh', 0.5, 'standard deviation threshold for loss to determine the condition if early stopping is required') | ||
|
||
# Decoder | ||
|
||
tf.app.flags.DEFINE_string ('decoder_library_path', 'native_client/libctc_decoder_with_kenlm.so', 'path to the libctc_decoder_with_kenlm.so library containing the decoder implementation.') | ||
tf.app.flags.DEFINE_string ('alphabet_config_path', 'data/alphabet.txt', 'path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format.') | ||
tf.app.flags.DEFINE_string ('lm_binary_path', 'data/lm/lm.binary', 'path to the language model binary file created with KenLM') | ||
tf.app.flags.DEFINE_string ('lm_trie_path', 'data/lm/trie', 'path to the language model trie file created with native_client/generate_trie') | ||
tf.app.flags.DEFINE_integer ('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions') | ||
tf.app.flags.DEFINE_float ('lm_weight', 2.15, 'the alpha hyperparameter of the CTC decoder. Language Model weight.') | ||
tf.app.flags.DEFINE_float ('word_count_weight', -0.10, 'the beta hyperparameter of the CTC decoder. Word insertion weight (penalty).') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
tf.app.flags.DEFINE_float ('valid_word_count_weight', 1.10, 'Valid word insertion weight. This is used to lessen the word insertion penalty when the inserted word is part of the vocabulary.') | ||
|
||
for var in ['b1', 'h1', 'b2', 'h2', 'b3', 'h3', 'b5', 'h5', 'b6', 'h6']: | ||
tf.app.flags.DEFINE_float('%s_stddev' % var, None, 'standard deviation to use when initialising %s' % var) | ||
|
@@ -452,6 +460,28 @@ def BiRNN(batch_x, seq_length, dropout): | |
# Output shape: [n_steps, batch_size, n_hidden_6] | ||
return layer_6 | ||
|
||
if not os.path.exists(os.path.abspath(FLAGS.decoder_library_path)): | ||
print('ERROR: The decoder library file does not exist. Make sure you have ' \ | ||
'downloaded or built the native client binaries and pass the ' \ | ||
'appropriate path to the binaries in the --decoder_library_path parameter.') | ||
|
||
custom_op_module = tf.load_op_library(FLAGS.decoder_library_path) | ||
|
||
def decode_with_lm(inputs, sequence_length, beam_width=100, | ||
top_paths=1, merge_repeated=True): | ||
decoded_ixs, decoded_vals, decoded_shapes, log_probabilities = ( | ||
custom_op_module.ctc_beam_search_decoder_with_lm( | ||
inputs, sequence_length, beam_width=beam_width, | ||
model_path=FLAGS.lm_binary_path, trie_path=FLAGS.lm_trie_path, alphabet_path=FLAGS.alphabet_config_path, | ||
lm_weight=FLAGS.lm_weight, word_count_weight=FLAGS.word_count_weight, valid_word_count_weight=FLAGS.valid_word_count_weight, | ||
top_paths=top_paths, merge_repeated=merge_repeated)) | ||
|
||
return ( | ||
[tf.SparseTensor(ix, val, shape) for (ix, val, shape) | ||
in zip(decoded_ixs, decoded_vals, decoded_shapes)], | ||
log_probabilities) | ||
|
||
|
||
|
||
# Accuracy and Loss | ||
# ================= | ||
|
@@ -485,7 +515,7 @@ def calculate_mean_edit_distance_and_loss(model_feeder, tower, dropout): | |
avg_loss = tf.reduce_mean(total_loss) | ||
|
||
# Beam search decode the batch | ||
decoded, _ = tf.nn.ctc_beam_search_decoder(logits, batch_seq_len, merge_repeated=False) | ||
decoded, _ = decode_with_lm(logits, batch_seq_len, merge_repeated=False, beam_width=FLAGS.beam_width) | ||
|
||
# Compute the edit (Levenshtein) distance | ||
distance = tf.edit_distance(tf.cast(decoded[0], tf.int32), batch_y) | ||
|
@@ -718,9 +748,8 @@ def calculate_report(results_tuple): | |
items = list(zip(*results_tuple)) | ||
mean_wer = 0.0 | ||
for label, decoding, distance, loss in items: | ||
corrected = correction(decoding, alphabet) | ||
sample_wer = wer(label, corrected) | ||
sample = Sample(label, corrected, loss, distance, sample_wer) | ||
sample_wer = wer(label, decoding) | ||
sample = Sample(label, decoding, loss, distance, sample_wer) | ||
samples.append(sample) | ||
mean_wer += sample_wer | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Git LFS file not shown
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you give the reference which defines your α, β, and β'?
For example, the original Deep Speech paper[1] and the Deep Speech 2 paper[2] don't define β'.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't remember and can't find where I got the beta' terminology from. I'll reword the explanation.