diff --git a/src/so_vits_svc_fork/cluster/train_cluster.py b/src/so_vits_svc_fork/cluster/train_cluster.py index 4e7e5974..fe961eb6 100644 --- a/src/so_vits_svc_fork/cluster/train_cluster.py +++ b/src/so_vits_svc_fork/cluster/train_cluster.py @@ -24,8 +24,10 @@ def train_cluster( LOG.info(f"Loading features from {input_dir}") features = [] nums = 0 - for path in input_dir.glob("*.soft.pt"): - features.append(torch.load(path).squeeze(0).numpy().T) + for path in input_dir.rglob("*.data.pt"): + features.append( + torch.load(path, weights_only=True)["content"].squeeze(0).numpy().T + ) features = np.concatenate(features, axis=0).astype(np.float32) if features.shape[0] < n_clusters: raise ValueError(