Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove runtime estimator state dict #2015

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 6 additions & 18 deletions composer/callbacks/runtime_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import time
import warnings
from typing import Any, Dict, List, Optional
from typing import Dict, List, Optional

from composer.core import Callback, State, TimeUnit
from composer.loggers import Logger
Expand Down Expand Up @@ -80,20 +80,6 @@ def __init__(self, skip_batches: int = 1, time_unit: str = 'hours') -> None:
self.eval_frequency_per_label: Dict[str, float] = {}
self.last_elapsed_fraction: float = 0.0

def state_dict(self) -> Dict[str, Any]:
return {
'total_eval_wct': self.total_eval_wct,
'eval_wct_per_label': self.eval_wct_per_label,
'eval_frequency_per_label': self.eval_frequency_per_label,
'last_elapsed_fraction': self.last_elapsed_fraction,
}

def load_state_dict(self, state: Dict[str, Any]) -> None:
self.total_eval_wct = state['total_eval_wct']
self.eval_wct_per_label = state['eval_wct_per_label']
self.eval_frequency_per_label = state['eval_frequency_per_label']
self.last_elapsed_fraction = state['last_elapsed_fraction']

def _get_elapsed_duration(self, state: State) -> Optional[float]:
"""Get the elapsed duration.

Expand Down Expand Up @@ -130,7 +116,7 @@ def batch_end(self, state: State, logger: Logger) -> None:
return

elapsed_dur = self._get_elapsed_duration(state)
assert elapsed_dur is not None, 'max_duration checked as non-None on batch_start'
assert elapsed_dur is not None, 'max_duration checked as non-None on batch_start if enabled'

assert self.start_dur is not None
assert self.start_time is not None
Expand Down Expand Up @@ -168,7 +154,9 @@ def eval_end(self, state: State, logger: Logger) -> None:
if state.dataloader_label not in self.eval_wct_per_label:
self.eval_wct_per_label[state.dataloader_label] = []
self.eval_wct_per_label[state.dataloader_label].append(state.eval_timestamp.total_wct.total_seconds())
elapsed_fraction = self._get_elapsed_duration(state)
assert elapsed_fraction is not None, 'max_duration checked as non-None on batch_start'
elapsed_dur = self._get_elapsed_duration(state)
assert elapsed_dur is not None, 'max_duration checked as non-None on batch_start if enabled'
assert self.start_dur is not None, 'start_dur is set on batch_start if enabled'
elapsed_fraction = elapsed_dur - self.start_dur
num_evals_finished = len(self.eval_wct_per_label[state.dataloader_label])
self.eval_frequency_per_label[state.dataloader_label] = elapsed_fraction / num_evals_finished