diff --git a/src/so_vits_svc_fork/preprocessing/config_templates/quickvc.json b/src/so_vits_svc_fork/preprocessing/config_templates/quickvc.json index db26f630..1d725259 100644 --- a/src/so_vits_svc_fork/preprocessing/config_templates/quickvc.json +++ b/src/so_vits_svc_fork/preprocessing/config_templates/quickvc.json @@ -25,7 +25,9 @@ "win_lengths": [300, 600, 120], "window": "hann_window", "num_workers": 4, - "log_version": 0 + "log_version": 0, + "ckpt_name_by_step": false, + "accumulate_grad_batches": 1 }, "data": { "training_files": "filelists/44k/train.txt", diff --git a/src/so_vits_svc_fork/preprocessing/config_templates/so-vits-svc-4.0v1-legacy.json b/src/so_vits_svc_fork/preprocessing/config_templates/so-vits-svc-4.0v1-legacy.json index 93907168..174e1b79 100644 --- a/src/so_vits_svc_fork/preprocessing/config_templates/so-vits-svc-4.0v1-legacy.json +++ b/src/so_vits_svc_fork/preprocessing/config_templates/so-vits-svc-4.0v1-legacy.json @@ -21,7 +21,9 @@ "port": "8001", "keep_ckpts": 3, "num_workers": 4, - "log_version": 0 + "log_version": 0, + "ckpt_name_by_step": false, + "accumulate_grad_batches": 1 }, "data": { "training_files": "filelists/44k/train.txt", diff --git a/src/so_vits_svc_fork/preprocessing/config_templates/so-vits-svc-4.0v1.json b/src/so_vits_svc_fork/preprocessing/config_templates/so-vits-svc-4.0v1.json index 8a481455..0e5522c1 100644 --- a/src/so_vits_svc_fork/preprocessing/config_templates/so-vits-svc-4.0v1.json +++ b/src/so_vits_svc_fork/preprocessing/config_templates/so-vits-svc-4.0v1.json @@ -21,7 +21,9 @@ "port": "8001", "keep_ckpts": 3, "num_workers": 4, - "log_version": 0 + "log_version": 0, + "ckpt_name_by_step": false, + "accumulate_grad_batches": 1 }, "data": { "training_files": "filelists/44k/train.txt", diff --git a/src/so_vits_svc_fork/train.py b/src/so_vits_svc_fork/train.py index c298c0e8..ea8c46a0 100644 --- a/src/so_vits_svc_fork/train.py +++ b/src/so_vits_svc_fork/train.py @@ -505,14 +505,16 @@ def validation_step(self, batch, batch_idx): self.optim_g, self.learning_rate, self.current_epoch + 1, # prioritize prevention of undervaluation - Path(self.hparams.model_dir) / f"G_{self.total_batch_idx}.pth", + Path(self.hparams.model_dir) + / f"G_{self.total_batch_idx if self.hparams.train.get('ckpt_name_by_step', False) else self.current_epoch + 1}.pth", ) utils.save_checkpoint( self.net_d, self.optim_d, self.learning_rate, self.current_epoch + 1, - Path(self.hparams.model_dir) / f"D_{self.total_batch_idx}.pth", + Path(self.hparams.model_dir) + / f"D_{self.total_batch_idx if self.hparams.train.get('ckpt_name_by_step', False) else self.current_epoch + 1}.pth", ) keep_ckpts = self.hparams.train.get("keep_ckpts", 0) if keep_ckpts > 0: