Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Matcha TTS #3582

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions TTS/tts/configs/matcha_tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from dataclasses import dataclass, field

from TTS.tts.configs.shared_configs import BaseTTSConfig


@dataclass
class MatchaTTSConfig(BaseTTSConfig):
model: str = "matcha_tts"
num_chars: int = None
299 changes: 299 additions & 0 deletions TTS/tts/layers/matcha_tts/UNet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
import math
from einops import pack, rearrange
import torch
from torch import nn
import conformer


class PositionalEncoding(torch.nn.Module):
def __init__(self, channels):
super().__init__()
self.channels = channels

def forward(self, x, scale=1000):
if x.ndim < 1:
x = x.unsqueeze(0)
emb = math.log(10000) / (self.channels // 2 - 1)
emb = torch.exp(torch.arange(self.channels // 2, device=x.device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb

class ConvBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, num_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
nn.GroupNorm(num_groups, out_channels),
nn.Mish()
)

def forward(self, x, mask=None):
if mask is not None:
x = x * mask
output = self.block(x)
if mask is not None:
output = output * mask
return output


class ResNetBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, time_embed_channels, num_groups=8):
super().__init__()
self.block_1 = ConvBlock1D(in_channels, out_channels, num_groups=num_groups)
self.mlp = nn.Sequential(
nn.Mish(),
nn.Linear(time_embed_channels, out_channels)
)
self.block_2 = ConvBlock1D(in_channels=out_channels, out_channels=out_channels, num_groups=num_groups)
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1)

def forward(self, x, mask, t):
h = self.block_1(x, mask)
h += self.mlp(t).unsqueeze(-1)
h = self.block_2(h, mask)
output = h + self.conv(x * mask)
return output


class Downsample1D(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=3, stride=2, padding=1)

def forward(self, x):
return self.conv(x)


class Upsample1D(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv = nn.ConvTranspose1d(in_channels=channels, out_channels=channels, kernel_size=4, stride=2, padding=1)

def forward(self, x):
return self.conv(x)


class ConformerBlock(conformer.ConformerBlock):
def __init__(
self,
dim: int,
dim_head: int = 64,
heads: int = 8,
ff_mult: int = 4,
conv_expansion_factor: int = 2,
conv_kernel_size: int = 31,
attn_dropout: float = 0.,
ff_dropout: float = 0.,
conv_dropout: float = 0.,
conv_causal: bool = False,
):
super().__init__(
dim=dim,
dim_head=dim_head,
heads=heads,
ff_mult=ff_mult,
conv_expansion_factor=conv_expansion_factor,
conv_kernel_size=conv_kernel_size,
attn_dropout=attn_dropout,
ff_dropout=ff_dropout,
conv_dropout=conv_dropout,
conv_causal=conv_causal,
)

def forward(self, x, mask,):
x = rearrange(x, "b c t -> b t c")
mask = rearrange(mask, "b 1 t -> b t")
output = super().forward(x=x, mask=mask.bool())
return rearrange(output, "b t c -> b c t")


class UNet(nn.Module):
def __init__(
self,
in_channels: int,
model_channels: int,
out_channels: int,
num_blocks: int,
transformer_num_heads: int = 4,
transformer_dim_head: int = 64,
transformer_ff_mult: int = 1,
transformer_conv_expansion_factor: int = 2,
transformer_conv_kernel_size: int = 31,
transformer_dropout: float = 0.05,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels

self.time_encoder = PositionalEncoding(in_channels)
time_embed_channels = model_channels * 4
self.time_embed = nn.Sequential(
nn.Linear(in_channels, time_embed_channels),
nn.SiLU(),
nn.Linear(time_embed_channels, time_embed_channels),
)

self.input_blocks = nn.ModuleList([])
block_in_channels = in_channels * 2
block_out_channels = model_channels
for level in range(num_blocks):
block = nn.ModuleList([])

block.append(
ResNetBlock1D(
in_channels=block_in_channels,
out_channels=block_out_channels,
time_embed_channels=time_embed_channels
)
)

block.append(
self._create_transformer_block(
block_out_channels,
dim_head=transformer_dim_head,
num_heads=transformer_num_heads,
ff_mult=transformer_ff_mult,
conv_expansion_factor=transformer_conv_expansion_factor,
conv_kernel_size=transformer_conv_kernel_size,
dropout=transformer_dropout,
)
)

if level != num_blocks - 1:
block.append(Downsample1D(block_out_channels))
else:
block.append(None)

block_in_channels = block_out_channels
self.input_blocks.append(block)

self.middle_blocks = nn.ModuleList([])
for i in range(2):
block = nn.ModuleList([])

block.append(
ResNetBlock1D(
in_channels=block_out_channels,
out_channels=block_out_channels,
time_embed_channels=time_embed_channels
)
)

block.append(
self._create_transformer_block(
block_out_channels,
dim_head=transformer_dim_head,
num_heads=transformer_num_heads,
ff_mult=transformer_ff_mult,
conv_expansion_factor=transformer_conv_expansion_factor,
conv_kernel_size=transformer_conv_kernel_size,
dropout=transformer_dropout,
)
)

self.middle_blocks.append(block)

self.output_blocks = nn.ModuleList([])
block_in_channels = block_out_channels * 2
block_out_channels = model_channels
for level in range(num_blocks):
block = nn.ModuleList([])

block.append(
ResNetBlock1D(
in_channels=block_in_channels,
out_channels=block_out_channels,
time_embed_channels=time_embed_channels
)
)

block.append(
self._create_transformer_block(
block_out_channels,
dim_head=transformer_dim_head,
num_heads=transformer_num_heads,
ff_mult=transformer_ff_mult,
conv_expansion_factor=transformer_conv_expansion_factor,
conv_kernel_size=transformer_conv_kernel_size,
dropout=transformer_dropout,
)
)

if level != num_blocks - 1:
block.append(Upsample1D(block_out_channels))
else:
block.append(None)

block_in_channels = block_out_channels * 2
self.output_blocks.append(block)

self.conv_block = ConvBlock1D(model_channels, model_channels)
self.conv = nn.Conv1d(model_channels, self.out_channels, 1)

def _create_transformer_block(
self,
dim,
dim_head: int = 64,
num_heads: int = 4,
ff_mult: int = 1,
conv_expansion_factor: int = 2,
conv_kernel_size: int = 31,
dropout: float = 0.05,
):
return ConformerBlock(
dim=dim,
dim_head=dim_head,
heads=num_heads,
ff_mult=ff_mult,
conv_expansion_factor=conv_expansion_factor,
conv_kernel_size=conv_kernel_size,
attn_dropout=dropout,
ff_dropout=dropout,
conv_dropout=dropout,
conv_causal=False,
)

def forward(self, x_t, mean, mask, t):
t = self.time_encoder(t)
t = self.time_embed(t)

x_t = pack([x_t, mean], "b * t")[0]

hidden_states = []
mask_states = [mask]

for block in self.input_blocks:
res_net_block, transformer, downsample = block

x_t = res_net_block(x_t, mask, t)
x_t = transformer(x_t, mask)

hidden_states.append(x_t)

if downsample is not None:
x_t = downsample(x_t * mask)
mask = mask[:, :, ::2]
mask_states.append(mask)

for block in self.middle_blocks:
res_net_block, transformer = block
mask = mask_states[-1]
x_t = res_net_block(x_t, mask, t)
x_t = transformer(x_t, mask)

for block in self.output_blocks:
res_net_block, transformer, upsample = block

x_t = pack([x_t, hidden_states.pop()], "b * t")[0]
mask = mask_states.pop()
x_t = res_net_block(x_t, mask, t)
x_t = transformer(x_t, mask)

if upsample is not None:
x_t = upsample(x_t * mask)

output = self.conv_block(x_t)
output = self.conv(x_t)

return output * mask
32 changes: 32 additions & 0 deletions TTS/tts/layers/matcha_tts/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch
from torch import nn
import torch.nn.functional as F

from TTS.tts.layers.matcha_tts.UNet import UNet


class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.sigma_min = 1e-5
self.predictor = UNet(
in_channels=80,
model_channels=256,
out_channels=80,
num_blocks=2
)

def forward(self, x_1, mean, mask):
"""
Shapes:
- x_1: :math:`[B, C, T]`
- mean: :math:`[B, C ,T]`
- mask: :math:`[B, 1, T]`
"""
t = torch.rand([x_1.size(0), 1, 1], device=x_1.device, dtype=x_1.dtype)
x_0 = torch.randn_like(x_1)
x_t = (1 - (1 - self.sigma_min) * t) * x_0 + t * x_1
u_t = x_1 - (1 - self.sigma_min) * x_0
v_t = self.predictor(x_t, mean, mask, t.squeeze())
loss = F.mse_loss(v_t, u_t, reduction="sum") / (torch.sum(mask) * u_t.shape[1])
return loss
Loading
Loading