Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ASR]add squeezeformer model #2755

Merged
merged 9 commits into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/aishell/asr1/conf/chunk_squeezeformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
cmvn_file:
cmvn_file_type: "json"
# encoder related
encoder: conformer
encoder: squeezeformer
encoder_conf:
encoder_dim: 256 # dimension of attention
output_size: 256 # dimension of output
Expand All @@ -21,7 +21,8 @@ encoder_conf:
normalize_before: false
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
time_reduction_layer_type: 'conv2d'
do_rel_shift: false
time_reduction_layer_type: 'stream'
causal: true
use_dynamic_chunk: true
use_dynamic_left_chunk: false
Expand Down
194 changes: 39 additions & 155 deletions paddlespeech/s2t/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,10 @@ def forward(self,
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding."""

def __init__(self, n_head, n_feat, dropout_rate):
def __init__(self, n_head, n_feat, dropout_rate,
do_rel_shift=False,
adaptive_scale=False,
init_weights=False):
"""Construct an RelPositionMultiHeadedAttention object.
Paper: https://arxiv.org/abs/1901.02860
Args:
Expand All @@ -226,151 +229,15 @@ def __init__(self, n_head, n_feat, dropout_rate):
pos_bias_v = self.create_parameter(
(self.h, self.d_k), default_initializer=I.XavierUniform())
self.add_parameter('pos_bias_v', pos_bias_v)

def rel_shift(self, x, zero_triu: bool=False):
"""Compute relative positinal encoding.
Args:
x (paddle.Tensor): Input tensor (batch, head, time1, time1).
zero_triu (bool): If true, return the lower triangular part of
the matrix.
Returns:
paddle.Tensor: Output tensor. (batch, head, time1, time1)
"""
zero_pad = paddle.zeros(
(x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype)
x_padded = paddle.cat([zero_pad, x], dim=-1)

x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1,
x.shape[2])
x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1]

if zero_triu:
ones = paddle.ones((x.shape[2], x.shape[3]))
x = x * paddle.tril(ones, x.shape[3] - x.shape[2])[None, None, :, :]

return x

def forward(self,
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
pos_emb: paddle.Tensor=paddle.empty([0]),
cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (paddle.Tensor): Query tensor (#batch, time1, size).
key (paddle.Tensor): Key tensor (#batch, time2, size).
value (paddle.Tensor): Value tensor (#batch, time2, size).
mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
pos_emb (paddle.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
paddle.Tensor: Output tensor (#batch, time1, d_model).
paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q, k, v = self.forward_qkv(query, key, value)
# q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)

# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if cache.shape[0] > 0:
# last dim `d_k * 2` for (key, val)
key_cache, value_cache = paddle.split(cache, 2, axis=-1)
k = paddle.concat([key_cache, k], axis=2)
v = paddle.concat([value_cache, v], axis=2)
# We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = paddle.concat((k, v), axis=-1)

n_batch_pos = pos_emb.shape[0]
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)

# (batch, head, time1, d_k)
# q_with_bias_u = (q + self.pos_bias_u).transpose([0, 2, 1, 3])
q_with_bias_u = q + self.pos_bias_u.unsqueeze(1)
# (batch, head, time1, d_k)
# q_with_bias_v = (q + self.pos_bias_v).transpose([0, 2, 1, 3])
q_with_bias_v = q + self.pos_bias_v.unsqueeze(1)

# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
# matrix_ac = paddle.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2]))
matrix_ac = paddle.matmul(q_with_bias_u, k, transpose_y=True)

# compute matrix b and matrix d
# (batch, head, time1, time2)
# matrix_bd = paddle.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2]))
matrix_bd = paddle.matmul(q_with_bias_v, p, transpose_y=True)
# Remove rel_shift since it is useless in speech recognition,
# and it requires special attention for streaming.
# matrix_bd = self.rel_shift(matrix_bd)

scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k) # (batch, head, time1, time2)

return self.forward_attention(v, scores, mask), new_cache


class RelPositionMultiHeadedAttention2(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""

def __init__(self,
n_head,
n_feat,
dropout_rate,
do_rel_shift=False,
adaptive_scale=False,
init_weights=False):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate)
# linear transformation for positional encoding
self.linear_pos = Linear(n_feat, n_feat)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.do_rel_shift = do_rel_shift
pos_bias_u = self.create_parameter(
[self.h, self.d_k], default_initializer=I.XavierUniform())
self.add_parameter('pos_bias_u', pos_bias_u)
pos_bias_v = self.create_parameter(
[self.h, self.d_k], default_initializer=I.XavierUniform())
self.add_parameter('pos_bias_v', pos_bias_v)
self.adaptive_scale = adaptive_scale
zh794390558 marked this conversation as resolved.
Show resolved Hide resolved
ada_scale = self.create_parameter(
[1, 1, n_feat], default_initializer=I.Constant(1.0))
self.add_parameter('ada_scale', ada_scale)
ada_bias = self.create_parameter(
[1, 1, n_feat], default_initializer=I.Constant(0.0))
self.add_parameter('ada_bias', ada_bias)
if self.adaptive_scale:
ada_scale = self.create_parameter(
[1, 1, n_feat], default_initializer=I.Constant(1.0))
self.add_parameter('ada_scale', ada_scale)
ada_bias = self.create_parameter(
[1, 1, n_feat], default_initializer=I.Constant(0.0))
self.add_parameter('ada_bias', ada_bias)
if init_weights:
self.init_weights()

Expand Down Expand Up @@ -407,12 +274,12 @@ def rel_shift(self, x, zero_triu: bool=False):
paddle.Tensor: Output tensor. (batch, head, time1, time1)
"""
zero_pad = paddle.zeros(
[x.shape[0], x.shape[1], x.shape[2], 1], dtype=x.dtype)
x_padded = paddle.concat([zero_pad, x], axis=-1)
(x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype)
x_padded = paddle.cat([zero_pad, x], dim=-1)
zxcd marked this conversation as resolved.
Show resolved Hide resolved

x_padded = x_padded.reshape(
[x.shape[0], x.shape[1], x.shape[3] + 1, x.shape[2]])
x = x_padded[:, :, 1:].reshape(paddle.shape(x)) # [B, H, T1, T1]
x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1,
x.shape[2])
x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1]

if zero_triu:
ones = paddle.ones((x.shape[2], x.shape[3]))
Expand All @@ -424,10 +291,10 @@ def forward(self,
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
mask: paddle.Tensor=paddle.ones((0, 0, 0), dtype=paddle.bool),
mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
pos_emb: paddle.Tensor=paddle.empty([0]),
cache: paddle.Tensor=paddle.zeros(
(0, 0, 0, 0))) -> Tuple[paddle.Tensor, paddle.Tensor]:
cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (paddle.Tensor): Query tensor (#batch, time1, size).
Expand All @@ -452,17 +319,34 @@ def forward(self,
value = self.ada_scale * value + self.ada_bias

q, k, v = self.forward_qkv(query, key, value)
# q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)

# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if cache.shape[0] > 0:
# last dim `d_k * 2` for (key, val)
key_cache, value_cache = paddle.split(cache, 2, axis=-1)
k = paddle.concat([key_cache, k], axis=2)
v = paddle.concat([value_cache, v], axis=2)
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = paddle.concat((k, v), axis=-1)

n_batch_pos = pos_emb.shape[0]
p = self.linear_pos(pos_emb).reshape(
[n_batch_pos, -1, self.h, self.d_k])
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)

# (batch, head, time1, d_k)
Expand Down
43 changes: 42 additions & 1 deletion paddlespeech/s2t/modules/conformer_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import paddle
from paddle import nn
from paddle.nn import initializer as I
from typeguard import check_argument_types

from paddlespeech.s2t.modules.align import BatchNorm1D
Expand All @@ -39,7 +40,9 @@ def __init__(self,
activation: nn.Layer=nn.ReLU(),
norm: str="batch_norm",
causal: bool=False,
bias: bool=True):
bias: bool=True,
adaptive_scale: bool=False,
init_weights: bool=False):
"""Construct an ConvolutionModule object.
Args:
channels (int): The number of channels of conv layers.
Expand All @@ -51,6 +54,19 @@ def __init__(self,
"""
assert check_argument_types()
super().__init__()
self.bias = bias
self.channels = channels
self.kernel_size = kernel_size
self.adaptive_scale = adaptive_scale
if self.adaptive_scale:
ada_scale = self.create_parameter(
[1, 1, channels], default_initializer=I.Constant(1.0))
self.add_parameter('ada_scale', ada_scale)
ada_bias = self.create_parameter(
[1, 1, channels], default_initializer=I.Constant(0.0))
self.add_parameter('ada_bias', ada_bias)


self.pointwise_conv1 = Conv1D(
channels,
2 * channels,
Expand Down Expand Up @@ -105,6 +121,28 @@ def __init__(self,
)
self.activation = activation

if init_weights:
self.init_weights()

def init_weights(self):
pw_max = self.channels**-0.5
dw_max = self.kernel_size**-0.5
self.pointwise_conv1._param_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
if self.bias:
self.pointwise_conv1._bias_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
self.depthwise_conv._param_attr = paddle.nn.initializer.Uniform(
low=-dw_max, high=dw_max)
if self.bias:
self.depthwise_conv._bias_attr = paddle.nn.initializer.Uniform(
low=-dw_max, high=dw_max)
self.pointwise_conv2._param_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
if self.bias:
self.pointwise_conv2._bias_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)

def forward(
self,
x: paddle.Tensor,
Expand All @@ -123,6 +161,9 @@ def forward(
paddle.Tensor: Output tensor (#batch, time, channels).
paddle.Tensor: Output cache tensor (#batch, channels, time')
"""
if self.adaptive_scale:
x = self.ada_scale * x + self.ada_bias

# exchange the temporal dimension and the feature dimension
x = x.transpose([0, 2, 1]) # [B, C, T]

Expand Down
Loading