Skip to content

Commit

Permalink
fix(utils): fix epoch variable name in checkpoint save/load functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Scorpi committed Apr 29, 2023
1 parent ed3240a commit f209593
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions src/so_vits_svc_fork/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,11 @@ def load_checkpoint(
"ignore", category=UserWarning, message="TypedStorage is deprecated"
)
checkpoint_dict = torch.load(f, map_location="cpu", weights_only=True)
iteration = checkpoint_dict["iteration"]
# The variable in this data structure was originally named 'iteration'.
# It has been renamed to 'epoch' to represent it's actual value.
# However, when loading this data structure from file, we still use 'iteration' to ensure
# backward compatibility with older checkpoint files
epoch = checkpoint_dict["iteration"]
learning_rate = checkpoint_dict["learning_rate"]

# safe load module
Expand All @@ -256,20 +260,20 @@ def load_checkpoint(
warnings.simplefilter("ignore")
safe_load(optimizer, checkpoint_dict["optimizer"])

LOG.info(f"Loaded checkpoint '{checkpoint_path}' (iteration {iteration})")
return model, optimizer, learning_rate, iteration
LOG.info(f"Loaded checkpoint '{checkpoint_path}' (epoch {epoch})")
return model, optimizer, learning_rate, epoch


def save_checkpoint(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
learning_rate: float,
iteration: int,
epoch: int,
checkpoint_path: Path | str,
) -> None:
LOG.info(
"Saving model and optimizer state at iteration {} to {}".format(
iteration, checkpoint_path
"Saving model and optimizer state at epoch {} to {}".format(
epoch, checkpoint_path
)
)
if hasattr(model, "module"):
Expand All @@ -280,7 +284,11 @@ def save_checkpoint(
torch.save(
{
"model": state_dict,
"iteration": iteration,
# The variable in this data structure was originally named 'iteration'.
# It has been renamed to 'epoch' to represent it's actual value.
# However, when saving this data structure to file, we still use 'iteration' to ensure
# backward compatibility with older checkpoint files
"iteration": epoch,
"optimizer": optimizer.state_dict(),
"learning_rate": learning_rate,
},
Expand Down

0 comments on commit f209593

Please sign in to comment.