Skip to content

Commit

Permalink
[ctl] simplified ctl (#2483)
Browse files Browse the repository at this point in the history
* [ctl] simplified ctl

* [ctl] unify
  • Loading branch information
Mddct authored Apr 17, 2024
1 parent 6aae6ef commit 1f0fba4
Showing 1 changed file with 47 additions and 110 deletions.
157 changes: 47 additions & 110 deletions wenet/ctl_model/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Encoder definition."""
from typing import Tuple
from typing import Optional, Tuple

import torch

from wenet.utils.mask import make_pad_mask
from wenet.utils.mask import add_optional_chunk_mask
from wenet.transformer.encoder import TransformerEncoder, ConformerEncoder


Expand All @@ -44,6 +43,21 @@ def __init__(
use_dynamic_chunk: bool = False,
global_cmvn: torch.nn.Module = None,
use_dynamic_left_chunk: bool = False,
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
activation_type: str = "relu",
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
selfattention_layer_type: str = "selfattn",
mlp_type: str = 'position_wise_feed_forward',
mlp_bias: bool = True,
n_expert: int = 8,
n_expert_activated: int = 2,
):
""" Construct DualTransformerEncoder
Support both the full context mode and the streaming mode separately
Expand All @@ -53,56 +67,11 @@ def __init__(
positional_dropout_rate, attention_dropout_rate,
input_layer, pos_enc_layer_type, normalize_before,
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk)

def forward(
self,
xs: torch.Tensor,
xs_lens: torch.Tensor,
decoding_chunk_size: int = 0,
num_decoding_left_chunks: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Embed positions in tensor.
Args:
xs: padded input tensor (B, T, D)
xs_lens: input length (B)
decoding_chunk_size: decoding chunk size for dynamic chunk
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
Returns:
encoder output tensor xs, and subsampled masks
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
masks: torch.Tensor batch padding mask after subsample
(B, 1, T' ~= T/subsample_rate)
"""
T = xs.size(1)
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
xs, pos_emb, masks = self.embed(xs, masks)
mask_pad = masks # (B, 1, T/subsample_rate)
chunk_masks = add_optional_chunk_mask(xs,
masks,
self.use_dynamic_chunk,
self.use_dynamic_left_chunk,
decoding_chunk_size,
self.static_chunk_size,
num_decoding_left_chunks,
enable_full_context=False)
for layer in self.encoders:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
if self.normalize_before:
xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just
# return the masks before encoder layers, and the masks will be used
# for cross attention with decoder later
return xs, masks
use_dynamic_left_chunk, query_bias, key_bias,
value_bias, activation_type, gradient_checkpointing,
use_sdpa, layer_norm_type, norm_eps, n_kv_head,
head_dim, selfattention_layer_type, mlp_type,
mlp_bias, n_expert, n_expert_activated)

def forward_full(
self,
Expand Down Expand Up @@ -152,68 +121,36 @@ def __init__(
cnn_module_kernel: int = 15,
causal: bool = False,
cnn_module_norm: str = "batch_norm",
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
conv_bias: bool = True,
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
mlp_type: str = 'position_wise_feed_forward',
mlp_bias: bool = True,
n_expert: int = 8,
n_expert_activated: int = 2,
):
""" Construct DualConformerEncoder
Support both the full context mode and the streaming mode separately
"""
super().__init__(input_size, output_size, attention_heads,
linear_units, num_blocks, dropout_rate,
positional_dropout_rate, attention_dropout_rate,
input_layer, pos_enc_layer_type, normalize_before,
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk, positionwise_conv_kernel_size,
macaron_style, selfattention_layer_type,
activation_type, use_cnn_module, cnn_module_kernel,
causal, cnn_module_norm)

def forward(
self,
xs: torch.Tensor,
xs_lens: torch.Tensor,
decoding_chunk_size: int = 0,
num_decoding_left_chunks: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Embed positions in tensor.
Args:
xs: padded input tensor (B, T, D)
xs_lens: input length (B)
decoding_chunk_size: decoding chunk size for dynamic chunk
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
Returns:
encoder output tensor xs, and subsampled masks
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
masks: torch.Tensor batch padding mask after subsample
(B, 1, T' ~= T/subsample_rate)
"""
T = xs.size(1)
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
xs, pos_emb, masks = self.embed(xs, masks)
mask_pad = masks # (B, 1, T/subsample_rate)
chunk_masks = add_optional_chunk_mask(xs,
masks,
self.use_dynamic_chunk,
self.use_dynamic_left_chunk,
decoding_chunk_size,
self.static_chunk_size,
num_decoding_left_chunks,
enable_full_context=False)
for layer in self.encoders:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
if self.normalize_before:
xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just
# return the masks before encoder layers, and the masks will be used
# for cross attention with decoder later
return xs, masks
super().__init__(
input_size, output_size, attention_heads, linear_units, num_blocks,
dropout_rate, positional_dropout_rate, attention_dropout_rate,
input_layer, pos_enc_layer_type, normalize_before,
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk, positionwise_conv_kernel_size,
macaron_style, selfattention_layer_type, activation_type,
use_cnn_module, cnn_module_kernel, causal, cnn_module_norm,
query_bias, key_bias, value_bias, conv_bias,
gradient_checkpointing, use_sdpa, layer_norm_type, norm_eps,
n_kv_head, head_dim, mlp_type, mlp_bias, n_expert,
n_expert_activated)

def forward_full(
self,
Expand Down

0 comments on commit 1f0fba4

Please sign in to comment.