Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TTS]Fix diffusion wavenet denoiser final conv init param #2868

Merged
merged 3 commits into from
Feb 2, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 14 additions & 20 deletions paddlespeech/t2s/modules/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class WaveNetDenoiser(nn.Layer):
layers (int, optional):
Number of residual blocks inside, by default 20
stacks (int, optional):
The number of groups to split the residual blocks into, by default 4
The number of groups to split the residual blocks into, by default 5
Within each group, the dilation of the residual block grows exponentially.
residual_channels (int, optional):
Residual channel of the residual blocks, by default 256
Expand All @@ -64,15 +64,15 @@ def __init__(
out_channels: int=80,
kernel_size: int=3,
layers: int=20,
stacks: int=4,
stacks: int=5,
residual_channels: int=256,
gate_channels: int=512,
skip_channels: int=256,
aux_channels: int=256,
dropout: float=0.,
bias: bool=True,
use_weight_norm: bool=False,
init_type: str="kaiming_uniform", ):
init_type: str="kaiming_normal", ):
super().__init__()

# initialize parameters
Expand Down Expand Up @@ -118,18 +118,15 @@ def __init__(
bias=bias)
self.conv_layers.append(conv)

final_conv = nn.Conv1D(skip_channels, out_channels, 1, bias_attr=True)
nn.initializer.Constant(0.0)(final_conv.weight)
self.last_conv_layers = nn.Sequential(nn.ReLU(),
nn.Conv1D(
skip_channels,
skip_channels,
1,
bias_attr=True),
nn.ReLU(),
nn.Conv1D(
skip_channels,
out_channels,
1,
bias_attr=True))
nn.ReLU(), final_conv)

if use_weight_norm:
self.apply_weight_norm()
Expand Down Expand Up @@ -200,10 +197,6 @@ class GaussianDiffusion(nn.Layer):
Args:
denoiser (Layer, optional):
The model used for denoising noises.
In fact, the denoiser model performs the operation
of producing a output with more noises from the noisy input.
Then we use the diffusion algorithm to calculate
the input with the output to get the denoised result.
num_train_timesteps (int, optional):
The number of timesteps between the noise and the real during training, by default 1000.
beta_start (float, optional):
Expand Down Expand Up @@ -233,7 +226,8 @@ class GaussianDiffusion(nn.Layer):
>>> def callback(index, timestep, num_timesteps, sample):
>>> nonlocal pbar
>>> if pbar is None:
>>> pbar = tqdm(total=num_timesteps-index)
>>> pbar = tqdm(total=num_timesteps)
>>> pbar.update(index)
>>> pbar.update()
>>>
>>> return callback
Expand All @@ -247,7 +241,7 @@ class GaussianDiffusion(nn.Layer):
>>> diffusion = GaussianDiffusion(denoiser, num_train_timesteps=ds, num_max_timesteps=K_step)
>>> with paddle.no_grad():
>>> sample = diffusion.inference(
>>> paddle.randn(x.shape), c, x,
>>> paddle.randn(x.shape), c, ref_x=x_in,
>>> num_inference_steps=infer_steps,
>>> scheduler_type=scheduler_type,
>>> callback=create_progress_callback())
Expand All @@ -262,7 +256,7 @@ class GaussianDiffusion(nn.Layer):
>>> diffusion = GaussianDiffusion(denoiser, num_train_timesteps=ds, num_max_timesteps=K_step)
>>> with paddle.no_grad():
>>> sample = diffusion.inference(
>>> paddle.randn(x.shape), c, x_in,
>>> paddle.randn(x.shape), c, ref_x=x_in,
>>> num_inference_steps=infer_steps,
>>> scheduler_type=scheduler_type,
>>> callback=create_progress_callback())
Expand All @@ -277,11 +271,11 @@ class GaussianDiffusion(nn.Layer):
>>> diffusion = GaussianDiffusion(denoiser, num_train_timesteps=ds, num_max_timesteps=K_step)
>>> with paddle.no_grad():
>>> sample = diffusion.inference(
>>> paddle.randn(x.shape), c, None,
>>> paddle.randn(x.shape), c, ref_x=x_in,
>>> num_inference_steps=infer_steps,
>>> scheduler_type=scheduler_type,
>>> callback=create_progress_callback())
100%|█████| 25/25 [00:01<00:00, 19.75it/s]
100%|█████| 34/34 [00:01<00:00, 19.75it/s]
>>>
>>> # ds=1000, K_step=100, scheduler=pndm, infer_step=50, from aux fs2 mel output
>>> ds = 1000
Expand All @@ -292,11 +286,11 @@ class GaussianDiffusion(nn.Layer):
>>> diffusion = GaussianDiffusion(denoiser, num_train_timesteps=ds, num_max_timesteps=K_step)
>>> with paddle.no_grad():
>>> sample = diffusion.inference(
>>> paddle.randn(x.shape), c, x,
>>> paddle.randn(x.shape), c, ref_x=x_in,
>>> num_inference_steps=infer_steps,
>>> scheduler_type=scheduler_type,
>>> callback=create_progress_callback())
100%|█████| 5/5 [00:00<00:00, 23.80it/s]
100%|█████| 14/14 [00:00<00:00, 23.80it/s]

"""

Expand Down