Skip to content

Commit

Permalink
fix(cluster): fix train_cluster (#250)
Browse files Browse the repository at this point in the history
  • Loading branch information
34j authored Apr 8, 2023
1 parent dbef454 commit b0c93e4
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/so_vits_svc_fork/cluster/train_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ def train_cluster(
LOG.info(f"Loading features from {input_dir}")
features = []
nums = 0
for path in input_dir.glob("*.soft.pt"):
features.append(torch.load(path).squeeze(0).numpy().T)
for path in input_dir.rglob("*.data.pt"):
features.append(
torch.load(path, weights_only=True)["content"].squeeze(0).numpy().T
)
features = np.concatenate(features, axis=0).astype(np.float32)
if features.shape[0] < n_clusters:
raise ValueError(
Expand Down

0 comments on commit b0c93e4

Please sign in to comment.