From ef30a9d5ae60fdde5f6b44d6cea8cee0a40dd3e9 Mon Sep 17 00:00:00 2001 From: BlueAmulet <43395286+BlueAmulet@users.noreply.github.com> Date: Sun, 2 Apr 2023 23:42:10 -0600 Subject: [PATCH] perf(utils): better implementation of repeat_expand_2d (#216) --- src/so_vits_svc_fork/utils.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/src/so_vits_svc_fork/utils.py b/src/so_vits_svc_fork/utils.py index c84ffbb8..9878825b 100644 --- a/src/so_vits_svc_fork/utils.py +++ b/src/so_vits_svc_fork/utils.py @@ -369,18 +369,12 @@ def get_hparams(config_path: Path | str) -> HParams: def repeat_expand_2d(content: torch.Tensor, target_len: int) -> torch.Tensor: # content : [h, t] src_len = content.shape[-1] - target = torch.zeros( - [content.shape[0], target_len], dtype=content.dtype, device=content.device - ) - temp = torch.arange(src_len + 1) * target_len / src_len - current_pos = 0 - for i in range(target_len): - if i < temp[current_pos + 1]: - target[:, i] = content[:, current_pos] - else: - current_pos += 1 - target[:, i] = content[:, current_pos] - return target + if target_len < src_len: + return content[:, :target_len] + else: + return torch.nn.functional.interpolate( + content.unsqueeze(0), size=target_len, mode="nearest" + ).squeeze(0) def plot_data_to_numpy(x: ndarray, y: ndarray) -> ndarray: