diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 448fefc712..a32ad00f56 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -125,7 +125,7 @@ def evaluation(model, criterion, data_loader, global_step): def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, global_step): model.train() - best_loss = float("inf") + best_loss = {"train_loss": None, "eval_loss": float("inf")} avg_loader_time = 0 end_time = time.time() for epoch in range(c.epochs): @@ -248,7 +248,7 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, ) # save the best checkpoint best_loss = save_best_model( - eval_loss, + {"train_loss": None, "eval_loss": eval_loss}, best_loss, c, model, diff --git a/requirements.txt b/requirements.txt index 1f7a44f6d8..23e8d2d013 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,7 +27,7 @@ pandas>=1.4,<2.0 # deps for training matplotlib>=3.7.0 # coqui stack -trainer>=0.0.32 +trainer>=0.0.36 # config management coqpit>=0.0.16 # chinese g2p deps