diff --git a/composer/callbacks/runtime_estimator.py b/composer/callbacks/runtime_estimator.py index 7fcfeaac81..8833c8c8a2 100644 --- a/composer/callbacks/runtime_estimator.py +++ b/composer/callbacks/runtime_estimator.py @@ -6,7 +6,7 @@ import time import warnings -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional from composer.core import Callback, State, TimeUnit from composer.loggers import Logger @@ -80,20 +80,6 @@ def __init__(self, skip_batches: int = 1, time_unit: str = 'hours') -> None: self.eval_frequency_per_label: Dict[str, float] = {} self.last_elapsed_fraction: float = 0.0 - def state_dict(self) -> Dict[str, Any]: - return { - 'total_eval_wct': self.total_eval_wct, - 'eval_wct_per_label': self.eval_wct_per_label, - 'eval_frequency_per_label': self.eval_frequency_per_label, - 'last_elapsed_fraction': self.last_elapsed_fraction, - } - - def load_state_dict(self, state: Dict[str, Any]) -> None: - self.total_eval_wct = state['total_eval_wct'] - self.eval_wct_per_label = state['eval_wct_per_label'] - self.eval_frequency_per_label = state['eval_frequency_per_label'] - self.last_elapsed_fraction = state['last_elapsed_fraction'] - def _get_elapsed_duration(self, state: State) -> Optional[float]: """Get the elapsed duration. @@ -130,7 +116,7 @@ def batch_end(self, state: State, logger: Logger) -> None: return elapsed_dur = self._get_elapsed_duration(state) - assert elapsed_dur is not None, 'max_duration checked as non-None on batch_start' + assert elapsed_dur is not None, 'max_duration checked as non-None on batch_start if enabled' assert self.start_dur is not None assert self.start_time is not None @@ -168,7 +154,9 @@ def eval_end(self, state: State, logger: Logger) -> None: if state.dataloader_label not in self.eval_wct_per_label: self.eval_wct_per_label[state.dataloader_label] = [] self.eval_wct_per_label[state.dataloader_label].append(state.eval_timestamp.total_wct.total_seconds()) - elapsed_fraction = self._get_elapsed_duration(state) - assert elapsed_fraction is not None, 'max_duration checked as non-None on batch_start' + elapsed_dur = self._get_elapsed_duration(state) + assert elapsed_dur is not None, 'max_duration checked as non-None on batch_start if enabled' + assert self.start_dur is not None, 'start_dur is set on batch_start if enabled' + elapsed_fraction = elapsed_dur - self.start_dur num_evals_finished = len(self.eval_wct_per_label[state.dataloader_label]) self.eval_frequency_per_label[state.dataloader_label] = elapsed_fraction / num_evals_finished