Skip to content

Commit

Permalink
fix(train): properly stop training after epochs has been reached
Browse files Browse the repository at this point in the history
  • Loading branch information
Lordmau5 authored Apr 19, 2023
1 parent 48e0bbd commit f9bb3d8
Showing 1 changed file with 37 additions and 25 deletions.
62 changes: 37 additions & 25 deletions src/so_vits_svc_fork/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,42 @@ def stft(

torch.stft = stft

def on_train_end(self) -> None:
self.save_checkpoints(adjust=0)

def save_checkpoints(self, adjust=1):
# `on_train_end` will be the actual epoch, not a -1, so we have to call it with `adjust = 0`
current_epoch = self.current_epoch + adjust
total_batch_idx = self.total_batch_idx - 1 + adjust

utils.save_checkpoint(
self.net_g,
self.optim_g,
self.learning_rate,
current_epoch,
Path(self.hparams.model_dir)
/ f"G_{total_batch_idx if self.hparams.train.get('ckpt_name_by_step', False) else current_epoch}.pth",
)
utils.save_checkpoint(
self.net_d,
self.optim_d,
self.learning_rate,
current_epoch,
Path(self.hparams.model_dir)
/ f"D_{total_batch_idx if self.hparams.train.get('ckpt_name_by_step', False) else current_epoch}.pth",
)
keep_ckpts = self.hparams.train.get("keep_ckpts", 0)
if keep_ckpts > 0:
utils.clean_checkpoints(
path_to_models=self.hparams.model_dir,
n_ckpts_to_keep=keep_ckpts,
sort_by_time=True,
)

def set_current_epoch(self, epoch: int):
LOG.info(f"Setting current epoch to {epoch}")
self.trainer.fit_loop.epoch_progress.current.completed = epoch
self.trainer.fit_loop.epoch_progress.current.processed = epoch
assert self.current_epoch == epoch, f"{self.current_epoch} != {epoch}"

def set_global_step(self, global_step: int):
Expand Down Expand Up @@ -511,28 +544,7 @@ def validation_step(self, batch, batch_idx):
),
}
)
if self.current_epoch == 0 or batch_idx != 0:
return
utils.save_checkpoint(
self.net_g,
self.optim_g,
self.learning_rate,
self.current_epoch + 1, # prioritize prevention of undervaluation
Path(self.hparams.model_dir)
/ f"G_{self.total_batch_idx if self.hparams.train.get('ckpt_name_by_step', False) else self.current_epoch + 1}.pth",
)
utils.save_checkpoint(
self.net_d,
self.optim_d,
self.learning_rate,
self.current_epoch + 1,
Path(self.hparams.model_dir)
/ f"D_{self.total_batch_idx if self.hparams.train.get('ckpt_name_by_step', False) else self.current_epoch + 1}.pth",
)
keep_ckpts = self.hparams.train.get("keep_ckpts", 0)
if keep_ckpts > 0:
utils.clean_checkpoints(
path_to_models=self.hparams.model_dir,
n_ckpts_to_keep=keep_ckpts,
sort_by_time=True,
)

def on_validation_end(self) -> None:
if not self.trainer.sanity_checking:
self.save_checkpoints()

0 comments on commit f9bb3d8

Please sign in to comment.