Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add modified_beam_search for streaming decode #489

Merged
merged 5 commits into from
Jul 25, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
# Copyright 2022 Xiaomi Corp. (authors: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import List

import k2
import torch
import torch.nn as nn
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
from decode_stream import DecodeStream

from icefall.decode import one_best_decoding
from icefall.utils import get_texts


def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
) -> None:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.

Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
streams:
A list of Stream objects.
"""
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3

blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
T = encoder_out.size(1)

decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (N, decoder_out_dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# decoder_out is of shape (N, decoder_out_dim)
# decoder_out is of shape (N, 1, decoder_out_dim)

decoder_out = model.decoder(decoder_input, need_pad=False)

for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa

logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
)
# logits'shape (batch_size, vocab_size)
logits = logits.squeeze(1).squeeze(1)

assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
)


def modified_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
beam: int = 4,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest using num_active_paths in the new code.

People can get confused with beam and beam_size.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know you renamed it to num_active_paths in sherpa, I thought it is better to make it be consistent with the one in decode.py. So do we need to rename the one in decode.py?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest only doing this for new code.

The current file associates two different meanings with beam, which can confuse users, I think.

) -> None:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.

Args:
model:
The RNN-T model.
encoder_out:
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
the encoder model.
streams:
A list of stream objects.
beam:
Number of active paths during the beam search.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert len(streams) == encoder_out.size(0)

blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size = len(streams)
T = encoder_out.size(1)

B = [stream.hyps for stream in streams]

for t in range(T):
current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)

hyps_shape = get_hyps_shape(B).to(device)

A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]

ys_log_probs = torch.stack(
[hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0
) # (num_hyps, 1)

decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)

decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
# decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim)

# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, encoder_out_dim)

logits = model.joiner(current_encoder_out, decoder_out)
# logits is of shape (num_hyps, 1, 1, vocab_size)

logits = logits.squeeze(1).squeeze(1)

log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)

log_probs.add_(ys_log_probs)

vocab_size = log_probs.size(-1)

log_probs = log_probs.reshape(-1)

row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)

for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)

with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()

for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]

new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id:
new_ys.append(new_token)

new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)

for i in range(batch_size):
streams[i].hyps = B[i]


def fast_beam_search_one_best(
model: nn.Module,
encoder_out: torch.Tensor,
processed_lens: torch.Tensor,
streams: List[DecodeStream],
beam: float,
max_states: int,
max_contexts: int,
) -> None:
"""It limits the maximum number of symbols per frame to 1.

A lattice is first generated by Fsa-based beam search, then we get the
recognition by applying shortest path on the lattice.

Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
processed_lens:
A tensor of shape (N,) containing the number of processed frames
in `encoder_out` before padding.
streams:
A list of stream objects.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
"""
assert encoder_out.ndim == 3
B, T, C = encoder_out.shape
assert B == len(streams)

context_size = model.decoder.context_size
vocab_size = model.decoder.vocab_size

config = k2.RnntDecodingConfig(
vocab_size=vocab_size,
decoder_history_len=context_size,
beam=beam,
max_contexts=max_contexts,
max_states=max_states,
)
individual_streams = []
for i in range(B):
individual_streams.append(streams[i].rnnt_decoding_stream)
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)

for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
# current_encoder_out is of shape
# (shape.NumElements(), 1, joiner_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)

decoding_streams.terminate_and_flush_to_streams()

lattice = decoding_streams.format_output(processed_lens.tolist())
best_path = one_best_decoding(lattice)
hyp_tokens = get_texts(best_path)

for i in range(B):
streams[i].hyp = hyp_tokens[i]
Loading