From ab16d8ce3c4487574b7ef157b07be092fe41c5d2 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Thu, 17 Mar 2022 03:22:33 +0000 Subject: [PATCH 1/5] change default initializer to kaiming_uniform, test=asr --- examples/aishell/asr1/conf/conformer.yaml | 1 + paddlespeech/s2t/__init__.py | 6 + paddlespeech/s2t/models/u2/u2.py | 9 +- paddlespeech/s2t/modules/attention.py | 2 +- .../s2t/modules/conformer_convolution.py | 18 +- paddlespeech/s2t/modules/decoder.py | 15 +- paddlespeech/s2t/modules/decoder_layer.py | 24 +- paddlespeech/s2t/modules/encoder.py | 9 +- paddlespeech/s2t/modules/encoder_layer.py | 42 ++- paddlespeech/s2t/modules/initializer.py | 272 ++++++++++++++++++ paddlespeech/s2t/modules/nets_utils.py | 44 +++ 11 files changed, 423 insertions(+), 19 deletions(-) create mode 100644 paddlespeech/s2t/modules/initializer.py create mode 100644 paddlespeech/s2t/modules/nets_utils.py diff --git a/examples/aishell/asr1/conf/conformer.yaml b/examples/aishell/asr1/conf/conformer.yaml index 775a4527d49..679a5bf6613 100644 --- a/examples/aishell/asr1/conf/conformer.yaml +++ b/examples/aishell/asr1/conf/conformer.yaml @@ -37,6 +37,7 @@ model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option length_normalized_loss: false + init_type: 'kaiming_uniform' ########################################### # Data # diff --git a/paddlespeech/s2t/__init__.py b/paddlespeech/s2t/__init__.py index 855ceef96f5..26f59dfb13f 100644 --- a/paddlespeech/s2t/__init__.py +++ b/paddlespeech/s2t/__init__.py @@ -21,6 +21,7 @@ from paddle.fluid import core from paddle.nn import functional as F +from paddlespeech.s2t.modules import initializer from paddlespeech.s2t.utils.log import Log #TODO(Hui Zhang): remove fluid import @@ -505,3 +506,8 @@ def update(self, modules: Mapping[str, Layer]) -> None: logger.debug( "register user LayerDict to paddle.nn, remove this when fixed!") setattr(paddle.nn, 'LayerDict', LayerDict) + +""" + hack KaiminigUniform: change limit from np.sqrt(6.0 / float(fan_in)) to np.sqrt(1.0 / float(fan_in)) +""" +paddle.nn.initializer.KaimingUniform = initializer.KaimingUniform diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index 910798127ee..67ec5924af8 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -41,6 +41,7 @@ from paddlespeech.s2t.modules.mask import mask_finished_preds from paddlespeech.s2t.modules.mask import mask_finished_scores from paddlespeech.s2t.modules.mask import subsequent_mask +from paddlespeech.s2t.modules.nets_utils import initialize from paddlespeech.s2t.utils import checkpoint from paddlespeech.s2t.utils import layer_tools from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank @@ -72,6 +73,7 @@ def __init__(self, assert 0.0 <= ctc_weight <= 1.0, ctc_weight nn.Layer.__init__(self) + # note that eos is the same as sos (equivalent ID) self.sos = vocab_size - 1 self.eos = vocab_size - 1 @@ -780,9 +782,14 @@ def encode(self, x): class U2Model(U2DecodeModel): def __init__(self, configs: dict): + model_conf = configs.get('model_conf', dict()) + init_type = model_conf.get("init_type", None) + if init_type is not None: + logger.info(f"Use {init_type} initializer as default initializer") + initialize(self, init_type) vocab_size, encoder, decoder, ctc = U2Model._init_from_config(configs) + nn.initializer.set_global_initializer(None) - model_conf = configs.get('model_conf', dict()) super().__init__( vocab_size=vocab_size, encoder=encoder, diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index 3d5f8cd1d3a..c2f5e503eb4 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -95,7 +95,7 @@ def forward_attention(self, mask (paddle.Tensor): Mask, size (#batch, 1, time2) or (#batch, time1, time2). Returns: - paddle.Tensor: Transformed value weighted + paddle.Tensor: Transformed value weighted by the attention score, (#batch, time1, d_model). """ n_batch = value.shape[0] diff --git a/paddlespeech/s2t/modules/conformer_convolution.py b/paddlespeech/s2t/modules/conformer_convolution.py index 7ec92554eec..256d187c907 100644 --- a/paddlespeech/s2t/modules/conformer_convolution.py +++ b/paddlespeech/s2t/modules/conformer_convolution.py @@ -60,8 +60,8 @@ def __init__(self, ) # self.lorder is used to distinguish if it's a causal convolution, - # if self.lorder > 0: - # it's a causal convolution, the input will be padded with + # if self.lorder > 0: + # it's a causal convolution, the input will be padded with # `self.lorder` frames on the left in forward (causal conv impl). # else: it's a symmetrical convolution if causal: @@ -87,10 +87,20 @@ def __init__(self, assert norm in ['batch_norm', 'layer_norm'] if norm == "batch_norm": self.use_layer_norm = False - self.norm = nn.BatchNorm1D(channels) + self.norm = nn.BatchNorm1D( + channels, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(1.0)), + bias_attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(0.0))) else: self.use_layer_norm = True - self.norm = nn.LayerNorm(channels) + self.norm = nn.LayerNorm( + channels, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(1.0)), + bias_attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(0.0))) self.pointwise_conv2 = nn.Conv1D( channels, diff --git a/paddlespeech/s2t/modules/decoder.py b/paddlespeech/s2t/modules/decoder.py index 6b4d959123b..b0ae27e52e9 100644 --- a/paddlespeech/s2t/modules/decoder.py +++ b/paddlespeech/s2t/modules/decoder.py @@ -76,19 +76,30 @@ def __init__( concat_after: bool=False, ): assert check_argument_types() + nn.Layer.__init__(self) self.selfattention_layer_type = 'selfattn' attention_dim = encoder_output_size if input_layer == "embed": self.embed = nn.Sequential( - nn.Embedding(vocab_size, attention_dim), + nn.Embedding( + vocab_size, + attention_dim, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.Normal())), PositionalEncoding(attention_dim, positional_dropout_rate), ) else: raise ValueError(f"only 'embed' is supported: {input_layer}") self.normalize_before = normalize_before - self.after_norm = nn.LayerNorm(attention_dim, epsilon=1e-12) + self.after_norm = nn.LayerNorm( + attention_dim, + epsilon=1e-12, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(1.0)), + bias_attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(0.0))) self.use_output_layer = use_output_layer self.output_layer = nn.Linear(attention_dim, vocab_size) diff --git a/paddlespeech/s2t/modules/decoder_layer.py b/paddlespeech/s2t/modules/decoder_layer.py index 520b18dea17..8eee5ceb1bc 100644 --- a/paddlespeech/s2t/modules/decoder_layer.py +++ b/paddlespeech/s2t/modules/decoder_layer.py @@ -62,9 +62,27 @@ def __init__( self.self_attn = self_attn self.src_attn = src_attn self.feed_forward = feed_forward - self.norm1 = nn.LayerNorm(size, epsilon=1e-12) - self.norm2 = nn.LayerNorm(size, epsilon=1e-12) - self.norm3 = nn.LayerNorm(size, epsilon=1e-12) + self.norm1 = nn.LayerNorm( + size, + epsilon=1e-12, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(1.0)), + bias_attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(0.0))) + self.norm2 = nn.LayerNorm( + size, + epsilon=1e-12, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(1.0)), + bias_attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(0.0))) + self.norm3 = nn.LayerNorm( + size, + epsilon=1e-12, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(1.0)), + bias_attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(0.0))) self.dropout = nn.Dropout(dropout_rate) self.normalize_before = normalize_before self.concat_after = concat_after diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index 5c8ba0810d0..5f7b8e99dfc 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -129,7 +129,13 @@ def __init__( d_model=output_size, dropout_rate=positional_dropout_rate), ) self.normalize_before = normalize_before - self.after_norm = nn.LayerNorm(output_size, epsilon=1e-12) + self.after_norm = nn.LayerNorm( + output_size, + epsilon=1e-12, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(1.0)), + bias_attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(0.0))) self.static_chunk_size = static_chunk_size self.use_dynamic_chunk = use_dynamic_chunk self.use_dynamic_left_chunk = use_dynamic_left_chunk @@ -457,6 +463,7 @@ def __init__( cnn_module_norm (str): cnn conv norm type, Optional['batch_norm','layer_norm'] """ assert check_argument_types() + super().__init__(input_size, output_size, attention_heads, linear_units, num_blocks, dropout_rate, positional_dropout_rate, attention_dropout_rate, input_layer, diff --git a/paddlespeech/s2t/modules/encoder_layer.py b/paddlespeech/s2t/modules/encoder_layer.py index d39c0695a04..69a3f67bb4e 100644 --- a/paddlespeech/s2t/modules/encoder_layer.py +++ b/paddlespeech/s2t/modules/encoder_layer.py @@ -39,7 +39,7 @@ def __init__( normalize_before: bool=True, concat_after: bool=False, ): """Construct an EncoderLayer object. - + Args: size (int): Input dimension. self_attn (nn.Layer): Self-attention module instance. @@ -147,7 +147,7 @@ def __init__( normalize_before: bool=True, concat_after: bool=False, ): """Construct an EncoderLayer object. - + Args: size (int): Input dimension. self_attn (nn.Layer): Self-attention module instance. @@ -174,18 +174,46 @@ def __init__( self.feed_forward = feed_forward self.feed_forward_macaron = feed_forward_macaron self.conv_module = conv_module - self.norm_ff = nn.LayerNorm(size, epsilon=1e-12) # for the FNN module - self.norm_mha = nn.LayerNorm(size, epsilon=1e-12) # for the MHA module + self.norm_ff = nn.LayerNorm( + size, + epsilon=1e-12, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(1.0)), + bias_attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(0.0))) # for the FNN module + self.norm_mha = nn.LayerNorm( + size, + epsilon=1e-12, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(1.0)), + bias_attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(0.0))) # for the MHA module if feed_forward_macaron is not None: - self.norm_ff_macaron = nn.LayerNorm(size, epsilon=1e-12) + self.norm_ff_macaron = nn.LayerNorm( + size, + epsilon=1e-12, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(1.0)), + bias_attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(0.0))) self.ff_scale = 0.5 else: self.ff_scale = 1.0 if self.conv_module is not None: self.norm_conv = nn.LayerNorm( - size, epsilon=1e-12) # for the CNN module + size, + epsilon=1e-12, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(1.0)), + bias_attr=paddle.ParamAttr(initializer=nn.initializer.Constant( + 0.0))) # for the CNN module self.norm_final = nn.LayerNorm( - size, epsilon=1e-12) # for the final output of the block + size, + epsilon=1e-12, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(1.0)), + bias_attr=paddle.ParamAttr(initializer=nn.initializer.Constant( + 0.0))) # for the final output of the block self.dropout = nn.Dropout(dropout_rate) self.size = size self.normalize_before = normalize_before diff --git a/paddlespeech/s2t/modules/initializer.py b/paddlespeech/s2t/modules/initializer.py new file mode 100644 index 00000000000..c91ab231741 --- /dev/null +++ b/paddlespeech/s2t/modules/initializer.py @@ -0,0 +1,272 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from paddle.fluid import framework +from paddle.fluid.framework import in_dygraph_mode, default_main_program +import numpy as np +from paddle.fluid.core import VarDesc +from paddle.fluid import unique_name + +__all__ = [ + 'MSRAInitializer' +] + + +class Initializer(object): + """Base class for variable initializers + + Defines the common interface of variable initializers. + They add operations to the init program that are used + to initialize variables. Users should not use this class + directly, but need to use one of its implementations. + """ + + def __init__(self): + pass + + def __call__(self, param, block=None): + """Add corresponding initialization operations to the network + """ + raise NotImplementedError() + + def _check_block(self, block): + if block is None: + block = default_main_program().global_block() + + return block + + def _compute_fans(self, var): + """Compute the fan_in and the fan_out for layers + + This method computes the fan_in and the fan_out + for neural network layers, if not specified. It is + not possible to perfectly estimate fan_in and fan_out. + This method will estimate it correctly for matrix multiply and + convolutions. + + Args: + var: variable for which fan_in and fan_out have to be computed + + Returns: + tuple of two integers (fan_in, fan_out) + """ + shape = var.shape + if not shape or len(shape) == 0: + fan_in = fan_out = 1 + elif len(shape) == 1: + fan_in = fan_out = shape[0] + elif len(shape) == 2: + # This is the case for simple matrix multiply + fan_in = shape[0] + fan_out = shape[1] + else: + # Assume this to be a convolutional kernel + # In PaddlePaddle, the shape of the kernel is like: + # [num_filters, num_filter_channels, ...] where the remaining + # dimensions are the filter_size + receptive_field_size = np.prod(shape[2:]) + fan_in = shape[1] * receptive_field_size + fan_out = shape[0] * receptive_field_size + + return (fan_in, fan_out) + + + +class MSRAInitializer(Initializer): + r"""Implements the MSRA initializer a.k.a. Kaiming Initializer + + This class implements the weight initialization from the paper + `Delving Deep into Rectifiers: Surpassing Human-Level Performance on + ImageNet Classification `_ + by Kaiming He, Xiangyu Zhang, Shaoqing Ren and Jian Sun. This is a + robust initialization method that particularly considers the rectifier + nonlinearities. In case of Uniform distribution, the range is [-x, x], where + + .. math:: + + x = \sqrt{\\frac{6.0}{fan\_in}} + + In case of Normal distribution, the mean is 0 and the standard deviation + is + + .. math:: + + \sqrt{\\frac{2.0}{fan\_in}} + + Args: + uniform (bool): whether to use uniform or normal distribution + fan_in (float32|None): fan_in for MSRAInitializer. If None, it is\ + inferred from the variable. default is None. + seed (int32): random seed + + Note: + It is recommended to set fan_in to None for most cases. + + Examples: + .. code-block:: python + + import paddle + import paddle.fluid as fluid + paddle.enable_static() + x = fluid.data(name="data", shape=[8, 32, 32], dtype="float32") + fc = fluid.layers.fc(input=x, size=10, + param_attr=fluid.initializer.MSRA(uniform=False)) + + """ + + def __init__(self, uniform=True, fan_in=None, seed=0): + """Constructor for MSRAInitializer + """ + assert uniform is not None + assert seed is not None + super(MSRAInitializer, self).__init__() + self._uniform = uniform + self._fan_in = fan_in + self._seed = seed + + def __call__(self, var, block=None): + """Initialize the input tensor with MSRA initialization. + + Args: + var(Tensor): Tensor that needs to be initialized. + block(Block, optional): The block in which initialization ops + should be added. Used in static graph only, default None. + + Returns: + The initialization op + """ + block = self._check_block(block) + + assert isinstance(var, framework.Variable) + assert isinstance(block, framework.Block) + f_in, f_out = self._compute_fans(var) + + # If fan_in is passed, use it + fan_in = f_in if self._fan_in is None else self._fan_in + + if self._seed == 0: + self._seed = block.program.random_seed + + # to be compatible of fp16 initalizers + if var.dtype == VarDesc.VarType.FP16 or ( + var.dtype == VarDesc.VarType.BF16 and not self._uniform): + out_dtype = VarDesc.VarType.FP32 + out_var = block.create_var( + name=unique_name.generate(".".join( + ['masra_init', var.name, 'tmp'])), + shape=var.shape, + dtype=out_dtype, + type=VarDesc.VarType.LOD_TENSOR, + persistable=False) + else: + out_dtype = var.dtype + out_var = var + + if self._uniform: + limit = np.sqrt(1.0 / float(fan_in)) + op = block.append_op( + type="uniform_random", + inputs={}, + outputs={"Out": out_var}, + attrs={ + "shape": out_var.shape, + "dtype": int(out_dtype), + "min": -limit, + "max": limit, + "seed": self._seed + }, + stop_gradient=True) + + else: + std = np.sqrt(2.0 / float(fan_in)) + op = block.append_op( + type="gaussian_random", + outputs={"Out": out_var}, + attrs={ + "shape": out_var.shape, + "dtype": int(out_dtype), + "mean": 0.0, + "std": std, + "seed": self._seed + }, + stop_gradient=True) + + if var.dtype == VarDesc.VarType.FP16 or ( + var.dtype == VarDesc.VarType.BF16 and not self._uniform): + block.append_op( + type="cast", + inputs={"X": out_var}, + outputs={"Out": var}, + attrs={"in_dtype": out_var.dtype, + "out_dtype": var.dtype}) + + if not framework.in_dygraph_mode(): + var.op = op + return op + +class KaimingUniform(MSRAInitializer): + r"""Implements the Kaiming Uniform initializer + + This class implements the weight initialization from the paper + `Delving Deep into Rectifiers: Surpassing Human-Level Performance on + ImageNet Classification `_ + by Kaiming He, Xiangyu Zhang, Shaoqing Ren and Jian Sun. This is a + robust initialization method that particularly considers the rectifier + nonlinearities. + + In case of Uniform distribution, the range is [-x, x], where + + .. math:: + + x = \sqrt{\frac{6.0}{fan\_in}} + + Args: + fan_in (float32|None): fan_in for Kaiming uniform Initializer. If None, it is\ + inferred from the variable. default is None. + + Note: + It is recommended to set fan_in to None for most cases. + + Examples: + .. code-block:: python + + import paddle + import paddle.nn as nn + + linear = nn.Linear(2, + 4, + weight_attr=nn.initializer.KaimingUniform()) + data = paddle.rand([30, 10, 2], dtype='float32') + res = linear(data) + + """ + + def __init__(self, fan_in=None): + super(KaimingUniform, self).__init__( + uniform=True, fan_in=fan_in, seed=0) + + + +# We short the class name, since users will use the initializer with the package +# name. The sample code: +# +# import paddle.fluid as fluid +# +# hidden = fluid.layers.fc(..., +# param_attr=ParamAttr(fluid.initializer.Xavier())) +# +# It is no need to add an `Initializer` as the class suffix +MSRA = MSRAInitializer diff --git a/paddlespeech/s2t/modules/nets_utils.py b/paddlespeech/s2t/modules/nets_utils.py new file mode 100644 index 00000000000..10915c8c3ea --- /dev/null +++ b/paddlespeech/s2t/modules/nets_utils.py @@ -0,0 +1,44 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +from paddle import nn +from typeguard import check_argument_types + +def initialize(model: nn.Layer, init: str): + """Initialize weights of a neural network module. + + Parameters are initialized using the given method or distribution. + + Custom initialization routines can be implemented into submodules + + Args: + model (nn.Layer): Target. + init (str): Method of initialization. + """ + assert check_argument_types() + + if init == "xavier_uniform": + nn.initializer.set_global_initializer(nn.initializer.XavierUniform(), + nn.initializer.Constant()) + elif init == "xavier_normal": + nn.initializer.set_global_initializer(nn.initializer.XavierNormal(), + nn.initializer.Constant()) + elif init == "kaiming_uniform": + nn.initializer.set_global_initializer(nn.initializer.KaimingUniform(), + nn.initializer.KaimingUniform()) + elif init == "kaiming_normal": + nn.initializer.set_global_initializer(nn.initializer.KaimingNormal(), + nn.initializer.Constant()) + else: + raise ValueError("Unknown initialization: " + init) From d53e1163a60ff992d4ee4790e4cc3f02793e0c7c Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Tue, 22 Mar 2022 05:12:21 +0000 Subject: [PATCH 2/5] update the code, test=asr --- paddlespeech/s2t/__init__.py | 6 - paddlespeech/s2t/exps/u2/model.py | 4 +- paddlespeech/s2t/models/u2/u2.py | 10 +- paddlespeech/s2t/modules/activation.py | 13 +- paddlespeech/s2t/modules/align.py | 74 +++++++ paddlespeech/s2t/modules/attention.py | 11 +- .../s2t/modules/conformer_convolution.py | 25 +-- paddlespeech/s2t/modules/ctc.py | 3 +- paddlespeech/s2t/modules/decoder.py | 19 +- paddlespeech/s2t/modules/decoder_layer.py | 30 +-- paddlespeech/s2t/modules/encoder.py | 10 +- paddlespeech/s2t/modules/encoder_layer.py | 52 ++--- paddlespeech/s2t/modules/initializer.py | 185 +++++------------- paddlespeech/s2t/modules/nets_utils.py | 44 ----- .../s2t/modules/positionwise_feed_forward.py | 5 +- paddlespeech/s2t/modules/subsampling.py | 29 +-- 16 files changed, 196 insertions(+), 324 deletions(-) create mode 100644 paddlespeech/s2t/modules/align.py delete mode 100644 paddlespeech/s2t/modules/nets_utils.py diff --git a/paddlespeech/s2t/__init__.py b/paddlespeech/s2t/__init__.py index 26f59dfb13f..855ceef96f5 100644 --- a/paddlespeech/s2t/__init__.py +++ b/paddlespeech/s2t/__init__.py @@ -21,7 +21,6 @@ from paddle.fluid import core from paddle.nn import functional as F -from paddlespeech.s2t.modules import initializer from paddlespeech.s2t.utils.log import Log #TODO(Hui Zhang): remove fluid import @@ -506,8 +505,3 @@ def update(self, modules: Mapping[str, Layer]) -> None: logger.debug( "register user LayerDict to paddle.nn, remove this when fixed!") setattr(paddle.nn, 'LayerDict', LayerDict) - -""" - hack KaiminigUniform: change limit from np.sqrt(6.0 / float(fan_in)) to np.sqrt(1.0 / float(fan_in)) -""" -paddle.nn.initializer.KaimingUniform = initializer.KaimingUniform diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index d7bee6d7fe7..bcbc15d64ed 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -239,7 +239,7 @@ def setup_dataloader(self): n_iter_processes=config.num_workers, subsampling_factor=1, num_encs=1, - dist_sampler=False, + dist_sampler=True, shortest_first=False) self.valid_loader = BatchDataLoader( @@ -260,7 +260,7 @@ def setup_dataloader(self): n_iter_processes=config.num_workers, subsampling_factor=1, num_encs=1, - dist_sampler=False, + dist_sampler=True, shortest_first=False) logger.info("Setup train/valid Dataloader!") else: diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index 67ec5924af8..e077cd5b7cc 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -41,7 +41,6 @@ from paddlespeech.s2t.modules.mask import mask_finished_preds from paddlespeech.s2t.modules.mask import mask_finished_scores from paddlespeech.s2t.modules.mask import subsequent_mask -from paddlespeech.s2t.modules.nets_utils import initialize from paddlespeech.s2t.utils import checkpoint from paddlespeech.s2t.utils import layer_tools from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank @@ -51,6 +50,8 @@ from paddlespeech.s2t.utils.tensor_utils import th_accuracy from paddlespeech.s2t.utils.utility import log_add from paddlespeech.s2t.utils.utility import UpdateConfig +from paddlespeech.s2t.modules.initializer import DefaultInitializerContext +# from paddlespeech.s2t.modules.initializer import initialize __all__ = ["U2Model", "U2InferModel"] @@ -784,11 +785,8 @@ class U2Model(U2DecodeModel): def __init__(self, configs: dict): model_conf = configs.get('model_conf', dict()) init_type = model_conf.get("init_type", None) - if init_type is not None: - logger.info(f"Use {init_type} initializer as default initializer") - initialize(self, init_type) - vocab_size, encoder, decoder, ctc = U2Model._init_from_config(configs) - nn.initializer.set_global_initializer(None) + with DefaultInitializerContext(init_type): + vocab_size, encoder, decoder, ctc = U2Model._init_from_config(configs) super().__init__( vocab_size=vocab_size, diff --git a/paddlespeech/s2t/modules/activation.py b/paddlespeech/s2t/modules/activation.py index 4081f7f81a5..48c84fa6345 100644 --- a/paddlespeech/s2t/modules/activation.py +++ b/paddlespeech/s2t/modules/activation.py @@ -16,7 +16,8 @@ import paddle from paddle import nn from paddle.nn import functional as F - +from paddlespeech.s2t.modules.align import Linear +from paddlespeech.s2t.modules.align import Conv2D from paddlespeech.s2t.utils.log import Log logger = Log(__name__).getlog() @@ -51,7 +52,7 @@ def __init__(self, idim: int): idim (int): input and output dimension """ super().__init__() - self.fc = nn.Linear(idim, idim * 2) + self.fc = Linear(idim, idim * 2) def forward(self, xs): return glu(self.fc(xs), dim=-1) @@ -75,7 +76,7 @@ def __init__(self, kernel_size, in_ch, out_ch, bottlececk_dim=0, self.conv_residual = None if in_ch != out_ch: self.conv_residual = nn.utils.weight_norm( - nn.Conv2D( + Conv2D( in_channels=in_ch, out_channels=out_ch, kernel_size=(1, 1)), name='weight', dim=0) @@ -86,7 +87,7 @@ def __init__(self, kernel_size, in_ch, out_ch, bottlececk_dim=0, layers = OrderedDict() if bottlececk_dim == 0: layers['conv'] = nn.utils.weight_norm( - nn.Conv2D( + Conv2D( in_channels=in_ch, out_channels=out_ch * 2, kernel_size=(kernel_size, 1)), @@ -106,7 +107,7 @@ def __init__(self, kernel_size, in_ch, out_ch, bottlececk_dim=0, dim=0) layers['dropout_in'] = nn.Dropout(p=dropout) layers['conv_bottleneck'] = nn.utils.weight_norm( - nn.Conv2D( + Conv2D( in_channels=bottlececk_dim, out_channels=bottlececk_dim, kernel_size=(kernel_size, 1)), @@ -115,7 +116,7 @@ def __init__(self, kernel_size, in_ch, out_ch, bottlececk_dim=0, layers['dropout'] = nn.Dropout(p=dropout) layers['glu'] = GLU() layers['conv_out'] = nn.utils.weight_norm( - nn.Conv2D( + Conv2D( in_channels=bottlececk_dim, out_channels=out_ch * 2, kernel_size=(1, 1)), diff --git a/paddlespeech/s2t/modules/align.py b/paddlespeech/s2t/modules/align.py new file mode 100644 index 00000000000..575773d70b0 --- /dev/null +++ b/paddlespeech/s2t/modules/align.py @@ -0,0 +1,74 @@ +import paddle +from paddle import nn +from paddlespeech.s2t.modules.initializer import KaimingUniform + +""" + To align the initializer between paddle and torch, + the API below are set defalut initializer with priority higger than global initializer. +""" +global_init_type = None + + +class LayerNorm(nn.LayerNorm): + def __init__(self, normalized_shape, epsilon=1e-05, weight_attr=None, bias_attr=None, name=None): + if weight_attr is None: + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.Constant(1.0)) + if bias_attr is None: + bias_attr = paddle.ParamAttr( + initializer=nn.initializer.Constant(0.0)) + super(LayerNorm, self).__init__(normalized_shape, epsilon, weight_attr, bias_attr, name) + +class BatchNorm1D(nn.BatchNorm1D): + def __init__(self, num_features, momentum=0.9, epsilon=1e-05, weight_attr=None, bias_attr=None, data_format='NCL', name=None): + if weight_attr is None: + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.Constant(1.0)) + if bias_attr is None: + bias_attr = paddle.ParamAttr( + initializer=nn.initializer.Constant(0.0)) + super(BatchNorm1D, self).__init__(num_features, momentum, epsilon, weight_attr, bias_attr, data_format, name) + +class Embedding(nn.Embedding): + def __init__(self, num_embeddings, embedding_dim, padding_idx=None, sparse=False, weight_attr=None, name=None): + if weight_attr is None: + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.Normal()) + super(Embedding, self).__init__(num_embeddings, embedding_dim, padding_idx, sparse, weight_attr, name) + +class Linear(nn.Linear): + def __init__(self, in_features, out_features, weight_attr=None, bias_attr=None, name=None): + if weight_attr is None: + if global_init_type == "kaiming_uniform": + weight_attr = paddle.ParamAttr( + initializer=KaimingUniform()) + if bias_attr is None: + if global_init_type == "kaiming_uniform": + bias_attr = paddle.ParamAttr( + initializer=KaimingUniform()) + super(Linear, self).__init__(in_features, out_features, weight_attr, bias_attr, name) + +class Conv1D(nn.Conv1D): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', weight_attr=None, bias_attr=None, data_format='NCL'): + if weight_attr is None: + if global_init_type == "kaiming_uniform": + print("set kaiming_uniform") + weight_attr = paddle.ParamAttr( + initializer=KaimingUniform()) + if bias_attr is None: + if global_init_type == "kaiming_uniform": + bias_attr = paddle.ParamAttr( + initializer=KaimingUniform()) + super(Conv1D, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, padding_mode, weight_attr, bias_attr, data_format) + +class Conv2D(nn.Conv2D): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', weight_attr=None, bias_attr=None, data_format='NCHW'): + if weight_attr is None: + if global_init_type == "kaiming_uniform": + weight_attr = paddle.ParamAttr( + initializer=KaimingUniform()) + if bias_attr is None: + if global_init_type == "kaiming_uniform": + bias_attr = paddle.ParamAttr( + initializer=KaimingUniform()) + super(Conv2D, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, padding_mode, weight_attr, bias_attr, data_format) diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index c2f5e503eb4..438efd2a141 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -22,6 +22,7 @@ from paddle import nn from paddle.nn import initializer as I +from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.utils.log import Log logger = Log(__name__).getlog() @@ -48,10 +49,10 @@ def __init__(self, n_head: int, n_feat: int, dropout_rate: float): # We assume d_v always equals d_k self.d_k = n_feat // n_head self.h = n_head - self.linear_q = nn.Linear(n_feat, n_feat) - self.linear_k = nn.Linear(n_feat, n_feat) - self.linear_v = nn.Linear(n_feat, n_feat) - self.linear_out = nn.Linear(n_feat, n_feat) + self.linear_q = Linear(n_feat, n_feat) + self.linear_k = Linear(n_feat, n_feat) + self.linear_v = Linear(n_feat, n_feat) + self.linear_out = Linear(n_feat, n_feat) self.dropout = nn.Dropout(p=dropout_rate) def forward_qkv(self, @@ -150,7 +151,7 @@ def __init__(self, n_head, n_feat, dropout_rate): """ super().__init__(n_head, n_feat, dropout_rate) # linear transformation for positional encoding - self.linear_pos = nn.Linear(n_feat, n_feat, bias_attr=False) + self.linear_pos = Linear(n_feat, n_feat, bias_attr=False) # these two learnable bias are used in matrix c and matrix d # as described in https://arxiv.org/abs/1901.02860 Section 3.3 #self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) diff --git a/paddlespeech/s2t/modules/conformer_convolution.py b/paddlespeech/s2t/modules/conformer_convolution.py index 256d187c907..89e6526885a 100644 --- a/paddlespeech/s2t/modules/conformer_convolution.py +++ b/paddlespeech/s2t/modules/conformer_convolution.py @@ -21,6 +21,9 @@ from paddle import nn from typeguard import check_argument_types +from paddlespeech.s2t.modules.align import BatchNorm1D +from paddlespeech.s2t.modules.align import Conv1D +from paddlespeech.s2t.modules.align import LayerNorm from paddlespeech.s2t.utils.log import Log logger = Log(__name__).getlog() @@ -49,7 +52,7 @@ def __init__(self, """ assert check_argument_types() super().__init__() - self.pointwise_conv1 = nn.Conv1D( + self.pointwise_conv1 = Conv1D( channels, 2 * channels, kernel_size=1, @@ -73,7 +76,7 @@ def __init__(self, padding = (kernel_size - 1) // 2 self.lorder = 0 - self.depthwise_conv = nn.Conv1D( + self.depthwise_conv = Conv1D( channels, channels, kernel_size, @@ -87,22 +90,12 @@ def __init__(self, assert norm in ['batch_norm', 'layer_norm'] if norm == "batch_norm": self.use_layer_norm = False - self.norm = nn.BatchNorm1D( - channels, - weight_attr=paddle.ParamAttr( - initializer=nn.initializer.Constant(1.0)), - bias_attr=paddle.ParamAttr( - initializer=nn.initializer.Constant(0.0))) + self.norm = BatchNorm1D(channels) else: self.use_layer_norm = True - self.norm = nn.LayerNorm( - channels, - weight_attr=paddle.ParamAttr( - initializer=nn.initializer.Constant(1.0)), - bias_attr=paddle.ParamAttr( - initializer=nn.initializer.Constant(0.0))) - - self.pointwise_conv2 = nn.Conv1D( + self.norm = LayerNorm(channels) + + self.pointwise_conv2 = Conv1D( channels, channels, kernel_size=1, diff --git a/paddlespeech/s2t/modules/ctc.py b/paddlespeech/s2t/modules/ctc.py index 2094182af1a..33ad472defb 100644 --- a/paddlespeech/s2t/modules/ctc.py +++ b/paddlespeech/s2t/modules/ctc.py @@ -18,6 +18,7 @@ from paddle.nn import functional as F from typeguard import check_argument_types +from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.modules.loss import CTCLoss from paddlespeech.s2t.utils import ctc_utils from paddlespeech.s2t.utils.log import Log @@ -69,7 +70,7 @@ def __init__(self, self.blank_id = blank_id self.odim = odim self.dropout = nn.Dropout(dropout_rate) - self.ctc_lo = nn.Linear(enc_n_units, self.odim) + self.ctc_lo = Linear(enc_n_units, self.odim) reduction_type = "sum" if reduction else "none" self.criterion = CTCLoss( blank=self.blank_id, diff --git a/paddlespeech/s2t/modules/decoder.py b/paddlespeech/s2t/modules/decoder.py index b0ae27e52e9..3a851ec62c3 100644 --- a/paddlespeech/s2t/modules/decoder.py +++ b/paddlespeech/s2t/modules/decoder.py @@ -24,6 +24,9 @@ from typeguard import check_argument_types from paddlespeech.s2t.decoders.scorers.scorer_interface import BatchScorerInterface +from paddlespeech.s2t.modules.align import Embedding +from paddlespeech.s2t.modules.align import LayerNorm +from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.modules.attention import MultiHeadedAttention from paddlespeech.s2t.modules.decoder_layer import DecoderLayer from paddlespeech.s2t.modules.embedding import PositionalEncoding @@ -83,25 +86,15 @@ def __init__( if input_layer == "embed": self.embed = nn.Sequential( - nn.Embedding( - vocab_size, - attention_dim, - weight_attr=paddle.ParamAttr( - initializer=nn.initializer.Normal())), + Embedding(vocab_size, attention_dim), PositionalEncoding(attention_dim, positional_dropout_rate), ) else: raise ValueError(f"only 'embed' is supported: {input_layer}") self.normalize_before = normalize_before - self.after_norm = nn.LayerNorm( - attention_dim, - epsilon=1e-12, - weight_attr=paddle.ParamAttr( - initializer=nn.initializer.Constant(1.0)), - bias_attr=paddle.ParamAttr( - initializer=nn.initializer.Constant(0.0))) + self.after_norm = LayerNorm(attention_dim, epsilon=1e-12) self.use_output_layer = use_output_layer - self.output_layer = nn.Linear(attention_dim, vocab_size) + self.output_layer = Linear(attention_dim, vocab_size) self.decoders = nn.LayerList([ DecoderLayer( diff --git a/paddlespeech/s2t/modules/decoder_layer.py b/paddlespeech/s2t/modules/decoder_layer.py index 8eee5ceb1bc..b7f8694c126 100644 --- a/paddlespeech/s2t/modules/decoder_layer.py +++ b/paddlespeech/s2t/modules/decoder_layer.py @@ -20,6 +20,8 @@ import paddle from paddle import nn +from paddlespeech.s2t.modules.align import LayerNorm +from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.utils.log import Log logger = Log(__name__).getlog() @@ -62,32 +64,14 @@ def __init__( self.self_attn = self_attn self.src_attn = src_attn self.feed_forward = feed_forward - self.norm1 = nn.LayerNorm( - size, - epsilon=1e-12, - weight_attr=paddle.ParamAttr( - initializer=nn.initializer.Constant(1.0)), - bias_attr=paddle.ParamAttr( - initializer=nn.initializer.Constant(0.0))) - self.norm2 = nn.LayerNorm( - size, - epsilon=1e-12, - weight_attr=paddle.ParamAttr( - initializer=nn.initializer.Constant(1.0)), - bias_attr=paddle.ParamAttr( - initializer=nn.initializer.Constant(0.0))) - self.norm3 = nn.LayerNorm( - size, - epsilon=1e-12, - weight_attr=paddle.ParamAttr( - initializer=nn.initializer.Constant(1.0)), - bias_attr=paddle.ParamAttr( - initializer=nn.initializer.Constant(0.0))) + self.norm1 = LayerNorm(size, epsilon=1e-12) + self.norm2 = LayerNorm(size, epsilon=1e-12) + self.norm3 = LayerNorm(size, epsilon=1e-12) self.dropout = nn.Dropout(dropout_rate) self.normalize_before = normalize_before self.concat_after = concat_after - self.concat_linear1 = nn.Linear(size + size, size) - self.concat_linear2 = nn.Linear(size + size, size) + self.concat_linear1 = Linear(size + size, size) + self.concat_linear2 = Linear(size + size, size) def forward( self, diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index 5f7b8e99dfc..71a2bad40bf 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -23,6 +23,8 @@ from typeguard import check_argument_types from paddlespeech.s2t.modules.activation import get_activation +from paddlespeech.s2t.modules.align import LayerNorm +from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.modules.attention import MultiHeadedAttention from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention from paddlespeech.s2t.modules.conformer_convolution import ConvolutionModule @@ -129,13 +131,7 @@ def __init__( d_model=output_size, dropout_rate=positional_dropout_rate), ) self.normalize_before = normalize_before - self.after_norm = nn.LayerNorm( - output_size, - epsilon=1e-12, - weight_attr=paddle.ParamAttr( - initializer=nn.initializer.Constant(1.0)), - bias_attr=paddle.ParamAttr( - initializer=nn.initializer.Constant(0.0))) + self.after_norm = LayerNorm(output_size, epsilon=1e-12) self.static_chunk_size = static_chunk_size self.use_dynamic_chunk = use_dynamic_chunk self.use_dynamic_left_chunk = use_dynamic_left_chunk diff --git a/paddlespeech/s2t/modules/encoder_layer.py b/paddlespeech/s2t/modules/encoder_layer.py index 69a3f67bb4e..e80a298d621 100644 --- a/paddlespeech/s2t/modules/encoder_layer.py +++ b/paddlespeech/s2t/modules/encoder_layer.py @@ -20,6 +20,8 @@ import paddle from paddle import nn +from paddlespeech.s2t.modules.align import LayerNorm +from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.utils.log import Log logger = Log(__name__).getlog() @@ -59,15 +61,15 @@ def __init__( super().__init__() self.self_attn = self_attn self.feed_forward = feed_forward - self.norm1 = nn.LayerNorm(size, epsilon=1e-12) - self.norm2 = nn.LayerNorm(size, epsilon=1e-12) + self.norm1 = LayerNorm(size, epsilon=1e-12) + self.norm2 = LayerNorm(size, epsilon=1e-12) self.dropout = nn.Dropout(dropout_rate) self.size = size self.normalize_before = normalize_before self.concat_after = concat_after # concat_linear may be not used in forward fuction, # but will be saved in the *.pt - self.concat_linear = nn.Linear(size + size, size) + self.concat_linear = Linear(size + size, size) def forward( self, @@ -174,51 +176,23 @@ def __init__( self.feed_forward = feed_forward self.feed_forward_macaron = feed_forward_macaron self.conv_module = conv_module - self.norm_ff = nn.LayerNorm( - size, - epsilon=1e-12, - weight_attr=paddle.ParamAttr( - initializer=nn.initializer.Constant(1.0)), - bias_attr=paddle.ParamAttr( - initializer=nn.initializer.Constant(0.0))) # for the FNN module - self.norm_mha = nn.LayerNorm( - size, - epsilon=1e-12, - weight_attr=paddle.ParamAttr( - initializer=nn.initializer.Constant(1.0)), - bias_attr=paddle.ParamAttr( - initializer=nn.initializer.Constant(0.0))) # for the MHA module + self.norm_ff = LayerNorm(size, epsilon=1e-12) # for the FNN module + self.norm_mha = LayerNorm(size, epsilon=1e-12) # for the MHA module if feed_forward_macaron is not None: - self.norm_ff_macaron = nn.LayerNorm( - size, - epsilon=1e-12, - weight_attr=paddle.ParamAttr( - initializer=nn.initializer.Constant(1.0)), - bias_attr=paddle.ParamAttr( - initializer=nn.initializer.Constant(0.0))) + self.norm_ff_macaron = LayerNorm(size, epsilon=1e-12) self.ff_scale = 0.5 else: self.ff_scale = 1.0 if self.conv_module is not None: - self.norm_conv = nn.LayerNorm( - size, - epsilon=1e-12, - weight_attr=paddle.ParamAttr( - initializer=nn.initializer.Constant(1.0)), - bias_attr=paddle.ParamAttr(initializer=nn.initializer.Constant( - 0.0))) # for the CNN module - self.norm_final = nn.LayerNorm( - size, - epsilon=1e-12, - weight_attr=paddle.ParamAttr( - initializer=nn.initializer.Constant(1.0)), - bias_attr=paddle.ParamAttr(initializer=nn.initializer.Constant( - 0.0))) # for the final output of the block + self.norm_conv = LayerNorm( + size, epsilon=1e-12) # for the CNN module + self.norm_final = LayerNorm( + size, epsilon=1e-12) # for the final output of the block self.dropout = nn.Dropout(dropout_rate) self.size = size self.normalize_before = normalize_before self.concat_after = concat_after - self.concat_linear = nn.Linear(size + size, size) + self.concat_linear = Linear(size + size, size) def forward( self, diff --git a/paddlespeech/s2t/modules/initializer.py b/paddlespeech/s2t/modules/initializer.py index c91ab231741..3fbab285320 100644 --- a/paddlespeech/s2t/modules/initializer.py +++ b/paddlespeech/s2t/modules/initializer.py @@ -11,93 +11,35 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from __future__ import print_function - -from paddle.fluid import framework -from paddle.fluid.framework import in_dygraph_mode, default_main_program import numpy as np -from paddle.fluid.core import VarDesc +from paddle import nn +from paddle.fluid import framework from paddle.fluid import unique_name +from paddle.fluid.core import VarDesc +from paddle.fluid.framework import default_main_program +from paddle.fluid.framework import in_dygraph_mode +from paddle.fluid.initializer import Initializer +from paddle.fluid.initializer import MSRAInitializer +from typeguard import check_argument_types -__all__ = [ - 'MSRAInitializer' -] - - -class Initializer(object): - """Base class for variable initializers - - Defines the common interface of variable initializers. - They add operations to the init program that are used - to initialize variables. Users should not use this class - directly, but need to use one of its implementations. - """ - - def __init__(self): - pass - - def __call__(self, param, block=None): - """Add corresponding initialization operations to the network - """ - raise NotImplementedError() - - def _check_block(self, block): - if block is None: - block = default_main_program().global_block() - - return block - - def _compute_fans(self, var): - """Compute the fan_in and the fan_out for layers - - This method computes the fan_in and the fan_out - for neural network layers, if not specified. It is - not possible to perfectly estimate fan_in and fan_out. - This method will estimate it correctly for matrix multiply and - convolutions. - - Args: - var: variable for which fan_in and fan_out have to be computed - - Returns: - tuple of two integers (fan_in, fan_out) - """ - shape = var.shape - if not shape or len(shape) == 0: - fan_in = fan_out = 1 - elif len(shape) == 1: - fan_in = fan_out = shape[0] - elif len(shape) == 2: - # This is the case for simple matrix multiply - fan_in = shape[0] - fan_out = shape[1] - else: - # Assume this to be a convolutional kernel - # In PaddlePaddle, the shape of the kernel is like: - # [num_filters, num_filter_channels, ...] where the remaining - # dimensions are the filter_size - receptive_field_size = np.prod(shape[2:]) - fan_in = shape[1] * receptive_field_size - fan_out = shape[0] * receptive_field_size - - return (fan_in, fan_out) - +__all__ = ['KaimingUniform'] -class MSRAInitializer(Initializer): - r"""Implements the MSRA initializer a.k.a. Kaiming Initializer +class KaimingUniform(MSRAInitializer): + r"""Implements the Kaiming Uniform initializer This class implements the weight initialization from the paper `Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification `_ by Kaiming He, Xiangyu Zhang, Shaoqing Ren and Jian Sun. This is a robust initialization method that particularly considers the rectifier - nonlinearities. In case of Uniform distribution, the range is [-x, x], where + nonlinearities. + + In case of Uniform distribution, the range is [-x, x], where .. math:: - x = \sqrt{\\frac{6.0}{fan\_in}} + x = \sqrt{\frac{1.0}{fan\_in}} In case of Normal distribution, the mean is 0 and the standard deviation is @@ -107,10 +49,8 @@ class MSRAInitializer(Initializer): \sqrt{\\frac{2.0}{fan\_in}} Args: - uniform (bool): whether to use uniform or normal distribution - fan_in (float32|None): fan_in for MSRAInitializer. If None, it is\ + fan_in (float32|None): fan_in for Kaiming uniform Initializer. If None, it is\ inferred from the variable. default is None. - seed (int32): random seed Note: It is recommended to set fan_in to None for most cases. @@ -119,23 +59,19 @@ class MSRAInitializer(Initializer): .. code-block:: python import paddle - import paddle.fluid as fluid - paddle.enable_static() - x = fluid.data(name="data", shape=[8, 32, 32], dtype="float32") - fc = fluid.layers.fc(input=x, size=10, - param_attr=fluid.initializer.MSRA(uniform=False)) + import paddle.nn as nn + + linear = nn.Linear(2, + 4, + weight_attr=nn.initializer.KaimingUniform()) + data = paddle.rand([30, 10, 2], dtype='float32') + res = linear(data) """ - def __init__(self, uniform=True, fan_in=None, seed=0): - """Constructor for MSRAInitializer - """ - assert uniform is not None - assert seed is not None - super(MSRAInitializer, self).__init__() - self._uniform = uniform - self._fan_in = fan_in - self._seed = seed + def __init__(self, fan_in=None): + super(KaimingUniform, self).__init__( + uniform=True, fan_in=fan_in, seed=0) def __call__(self, var, block=None): """Initialize the input tensor with MSRA initialization. @@ -165,8 +101,8 @@ def __call__(self, var, block=None): var.dtype == VarDesc.VarType.BF16 and not self._uniform): out_dtype = VarDesc.VarType.FP32 out_var = block.create_var( - name=unique_name.generate(".".join( - ['masra_init', var.name, 'tmp'])), + name=unique_name.generate( + ".".join(['masra_init', var.name, 'tmp'])), shape=var.shape, dtype=out_dtype, type=VarDesc.VarType.LOD_TENSOR, @@ -217,56 +153,23 @@ def __call__(self, var, block=None): var.op = op return op -class KaimingUniform(MSRAInitializer): - r"""Implements the Kaiming Uniform initializer - - This class implements the weight initialization from the paper - `Delving Deep into Rectifiers: Surpassing Human-Level Performance on - ImageNet Classification `_ - by Kaiming He, Xiangyu Zhang, Shaoqing Ren and Jian Sun. This is a - robust initialization method that particularly considers the rectifier - nonlinearities. - - In case of Uniform distribution, the range is [-x, x], where - - .. math:: - - x = \sqrt{\frac{6.0}{fan\_in}} - - Args: - fan_in (float32|None): fan_in for Kaiming uniform Initializer. If None, it is\ - inferred from the variable. default is None. - - Note: - It is recommended to set fan_in to None for most cases. - - Examples: - .. code-block:: python - - import paddle - import paddle.nn as nn - - linear = nn.Linear(2, - 4, - weight_attr=nn.initializer.KaimingUniform()) - data = paddle.rand([30, 10, 2], dtype='float32') - res = linear(data) +class DefaultInitializerContext(object): """ - - def __init__(self, fan_in=None): - super(KaimingUniform, self).__init__( - uniform=True, fan_in=fan_in, seed=0) - + egs: + with DefaultInitializerContext("kaiming_uniform"): + code for setup_model + """ + def __init__(self, init_type=None): + self.init_type = init_type + + def __enter__(self): + from paddlespeech.s2t.modules import align + align.global_init_type = self.init_type + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + from paddlespeech.s2t.modules import align + align.global_init_type = None -# We short the class name, since users will use the initializer with the package -# name. The sample code: -# -# import paddle.fluid as fluid -# -# hidden = fluid.layers.fc(..., -# param_attr=ParamAttr(fluid.initializer.Xavier())) -# -# It is no need to add an `Initializer` as the class suffix -MSRA = MSRAInitializer diff --git a/paddlespeech/s2t/modules/nets_utils.py b/paddlespeech/s2t/modules/nets_utils.py deleted file mode 100644 index 10915c8c3ea..00000000000 --- a/paddlespeech/s2t/modules/nets_utils.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Modified from espnet(https://github.com/espnet/espnet) -from paddle import nn -from typeguard import check_argument_types - -def initialize(model: nn.Layer, init: str): - """Initialize weights of a neural network module. - - Parameters are initialized using the given method or distribution. - - Custom initialization routines can be implemented into submodules - - Args: - model (nn.Layer): Target. - init (str): Method of initialization. - """ - assert check_argument_types() - - if init == "xavier_uniform": - nn.initializer.set_global_initializer(nn.initializer.XavierUniform(), - nn.initializer.Constant()) - elif init == "xavier_normal": - nn.initializer.set_global_initializer(nn.initializer.XavierNormal(), - nn.initializer.Constant()) - elif init == "kaiming_uniform": - nn.initializer.set_global_initializer(nn.initializer.KaimingUniform(), - nn.initializer.KaimingUniform()) - elif init == "kaiming_normal": - nn.initializer.set_global_initializer(nn.initializer.KaimingNormal(), - nn.initializer.Constant()) - else: - raise ValueError("Unknown initialization: " + init) diff --git a/paddlespeech/s2t/modules/positionwise_feed_forward.py b/paddlespeech/s2t/modules/positionwise_feed_forward.py index e2619cd49dc..c2725dc5cc4 100644 --- a/paddlespeech/s2t/modules/positionwise_feed_forward.py +++ b/paddlespeech/s2t/modules/positionwise_feed_forward.py @@ -17,6 +17,7 @@ import paddle from paddle import nn +from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.utils.log import Log logger = Log(__name__).getlog() @@ -44,10 +45,10 @@ def __init__(self, activation (paddle.nn.Layer): Activation function """ super().__init__() - self.w_1 = nn.Linear(idim, hidden_units) + self.w_1 = Linear(idim, hidden_units) self.activation = activation self.dropout = nn.Dropout(dropout_rate) - self.w_2 = nn.Linear(hidden_units, idim) + self.w_2 = Linear(hidden_units, idim) def forward(self, xs: paddle.Tensor) -> paddle.Tensor: """Forward function. diff --git a/paddlespeech/s2t/modules/subsampling.py b/paddlespeech/s2t/modules/subsampling.py index 99a8300f246..88451ddd77f 100644 --- a/paddlespeech/s2t/modules/subsampling.py +++ b/paddlespeech/s2t/modules/subsampling.py @@ -19,6 +19,9 @@ import paddle from paddle import nn +from paddlespeech.s2t.modules.align import Conv2D +from paddlespeech.s2t.modules.align import LayerNorm +from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.modules.embedding import PositionalEncoding from paddlespeech.s2t.utils.log import Log @@ -60,8 +63,8 @@ def __init__(self, """ super().__init__(pos_enc_class) self.out = nn.Sequential( - nn.Linear(idim, odim), - nn.LayerNorm(odim, epsilon=1e-12), + Linear(idim, odim), + LayerNorm(odim, epsilon=1e-12), nn.Dropout(dropout_rate), nn.ReLU(), ) self.right_context = 0 @@ -108,12 +111,12 @@ def __init__(self, """ super().__init__(pos_enc_class) self.conv = nn.Sequential( - nn.Conv2D(1, odim, 3, 2), + Conv2D(1, odim, 3, 2), nn.ReLU(), - nn.Conv2D(odim, odim, 3, 2), + Conv2D(odim, odim, 3, 2), nn.ReLU(), ) self.out = nn.Sequential( - nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)) + Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)) self.subsampling_rate = 4 # The right context for every conv layer is computed by: # (kernel_size - 1) * frame_rate_of_this_layer @@ -160,13 +163,13 @@ def __init__(self, """ super().__init__(pos_enc_class) self.conv = nn.Sequential( - nn.Conv2D(1, odim, 3, 2), + Conv2D(1, odim, 3, 2), nn.ReLU(), - nn.Conv2D(odim, odim, 5, 3), + Conv2D(odim, odim, 5, 3), nn.ReLU(), ) # O = (I - F + Pstart + Pend) // S + 1 # when Padding == 0, O = (I - F - S) // S - self.linear = nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim) + self.linear = Linear(odim * (((idim - 1) // 2 - 2) // 3), odim) # The right context for every conv layer is computed by: # (kernel_size - 1) * frame_rate_of_this_layer # 10 = (3 - 1) * 1 + (5 - 1) * 2 @@ -212,14 +215,14 @@ def __init__(self, """ super().__init__(pos_enc_class) self.conv = nn.Sequential( - nn.Conv2D(1, odim, 3, 2), + Conv2D(1, odim, 3, 2), nn.ReLU(), - nn.Conv2D(odim, odim, 3, 2), + Conv2D(odim, odim, 3, 2), nn.ReLU(), - nn.Conv2D(odim, odim, 3, 2), + Conv2D(odim, odim, 3, 2), nn.ReLU(), ) - self.linear = nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), - odim) + self.linear = Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), + odim) self.subsampling_rate = 8 # The right context for every conv layer is computed by: # (kernel_size - 1) * frame_rate_of_this_layer From a4f5a680742240a471c6264432106eb14ad678d1 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Tue, 22 Mar 2022 07:14:40 +0000 Subject: [PATCH 3/5] fix some format, test=asr --- paddlespeech/s2t/models/u2/u2.py | 5 +- paddlespeech/s2t/modules/activation.py | 3 +- paddlespeech/s2t/modules/align.py | 139 +++++++++++++++++------- paddlespeech/s2t/modules/encoder.py | 1 - paddlespeech/s2t/modules/initializer.py | 12 +- 5 files changed, 110 insertions(+), 50 deletions(-) diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index e077cd5b7cc..e94a127db60 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -36,6 +36,7 @@ from paddlespeech.s2t.modules.decoder import TransformerDecoder from paddlespeech.s2t.modules.encoder import ConformerEncoder from paddlespeech.s2t.modules.encoder import TransformerEncoder +from paddlespeech.s2t.modules.initializer import DefaultInitializerContext from paddlespeech.s2t.modules.loss import LabelSmoothingLoss from paddlespeech.s2t.modules.mask import make_pad_mask from paddlespeech.s2t.modules.mask import mask_finished_preds @@ -50,7 +51,6 @@ from paddlespeech.s2t.utils.tensor_utils import th_accuracy from paddlespeech.s2t.utils.utility import log_add from paddlespeech.s2t.utils.utility import UpdateConfig -from paddlespeech.s2t.modules.initializer import DefaultInitializerContext # from paddlespeech.s2t.modules.initializer import initialize __all__ = ["U2Model", "U2InferModel"] @@ -786,7 +786,8 @@ def __init__(self, configs: dict): model_conf = configs.get('model_conf', dict()) init_type = model_conf.get("init_type", None) with DefaultInitializerContext(init_type): - vocab_size, encoder, decoder, ctc = U2Model._init_from_config(configs) + vocab_size, encoder, decoder, ctc = U2Model._init_from_config( + configs) super().__init__( vocab_size=vocab_size, diff --git a/paddlespeech/s2t/modules/activation.py b/paddlespeech/s2t/modules/activation.py index 48c84fa6345..2f387b0d99b 100644 --- a/paddlespeech/s2t/modules/activation.py +++ b/paddlespeech/s2t/modules/activation.py @@ -16,8 +16,9 @@ import paddle from paddle import nn from paddle.nn import functional as F -from paddlespeech.s2t.modules.align import Linear + from paddlespeech.s2t.modules.align import Conv2D +from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.utils.log import Log logger = Log(__name__).getlog() diff --git a/paddlespeech/s2t/modules/align.py b/paddlespeech/s2t/modules/align.py index 575773d70b0..f8891679361 100644 --- a/paddlespeech/s2t/modules/align.py +++ b/paddlespeech/s2t/modules/align.py @@ -1,7 +1,20 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import paddle from paddle import nn -from paddlespeech.s2t.modules.initializer import KaimingUniform +from paddlespeech.s2t.modules.initializer import KaimingUniform """ To align the initializer between paddle and torch, the API below are set defalut initializer with priority higger than global initializer. @@ -10,65 +23,117 @@ class LayerNorm(nn.LayerNorm): - def __init__(self, normalized_shape, epsilon=1e-05, weight_attr=None, bias_attr=None, name=None): + def __init__(self, + normalized_shape, + epsilon=1e-05, + weight_attr=None, + bias_attr=None, + name=None): if weight_attr is None: weight_attr = paddle.ParamAttr( initializer=nn.initializer.Constant(1.0)) if bias_attr is None: bias_attr = paddle.ParamAttr( initializer=nn.initializer.Constant(0.0)) - super(LayerNorm, self).__init__(normalized_shape, epsilon, weight_attr, bias_attr, name) + super(LayerNorm, self).__init__(normalized_shape, epsilon, weight_attr, + bias_attr, name) + -class BatchNorm1D(nn.BatchNorm1D): - def __init__(self, num_features, momentum=0.9, epsilon=1e-05, weight_attr=None, bias_attr=None, data_format='NCL', name=None): +class BatchNorm1D(nn.BatchNorm1D): + def __init__(self, + num_features, + momentum=0.9, + epsilon=1e-05, + weight_attr=None, + bias_attr=None, + data_format='NCL', + name=None): if weight_attr is None: weight_attr = paddle.ParamAttr( initializer=nn.initializer.Constant(1.0)) if bias_attr is None: bias_attr = paddle.ParamAttr( initializer=nn.initializer.Constant(0.0)) - super(BatchNorm1D, self).__init__(num_features, momentum, epsilon, weight_attr, bias_attr, data_format, name) + super(BatchNorm1D, + self).__init__(num_features, momentum, epsilon, weight_attr, + bias_attr, data_format, name) + class Embedding(nn.Embedding): - def __init__(self, num_embeddings, embedding_dim, padding_idx=None, sparse=False, weight_attr=None, name=None): + def __init__(self, + num_embeddings, + embedding_dim, + padding_idx=None, + sparse=False, + weight_attr=None, + name=None): if weight_attr is None: - weight_attr = paddle.ParamAttr( - initializer=nn.initializer.Normal()) - super(Embedding, self).__init__(num_embeddings, embedding_dim, padding_idx, sparse, weight_attr, name) + weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal()) + super(Embedding, self).__init__(num_embeddings, embedding_dim, + padding_idx, sparse, weight_attr, name) + class Linear(nn.Linear): - def __init__(self, in_features, out_features, weight_attr=None, bias_attr=None, name=None): - if weight_attr is None: - if global_init_type == "kaiming_uniform": - weight_attr = paddle.ParamAttr( - initializer=KaimingUniform()) - if bias_attr is None: - if global_init_type == "kaiming_uniform": - bias_attr = paddle.ParamAttr( - initializer=KaimingUniform()) - super(Linear, self).__init__(in_features, out_features, weight_attr, bias_attr, name) + def __init__(self, + in_features, + out_features, + weight_attr=None, + bias_attr=None, + name=None): + if weight_attr is None: + if global_init_type == "kaiming_uniform": + weight_attr = paddle.ParamAttr(initializer=KaimingUniform()) + if bias_attr is None: + if global_init_type == "kaiming_uniform": + bias_attr = paddle.ParamAttr(initializer=KaimingUniform()) + super(Linear, self).__init__(in_features, out_features, weight_attr, + bias_attr, name) + class Conv1D(nn.Conv1D): - def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', weight_attr=None, bias_attr=None, data_format='NCL'): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + padding_mode='zeros', + weight_attr=None, + bias_attr=None, + data_format='NCL'): if weight_attr is None: if global_init_type == "kaiming_uniform": print("set kaiming_uniform") - weight_attr = paddle.ParamAttr( - initializer=KaimingUniform()) + weight_attr = paddle.ParamAttr(initializer=KaimingUniform()) if bias_attr is None: if global_init_type == "kaiming_uniform": - bias_attr = paddle.ParamAttr( - initializer=KaimingUniform()) - super(Conv1D, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, padding_mode, weight_attr, bias_attr, data_format) - + bias_attr = paddle.ParamAttr(initializer=KaimingUniform()) + super(Conv1D, self).__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, padding_mode, weight_attr, bias_attr, data_format) + + class Conv2D(nn.Conv2D): - def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', weight_attr=None, bias_attr=None, data_format='NCHW'): - if weight_attr is None: - if global_init_type == "kaiming_uniform": - weight_attr = paddle.ParamAttr( - initializer=KaimingUniform()) - if bias_attr is None: - if global_init_type == "kaiming_uniform": - bias_attr = paddle.ParamAttr( - initializer=KaimingUniform()) - super(Conv2D, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, padding_mode, weight_attr, bias_attr, data_format) + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + padding_mode='zeros', + weight_attr=None, + bias_attr=None, + data_format='NCHW'): + if weight_attr is None: + if global_init_type == "kaiming_uniform": + weight_attr = paddle.ParamAttr(initializer=KaimingUniform()) + if bias_attr is None: + if global_init_type == "kaiming_uniform": + bias_attr = paddle.ParamAttr(initializer=KaimingUniform()) + super(Conv2D, self).__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, padding_mode, weight_attr, bias_attr, data_format) diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index 71a2bad40bf..c843c0e2070 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -24,7 +24,6 @@ from paddlespeech.s2t.modules.activation import get_activation from paddlespeech.s2t.modules.align import LayerNorm -from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.modules.attention import MultiHeadedAttention from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention from paddlespeech.s2t.modules.conformer_convolution import ConvolutionModule diff --git a/paddlespeech/s2t/modules/initializer.py b/paddlespeech/s2t/modules/initializer.py index 3fbab285320..98466ebdb0a 100644 --- a/paddlespeech/s2t/modules/initializer.py +++ b/paddlespeech/s2t/modules/initializer.py @@ -12,15 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import numpy as np -from paddle import nn from paddle.fluid import framework from paddle.fluid import unique_name from paddle.fluid.core import VarDesc -from paddle.fluid.framework import default_main_program -from paddle.fluid.framework import in_dygraph_mode -from paddle.fluid.initializer import Initializer from paddle.fluid.initializer import MSRAInitializer -from typeguard import check_argument_types __all__ = ['KaimingUniform'] @@ -160,16 +155,15 @@ class DefaultInitializerContext(object): with DefaultInitializerContext("kaiming_uniform"): code for setup_model """ + def __init__(self, init_type=None): self.init_type = init_type - + def __enter__(self): from paddlespeech.s2t.modules import align align.global_init_type = self.init_type return self - + def __exit__(self, exc_type, exc_val, exc_tb): from paddlespeech.s2t.modules import align align.global_init_type = None - - From 6da8465f146afaeeaf253c9acd13c7397df50065 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Tue, 22 Mar 2022 08:10:52 +0000 Subject: [PATCH 4/5] add dist_sampler args, test=asr --- examples/aishell/asr1/conf/chunk_conformer.yaml | 3 ++- examples/aishell/asr1/conf/conformer.yaml | 3 ++- examples/aishell/asr1/conf/transformer.yaml | 5 +++-- paddlespeech/s2t/models/u2/u2.py | 1 - 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/aishell/asr1/conf/chunk_conformer.yaml b/examples/aishell/asr1/conf/chunk_conformer.yaml index 68e852ba777..1ad77f97e98 100644 --- a/examples/aishell/asr1/conf/chunk_conformer.yaml +++ b/examples/aishell/asr1/conf/chunk_conformer.yaml @@ -70,7 +70,7 @@ batch_bins: 0 batch_frames_in: 0 batch_frames_out: 0 batch_frames_inout: 0 -num_workers: 0 +num_workers: 2 subsampling_factor: 1 num_encs: 1 @@ -80,6 +80,7 @@ num_encs: 1 n_epoch: 240 accum_grad: 2 global_grad_clip: 5.0 +dist_sampler: True optim: adam optim_conf: lr: 0.002 diff --git a/examples/aishell/asr1/conf/conformer.yaml b/examples/aishell/asr1/conf/conformer.yaml index 679a5bf6613..d5d883a031f 100644 --- a/examples/aishell/asr1/conf/conformer.yaml +++ b/examples/aishell/asr1/conf/conformer.yaml @@ -76,6 +76,7 @@ num_encs: 1 n_epoch: 240 accum_grad: 2 global_grad_clip: 5.0 +dist_sampler: True optim: adam optim_conf: lr: 0.002 @@ -84,7 +85,7 @@ scheduler: warmuplr scheduler_conf: warmup_steps: 25000 lr_decay: 1.0 -log_interval: 100 +log_interval: 1 checkpoint: kbest_n: 50 latest_n: 5 diff --git a/examples/aishell/asr1/conf/transformer.yaml b/examples/aishell/asr1/conf/transformer.yaml index 9d2946537b4..9e08ea0ec79 100644 --- a/examples/aishell/asr1/conf/transformer.yaml +++ b/examples/aishell/asr1/conf/transformer.yaml @@ -61,16 +61,17 @@ batch_frames_in: 0 batch_frames_out: 0 batch_frames_inout: 0 preprocess_config: conf/preprocess.yaml -num_workers: 0 +num_workers: 2 subsampling_factor: 1 num_encs: 1 ########################################### # Training # ########################################### -n_epoch: 240 +n_epoch: 30 accum_grad: 2 global_grad_clip: 5.0 +dist_sampler: False optim: adam optim_conf: lr: 0.002 diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index e94a127db60..51388586f97 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -51,7 +51,6 @@ from paddlespeech.s2t.utils.tensor_utils import th_accuracy from paddlespeech.s2t.utils.utility import log_add from paddlespeech.s2t.utils.utility import UpdateConfig -# from paddlespeech.s2t.modules.initializer import initialize __all__ = ["U2Model", "U2InferModel"] From e1b581b622608820b0c36971a2ad64cf2e0923c5 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Tue, 22 Mar 2022 08:59:26 +0000 Subject: [PATCH 5/5] fix some bug, test=asr --- examples/aishell/asr1/conf/conformer.yaml | 2 +- paddlespeech/s2t/exps/u2/model.py | 4 ++-- paddlespeech/s2t/modules/initializer.py | 9 ++++++--- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/examples/aishell/asr1/conf/conformer.yaml b/examples/aishell/asr1/conf/conformer.yaml index d5d883a031f..a150a04d556 100644 --- a/examples/aishell/asr1/conf/conformer.yaml +++ b/examples/aishell/asr1/conf/conformer.yaml @@ -85,7 +85,7 @@ scheduler: warmuplr scheduler_conf: warmup_steps: 25000 lr_decay: 1.0 -log_interval: 1 +log_interval: 100 checkpoint: kbest_n: 50 latest_n: 5 diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index bcbc15d64ed..efcc9629fdb 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -239,7 +239,7 @@ def setup_dataloader(self): n_iter_processes=config.num_workers, subsampling_factor=1, num_encs=1, - dist_sampler=True, + dist_sampler=config.get('dist_sampler', False), shortest_first=False) self.valid_loader = BatchDataLoader( @@ -260,7 +260,7 @@ def setup_dataloader(self): n_iter_processes=config.num_workers, subsampling_factor=1, num_encs=1, - dist_sampler=True, + dist_sampler=config.get('dist_sampler', False), shortest_first=False) logger.info("Setup train/valid Dataloader!") else: diff --git a/paddlespeech/s2t/modules/initializer.py b/paddlespeech/s2t/modules/initializer.py index 98466ebdb0a..30a04e44fb2 100644 --- a/paddlespeech/s2t/modules/initializer.py +++ b/paddlespeech/s2t/modules/initializer.py @@ -160,9 +160,12 @@ def __init__(self, init_type=None): self.init_type = init_type def __enter__(self): - from paddlespeech.s2t.modules import align - align.global_init_type = self.init_type - return self + if self.init_type is None: + return + else: + from paddlespeech.s2t.modules import align + align.global_init_type = self.init_type + return def __exit__(self, exc_type, exc_val, exc_tb): from paddlespeech.s2t.modules import align