Skip to content

Commit

Permalink
rm positionwise_feed_forward.py/lfr.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Oct 23, 2023
1 parent e987b00 commit 3d25e2e
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 107 deletions.
72 changes: 0 additions & 72 deletions wenet/paraformer/ali_paraformer/lfr.py

This file was deleted.

104 changes: 101 additions & 3 deletions wenet/paraformer/ali_paraformer/model.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
""" NOTE(Mddct): This file is experimental and is used to export paraformer
"""

import math
from typing import Dict, List, Optional, Tuple
import torch
from wenet.cif.predictor import Predictor
from wenet.paraformer.ali_paraformer.attention import (DummyMultiHeadSANM,
MultiHeadAttentionCross,
MultiHeadedAttentionSANM
)
from wenet.paraformer.ali_paraformer.lfr import LFR
from wenet.paraformer.ali_paraformer.positionwise_feed_forward import \
PositionwiseFeedForwardDecoderSANM
from wenet.transformer.search import DecodeResult
from wenet.transformer.encoder import BaseEncoder
from wenet.transformer.decoder import TransformerDecoder
Expand All @@ -20,6 +18,106 @@
from wenet.utils.mask import make_non_pad_mask


class LFR(torch.nn.Module):

def __init__(self, m: int = 7, n: int = 6) -> None:
"""
Actually, this implements stacking frames and skipping frames.
if m = 1 and n = 1, just return the origin features.
if m = 1 and n > 1, it works like skipping.
if m > 1 and n = 1, it works like stacking but only support right frames.
if m > 1 and n > 1, it works like LFR.
"""
super().__init__()

self.m = m
self.n = n

self.left_padding_nums = math.ceil((self.m - 1) // 2)

def forward(self, input: torch.Tensor,
input_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
B, _, D = input.size()
n_lfr = torch.ceil(input_lens / self.n)
# print(n_lfr)
# right_padding_nums >= 0
prepad_nums = input_lens + self.left_padding_nums

right_padding_nums = torch.where(
self.m >= (prepad_nums - self.n * (n_lfr - 1)),
self.m - (prepad_nums - self.n * (n_lfr - 1)),
0,
)
T_all = self.left_padding_nums + input_lens + right_padding_nums

new_len = T_all // self.n

T_all_max = T_all.max().int()

tail_frames_index = (input_lens - 1).view(B, 1, 1).repeat(1, 1,
D) # [B,1,D]

tail_frames = torch.gather(input, 1, tail_frames_index)
tail_frames = tail_frames.repeat(1, right_padding_nums.max().int(), 1)
head_frames = input[:, 0:1, :].repeat(1, self.left_padding_nums, 1)

# stack
input = torch.cat([head_frames, input, tail_frames], dim=1)

index = torch.arange(T_all_max,
device=input.device,
dtype=input_lens.dtype).unsqueeze(0).repeat(
B, 1) # [B, T_all_max]
# [B, T_all_max]
index_mask = index < (self.left_padding_nums + input_lens).unsqueeze(1)

tail_index_mask = torch.logical_not(
index >= (T_all.unsqueeze(1))) & index_mask
tail = torch.ones(T_all_max,
dtype=input_lens.dtype,
device=input.device).unsqueeze(0).repeat(B, 1) * (
T_all_max - 1) # [B, T_all_max]
indices = torch.where(torch.logical_or(index_mask, tail_index_mask),
index, tail)
input = torch.gather(input, 1, indices.unsqueeze(2).repeat(1, 1, D))

input = input.unfold(1, self.m, step=self.n).transpose(2, 3)
# new len
return input.reshape(B, -1, D * self.m), new_len


class PositionwiseFeedForwardDecoderSANM(torch.nn.Module):
"""Positionwise feed forward layer.
Args:
idim (int): Input dimenstion.
hidden_units (int): The number of hidden units.
dropout_rate (float): Dropout rate.
"""

def __init__(self,
idim,
hidden_units,
dropout_rate,
adim=None,
activation=torch.nn.ReLU()):
"""Construct an PositionwiseFeedForward object."""
super(PositionwiseFeedForwardDecoderSANM, self).__init__()
self.w_1 = torch.nn.Linear(idim, hidden_units)
self.w_2 = torch.nn.Linear(hidden_units,
idim if adim is None else adim,
bias=False)
self.dropout = torch.nn.Dropout(dropout_rate)
self.activation = activation
self.norm = torch.nn.LayerNorm(hidden_units)

def forward(self, x):
"""Forward function."""
return self.w_2(self.norm(self.dropout(self.activation(self.w_1(x)))))


class SinusoidalPositionEncoder(torch.nn.Module):
"""https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/modules/embedding.py#L387
"""
Expand Down
32 changes: 0 additions & 32 deletions wenet/paraformer/ali_paraformer/positionwise_feed_forward.py

This file was deleted.

0 comments on commit 3d25e2e

Please sign in to comment.