Skip to content

Commit

Permalink
Add E-Branchformer module
Browse files Browse the repository at this point in the history
  • Loading branch information
TeaPoly committed Sep 15, 2023
1 parent 80e9a05 commit 89962d1
Show file tree
Hide file tree
Showing 6 changed files with 676 additions and 0 deletions.
15 changes: 15 additions & 0 deletions examples/aishell/s0/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,4 +203,19 @@
| attention rescoring | 4.81 |
| LM + attention rescoring | 4.46 |

## E-Branchformer Result

* Feature info: using fbank feature, dither=1.0, cmvn, online speed perturb
* * Model info:
* Model Params: 47,570,132
* Num Encoder Layer: 17
* CNN Kernel Size: 31
* Training info: lr 0.001, weight_decay: 0.000001, batch size 16, 4 gpu, acc_grad 1, 240 epochs
* Decoding info: ctc_weight 0.3, average_num 30

| decoding mode | CER |
| ---------------------- | ---- |
| attention decoder | 4.70 |
| ctc greedy search | 4.78 |
| ctc prefix beam search | 4.78 |
| attention rescoring | 4.41 |
88 changes: 88 additions & 0 deletions examples/aishell/s0/conf/train_ebranchformer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# network architecture
# encoder related
encoder: e_branchformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 8
linear_units: 1024 # the number of units of position-wise feed forward
num_blocks: 17 # the number of encoder blocks
cgmlp_linear_units: 1024
cgmlp_conv_kernel: 31
use_linear_after_conv: false
gate_activation: identity
merge_conv_kernel: 31
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
activation_type: 'swish'
causal: false
pos_enc_layer_type: 'rel_pos'
attention_layer_type: 'rel_selfattn'

# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1

# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

dataset_conf:
filter_conf:
max_length: 40960
min_length: 0
token_max_length: 200
token_min_length: 1
resample_conf:
resample_rate: 16000
speed_perturb: true
fbank_conf:
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 1.0
spec_aug: true
spec_aug_conf:
num_t_mask: 2
num_f_mask: 2
max_t: 50
max_f: 10
spec_sub: true
spec_sub_conf:
num_t_sub: 3
max_t: 30
spec_trim: false
spec_trim_conf:
max_t: 50
shuffle: true
shuffle_conf:
shuffle_size: 1500
sort: true
sort_conf:
sort_size: 500 # sort_size should be less than shuffle_size
batch_conf:
batch_type: 'static' # static or dynamic
batch_size: 16

grad_clip: 5
accum_grad: 1
max_epoch: 240
log_interval: 100

optim: adam
optim_conf:
lr: 0.001
weight_decay: 0.000001
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 35000
2 changes: 2 additions & 0 deletions wenet/branchformer/cgmlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def __init__(
dropout_rate: float,
use_linear_after_conv: bool,
gate_activation: str,
causal: bool = True,
):
super().__init__()

Expand All @@ -155,6 +156,7 @@ def __init__(
dropout_rate=dropout_rate,
use_linear_after_conv=use_linear_after_conv,
gate_activation=gate_activation,
causal=causal,
)
self.channel_proj2 = torch.nn.Linear(linear_units // 2, size)

Expand Down
Loading

0 comments on commit 89962d1

Please sign in to comment.