From d1aad756bb949e03565d28c1146bf1122aa4e897 Mon Sep 17 00:00:00 2001 From: hungcs Date: Tue, 12 Sep 2023 16:02:13 -0700 Subject: [PATCH] Store steps_per_epoch in Trainer --- ludwig/trainers/trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ludwig/trainers/trainer.py b/ludwig/trainers/trainer.py index ff770b6db8c..6261cb2a845 100644 --- a/ludwig/trainers/trainer.py +++ b/ludwig/trainers/trainer.py @@ -135,6 +135,7 @@ def __init__( self.epochs = config.epochs self.train_steps = config.train_steps + self.steps_per_epoch = 0 # Computed during training, after batcher has been initialized. self.total_steps = 0 # Computed during training, after batcher has been initialized. self.regularization_lambda = config.regularization_lambda @@ -755,6 +756,7 @@ def train( augmentation_pipeline=self.model.get_augmentation_pipelines(), ) as batcher: # ================ Training Loop ================ + self.steps_per_epoch = batcher.steps_per_epoch self.total_steps = get_total_steps(self.epochs, batcher.steps_per_epoch, self.train_steps) # Get the terminal steps per checkpoint.