Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[e_branchformer] simplified e_branchformer #2484

Merged
merged 8 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/aishell/s0/conf/train_ebranchformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ encoder_conf:
activation_type: 'swish'
causal: false
pos_enc_layer_type: 'rel_pos'
attention_layer_type: 'rel_selfattn'
selfattention_layer_type: 'rel_selfattn'

# decoder related
decoder: transformer
Expand Down
307 changes: 68 additions & 239 deletions wenet/e_branchformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,19 @@
"""Encoder definition."""

import torch
import torch.nn as nn
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Union

from wenet.e_branchformer.encoder_layer import EBranchformerEncoderLayer
from wenet.branchformer.cgmlp import ConvolutionalGatingMLP
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
from wenet.utils.mask import make_pad_mask
from wenet.utils.mask import add_optional_chunk_mask
from wenet.transformer.encoder import ConformerEncoder
from wenet.utils.class_utils import (
WENET_ATTENTION_CLASSES,
WENET_EMB_CLASSES,
WENET_SUBSAMPLE_CLASSES,
WENET_ACTIVATION_CLASSES,
WENET_ATTENTION_CLASSES,
WENET_MLP_CLASSES,
)


class EBranchformerEncoder(nn.Module):
class EBranchformerEncoder(ConformerEncoder):
"""E-Branchformer encoder module."""

def __init__(
Expand All @@ -42,20 +38,18 @@ def __init__(
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
attention_layer_type: str = "rel_selfattn",
selfattention_layer_type: str = "rel_selfattn",
pos_enc_layer_type: str = "rel_pos",
activation_type: str = "swish",
cgmlp_linear_units: int = 2048,
cgmlp_conv_kernel: int = 31,
use_linear_after_conv: bool = False,
gate_activation: str = "identity",
merge_method: str = "concat",
num_blocks: int = 12,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: Optional[str] = "conv2d",
padding_idx: int = -1,
input_layer: str = "conv2d",
stochastic_depth_rate: Union[float, List[float]] = 0.0,
static_chunk_size: int = 0,
use_dynamic_chunk: bool = False,
Expand All @@ -65,23 +59,65 @@ def __init__(
merge_conv_kernel: int = 3,
use_ffn: bool = True,
macaron_style: bool = True,
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
conv_bias: bool = True,
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
mlp_type: str = 'position_wise_feed_forward',
mlp_bias: bool = True,
n_expert: int = 8,
n_expert_activated: int = 2,
):
super().__init__()
activation = WENET_ACTIVATION_CLASSES[activation_type]()
self._output_size = output_size

self.embed = WENET_SUBSAMPLE_CLASSES[input_layer](
input_size,
output_size,
dropout_rate,
WENET_EMB_CLASSES[pos_enc_layer_type](output_size,
positional_dropout_rate),
)
super().__init__(input_size,
output_size,
attention_heads,
linear_units,
num_blocks,
dropout_rate,
positional_dropout_rate,
attention_dropout_rate,
input_layer,
pos_enc_layer_type,
True,
static_chunk_size,
use_dynamic_chunk,
global_cmvn,
use_dynamic_left_chunk,
1,
macaron_style,
selfattention_layer_type,
activation_type,
query_bias=query_bias,
key_bias=key_bias,
value_bias=value_bias,
conv_bias=conv_bias,
gradient_checkpointing=gradient_checkpointing,
use_sdpa=use_sdpa,
layer_norm_type=layer_norm_type,
norm_eps=norm_eps,
n_kv_head=n_kv_head,
head_dim=head_dim,
mlp_type=mlp_type,
mlp_bias=mlp_bias,
n_expert=n_expert,
n_expert_activated=n_expert_activated)

encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
query_bias,
key_bias,
value_bias,
use_sdpa,
n_kv_head,
head_dim,
)

cgmlp_layer = ConvolutionalGatingMLP
Expand All @@ -90,12 +126,16 @@ def __init__(
gate_activation, causal)

# feed-forward module definition
positionwise_layer = PositionwiseFeedForward
mlp_class = WENET_MLP_CLASSES[mlp_type]
activation = WENET_ACTIVATION_CLASSES[activation_type]()
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
activation,
mlp_bias,
n_expert,
n_expert_activated,
)

if isinstance(stochastic_depth_rate, float):
Expand All @@ -108,226 +148,15 @@ def __init__(
self.encoders = torch.nn.ModuleList([
EBranchformerEncoderLayer(
output_size,
WENET_ATTENTION_CLASSES[attention_layer_type](
WENET_ATTENTION_CLASSES[selfattention_layer_type](
*encoder_selfattn_layer_args),
cgmlp_layer(*cgmlp_layer_args),
positionwise_layer(
*positionwise_layer_args) if use_ffn else None,
positionwise_layer(*positionwise_layer_args)
mlp_class(*positionwise_layer_args) if use_ffn else None,
mlp_class(*positionwise_layer_args)
if use_ffn and macaron_style else None,
dropout_rate,
merge_conv_kernel=merge_conv_kernel,
causal=causal,
stochastic_depth_rate=stochastic_depth_rate[lnum],
) for lnum in range(num_blocks)
])

self.after_norm = nn.LayerNorm(output_size)
self.static_chunk_size = static_chunk_size
self.global_cmvn = global_cmvn
self.use_dynamic_chunk = use_dynamic_chunk
self.use_dynamic_left_chunk = use_dynamic_left_chunk

def output_size(self) -> int:
return self._output_size

def forward(
self,
xs: torch.Tensor,
ilens: torch.Tensor,
decoding_chunk_size: int = 0,
num_decoding_left_chunks: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Calculate forward propagation.

Args:
xs (torch.Tensor): Input tensor (B, T, D).
ilens (torch.Tensor): Input length (#batch).
decoding_chunk_size: decoding chunk size for dynamic chunk
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks

Returns:
encoder output tensor xs, and subsampled masks
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
masks: torch.Tensor batch padding mask after subsample
(B, 1, T' ~= T/subsample_rate)
"""

T = xs.size(1)
masks = ~make_pad_mask(ilens, T).unsqueeze(1) # (B, 1, T)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
xs, pos_emb, masks = self.embed(xs, masks)
mask_pad = masks # (B, 1, T/subsample_rate)
chunk_masks = add_optional_chunk_mask(xs, masks,
self.use_dynamic_chunk,
self.use_dynamic_left_chunk,
decoding_chunk_size,
self.static_chunk_size,
num_decoding_left_chunks)
for layer in self.encoders:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)

xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just
# return the masks before encoder layers, and the masks will be used
# for cross attention with decoder later
return xs, masks

def forward_chunk(
self,
xs: torch.Tensor,
offset: int,
required_cache_size: int,
att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" Forward just one chunk

Args:
xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
where `time == (chunk_size - 1) * subsample_rate + \
subsample.right_context + 1`
offset (int): current offset in encoder output time stamp
required_cache_size (int): cache size required for next chunk
compuation
>=0: actual cache size
<0: means all history cache is required
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`.
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
(elayers, b=1, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1`

Returns:
torch.Tensor: output of current input xs,
with shape (b=1, chunk_size, hidden-dim).
torch.Tensor: new attention cache required for next chunk, with
dynamic shape (elayers, head, ?, d_k * 2)
depending on required_cache_size.
torch.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.

"""
assert xs.size(0) == 1
# tmp_masks is just for interface compatibility
tmp_masks = torch.ones(1,
xs.size(1),
device=xs.device,
dtype=torch.bool)
tmp_masks = tmp_masks.unsqueeze(1)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
# NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
# NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
chunk_size = xs.size(1)
attention_key_size = cache_t1 + chunk_size
pos_emb = self.embed.position_encoding(offset=offset - cache_t1,
size=attention_key_size)
if required_cache_size < 0:
next_cache_start = 0
elif required_cache_size == 0:
next_cache_start = attention_key_size
else:
next_cache_start = max(attention_key_size - required_cache_size, 0)
r_att_cache = []
r_cnn_cache = []
for i, layer in enumerate(self.encoders):
# NOTE(xcsong): Before layer.forward
# shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
# shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
xs, _, new_att_cache, new_cnn_cache = layer(
xs,
att_mask,
pos_emb,
att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache)
# NOTE(xcsong): After layer.forward
# shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
# shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
r_cnn_cache.append(new_cnn_cache.unsqueeze(0))

xs = self.after_norm(xs)

# NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
# ? may be larger than cache_t1, it depends on required_cache_size
r_att_cache = torch.cat(r_att_cache, dim=0)
# NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
r_cnn_cache = torch.cat(r_cnn_cache, dim=0)

return (xs, r_att_cache, r_cnn_cache)

def forward_chunk_by_chunk(
self,
xs: torch.Tensor,
decoding_chunk_size: int,
num_decoding_left_chunks: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
""" Forward input chunk by chunk with chunk_size like a streaming
fashion

Here we should pay special attention to computation cache in the
streaming style forward chunk by chunk. Three things should be taken
into account for computation in the current network:
1. transformer/conformer encoder layers output cache
2. convolution in conformer
3. convolution in subsampling

However, we don't implement subsampling cache for:
1. We can control subsampling module to output the right result by
overlapping input instead of cache left context, even though it
wastes some computation, but subsampling only takes a very
small fraction of computation in the whole model.
2. Typically, there are several covolution layers with subsampling
in subsampling module, it is tricky and complicated to do cache
with different convolution layers with different subsampling
rate.
3. Currently, nn.Sequential is used to stack all the convolution
layers in subsampling, we need to rewrite it to make it work
with cache, which is not prefered.
Args:
xs (torch.Tensor): (1, max_len, dim)
chunk_size (int): decoding chunk size
"""
assert decoding_chunk_size > 0
# The model is trained by static or dynamic chunk
assert self.static_chunk_size > 0 or self.use_dynamic_chunk
subsampling = self.embed.subsampling_rate
context = self.embed.right_context + 1 # Add current frame
stride = subsampling * decoding_chunk_size
decoding_window = (decoding_chunk_size - 1) * subsampling + context
num_frames = xs.size(1)
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
outputs = []
offset = 0
required_cache_size = decoding_chunk_size * num_decoding_left_chunks

# Feed forward overlap input step by step
for cur in range(0, num_frames - context + 1, stride):
end = min(cur + decoding_window, num_frames)
chunk_xs = xs[:, cur:end, :]
(y, att_cache,
cnn_cache) = self.forward_chunk(chunk_xs, offset,
required_cache_size, att_cache,
cnn_cache)
outputs.append(y)
offset += y.size(1)
ys = torch.cat(outputs, 1)
masks = torch.ones((1, 1, ys.size(1)),
device=ys.device,
dtype=torch.bool)
return ys, masks
Loading