diff --git a/src/so_vits_svc_fork/utils.py b/src/so_vits_svc_fork/utils.py index 49a2e554e..6b04517c6 100644 --- a/src/so_vits_svc_fork/utils.py +++ b/src/so_vits_svc_fork/utils.py @@ -238,7 +238,11 @@ def load_checkpoint( "ignore", category=UserWarning, message="TypedStorage is deprecated" ) checkpoint_dict = torch.load(f, map_location="cpu", weights_only=True) - iteration = checkpoint_dict["iteration"] + # The variable in this data structure was originally named 'iteration'. + # It has been renamed to 'epoch' to represent it's actual value. + # However, when loading this data structure from file, we still use 'iteration' to ensure + # backward compatibility with older checkpoint files + epoch = checkpoint_dict["iteration"] learning_rate = checkpoint_dict["learning_rate"] # safe load module @@ -256,20 +260,20 @@ def load_checkpoint( warnings.simplefilter("ignore") safe_load(optimizer, checkpoint_dict["optimizer"]) - LOG.info(f"Loaded checkpoint '{checkpoint_path}' (iteration {iteration})") - return model, optimizer, learning_rate, iteration + LOG.info(f"Loaded checkpoint '{checkpoint_path}' (epoch {epoch})") + return model, optimizer, learning_rate, epoch def save_checkpoint( model: torch.nn.Module, optimizer: torch.optim.Optimizer, learning_rate: float, - iteration: int, + epoch: int, checkpoint_path: Path | str, ) -> None: LOG.info( - "Saving model and optimizer state at iteration {} to {}".format( - iteration, checkpoint_path + "Saving model and optimizer state at epoch {} to {}".format( + epoch, checkpoint_path ) ) if hasattr(model, "module"): @@ -280,7 +284,11 @@ def save_checkpoint( torch.save( { "model": state_dict, - "iteration": iteration, + # The variable in this data structure was originally named 'iteration'. + # It has been renamed to 'epoch' to represent it's actual value. + # However, when saving this data structure to file, we still use 'iteration' to ensure + # backward compatibility with older checkpoint files + "iteration": epoch, "optimizer": optimizer.state_dict(), "learning_rate": learning_rate, },