-
Notifications
You must be signed in to change notification settings - Fork 295
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add modified_beam_search for streaming decode #489
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,278 @@ | ||||||
# Copyright 2022 Xiaomi Corp. (authors: Wei Kang) | ||||||
# | ||||||
# See ../../../../LICENSE for clarification regarding multiple authors | ||||||
# | ||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
# you may not use this file except in compliance with the License. | ||||||
# You may obtain a copy of the License at | ||||||
# | ||||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||||
# | ||||||
# Unless required by applicable law or agreed to in writing, software | ||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
# See the License for the specific language governing permissions and | ||||||
# limitations under the License. | ||||||
|
||||||
import warnings | ||||||
from typing import List | ||||||
|
||||||
import k2 | ||||||
import torch | ||||||
import torch.nn as nn | ||||||
from beam_search import Hypothesis, HypothesisList, get_hyps_shape | ||||||
from decode_stream import DecodeStream | ||||||
|
||||||
from icefall.decode import one_best_decoding | ||||||
from icefall.utils import get_texts | ||||||
|
||||||
|
||||||
def greedy_search( | ||||||
model: nn.Module, | ||||||
encoder_out: torch.Tensor, | ||||||
streams: List[DecodeStream], | ||||||
) -> None: | ||||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. | ||||||
|
||||||
Args: | ||||||
model: | ||||||
The transducer model. | ||||||
encoder_out: | ||||||
Output from the encoder. Its shape is (N, T, C), where N >= 1. | ||||||
streams: | ||||||
A list of Stream objects. | ||||||
""" | ||||||
assert len(streams) == encoder_out.size(0) | ||||||
assert encoder_out.ndim == 3 | ||||||
|
||||||
blank_id = model.decoder.blank_id | ||||||
context_size = model.decoder.context_size | ||||||
device = model.device | ||||||
T = encoder_out.size(1) | ||||||
|
||||||
decoder_input = torch.tensor( | ||||||
[stream.hyp[-context_size:] for stream in streams], | ||||||
device=device, | ||||||
dtype=torch.int64, | ||||||
) | ||||||
# decoder_out is of shape (N, decoder_out_dim) | ||||||
decoder_out = model.decoder(decoder_input, need_pad=False) | ||||||
|
||||||
for t in range(T): | ||||||
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim) | ||||||
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa | ||||||
|
||||||
logits = model.joiner( | ||||||
current_encoder_out.unsqueeze(2), | ||||||
decoder_out.unsqueeze(1), | ||||||
) | ||||||
# logits'shape (batch_size, vocab_size) | ||||||
logits = logits.squeeze(1).squeeze(1) | ||||||
|
||||||
assert logits.ndim == 2, logits.shape | ||||||
y = logits.argmax(dim=1).tolist() | ||||||
emitted = False | ||||||
for i, v in enumerate(y): | ||||||
if v != blank_id: | ||||||
streams[i].hyp.append(v) | ||||||
emitted = True | ||||||
if emitted: | ||||||
# update decoder output | ||||||
decoder_input = torch.tensor( | ||||||
[stream.hyp[-context_size:] for stream in streams], | ||||||
device=device, | ||||||
dtype=torch.int64, | ||||||
) | ||||||
decoder_out = model.decoder( | ||||||
decoder_input, | ||||||
need_pad=False, | ||||||
) | ||||||
|
||||||
|
||||||
def modified_beam_search( | ||||||
model: nn.Module, | ||||||
encoder_out: torch.Tensor, | ||||||
streams: List[DecodeStream], | ||||||
beam: int = 4, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suggest using People can get confused with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know you renamed it to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suggest only doing this for new code. The current file associates two different meanings with
icefall/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py Line 203 in 72d76a4
|
||||||
) -> None: | ||||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. | ||||||
|
||||||
Args: | ||||||
model: | ||||||
The RNN-T model. | ||||||
encoder_out: | ||||||
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of | ||||||
the encoder model. | ||||||
streams: | ||||||
A list of stream objects. | ||||||
beam: | ||||||
Number of active paths during the beam search. | ||||||
""" | ||||||
assert encoder_out.ndim == 3, encoder_out.shape | ||||||
assert len(streams) == encoder_out.size(0) | ||||||
|
||||||
blank_id = model.decoder.blank_id | ||||||
context_size = model.decoder.context_size | ||||||
device = next(model.parameters()).device | ||||||
batch_size = len(streams) | ||||||
T = encoder_out.size(1) | ||||||
|
||||||
B = [stream.hyps for stream in streams] | ||||||
|
||||||
for t in range(T): | ||||||
current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1) | ||||||
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) | ||||||
|
||||||
hyps_shape = get_hyps_shape(B).to(device) | ||||||
|
||||||
A = [list(b) for b in B] | ||||||
B = [HypothesisList() for _ in range(batch_size)] | ||||||
|
||||||
ys_log_probs = torch.stack( | ||||||
[hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0 | ||||||
) # (num_hyps, 1) | ||||||
|
||||||
decoder_input = torch.tensor( | ||||||
[hyp.ys[-context_size:] for hyps in A for hyp in hyps], | ||||||
device=device, | ||||||
dtype=torch.int64, | ||||||
) # (num_hyps, context_size) | ||||||
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) | ||||||
# decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim) | ||||||
|
||||||
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor | ||||||
# as index, so we use `to(torch.int64)` below. | ||||||
current_encoder_out = torch.index_select( | ||||||
current_encoder_out, | ||||||
dim=0, | ||||||
index=hyps_shape.row_ids(1).to(torch.int64), | ||||||
) # (num_hyps, encoder_out_dim) | ||||||
|
||||||
logits = model.joiner(current_encoder_out, decoder_out) | ||||||
# logits is of shape (num_hyps, 1, 1, vocab_size) | ||||||
|
||||||
logits = logits.squeeze(1).squeeze(1) | ||||||
|
||||||
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) | ||||||
|
||||||
log_probs.add_(ys_log_probs) | ||||||
|
||||||
vocab_size = log_probs.size(-1) | ||||||
|
||||||
log_probs = log_probs.reshape(-1) | ||||||
|
||||||
row_splits = hyps_shape.row_splits(1) * vocab_size | ||||||
log_probs_shape = k2.ragged.create_ragged_shape2( | ||||||
row_splits=row_splits, cached_tot_size=log_probs.numel() | ||||||
) | ||||||
ragged_log_probs = k2.RaggedTensor( | ||||||
shape=log_probs_shape, value=log_probs | ||||||
) | ||||||
|
||||||
for i in range(batch_size): | ||||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) | ||||||
|
||||||
with warnings.catch_warnings(): | ||||||
warnings.simplefilter("ignore") | ||||||
topk_hyp_indexes = (topk_indexes // vocab_size).tolist() | ||||||
topk_token_indexes = (topk_indexes % vocab_size).tolist() | ||||||
|
||||||
for k in range(len(topk_hyp_indexes)): | ||||||
hyp_idx = topk_hyp_indexes[k] | ||||||
hyp = A[i][hyp_idx] | ||||||
|
||||||
new_ys = hyp.ys[:] | ||||||
new_token = topk_token_indexes[k] | ||||||
if new_token != blank_id: | ||||||
new_ys.append(new_token) | ||||||
|
||||||
new_log_prob = topk_log_probs[k] | ||||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) | ||||||
B[i].add(new_hyp) | ||||||
|
||||||
for i in range(batch_size): | ||||||
streams[i].hyps = B[i] | ||||||
|
||||||
|
||||||
def fast_beam_search_one_best( | ||||||
model: nn.Module, | ||||||
encoder_out: torch.Tensor, | ||||||
processed_lens: torch.Tensor, | ||||||
streams: List[DecodeStream], | ||||||
beam: float, | ||||||
max_states: int, | ||||||
max_contexts: int, | ||||||
) -> None: | ||||||
"""It limits the maximum number of symbols per frame to 1. | ||||||
|
||||||
A lattice is first generated by Fsa-based beam search, then we get the | ||||||
recognition by applying shortest path on the lattice. | ||||||
|
||||||
Args: | ||||||
model: | ||||||
An instance of `Transducer`. | ||||||
encoder_out: | ||||||
A tensor of shape (N, T, C) from the encoder. | ||||||
processed_lens: | ||||||
A tensor of shape (N,) containing the number of processed frames | ||||||
in `encoder_out` before padding. | ||||||
streams: | ||||||
A list of stream objects. | ||||||
beam: | ||||||
Beam value, similar to the beam used in Kaldi.. | ||||||
max_states: | ||||||
Max states per stream per frame. | ||||||
max_contexts: | ||||||
Max contexts pre stream per frame. | ||||||
""" | ||||||
assert encoder_out.ndim == 3 | ||||||
B, T, C = encoder_out.shape | ||||||
assert B == len(streams) | ||||||
|
||||||
context_size = model.decoder.context_size | ||||||
vocab_size = model.decoder.vocab_size | ||||||
|
||||||
config = k2.RnntDecodingConfig( | ||||||
vocab_size=vocab_size, | ||||||
decoder_history_len=context_size, | ||||||
beam=beam, | ||||||
max_contexts=max_contexts, | ||||||
max_states=max_states, | ||||||
) | ||||||
individual_streams = [] | ||||||
for i in range(B): | ||||||
individual_streams.append(streams[i].rnnt_decoding_stream) | ||||||
decoding_streams = k2.RnntDecodingStreams(individual_streams, config) | ||||||
|
||||||
for t in range(T): | ||||||
# shape is a RaggedShape of shape (B, context) | ||||||
# contexts is a Tensor of shape (shape.NumElements(), context_size) | ||||||
shape, contexts = decoding_streams.get_contexts() | ||||||
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 | ||||||
contexts = contexts.to(torch.int64) | ||||||
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) | ||||||
decoder_out = model.decoder(contexts, need_pad=False) | ||||||
# current_encoder_out is of shape | ||||||
# (shape.NumElements(), 1, joiner_dim) | ||||||
# fmt: off | ||||||
current_encoder_out = torch.index_select( | ||||||
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) | ||||||
) | ||||||
# fmt: on | ||||||
logits = model.joiner( | ||||||
current_encoder_out.unsqueeze(2), | ||||||
decoder_out.unsqueeze(1), | ||||||
) | ||||||
logits = logits.squeeze(1).squeeze(1) | ||||||
log_probs = logits.log_softmax(dim=-1) | ||||||
decoding_streams.advance(log_probs) | ||||||
|
||||||
decoding_streams.terminate_and_flush_to_streams() | ||||||
|
||||||
lattice = decoding_streams.format_output(processed_lens.tolist()) | ||||||
best_path = one_best_decoding(lattice) | ||||||
hyp_tokens = get_texts(best_path) | ||||||
|
||||||
for i in range(B): | ||||||
streams[i].hyp = hyp_tokens[i] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.