Skip to content

Commit

Permalink
e-branformer works
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Apr 17, 2024
1 parent 89d99c6 commit 63d9a3b
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 52 deletions.
9 changes: 6 additions & 3 deletions wenet/branchformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,12 @@ def __init__(
WENET_ATTENTION_CLASSES[selfattention_layer_type](
*encoder_selfattn_layer_args) if use_attn else None,
cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None,
dropout_rate, merge_method, cgmlp_weight[lnum],
attn_branch_drop_rate[lnum], stochastic_depth_rate[lnum],
gradient_checkpointing) for lnum in range(num_blocks)
dropout_rate,
merge_method,
cgmlp_weight[lnum],
attn_branch_drop_rate[lnum],
stochastic_depth_rate[lnum],
) for lnum in range(num_blocks)
])

@torch.jit.ignore(drop=True)
Expand Down
40 changes: 25 additions & 15 deletions wenet/e_branchformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import torch
from typing import List, Optional, Union
from wenet.branchformer.encoder import LayerDropModuleList

from wenet.e_branchformer.encoder_layer import EBranchformerEncoderLayer
from wenet.branchformer.cgmlp import ConvolutionalGatingMLP
Expand Down Expand Up @@ -145,18 +146,27 @@ def __init__(
f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) "
f"should be equal to num_blocks ({num_blocks})")

self.encoders = torch.nn.ModuleList([
EBranchformerEncoderLayer(
output_size,
WENET_ATTENTION_CLASSES[selfattention_layer_type](
*encoder_selfattn_layer_args),
cgmlp_layer(*cgmlp_layer_args),
mlp_class(*positionwise_layer_args) if use_ffn else None,
mlp_class(*positionwise_layer_args)
if use_ffn and macaron_style else None,
dropout_rate,
merge_conv_kernel=merge_conv_kernel,
causal=causal,
stochastic_depth_rate=stochastic_depth_rate[lnum],
) for lnum in range(num_blocks)
])
self.encoders = LayerDropModuleList(
p=stochastic_depth_rate,
modules=[
EBranchformerEncoderLayer(
output_size,
WENET_ATTENTION_CLASSES[selfattention_layer_type](
*encoder_selfattn_layer_args),
cgmlp_layer(*cgmlp_layer_args),
mlp_class(*positionwise_layer_args) if use_ffn else None,
mlp_class(*positionwise_layer_args)
if use_ffn and macaron_style else None,
dropout_rate,
merge_conv_kernel=merge_conv_kernel,
causal=causal,
stochastic_depth_rate=stochastic_depth_rate[lnum],
) for lnum in range(num_blocks)
])

@torch.jit.ignore(drop=True)
def forward_layers_checkpointed(self, xs: torch.Tensor,
chunk_masks: torch.Tensor,
pos_emb: torch.Tensor,
mask_pad: torch.Tensor) -> torch.Tensor:
return self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
39 changes: 5 additions & 34 deletions wenet/e_branchformer/encoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import torch.nn as nn
from typing import Optional, Tuple

from wenet.branchformer.encoder_layer import BranchformerEncoderLayer

class EBranchformerEncoderLayer(torch.nn.Module):

class EBranchformerEncoderLayer(BranchformerEncoderLayer):
"""E-Branchformer encoder layer module.
Args:
Expand Down Expand Up @@ -88,47 +90,16 @@ def __init__(
self.merge_proj = torch.nn.Linear(size + size, size)
self.stochastic_depth_rate = stochastic_depth_rate

def forward(
def _forward(
self,
x: torch.Tensor,
mask: torch.Tensor,
pos_emb: torch.Tensor,
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
stoch_layer_coeff: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute encoded features.
Args:
x (Union[Tuple, torch.Tensor]): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time, time).
pos_emb (torch.Tensor): positional encoding, must not be None
for BranchformerEncoderLayer.
mask_pad (torch.Tensor): batch padding mask used for conv module.
(#batch, 1,time), (0, 0, 0) means fake mask.
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
cnn_cache (torch.Tensor): Convolution cache in cgmlp layer
(#batch=1, size, cache_t2)
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time, time.
torch.Tensor: att_cache tensor,
(#batch=1, head, cache_t1 + time, d_k * 2).
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
"""

stoch_layer_coeff = 1.0
skip_layer = False
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
if self.training and self.stochastic_depth_rate > 0:
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)

if skip_layer:
return x, mask, att_cache, cnn_cache

if self.feed_forward_macaron is not None:
residual = x
Expand Down

0 comments on commit 63d9a3b

Please sign in to comment.