Skip to content

Commit

Permalink
WIP: Write a CTC beam search decoder TF op that scores beams with our LM
Browse files Browse the repository at this point in the history
  • Loading branch information
reuben committed Sep 1, 2017
1 parent f1e859e commit 895be68
Show file tree
Hide file tree
Showing 15 changed files with 854 additions and 72 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@

*.binary filter=lfs diff=lfs merge=lfs -crlf
data/lm/trie filter=lfs diff=lfs merge=lfs -crlf
5 changes: 3 additions & 2 deletions .tc.training.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,17 @@ payload:
image: "ubuntu:14.04"
env:
TENSORFLOW_WHEEL: https://index.taskcluster.net/v1/task/project.deepspeech.tensorflow.pip.master.cpu/artifacts/public/tensorflow_warpctc-1.3.0rc0-cp27-cp27mu-linux_x86_64.whl
DEEPSPEECH_ARTIFACTS_ROOT: https://queue.taskcluster.net/v1/task/{{ TASK_ID }}/runs/0/artifacts/public
command:
- "/bin/bash"
- "--login"
- "-cxe"
- apt-get -qq update && apt-get -qq -y install git &&
- apt-get -qq update && apt-get -qq -y install git pixz &&
apt-get -qq -y install make build-essential libssl-dev zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev wget curl llvm libncurses5-dev xz-utils tk-dev &&
{{ SYSTEM_ADD_USER }} &&
echo -e "#!/bin/bash\nset -xe\nexport PATH=/home/build-user/bin:$PATH && env && id && wget https://github.com/git-lfs/git-lfs/releases/download/v2.2.1/git-lfs-linux-amd64-2.2.1.tar.gz -O - | tar -C /tmp -zxf - && PREFIX=/home/build-user/ /tmp/git-lfs-2.2.1/install.sh && mkdir ~/DeepSpeech/ && git clone --quiet {{ GITHUB_HEAD_REPO_URL }} ~/DeepSpeech/ds/ && cd ~/DeepSpeech/ds && git checkout --quiet {{ GITHUB_HEAD_SHA }}" > /tmp/clone.sh && chmod +x /tmp/clone.sh &&
{{ SYSTEM_DO_CLONE }} &&
sudo -H -u build-user TENSORFLOW_WHEEL=${TENSORFLOW_WHEEL} /bin/bash /home/build-user/DeepSpeech/ds/tc-train-tests.sh 2.7.13
sudo -H -u build-user TENSORFLOW_WHEEL=${TENSORFLOW_WHEEL} {{ TASK_ENV_VARS }} /bin/bash /home/build-user/DeepSpeech/ds/tc-train-tests.sh 2.7.13
artifacts:
"public":
type: "directory"
Expand Down
64 changes: 58 additions & 6 deletions DeepSpeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,10 @@
from util.feeding import DataSet, ModelFeeder
from util.gpu import get_available_gpus
from util.shared_lib import check_cupti
from util.spell import correction
from util.text import sparse_tensor_value_to_texts, wer, Alphabet
from xdg import BaseDirectory as xdg
import numpy as np


# Importer
# ========

Expand Down Expand Up @@ -139,6 +137,9 @@
tf.app.flags.DEFINE_float ('estop_mean_thresh', 0.5, 'mean threshold for loss to determine the condition if early stopping is required')
tf.app.flags.DEFINE_float ('estop_std_thresh', 0.5, 'standard deviation threshold for loss to determine the condition if early stopping is required')

# Decoder

tf.app.flags.DEFINE_string ('decoder_library_path', 'native_client/libctc_decoder_with_kenlm.so', 'path to the libctc_decoder_with_kenlm.so library containing the decoder implementation.')
tf.app.flags.DEFINE_string ('alphabet_config_path', 'data/alphabet.txt', 'path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format.')

for var in ['b1', 'h1', 'b2', 'h2', 'b3', 'h3', 'b5', 'h5', 'b6', 'h6']:
Expand Down Expand Up @@ -451,6 +452,58 @@ def BiRNN(batch_x, seq_length, dropout):
# Output shape: [n_steps, batch_size, n_hidden_6]
return layer_6

custom_op_module = tf.load_op_library(FLAGS.decoder_library_path)

def decode_with_lm(inputs, sequence_length, beam_width=100,
top_paths=1, merge_repeated=True):
"""Performs beam search decoding on the logits given in input.
**Note** The `ctc_greedy_decoder` is a special case of the
`ctc_beam_search_decoder` with `top_paths=1` and `beam_width=1` (but
that decoder is faster for this special case).
If `merge_repeated` is `True`, merge repeated classes in the output beams.
This means that if consecutive entries in a beam are the same,
only the first of these is emitted. That is, when the top path
is `A B B B B`, the return value is:
* `A B` if `merge_repeated = True`.
* `A B B B B` if `merge_repeated = False`.
Args:
inputs: 3-D `float` `Tensor`, size
`[max_time x batch_size x num_classes]`. The logits.
sequence_length: 1-D `int32` vector containing sequence lengths,
having size `[batch_size]`.
beam_width: An int scalar >= 0 (beam search beam width).
top_paths: An int scalar >= 0, <= beam_width (controls output size).
merge_repeated: Boolean. Default: True.
Returns:
A tuple `(decoded, log_probabilities)` where
decoded: A list of length top_paths, where `decoded[j]`
is a `SparseTensor` containing the decoded outputs:
`decoded[j].indices`: Indices matrix `(total_decoded_outputs[j] x 2)`
The rows store: [batch, time].
`decoded[j].values`: Values vector, size `(total_decoded_outputs[j])`.
The vector stores the decoded classes for beam j.
`decoded[j].shape`: Shape vector, size `(2)`.
The shape values are: `[batch_size, max_decoded_length[j]]`.
log_probability: A `float` matrix `(batch_size x top_paths)` containing
sequence log-probabilities.
"""

decoded_ixs, decoded_vals, decoded_shapes, log_probabilities = (
custom_op_module.ctc_beam_search_decoder_with_lm(
inputs, sequence_length, model_path="data/lm/lm.binary", trie_path="data/lm/trie", alphabet_path="data/alphabet.txt",
beam_width=beam_width, top_paths=top_paths, merge_repeated=merge_repeated))

return (
[tf.SparseTensor(ix, val, shape) for (ix, val, shape)
in zip(decoded_ixs, decoded_vals, decoded_shapes)],
log_probabilities)



# Accuracy and Loss
# =================
Expand Down Expand Up @@ -484,7 +537,7 @@ def calculate_mean_edit_distance_and_loss(model_feeder, tower, dropout):
avg_loss = tf.reduce_mean(total_loss)

# Beam search decode the batch
decoded, _ = tf.nn.ctc_beam_search_decoder(logits, batch_seq_len, merge_repeated=False)
decoded, _ = decode_with_lm(logits, batch_seq_len, merge_repeated=False, beam_width=1024)

# Compute the edit (Levenshtein) distance
distance = tf.edit_distance(tf.cast(decoded[0], tf.int32), batch_y)
Expand Down Expand Up @@ -717,9 +770,8 @@ def calculate_report(results_tuple):
items = list(zip(*results_tuple))
mean_wer = 0.0
for label, decoding, distance, loss in items:
corrected = correction(decoding, alphabet)
sample_wer = wer(label, corrected)
sample = Sample(label, corrected, loss, distance, sample_wer)
sample_wer = wer(label, decoding)
sample = Sample(label, decoding, loss, distance, sample_wer)
samples.append(sample)
mean_wer += sample_wer

Expand Down
3 changes: 2 additions & 1 deletion bin/run-tc-ldc93s1.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ python -u DeepSpeech.py \
--test_files ${ldc93s1_csv} --test_batch_size 1 \
--n_hidden 494 --epoch 75 --random_seed 4567 --default_stddev 0.046875 \
--max_to_keep 1 --checkpoint_dir '/tmp/ckpt' --checkpoint_secs 0 \
--learning_rate 0.001 --dropout_rate 0.05 --export_dir "/tmp/train"
--learning_rate 0.001 --dropout_rate 0.05 --export_dir "/tmp/train" \
--decoder_library_path "/tmp/ds/libctc_decoder_with_kenlm.so"
3 changes: 3 additions & 0 deletions data/lm/trie
Git LFS file not shown
32 changes: 32 additions & 0 deletions native_client/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,35 @@ cc_library(
copts = [] + if_linux_x86_64(["-mno-fma", "-mno-avx", "-mno-avx2"]),
nocopts = "(-fstack-protector|-fno-omit-frame-pointer)",
)


cc_library(
name = "ctc_decoder_with_kenlm",
srcs = ["beam_search.cc",
"alphabet.h",
"trie_node.h"] +
glob(["kenlm/lm/*.cc", "kenlm/util/*.cc", "kenlm/util/double-conversion/*.cc",
"kenlm/lm/*.hh", "kenlm/util/*.hh", "kenlm/util/double-conversion/*.h"],
exclude = ["kenlm/*/*test.cc", "kenlm/*/*main.cc"]),
includes = ["kenlm"],
defines = ["KENLM_MAX_ORDER=6"],
deps = ["//tensorflow/core:core",
"//tensorflow/core/util/ctc",
"//third_party/eigen3",
],
)

cc_binary(
name = "generate_trie",
srcs = [
"generate_trie.cpp",
"trie_node.h",
"alphabet.h",
] + glob(["kenlm/lm/*.cc", "kenlm/util/*.cc", "kenlm/util/double-conversion/*.cc",
"kenlm/lm/*.hh", "kenlm/util/*.hh", "kenlm/util/double-conversion/*.h"],
exclude = ["kenlm/*/*test.cc", "kenlm/*/*main.cc"]),
includes = ["kenlm"],
copts = ['-std=c++11'],
linkopts = ['-lm'],
defines = ["KENLM_MAX_ORDER=6"],
)
5 changes: 5 additions & 0 deletions native_client/alphabet.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ class Alphabet {
return size_;
}

bool IsSpace(unsigned int label) const {
const std::string& str = StringFromLabel(label);
return str.size() == 1 && str[0] == ' ';
}

private:
size_t size_;
std::unordered_map<unsigned int, std::string> label_to_str_;
Expand Down
Loading

0 comments on commit 895be68

Please sign in to comment.