Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ASR] change default initializer to kaiming_uniform #1577

Merged
merged 6 commits into from
Mar 22, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/aishell/asr1/conf/conformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
init_type: 'kaiming_uniform'

###########################################
# Data #
Expand Down
6 changes: 6 additions & 0 deletions paddlespeech/s2t/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from paddle.fluid import core
from paddle.nn import functional as F

from paddlespeech.s2t.modules import initializer
from paddlespeech.s2t.utils.log import Log

#TODO(Hui Zhang): remove fluid import
Expand Down Expand Up @@ -505,3 +506,8 @@ def update(self, modules: Mapping[str, Layer]) -> None:
logger.debug(
"register user LayerDict to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'LayerDict', LayerDict)

"""
hack KaiminigUniform: change limit from np.sqrt(6.0 / float(fan_in)) to np.sqrt(1.0 / float(fan_in))
"""
paddle.nn.initializer.KaimingUniform = initializer.KaimingUniform
9 changes: 8 additions & 1 deletion paddlespeech/s2t/models/u2/u2.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from paddlespeech.s2t.modules.mask import mask_finished_preds
from paddlespeech.s2t.modules.mask import mask_finished_scores
from paddlespeech.s2t.modules.mask import subsequent_mask
from paddlespeech.s2t.modules.nets_utils import initialize
from paddlespeech.s2t.utils import checkpoint
from paddlespeech.s2t.utils import layer_tools
from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank
Expand Down Expand Up @@ -72,6 +73,7 @@ def __init__(self,
assert 0.0 <= ctc_weight <= 1.0, ctc_weight

nn.Layer.__init__(self)

# note that eos is the same as sos (equivalent ID)
self.sos = vocab_size - 1
self.eos = vocab_size - 1
Expand Down Expand Up @@ -780,9 +782,14 @@ def encode(self, x):

class U2Model(U2DecodeModel):
def __init__(self, configs: dict):
model_conf = configs.get('model_conf', dict())
zh794390558 marked this conversation as resolved.
Show resolved Hide resolved
init_type = model_conf.get("init_type", None)
if init_type is not None:
logger.info(f"Use {init_type} initializer as default initializer")
initialize(self, init_type)
vocab_size, encoder, decoder, ctc = U2Model._init_from_config(configs)
nn.initializer.set_global_initializer(None)

model_conf = configs.get('model_conf', dict())
super().__init__(
vocab_size=vocab_size,
encoder=encoder,
Expand Down
2 changes: 1 addition & 1 deletion paddlespeech/s2t/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def forward_attention(self,
mask (paddle.Tensor): Mask, size (#batch, 1, time2) or
(#batch, time1, time2).
Returns:
paddle.Tensor: Transformed value weighted
paddle.Tensor: Transformed value weighted
by the attention score, (#batch, time1, d_model).
"""
n_batch = value.shape[0]
Expand Down
18 changes: 14 additions & 4 deletions paddlespeech/s2t/modules/conformer_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def __init__(self,
)

# self.lorder is used to distinguish if it's a causal convolution,
# if self.lorder > 0:
# it's a causal convolution, the input will be padded with
# if self.lorder > 0:
# it's a causal convolution, the input will be padded with
# `self.lorder` frames on the left in forward (causal conv impl).
# else: it's a symmetrical convolution
if causal:
Expand All @@ -87,10 +87,20 @@ def __init__(self,
assert norm in ['batch_norm', 'layer_norm']
if norm == "batch_norm":
self.use_layer_norm = False
self.norm = nn.BatchNorm1D(channels)
self.norm = nn.BatchNorm1D(
zh794390558 marked this conversation as resolved.
Show resolved Hide resolved
channels,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0)),
bias_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0)))
else:
self.use_layer_norm = True
self.norm = nn.LayerNorm(channels)
self.norm = nn.LayerNorm(
channels,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0)),
bias_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0)))

self.pointwise_conv2 = nn.Conv1D(
channels,
Expand Down
15 changes: 13 additions & 2 deletions paddlespeech/s2t/modules/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,30 @@ def __init__(
concat_after: bool=False, ):

assert check_argument_types()

nn.Layer.__init__(self)
self.selfattention_layer_type = 'selfattn'
attention_dim = encoder_output_size

if input_layer == "embed":
self.embed = nn.Sequential(
nn.Embedding(vocab_size, attention_dim),
nn.Embedding(
zh794390558 marked this conversation as resolved.
Show resolved Hide resolved
vocab_size,
attention_dim,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Normal())),
PositionalEncoding(attention_dim, positional_dropout_rate), )
else:
raise ValueError(f"only 'embed' is supported: {input_layer}")

self.normalize_before = normalize_before
self.after_norm = nn.LayerNorm(attention_dim, epsilon=1e-12)
self.after_norm = nn.LayerNorm(
zh794390558 marked this conversation as resolved.
Show resolved Hide resolved
attention_dim,
epsilon=1e-12,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0)),
bias_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0)))
self.use_output_layer = use_output_layer
self.output_layer = nn.Linear(attention_dim, vocab_size)

Expand Down
24 changes: 21 additions & 3 deletions paddlespeech/s2t/modules/decoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,27 @@ def __init__(
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = nn.LayerNorm(size, epsilon=1e-12)
self.norm2 = nn.LayerNorm(size, epsilon=1e-12)
self.norm3 = nn.LayerNorm(size, epsilon=1e-12)
self.norm1 = nn.LayerNorm(
size,
epsilon=1e-12,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0)),
bias_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0)))
self.norm2 = nn.LayerNorm(
size,
epsilon=1e-12,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0)),
bias_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0)))
self.norm3 = nn.LayerNorm(
size,
epsilon=1e-12,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0)),
bias_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0)))
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
self.concat_after = concat_after
Expand Down
9 changes: 8 additions & 1 deletion paddlespeech/s2t/modules/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,13 @@ def __init__(
d_model=output_size, dropout_rate=positional_dropout_rate), )

self.normalize_before = normalize_before
self.after_norm = nn.LayerNorm(output_size, epsilon=1e-12)
self.after_norm = nn.LayerNorm(
output_size,
epsilon=1e-12,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0)),
bias_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0)))
self.static_chunk_size = static_chunk_size
self.use_dynamic_chunk = use_dynamic_chunk
self.use_dynamic_left_chunk = use_dynamic_left_chunk
Expand Down Expand Up @@ -457,6 +463,7 @@ def __init__(
cnn_module_norm (str): cnn conv norm type, Optional['batch_norm','layer_norm']
"""
assert check_argument_types()

super().__init__(input_size, output_size, attention_heads, linear_units,
num_blocks, dropout_rate, positional_dropout_rate,
attention_dropout_rate, input_layer,
Expand Down
42 changes: 35 additions & 7 deletions paddlespeech/s2t/modules/encoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
normalize_before: bool=True,
concat_after: bool=False, ):
"""Construct an EncoderLayer object.

Args:
size (int): Input dimension.
self_attn (nn.Layer): Self-attention module instance.
Expand Down Expand Up @@ -147,7 +147,7 @@ def __init__(
normalize_before: bool=True,
concat_after: bool=False, ):
"""Construct an EncoderLayer object.

Args:
size (int): Input dimension.
self_attn (nn.Layer): Self-attention module instance.
Expand All @@ -174,18 +174,46 @@ def __init__(
self.feed_forward = feed_forward
self.feed_forward_macaron = feed_forward_macaron
self.conv_module = conv_module
self.norm_ff = nn.LayerNorm(size, epsilon=1e-12) # for the FNN module
self.norm_mha = nn.LayerNorm(size, epsilon=1e-12) # for the MHA module
self.norm_ff = nn.LayerNorm(
size,
epsilon=1e-12,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0)),
bias_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0))) # for the FNN module
self.norm_mha = nn.LayerNorm(
size,
epsilon=1e-12,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0)),
bias_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0))) # for the MHA module
if feed_forward_macaron is not None:
self.norm_ff_macaron = nn.LayerNorm(size, epsilon=1e-12)
self.norm_ff_macaron = nn.LayerNorm(
size,
epsilon=1e-12,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0)),
bias_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0)))
self.ff_scale = 0.5
else:
self.ff_scale = 1.0
if self.conv_module is not None:
self.norm_conv = nn.LayerNorm(
size, epsilon=1e-12) # for the CNN module
size,
epsilon=1e-12,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0)),
bias_attr=paddle.ParamAttr(initializer=nn.initializer.Constant(
0.0))) # for the CNN module
self.norm_final = nn.LayerNorm(
size, epsilon=1e-12) # for the final output of the block
size,
epsilon=1e-12,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0)),
bias_attr=paddle.ParamAttr(initializer=nn.initializer.Constant(
0.0))) # for the final output of the block
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
Expand Down
Loading