diff --git a/src/so_vits_svc_fork/utils.py b/src/so_vits_svc_fork/utils.py index c84ffbb8..9878825b 100644 --- a/src/so_vits_svc_fork/utils.py +++ b/src/so_vits_svc_fork/utils.py @@ -369,18 +369,12 @@ def get_hparams(config_path: Path | str) -> HParams: def repeat_expand_2d(content: torch.Tensor, target_len: int) -> torch.Tensor: # content : [h, t] src_len = content.shape[-1] - target = torch.zeros( - [content.shape[0], target_len], dtype=content.dtype, device=content.device - ) - temp = torch.arange(src_len + 1) * target_len / src_len - current_pos = 0 - for i in range(target_len): - if i < temp[current_pos + 1]: - target[:, i] = content[:, current_pos] - else: - current_pos += 1 - target[:, i] = content[:, current_pos] - return target + if target_len < src_len: + return content[:, :target_len] + else: + return torch.nn.functional.interpolate( + content.unsqueeze(0), size=target_len, mode="nearest" + ).squeeze(0) def plot_data_to_numpy(x: ndarray, y: ndarray) -> ndarray: