Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ctl] simplified ctl #2483

Merged
merged 2 commits into from
Apr 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 47 additions & 110 deletions wenet/ctl_model/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Encoder definition."""
from typing import Tuple
from typing import Optional, Tuple

import torch

from wenet.utils.mask import make_pad_mask
from wenet.utils.mask import add_optional_chunk_mask
from wenet.transformer.encoder import TransformerEncoder, ConformerEncoder


Expand All @@ -44,6 +43,21 @@ def __init__(
use_dynamic_chunk: bool = False,
global_cmvn: torch.nn.Module = None,
use_dynamic_left_chunk: bool = False,
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
activation_type: str = "relu",
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
selfattention_layer_type: str = "selfattn",
mlp_type: str = 'position_wise_feed_forward',
mlp_bias: bool = True,
n_expert: int = 8,
n_expert_activated: int = 2,
):
""" Construct DualTransformerEncoder
Support both the full context mode and the streaming mode separately
Expand All @@ -53,56 +67,11 @@ def __init__(
positional_dropout_rate, attention_dropout_rate,
input_layer, pos_enc_layer_type, normalize_before,
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk)

def forward(
self,
xs: torch.Tensor,
xs_lens: torch.Tensor,
decoding_chunk_size: int = 0,
num_decoding_left_chunks: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Embed positions in tensor.

Args:
xs: padded input tensor (B, T, D)
xs_lens: input length (B)
decoding_chunk_size: decoding chunk size for dynamic chunk
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
Returns:
encoder output tensor xs, and subsampled masks
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
masks: torch.Tensor batch padding mask after subsample
(B, 1, T' ~= T/subsample_rate)
"""
T = xs.size(1)
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
xs, pos_emb, masks = self.embed(xs, masks)
mask_pad = masks # (B, 1, T/subsample_rate)
chunk_masks = add_optional_chunk_mask(xs,
masks,
self.use_dynamic_chunk,
self.use_dynamic_left_chunk,
decoding_chunk_size,
self.static_chunk_size,
num_decoding_left_chunks,
enable_full_context=False)
for layer in self.encoders:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
if self.normalize_before:
xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just
# return the masks before encoder layers, and the masks will be used
# for cross attention with decoder later
return xs, masks
use_dynamic_left_chunk, query_bias, key_bias,
value_bias, activation_type, gradient_checkpointing,
use_sdpa, layer_norm_type, norm_eps, n_kv_head,
head_dim, selfattention_layer_type, mlp_type,
mlp_bias, n_expert, n_expert_activated)

def forward_full(
self,
Expand Down Expand Up @@ -152,68 +121,36 @@ def __init__(
cnn_module_kernel: int = 15,
causal: bool = False,
cnn_module_norm: str = "batch_norm",
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
conv_bias: bool = True,
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
mlp_type: str = 'position_wise_feed_forward',
mlp_bias: bool = True,
n_expert: int = 8,
n_expert_activated: int = 2,
):
""" Construct DualConformerEncoder
Support both the full context mode and the streaming mode separately
"""
super().__init__(input_size, output_size, attention_heads,
linear_units, num_blocks, dropout_rate,
positional_dropout_rate, attention_dropout_rate,
input_layer, pos_enc_layer_type, normalize_before,
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk, positionwise_conv_kernel_size,
macaron_style, selfattention_layer_type,
activation_type, use_cnn_module, cnn_module_kernel,
causal, cnn_module_norm)

def forward(
self,
xs: torch.Tensor,
xs_lens: torch.Tensor,
decoding_chunk_size: int = 0,
num_decoding_left_chunks: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Embed positions in tensor.

Args:
xs: padded input tensor (B, T, D)
xs_lens: input length (B)
decoding_chunk_size: decoding chunk size for dynamic chunk
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
Returns:
encoder output tensor xs, and subsampled masks
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
masks: torch.Tensor batch padding mask after subsample
(B, 1, T' ~= T/subsample_rate)
"""
T = xs.size(1)
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
xs, pos_emb, masks = self.embed(xs, masks)
mask_pad = masks # (B, 1, T/subsample_rate)
chunk_masks = add_optional_chunk_mask(xs,
masks,
self.use_dynamic_chunk,
self.use_dynamic_left_chunk,
decoding_chunk_size,
self.static_chunk_size,
num_decoding_left_chunks,
enable_full_context=False)
for layer in self.encoders:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
if self.normalize_before:
xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just
# return the masks before encoder layers, and the masks will be used
# for cross attention with decoder later
return xs, masks
super().__init__(
input_size, output_size, attention_heads, linear_units, num_blocks,
dropout_rate, positional_dropout_rate, attention_dropout_rate,
input_layer, pos_enc_layer_type, normalize_before,
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk, positionwise_conv_kernel_size,
macaron_style, selfattention_layer_type, activation_type,
use_cnn_module, cnn_module_kernel, causal, cnn_module_norm,
query_bias, key_bias, value_bias, conv_bias,
gradient_checkpointing, use_sdpa, layer_norm_type, norm_eps,
n_kv_head, head_dim, mlp_type, mlp_bias, n_expert,
n_expert_activated)

def forward_full(
self,
Expand Down
Loading