Skip to content

Commit

Permalink
Make style
Browse files Browse the repository at this point in the history
  • Loading branch information
Edresson committed Oct 6, 2023
1 parent 4a6103f commit 24af6a1
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 244 deletions.
1 change: 0 additions & 1 deletion TTS/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def is_multi_lingual(self):
return self.synthesizer.tts_model.language_manager.num_languages > 1
return False


@property
def speakers(self):
if not self.is_multi_speaker:
Expand Down
1 change: 1 addition & 0 deletions TTS/tts/layers/xtts/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def init_gpt_for_inference(self, kv_cache=True, use_deepspeed=False):

if use_deepspeed:
import deepspeed

self.ds_engine = deepspeed.init_inference(
model=self.gpt_inference.half(), # Transformers models
mp_size=1, # Number of GPU
Expand Down
29 changes: 9 additions & 20 deletions TTS/tts/layers/xtts/hifigan_decoder.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import torch
import torchaudio
from torch import nn
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
import torchaudio

from TTS.utils.io import load_fsspec


LRELU_SLOPE = 0.1


Expand Down Expand Up @@ -224,9 +223,7 @@ def __init__(
self.cond_in_each_up_layer = cond_in_each_up_layer

# initial upsampling layers
self.conv_pre = weight_norm(
Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
)
self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3))
resblock = ResBlock1 if resblock_type == "1" else ResBlock2
# upsampling layers
self.ups = nn.ModuleList()
Expand All @@ -246,14 +243,10 @@ def __init__(
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for _, (k, d) in enumerate(
zip(resblock_kernel_sizes, resblock_dilation_sizes)
):
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d))
# post convolution layer
self.conv_post = weight_norm(
Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)
)
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias))
if cond_channels > 0:
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)

Expand Down Expand Up @@ -318,9 +311,7 @@ def inference(self, c):
Tensor: [B, 1, T]
"""
c = c.to(self.conv_pre.weight.device)
c = torch.nn.functional.pad(
c, (self.inference_padding, self.inference_padding), "replicate"
)
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
return self.forward(c)

def remove_weight_norm(self):
Expand All @@ -342,6 +333,7 @@ def load_checkpoint(
assert not self.training
self.remove_weight_norm()


class SELayer(nn.Module):
def __init__(self, channel, reduction=8):
super(SELayer, self).__init__()
Expand Down Expand Up @@ -425,10 +417,8 @@ def forward(self, x):
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)



class ResNetSpeakerEncoder(nn.Module):
"""This is copied from 🐸TTS to remove it from the dependencies.
"""
"""This is copied from 🐸TTS to remove it from the dependencies."""

# pylint: disable=W0102
def __init__(
Expand Down Expand Up @@ -620,6 +610,7 @@ def load_checkpoint(
return criterion, state["step"]
return criterion


class HifiDecoder(torch.nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -724,9 +715,7 @@ def inference(self, c, g):
"""
return self.forward(c, g=g)

def load_checkpoint(
self, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
def load_checkpoint(self, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
# remove unused keys
state = state["model"]
Expand Down
Loading

0 comments on commit 24af6a1

Please sign in to comment.