Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(train): allow higher segment size #351

Merged
merged 3 commits into from
Apr 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 70 additions & 124 deletions src/so_vits_svc_fork/modules/commons.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,79 @@
import math
from __future__ import annotations

import torch
from torch.nn import functional as F
import torch.nn.functional as F
from torch import Tensor


def slice_segments(x: Tensor, starts: Tensor, length: int) -> Tensor:
if length is None:
return x
length = min(length, x.size(-1))
x_slice = torch.zeros((x.size()[:-1] + (length,)), dtype=x.dtype, device=x.device)
ends = starts + length
for i, (start, end) in enumerate(zip(starts, ends)):
# LOG.debug(i, start, end, x.size(), x[i, ..., start:end].size(), x_slice.size())
# x_slice[i, ...] = x[i, ..., start:end] need to pad
# x_slice[i, ..., :end - start] = x[i, ..., start:end] this does not work
x_slice[i, ...] = F.pad(x[i, ..., start:end], (0, max(0, length - x.size(-1))))
return x_slice


def rand_slice_segments_with_pitch(
x: Tensor, f0: Tensor, x_lengths: Tensor | int | None, segment_size: int | None
):
if segment_size is None:
return x, f0, torch.arange(x.size(0), device=x.device)
if x_lengths is None:
x_lengths = x.size(-1) * torch.ones(
x.size(0), dtype=torch.long, device=x.device
)
# slice_starts = (torch.rand(z.size(0), device=z.device) * (z_lengths - segment_size)).long()
slice_starts = (
torch.rand(x.size(0), device=x.device)
* torch.max(
x_lengths - segment_size, torch.zeros_like(x_lengths, device=x.device)
)
).long()
z_slice = slice_segments(x, slice_starts, segment_size)
f0_slice = slice_segments(f0, slice_starts, segment_size)
return z_slice, f0_slice, slice_starts


def slice_2d_segments(x: Tensor, starts: Tensor, length: int) -> Tensor:
batch_size, num_features, seq_len = x.shape
ends = starts + length
idxs = (
torch.arange(seq_len, device=x.device)
.unsqueeze(0)
.unsqueeze(1)
.repeat(batch_size, num_features, 1)
)
mask = (idxs >= starts.unsqueeze(-1).unsqueeze(-1)) & (
idxs < ends.unsqueeze(-1).unsqueeze(-1)
)
return x[mask].reshape(batch_size, num_features, length)


def slice_pitch_segments(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
ret[i] = x[i, idx_str:idx_end]
return ret
def slice_1d_segments(x: Tensor, starts: Tensor, length: int) -> Tensor:
batch_size, seq_len = x.shape
ends = starts + length
idxs = torch.arange(seq_len, device=x.device).unsqueeze(0).repeat(batch_size, 1)
mask = (idxs >= starts.unsqueeze(-1)) & (idxs < ends.unsqueeze(-1))
return x[mask].reshape(batch_size, length)


def rand_slice_segments_with_pitch(x, pitch, x_lengths=None, segment_size=4):
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = x_lengths - segment_size + 1
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size)
ret_pitch = slice_pitch_segments(pitch, ids_str, segment_size)
return ret, ret_pitch, ids_str
def _slice_segments_v3(x: Tensor, starts: Tensor, length: int) -> Tensor:
shape = x.shape[:-1] + (length,)
ends = starts + length
idxs = torch.arange(x.shape[-1], device=x.device).unsqueeze(0).unsqueeze(0)
unsqueeze_dims = len(shape) - len(
x.shape
) # calculate number of dimensions to unsqueeze
starts = starts.reshape(starts.shape + (1,) * unsqueeze_dims)
ends = ends.reshape(ends.shape + (1,) * unsqueeze_dims)
mask = (idxs >= starts) & (idxs < ends)
return x[mask].reshape(shape)


def init_weights(m, mean=0.0, std=0.01):
Expand All @@ -40,89 +92,6 @@ def convert_pad_shape(pad_shape):
return pad_shape


def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result


def kl_divergence(m_p, logs_p, m_q, logs_q):
"""KL(P||Q)"""
kl = (logs_q - logs_p) - 0.5
kl += (
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
)
return kl


def rand_gumbel(shape):
"""Sample from the Gumbel distribution, protect from overflows."""
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
return -torch.log(-torch.log(uniform_samples))


def rand_gumbel_like(x):
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
return g


def slice_segments(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
ret[i] = x[i, :, idx_str:idx_end]
return ret


def rand_slice_segments(x, x_lengths=None, segment_size=4):
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = x_lengths - segment_size + 1
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str


def rand_spec_segments(x, x_lengths=None, segment_size=4):
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = x_lengths - segment_size
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str


def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
position = torch.arange(length, dtype=torch.float)
num_timescales = channels // 2
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
num_timescales - 1
)
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
)
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
signal = F.pad(signal, [0, 0, 0, channels % 2])
signal = signal.view(1, channels, length)
return signal


def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return x + signal.to(dtype=x.dtype, device=x.device)


def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)


def subsequent_mask(length):
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
return mask
Expand All @@ -138,36 +107,13 @@ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
return acts


def shift_1d(x):
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
return x


def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)


def generate_path(duration, mask):
"""
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
"""
duration.device

b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1)

cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path.unsqueeze(1).transpose(2, 3) * mask
return path


def clip_grad_value_(parameters, clip_value, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
Expand Down
1 change: 1 addition & 0 deletions src/so_vits_svc_fork/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> None:
ids_slice * self.hparams.data.hop_length,
self.hparams.train.segment_size,
)
y = y[..., : y_hat.shape[-1]]

# generator loss
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.net_d(y, y_hat)
Expand Down