Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cli/paraformer] ali-paraformer inference #2067

Merged
merged 21 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 55 additions & 31 deletions wenet/cif/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,38 @@

import torch
from torch import nn
from torchaudio.compliance.kaldi import Tuple
from wenet.utils.mask import make_pad_mask


class Predictor(nn.Module):
def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1,
smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):

def __init__(self,
idim,
l_order,
r_order,
threshold=1.0,
dropout=0.1,
smooth_factor=1.0,
noise_threshold=0,
tail_threshold=0.45,
residual=True,
cnn_groups=0):
super().__init__()

self.pad = nn.ConstantPad1d((l_order, r_order), 0.0)
self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1,
groups=idim)
self.cif_conv1d = nn.Conv1d(
idim,
idim,
l_order + r_order + 1,
groups=idim if cnn_groups == 0 else cnn_groups)
self.cif_output = nn.Linear(idim, 1)
self.dropout = torch.nn.Dropout(p=dropout)
self.threshold = threshold
self.smooth_factor = smooth_factor
self.noise_threshold = noise_threshold
self.tail_threshold = tail_threshold
self.residual = residual

def forward(self,
hidden,
Expand All @@ -46,7 +61,10 @@ def forward(self,
context = h.transpose(1, 2)
queries = self.pad(context)
memory = self.cif_conv1d(queries)
output = memory + context
if self.residual:
output = memory + context
else:
output = memory
output = self.dropout(output)
output = output.transpose(1, 2)
output = torch.relu(output)
Expand All @@ -55,7 +73,7 @@ def forward(self,
alphas = torch.nn.functional.relu(alphas * self.smooth_factor -
self.noise_threshold)
if mask is not None:
mask = mask.transpose(-1, -2).float()
mask = mask.transpose(-1, -2)
alphas = alphas * mask
if mask_chunk_predictor is not None:
alphas = alphas * mask_chunk_predictor
Expand All @@ -72,10 +90,10 @@ def forward(self,
alphas *= (target_length / token_num)[:, None] \
.repeat(1, alphas.size(1))
elif self.tail_threshold > 0.0:
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas,
hidden, alphas, token_num = self.tail_process_fn(hidden,
alphas,
token_num,
mask=mask)

acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)

if target_length is None and self.tail_threshold > 0.0:
Expand All @@ -84,26 +102,32 @@ def forward(self,

return acoustic_embeds, token_num, alphas, cif_peak

def tail_process_fn(self, hidden, alphas,
token_num: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None):
b, t, d = hidden.size()
tail_threshold = self.tail_threshold
def tail_process_fn(
self,
hidden: torch.Tensor,
alphas: torch.Tensor,
token_num: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
b, _, d = hidden.size()
if mask is not None:
zeros_t = torch.zeros((b, 1), dtype=torch.float32,
zeros_t = torch.zeros((b, 1),
dtype=torch.float32,
device=alphas.device)
mask = mask.to(zeros_t.dtype)
ones_t = torch.ones_like(zeros_t)
mask_1 = torch.cat([mask, zeros_t], dim=1)
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
tail_threshold = mask * self.tail_threshold
alphas = torch.cat([alphas, zeros_t], dim=1)
alphas = torch.add(alphas, tail_threshold)
else:
tail_threshold_tensor = torch.tensor([tail_threshold],
tail_threshold_tensor = torch.tensor([self.tail_threshold],
dtype=alphas.dtype).to(
alphas.device)
tail_threshold_tensor = torch.reshape(tail_threshold_tensor, (1, 1))
alphas.device)
tail_threshold_tensor = torch.reshape(tail_threshold_tensor,
(1, 1))
alphas = torch.cat([alphas, tail_threshold_tensor], dim=1)
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
hidden = torch.cat([hidden, zeros], dim=1)
Expand Down Expand Up @@ -132,13 +156,15 @@ def gen_frame_alignments(self,

index = torch.ones([batch_size, max_token_num], dtype=int_type)
index = torch.cumsum(index, dim=1)
index = index[:, :, None].repeat(1, 1, maximum_length).to(
alphas_cumsum.device)
index = index[:, :,
None].repeat(1, 1,
maximum_length).to(alphas_cumsum.device)

index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(
int_type)
index_div = torch.floor(torch.true_divide(alphas_cumsum,
index)).type(int_type)
index_div_bool_zeros = index_div.eq(0)
index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
index_div_bool_zeros_count = torch.sum(index_div_bool_zeros,
dim=-1) + 1
index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0,
encoder_sequence_length.max())
token_num_mask = (~make_pad_mask(token_num, max_len=max_token_num)).to(
Expand Down Expand Up @@ -210,19 +236,17 @@ def cif(hidden: torch.Tensor, alphas: torch.Tensor, threshold: float):
list_fires.append(integrate)

fire_place = integrate >= threshold
integrate = torch.where(fire_place, integrate -
torch.ones([batch_size], device=hidden.device),
integrate)
cur = torch.where(fire_place,
distribution_completion,
alpha)
integrate = torch.where(
fire_place,
integrate - torch.ones([batch_size], device=hidden.device),
integrate)
cur = torch.where(fire_place, distribution_completion, alpha)
remainds = alpha - cur

frame += cur[:, None] * hidden[:, t, :]
list_frames.append(frame)
frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
remainds[:, None] * hidden[:, t, :],
frame)
remainds[:, None] * hidden[:, t, :], frame)

fires = torch.stack(list_fires, 1)
frames = torch.stack(list_frames, 1)
Expand Down
66 changes: 66 additions & 0 deletions wenet/cli/paraformer_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os

import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi

from wenet.paraformer.search import paraformer_greedy_search
from wenet.utils.file_utils import read_symbol_table


class Paraformer:

def __init__(self, model_dir: str) -> None:

model_path = os.path.join(model_dir, 'final.zip')
units_path = os.path.join(model_dir, 'units.txt')
self.model = torch.jit.load(model_path)
symbol_table = read_symbol_table(units_path)
self.char_dict = {v: k for k, v in symbol_table.items()}
self.eos = 2

def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict:
waveform, sample_rate = torchaudio.load(audio_file, normalize=False)
waveform = waveform.to(torch.float)
feats = kaldi.fbank(waveform,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default window in the FunASR frontend is hamming. You can find more details here. However, the default window in kaldi.fbank is povey, as specified here. This different window maybe a little mismatch. As mentioned in line 44 of this document:

"povey" is a window I made to be similar to Hamming but to go to zero at the edges, it's pow((0.5 - 0.5cos(n/N2*pi)), 0.85) I just don't think the Hamming window makes sense as a windowing function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pr welcome

num_mel_bins=80,
frame_length=25,
frame_shift=10,
energy_floor=0.0,
sample_frequency=16000)
feats = feats.unsqueeze(0)
feats_lens = torch.tensor([feats.size(1)], dtype=torch.int64)

decoder_out, token_num = self.model.forward_paraformer(
feats, feats_lens)

res = paraformer_greedy_search(decoder_out, token_num)[0]

result = {}
result['confidence'] = res.confidence
# # TODO(Mddct): deal with '@@' and 'eos'
result['rec'] = "".join([self.char_dict[x] for x in res.tokens])

if tokens_info:
tokens_info = []
for i, x in enumerate(res.tokens):
tokens_info.append({
'token': self.char_dict[x],
# TODO(Mddct): support times
# 'start': 0,
# 'end': 0,
'confidence': res.tokens_confidence[i]
})
result['tokens'] = tokens_info

# result = ''.join(hyp)
return result

def align(self, audio_file: str, label: str) -> dict:
raise NotImplementedError


def load_model(language: str = None, model_dir: str = None) -> Paraformer:
if model_dir is None:
model_dir = Hub.get_model_by_lang(language)
return Paraformer(model_dir)
10 changes: 9 additions & 1 deletion wenet/cli/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import argparse

from wenet.cli.paraformer_model import load_model as load_paraformer
from wenet.cli.model import load_model


Expand Down Expand Up @@ -41,13 +42,20 @@ def get_args():
action='store_true',
help='force align the input audio and transcript')
parser.add_argument('--label', type=str, help='the input label to align')
parser.add_argument('--paraformer',
action='store_true',
help='whether to use the best chinese model')
args = parser.parse_args()
return args


def main():
args = get_args()
model = load_model(args.language, args.model_dir)

if args.paraformer:
model = load_paraformer(args.language, args.model_dir)
else:
model = load_model(args.language, args.model_dir)
if args.align:
result = model.align(args.audio_file, args.label)
else:
Expand Down
49 changes: 49 additions & 0 deletions wenet/paraformer/ali_paraformer/assets/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# network architecture
# encoder related
encoder: SanmEncoder
encoder_conf:
output_size: 512 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 50 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
input_layer: 'conv2d' # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: true
kernel_size: 11
sanm_shfit: 0

input_dim: 80
output_dim: 8404
paraformer: true
is_json_cmvn: True
# decoder related
decoder: SanmDecoder
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 16
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1
att_layer_num: 16
kernel_size: 11
sanm_shfit: 0

lfr_conf:
lfr_m: 7
lfr_n: 6

cif_predictor_conf:
idim: 512
threshold: 1.0
l_order: 1
r_order: 1
tail_threshold: 0.45
cnn_groups: 1
residual: false

model_conf:
ctc_weight: 0.0
Loading
Loading