Skip to content

Commit

Permalink
[cli/paraformer] ali-paraformer inference (#2067)
Browse files Browse the repository at this point in the history
* [cli/paraformer]   ali-paraformer load and infer work

* fix lint

* export jit and load work

* reuse init_model.py

* mv  the intermediate files to the assets directory

* model.decodde work && recognize.py work

* rm positionwise_feed_forward.py/lfr.py

* refactor search

* merge main

* cli work

* fix lint

* fix att mask && batch infer

* search confidence works

* merge main

* fix linux dtype

* fix label type

* revert init_model.py and add init_model in export_jit
  • Loading branch information
Mddct authored Oct 30, 2023
1 parent 5b5ce3b commit af1315c
Show file tree
Hide file tree
Showing 11 changed files with 9,519 additions and 128 deletions.
86 changes: 55 additions & 31 deletions wenet/cif/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,38 @@

import torch
from torch import nn
from torchaudio.compliance.kaldi import Tuple
from wenet.utils.mask import make_pad_mask


class Predictor(nn.Module):
def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1,
smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):

def __init__(self,
idim,
l_order,
r_order,
threshold=1.0,
dropout=0.1,
smooth_factor=1.0,
noise_threshold=0,
tail_threshold=0.45,
residual=True,
cnn_groups=0):
super().__init__()

self.pad = nn.ConstantPad1d((l_order, r_order), 0.0)
self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1,
groups=idim)
self.cif_conv1d = nn.Conv1d(
idim,
idim,
l_order + r_order + 1,
groups=idim if cnn_groups == 0 else cnn_groups)
self.cif_output = nn.Linear(idim, 1)
self.dropout = torch.nn.Dropout(p=dropout)
self.threshold = threshold
self.smooth_factor = smooth_factor
self.noise_threshold = noise_threshold
self.tail_threshold = tail_threshold
self.residual = residual

def forward(self,
hidden,
Expand All @@ -46,7 +61,10 @@ def forward(self,
context = h.transpose(1, 2)
queries = self.pad(context)
memory = self.cif_conv1d(queries)
output = memory + context
if self.residual:
output = memory + context
else:
output = memory
output = self.dropout(output)
output = output.transpose(1, 2)
output = torch.relu(output)
Expand All @@ -55,7 +73,7 @@ def forward(self,
alphas = torch.nn.functional.relu(alphas * self.smooth_factor -
self.noise_threshold)
if mask is not None:
mask = mask.transpose(-1, -2).float()
mask = mask.transpose(-1, -2)
alphas = alphas * mask
if mask_chunk_predictor is not None:
alphas = alphas * mask_chunk_predictor
Expand All @@ -72,10 +90,10 @@ def forward(self,
alphas *= (target_length / token_num)[:, None] \
.repeat(1, alphas.size(1))
elif self.tail_threshold > 0.0:
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas,
hidden, alphas, token_num = self.tail_process_fn(hidden,
alphas,
token_num,
mask=mask)

acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)

if target_length is None and self.tail_threshold > 0.0:
Expand All @@ -84,26 +102,32 @@ def forward(self,

return acoustic_embeds, token_num, alphas, cif_peak

def tail_process_fn(self, hidden, alphas,
token_num: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None):
b, t, d = hidden.size()
tail_threshold = self.tail_threshold
def tail_process_fn(
self,
hidden: torch.Tensor,
alphas: torch.Tensor,
token_num: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
b, _, d = hidden.size()
if mask is not None:
zeros_t = torch.zeros((b, 1), dtype=torch.float32,
zeros_t = torch.zeros((b, 1),
dtype=torch.float32,
device=alphas.device)
mask = mask.to(zeros_t.dtype)
ones_t = torch.ones_like(zeros_t)
mask_1 = torch.cat([mask, zeros_t], dim=1)
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
tail_threshold = mask * self.tail_threshold
alphas = torch.cat([alphas, zeros_t], dim=1)
alphas = torch.add(alphas, tail_threshold)
else:
tail_threshold_tensor = torch.tensor([tail_threshold],
tail_threshold_tensor = torch.tensor([self.tail_threshold],
dtype=alphas.dtype).to(
alphas.device)
tail_threshold_tensor = torch.reshape(tail_threshold_tensor, (1, 1))
alphas.device)
tail_threshold_tensor = torch.reshape(tail_threshold_tensor,
(1, 1))
alphas = torch.cat([alphas, tail_threshold_tensor], dim=1)
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
hidden = torch.cat([hidden, zeros], dim=1)
Expand Down Expand Up @@ -132,13 +156,15 @@ def gen_frame_alignments(self,

index = torch.ones([batch_size, max_token_num], dtype=int_type)
index = torch.cumsum(index, dim=1)
index = index[:, :, None].repeat(1, 1, maximum_length).to(
alphas_cumsum.device)
index = index[:, :,
None].repeat(1, 1,
maximum_length).to(alphas_cumsum.device)

index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(
int_type)
index_div = torch.floor(torch.true_divide(alphas_cumsum,
index)).type(int_type)
index_div_bool_zeros = index_div.eq(0)
index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
index_div_bool_zeros_count = torch.sum(index_div_bool_zeros,
dim=-1) + 1
index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0,
encoder_sequence_length.max())
token_num_mask = (~make_pad_mask(token_num, max_len=max_token_num)).to(
Expand Down Expand Up @@ -210,19 +236,17 @@ def cif(hidden: torch.Tensor, alphas: torch.Tensor, threshold: float):
list_fires.append(integrate)

fire_place = integrate >= threshold
integrate = torch.where(fire_place, integrate -
torch.ones([batch_size], device=hidden.device),
integrate)
cur = torch.where(fire_place,
distribution_completion,
alpha)
integrate = torch.where(
fire_place,
integrate - torch.ones([batch_size], device=hidden.device),
integrate)
cur = torch.where(fire_place, distribution_completion, alpha)
remainds = alpha - cur

frame += cur[:, None] * hidden[:, t, :]
list_frames.append(frame)
frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
remainds[:, None] * hidden[:, t, :],
frame)
remainds[:, None] * hidden[:, t, :], frame)

fires = torch.stack(list_fires, 1)
frames = torch.stack(list_frames, 1)
Expand Down
66 changes: 66 additions & 0 deletions wenet/cli/paraformer_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os

import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi

from wenet.paraformer.search import paraformer_greedy_search
from wenet.utils.file_utils import read_symbol_table


class Paraformer:

def __init__(self, model_dir: str) -> None:

model_path = os.path.join(model_dir, 'final.zip')
units_path = os.path.join(model_dir, 'units.txt')
self.model = torch.jit.load(model_path)
symbol_table = read_symbol_table(units_path)
self.char_dict = {v: k for k, v in symbol_table.items()}
self.eos = 2

def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict:
waveform, sample_rate = torchaudio.load(audio_file, normalize=False)
waveform = waveform.to(torch.float)
feats = kaldi.fbank(waveform,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
energy_floor=0.0,
sample_frequency=16000)
feats = feats.unsqueeze(0)
feats_lens = torch.tensor([feats.size(1)], dtype=torch.int64)

decoder_out, token_num = self.model.forward_paraformer(
feats, feats_lens)

res = paraformer_greedy_search(decoder_out, token_num)[0]

result = {}
result['confidence'] = res.confidence
# # TODO(Mddct): deal with '@@' and 'eos'
result['rec'] = "".join([self.char_dict[x] for x in res.tokens])

if tokens_info:
tokens_info = []
for i, x in enumerate(res.tokens):
tokens_info.append({
'token': self.char_dict[x],
# TODO(Mddct): support times
# 'start': 0,
# 'end': 0,
'confidence': res.tokens_confidence[i]
})
result['tokens'] = tokens_info

# result = ''.join(hyp)
return result

def align(self, audio_file: str, label: str) -> dict:
raise NotImplementedError


def load_model(language: str = None, model_dir: str = None) -> Paraformer:
if model_dir is None:
model_dir = Hub.get_model_by_lang(language)
return Paraformer(model_dir)
10 changes: 9 additions & 1 deletion wenet/cli/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import argparse

from wenet.cli.paraformer_model import load_model as load_paraformer
from wenet.cli.model import load_model


Expand Down Expand Up @@ -41,13 +42,20 @@ def get_args():
action='store_true',
help='force align the input audio and transcript')
parser.add_argument('--label', type=str, help='the input label to align')
parser.add_argument('--paraformer',
action='store_true',
help='whether to use the best chinese model')
args = parser.parse_args()
return args


def main():
args = get_args()
model = load_model(args.language, args.model_dir)

if args.paraformer:
model = load_paraformer(args.language, args.model_dir)
else:
model = load_model(args.language, args.model_dir)
if args.align:
result = model.align(args.audio_file, args.label)
else:
Expand Down
49 changes: 49 additions & 0 deletions wenet/paraformer/ali_paraformer/assets/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# network architecture
# encoder related
encoder: SanmEncoder
encoder_conf:
output_size: 512 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 50 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
input_layer: 'conv2d' # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: true
kernel_size: 11
sanm_shfit: 0

input_dim: 80
output_dim: 8404
paraformer: true
is_json_cmvn: True
# decoder related
decoder: SanmDecoder
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 16
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1
att_layer_num: 16
kernel_size: 11
sanm_shfit: 0

lfr_conf:
lfr_m: 7
lfr_n: 6

cif_predictor_conf:
idim: 512
threshold: 1.0
l_order: 1
r_order: 1
tail_threshold: 0.45
cnn_groups: 1
residual: false

model_conf:
ctc_weight: 0.0
Loading

0 comments on commit af1315c

Please sign in to comment.