diff --git a/imaginairy/utils/log_utils.py b/imaginairy/utils/log_utils.py index f294ca0d..0b55cb39 100644 --- a/imaginairy/utils/log_utils.py +++ b/imaginairy/utils/log_utils.py @@ -179,6 +179,9 @@ def get_performance_stats(self) -> dict[str, dict[str, float]]: self.summary_context.stop() self.timing_contexts["total"] = self.summary_context + # move total to the end + self.timing_contexts["total"] = self.timing_contexts.pop("total") + if torch.cuda.is_available(): self.summary_context.memory_peak = max( max(context.memory_peak, context.memory_start, context.memory_end)