Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ASR] change default initializer to kaiming_uniform #1577

Merged
merged 6 commits into from
Mar 22, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/aishell/asr1/conf/conformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
init_type: 'kaiming_uniform'

###########################################
# Data #
Expand Down
4 changes: 2 additions & 2 deletions paddlespeech/s2t/exps/u2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def setup_dataloader(self):
n_iter_processes=config.num_workers,
subsampling_factor=1,
num_encs=1,
dist_sampler=False,
dist_sampler=True,
Jackwaterveg marked this conversation as resolved.
Show resolved Hide resolved
shortest_first=False)

self.valid_loader = BatchDataLoader(
Expand All @@ -260,7 +260,7 @@ def setup_dataloader(self):
n_iter_processes=config.num_workers,
subsampling_factor=1,
num_encs=1,
dist_sampler=False,
dist_sampler=True,
shortest_first=False)
logger.info("Setup train/valid Dataloader!")
else:
Expand Down
10 changes: 8 additions & 2 deletions paddlespeech/s2t/models/u2/u2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from paddlespeech.s2t.modules.decoder import TransformerDecoder
from paddlespeech.s2t.modules.encoder import ConformerEncoder
from paddlespeech.s2t.modules.encoder import TransformerEncoder
from paddlespeech.s2t.modules.initializer import DefaultInitializerContext
from paddlespeech.s2t.modules.loss import LabelSmoothingLoss
from paddlespeech.s2t.modules.mask import make_pad_mask
from paddlespeech.s2t.modules.mask import mask_finished_preds
Expand All @@ -50,6 +51,7 @@
from paddlespeech.s2t.utils.tensor_utils import th_accuracy
from paddlespeech.s2t.utils.utility import log_add
from paddlespeech.s2t.utils.utility import UpdateConfig
# from paddlespeech.s2t.modules.initializer import initialize
zh794390558 marked this conversation as resolved.
Show resolved Hide resolved

__all__ = ["U2Model", "U2InferModel"]

Expand All @@ -72,6 +74,7 @@ def __init__(self,
assert 0.0 <= ctc_weight <= 1.0, ctc_weight

nn.Layer.__init__(self)

# note that eos is the same as sos (equivalent ID)
self.sos = vocab_size - 1
self.eos = vocab_size - 1
Expand Down Expand Up @@ -780,9 +783,12 @@ def encode(self, x):

class U2Model(U2DecodeModel):
def __init__(self, configs: dict):
vocab_size, encoder, decoder, ctc = U2Model._init_from_config(configs)

model_conf = configs.get('model_conf', dict())
zh794390558 marked this conversation as resolved.
Show resolved Hide resolved
init_type = model_conf.get("init_type", None)
with DefaultInitializerContext(init_type):
vocab_size, encoder, decoder, ctc = U2Model._init_from_config(
configs)

super().__init__(
vocab_size=vocab_size,
encoder=encoder,
Expand Down
12 changes: 7 additions & 5 deletions paddlespeech/s2t/modules/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from paddle import nn
from paddle.nn import functional as F

from paddlespeech.s2t.modules.align import Conv2D
from paddlespeech.s2t.modules.align import Linear
from paddlespeech.s2t.utils.log import Log

logger = Log(__name__).getlog()
Expand Down Expand Up @@ -51,7 +53,7 @@ def __init__(self, idim: int):
idim (int): input and output dimension
"""
super().__init__()
self.fc = nn.Linear(idim, idim * 2)
self.fc = Linear(idim, idim * 2)

def forward(self, xs):
return glu(self.fc(xs), dim=-1)
Expand All @@ -75,7 +77,7 @@ def __init__(self, kernel_size, in_ch, out_ch, bottlececk_dim=0,
self.conv_residual = None
if in_ch != out_ch:
self.conv_residual = nn.utils.weight_norm(
nn.Conv2D(
Conv2D(
in_channels=in_ch, out_channels=out_ch, kernel_size=(1, 1)),
name='weight',
dim=0)
Expand All @@ -86,7 +88,7 @@ def __init__(self, kernel_size, in_ch, out_ch, bottlececk_dim=0,
layers = OrderedDict()
if bottlececk_dim == 0:
layers['conv'] = nn.utils.weight_norm(
nn.Conv2D(
Conv2D(
in_channels=in_ch,
out_channels=out_ch * 2,
kernel_size=(kernel_size, 1)),
Expand All @@ -106,7 +108,7 @@ def __init__(self, kernel_size, in_ch, out_ch, bottlececk_dim=0,
dim=0)
layers['dropout_in'] = nn.Dropout(p=dropout)
layers['conv_bottleneck'] = nn.utils.weight_norm(
nn.Conv2D(
Conv2D(
in_channels=bottlececk_dim,
out_channels=bottlececk_dim,
kernel_size=(kernel_size, 1)),
Expand All @@ -115,7 +117,7 @@ def __init__(self, kernel_size, in_ch, out_ch, bottlececk_dim=0,
layers['dropout'] = nn.Dropout(p=dropout)
layers['glu'] = GLU()
layers['conv_out'] = nn.utils.weight_norm(
nn.Conv2D(
Conv2D(
in_channels=bottlececk_dim,
out_channels=out_ch * 2,
kernel_size=(1, 1)),
Expand Down
139 changes: 139 additions & 0 deletions paddlespeech/s2t/modules/align.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn

from paddlespeech.s2t.modules.initializer import KaimingUniform
"""
To align the initializer between paddle and torch,
the API below are set defalut initializer with priority higger than global initializer.
"""
global_init_type = None


class LayerNorm(nn.LayerNorm):
def __init__(self,
normalized_shape,
epsilon=1e-05,
weight_attr=None,
bias_attr=None,
name=None):
if weight_attr is None:
weight_attr = paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0))
if bias_attr is None:
bias_attr = paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0))
super(LayerNorm, self).__init__(normalized_shape, epsilon, weight_attr,
bias_attr, name)


class BatchNorm1D(nn.BatchNorm1D):
def __init__(self,
num_features,
momentum=0.9,
epsilon=1e-05,
weight_attr=None,
bias_attr=None,
data_format='NCL',
name=None):
if weight_attr is None:
weight_attr = paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0))
if bias_attr is None:
bias_attr = paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0))
super(BatchNorm1D,
self).__init__(num_features, momentum, epsilon, weight_attr,
bias_attr, data_format, name)


class Embedding(nn.Embedding):
def __init__(self,
num_embeddings,
embedding_dim,
padding_idx=None,
sparse=False,
weight_attr=None,
name=None):
if weight_attr is None:
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal())
super(Embedding, self).__init__(num_embeddings, embedding_dim,
padding_idx, sparse, weight_attr, name)


class Linear(nn.Linear):
def __init__(self,
in_features,
out_features,
weight_attr=None,
bias_attr=None,
name=None):
if weight_attr is None:
if global_init_type == "kaiming_uniform":
weight_attr = paddle.ParamAttr(initializer=KaimingUniform())
if bias_attr is None:
if global_init_type == "kaiming_uniform":
bias_attr = paddle.ParamAttr(initializer=KaimingUniform())
super(Linear, self).__init__(in_features, out_features, weight_attr,
bias_attr, name)


class Conv1D(nn.Conv1D):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
padding_mode='zeros',
weight_attr=None,
bias_attr=None,
data_format='NCL'):
if weight_attr is None:
if global_init_type == "kaiming_uniform":
print("set kaiming_uniform")
weight_attr = paddle.ParamAttr(initializer=KaimingUniform())
if bias_attr is None:
if global_init_type == "kaiming_uniform":
bias_attr = paddle.ParamAttr(initializer=KaimingUniform())
super(Conv1D, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, padding_mode, weight_attr, bias_attr, data_format)


class Conv2D(nn.Conv2D):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
padding_mode='zeros',
weight_attr=None,
bias_attr=None,
data_format='NCHW'):
if weight_attr is None:
if global_init_type == "kaiming_uniform":
weight_attr = paddle.ParamAttr(initializer=KaimingUniform())
if bias_attr is None:
if global_init_type == "kaiming_uniform":
bias_attr = paddle.ParamAttr(initializer=KaimingUniform())
super(Conv2D, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, padding_mode, weight_attr, bias_attr, data_format)
13 changes: 7 additions & 6 deletions paddlespeech/s2t/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from paddle import nn
from paddle.nn import initializer as I

from paddlespeech.s2t.modules.align import Linear
from paddlespeech.s2t.utils.log import Log

logger = Log(__name__).getlog()
Expand All @@ -48,10 +49,10 @@ def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.linear_q = Linear(n_feat, n_feat)
self.linear_k = Linear(n_feat, n_feat)
self.linear_v = Linear(n_feat, n_feat)
self.linear_out = Linear(n_feat, n_feat)
self.dropout = nn.Dropout(p=dropout_rate)

def forward_qkv(self,
Expand Down Expand Up @@ -95,7 +96,7 @@ def forward_attention(self,
mask (paddle.Tensor): Mask, size (#batch, 1, time2) or
(#batch, time1, time2).
Returns:
paddle.Tensor: Transformed value weighted
paddle.Tensor: Transformed value weighted
by the attention score, (#batch, time1, d_model).
"""
n_batch = value.shape[0]
Expand Down Expand Up @@ -150,7 +151,7 @@ def __init__(self, n_head, n_feat, dropout_rate):
"""
super().__init__(n_head, n_feat, dropout_rate)
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias_attr=False)
self.linear_pos = Linear(n_feat, n_feat, bias_attr=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
#self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
Expand Down
17 changes: 10 additions & 7 deletions paddlespeech/s2t/modules/conformer_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from paddle import nn
from typeguard import check_argument_types

from paddlespeech.s2t.modules.align import BatchNorm1D
from paddlespeech.s2t.modules.align import Conv1D
from paddlespeech.s2t.modules.align import LayerNorm
from paddlespeech.s2t.utils.log import Log

logger = Log(__name__).getlog()
Expand Down Expand Up @@ -49,7 +52,7 @@ def __init__(self,
"""
assert check_argument_types()
super().__init__()
self.pointwise_conv1 = nn.Conv1D(
self.pointwise_conv1 = Conv1D(
channels,
2 * channels,
kernel_size=1,
Expand All @@ -60,8 +63,8 @@ def __init__(self,
)

# self.lorder is used to distinguish if it's a causal convolution,
# if self.lorder > 0:
# it's a causal convolution, the input will be padded with
# if self.lorder > 0:
# it's a causal convolution, the input will be padded with
# `self.lorder` frames on the left in forward (causal conv impl).
# else: it's a symmetrical convolution
if causal:
Expand All @@ -73,7 +76,7 @@ def __init__(self,
padding = (kernel_size - 1) // 2
self.lorder = 0

self.depthwise_conv = nn.Conv1D(
self.depthwise_conv = Conv1D(
channels,
channels,
kernel_size,
Expand All @@ -87,12 +90,12 @@ def __init__(self,
assert norm in ['batch_norm', 'layer_norm']
if norm == "batch_norm":
self.use_layer_norm = False
self.norm = nn.BatchNorm1D(channels)
self.norm = BatchNorm1D(channels)
else:
self.use_layer_norm = True
self.norm = nn.LayerNorm(channels)
self.norm = LayerNorm(channels)

self.pointwise_conv2 = nn.Conv1D(
self.pointwise_conv2 = Conv1D(
channels,
channels,
kernel_size=1,
Expand Down
3 changes: 2 additions & 1 deletion paddlespeech/s2t/modules/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from paddle.nn import functional as F
from typeguard import check_argument_types

from paddlespeech.s2t.modules.align import Linear
from paddlespeech.s2t.modules.loss import CTCLoss
from paddlespeech.s2t.utils import ctc_utils
from paddlespeech.s2t.utils.log import Log
Expand Down Expand Up @@ -69,7 +70,7 @@ def __init__(self,
self.blank_id = blank_id
self.odim = odim
self.dropout = nn.Dropout(dropout_rate)
self.ctc_lo = nn.Linear(enc_n_units, self.odim)
self.ctc_lo = Linear(enc_n_units, self.odim)
reduction_type = "sum" if reduction else "none"
self.criterion = CTCLoss(
blank=self.blank_id,
Expand Down
Loading