From 6d867f714d70689cad25ce6086b1458fe5552915 Mon Sep 17 00:00:00 2001 From: yeyupiaoling Date: Tue, 20 Dec 2022 09:56:53 +0800 Subject: [PATCH 1/9] add squeezeformer model --- .../asr1/conf/chunk_squeezeformer.yaml | 98 +++++ examples/aishell/asr1/conf/squeezeformer.yaml | 93 +++++ paddlespeech/s2t/models/u2/u2.py | 5 +- paddlespeech/s2t/modules/attention.py | 133 ++++++ paddlespeech/s2t/modules/conv2d.py | 56 +++ paddlespeech/s2t/modules/convolution.py | 172 ++++++++ paddlespeech/s2t/modules/encoder.py | 383 +++++++++++++++++- paddlespeech/s2t/modules/encoder_layer.py | 123 ++++++ .../s2t/modules/positionwise_feed_forward.py | 61 ++- paddlespeech/s2t/modules/subsampling.py | 259 +++++++++++- paddlespeech/s2t/utils/utility.py | 12 +- 11 files changed, 1384 insertions(+), 11 deletions(-) create mode 100644 examples/aishell/asr1/conf/chunk_squeezeformer.yaml create mode 100644 examples/aishell/asr1/conf/squeezeformer.yaml create mode 100644 paddlespeech/s2t/modules/conv2d.py create mode 100644 paddlespeech/s2t/modules/convolution.py diff --git a/examples/aishell/asr1/conf/chunk_squeezeformer.yaml b/examples/aishell/asr1/conf/chunk_squeezeformer.yaml new file mode 100644 index 00000000000..691d9046162 --- /dev/null +++ b/examples/aishell/asr1/conf/chunk_squeezeformer.yaml @@ -0,0 +1,98 @@ +############################################ +# Network Architecture # +############################################ +cmvn_file: +cmvn_file_type: "json" +# encoder related +encoder: conformer +encoder_conf: + encoder_dim: 256 # dimension of attention + output_size: 256 # dimension of output + attention_heads: 4 + num_blocks: 12 # the number of encoder blocks + reduce_idx: 5 + recover_idx: 11 + feed_forward_expansion_factor: 4 + input_dropout_rate: 0.1 + feed_forward_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + adaptive_scale: true + cnn_module_kernel: 31 + normalize_before: false + activation_type: 'swish' + pos_enc_layer_type: 'rel_pos' + time_reduction_layer_type: 'conv2d' + causal: true + use_dynamic_chunk: true + use_dynamic_left_chunk: false + +# decoder related +decoder: transformer +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 # sublayer output dropout + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + init_type: 'kaiming_uniform' # !Warning: need to convergence + +########################################### +# Data # +########################################### + +train_manifest: data/manifest.train +dev_manifest: data/manifest.dev +test_manifest: data/manifest.test + + +########################################### +# Dataloader # +########################################### + +vocab_filepath: data/lang_char/vocab.txt +spm_model_prefix: '' +unit_type: 'char' +preprocess_config: conf/preprocess.yaml +feat_dim: 80 +stride_ms: 10.0 +window_ms: 25.0 +sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs +batch_size: 32 +maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced +maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced +minibatches: 0 # for debug +batch_count: auto +batch_bins: 0 +batch_frames_in: 0 +batch_frames_out: 0 +batch_frames_inout: 0 +num_workers: 2 +subsampling_factor: 1 +num_encs: 1 + +########################################### +# Training # +########################################### +n_epoch: 240 +accum_grad: 1 +global_grad_clip: 5.0 +dist_sampler: True +optim: adam +optim_conf: + lr: 0.001 + weight_decay: 1.0e-6 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 25000 + lr_decay: 1.0 +log_interval: 100 +checkpoint: + kbest_n: 50 + latest_n: 5 diff --git a/examples/aishell/asr1/conf/squeezeformer.yaml b/examples/aishell/asr1/conf/squeezeformer.yaml new file mode 100644 index 00000000000..db8ef7c2df2 --- /dev/null +++ b/examples/aishell/asr1/conf/squeezeformer.yaml @@ -0,0 +1,93 @@ +############################################ +# Network Architecture # +############################################ +cmvn_file: +cmvn_file_type: "json" +# encoder related +encoder: squeezeformer +encoder_conf: + encoder_dim: 256 # dimension of attention + output_size: 256 # dimension of output + attention_heads: 4 + num_blocks: 12 # the number of encoder blocks + reduce_idx: 5 + recover_idx: 11 + feed_forward_expansion_factor: 4 + input_dropout_rate: 0.1 + feed_forward_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + adaptive_scale: true + cnn_module_kernel: 31 + normalize_before: false + activation_type: 'swish' + pos_enc_layer_type: 'rel_pos' + time_reduction_layer_type: 'conv2d' + +# decoder related +decoder: transformer +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + init_type: 'kaiming_uniform' # !Warning: need to convergence + +########################################### +# Data # +########################################### +train_manifest: data/manifest.train +dev_manifest: data/manifest.dev +test_manifest: data/manifest.test + +########################################### +# Dataloader # +########################################### +vocab_filepath: data/lang_char/vocab.txt +spm_model_prefix: '' +unit_type: 'char' +preprocess_config: conf/preprocess.yaml +feat_dim: 80 +stride_ms: 10.0 +window_ms: 25.0 +sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs +batch_size: 32 +maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced +maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced +minibatches: 0 # for debug +batch_count: auto +batch_bins: 0 +batch_frames_in: 0 +batch_frames_out: 0 +batch_frames_inout: 0 +num_workers: 2 +subsampling_factor: 1 +num_encs: 1 + +########################################### +# Training # +########################################### +n_epoch: 150 +accum_grad: 8 +global_grad_clip: 5.0 +dist_sampler: False +optim: adam +optim_conf: + lr: 0.002 + weight_decay: 1.0e-6 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 25000 + lr_decay: 1.0 +log_interval: 100 +checkpoint: + kbest_n: 50 + latest_n: 5 diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index 544c1e8367e..7e72a01fd9c 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -42,7 +42,7 @@ from paddlespeech.s2t.modules.ctc import CTCDecoderBase from paddlespeech.s2t.modules.decoder import BiTransformerDecoder from paddlespeech.s2t.modules.decoder import TransformerDecoder -from paddlespeech.s2t.modules.encoder import ConformerEncoder +from paddlespeech.s2t.modules.encoder import ConformerEncoder, SqueezeformerEncoder from paddlespeech.s2t.modules.encoder import TransformerEncoder from paddlespeech.s2t.modules.initializer import DefaultInitializerContext from paddlespeech.s2t.modules.loss import LabelSmoothingLoss @@ -905,6 +905,9 @@ def _init_from_config(cls, configs: dict): elif encoder_type == 'conformer': encoder = ConformerEncoder( input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) + elif encoder_type == 'squeezeformer': + encoder = SqueezeformerEncoder( + input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) else: raise ValueError(f"not support encoder type:{encoder_type}") diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index d9568dcc90d..e149c504169 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -330,3 +330,136 @@ def forward(self, self.d_k) # (batch, head, time1, time2) return self.forward_attention(v, scores, mask), new_cache + + +class RelPositionMultiHeadedAttention2(MultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + """ + + def __init__(self, n_head, n_feat, dropout_rate, do_rel_shift=False, adaptive_scale=False, init_weights=False): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate) + # linear transformation for positional encoding + self.linear_pos = Linear(n_feat, n_feat) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.do_rel_shift = do_rel_shift + pos_bias_u = self.create_parameter([self.h, self.d_k], default_initializer=I.XavierUniform()) + self.add_parameter('pos_bias_u', pos_bias_u) + pos_bias_v = self.create_parameter([self.h, self.d_k], default_initializer=I.XavierUniform()) + self.add_parameter('pos_bias_v', pos_bias_v) + self.adaptive_scale = adaptive_scale + ada_scale = self.create_parameter([1, 1, n_feat], default_initializer=I.Constant(1.0)) + self.add_parameter('ada_scale', ada_scale) + ada_bias = self.create_parameter([1, 1, n_feat], default_initializer=I.Constant(0.0)) + self.add_parameter('ada_bias', ada_bias) + if init_weights: + self.init_weights() + + def init_weights(self): + input_max = (self.h * self.d_k) ** -0.5 + self.linear_q._param_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) + self.linear_q._bias_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) + self.linear_k._param_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) + self.linear_k._bias_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) + self.linear_v._param_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) + self.linear_v._bias_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) + self.linear_pos._param_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) + self.linear_pos._bias_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) + self.linear_out._param_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) + self.linear_out._bias_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) + + def rel_shift(self, x, zero_triu: bool = False): + """Compute relative positinal encoding. + Args: + x (paddle.Tensor): Input tensor (batch, head, time1, time1). + zero_triu (bool): If true, return the lower triangular part of + the matrix. + Returns: + paddle.Tensor: Output tensor. (batch, head, time1, time1) + """ + zero_pad = paddle.zeros([x.shape[0], x.shape[1], x.shape[2], 1], dtype=x.dtype) + x_padded = paddle.concat([zero_pad, x], axis=-1) + + x_padded = x_padded.reshape([x.shape[0], x.shape[1], x.shape[3] + 1, x.shape[2]]) + x = x_padded[:, :, 1:].reshape(paddle.shape(x)) # [B, H, T1, T1] + + if zero_triu: + ones = paddle.ones((x.shape[2], x.shape[3])) + x = x * paddle.tril(ones, x.shape[3] - x.shape[2])[None, None, :, :] + + return x + + def forward(self, query: paddle.Tensor, + key: paddle.Tensor, value: paddle.Tensor, + mask: paddle.Tensor = paddle.ones((0, 0, 0), dtype=paddle.bool), + pos_emb: paddle.Tensor = paddle.empty([0]), + cache: paddle.Tensor = paddle.zeros((0, 0, 0, 0)) + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + Args: + query (paddle.Tensor): Query tensor (#batch, time1, size). + key (paddle.Tensor): Key tensor (#batch, time2, size). + value (paddle.Tensor): Value tensor (#batch, time2, size). + mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2), (0, 0, 0) means fake mask. + pos_emb (paddle.Tensor): Positional embedding tensor + (#batch, time2, size). + cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2), + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + Returns: + paddle.Tensor: Output tensor (#batch, time1, d_model). + paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + """ + if self.adaptive_scale: + query = self.ada_scale * query + self.ada_bias + key = self.ada_scale * key + self.ada_bias + value = self.ada_scale * value + self.ada_bias + + q, k, v = self.forward_qkv(query, key, value) + if cache.shape[0] > 0: + key_cache, value_cache = paddle.split(cache, 2, axis=-1) + k = paddle.concat([key_cache, k], axis=2) + v = paddle.concat([value_cache, v], axis=2) + # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's + # non-trivial to calculate `next_cache_start` here. + new_cache = paddle.concat((k, v), axis=-1) + + n_batch_pos = pos_emb.shape[0] + p = self.linear_pos(pos_emb).reshape([n_batch_pos, -1, self.h, self.d_k]) + p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) + + # (batch, head, time1, d_k) + # q_with_bias_u = (q + self.pos_bias_u).transpose([0, 2, 1, 3]) + q_with_bias_u = q + self.pos_bias_u.unsqueeze(1) + # (batch, head, time1, d_k) + # q_with_bias_v = (q + self.pos_bias_v).transpose([0, 2, 1, 3]) + q_with_bias_v = q + self.pos_bias_v.unsqueeze(1) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + # matrix_ac = paddle.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2])) + matrix_ac = paddle.matmul(q_with_bias_u, k, transpose_y=True) + + # compute matrix b and matrix d + # (batch, head, time1, time2) + # matrix_bd = paddle.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2])) + matrix_bd = paddle.matmul(q_with_bias_v, p, transpose_y=True) + # Remove rel_shift since it is useless in speech recognition, + # and it requires special attention for streaming. + if self.do_rel_shift: + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask), new_cache diff --git a/paddlespeech/s2t/modules/conv2d.py b/paddlespeech/s2t/modules/conv2d.py new file mode 100644 index 00000000000..4b41d80a412 --- /dev/null +++ b/paddlespeech/s2t/modules/conv2d.py @@ -0,0 +1,56 @@ +from typing import Union, Optional + +import paddle +import paddle.nn.functional as F +from paddle.nn.layer.conv import _ConvNd + +__all__ = ['Conv2DValid'] + + +class Conv2DValid(_ConvNd): + """ + Conv2d operator for VALID mode padding. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: Union[str, int] = 0, + dilation: int = 1, + groups: int = 1, + padding_mode: str = 'zeros', + weight_attr=None, + bias_attr=None, + data_format="NCHW", + valid_trigx: bool = False, + valid_trigy: bool = False + ) -> None: + super(Conv2DValid, self).__init__(in_channels, + out_channels, + kernel_size, + False, + 2, + stride=stride, + padding=padding, + padding_mode=padding_mode, + dilation=dilation, + groups=groups, + weight_attr=weight_attr, + bias_attr=bias_attr, + data_format=data_format) + self.valid_trigx = valid_trigx + self.valid_trigy = valid_trigy + + def _conv_forward(self, input: paddle.Tensor, weight: paddle.Tensor, bias: Optional[paddle.Tensor]): + validx, validy = 0, 0 + if self.valid_trigx: + validx = (input.shape[-2] * (self._stride[-2] - 1) - 1 + self._kernel_size[-2]) // 2 + if self.valid_trigy: + validy = (input.shape[-1] * (self._stride[-1] - 1) - 1 + self._kernel_size[-1]) // 2 + return F.conv2d(input, weight, bias, self._stride, (validx, validy), self._dilation, self._groups) + + def forward(self, input: paddle.Tensor) -> paddle.Tensor: + return self._conv_forward(input, self.weight, self.bias) diff --git a/paddlespeech/s2t/modules/convolution.py b/paddlespeech/s2t/modules/convolution.py new file mode 100644 index 00000000000..0b5ab0356c5 --- /dev/null +++ b/paddlespeech/s2t/modules/convolution.py @@ -0,0 +1,172 @@ +from typing import Tuple + +import paddle +from paddle import nn +from paddle.nn import initializer as I +from typeguard import check_argument_types + +__all__ = ['ConvolutionModule'] + +from paddlespeech.s2t import masked_fill +from paddlespeech.s2t.modules.align import Conv1D, BatchNorm1D, LayerNorm + + +class ConvolutionModule2(nn.Layer): + """ConvolutionModule in Conformer model.""" + + def __init__(self, + channels: int, + kernel_size: int = 15, + activation: nn.Layer = nn.ReLU(), + norm: str = "batch_norm", + causal: bool = False, + bias: bool = True, + adaptive_scale: bool = False, + init_weights: bool = False): + """Construct an ConvolutionModule object. + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernel size of conv layers. + causal (int): Whether use causal convolution or not + """ + assert check_argument_types() + super().__init__() + self.bias = bias + self.channels = channels + self.kernel_size = kernel_size + self.adaptive_scale = adaptive_scale + ada_scale = self.create_parameter([1, 1, channels], default_initializer=I.Constant(1.0)) + self.add_parameter('ada_scale', ada_scale) + ada_bias = self.create_parameter([1, 1, channels], default_initializer=I.Constant(0.0)) + self.add_parameter('ada_bias', ada_bias) + + self.pointwise_conv1 = Conv1D( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias_attr=None + if bias else False, # None for True, using bias as default config + ) + + # self.lorder is used to distinguish if it's a causal convolution, + # if self.lorder > 0: it's a causal convolution, the input will be + # padded with self.lorder frames on the left in forward. + # else: it's a symmetrical convolution + if causal: + padding = 0 + self.lorder = kernel_size - 1 + else: + # kernel_size should be an odd number for none causal convolution + assert (kernel_size - 1) % 2 == 0 + padding = (kernel_size - 1) // 2 + self.lorder = 0 + self.depthwise_conv = Conv1D( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + bias_attr=None + if bias else False, # None for True, using bias as default config + ) + + assert norm in ['batch_norm', 'layer_norm'] + if norm == "batch_norm": + self.use_layer_norm = False + self.norm = BatchNorm1D(channels) + else: + self.use_layer_norm = True + self.norm = LayerNorm(channels) + + self.pointwise_conv2 = Conv1D( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias_attr=None + if bias else False, # None for True, using bias as default config + ) + self.activation = activation + + if init_weights: + self.init_weights() + + def init_weights(self): + pw_max = self.channels ** -0.5 + dw_max = self.kernel_size ** -0.5 + self.pointwise_conv1._param_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) + if self.bias: + self.pointwise_conv1._bias_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) + self.depthwise_conv._param_attr = paddle.nn.initializer.Uniform(low=-dw_max, high=dw_max) + if self.bias: + self.depthwise_conv._bias_attr = paddle.nn.initializer.Uniform(low=-dw_max, high=dw_max) + self.pointwise_conv2._param_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) + if self.bias: + self.pointwise_conv2._bias_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) + + def forward( + self, + x: paddle.Tensor, + mask_pad: paddle.Tensor = paddle.ones([0, 0, 0], dtype=paddle.bool), + cache: paddle.Tensor = paddle.zeros([0, 0, 0]), + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Compute convolution module. + Args: + x (torch.Tensor): Input tensor (#batch, time, channels). + mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), + (0, 0, 0) means fake mask. + cache (torch.Tensor): left context cache, it is only + used in causal convolution (#batch, channels, cache_t), + (0, 0, 0) meas fake cache. + Returns: + torch.Tensor: Output tensor (#batch, time, channels). + """ + if self.adaptive_scale: + x = self.ada_scale * x + self.ada_bias + + # exchange the temporal dimension and the feature dimension + x = x.transpose([0, 2, 1]) # [B, C, T] + + # mask batch padding + if mask_pad.shape[2] > 0: # time > 0 + x = masked_fill(x, mask_pad, 0.0) + + if self.lorder > 0: + if cache.shape[2] == 0: # cache_t == 0 + x = nn.functional.pad(x, [self.lorder, 0], 'constant', 0.0, data_format='NCL') + else: + assert cache.shape[0] == x.shape[0] # B + assert cache.shape[1] == x.shape[1] # C + x = paddle.concat((cache, x), axis=2) + + assert (x.shape[2] > self.lorder) + new_cache = x[:, :, -self.lorder:] # [B, C, T] + else: + # It's better we just return None if no cache is required, + # However, for JIT export, here we just fake one tensor instead of + # None. + new_cache = paddle.zeros([0, 0, 0], dtype=x.dtype) + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channel, dim) + x = nn.functional.glu(x, axis=1) # (batch, channel, dim) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + if self.use_layer_norm: + x = x.transpose([0, 2, 1]) # [B, T, C] + x = self.activation(self.norm(x)) + if self.use_layer_norm: + x = x.transpose([0, 2, 1]) # [B, C, T] + x = self.pointwise_conv2(x) + + # mask batch padding + if mask_pad.shape[2] > 0: # time > 0 + x = masked_fill(x, mask_pad, 0.0) + + x = x.transpose([0, 2, 1]) # [B, T, C] + return x, new_cache diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index fd7bd7b9a1d..654c1607ad9 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -14,26 +14,28 @@ # limitations under the License. # Modified from wenet(https://github.com/wenet-e2e/wenet) """Encoder definition.""" -from typing import Tuple +from typing import Tuple, Union, Optional, List import paddle from paddle import nn from typeguard import check_argument_types from paddlespeech.s2t.modules.activation import get_activation -from paddlespeech.s2t.modules.align import LayerNorm -from paddlespeech.s2t.modules.attention import MultiHeadedAttention +from paddlespeech.s2t.modules.align import LayerNorm, Linear +from paddlespeech.s2t.modules.attention import MultiHeadedAttention, RelPositionMultiHeadedAttention2 from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention from paddlespeech.s2t.modules.conformer_convolution import ConvolutionModule +from paddlespeech.s2t.modules.convolution import ConvolutionModule2 from paddlespeech.s2t.modules.embedding import NoPositionalEncoding from paddlespeech.s2t.modules.embedding import PositionalEncoding from paddlespeech.s2t.modules.embedding import RelPositionalEncoding -from paddlespeech.s2t.modules.encoder_layer import ConformerEncoderLayer +from paddlespeech.s2t.modules.encoder_layer import ConformerEncoderLayer, SqueezeformerEncoderLayer from paddlespeech.s2t.modules.encoder_layer import TransformerEncoderLayer from paddlespeech.s2t.modules.mask import add_optional_chunk_mask from paddlespeech.s2t.modules.mask import make_non_pad_mask -from paddlespeech.s2t.modules.positionwise_feed_forward import PositionwiseFeedForward -from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling4 +from paddlespeech.s2t.modules.positionwise_feed_forward import PositionwiseFeedForward, PositionwiseFeedForward2 +from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling4, TimeReductionLayerStream, TimeReductionLayer1D, \ + DepthwiseConv2DSubsampling4, TimeReductionLayer2D from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling6 from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling8 from paddlespeech.s2t.modules.subsampling import LinearNoSubsampling @@ -487,3 +489,372 @@ def __init__(self, normalize_before=normalize_before, concat_after=concat_after) for _ in range(num_blocks) ]) + + +class SqueezeformerEncoder(nn.Layer): + def __init__( + self, + input_size: int, + encoder_dim: int = 256, + output_size: int = 256, + attention_heads: int = 4, + num_blocks: int = 12, + reduce_idx: Optional[Union[int, List[int]]] = 5, + recover_idx: Optional[Union[int, List[int]]] = 11, + feed_forward_expansion_factor: int = 4, + dw_stride: bool = False, + input_dropout_rate: float = 0.1, + pos_enc_layer_type: str = "rel_pos", + time_reduction_layer_type: str = "conv1d", + do_rel_shift: bool = True, + feed_forward_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.1, + cnn_module_kernel: int = 31, + cnn_norm_type: str = "layer_norm", + dropout: float = 0.1, + causal: bool = False, + adaptive_scale: bool = True, + activation_type: str = "swish", + init_weights: bool = True, + global_cmvn: paddle.nn.Layer = None, + normalize_before: bool = False, + use_dynamic_chunk: bool = False, + concat_after: bool = False, + static_chunk_size: int = 0, + use_dynamic_left_chunk: bool = False + ): + """Construct SqueezeformerEncoder + + Args: + input_size to use_dynamic_chunk, see in Transformer BaseEncoder. + encoder_dim (int): The hidden dimension of encoder layer. + output_size (int): The output dimension of final projection layer. + attention_heads (int): Num of attention head in attention module. + num_blocks (int): Num of encoder layers. + reduce_idx Optional[Union[int, List[int]]]: + reduce layer index, from 40ms to 80ms per frame. + recover_idx Optional[Union[int, List[int]]]: + recover layer index, from 80ms to 40ms per frame. + feed_forward_expansion_factor (int): Enlarge coefficient of FFN. + dw_stride (bool): Whether do depthwise convolution + on subsampling module. + input_dropout_rate (float): Dropout rate of input projection layer. + pos_enc_layer_type (str): Self attention type. + time_reduction_layer_type (str): Conv1d or Conv2d reduction layer. + do_rel_shift (bool): Whether to do relative shift + operation on rel-attention module. + cnn_module_kernel (int): Kernel size of CNN module. + activation_type (str): Encoder activation function type. + cnn_module_kernel (int): Kernel size of convolution module. + adaptive_scale (bool): Whether to use adaptive scale. + init_weights (bool): Whether to initialize weights. + causal (bool): whether to use causal convolution or not. + """ + assert check_argument_types() + super().__init__() + self.global_cmvn = global_cmvn + self.reduce_idx: Optional[Union[int, List[int]]] = [reduce_idx] \ + if type(reduce_idx) == int else reduce_idx + self.recover_idx: Optional[Union[int, List[int]]] = [recover_idx] \ + if type(recover_idx) == int else recover_idx + self.check_ascending_list() + if reduce_idx is None: + self.time_reduce = None + else: + if recover_idx is None: + self.time_reduce = 'normal' # no recovery at the end + else: + self.time_reduce = 'recover' # recovery at the end + assert len(self.reduce_idx) == len(self.recover_idx) + self.reduce_stride = 2 + self._output_size = output_size + self.normalize_before = normalize_before + self.static_chunk_size = static_chunk_size + self.use_dynamic_chunk = use_dynamic_chunk + self.use_dynamic_left_chunk = use_dynamic_left_chunk + activation = get_activation(activation_type) + + # self-attention module definition + if pos_enc_layer_type != "rel_pos": + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = (attention_heads, + output_size, + attention_dropout_rate) + else: + encoder_selfattn_layer = RelPositionMultiHeadedAttention2 + encoder_selfattn_layer_args = (attention_heads, + encoder_dim, + attention_dropout_rate, + do_rel_shift, + adaptive_scale, + init_weights) + + # feed-forward module definition + positionwise_layer = PositionwiseFeedForward2 + positionwise_layer_args = (encoder_dim, + encoder_dim * feed_forward_expansion_factor, + feed_forward_dropout_rate, + activation, + adaptive_scale, + init_weights) + + # convolution module definition + convolution_layer = ConvolutionModule2 + convolution_layer_args = (encoder_dim, cnn_module_kernel, activation, + cnn_norm_type, causal, True, adaptive_scale, init_weights) + + self.embed = DepthwiseConv2DSubsampling4(1, encoder_dim, + RelPositionalEncoding(encoder_dim, dropout_rate=0.1), + dw_stride, + input_size, + input_dropout_rate, + init_weights) + + self.preln = LayerNorm(encoder_dim) + self.encoders = paddle.nn.LayerList([SqueezeformerEncoderLayer( + encoder_dim, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + convolution_layer(*convolution_layer_args), + positionwise_layer(*positionwise_layer_args), + normalize_before, + dropout, + concat_after) for _ in range(num_blocks) + ]) + if time_reduction_layer_type == 'conv1d': + time_reduction_layer = TimeReductionLayer1D + time_reduction_layer_args = { + 'channel': encoder_dim, + 'out_dim': encoder_dim, + } + elif time_reduction_layer_type == 'stream': + time_reduction_layer = TimeReductionLayerStream + time_reduction_layer_args = { + 'channel': encoder_dim, + 'out_dim': encoder_dim, + } + else: + time_reduction_layer = TimeReductionLayer2D + time_reduction_layer_args = {'encoder_dim': encoder_dim} + + self.time_reduction_layer = time_reduction_layer(**time_reduction_layer_args) + self.time_recover_layer = Linear(encoder_dim, encoder_dim) + self.final_proj = None + if output_size != encoder_dim: + self.final_proj = Linear(encoder_dim, output_size) + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs: paddle.Tensor, + xs_lens: paddle.Tensor, + decoding_chunk_size: int = 0, + num_decoding_left_chunks: int = -1, + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Embed positions in tensor. + Args: + xs: padded input tensor (B, L, D) + xs_lens: input length (B) + decoding_chunk_size: decoding chunk size for dynamic chunk + 0: default for training, use random dynamic chunk. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + num_decoding_left_chunks: number of left chunks, this is for decoding, + the chunk size is decoding_chunk_size. + >=0: use num_decoding_left_chunks + <0: use all left chunks + Returns: + encoder output tensor, lens and mask + """ + masks = make_non_pad_mask(xs_lens).unsqueeze(1) # (B, 1, L) + + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + xs, pos_emb, masks = self.embed(xs, masks) + mask_pad = ~masks + chunk_masks = add_optional_chunk_mask(xs, masks, + self.use_dynamic_chunk, + self.use_dynamic_left_chunk, + decoding_chunk_size, + self.static_chunk_size, + num_decoding_left_chunks) + xs_lens = chunk_masks.squeeze(1).sum(1) + xs = self.preln(xs) + recover_activations: \ + List[Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]] = [] + index = 0 + for i, layer in enumerate(self.encoders): + if self.reduce_idx is not None: + if self.time_reduce is not None and i in self.reduce_idx: + recover_activations.append((xs, chunk_masks, pos_emb, mask_pad)) + xs, xs_lens, chunk_masks, mask_pad = self.time_reduction_layer(xs, xs_lens, chunk_masks, mask_pad) + pos_emb = pos_emb[:, ::2, :] + index += 1 + + if self.recover_idx is not None: + if self.time_reduce == 'recover' and i in self.recover_idx: + index -= 1 + recover_tensor, recover_chunk_masks, recover_pos_emb, recover_mask_pad = recover_activations[index] + # recover output length for ctc decode + xs = paddle.repeat_interleave(xs, repeats=2, axis=1) + xs = self.time_recover_layer(xs) + recoverd_t = recover_tensor.shape[1] + xs = recover_tensor + xs[:, :recoverd_t, :] + chunk_masks = recover_chunk_masks + pos_emb = recover_pos_emb + mask_pad = recover_mask_pad + + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + + if self.final_proj is not None: + xs = self.final_proj(xs) + return xs, masks + + def check_ascending_list(self): + if self.reduce_idx is not None: + assert self.reduce_idx == sorted(self.reduce_idx), \ + "reduce_idx should be int or ascending list" + if self.recover_idx is not None: + assert self.recover_idx == sorted(self.recover_idx), \ + "recover_idx should be int or ascending list" + + def calculate_downsampling_factor(self, i: int) -> int: + if self.reduce_idx is None: + return 1 + else: + reduce_exp, recover_exp = 0, 0 + for exp, rd_idx in enumerate(self.reduce_idx): + if i >= rd_idx: + reduce_exp = exp + 1 + if self.recover_idx is not None: + for exp, rc_idx in enumerate(self.recover_idx): + if i >= rc_idx: + recover_exp = exp + 1 + return int(2 ** (reduce_exp - recover_exp)) + + def forward_chunk( + self, + xs: paddle.Tensor, + offset: int, + required_cache_size: int, + att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]), + cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]), + att_mask: paddle.Tensor = paddle.ones([0, 0, 0], dtype=paddle.bool), + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """ Forward just one chunk + + Args: + xs (paddle.Tensor): chunk input, with shape (b=1, time, mel-dim), + where `time == (chunk_size - 1) * subsample_rate + \ + subsample.right_context + 1` + offset (int): current offset in encoder output time stamp + required_cache_size (int): cache size required for next chunk + compuation + >=0: actual cache size + <0: means all history cache is required + att_cache (paddle.Tensor): cache tensor for KEY & VALUE in + transformer/conformer attention, with shape + (elayers, head, cache_t1, d_k * 2), where + `head * d_k == hidden-dim` and + `cache_t1 == chunk_size * num_decoding_left_chunks`. + cnn_cache (paddle.Tensor): cache tensor for cnn_module in conformer, + (elayers, b=1, hidden-dim, cache_t2), where + `cache_t2 == cnn.lorder - 1` + + Returns: + paddle.Tensor: output of current input xs, + with shape (b=1, chunk_size, hidden-dim). + paddle.Tensor: new attention cache required for next chunk, with + dynamic shape (elayers, head, ?, d_k * 2) + depending on required_cache_size. + paddle.Tensor: new conformer cnn cache required for next chunk, with + same shape as the original cnn_cache. + """ + assert xs.shape[0] == 1 # batch size must be one + + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + + # tmp_masks is just for interface compatibility, [B=1, C=1, T] + tmp_masks = paddle.ones([1, 1, xs.shape[1]], dtype=paddle.bool) + # before embed, xs=(B, T, D1), pos_emb=(B=1, T, D) + xs, pos_emb, _ = self.embed(xs, tmp_masks, offset=offset) + + # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim) + elayers, cache_t1 = att_cache.shape[0], att_cache.shape[2] + chunk_size = xs.shape[1] + attention_key_size = cache_t1 + chunk_size + pos_emb = self.embed.position_encoding(offset=offset - cache_t1, size=attention_key_size) + if required_cache_size < 0: + next_cache_start = 0 + elif required_cache_size == 0: + next_cache_start = attention_key_size + else: + next_cache_start = max(attention_key_size - required_cache_size, 0) + + r_att_cache = [] + r_cnn_cache = [] + + mask_pad = paddle.ones([1, xs.shape[1]], dtype=paddle.bool) + mask_pad = mask_pad.unsqueeze(1) + max_att_len: int = 0 + recover_activations: \ + List[Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]] = [] + index = 0 + xs_lens = paddle.to_tensor([xs.shape[1]], dtype=paddle.int32) + xs = self.preln(xs) + for i, layer in enumerate(self.encoders): + # NOTE(xcsong): Before layer.forward + # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2), + # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2) + if self.reduce_idx is not None: + if self.time_reduce is not None and i in self.reduce_idx: + recover_activations.append((xs, att_mask, pos_emb, mask_pad)) + xs, xs_lens, att_mask, mask_pad = self.time_reduction_layer(xs, xs_lens, att_mask, mask_pad) + pos_emb = pos_emb[:, ::2, :] + index += 1 + + if self.recover_idx is not None: + if self.time_reduce == 'recover' and i in self.recover_idx: + index -= 1 + recover_tensor, recover_att_mask, recover_pos_emb, recover_mask_pad = recover_activations[index] + # recover output length for ctc decode + xs = paddle.repeat_interleave(xs, repeats=2, axis=1) + xs = self.time_recover_layer(xs) + recoverd_t = recover_tensor.shape[1] + xs = recover_tensor + xs[:, :recoverd_t, :] + att_mask = recover_att_mask + pos_emb = recover_pos_emb + mask_pad = recover_mask_pad + + factor = self.calculate_downsampling_factor(i) + att_cache1 = att_cache[i:i + 1][:, :, ::factor, :][:, :, :pos_emb.shape[1] - xs.shape[1], :] + cnn_cache1 = cnn_cache[i] if cnn_cache.shape[0] > 0 else cnn_cache + xs, _, new_att_cache, new_cnn_cache = layer( + xs, + att_mask, + pos_emb, + att_cache=att_cache1, + cnn_cache=cnn_cache1) + # NOTE(xcsong): After layer.forward + # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2), + # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2) + cached_att = new_att_cache[:, :, next_cache_start // factor:, :] + cached_cnn = new_cnn_cache.unsqueeze(0) + cached_att = cached_att.repeat_interleave(repeats=factor, axis=2) + if i == 0: + # record length for the first block as max length + max_att_len = cached_att.shape[2] + r_att_cache.append(cached_att[:, :, :max_att_len, :]) + r_cnn_cache.append(cached_cnn) + # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2), + # ? may be larger than cache_t1, it depends on required_cache_size + r_att_cache = paddle.concat(r_att_cache, axis=0) + # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2) + r_cnn_cache = paddle.concat(r_cnn_cache, axis=0) + + if self.final_proj is not None: + xs = self.final_proj(xs) + return xs, r_att_cache, r_cnn_cache diff --git a/paddlespeech/s2t/modules/encoder_layer.py b/paddlespeech/s2t/modules/encoder_layer.py index dac62bce3e3..971afcd8ead 100644 --- a/paddlespeech/s2t/modules/encoder_layer.py +++ b/paddlespeech/s2t/modules/encoder_layer.py @@ -276,3 +276,126 @@ def forward( x = self.norm_final(x) return x, mask, new_att_cache, new_cnn_cache + + +class SqueezeformerEncoderLayer(nn.Layer): + """Encoder layer module.""" + + def __init__( + self, + size: int, + self_attn: paddle.nn.Layer, + feed_forward1: Optional[nn.Layer] = None, + conv_module: Optional[nn.Layer] = None, + feed_forward2: Optional[nn.Layer] = None, + normalize_before: bool = False, + dropout_rate: float = 0.1, + concat_after: bool = False): + """Construct an EncoderLayer object. + + Args: + size (int): Input dimension. + self_attn (paddle.nn.Layer): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + instance can be used as the argument. + feed_forward1 (paddle.nn.Layer): Feed-forward module instance. + `PositionwiseFeedForward` instance can be used as the argument. + conv_module (paddle.nn.Layer): Convolution module instance. + `ConvlutionLayer` instance can be used as the argument. + feed_forward2 (paddle.nn.Layer): Feed-forward module instance. + `PositionwiseFeedForward` instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): + True: use layer_norm before each sub-block. + False: use layer_norm after each sub-block. + """ + super().__init__() + self.size = size + self.self_attn = self_attn + self.layer_norm1 = LayerNorm(size) + self.ffn1 = feed_forward1 + self.layer_norm2 = LayerNorm(size) + self.conv_module = conv_module + self.layer_norm3 = LayerNorm(size) + self.ffn2 = feed_forward2 + self.layer_norm4 = LayerNorm(size) + self.normalize_before = normalize_before + self.dropout = nn.Dropout(dropout_rate) + self.concat_after = concat_after + if concat_after: + self.concat_linear = Linear(size + size, size) + else: + self.concat_linear = nn.Identity() + + def forward( + self, + x: paddle.Tensor, + mask: paddle.Tensor, + pos_emb: paddle.Tensor, + mask_pad: paddle.Tensor = paddle.ones([0, 0, 0], dtype=paddle.bool), + att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]), + cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]), + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Compute encoded features. + Args: + x (paddle.Tensor): Input tensor (#batch, time, size). + mask (paddle.Tensor): Mask tensor for the input (#batch, time, time). + (0,0,0) means fake mask. + pos_emb (paddle.Tensor): postional encoding, must not be None + for ConformerEncoderLayer + mask_pad (paddle.Tensor): batch padding mask used for conv module. + (#batch, 1,time), (0, 0, 0) means fake mask. + att_cache (paddle.Tensor): Cache tensor of the KEY & VALUE + (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. + cnn_cache (paddle.Tensor): Convolution cache in conformer layer + (1, #batch=1, size, cache_t2). First dim will not be used, just + for dy2st. + Returns: + paddle.Tensor: Output tensor (#batch, time, size). + paddle.Tensor: Mask tensor (#batch, time, time). + paddle.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 2). + paddle.Tensor: cnn_cahce tensor (#batch, size, cache_t2). + """ + # self attention module + residual = x + if self.normalize_before: + x = self.layer_norm1(x) + x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, att_cache) + if self.concat_after: + x_concat = paddle.concat((x, x_att), axis=-1) + x = residual + self.concat_linear(x_concat) + else: + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.layer_norm1(x) + + # ffn module + residual = x + if self.normalize_before: + x = self.layer_norm2(x) + x = self.ffn1(x) + x = residual + self.dropout(x) + if not self.normalize_before: + x = self.layer_norm2(x) + + # conv module + residual = x + if self.normalize_before: + x = self.layer_norm3(x) + x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) + x = residual + self.dropout(x) + if not self.normalize_before: + x = self.layer_norm3(x) + + # ffn module + residual = x + if self.normalize_before: + x = self.layer_norm4(x) + x = self.ffn2(x) + # we do not use dropout here since it is inside feed forward function + x = residual + self.dropout(x) + if not self.normalize_before: + x = self.layer_norm4(x) + + return x, mask, new_att_cache, new_cnn_cache diff --git a/paddlespeech/s2t/modules/positionwise_feed_forward.py b/paddlespeech/s2t/modules/positionwise_feed_forward.py index c2725dc5cc4..336199a58e6 100644 --- a/paddlespeech/s2t/modules/positionwise_feed_forward.py +++ b/paddlespeech/s2t/modules/positionwise_feed_forward.py @@ -17,6 +17,7 @@ import paddle from paddle import nn +from paddle.nn import initializer as I from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.utils.log import Log @@ -32,7 +33,7 @@ def __init__(self, idim: int, hidden_units: int, dropout_rate: float, - activation: nn.Layer=nn.ReLU()): + activation: nn.Layer = nn.ReLU()): """Construct a PositionwiseFeedForward object. FeedForward are appied on each position of the sequence. @@ -58,3 +59,61 @@ def forward(self, xs: paddle.Tensor) -> paddle.Tensor: output tensor, (B, Lmax, D) """ return self.w_2(self.dropout(self.activation(self.w_1(xs)))) + + +class PositionwiseFeedForward2(paddle.nn.Layer): + """Positionwise feed forward layer. + + FeedForward are appied on each position of the sequence. + The output dim is same with the input dim. + + Args: + idim (int): Input dimenstion. + hidden_units (int): The number of hidden units. + dropout_rate (float): Dropout rate. + activation (paddle.nn.Layer): Activation function + """ + + def __init__(self, + idim: int, + hidden_units: int, + dropout_rate: float, + activation: paddle.nn.Layer = paddle.nn.ReLU(), + adaptive_scale: bool = False, + init_weights: bool = False): + """Construct a PositionwiseFeedForward object.""" + super(PositionwiseFeedForward2, self).__init__() + self.idim = idim + self.hidden_units = hidden_units + self.w_1 = Linear(idim, hidden_units) + self.activation = activation + self.dropout = paddle.nn.Dropout(dropout_rate) + self.w_2 = Linear(hidden_units, idim) + self.adaptive_scale = adaptive_scale + ada_scale = self.create_parameter([1, 1, idim], default_initializer=I.XavierUniform()) + self.add_parameter('ada_scale', ada_scale) + ada_bias = self.create_parameter([1, 1, idim], default_initializer=I.XavierUniform()) + self.add_parameter('ada_bias', ada_bias) + + if init_weights: + self.init_weights() + + def init_weights(self): + ffn1_max = self.idim ** -0.5 + ffn2_max = self.hidden_units ** -0.5 + self.w_1._param_attr = paddle.nn.initializer.Uniform(low=-ffn1_max, high=ffn1_max) + self.w_1._bias_attr = paddle.nn.initializer.Uniform(low=-ffn1_max, high=ffn1_max) + self.w_2._param_attr = paddle.nn.initializer.Uniform(low=-ffn2_max, high=ffn2_max) + self.w_2._bias_attr = paddle.nn.initializer.Uniform(low=-ffn2_max, high=ffn2_max) + + def forward(self, xs: paddle.Tensor) -> paddle.Tensor: + """Forward function. + + Args: + xs: input tensor (B, L, D) + Returns: + output tensor, (B, L, D) + """ + if self.adaptive_scale: + xs = self.ada_scale * xs + self.ada_bias + return self.w_2(self.dropout(self.activation(self.w_1(xs)))) diff --git a/paddlespeech/s2t/modules/subsampling.py b/paddlespeech/s2t/modules/subsampling.py index 782a437ee85..09f92acca8e 100644 --- a/paddlespeech/s2t/modules/subsampling.py +++ b/paddlespeech/s2t/modules/subsampling.py @@ -17,11 +17,14 @@ from typing import Tuple import paddle +import paddle.nn.functional as F from paddle import nn -from paddlespeech.s2t.modules.align import Conv2D +from paddlespeech.s2t import masked_fill +from paddlespeech.s2t.modules.align import Conv2D, Conv1D from paddlespeech.s2t.modules.align import LayerNorm from paddlespeech.s2t.modules.align import Linear +from paddlespeech.s2t.modules.conv2d import Conv2DValid from paddlespeech.s2t.modules.embedding import PositionalEncoding from paddlespeech.s2t.utils.log import Log @@ -249,3 +252,257 @@ def forward(self, x: paddle.Tensor, x_mask: paddle.Tensor, offset: int=0 x = self.linear(x.transpose([0, 2, 1, 3]).reshape([b, -1, c * f])) x, pos_emb = self.pos_enc(x, offset) return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2] + + +class DepthwiseConv2DSubsampling4(BaseSubsampling): + """Depthwise Convolutional 2D subsampling (to 1/4 length). + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + pos_enc_class (nn.Layer): position encoding class. + dw_stride (int): Whether do depthwise convolution. + input_size (int): filter bank dimension. + + """ + + def __init__( + self, idim: int, odim: int, + pos_enc_class: nn.Layer, + dw_stride: bool = False, + input_size: int = 80, + input_dropout_rate: float = 0.1, + init_weights: bool = True): + super(DepthwiseConv2DSubsampling4, self).__init__() + self.idim = idim + self.odim = odim + self.pw_conv = Conv2D(in_channels=idim, out_channels=odim, kernel_size=3, stride=2) + self.act1 = nn.ReLU() + self.dw_conv = Conv2D(in_channels=odim, out_channels=odim, kernel_size=3, stride=2, + groups=odim if dw_stride else 1) + self.act2 = nn.ReLU() + self.pos_enc = pos_enc_class + self.input_proj = nn.Sequential( + Linear(odim * (((input_size - 1) // 2 - 1) // 2), odim), + nn.Dropout(p=input_dropout_rate)) + if init_weights: + linear_max = (odim * input_size / 4) ** -0.5 + self.input_proj.state_dict()['0.weight'] = paddle.nn.initializer.Uniform(low=-linear_max, high=linear_max) + self.input_proj.state_dict()['0.bias'] = paddle.nn.initializer.Uniform(low=-linear_max, high=linear_max) + + self.subsampling_rate = 4 + # 6 = (3 - 1) * 1 + (3 - 1) * 2 + self.right_context = 6 + + def forward( + self, + x: paddle.Tensor, + x_mask: paddle.Tensor, + offset: int = 0 + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + x = x.unsqueeze(1) # (b, c=1, t, f) + x = self.pw_conv(x) + x = self.act1(x) + x = self.dw_conv(x) + x = self.act2(x) + b, c, t, f = x.shape + x = x.transpose([0, 2, 1, 3]).reshape([b, -1, c * f]) + x, pos_emb = self.pos_enc(x, offset) + x = self.input_proj(x) + return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2] + + +class TimeReductionLayer1D(nn.Layer): + """ + Modified NeMo, + Squeezeformer Time Reduction procedure. + Downsamples the audio by `stride` in the time dimension. + Args: + channel (int): input dimension of + MultiheadAttentionMechanism and PositionwiseFeedForward + out_dim (int): Output dimension of the module. + kernel_size (int): Conv kernel size for + depthwise convolution in convolution module + stride (int): Downsampling factor in time dimension. + """ + + def __init__(self, channel: int, out_dim: int, kernel_size: int = 5, stride: int = 2): + super(TimeReductionLayer1D, self).__init__() + + self.channel = channel + self.out_dim = out_dim + self.kernel_size = kernel_size + self.stride = stride + self.padding = max(0, self.kernel_size - self.stride) + + self.dw_conv = Conv1D( + in_channels=channel, + out_channels=channel, + kernel_size=kernel_size, + stride=stride, + padding=self.padding, + groups=channel, + ) + + self.pw_conv = Conv1D( + in_channels=channel, out_channels=out_dim, + kernel_size=1, stride=1, padding=0, groups=1, + ) + + self.init_weights() + + def init_weights(self): + dw_max = self.kernel_size ** -0.5 + pw_max = self.channel ** -0.5 + self.dw_conv._param_attr = paddle.nn.initializer.Uniform(low=-dw_max, high=dw_max) + self.dw_conv._bias_attr = paddle.nn.initializer.Uniform(low=-dw_max, high=dw_max) + self.pw_conv._param_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) + self.pw_conv._bias_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) + + def forward(self, xs, xs_lens: paddle.Tensor, + mask: paddle.Tensor = paddle.ones((0, 0, 0), dtype=paddle.bool), + mask_pad: paddle.Tensor = paddle.ones((0, 0, 0), dtype=paddle.bool), + ): + xs = xs.transpose([0, 2, 1]) # [B, C, T] + xs = masked_fill(xs, mask_pad.equal(0), 0.0) + + xs = self.dw_conv(xs) + xs = self.pw_conv(xs) + + xs = xs.transpose([0, 2, 1]) # [B, T, C] + + B, T, D = xs.shape + mask = mask[:, ::self.stride, ::self.stride] + mask_pad = mask_pad[:, :, ::self.stride] + L = mask_pad.shape[-1] + # For JIT exporting, we remove F.pad operator. + if L - T < 0: + xs = xs[:, :L - T, :] + else: + dummy_pad = paddle.zeros([B, L - T, D], dtype=paddle.float32) + xs = paddle.concat([xs, dummy_pad], axis=1) + + xs_lens = (xs_lens + 1) // 2 + return xs, xs_lens, mask, mask_pad + + +class TimeReductionLayer2D(nn.Layer): + def __init__(self, kernel_size: int = 5, stride: int = 2, encoder_dim: int = 256): + super(TimeReductionLayer2D, self).__init__() + self.encoder_dim = encoder_dim + self.kernel_size = kernel_size + self.dw_conv = Conv2DValid(in_channels=encoder_dim, + out_channels=encoder_dim, + kernel_size=(kernel_size, 1), + stride=stride, + valid_trigy=True) + self.pw_conv = Conv2DValid(in_channels=encoder_dim, + out_channels=encoder_dim, + kernel_size=1, + stride=1, + valid_trigx=False, + valid_trigy=False) + + self.kernel_size = kernel_size + self.stride = stride + self.init_weights() + + def init_weights(self): + dw_max = self.kernel_size ** -0.5 + pw_max = self.encoder_dim ** -0.5 + self.dw_conv._param_attr = paddle.nn.initializer.Uniform(low=-dw_max, high=dw_max) + self.dw_conv._bias_attr = paddle.nn.initializer.Uniform(low=-dw_max, high=dw_max) + self.pw_conv._param_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) + self.pw_conv._bias_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) + + def forward( + self, xs: paddle.Tensor, xs_lens: paddle.Tensor, + mask: paddle.Tensor = paddle.ones((0, 0, 0), dtype=paddle.bool), + mask_pad: paddle.Tensor = paddle.ones((0, 0, 0), dtype=paddle.bool), + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + xs = masked_fill(xs, mask_pad.transpose([0, 2, 1]).equal(0), 0.0) + xs = xs.unsqueeze(1) + padding1 = self.kernel_size - self.stride + xs = F.pad(xs, (0, 0, 0, 0, 0, padding1, 0, 0), mode='constant', value=0.) + xs = self.dw_conv(xs.transpose([0, 3, 2, 1])) + xs = self.pw_conv(xs).transpose([0, 3, 2, 1]).squeeze(1) + tmp_length = xs.shape[1] + xs_lens = (xs_lens + 1) // 2 + padding2 = max(0, (xs_lens.max() - tmp_length).item()) + batch_size, hidden = xs.shape[0], xs.shape[-1] + dummy_pad = paddle.zeros([batch_size, padding2, hidden], dtype=paddle.float32) + xs = paddle.concat([xs, dummy_pad], axis=1) + mask = mask[:, ::2, ::2] + mask_pad = mask_pad[:, :, ::2] + return xs, xs_lens, mask, mask_pad + + +class TimeReductionLayerStream(nn.Layer): + """ + Squeezeformer Time Reduction procedure. + Downsamples the audio by `stride` in the time dimension. + Args: + channel (int): input dimension of + MultiheadAttentionMechanism and PositionwiseFeedForward + out_dim (int): Output dimension of the module. + kernel_size (int): Conv kernel size for + depthwise convolution in convolution module + stride (int): Downsampling factor in time dimension. + """ + + def __init__(self, channel: int, out_dim: int, + kernel_size: int = 1, stride: int = 2): + super(TimeReductionLayerStream, self).__init__() + + self.channel = channel + self.out_dim = out_dim + self.kernel_size = kernel_size + self.stride = stride + + self.dw_conv = Conv1D(in_channels=channel, + out_channels=channel, + kernel_size=kernel_size, + stride=stride, + padding=0, + groups=channel) + + self.pw_conv = Conv1D(in_channels=channel, + out_channels=out_dim, + kernel_size=1, + stride=1, + padding=0, + groups=1) + self.init_weights() + + def init_weights(self): + dw_max = self.kernel_size ** -0.5 + pw_max = self.channel ** -0.5 + self.dw_conv._param_attr = paddle.nn.initializer.Uniform(low=-dw_max, high=dw_max) + self.dw_conv._bias_attr = paddle.nn.initializer.Uniform(low=-dw_max, high=dw_max) + self.pw_conv._param_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) + self.pw_conv._bias_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) + + def forward(self, xs, xs_lens: paddle.Tensor, + mask: paddle.Tensor = paddle.ones([0, 0, 0], dtype=paddle.bool), + mask_pad: paddle.Tensor = paddle.ones([0, 0, 0], dtype=paddle.bool)): + xs = xs.transpose([0, 2, 1]) # [B, C, T] + xs = masked_fill(xs, mask_pad.equal(0), 0.0) + + xs = self.dw_conv(xs) + xs = self.pw_conv(xs) + + xs = xs.transpose([0, 2, 1]) # [B, T, C] + + B, T, D = xs.shape + mask = mask[:, ::self.stride, ::self.stride] + mask_pad = mask_pad[:, :, ::self.stride] + L = mask_pad.shape[-1] + # For JIT exporting, we remove F.pad operator. + if L - T < 0: + xs = xs[:, :L - T, :] + else: + dummy_pad = paddle.zeros([B, L - T, D], dtype=paddle.float32) + xs = paddle.concat([xs, dummy_pad], axis=1) + + xs_lens = (xs_lens + 1) // 2 + return xs, xs_lens, mask, mask_pad diff --git a/paddlespeech/s2t/utils/utility.py b/paddlespeech/s2t/utils/utility.py index fdd8c029232..fe2fa9167a5 100644 --- a/paddlespeech/s2t/utils/utility.py +++ b/paddlespeech/s2t/utils/utility.py @@ -130,11 +130,19 @@ def get_subsample(config): Returns: int: subsample rate. """ - input_layer = config["encoder_conf"]["input_layer"] - assert input_layer in ["conv2d", "conv2d6", "conv2d8"] + if config['encoder'] == 'squeezeformer': + input_layer = config["encoder_conf"]["time_reduction_layer_type"] + assert input_layer in ["conv2d", "conv1d", "stream"] + else: + input_layer = config["encoder_conf"]["input_layer"] + assert input_layer in ["conv2d", "conv2d6", "conv2d8"] if input_layer == "conv2d": return 4 elif input_layer == "conv2d6": return 6 elif input_layer == "conv2d8": return 8 + elif input_layer == "conv1d": + return 6 + elif input_layer == "stream": + return 8 From 2aa84571c07bce9cb439f83a1c4d4923410013c0 Mon Sep 17 00:00:00 2001 From: yeyupiaoling Date: Tue, 20 Dec 2022 10:38:21 +0800 Subject: [PATCH 2/9] change CodeStyle, test=asr --- paddlespeech/s2t/models/u2/u2.py | 3 ++- paddlespeech/s2t/modules/attention.py | 2 +- paddlespeech/s2t/modules/convolution.py | 2 +- paddlespeech/s2t/modules/encoder.py | 2 +- paddlespeech/s2t/modules/encoder_layer.py | 2 +- paddlespeech/s2t/modules/positionwise_feed_forward.py | 2 +- paddlespeech/s2t/modules/subsampling.py | 3 ++- 7 files changed, 9 insertions(+), 7 deletions(-) diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index 7e72a01fd9c..6494b5304c4 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -42,7 +42,8 @@ from paddlespeech.s2t.modules.ctc import CTCDecoderBase from paddlespeech.s2t.modules.decoder import BiTransformerDecoder from paddlespeech.s2t.modules.decoder import TransformerDecoder -from paddlespeech.s2t.modules.encoder import ConformerEncoder, SqueezeformerEncoder +from paddlespeech.s2t.modules.encoder import ConformerEncoder +from paddlespeech.s2t.modules.encoder import SqueezeformerEncoder from paddlespeech.s2t.modules.encoder import TransformerEncoder from paddlespeech.s2t.modules.initializer import DefaultInitializerContext from paddlespeech.s2t.modules.loss import LabelSmoothingLoss diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index e149c504169..29d26c60cf1 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -26,7 +26,7 @@ logger = Log(__name__).getlog() -__all__ = ["MultiHeadedAttention", "RelPositionMultiHeadedAttention"] +__all__ = ["MultiHeadedAttention", "RelPositionMultiHeadedAttention", "RelPositionMultiHeadedAttention2"] # Relative Positional Encodings # https://www.jianshu.com/p/c0608efcc26f diff --git a/paddlespeech/s2t/modules/convolution.py b/paddlespeech/s2t/modules/convolution.py index 0b5ab0356c5..47018676505 100644 --- a/paddlespeech/s2t/modules/convolution.py +++ b/paddlespeech/s2t/modules/convolution.py @@ -5,7 +5,7 @@ from paddle.nn import initializer as I from typeguard import check_argument_types -__all__ = ['ConvolutionModule'] +__all__ = ['ConvolutionModule2'] from paddlespeech.s2t import masked_fill from paddlespeech.s2t.modules.align import Conv1D, BatchNorm1D, LayerNorm diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index 654c1607ad9..6063e95dcf6 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -43,7 +43,7 @@ logger = Log(__name__).getlog() -__all__ = ["BaseEncoder", 'TransformerEncoder', "ConformerEncoder"] +__all__ = ["BaseEncoder", 'TransformerEncoder', "ConformerEncoder", "SqueezeformerEncoder"] class BaseEncoder(nn.Layer): diff --git a/paddlespeech/s2t/modules/encoder_layer.py b/paddlespeech/s2t/modules/encoder_layer.py index 971afcd8ead..08304210ae9 100644 --- a/paddlespeech/s2t/modules/encoder_layer.py +++ b/paddlespeech/s2t/modules/encoder_layer.py @@ -26,7 +26,7 @@ logger = Log(__name__).getlog() -__all__ = ["TransformerEncoderLayer", "ConformerEncoderLayer"] +__all__ = ["TransformerEncoderLayer", "ConformerEncoderLayer", "SqueezeformerEncoderLayer"] class TransformerEncoderLayer(nn.Layer): diff --git a/paddlespeech/s2t/modules/positionwise_feed_forward.py b/paddlespeech/s2t/modules/positionwise_feed_forward.py index 336199a58e6..28488d06f8e 100644 --- a/paddlespeech/s2t/modules/positionwise_feed_forward.py +++ b/paddlespeech/s2t/modules/positionwise_feed_forward.py @@ -23,7 +23,7 @@ logger = Log(__name__).getlog() -__all__ = ["PositionwiseFeedForward"] +__all__ = ["PositionwiseFeedForward", "PositionwiseFeedForward2"] class PositionwiseFeedForward(nn.Layer): diff --git a/paddlespeech/s2t/modules/subsampling.py b/paddlespeech/s2t/modules/subsampling.py index 09f92acca8e..97c226150d5 100644 --- a/paddlespeech/s2t/modules/subsampling.py +++ b/paddlespeech/s2t/modules/subsampling.py @@ -32,7 +32,8 @@ __all__ = [ "LinearNoSubsampling", "Conv2dSubsampling4", "Conv2dSubsampling6", - "Conv2dSubsampling8" + "Conv2dSubsampling8", "TimeReductionLayerStream", "TimeReductionLayer1D", + "TimeReductionLayer2D", "DepthwiseConv2DSubsampling4" ] From 34acf5f970203627f05f31d5698d1b7c8fe9da8d Mon Sep 17 00:00:00 2001 From: yeyupiaoling Date: Tue, 20 Dec 2022 10:45:58 +0800 Subject: [PATCH 3/9] change CodeStyle, test=asr --- paddlespeech/s2t/modules/attention.py | 87 ++++--- paddlespeech/s2t/modules/conv2d.py | 74 +++--- paddlespeech/s2t/modules/convolution.py | 49 ++-- paddlespeech/s2t/modules/encoder.py | 190 ++++++++-------- paddlespeech/s2t/modules/encoder_layer.py | 30 +-- .../s2t/modules/positionwise_feed_forward.py | 32 +-- paddlespeech/s2t/modules/subsampling.py | 212 +++++++++++------- 7 files changed, 389 insertions(+), 285 deletions(-) diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index 29d26c60cf1..6347bdb12ee 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -26,7 +26,10 @@ logger = Log(__name__).getlog() -__all__ = ["MultiHeadedAttention", "RelPositionMultiHeadedAttention", "RelPositionMultiHeadedAttention2"] +__all__ = [ + "MultiHeadedAttention", "RelPositionMultiHeadedAttention", + "RelPositionMultiHeadedAttention2" +] # Relative Positional Encodings # https://www.jianshu.com/p/c0608efcc26f @@ -341,7 +344,13 @@ class RelPositionMultiHeadedAttention2(MultiHeadedAttention): dropout_rate (float): Dropout rate. """ - def __init__(self, n_head, n_feat, dropout_rate, do_rel_shift=False, adaptive_scale=False, init_weights=False): + def __init__(self, + n_head, + n_feat, + dropout_rate, + do_rel_shift=False, + adaptive_scale=False, + init_weights=False): """Construct an RelPositionMultiHeadedAttention object.""" super().__init__(n_head, n_feat, dropout_rate) # linear transformation for positional encoding @@ -349,32 +358,46 @@ def __init__(self, n_head, n_feat, dropout_rate, do_rel_shift=False, adaptive_sc # these two learnable bias are used in matrix c and matrix d # as described in https://arxiv.org/abs/1901.02860 Section 3.3 self.do_rel_shift = do_rel_shift - pos_bias_u = self.create_parameter([self.h, self.d_k], default_initializer=I.XavierUniform()) + pos_bias_u = self.create_parameter( + [self.h, self.d_k], default_initializer=I.XavierUniform()) self.add_parameter('pos_bias_u', pos_bias_u) - pos_bias_v = self.create_parameter([self.h, self.d_k], default_initializer=I.XavierUniform()) + pos_bias_v = self.create_parameter( + [self.h, self.d_k], default_initializer=I.XavierUniform()) self.add_parameter('pos_bias_v', pos_bias_v) self.adaptive_scale = adaptive_scale - ada_scale = self.create_parameter([1, 1, n_feat], default_initializer=I.Constant(1.0)) + ada_scale = self.create_parameter( + [1, 1, n_feat], default_initializer=I.Constant(1.0)) self.add_parameter('ada_scale', ada_scale) - ada_bias = self.create_parameter([1, 1, n_feat], default_initializer=I.Constant(0.0)) + ada_bias = self.create_parameter( + [1, 1, n_feat], default_initializer=I.Constant(0.0)) self.add_parameter('ada_bias', ada_bias) if init_weights: self.init_weights() def init_weights(self): - input_max = (self.h * self.d_k) ** -0.5 - self.linear_q._param_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) - self.linear_q._bias_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) - self.linear_k._param_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) - self.linear_k._bias_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) - self.linear_v._param_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) - self.linear_v._bias_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) - self.linear_pos._param_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) - self.linear_pos._bias_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) - self.linear_out._param_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) - self.linear_out._bias_attr = paddle.nn.initializer.Uniform(low=-input_max, high=input_max) - - def rel_shift(self, x, zero_triu: bool = False): + input_max = (self.h * self.d_k)**-0.5 + self.linear_q._param_attr = paddle.nn.initializer.Uniform( + low=-input_max, high=input_max) + self.linear_q._bias_attr = paddle.nn.initializer.Uniform( + low=-input_max, high=input_max) + self.linear_k._param_attr = paddle.nn.initializer.Uniform( + low=-input_max, high=input_max) + self.linear_k._bias_attr = paddle.nn.initializer.Uniform( + low=-input_max, high=input_max) + self.linear_v._param_attr = paddle.nn.initializer.Uniform( + low=-input_max, high=input_max) + self.linear_v._bias_attr = paddle.nn.initializer.Uniform( + low=-input_max, high=input_max) + self.linear_pos._param_attr = paddle.nn.initializer.Uniform( + low=-input_max, high=input_max) + self.linear_pos._bias_attr = paddle.nn.initializer.Uniform( + low=-input_max, high=input_max) + self.linear_out._param_attr = paddle.nn.initializer.Uniform( + low=-input_max, high=input_max) + self.linear_out._bias_attr = paddle.nn.initializer.Uniform( + low=-input_max, high=input_max) + + def rel_shift(self, x, zero_triu: bool=False): """Compute relative positinal encoding. Args: x (paddle.Tensor): Input tensor (batch, head, time1, time1). @@ -383,10 +406,12 @@ def rel_shift(self, x, zero_triu: bool = False): Returns: paddle.Tensor: Output tensor. (batch, head, time1, time1) """ - zero_pad = paddle.zeros([x.shape[0], x.shape[1], x.shape[2], 1], dtype=x.dtype) + zero_pad = paddle.zeros( + [x.shape[0], x.shape[1], x.shape[2], 1], dtype=x.dtype) x_padded = paddle.concat([zero_pad, x], axis=-1) - x_padded = x_padded.reshape([x.shape[0], x.shape[1], x.shape[3] + 1, x.shape[2]]) + x_padded = x_padded.reshape( + [x.shape[0], x.shape[1], x.shape[3] + 1, x.shape[2]]) x = x_padded[:, :, 1:].reshape(paddle.shape(x)) # [B, H, T1, T1] if zero_triu: @@ -395,12 +420,14 @@ def rel_shift(self, x, zero_triu: bool = False): return x - def forward(self, query: paddle.Tensor, - key: paddle.Tensor, value: paddle.Tensor, - mask: paddle.Tensor = paddle.ones((0, 0, 0), dtype=paddle.bool), - pos_emb: paddle.Tensor = paddle.empty([0]), - cache: paddle.Tensor = paddle.zeros((0, 0, 0, 0)) - ) -> Tuple[paddle.Tensor, paddle.Tensor]: + def forward(self, + query: paddle.Tensor, + key: paddle.Tensor, + value: paddle.Tensor, + mask: paddle.Tensor=paddle.ones((0, 0, 0), dtype=paddle.bool), + pos_emb: paddle.Tensor=paddle.empty([0]), + cache: paddle.Tensor=paddle.zeros( + (0, 0, 0, 0))) -> Tuple[paddle.Tensor, paddle.Tensor]: """Compute 'Scaled Dot Product Attention' with rel. positional encoding. Args: query (paddle.Tensor): Query tensor (#batch, time1, size). @@ -434,7 +461,8 @@ def forward(self, query: paddle.Tensor, new_cache = paddle.concat((k, v), axis=-1) n_batch_pos = pos_emb.shape[0] - p = self.linear_pos(pos_emb).reshape([n_batch_pos, -1, self.h, self.d_k]) + p = self.linear_pos(pos_emb).reshape( + [n_batch_pos, -1, self.h, self.d_k]) p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) # (batch, head, time1, d_k) @@ -460,6 +488,7 @@ def forward(self, query: paddle.Tensor, if self.do_rel_shift: matrix_bd = self.rel_shift(matrix_bd) - scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2) + scores = (matrix_ac + matrix_bd) / math.sqrt( + self.d_k) # (batch, head, time1, time2) return self.forward_attention(v, scores, mask), new_cache diff --git a/paddlespeech/s2t/modules/conv2d.py b/paddlespeech/s2t/modules/conv2d.py index 4b41d80a412..ca6e136ad6c 100644 --- a/paddlespeech/s2t/modules/conv2d.py +++ b/paddlespeech/s2t/modules/conv2d.py @@ -1,4 +1,5 @@ -from typing import Union, Optional +from typing import Optional +from typing import Union import paddle import paddle.nn.functional as F @@ -12,45 +13,50 @@ class Conv2DValid(_ConvNd): Conv2d operator for VALID mode padding. """ - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: int = 1, - padding: Union[str, int] = 0, - dilation: int = 1, - groups: int = 1, - padding_mode: str = 'zeros', - weight_attr=None, - bias_attr=None, - data_format="NCHW", - valid_trigx: bool = False, - valid_trigy: bool = False - ) -> None: - super(Conv2DValid, self).__init__(in_channels, - out_channels, - kernel_size, - False, - 2, - stride=stride, - padding=padding, - padding_mode=padding_mode, - dilation=dilation, - groups=groups, - weight_attr=weight_attr, - bias_attr=bias_attr, - data_format=data_format) + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int=1, + padding: Union[str, int]=0, + dilation: int=1, + groups: int=1, + padding_mode: str='zeros', + weight_attr=None, + bias_attr=None, + data_format="NCHW", + valid_trigx: bool=False, + valid_trigy: bool=False) -> None: + super(Conv2DValid, self).__init__( + in_channels, + out_channels, + kernel_size, + False, + 2, + stride=stride, + padding=padding, + padding_mode=padding_mode, + dilation=dilation, + groups=groups, + weight_attr=weight_attr, + bias_attr=bias_attr, + data_format=data_format) self.valid_trigx = valid_trigx self.valid_trigy = valid_trigy - def _conv_forward(self, input: paddle.Tensor, weight: paddle.Tensor, bias: Optional[paddle.Tensor]): + def _conv_forward(self, + input: paddle.Tensor, + weight: paddle.Tensor, + bias: Optional[paddle.Tensor]): validx, validy = 0, 0 if self.valid_trigx: - validx = (input.shape[-2] * (self._stride[-2] - 1) - 1 + self._kernel_size[-2]) // 2 + validx = (input.shape[-2] * + (self._stride[-2] - 1) - 1 + self._kernel_size[-2]) // 2 if self.valid_trigy: - validy = (input.shape[-1] * (self._stride[-1] - 1) - 1 + self._kernel_size[-1]) // 2 - return F.conv2d(input, weight, bias, self._stride, (validx, validy), self._dilation, self._groups) + validy = (input.shape[-1] * + (self._stride[-1] - 1) - 1 + self._kernel_size[-1]) // 2 + return F.conv2d(input, weight, bias, self._stride, (validx, validy), + self._dilation, self._groups) def forward(self, input: paddle.Tensor) -> paddle.Tensor: return self._conv_forward(input, self.weight, self.bias) diff --git a/paddlespeech/s2t/modules/convolution.py b/paddlespeech/s2t/modules/convolution.py index 47018676505..caaa985668a 100644 --- a/paddlespeech/s2t/modules/convolution.py +++ b/paddlespeech/s2t/modules/convolution.py @@ -16,13 +16,13 @@ class ConvolutionModule2(nn.Layer): def __init__(self, channels: int, - kernel_size: int = 15, - activation: nn.Layer = nn.ReLU(), - norm: str = "batch_norm", - causal: bool = False, - bias: bool = True, - adaptive_scale: bool = False, - init_weights: bool = False): + kernel_size: int=15, + activation: nn.Layer=nn.ReLU(), + norm: str="batch_norm", + causal: bool=False, + bias: bool=True, + adaptive_scale: bool=False, + init_weights: bool=False): """Construct an ConvolutionModule object. Args: channels (int): The number of channels of conv layers. @@ -35,9 +35,11 @@ def __init__(self, self.channels = channels self.kernel_size = kernel_size self.adaptive_scale = adaptive_scale - ada_scale = self.create_parameter([1, 1, channels], default_initializer=I.Constant(1.0)) + ada_scale = self.create_parameter( + [1, 1, channels], default_initializer=I.Constant(1.0)) self.add_parameter('ada_scale', ada_scale) - ada_bias = self.create_parameter([1, 1, channels], default_initializer=I.Constant(0.0)) + ada_bias = self.create_parameter( + [1, 1, channels], default_initializer=I.Constant(0.0)) self.add_parameter('ada_bias', ada_bias) self.pointwise_conv1 = Conv1D( @@ -96,23 +98,29 @@ def __init__(self, self.init_weights() def init_weights(self): - pw_max = self.channels ** -0.5 - dw_max = self.kernel_size ** -0.5 - self.pointwise_conv1._param_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) + pw_max = self.channels**-0.5 + dw_max = self.kernel_size**-0.5 + self.pointwise_conv1._param_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) if self.bias: - self.pointwise_conv1._bias_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) - self.depthwise_conv._param_attr = paddle.nn.initializer.Uniform(low=-dw_max, high=dw_max) + self.pointwise_conv1._bias_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) + self.depthwise_conv._param_attr = paddle.nn.initializer.Uniform( + low=-dw_max, high=dw_max) if self.bias: - self.depthwise_conv._bias_attr = paddle.nn.initializer.Uniform(low=-dw_max, high=dw_max) - self.pointwise_conv2._param_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) + self.depthwise_conv._bias_attr = paddle.nn.initializer.Uniform( + low=-dw_max, high=dw_max) + self.pointwise_conv2._param_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) if self.bias: - self.pointwise_conv2._bias_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) + self.pointwise_conv2._bias_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) def forward( self, x: paddle.Tensor, - mask_pad: paddle.Tensor = paddle.ones([0, 0, 0], dtype=paddle.bool), - cache: paddle.Tensor = paddle.zeros([0, 0, 0]), + mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), + cache: paddle.Tensor=paddle.zeros([0, 0, 0]), ) -> Tuple[paddle.Tensor, paddle.Tensor]: """Compute convolution module. Args: @@ -137,7 +145,8 @@ def forward( if self.lorder > 0: if cache.shape[2] == 0: # cache_t == 0 - x = nn.functional.pad(x, [self.lorder, 0], 'constant', 0.0, data_format='NCL') + x = nn.functional.pad( + x, [self.lorder, 0], 'constant', 0.0, data_format='NCL') else: assert cache.shape[0] == x.shape[0] # B assert cache.shape[1] == x.shape[1] # C diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index 6063e95dcf6..f19ecfe4171 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -14,36 +14,49 @@ # limitations under the License. # Modified from wenet(https://github.com/wenet-e2e/wenet) """Encoder definition.""" -from typing import Tuple, Union, Optional, List +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union import paddle from paddle import nn from typeguard import check_argument_types from paddlespeech.s2t.modules.activation import get_activation -from paddlespeech.s2t.modules.align import LayerNorm, Linear -from paddlespeech.s2t.modules.attention import MultiHeadedAttention, RelPositionMultiHeadedAttention2 +from paddlespeech.s2t.modules.align import LayerNorm +from paddlespeech.s2t.modules.align import Linear +from paddlespeech.s2t.modules.attention import MultiHeadedAttention from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention +from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention2 from paddlespeech.s2t.modules.conformer_convolution import ConvolutionModule from paddlespeech.s2t.modules.convolution import ConvolutionModule2 from paddlespeech.s2t.modules.embedding import NoPositionalEncoding from paddlespeech.s2t.modules.embedding import PositionalEncoding from paddlespeech.s2t.modules.embedding import RelPositionalEncoding -from paddlespeech.s2t.modules.encoder_layer import ConformerEncoderLayer, SqueezeformerEncoderLayer +from paddlespeech.s2t.modules.encoder_layer import ConformerEncoderLayer +from paddlespeech.s2t.modules.encoder_layer import SqueezeformerEncoderLayer from paddlespeech.s2t.modules.encoder_layer import TransformerEncoderLayer from paddlespeech.s2t.modules.mask import add_optional_chunk_mask from paddlespeech.s2t.modules.mask import make_non_pad_mask -from paddlespeech.s2t.modules.positionwise_feed_forward import PositionwiseFeedForward, PositionwiseFeedForward2 -from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling4, TimeReductionLayerStream, TimeReductionLayer1D, \ - DepthwiseConv2DSubsampling4, TimeReductionLayer2D +from paddlespeech.s2t.modules.positionwise_feed_forward import PositionwiseFeedForward +from paddlespeech.s2t.modules.positionwise_feed_forward import PositionwiseFeedForward2 +from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling4 from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling6 from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling8 +from paddlespeech.s2t.modules.subsampling import DepthwiseConv2DSubsampling4 from paddlespeech.s2t.modules.subsampling import LinearNoSubsampling +from paddlespeech.s2t.modules.subsampling import TimeReductionLayer1D +from paddlespeech.s2t.modules.subsampling import TimeReductionLayer2D +from paddlespeech.s2t.modules.subsampling import TimeReductionLayerStream from paddlespeech.s2t.utils.log import Log logger = Log(__name__).getlog() -__all__ = ["BaseEncoder", 'TransformerEncoder', "ConformerEncoder", "SqueezeformerEncoder"] +__all__ = [ + "BaseEncoder", 'TransformerEncoder', "ConformerEncoder", + "SqueezeformerEncoder" +] class BaseEncoder(nn.Layer): @@ -492,37 +505,35 @@ def __init__(self, class SqueezeformerEncoder(nn.Layer): - def __init__( - self, - input_size: int, - encoder_dim: int = 256, - output_size: int = 256, - attention_heads: int = 4, - num_blocks: int = 12, - reduce_idx: Optional[Union[int, List[int]]] = 5, - recover_idx: Optional[Union[int, List[int]]] = 11, - feed_forward_expansion_factor: int = 4, - dw_stride: bool = False, - input_dropout_rate: float = 0.1, - pos_enc_layer_type: str = "rel_pos", - time_reduction_layer_type: str = "conv1d", - do_rel_shift: bool = True, - feed_forward_dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.1, - cnn_module_kernel: int = 31, - cnn_norm_type: str = "layer_norm", - dropout: float = 0.1, - causal: bool = False, - adaptive_scale: bool = True, - activation_type: str = "swish", - init_weights: bool = True, - global_cmvn: paddle.nn.Layer = None, - normalize_before: bool = False, - use_dynamic_chunk: bool = False, - concat_after: bool = False, - static_chunk_size: int = 0, - use_dynamic_left_chunk: bool = False - ): + def __init__(self, + input_size: int, + encoder_dim: int=256, + output_size: int=256, + attention_heads: int=4, + num_blocks: int=12, + reduce_idx: Optional[Union[int, List[int]]]=5, + recover_idx: Optional[Union[int, List[int]]]=11, + feed_forward_expansion_factor: int=4, + dw_stride: bool=False, + input_dropout_rate: float=0.1, + pos_enc_layer_type: str="rel_pos", + time_reduction_layer_type: str="conv1d", + do_rel_shift: bool=True, + feed_forward_dropout_rate: float=0.1, + attention_dropout_rate: float=0.1, + cnn_module_kernel: int=31, + cnn_norm_type: str="layer_norm", + dropout: float=0.1, + causal: bool=False, + adaptive_scale: bool=True, + activation_type: str="swish", + init_weights: bool=True, + global_cmvn: paddle.nn.Layer=None, + normalize_before: bool=False, + use_dynamic_chunk: bool=False, + concat_after: bool=False, + static_chunk_size: int=0, + use_dynamic_left_chunk: bool=False): """Construct SqueezeformerEncoder Args: @@ -577,49 +588,40 @@ def __init__( # self-attention module definition if pos_enc_layer_type != "rel_pos": encoder_selfattn_layer = MultiHeadedAttention - encoder_selfattn_layer_args = (attention_heads, - output_size, + encoder_selfattn_layer_args = (attention_heads, output_size, attention_dropout_rate) else: encoder_selfattn_layer = RelPositionMultiHeadedAttention2 - encoder_selfattn_layer_args = (attention_heads, - encoder_dim, - attention_dropout_rate, - do_rel_shift, - adaptive_scale, - init_weights) + encoder_selfattn_layer_args = (attention_heads, encoder_dim, + attention_dropout_rate, do_rel_shift, + adaptive_scale, init_weights) # feed-forward module definition positionwise_layer = PositionwiseFeedForward2 - positionwise_layer_args = (encoder_dim, - encoder_dim * feed_forward_expansion_factor, - feed_forward_dropout_rate, - activation, - adaptive_scale, - init_weights) + positionwise_layer_args = ( + encoder_dim, encoder_dim * feed_forward_expansion_factor, + feed_forward_dropout_rate, activation, adaptive_scale, init_weights) # convolution module definition convolution_layer = ConvolutionModule2 convolution_layer_args = (encoder_dim, cnn_module_kernel, activation, - cnn_norm_type, causal, True, adaptive_scale, init_weights) + cnn_norm_type, causal, True, adaptive_scale, + init_weights) - self.embed = DepthwiseConv2DSubsampling4(1, encoder_dim, - RelPositionalEncoding(encoder_dim, dropout_rate=0.1), - dw_stride, - input_size, - input_dropout_rate, - init_weights) + self.embed = DepthwiseConv2DSubsampling4( + 1, encoder_dim, + RelPositionalEncoding(encoder_dim, dropout_rate=0.1), dw_stride, + input_size, input_dropout_rate, init_weights) self.preln = LayerNorm(encoder_dim) - self.encoders = paddle.nn.LayerList([SqueezeformerEncoderLayer( - encoder_dim, - encoder_selfattn_layer(*encoder_selfattn_layer_args), - positionwise_layer(*positionwise_layer_args), - convolution_layer(*convolution_layer_args), - positionwise_layer(*positionwise_layer_args), - normalize_before, - dropout, - concat_after) for _ in range(num_blocks) + self.encoders = paddle.nn.LayerList([ + SqueezeformerEncoderLayer( + encoder_dim, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + convolution_layer(*convolution_layer_args), + positionwise_layer(*positionwise_layer_args), normalize_before, + dropout, concat_after) for _ in range(num_blocks) ]) if time_reduction_layer_type == 'conv1d': time_reduction_layer = TimeReductionLayer1D @@ -637,7 +639,8 @@ def __init__( time_reduction_layer = TimeReductionLayer2D time_reduction_layer_args = {'encoder_dim': encoder_dim} - self.time_reduction_layer = time_reduction_layer(**time_reduction_layer_args) + self.time_reduction_layer = time_reduction_layer( + **time_reduction_layer_args) self.time_recover_layer = Linear(encoder_dim, encoder_dim) self.final_proj = None if output_size != encoder_dim: @@ -650,8 +653,8 @@ def forward( self, xs: paddle.Tensor, xs_lens: paddle.Tensor, - decoding_chunk_size: int = 0, - num_decoding_left_chunks: int = -1, + decoding_chunk_size: int=0, + num_decoding_left_chunks: int=-1, ) -> Tuple[paddle.Tensor, paddle.Tensor]: """Embed positions in tensor. Args: @@ -674,12 +677,10 @@ def forward( xs = self.global_cmvn(xs) xs, pos_emb, masks = self.embed(xs, masks) mask_pad = ~masks - chunk_masks = add_optional_chunk_mask(xs, masks, - self.use_dynamic_chunk, - self.use_dynamic_left_chunk, - decoding_chunk_size, - self.static_chunk_size, - num_decoding_left_chunks) + chunk_masks = add_optional_chunk_mask( + xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk, + decoding_chunk_size, self.static_chunk_size, + num_decoding_left_chunks) xs_lens = chunk_masks.squeeze(1).sum(1) xs = self.preln(xs) recover_activations: \ @@ -688,15 +689,18 @@ def forward( for i, layer in enumerate(self.encoders): if self.reduce_idx is not None: if self.time_reduce is not None and i in self.reduce_idx: - recover_activations.append((xs, chunk_masks, pos_emb, mask_pad)) - xs, xs_lens, chunk_masks, mask_pad = self.time_reduction_layer(xs, xs_lens, chunk_masks, mask_pad) + recover_activations.append( + (xs, chunk_masks, pos_emb, mask_pad)) + xs, xs_lens, chunk_masks, mask_pad = self.time_reduction_layer( + xs, xs_lens, chunk_masks, mask_pad) pos_emb = pos_emb[:, ::2, :] index += 1 if self.recover_idx is not None: if self.time_reduce == 'recover' and i in self.recover_idx: index -= 1 - recover_tensor, recover_chunk_masks, recover_pos_emb, recover_mask_pad = recover_activations[index] + recover_tensor, recover_chunk_masks, recover_pos_emb, recover_mask_pad = recover_activations[ + index] # recover output length for ctc decode xs = paddle.repeat_interleave(xs, repeats=2, axis=1) xs = self.time_recover_layer(xs) @@ -732,16 +736,16 @@ def calculate_downsampling_factor(self, i: int) -> int: for exp, rc_idx in enumerate(self.recover_idx): if i >= rc_idx: recover_exp = exp + 1 - return int(2 ** (reduce_exp - recover_exp)) + return int(2**(reduce_exp - recover_exp)) def forward_chunk( self, xs: paddle.Tensor, offset: int, required_cache_size: int, - att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]), - cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]), - att_mask: paddle.Tensor = paddle.ones([0, 0, 0], dtype=paddle.bool), + att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), + cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), + att_mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: """ Forward just one chunk @@ -786,7 +790,8 @@ def forward_chunk( elayers, cache_t1 = att_cache.shape[0], att_cache.shape[2] chunk_size = xs.shape[1] attention_key_size = cache_t1 + chunk_size - pos_emb = self.embed.position_encoding(offset=offset - cache_t1, size=attention_key_size) + pos_emb = self.embed.position_encoding( + offset=offset - cache_t1, size=attention_key_size) if required_cache_size < 0: next_cache_start = 0 elif required_cache_size == 0: @@ -811,15 +816,18 @@ def forward_chunk( # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2) if self.reduce_idx is not None: if self.time_reduce is not None and i in self.reduce_idx: - recover_activations.append((xs, att_mask, pos_emb, mask_pad)) - xs, xs_lens, att_mask, mask_pad = self.time_reduction_layer(xs, xs_lens, att_mask, mask_pad) + recover_activations.append( + (xs, att_mask, pos_emb, mask_pad)) + xs, xs_lens, att_mask, mask_pad = self.time_reduction_layer( + xs, xs_lens, att_mask, mask_pad) pos_emb = pos_emb[:, ::2, :] index += 1 if self.recover_idx is not None: if self.time_reduce == 'recover' and i in self.recover_idx: index -= 1 - recover_tensor, recover_att_mask, recover_pos_emb, recover_mask_pad = recover_activations[index] + recover_tensor, recover_att_mask, recover_pos_emb, recover_mask_pad = recover_activations[ + index] # recover output length for ctc decode xs = paddle.repeat_interleave(xs, repeats=2, axis=1) xs = self.time_recover_layer(xs) @@ -830,7 +838,9 @@ def forward_chunk( mask_pad = recover_mask_pad factor = self.calculate_downsampling_factor(i) - att_cache1 = att_cache[i:i + 1][:, :, ::factor, :][:, :, :pos_emb.shape[1] - xs.shape[1], :] + att_cache1 = att_cache[ + i:i + 1][:, :, ::factor, :][:, :, :pos_emb.shape[1] - xs.shape[ + 1], :] cnn_cache1 = cnn_cache[i] if cnn_cache.shape[0] > 0 else cnn_cache xs, _, new_att_cache, new_cnn_cache = layer( xs, diff --git a/paddlespeech/s2t/modules/encoder_layer.py b/paddlespeech/s2t/modules/encoder_layer.py index 08304210ae9..ecba95e85bb 100644 --- a/paddlespeech/s2t/modules/encoder_layer.py +++ b/paddlespeech/s2t/modules/encoder_layer.py @@ -26,7 +26,10 @@ logger = Log(__name__).getlog() -__all__ = ["TransformerEncoderLayer", "ConformerEncoderLayer", "SqueezeformerEncoderLayer"] +__all__ = [ + "TransformerEncoderLayer", "ConformerEncoderLayer", + "SqueezeformerEncoderLayer" +] class TransformerEncoderLayer(nn.Layer): @@ -281,16 +284,15 @@ def forward( class SqueezeformerEncoderLayer(nn.Layer): """Encoder layer module.""" - def __init__( - self, - size: int, - self_attn: paddle.nn.Layer, - feed_forward1: Optional[nn.Layer] = None, - conv_module: Optional[nn.Layer] = None, - feed_forward2: Optional[nn.Layer] = None, - normalize_before: bool = False, - dropout_rate: float = 0.1, - concat_after: bool = False): + def __init__(self, + size: int, + self_attn: paddle.nn.Layer, + feed_forward1: Optional[nn.Layer]=None, + conv_module: Optional[nn.Layer]=None, + feed_forward2: Optional[nn.Layer]=None, + normalize_before: bool=False, + dropout_rate: float=0.1, + concat_after: bool=False): """Construct an EncoderLayer object. Args: @@ -332,9 +334,9 @@ def forward( x: paddle.Tensor, mask: paddle.Tensor, pos_emb: paddle.Tensor, - mask_pad: paddle.Tensor = paddle.ones([0, 0, 0], dtype=paddle.bool), - att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]), - cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]), + mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), + att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), + cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: """Compute encoded features. Args: diff --git a/paddlespeech/s2t/modules/positionwise_feed_forward.py b/paddlespeech/s2t/modules/positionwise_feed_forward.py index 28488d06f8e..39d8b1893e4 100644 --- a/paddlespeech/s2t/modules/positionwise_feed_forward.py +++ b/paddlespeech/s2t/modules/positionwise_feed_forward.py @@ -16,8 +16,8 @@ """Positionwise feed forward layer definition.""" import paddle from paddle import nn - from paddle.nn import initializer as I + from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.utils.log import Log @@ -33,7 +33,7 @@ def __init__(self, idim: int, hidden_units: int, dropout_rate: float, - activation: nn.Layer = nn.ReLU()): + activation: nn.Layer=nn.ReLU()): """Construct a PositionwiseFeedForward object. FeedForward are appied on each position of the sequence. @@ -78,9 +78,9 @@ def __init__(self, idim: int, hidden_units: int, dropout_rate: float, - activation: paddle.nn.Layer = paddle.nn.ReLU(), - adaptive_scale: bool = False, - init_weights: bool = False): + activation: paddle.nn.Layer=paddle.nn.ReLU(), + adaptive_scale: bool=False, + init_weights: bool=False): """Construct a PositionwiseFeedForward object.""" super(PositionwiseFeedForward2, self).__init__() self.idim = idim @@ -90,21 +90,27 @@ def __init__(self, self.dropout = paddle.nn.Dropout(dropout_rate) self.w_2 = Linear(hidden_units, idim) self.adaptive_scale = adaptive_scale - ada_scale = self.create_parameter([1, 1, idim], default_initializer=I.XavierUniform()) + ada_scale = self.create_parameter( + [1, 1, idim], default_initializer=I.XavierUniform()) self.add_parameter('ada_scale', ada_scale) - ada_bias = self.create_parameter([1, 1, idim], default_initializer=I.XavierUniform()) + ada_bias = self.create_parameter( + [1, 1, idim], default_initializer=I.XavierUniform()) self.add_parameter('ada_bias', ada_bias) if init_weights: self.init_weights() def init_weights(self): - ffn1_max = self.idim ** -0.5 - ffn2_max = self.hidden_units ** -0.5 - self.w_1._param_attr = paddle.nn.initializer.Uniform(low=-ffn1_max, high=ffn1_max) - self.w_1._bias_attr = paddle.nn.initializer.Uniform(low=-ffn1_max, high=ffn1_max) - self.w_2._param_attr = paddle.nn.initializer.Uniform(low=-ffn2_max, high=ffn2_max) - self.w_2._bias_attr = paddle.nn.initializer.Uniform(low=-ffn2_max, high=ffn2_max) + ffn1_max = self.idim**-0.5 + ffn2_max = self.hidden_units**-0.5 + self.w_1._param_attr = paddle.nn.initializer.Uniform( + low=-ffn1_max, high=ffn1_max) + self.w_1._bias_attr = paddle.nn.initializer.Uniform( + low=-ffn1_max, high=ffn1_max) + self.w_2._param_attr = paddle.nn.initializer.Uniform( + low=-ffn2_max, high=ffn2_max) + self.w_2._bias_attr = paddle.nn.initializer.Uniform( + low=-ffn2_max, high=ffn2_max) def forward(self, xs: paddle.Tensor) -> paddle.Tensor: """Forward function. diff --git a/paddlespeech/s2t/modules/subsampling.py b/paddlespeech/s2t/modules/subsampling.py index 97c226150d5..51322d324f2 100644 --- a/paddlespeech/s2t/modules/subsampling.py +++ b/paddlespeech/s2t/modules/subsampling.py @@ -21,7 +21,8 @@ from paddle import nn from paddlespeech.s2t import masked_fill -from paddlespeech.s2t.modules.align import Conv2D, Conv1D +from paddlespeech.s2t.modules.align import Conv1D +from paddlespeech.s2t.modules.align import Conv2D from paddlespeech.s2t.modules.align import LayerNorm from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.modules.conv2d import Conv2DValid @@ -267,40 +268,46 @@ class DepthwiseConv2DSubsampling4(BaseSubsampling): """ - def __init__( - self, idim: int, odim: int, - pos_enc_class: nn.Layer, - dw_stride: bool = False, - input_size: int = 80, - input_dropout_rate: float = 0.1, - init_weights: bool = True): + def __init__(self, + idim: int, + odim: int, + pos_enc_class: nn.Layer, + dw_stride: bool=False, + input_size: int=80, + input_dropout_rate: float=0.1, + init_weights: bool=True): super(DepthwiseConv2DSubsampling4, self).__init__() self.idim = idim self.odim = odim - self.pw_conv = Conv2D(in_channels=idim, out_channels=odim, kernel_size=3, stride=2) + self.pw_conv = Conv2D( + in_channels=idim, out_channels=odim, kernel_size=3, stride=2) self.act1 = nn.ReLU() - self.dw_conv = Conv2D(in_channels=odim, out_channels=odim, kernel_size=3, stride=2, - groups=odim if dw_stride else 1) + self.dw_conv = Conv2D( + in_channels=odim, + out_channels=odim, + kernel_size=3, + stride=2, + groups=odim if dw_stride else 1) self.act2 = nn.ReLU() self.pos_enc = pos_enc_class self.input_proj = nn.Sequential( Linear(odim * (((input_size - 1) // 2 - 1) // 2), odim), nn.Dropout(p=input_dropout_rate)) if init_weights: - linear_max = (odim * input_size / 4) ** -0.5 - self.input_proj.state_dict()['0.weight'] = paddle.nn.initializer.Uniform(low=-linear_max, high=linear_max) - self.input_proj.state_dict()['0.bias'] = paddle.nn.initializer.Uniform(low=-linear_max, high=linear_max) + linear_max = (odim * input_size / 4)**-0.5 + self.input_proj.state_dict()[ + '0.weight'] = paddle.nn.initializer.Uniform( + low=-linear_max, high=linear_max) + self.input_proj.state_dict()[ + '0.bias'] = paddle.nn.initializer.Uniform( + low=-linear_max, high=linear_max) self.subsampling_rate = 4 # 6 = (3 - 1) * 1 + (3 - 1) * 2 self.right_context = 6 - def forward( - self, - x: paddle.Tensor, - x_mask: paddle.Tensor, - offset: int = 0 - ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + def forward(self, x: paddle.Tensor, x_mask: paddle.Tensor, offset: int=0 + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: x = x.unsqueeze(1) # (b, c=1, t, f) x = self.pw_conv(x) x = self.act1(x) @@ -327,7 +334,11 @@ class TimeReductionLayer1D(nn.Layer): stride (int): Downsampling factor in time dimension. """ - def __init__(self, channel: int, out_dim: int, kernel_size: int = 5, stride: int = 2): + def __init__(self, + channel: int, + out_dim: int, + kernel_size: int=5, + stride: int=2): super(TimeReductionLayer1D, self).__init__() self.channel = channel @@ -342,28 +353,37 @@ def __init__(self, channel: int, out_dim: int, kernel_size: int = 5, stride: int kernel_size=kernel_size, stride=stride, padding=self.padding, - groups=channel, - ) + groups=channel, ) self.pw_conv = Conv1D( - in_channels=channel, out_channels=out_dim, - kernel_size=1, stride=1, padding=0, groups=1, - ) + in_channels=channel, + out_channels=out_dim, + kernel_size=1, + stride=1, + padding=0, + groups=1, ) self.init_weights() def init_weights(self): - dw_max = self.kernel_size ** -0.5 - pw_max = self.channel ** -0.5 - self.dw_conv._param_attr = paddle.nn.initializer.Uniform(low=-dw_max, high=dw_max) - self.dw_conv._bias_attr = paddle.nn.initializer.Uniform(low=-dw_max, high=dw_max) - self.pw_conv._param_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) - self.pw_conv._bias_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) - - def forward(self, xs, xs_lens: paddle.Tensor, - mask: paddle.Tensor = paddle.ones((0, 0, 0), dtype=paddle.bool), - mask_pad: paddle.Tensor = paddle.ones((0, 0, 0), dtype=paddle.bool), - ): + dw_max = self.kernel_size**-0.5 + pw_max = self.channel**-0.5 + self.dw_conv._param_attr = paddle.nn.initializer.Uniform( + low=-dw_max, high=dw_max) + self.dw_conv._bias_attr = paddle.nn.initializer.Uniform( + low=-dw_max, high=dw_max) + self.pw_conv._param_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) + self.pw_conv._bias_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) + + def forward( + self, + xs, + xs_lens: paddle.Tensor, + mask: paddle.Tensor=paddle.ones((0, 0, 0), dtype=paddle.bool), + mask_pad: paddle.Tensor=paddle.ones((0, 0, 0), + dtype=paddle.bool), ): xs = xs.transpose([0, 2, 1]) # [B, C, T] xs = masked_fill(xs, mask_pad.equal(0), 0.0) @@ -388,50 +408,60 @@ def forward(self, xs, xs_lens: paddle.Tensor, class TimeReductionLayer2D(nn.Layer): - def __init__(self, kernel_size: int = 5, stride: int = 2, encoder_dim: int = 256): + def __init__(self, kernel_size: int=5, stride: int=2, encoder_dim: int=256): super(TimeReductionLayer2D, self).__init__() self.encoder_dim = encoder_dim self.kernel_size = kernel_size - self.dw_conv = Conv2DValid(in_channels=encoder_dim, - out_channels=encoder_dim, - kernel_size=(kernel_size, 1), - stride=stride, - valid_trigy=True) - self.pw_conv = Conv2DValid(in_channels=encoder_dim, - out_channels=encoder_dim, - kernel_size=1, - stride=1, - valid_trigx=False, - valid_trigy=False) + self.dw_conv = Conv2DValid( + in_channels=encoder_dim, + out_channels=encoder_dim, + kernel_size=(kernel_size, 1), + stride=stride, + valid_trigy=True) + self.pw_conv = Conv2DValid( + in_channels=encoder_dim, + out_channels=encoder_dim, + kernel_size=1, + stride=1, + valid_trigx=False, + valid_trigy=False) self.kernel_size = kernel_size self.stride = stride self.init_weights() def init_weights(self): - dw_max = self.kernel_size ** -0.5 - pw_max = self.encoder_dim ** -0.5 - self.dw_conv._param_attr = paddle.nn.initializer.Uniform(low=-dw_max, high=dw_max) - self.dw_conv._bias_attr = paddle.nn.initializer.Uniform(low=-dw_max, high=dw_max) - self.pw_conv._param_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) - self.pw_conv._bias_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) + dw_max = self.kernel_size**-0.5 + pw_max = self.encoder_dim**-0.5 + self.dw_conv._param_attr = paddle.nn.initializer.Uniform( + low=-dw_max, high=dw_max) + self.dw_conv._bias_attr = paddle.nn.initializer.Uniform( + low=-dw_max, high=dw_max) + self.pw_conv._param_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) + self.pw_conv._bias_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) def forward( - self, xs: paddle.Tensor, xs_lens: paddle.Tensor, - mask: paddle.Tensor = paddle.ones((0, 0, 0), dtype=paddle.bool), - mask_pad: paddle.Tensor = paddle.ones((0, 0, 0), dtype=paddle.bool), + self, + xs: paddle.Tensor, + xs_lens: paddle.Tensor, + mask: paddle.Tensor=paddle.ones((0, 0, 0), dtype=paddle.bool), + mask_pad: paddle.Tensor=paddle.ones((0, 0, 0), dtype=paddle.bool), ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: xs = masked_fill(xs, mask_pad.transpose([0, 2, 1]).equal(0), 0.0) xs = xs.unsqueeze(1) padding1 = self.kernel_size - self.stride - xs = F.pad(xs, (0, 0, 0, 0, 0, padding1, 0, 0), mode='constant', value=0.) + xs = F.pad( + xs, (0, 0, 0, 0, 0, padding1, 0, 0), mode='constant', value=0.) xs = self.dw_conv(xs.transpose([0, 3, 2, 1])) xs = self.pw_conv(xs).transpose([0, 3, 2, 1]).squeeze(1) tmp_length = xs.shape[1] xs_lens = (xs_lens + 1) // 2 padding2 = max(0, (xs_lens.max() - tmp_length).item()) batch_size, hidden = xs.shape[0], xs.shape[-1] - dummy_pad = paddle.zeros([batch_size, padding2, hidden], dtype=paddle.float32) + dummy_pad = paddle.zeros( + [batch_size, padding2, hidden], dtype=paddle.float32) xs = paddle.concat([xs, dummy_pad], axis=1) mask = mask[:, ::2, ::2] mask_pad = mask_pad[:, :, ::2] @@ -451,8 +481,11 @@ class TimeReductionLayerStream(nn.Layer): stride (int): Downsampling factor in time dimension. """ - def __init__(self, channel: int, out_dim: int, - kernel_size: int = 1, stride: int = 2): + def __init__(self, + channel: int, + out_dim: int, + kernel_size: int=1, + stride: int=2): super(TimeReductionLayerStream, self).__init__() self.channel = channel @@ -460,32 +493,41 @@ def __init__(self, channel: int, out_dim: int, self.kernel_size = kernel_size self.stride = stride - self.dw_conv = Conv1D(in_channels=channel, - out_channels=channel, - kernel_size=kernel_size, - stride=stride, - padding=0, - groups=channel) - - self.pw_conv = Conv1D(in_channels=channel, - out_channels=out_dim, - kernel_size=1, - stride=1, - padding=0, - groups=1) + self.dw_conv = Conv1D( + in_channels=channel, + out_channels=channel, + kernel_size=kernel_size, + stride=stride, + padding=0, + groups=channel) + + self.pw_conv = Conv1D( + in_channels=channel, + out_channels=out_dim, + kernel_size=1, + stride=1, + padding=0, + groups=1) self.init_weights() def init_weights(self): - dw_max = self.kernel_size ** -0.5 - pw_max = self.channel ** -0.5 - self.dw_conv._param_attr = paddle.nn.initializer.Uniform(low=-dw_max, high=dw_max) - self.dw_conv._bias_attr = paddle.nn.initializer.Uniform(low=-dw_max, high=dw_max) - self.pw_conv._param_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) - self.pw_conv._bias_attr = paddle.nn.initializer.Uniform(low=-pw_max, high=pw_max) - - def forward(self, xs, xs_lens: paddle.Tensor, - mask: paddle.Tensor = paddle.ones([0, 0, 0], dtype=paddle.bool), - mask_pad: paddle.Tensor = paddle.ones([0, 0, 0], dtype=paddle.bool)): + dw_max = self.kernel_size**-0.5 + pw_max = self.channel**-0.5 + self.dw_conv._param_attr = paddle.nn.initializer.Uniform( + low=-dw_max, high=dw_max) + self.dw_conv._bias_attr = paddle.nn.initializer.Uniform( + low=-dw_max, high=dw_max) + self.pw_conv._param_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) + self.pw_conv._bias_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) + + def forward( + self, + xs, + xs_lens: paddle.Tensor, + mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), + mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool)): xs = xs.transpose([0, 2, 1]) # [B, C, T] xs = masked_fill(xs, mask_pad.equal(0), 0.0) From 1c156bfe4d6339c3f263310b7e6fa164b8cbfb07 Mon Sep 17 00:00:00 2001 From: yeyupiaoling Date: Tue, 20 Dec 2022 11:00:14 +0800 Subject: [PATCH 4/9] fix subsample rate error, test=asr --- paddlespeech/s2t/utils/utility.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/paddlespeech/s2t/utils/utility.py b/paddlespeech/s2t/utils/utility.py index fe2fa9167a5..d7e7c6ca27e 100644 --- a/paddlespeech/s2t/utils/utility.py +++ b/paddlespeech/s2t/utils/utility.py @@ -131,8 +131,7 @@ def get_subsample(config): int: subsample rate. """ if config['encoder'] == 'squeezeformer': - input_layer = config["encoder_conf"]["time_reduction_layer_type"] - assert input_layer in ["conv2d", "conv1d", "stream"] + return 4 else: input_layer = config["encoder_conf"]["input_layer"] assert input_layer in ["conv2d", "conv2d6", "conv2d8"] @@ -142,7 +141,3 @@ def get_subsample(config): return 6 elif input_layer == "conv2d8": return 8 - elif input_layer == "conv1d": - return 6 - elif input_layer == "stream": - return 8 From c1df5b7985eb8a28328e2b7b4685e984c77fcb2f Mon Sep 17 00:00:00 2001 From: yeyupiaoling Date: Wed, 4 Jan 2023 16:37:18 +0800 Subject: [PATCH 5/9] merge classes as required, test=asr --- .../asr1/conf/chunk_squeezeformer.yaml | 5 +- paddlespeech/s2t/modules/attention.py | 194 ++++-------------- .../s2t/modules/conformer_convolution.py | 43 +++- paddlespeech/s2t/modules/convolution.py | 181 ---------------- paddlespeech/s2t/modules/encoder.py | 10 +- .../s2t/modules/positionwise_feed_forward.py | 52 +---- 6 files changed, 95 insertions(+), 390 deletions(-) delete mode 100644 paddlespeech/s2t/modules/convolution.py diff --git a/examples/aishell/asr1/conf/chunk_squeezeformer.yaml b/examples/aishell/asr1/conf/chunk_squeezeformer.yaml index 691d9046162..2533eacfcda 100644 --- a/examples/aishell/asr1/conf/chunk_squeezeformer.yaml +++ b/examples/aishell/asr1/conf/chunk_squeezeformer.yaml @@ -4,7 +4,7 @@ cmvn_file: cmvn_file_type: "json" # encoder related -encoder: conformer +encoder: squeezeformer encoder_conf: encoder_dim: 256 # dimension of attention output_size: 256 # dimension of output @@ -21,7 +21,8 @@ encoder_conf: normalize_before: false activation_type: 'swish' pos_enc_layer_type: 'rel_pos' - time_reduction_layer_type: 'conv2d' + do_rel_shift: false + time_reduction_layer_type: 'stream' causal: true use_dynamic_chunk: true use_dynamic_left_chunk: false diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index 6347bdb12ee..b2184dbc7a3 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -203,7 +203,10 @@ def forward(self, class RelPositionMultiHeadedAttention(MultiHeadedAttention): """Multi-Head Attention layer with relative position encoding.""" - def __init__(self, n_head, n_feat, dropout_rate): + def __init__(self, n_head, n_feat, dropout_rate, + do_rel_shift=False, + adaptive_scale=False, + init_weights=False): """Construct an RelPositionMultiHeadedAttention object. Paper: https://arxiv.org/abs/1901.02860 Args: @@ -226,151 +229,15 @@ def __init__(self, n_head, n_feat, dropout_rate): pos_bias_v = self.create_parameter( (self.h, self.d_k), default_initializer=I.XavierUniform()) self.add_parameter('pos_bias_v', pos_bias_v) - - def rel_shift(self, x, zero_triu: bool=False): - """Compute relative positinal encoding. - Args: - x (paddle.Tensor): Input tensor (batch, head, time1, time1). - zero_triu (bool): If true, return the lower triangular part of - the matrix. - Returns: - paddle.Tensor: Output tensor. (batch, head, time1, time1) - """ - zero_pad = paddle.zeros( - (x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype) - x_padded = paddle.cat([zero_pad, x], dim=-1) - - x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1, - x.shape[2]) - x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1] - - if zero_triu: - ones = paddle.ones((x.shape[2], x.shape[3])) - x = x * paddle.tril(ones, x.shape[3] - x.shape[2])[None, None, :, :] - - return x - - def forward(self, - query: paddle.Tensor, - key: paddle.Tensor, - value: paddle.Tensor, - mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), - pos_emb: paddle.Tensor=paddle.empty([0]), - cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]) - ) -> Tuple[paddle.Tensor, paddle.Tensor]: - """Compute 'Scaled Dot Product Attention' with rel. positional encoding. - Args: - query (paddle.Tensor): Query tensor (#batch, time1, size). - key (paddle.Tensor): Key tensor (#batch, time2, size). - value (paddle.Tensor): Value tensor (#batch, time2, size). - mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or - (#batch, time1, time2), (0, 0, 0) means fake mask. - pos_emb (paddle.Tensor): Positional embedding tensor - (#batch, time2, size). - cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2), - where `cache_t == chunk_size * num_decoding_left_chunks` - and `head * d_k == size` - Returns: - paddle.Tensor: Output tensor (#batch, time1, d_model). - paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) - where `cache_t == chunk_size * num_decoding_left_chunks` - and `head * d_k == size` - """ - q, k, v = self.forward_qkv(query, key, value) - # q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k) - - # when export onnx model, for 1st chunk, we feed - # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) - # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). - # In all modes, `if cache.size(0) > 0` will alwayse be `True` - # and we will always do splitting and - # concatnation(this will simplify onnx export). Note that - # it's OK to concat & split zero-shaped tensors(see code below). - # when export jit model, for 1st chunk, we always feed - # cache(0, 0, 0, 0) since jit supports dynamic if-branch. - # >>> a = torch.ones((1, 2, 0, 4)) - # >>> b = torch.ones((1, 2, 3, 4)) - # >>> c = torch.cat((a, b), dim=2) - # >>> torch.equal(b, c) # True - # >>> d = torch.split(a, 2, dim=-1) - # >>> torch.equal(d[0], d[1]) # True - if cache.shape[0] > 0: - # last dim `d_k * 2` for (key, val) - key_cache, value_cache = paddle.split(cache, 2, axis=-1) - k = paddle.concat([key_cache, k], axis=2) - v = paddle.concat([value_cache, v], axis=2) - # We do cache slicing in encoder.forward_chunk, since it's - # non-trivial to calculate `next_cache_start` here. - new_cache = paddle.concat((k, v), axis=-1) - - n_batch_pos = pos_emb.shape[0] - p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) - p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) - - # (batch, head, time1, d_k) - # q_with_bias_u = (q + self.pos_bias_u).transpose([0, 2, 1, 3]) - q_with_bias_u = q + self.pos_bias_u.unsqueeze(1) - # (batch, head, time1, d_k) - # q_with_bias_v = (q + self.pos_bias_v).transpose([0, 2, 1, 3]) - q_with_bias_v = q + self.pos_bias_v.unsqueeze(1) - - # compute attention score - # first compute matrix a and matrix c - # as described in https://arxiv.org/abs/1901.02860 Section 3.3 - # (batch, head, time1, time2) - # matrix_ac = paddle.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2])) - matrix_ac = paddle.matmul(q_with_bias_u, k, transpose_y=True) - - # compute matrix b and matrix d - # (batch, head, time1, time2) - # matrix_bd = paddle.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2])) - matrix_bd = paddle.matmul(q_with_bias_v, p, transpose_y=True) - # Remove rel_shift since it is useless in speech recognition, - # and it requires special attention for streaming. - # matrix_bd = self.rel_shift(matrix_bd) - - scores = (matrix_ac + matrix_bd) / math.sqrt( - self.d_k) # (batch, head, time1, time2) - - return self.forward_attention(v, scores, mask), new_cache - - -class RelPositionMultiHeadedAttention2(MultiHeadedAttention): - """Multi-Head Attention layer with relative position encoding. - Paper: https://arxiv.org/abs/1901.02860 - Args: - n_head (int): The number of heads. - n_feat (int): The number of features. - dropout_rate (float): Dropout rate. - """ - - def __init__(self, - n_head, - n_feat, - dropout_rate, - do_rel_shift=False, - adaptive_scale=False, - init_weights=False): - """Construct an RelPositionMultiHeadedAttention object.""" - super().__init__(n_head, n_feat, dropout_rate) - # linear transformation for positional encoding - self.linear_pos = Linear(n_feat, n_feat) - # these two learnable bias are used in matrix c and matrix d - # as described in https://arxiv.org/abs/1901.02860 Section 3.3 self.do_rel_shift = do_rel_shift - pos_bias_u = self.create_parameter( - [self.h, self.d_k], default_initializer=I.XavierUniform()) - self.add_parameter('pos_bias_u', pos_bias_u) - pos_bias_v = self.create_parameter( - [self.h, self.d_k], default_initializer=I.XavierUniform()) - self.add_parameter('pos_bias_v', pos_bias_v) self.adaptive_scale = adaptive_scale - ada_scale = self.create_parameter( - [1, 1, n_feat], default_initializer=I.Constant(1.0)) - self.add_parameter('ada_scale', ada_scale) - ada_bias = self.create_parameter( - [1, 1, n_feat], default_initializer=I.Constant(0.0)) - self.add_parameter('ada_bias', ada_bias) + if self.adaptive_scale: + ada_scale = self.create_parameter( + [1, 1, n_feat], default_initializer=I.Constant(1.0)) + self.add_parameter('ada_scale', ada_scale) + ada_bias = self.create_parameter( + [1, 1, n_feat], default_initializer=I.Constant(0.0)) + self.add_parameter('ada_bias', ada_bias) if init_weights: self.init_weights() @@ -407,12 +274,12 @@ def rel_shift(self, x, zero_triu: bool=False): paddle.Tensor: Output tensor. (batch, head, time1, time1) """ zero_pad = paddle.zeros( - [x.shape[0], x.shape[1], x.shape[2], 1], dtype=x.dtype) - x_padded = paddle.concat([zero_pad, x], axis=-1) + (x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype) + x_padded = paddle.cat([zero_pad, x], dim=-1) - x_padded = x_padded.reshape( - [x.shape[0], x.shape[1], x.shape[3] + 1, x.shape[2]]) - x = x_padded[:, :, 1:].reshape(paddle.shape(x)) # [B, H, T1, T1] + x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1, + x.shape[2]) + x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1] if zero_triu: ones = paddle.ones((x.shape[2], x.shape[3])) @@ -424,10 +291,10 @@ def forward(self, query: paddle.Tensor, key: paddle.Tensor, value: paddle.Tensor, - mask: paddle.Tensor=paddle.ones((0, 0, 0), dtype=paddle.bool), + mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), pos_emb: paddle.Tensor=paddle.empty([0]), - cache: paddle.Tensor=paddle.zeros( - (0, 0, 0, 0))) -> Tuple[paddle.Tensor, paddle.Tensor]: + cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]) + ) -> Tuple[paddle.Tensor, paddle.Tensor]: """Compute 'Scaled Dot Product Attention' with rel. positional encoding. Args: query (paddle.Tensor): Query tensor (#batch, time1, size). @@ -452,17 +319,34 @@ def forward(self, value = self.ada_scale * value + self.ada_bias q, k, v = self.forward_qkv(query, key, value) + # q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k) + + # when export onnx model, for 1st chunk, we feed + # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) + # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). + # In all modes, `if cache.size(0) > 0` will alwayse be `True` + # and we will always do splitting and + # concatnation(this will simplify onnx export). Note that + # it's OK to concat & split zero-shaped tensors(see code below). + # when export jit model, for 1st chunk, we always feed + # cache(0, 0, 0, 0) since jit supports dynamic if-branch. + # >>> a = torch.ones((1, 2, 0, 4)) + # >>> b = torch.ones((1, 2, 3, 4)) + # >>> c = torch.cat((a, b), dim=2) + # >>> torch.equal(b, c) # True + # >>> d = torch.split(a, 2, dim=-1) + # >>> torch.equal(d[0], d[1]) # True if cache.shape[0] > 0: + # last dim `d_k * 2` for (key, val) key_cache, value_cache = paddle.split(cache, 2, axis=-1) k = paddle.concat([key_cache, k], axis=2) v = paddle.concat([value_cache, v], axis=2) - # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's + # We do cache slicing in encoder.forward_chunk, since it's # non-trivial to calculate `next_cache_start` here. new_cache = paddle.concat((k, v), axis=-1) n_batch_pos = pos_emb.shape[0] - p = self.linear_pos(pos_emb).reshape( - [n_batch_pos, -1, self.h, self.d_k]) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) # (batch, head, time1, d_k) diff --git a/paddlespeech/s2t/modules/conformer_convolution.py b/paddlespeech/s2t/modules/conformer_convolution.py index 09d903eee34..e4196e3d4d5 100644 --- a/paddlespeech/s2t/modules/conformer_convolution.py +++ b/paddlespeech/s2t/modules/conformer_convolution.py @@ -18,6 +18,7 @@ import paddle from paddle import nn +from paddle.nn import initializer as I from typeguard import check_argument_types from paddlespeech.s2t.modules.align import BatchNorm1D @@ -39,7 +40,9 @@ def __init__(self, activation: nn.Layer=nn.ReLU(), norm: str="batch_norm", causal: bool=False, - bias: bool=True): + bias: bool=True, + adaptive_scale: bool=False, + init_weights: bool=False): """Construct an ConvolutionModule object. Args: channels (int): The number of channels of conv layers. @@ -51,6 +54,19 @@ def __init__(self, """ assert check_argument_types() super().__init__() + self.bias = bias + self.channels = channels + self.kernel_size = kernel_size + self.adaptive_scale = adaptive_scale + if self.adaptive_scale: + ada_scale = self.create_parameter( + [1, 1, channels], default_initializer=I.Constant(1.0)) + self.add_parameter('ada_scale', ada_scale) + ada_bias = self.create_parameter( + [1, 1, channels], default_initializer=I.Constant(0.0)) + self.add_parameter('ada_bias', ada_bias) + + self.pointwise_conv1 = Conv1D( channels, 2 * channels, @@ -105,6 +121,28 @@ def __init__(self, ) self.activation = activation + if init_weights: + self.init_weights() + + def init_weights(self): + pw_max = self.channels**-0.5 + dw_max = self.kernel_size**-0.5 + self.pointwise_conv1._param_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) + if self.bias: + self.pointwise_conv1._bias_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) + self.depthwise_conv._param_attr = paddle.nn.initializer.Uniform( + low=-dw_max, high=dw_max) + if self.bias: + self.depthwise_conv._bias_attr = paddle.nn.initializer.Uniform( + low=-dw_max, high=dw_max) + self.pointwise_conv2._param_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) + if self.bias: + self.pointwise_conv2._bias_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) + def forward( self, x: paddle.Tensor, @@ -123,6 +161,9 @@ def forward( paddle.Tensor: Output tensor (#batch, time, channels). paddle.Tensor: Output cache tensor (#batch, channels, time') """ + if self.adaptive_scale: + x = self.ada_scale * x + self.ada_bias + # exchange the temporal dimension and the feature dimension x = x.transpose([0, 2, 1]) # [B, C, T] diff --git a/paddlespeech/s2t/modules/convolution.py b/paddlespeech/s2t/modules/convolution.py deleted file mode 100644 index caaa985668a..00000000000 --- a/paddlespeech/s2t/modules/convolution.py +++ /dev/null @@ -1,181 +0,0 @@ -from typing import Tuple - -import paddle -from paddle import nn -from paddle.nn import initializer as I -from typeguard import check_argument_types - -__all__ = ['ConvolutionModule2'] - -from paddlespeech.s2t import masked_fill -from paddlespeech.s2t.modules.align import Conv1D, BatchNorm1D, LayerNorm - - -class ConvolutionModule2(nn.Layer): - """ConvolutionModule in Conformer model.""" - - def __init__(self, - channels: int, - kernel_size: int=15, - activation: nn.Layer=nn.ReLU(), - norm: str="batch_norm", - causal: bool=False, - bias: bool=True, - adaptive_scale: bool=False, - init_weights: bool=False): - """Construct an ConvolutionModule object. - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernel size of conv layers. - causal (int): Whether use causal convolution or not - """ - assert check_argument_types() - super().__init__() - self.bias = bias - self.channels = channels - self.kernel_size = kernel_size - self.adaptive_scale = adaptive_scale - ada_scale = self.create_parameter( - [1, 1, channels], default_initializer=I.Constant(1.0)) - self.add_parameter('ada_scale', ada_scale) - ada_bias = self.create_parameter( - [1, 1, channels], default_initializer=I.Constant(0.0)) - self.add_parameter('ada_bias', ada_bias) - - self.pointwise_conv1 = Conv1D( - channels, - 2 * channels, - kernel_size=1, - stride=1, - padding=0, - bias_attr=None - if bias else False, # None for True, using bias as default config - ) - - # self.lorder is used to distinguish if it's a causal convolution, - # if self.lorder > 0: it's a causal convolution, the input will be - # padded with self.lorder frames on the left in forward. - # else: it's a symmetrical convolution - if causal: - padding = 0 - self.lorder = kernel_size - 1 - else: - # kernel_size should be an odd number for none causal convolution - assert (kernel_size - 1) % 2 == 0 - padding = (kernel_size - 1) // 2 - self.lorder = 0 - self.depthwise_conv = Conv1D( - channels, - channels, - kernel_size, - stride=1, - padding=padding, - groups=channels, - bias_attr=None - if bias else False, # None for True, using bias as default config - ) - - assert norm in ['batch_norm', 'layer_norm'] - if norm == "batch_norm": - self.use_layer_norm = False - self.norm = BatchNorm1D(channels) - else: - self.use_layer_norm = True - self.norm = LayerNorm(channels) - - self.pointwise_conv2 = Conv1D( - channels, - channels, - kernel_size=1, - stride=1, - padding=0, - bias_attr=None - if bias else False, # None for True, using bias as default config - ) - self.activation = activation - - if init_weights: - self.init_weights() - - def init_weights(self): - pw_max = self.channels**-0.5 - dw_max = self.kernel_size**-0.5 - self.pointwise_conv1._param_attr = paddle.nn.initializer.Uniform( - low=-pw_max, high=pw_max) - if self.bias: - self.pointwise_conv1._bias_attr = paddle.nn.initializer.Uniform( - low=-pw_max, high=pw_max) - self.depthwise_conv._param_attr = paddle.nn.initializer.Uniform( - low=-dw_max, high=dw_max) - if self.bias: - self.depthwise_conv._bias_attr = paddle.nn.initializer.Uniform( - low=-dw_max, high=dw_max) - self.pointwise_conv2._param_attr = paddle.nn.initializer.Uniform( - low=-pw_max, high=pw_max) - if self.bias: - self.pointwise_conv2._bias_attr = paddle.nn.initializer.Uniform( - low=-pw_max, high=pw_max) - - def forward( - self, - x: paddle.Tensor, - mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), - cache: paddle.Tensor=paddle.zeros([0, 0, 0]), - ) -> Tuple[paddle.Tensor, paddle.Tensor]: - """Compute convolution module. - Args: - x (torch.Tensor): Input tensor (#batch, time, channels). - mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), - (0, 0, 0) means fake mask. - cache (torch.Tensor): left context cache, it is only - used in causal convolution (#batch, channels, cache_t), - (0, 0, 0) meas fake cache. - Returns: - torch.Tensor: Output tensor (#batch, time, channels). - """ - if self.adaptive_scale: - x = self.ada_scale * x + self.ada_bias - - # exchange the temporal dimension and the feature dimension - x = x.transpose([0, 2, 1]) # [B, C, T] - - # mask batch padding - if mask_pad.shape[2] > 0: # time > 0 - x = masked_fill(x, mask_pad, 0.0) - - if self.lorder > 0: - if cache.shape[2] == 0: # cache_t == 0 - x = nn.functional.pad( - x, [self.lorder, 0], 'constant', 0.0, data_format='NCL') - else: - assert cache.shape[0] == x.shape[0] # B - assert cache.shape[1] == x.shape[1] # C - x = paddle.concat((cache, x), axis=2) - - assert (x.shape[2] > self.lorder) - new_cache = x[:, :, -self.lorder:] # [B, C, T] - else: - # It's better we just return None if no cache is required, - # However, for JIT export, here we just fake one tensor instead of - # None. - new_cache = paddle.zeros([0, 0, 0], dtype=x.dtype) - - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channel, dim) - x = nn.functional.glu(x, axis=1) # (batch, channel, dim) - - # 1D Depthwise Conv - x = self.depthwise_conv(x) - if self.use_layer_norm: - x = x.transpose([0, 2, 1]) # [B, T, C] - x = self.activation(self.norm(x)) - if self.use_layer_norm: - x = x.transpose([0, 2, 1]) # [B, C, T] - x = self.pointwise_conv2(x) - - # mask batch padding - if mask_pad.shape[2] > 0: # time > 0 - x = masked_fill(x, mask_pad, 0.0) - - x = x.transpose([0, 2, 1]) # [B, T, C] - return x, new_cache diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index f19ecfe4171..d133735b204 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -30,7 +30,6 @@ from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention2 from paddlespeech.s2t.modules.conformer_convolution import ConvolutionModule -from paddlespeech.s2t.modules.convolution import ConvolutionModule2 from paddlespeech.s2t.modules.embedding import NoPositionalEncoding from paddlespeech.s2t.modules.embedding import PositionalEncoding from paddlespeech.s2t.modules.embedding import RelPositionalEncoding @@ -40,7 +39,6 @@ from paddlespeech.s2t.modules.mask import add_optional_chunk_mask from paddlespeech.s2t.modules.mask import make_non_pad_mask from paddlespeech.s2t.modules.positionwise_feed_forward import PositionwiseFeedForward -from paddlespeech.s2t.modules.positionwise_feed_forward import PositionwiseFeedForward2 from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling4 from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling6 from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling8 @@ -591,19 +589,19 @@ def __init__(self, encoder_selfattn_layer_args = (attention_heads, output_size, attention_dropout_rate) else: - encoder_selfattn_layer = RelPositionMultiHeadedAttention2 + encoder_selfattn_layer = RelPositionMultiHeadedAttention encoder_selfattn_layer_args = (attention_heads, encoder_dim, attention_dropout_rate, do_rel_shift, adaptive_scale, init_weights) # feed-forward module definition - positionwise_layer = PositionwiseFeedForward2 + positionwise_layer = PositionwiseFeedForward positionwise_layer_args = ( encoder_dim, encoder_dim * feed_forward_expansion_factor, feed_forward_dropout_rate, activation, adaptive_scale, init_weights) # convolution module definition - convolution_layer = ConvolutionModule2 + convolution_layer = ConvolutionModule convolution_layer_args = (encoder_dim, cnn_module_kernel, activation, cnn_norm_type, causal, True, adaptive_scale, init_weights) @@ -676,7 +674,7 @@ def forward( if self.global_cmvn is not None: xs = self.global_cmvn(xs) xs, pos_emb, masks = self.embed(xs, masks) - mask_pad = ~masks + mask_pad = masks chunk_masks = add_optional_chunk_mask( xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk, decoding_chunk_size, self.static_chunk_size, diff --git a/paddlespeech/s2t/modules/positionwise_feed_forward.py b/paddlespeech/s2t/modules/positionwise_feed_forward.py index 39d8b1893e4..b5395f049c6 100644 --- a/paddlespeech/s2t/modules/positionwise_feed_forward.py +++ b/paddlespeech/s2t/modules/positionwise_feed_forward.py @@ -23,7 +23,7 @@ logger = Log(__name__).getlog() -__all__ = ["PositionwiseFeedForward", "PositionwiseFeedForward2"] +__all__ = ["PositionwiseFeedForward"] class PositionwiseFeedForward(nn.Layer): @@ -33,7 +33,9 @@ def __init__(self, idim: int, hidden_units: int, dropout_rate: float, - activation: nn.Layer=nn.ReLU()): + activation: nn.Layer=nn.ReLU(), + adaptive_scale: bool=False, + init_weights: bool=False): """Construct a PositionwiseFeedForward object. FeedForward are appied on each position of the sequence. @@ -46,48 +48,11 @@ def __init__(self, activation (paddle.nn.Layer): Activation function """ super().__init__() - self.w_1 = Linear(idim, hidden_units) - self.activation = activation - self.dropout = nn.Dropout(dropout_rate) - self.w_2 = Linear(hidden_units, idim) - - def forward(self, xs: paddle.Tensor) -> paddle.Tensor: - """Forward function. - Args: - xs: input tensor (B, Lmax, D) - Returns: - output tensor, (B, Lmax, D) - """ - return self.w_2(self.dropout(self.activation(self.w_1(xs)))) - - -class PositionwiseFeedForward2(paddle.nn.Layer): - """Positionwise feed forward layer. - - FeedForward are appied on each position of the sequence. - The output dim is same with the input dim. - - Args: - idim (int): Input dimenstion. - hidden_units (int): The number of hidden units. - dropout_rate (float): Dropout rate. - activation (paddle.nn.Layer): Activation function - """ - - def __init__(self, - idim: int, - hidden_units: int, - dropout_rate: float, - activation: paddle.nn.Layer=paddle.nn.ReLU(), - adaptive_scale: bool=False, - init_weights: bool=False): - """Construct a PositionwiseFeedForward object.""" - super(PositionwiseFeedForward2, self).__init__() self.idim = idim self.hidden_units = hidden_units self.w_1 = Linear(idim, hidden_units) self.activation = activation - self.dropout = paddle.nn.Dropout(dropout_rate) + self.dropout = nn.Dropout(dropout_rate) self.w_2 = Linear(hidden_units, idim) self.adaptive_scale = adaptive_scale ada_scale = self.create_parameter( @@ -114,12 +79,9 @@ def init_weights(self): def forward(self, xs: paddle.Tensor) -> paddle.Tensor: """Forward function. - Args: - xs: input tensor (B, L, D) + xs: input tensor (B, Lmax, D) Returns: - output tensor, (B, L, D) + output tensor, (B, Lmax, D) """ - if self.adaptive_scale: - xs = self.ada_scale * xs + self.ada_bias return self.w_2(self.dropout(self.activation(self.w_1(xs)))) From ccc15715946f56d440403537af3258664f24b7b7 Mon Sep 17 00:00:00 2001 From: yeyupiaoling Date: Wed, 4 Jan 2023 16:55:12 +0800 Subject: [PATCH 6/9] change CodeStyle, test=asr --- paddlespeech/s2t/modules/attention.py | 10 +++++----- paddlespeech/s2t/modules/conformer_convolution.py | 1 - 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index b2184dbc7a3..43700ca1ecc 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -26,10 +26,7 @@ logger = Log(__name__).getlog() -__all__ = [ - "MultiHeadedAttention", "RelPositionMultiHeadedAttention", - "RelPositionMultiHeadedAttention2" -] +__all__ = ["MultiHeadedAttention", "RelPositionMultiHeadedAttention"] # Relative Positional Encodings # https://www.jianshu.com/p/c0608efcc26f @@ -203,7 +200,10 @@ def forward(self, class RelPositionMultiHeadedAttention(MultiHeadedAttention): """Multi-Head Attention layer with relative position encoding.""" - def __init__(self, n_head, n_feat, dropout_rate, + def __init__(self, + n_head, + n_feat, + dropout_rate, do_rel_shift=False, adaptive_scale=False, init_weights=False): diff --git a/paddlespeech/s2t/modules/conformer_convolution.py b/paddlespeech/s2t/modules/conformer_convolution.py index e4196e3d4d5..7a0c72f3b22 100644 --- a/paddlespeech/s2t/modules/conformer_convolution.py +++ b/paddlespeech/s2t/modules/conformer_convolution.py @@ -66,7 +66,6 @@ def __init__(self, [1, 1, channels], default_initializer=I.Constant(0.0)) self.add_parameter('ada_bias', ada_bias) - self.pointwise_conv1 = Conv1D( channels, 2 * channels, From 7b1519b8580dff9194148474b6b18315c89769d4 Mon Sep 17 00:00:00 2001 From: yeyupiaoling Date: Wed, 4 Jan 2023 17:12:33 +0800 Subject: [PATCH 7/9] fix missing code, test=asr --- .../s2t/modules/positionwise_feed_forward.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/paddlespeech/s2t/modules/positionwise_feed_forward.py b/paddlespeech/s2t/modules/positionwise_feed_forward.py index b5395f049c6..9ebd5d638b4 100644 --- a/paddlespeech/s2t/modules/positionwise_feed_forward.py +++ b/paddlespeech/s2t/modules/positionwise_feed_forward.py @@ -55,12 +55,13 @@ def __init__(self, self.dropout = nn.Dropout(dropout_rate) self.w_2 = Linear(hidden_units, idim) self.adaptive_scale = adaptive_scale - ada_scale = self.create_parameter( - [1, 1, idim], default_initializer=I.XavierUniform()) - self.add_parameter('ada_scale', ada_scale) - ada_bias = self.create_parameter( - [1, 1, idim], default_initializer=I.XavierUniform()) - self.add_parameter('ada_bias', ada_bias) + if self.adaptive_scale: + ada_scale = self.create_parameter( + [1, 1, idim], default_initializer=I.XavierUniform()) + self.add_parameter('ada_scale', ada_scale) + ada_bias = self.create_parameter( + [1, 1, idim], default_initializer=I.XavierUniform()) + self.add_parameter('ada_bias', ada_bias) if init_weights: self.init_weights() @@ -84,4 +85,6 @@ def forward(self, xs: paddle.Tensor) -> paddle.Tensor: Returns: output tensor, (B, Lmax, D) """ + if self.adaptive_scale: + xs = self.ada_scale * xs + self.ada_bias return self.w_2(self.dropout(self.activation(self.w_1(xs)))) From b297635a7574cab1d0ddcc206254d94ce036df1a Mon Sep 17 00:00:00 2001 From: yeyupiaoling Date: Thu, 5 Jan 2023 15:00:17 +0800 Subject: [PATCH 8/9] split code to new file, test=asr --- .../asr1/conf/chunk_squeezeformer.yaml | 2 +- examples/aishell/asr1/conf/squeezeformer.yaml | 2 +- paddlespeech/s2t/modules/encoder.py | 7 +- paddlespeech/s2t/modules/subsampling.py | 238 +--------------- paddlespeech/s2t/modules/time_reduction.py | 263 ++++++++++++++++++ 5 files changed, 269 insertions(+), 243 deletions(-) create mode 100644 paddlespeech/s2t/modules/time_reduction.py diff --git a/examples/aishell/asr1/conf/chunk_squeezeformer.yaml b/examples/aishell/asr1/conf/chunk_squeezeformer.yaml index 2533eacfcda..45a2ac965c8 100644 --- a/examples/aishell/asr1/conf/chunk_squeezeformer.yaml +++ b/examples/aishell/asr1/conf/chunk_squeezeformer.yaml @@ -12,7 +12,7 @@ encoder_conf: num_blocks: 12 # the number of encoder blocks reduce_idx: 5 recover_idx: 11 - feed_forward_expansion_factor: 4 + feed_forward_expansion_factor: 8 input_dropout_rate: 0.1 feed_forward_dropout_rate: 0.1 attention_dropout_rate: 0.1 diff --git a/examples/aishell/asr1/conf/squeezeformer.yaml b/examples/aishell/asr1/conf/squeezeformer.yaml index db8ef7c2df2..49a837a8271 100644 --- a/examples/aishell/asr1/conf/squeezeformer.yaml +++ b/examples/aishell/asr1/conf/squeezeformer.yaml @@ -12,7 +12,7 @@ encoder_conf: num_blocks: 12 # the number of encoder blocks reduce_idx: 5 recover_idx: 11 - feed_forward_expansion_factor: 4 + feed_forward_expansion_factor: 8 input_dropout_rate: 0.1 feed_forward_dropout_rate: 0.1 attention_dropout_rate: 0.1 diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index d133735b204..7be1925751b 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -28,7 +28,6 @@ from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.modules.attention import MultiHeadedAttention from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention -from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention2 from paddlespeech.s2t.modules.conformer_convolution import ConvolutionModule from paddlespeech.s2t.modules.embedding import NoPositionalEncoding from paddlespeech.s2t.modules.embedding import PositionalEncoding @@ -44,9 +43,9 @@ from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling8 from paddlespeech.s2t.modules.subsampling import DepthwiseConv2DSubsampling4 from paddlespeech.s2t.modules.subsampling import LinearNoSubsampling -from paddlespeech.s2t.modules.subsampling import TimeReductionLayer1D -from paddlespeech.s2t.modules.subsampling import TimeReductionLayer2D -from paddlespeech.s2t.modules.subsampling import TimeReductionLayerStream +from paddlespeech.s2t.modules.time_reduction import TimeReductionLayer1D +from paddlespeech.s2t.modules.time_reduction import TimeReductionLayer2D +from paddlespeech.s2t.modules.time_reduction import TimeReductionLayerStream from paddlespeech.s2t.utils.log import Log logger = Log(__name__).getlog() diff --git a/paddlespeech/s2t/modules/subsampling.py b/paddlespeech/s2t/modules/subsampling.py index 51322d324f2..ef60bdf0aca 100644 --- a/paddlespeech/s2t/modules/subsampling.py +++ b/paddlespeech/s2t/modules/subsampling.py @@ -17,15 +17,11 @@ from typing import Tuple import paddle -import paddle.nn.functional as F from paddle import nn -from paddlespeech.s2t import masked_fill -from paddlespeech.s2t.modules.align import Conv1D from paddlespeech.s2t.modules.align import Conv2D from paddlespeech.s2t.modules.align import LayerNorm from paddlespeech.s2t.modules.align import Linear -from paddlespeech.s2t.modules.conv2d import Conv2DValid from paddlespeech.s2t.modules.embedding import PositionalEncoding from paddlespeech.s2t.utils.log import Log @@ -33,8 +29,7 @@ __all__ = [ "LinearNoSubsampling", "Conv2dSubsampling4", "Conv2dSubsampling6", - "Conv2dSubsampling8", "TimeReductionLayerStream", "TimeReductionLayer1D", - "TimeReductionLayer2D", "DepthwiseConv2DSubsampling4" + "Conv2dSubsampling8", "DepthwiseConv2DSubsampling4" ] @@ -318,234 +313,3 @@ def forward(self, x: paddle.Tensor, x_mask: paddle.Tensor, offset: int=0 x, pos_emb = self.pos_enc(x, offset) x = self.input_proj(x) return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2] - - -class TimeReductionLayer1D(nn.Layer): - """ - Modified NeMo, - Squeezeformer Time Reduction procedure. - Downsamples the audio by `stride` in the time dimension. - Args: - channel (int): input dimension of - MultiheadAttentionMechanism and PositionwiseFeedForward - out_dim (int): Output dimension of the module. - kernel_size (int): Conv kernel size for - depthwise convolution in convolution module - stride (int): Downsampling factor in time dimension. - """ - - def __init__(self, - channel: int, - out_dim: int, - kernel_size: int=5, - stride: int=2): - super(TimeReductionLayer1D, self).__init__() - - self.channel = channel - self.out_dim = out_dim - self.kernel_size = kernel_size - self.stride = stride - self.padding = max(0, self.kernel_size - self.stride) - - self.dw_conv = Conv1D( - in_channels=channel, - out_channels=channel, - kernel_size=kernel_size, - stride=stride, - padding=self.padding, - groups=channel, ) - - self.pw_conv = Conv1D( - in_channels=channel, - out_channels=out_dim, - kernel_size=1, - stride=1, - padding=0, - groups=1, ) - - self.init_weights() - - def init_weights(self): - dw_max = self.kernel_size**-0.5 - pw_max = self.channel**-0.5 - self.dw_conv._param_attr = paddle.nn.initializer.Uniform( - low=-dw_max, high=dw_max) - self.dw_conv._bias_attr = paddle.nn.initializer.Uniform( - low=-dw_max, high=dw_max) - self.pw_conv._param_attr = paddle.nn.initializer.Uniform( - low=-pw_max, high=pw_max) - self.pw_conv._bias_attr = paddle.nn.initializer.Uniform( - low=-pw_max, high=pw_max) - - def forward( - self, - xs, - xs_lens: paddle.Tensor, - mask: paddle.Tensor=paddle.ones((0, 0, 0), dtype=paddle.bool), - mask_pad: paddle.Tensor=paddle.ones((0, 0, 0), - dtype=paddle.bool), ): - xs = xs.transpose([0, 2, 1]) # [B, C, T] - xs = masked_fill(xs, mask_pad.equal(0), 0.0) - - xs = self.dw_conv(xs) - xs = self.pw_conv(xs) - - xs = xs.transpose([0, 2, 1]) # [B, T, C] - - B, T, D = xs.shape - mask = mask[:, ::self.stride, ::self.stride] - mask_pad = mask_pad[:, :, ::self.stride] - L = mask_pad.shape[-1] - # For JIT exporting, we remove F.pad operator. - if L - T < 0: - xs = xs[:, :L - T, :] - else: - dummy_pad = paddle.zeros([B, L - T, D], dtype=paddle.float32) - xs = paddle.concat([xs, dummy_pad], axis=1) - - xs_lens = (xs_lens + 1) // 2 - return xs, xs_lens, mask, mask_pad - - -class TimeReductionLayer2D(nn.Layer): - def __init__(self, kernel_size: int=5, stride: int=2, encoder_dim: int=256): - super(TimeReductionLayer2D, self).__init__() - self.encoder_dim = encoder_dim - self.kernel_size = kernel_size - self.dw_conv = Conv2DValid( - in_channels=encoder_dim, - out_channels=encoder_dim, - kernel_size=(kernel_size, 1), - stride=stride, - valid_trigy=True) - self.pw_conv = Conv2DValid( - in_channels=encoder_dim, - out_channels=encoder_dim, - kernel_size=1, - stride=1, - valid_trigx=False, - valid_trigy=False) - - self.kernel_size = kernel_size - self.stride = stride - self.init_weights() - - def init_weights(self): - dw_max = self.kernel_size**-0.5 - pw_max = self.encoder_dim**-0.5 - self.dw_conv._param_attr = paddle.nn.initializer.Uniform( - low=-dw_max, high=dw_max) - self.dw_conv._bias_attr = paddle.nn.initializer.Uniform( - low=-dw_max, high=dw_max) - self.pw_conv._param_attr = paddle.nn.initializer.Uniform( - low=-pw_max, high=pw_max) - self.pw_conv._bias_attr = paddle.nn.initializer.Uniform( - low=-pw_max, high=pw_max) - - def forward( - self, - xs: paddle.Tensor, - xs_lens: paddle.Tensor, - mask: paddle.Tensor=paddle.ones((0, 0, 0), dtype=paddle.bool), - mask_pad: paddle.Tensor=paddle.ones((0, 0, 0), dtype=paddle.bool), - ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: - xs = masked_fill(xs, mask_pad.transpose([0, 2, 1]).equal(0), 0.0) - xs = xs.unsqueeze(1) - padding1 = self.kernel_size - self.stride - xs = F.pad( - xs, (0, 0, 0, 0, 0, padding1, 0, 0), mode='constant', value=0.) - xs = self.dw_conv(xs.transpose([0, 3, 2, 1])) - xs = self.pw_conv(xs).transpose([0, 3, 2, 1]).squeeze(1) - tmp_length = xs.shape[1] - xs_lens = (xs_lens + 1) // 2 - padding2 = max(0, (xs_lens.max() - tmp_length).item()) - batch_size, hidden = xs.shape[0], xs.shape[-1] - dummy_pad = paddle.zeros( - [batch_size, padding2, hidden], dtype=paddle.float32) - xs = paddle.concat([xs, dummy_pad], axis=1) - mask = mask[:, ::2, ::2] - mask_pad = mask_pad[:, :, ::2] - return xs, xs_lens, mask, mask_pad - - -class TimeReductionLayerStream(nn.Layer): - """ - Squeezeformer Time Reduction procedure. - Downsamples the audio by `stride` in the time dimension. - Args: - channel (int): input dimension of - MultiheadAttentionMechanism and PositionwiseFeedForward - out_dim (int): Output dimension of the module. - kernel_size (int): Conv kernel size for - depthwise convolution in convolution module - stride (int): Downsampling factor in time dimension. - """ - - def __init__(self, - channel: int, - out_dim: int, - kernel_size: int=1, - stride: int=2): - super(TimeReductionLayerStream, self).__init__() - - self.channel = channel - self.out_dim = out_dim - self.kernel_size = kernel_size - self.stride = stride - - self.dw_conv = Conv1D( - in_channels=channel, - out_channels=channel, - kernel_size=kernel_size, - stride=stride, - padding=0, - groups=channel) - - self.pw_conv = Conv1D( - in_channels=channel, - out_channels=out_dim, - kernel_size=1, - stride=1, - padding=0, - groups=1) - self.init_weights() - - def init_weights(self): - dw_max = self.kernel_size**-0.5 - pw_max = self.channel**-0.5 - self.dw_conv._param_attr = paddle.nn.initializer.Uniform( - low=-dw_max, high=dw_max) - self.dw_conv._bias_attr = paddle.nn.initializer.Uniform( - low=-dw_max, high=dw_max) - self.pw_conv._param_attr = paddle.nn.initializer.Uniform( - low=-pw_max, high=pw_max) - self.pw_conv._bias_attr = paddle.nn.initializer.Uniform( - low=-pw_max, high=pw_max) - - def forward( - self, - xs, - xs_lens: paddle.Tensor, - mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), - mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool)): - xs = xs.transpose([0, 2, 1]) # [B, C, T] - xs = masked_fill(xs, mask_pad.equal(0), 0.0) - - xs = self.dw_conv(xs) - xs = self.pw_conv(xs) - - xs = xs.transpose([0, 2, 1]) # [B, T, C] - - B, T, D = xs.shape - mask = mask[:, ::self.stride, ::self.stride] - mask_pad = mask_pad[:, :, ::self.stride] - L = mask_pad.shape[-1] - # For JIT exporting, we remove F.pad operator. - if L - T < 0: - xs = xs[:, :L - T, :] - else: - dummy_pad = paddle.zeros([B, L - T, D], dtype=paddle.float32) - xs = paddle.concat([xs, dummy_pad], axis=1) - - xs_lens = (xs_lens + 1) // 2 - return xs, xs_lens, mask, mask_pad diff --git a/paddlespeech/s2t/modules/time_reduction.py b/paddlespeech/s2t/modules/time_reduction.py new file mode 100644 index 00000000000..d3393f108a1 --- /dev/null +++ b/paddlespeech/s2t/modules/time_reduction.py @@ -0,0 +1,263 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2019 Mobvoi Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from wenet(https://github.com/wenet-e2e/wenet) +"""Subsampling layer definition.""" +from typing import Tuple + +import paddle +import paddle.nn.functional as F +from paddle import nn + +from paddlespeech.s2t import masked_fill +from paddlespeech.s2t.modules.align import Conv1D +from paddlespeech.s2t.modules.conv2d import Conv2DValid +from paddlespeech.s2t.utils.log import Log + +logger = Log(__name__).getlog() + +__all__ = [ + "TimeReductionLayerStream", "TimeReductionLayer1D", "TimeReductionLayer2D" +] + + +class TimeReductionLayer1D(nn.Layer): + """ + Modified NeMo, + Squeezeformer Time Reduction procedure. + Downsamples the audio by `stride` in the time dimension. + Args: + channel (int): input dimension of + MultiheadAttentionMechanism and PositionwiseFeedForward + out_dim (int): Output dimension of the module. + kernel_size (int): Conv kernel size for + depthwise convolution in convolution module + stride (int): Downsampling factor in time dimension. + """ + + def __init__(self, + channel: int, + out_dim: int, + kernel_size: int=5, + stride: int=2): + super(TimeReductionLayer1D, self).__init__() + + self.channel = channel + self.out_dim = out_dim + self.kernel_size = kernel_size + self.stride = stride + self.padding = max(0, self.kernel_size - self.stride) + + self.dw_conv = Conv1D( + in_channels=channel, + out_channels=channel, + kernel_size=kernel_size, + stride=stride, + padding=self.padding, + groups=channel, ) + + self.pw_conv = Conv1D( + in_channels=channel, + out_channels=out_dim, + kernel_size=1, + stride=1, + padding=0, + groups=1, ) + + self.init_weights() + + def init_weights(self): + dw_max = self.kernel_size**-0.5 + pw_max = self.channel**-0.5 + self.dw_conv._param_attr = paddle.nn.initializer.Uniform( + low=-dw_max, high=dw_max) + self.dw_conv._bias_attr = paddle.nn.initializer.Uniform( + low=-dw_max, high=dw_max) + self.pw_conv._param_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) + self.pw_conv._bias_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) + + def forward( + self, + xs, + xs_lens: paddle.Tensor, + mask: paddle.Tensor=paddle.ones((0, 0, 0), dtype=paddle.bool), + mask_pad: paddle.Tensor=paddle.ones((0, 0, 0), + dtype=paddle.bool), ): + xs = xs.transpose([0, 2, 1]) # [B, C, T] + xs = masked_fill(xs, mask_pad.equal(0), 0.0) + + xs = self.dw_conv(xs) + xs = self.pw_conv(xs) + + xs = xs.transpose([0, 2, 1]) # [B, T, C] + + B, T, D = xs.shape + mask = mask[:, ::self.stride, ::self.stride] + mask_pad = mask_pad[:, :, ::self.stride] + L = mask_pad.shape[-1] + # For JIT exporting, we remove F.pad operator. + if L - T < 0: + xs = xs[:, :L - T, :] + else: + dummy_pad = paddle.zeros([B, L - T, D], dtype=paddle.float32) + xs = paddle.concat([xs, dummy_pad], axis=1) + + xs_lens = (xs_lens + 1) // 2 + return xs, xs_lens, mask, mask_pad + + +class TimeReductionLayer2D(nn.Layer): + def __init__(self, kernel_size: int=5, stride: int=2, encoder_dim: int=256): + super(TimeReductionLayer2D, self).__init__() + self.encoder_dim = encoder_dim + self.kernel_size = kernel_size + self.dw_conv = Conv2DValid( + in_channels=encoder_dim, + out_channels=encoder_dim, + kernel_size=(kernel_size, 1), + stride=stride, + valid_trigy=True) + self.pw_conv = Conv2DValid( + in_channels=encoder_dim, + out_channels=encoder_dim, + kernel_size=1, + stride=1, + valid_trigx=False, + valid_trigy=False) + + self.kernel_size = kernel_size + self.stride = stride + self.init_weights() + + def init_weights(self): + dw_max = self.kernel_size**-0.5 + pw_max = self.encoder_dim**-0.5 + self.dw_conv._param_attr = paddle.nn.initializer.Uniform( + low=-dw_max, high=dw_max) + self.dw_conv._bias_attr = paddle.nn.initializer.Uniform( + low=-dw_max, high=dw_max) + self.pw_conv._param_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) + self.pw_conv._bias_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) + + def forward( + self, + xs: paddle.Tensor, + xs_lens: paddle.Tensor, + mask: paddle.Tensor=paddle.ones((0, 0, 0), dtype=paddle.bool), + mask_pad: paddle.Tensor=paddle.ones((0, 0, 0), dtype=paddle.bool), + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + xs = masked_fill(xs, mask_pad.transpose([0, 2, 1]).equal(0), 0.0) + xs = xs.unsqueeze(1) + padding1 = self.kernel_size - self.stride + xs = F.pad( + xs, (0, 0, 0, 0, 0, padding1, 0, 0), mode='constant', value=0.) + xs = self.dw_conv(xs.transpose([0, 3, 2, 1])) + xs = self.pw_conv(xs).transpose([0, 3, 2, 1]).squeeze(1) + tmp_length = xs.shape[1] + xs_lens = (xs_lens + 1) // 2 + padding2 = max(0, (xs_lens.max() - tmp_length).item()) + batch_size, hidden = xs.shape[0], xs.shape[-1] + dummy_pad = paddle.zeros( + [batch_size, padding2, hidden], dtype=paddle.float32) + xs = paddle.concat([xs, dummy_pad], axis=1) + mask = mask[:, ::2, ::2] + mask_pad = mask_pad[:, :, ::2] + return xs, xs_lens, mask, mask_pad + + +class TimeReductionLayerStream(nn.Layer): + """ + Squeezeformer Time Reduction procedure. + Downsamples the audio by `stride` in the time dimension. + Args: + channel (int): input dimension of + MultiheadAttentionMechanism and PositionwiseFeedForward + out_dim (int): Output dimension of the module. + kernel_size (int): Conv kernel size for + depthwise convolution in convolution module + stride (int): Downsampling factor in time dimension. + """ + + def __init__(self, + channel: int, + out_dim: int, + kernel_size: int=1, + stride: int=2): + super(TimeReductionLayerStream, self).__init__() + + self.channel = channel + self.out_dim = out_dim + self.kernel_size = kernel_size + self.stride = stride + + self.dw_conv = Conv1D( + in_channels=channel, + out_channels=channel, + kernel_size=kernel_size, + stride=stride, + padding=0, + groups=channel) + + self.pw_conv = Conv1D( + in_channels=channel, + out_channels=out_dim, + kernel_size=1, + stride=1, + padding=0, + groups=1) + self.init_weights() + + def init_weights(self): + dw_max = self.kernel_size**-0.5 + pw_max = self.channel**-0.5 + self.dw_conv._param_attr = paddle.nn.initializer.Uniform( + low=-dw_max, high=dw_max) + self.dw_conv._bias_attr = paddle.nn.initializer.Uniform( + low=-dw_max, high=dw_max) + self.pw_conv._param_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) + self.pw_conv._bias_attr = paddle.nn.initializer.Uniform( + low=-pw_max, high=pw_max) + + def forward( + self, + xs, + xs_lens: paddle.Tensor, + mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), + mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool)): + xs = xs.transpose([0, 2, 1]) # [B, C, T] + xs = masked_fill(xs, mask_pad.equal(0), 0.0) + + xs = self.dw_conv(xs) + xs = self.pw_conv(xs) + + xs = xs.transpose([0, 2, 1]) # [B, T, C] + + B, T, D = xs.shape + mask = mask[:, ::self.stride, ::self.stride] + mask_pad = mask_pad[:, :, ::self.stride] + L = mask_pad.shape[-1] + # For JIT exporting, we remove F.pad operator. + if L - T < 0: + xs = xs[:, :L - T, :] + else: + dummy_pad = paddle.zeros([B, L - T, D], dtype=paddle.float32) + xs = paddle.concat([xs, dummy_pad], axis=1) + + xs_lens = (xs_lens + 1) // 2 + return xs, xs_lens, mask, mask_pad From fe8bbcc226f422ede56199563f000f722e16c7c3 Mon Sep 17 00:00:00 2001 From: yeyupiaoling Date: Wed, 11 Jan 2023 10:54:45 +0800 Subject: [PATCH 9/9] remove rel_shift, test=asr --- examples/aishell/asr1/conf/chunk_squeezeformer.yaml | 1 - examples/aishell/asr1/conf/squeezeformer.yaml | 2 +- paddlespeech/s2t/modules/attention.py | 5 +---- paddlespeech/s2t/modules/encoder.py | 5 +---- 4 files changed, 3 insertions(+), 10 deletions(-) diff --git a/examples/aishell/asr1/conf/chunk_squeezeformer.yaml b/examples/aishell/asr1/conf/chunk_squeezeformer.yaml index 45a2ac965c8..35a90b7d697 100644 --- a/examples/aishell/asr1/conf/chunk_squeezeformer.yaml +++ b/examples/aishell/asr1/conf/chunk_squeezeformer.yaml @@ -21,7 +21,6 @@ encoder_conf: normalize_before: false activation_type: 'swish' pos_enc_layer_type: 'rel_pos' - do_rel_shift: false time_reduction_layer_type: 'stream' causal: true use_dynamic_chunk: true diff --git a/examples/aishell/asr1/conf/squeezeformer.yaml b/examples/aishell/asr1/conf/squeezeformer.yaml index 49a837a8271..b7841aca50b 100644 --- a/examples/aishell/asr1/conf/squeezeformer.yaml +++ b/examples/aishell/asr1/conf/squeezeformer.yaml @@ -21,7 +21,7 @@ encoder_conf: normalize_before: false activation_type: 'swish' pos_enc_layer_type: 'rel_pos' - time_reduction_layer_type: 'conv2d' + time_reduction_layer_type: 'conv1d' # decoder related decoder: transformer diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index 43700ca1ecc..14336c03d93 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -204,7 +204,6 @@ def __init__(self, n_head, n_feat, dropout_rate, - do_rel_shift=False, adaptive_scale=False, init_weights=False): """Construct an RelPositionMultiHeadedAttention object. @@ -229,7 +228,6 @@ def __init__(self, pos_bias_v = self.create_parameter( (self.h, self.d_k), default_initializer=I.XavierUniform()) self.add_parameter('pos_bias_v', pos_bias_v) - self.do_rel_shift = do_rel_shift self.adaptive_scale = adaptive_scale if self.adaptive_scale: ada_scale = self.create_parameter( @@ -369,8 +367,7 @@ def forward(self, matrix_bd = paddle.matmul(q_with_bias_v, p, transpose_y=True) # Remove rel_shift since it is useless in speech recognition, # and it requires special attention for streaming. - if self.do_rel_shift: - matrix_bd = self.rel_shift(matrix_bd) + # matrix_bd = self.rel_shift(matrix_bd) scores = (matrix_ac + matrix_bd) / math.sqrt( self.d_k) # (batch, head, time1, time2) diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index 7be1925751b..d90d69d7744 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -515,7 +515,6 @@ def __init__(self, input_dropout_rate: float=0.1, pos_enc_layer_type: str="rel_pos", time_reduction_layer_type: str="conv1d", - do_rel_shift: bool=True, feed_forward_dropout_rate: float=0.1, attention_dropout_rate: float=0.1, cnn_module_kernel: int=31, @@ -549,8 +548,6 @@ def __init__(self, input_dropout_rate (float): Dropout rate of input projection layer. pos_enc_layer_type (str): Self attention type. time_reduction_layer_type (str): Conv1d or Conv2d reduction layer. - do_rel_shift (bool): Whether to do relative shift - operation on rel-attention module. cnn_module_kernel (int): Kernel size of CNN module. activation_type (str): Encoder activation function type. cnn_module_kernel (int): Kernel size of convolution module. @@ -590,7 +587,7 @@ def __init__(self, else: encoder_selfattn_layer = RelPositionMultiHeadedAttention encoder_selfattn_layer_args = (attention_heads, encoder_dim, - attention_dropout_rate, do_rel_shift, + attention_dropout_rate, adaptive_scale, init_weights) # feed-forward module definition