Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Runtime estimator #1991

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
797c50b
checkdown
mvpatel2000 Feb 8, 2023
4a1b452
Merge branch 'dev' into mvpatel2000/bottom-up-runtime-estimator
mvpatel2000 Feb 22, 2023
10ad629
checkdown
mvpatel2000 Feb 22, 2023
cb35e15
add runtime estimator
mvpatel2000 Feb 22, 2023
959dbeb
exprot
mvpatel2000 Feb 22, 2023
f8aeaa1
add prints
mvpatel2000 Feb 22, 2023
a884447
tweak logs
mvpatel2000 Feb 22, 2023
3a958a1
fit start
mvpatel2000 Feb 22, 2023
eb456e4
fix start time
mvpatel2000 Feb 22, 2023
db58030
update
mvpatel2000 Feb 22, 2023
6b80706
update guards
mvpatel2000 Feb 22, 2023
947130a
add eval adjustment
mvpatel2000 Feb 22, 2023
2a04051
update comments
mvpatel2000 Feb 22, 2023
93aa9a7
revert speed monitor changes
mvpatel2000 Feb 22, 2023
40ffdc2
Merge branch 'dev' into mvpatel2000/bottom-up-runtime-estimator
mvpatel2000 Feb 22, 2023
e4f6d1a
add logs
mvpatel2000 Feb 22, 2023
69b5515
Merge branch 'mvpatel2000/bottom-up-runtime-estimator' of github.com:…
mvpatel2000 Feb 22, 2023
4c6eaee
add more logs
mvpatel2000 Feb 22, 2023
3bd132d
move timestamp advance
mvpatel2000 Feb 22, 2023
b9321ab
revert
mvpatel2000 Feb 23, 2023
b86e6f9
Merge branch 'dev' into mvpatel2000/bottom-up-runtime-estimator
mvpatel2000 Feb 23, 2023
eaca75c
simplify ghost batchnorm
mvpatel2000 Feb 23, 2023
ce6abad
scale down image sizes
mvpatel2000 Feb 23, 2023
e7e839d
Merge branch 'dev' into mvpatel2000/bottom-up-runtime-estimator
mvpatel2000 Feb 23, 2023
a0c6d4b
tweak tests
mvpatel2000 Feb 24, 2023
a9e1187
add norms
mvpatel2000 Feb 24, 2023
dfdd866
add norms
mvpatel2000 Feb 24, 2023
4f682a9
reset
mvpatel2000 Feb 24, 2023
2c2645e
make none
mvpatel2000 Feb 24, 2023
e4d92cb
fix change
mvpatel2000 Feb 24, 2023
b85d2f1
add warning
mvpatel2000 Feb 24, 2023
d0e07df
update filter
mvpatel2000 Feb 24, 2023
2369238
Merge branch 'dev' into mvpatel2000/bottom-up-runtime-estimator
mvpatel2000 Feb 24, 2023
5715e53
respond to comments
mvpatel2000 Feb 24, 2023
1b952a9
update ignore warnings
mvpatel2000 Feb 24, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions composer/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from composer.callbacks.memory_monitor import MemoryMonitor
from composer.callbacks.mlperf import MLPerfCallback
from composer.callbacks.optimizer_monitor import OptimizerMonitor
from composer.callbacks.runtime_estimator import RuntimeEstimator
from composer.callbacks.speed_monitor import SpeedMonitor
from composer.callbacks.threshold_stopper import ThresholdStopper

Expand All @@ -28,4 +29,5 @@
'ExportForInferenceCallback',
'ThresholdStopper',
'ImageVisualizer',
'RuntimeEstimator',
]
159 changes: 159 additions & 0 deletions composer/callbacks/runtime_estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Estimate total time of training."""
from __future__ import annotations

import time
import warnings
from typing import Any, Dict, List, Optional

from composer.core import Callback, State, TimeUnit
from composer.loggers import Logger

__all__ = ['RuntimeEstimator']


class RuntimeEstimator(Callback):
"""Estimates total training time.

The training time is computed by taking the time elapsed for the current duration and multiplying
out to the full extended length of the training run.

This callback provides a best attempt estimate. This estimate may be inaccurate if throughput
changes through training or other significant changes are made to the model or dataloader.

Example:
.. doctest::

>>> from composer import Trainer
>>> from composer.callbacks import RuntimeEstimator
>>> # constructing trainer object with this callback
>>> trainer = Trainer(
... model=model,
... train_dataloader=train_dataloader,
... eval_dataloader=eval_dataloader,
... optimizers=optimizer,
... max_duration='1ep',
... callbacks=[RuntimeEstimator()],
... )

The runtime estimate is logged by the :class:`.Logger` to the following key as described below.

+-----------------------------------+---------------------------------------------------------+
| Key | Logged data |
+===================================+=========================================================+
| `wall_clock/remaining_estimate` | Estimated time to completion |
+-----------------------------------+---------------------------------------------------------+

Args:
skip_batches (int, optional): Number of batches to skip before starting clock to estimate
remaining time. Typically, the first few batches are slower due to dataloader, cache
warming, and other reasons. Defaults to 1.
"""

def __init__(self, skip_batches: int = 1) -> None:
self._enabled = True
self.batches_left_to_skip = skip_batches
self.start_time = None
self.start_dur = None

# Keep track of time spent evaluating
self.total_eval_wct = 0.0
self.eval_wct_per_label: Dict[str, List[float]] = {}
# How often eval is called as fraction of total training time
self.eval_frequency_per_label: Dict[str, float] = {}
self.last_elapsed_fraction: float = 0.0

def state_dict(self) -> Dict[str, Any]:
return {
'total_eval_wct': self.total_eval_wct,
'eval_wct_per_label': self.eval_wct_per_label,
'eval_frequency_per_label': self.eval_frequency_per_label,
'last_elapsed_fraction': self.last_elapsed_fraction,
}

def load_state_dict(self, state: Dict[str, Any]) -> None:
self.total_eval_wct = state['total_eval_wct']
self.eval_wct_per_label = state['eval_wct_per_label']
self.eval_frequency_per_label = state['eval_frequency_per_label']
self.last_elapsed_fraction = state['last_elapsed_fraction']

def get_elapsed_duration(self, state: State) -> Optional[float]:
"""Get the elapsed duration.

Unlike `state.get_elapsed_duration`, this method computes fractional progress in an epoch
provided at least 1 epoch has passed by recording how many batches were in each epoch.
"""
if state.max_duration is None:
return None
if state.max_duration.unit == TimeUnit('ep'):
if state.timestamp.epoch.value >= 1:
batches_per_epoch = (state.timestamp.batch -
state.timestamp.batch_in_epoch).value / state.timestamp.epoch.value
return state.timestamp.get('ba').value / (state.max_duration.value * batches_per_epoch)
elif state.dataloader_len is not None:
return state.timestamp.get('ba').value / (state.max_duration.value * state.dataloader_len.value)
elapsed_dur = state.get_elapsed_duration()
if elapsed_dur is not None:
return elapsed_dur.value
return None

def batch_start(self, state: State, logger: Logger) -> None:
if self._enabled and self.start_time is None and self.batches_left_to_skip == 0:
self.start_time = time.time()
self.start_dur = self.get_elapsed_duration(state)
if self.start_dur is None:
warnings.warn('`max_duration` is not set. Cannot estimate remaining time.')
self._enabled = False

def batch_end(self, state: State, logger: Logger) -> None:
if not self._enabled:
return
if self.batches_left_to_skip > 0:
self.batches_left_to_skip -= 1
return

elapsed_dur = self.get_elapsed_duration(state)
assert elapsed_dur is not None, 'max_duration checked as non-None on batch_start'

assert self.start_dur is not None
assert self.start_time is not None
if elapsed_dur > self.start_dur:
elapsed_time = time.time() - self.start_time
elapsed_time -= self.total_eval_wct # Subtract time spent evaluating
rate = elapsed_time / (elapsed_dur - self.start_dur)
remaining_time = rate * (1 - elapsed_dur)

# Add remaining time from each evaluator using known frequencies. We explicitly compute
# frequency instead of using time interpolation to avoid saw tooth pattern in estimates
for dataloader_label, eval_wcts in self.eval_wct_per_label.items():
# Discard first eval_wct if possible as it is often slower due to dataset downloading
eval_wct_avg = None
num_evals_finished = len(eval_wcts)
if num_evals_finished > 1:
eval_wct_avg = sum(eval_wcts[1:]) / (num_evals_finished - 1)
else:
eval_wct_avg = sum(eval_wcts) / num_evals_finished
eval_rate = self.eval_frequency_per_label[dataloader_label]
num_total_evals = 1 / eval_rate
remaining_calls = num_total_evals - num_evals_finished
remaining_time += eval_wct_avg * remaining_calls

logger.log_metrics({'wall_clock/remaining_estimate': remaining_time})

def eval_end(self, state: State, logger: Logger) -> None:
# If eval is called before training starts, ignore it
if not self._enabled or self.start_time is None:
return
self.total_eval_wct += state.eval_timestamp.total_wct.total_seconds()
# state.dataloader_label should always be non-None unless user explicitly sets evaluator
# label to None, ignoring type hints
assert state.dataloader_label is not None, 'evaluator label must not be None'
if state.dataloader_label not in self.eval_wct_per_label:
self.eval_wct_per_label[state.dataloader_label] = []
self.eval_wct_per_label[state.dataloader_label].append(state.eval_timestamp.total_wct.total_seconds())
elapsed_fraction = self.get_elapsed_duration(state)
assert elapsed_fraction is not None, 'max_duration checked as non-None on batch_start'
num_evals_finished = len(self.eval_wct_per_label[state.dataloader_label])
self.eval_frequency_per_label[state.dataloader_label] = elapsed_fraction / num_evals_finished
4 changes: 2 additions & 2 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2022,8 +2022,6 @@ def _train_loop(self) -> None:
# This happens if the "break" did not trigger above, or if it
# did (e.g. duration specified in samples/batches/tokens), but it is still
# the end of the dataloader (i.e. next(dataloader) would raise StopIteration)
self.state.timestamp = self.state.timestamp.to_next_epoch()

if self.state.train_metrics is not None:
self._compute_and_log_metrics(
dataloader_label='train',
Expand All @@ -2034,6 +2032,8 @@ def _train_loop(self) -> None:
for scheduler in self.state.schedulers:
scheduler.step()

self.state.timestamp = self.state.timestamp.to_next_epoch()
dakinggg marked this conversation as resolved.
Show resolved Hide resolved

self.engine.run_event(Event.EPOCH_END)

# Pause the timing during evaluation
Expand Down
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ reportUnusedCoroutine = "error"
addopts = "--codeblocks --strict-markers -m 'not gpu and not vision and not doctest and not daily and not remote'"

markers = [
# !!!!!!!!!!!IMPORTANT!!!!!!!!!: when updating the markers, also make sure to update .ci/Jenkinsfile and meta.yaml
# !!!!!!!!!!!IMPORTANT!!!!!!!!!: when updating the markers, also make sure to update meta.yaml
# Tests that require a world_size of two should be annotated with `@pytest.mark.world_size(2)`.
# If not specified, the test will be assumed to have a world-size of one, which is
# equivalent to `@pytest.mark.world_size(1)`
Expand Down Expand Up @@ -118,6 +118,12 @@ filterwarnings = [
'ignore:Torchmetrics v0.9 introduced a new argument class property:UserWarning',
'ignore:torch.distributed._all_gather_base is a private function and will be deprecated:UserWarning',
'ignore:torch.distributed._reduce_scatter_base is a private function and will be deprecated:UserWarning',
# Ignore tensorboard deprecation warnings
'ignore:Call to deprecated create function Descriptor().*:DeprecationWarning:tensorboard',
'ignore:Call to deprecated create function EnumDescriptor().*:DeprecationWarning:tensorboard',
'ignore:Call to deprecated create function EnumValueDescriptor().*:DeprecationWarning:tensorboard',
'ignore:Call to deprecated create function FieldDescriptor().*:DeprecationWarning:tensorboard',
'ignore:Call to deprecated create function FileDescriptor().*:DeprecationWarning:tensorboard',
]

# Coverage
Expand Down
9 changes: 3 additions & 6 deletions tests/algorithms/algorithm_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,10 @@
FusedLayerNorm: simple_bert_settings,
GatedLinearUnits: simple_bert_settings,
GhostBatchNorm: {
'model': (composer_resnet, {
'model_name': 'resnet18',
'num_classes': 2
}),
'dataset': (RandomImageDataset, {
'shape': (3, 224, 224)
'model': (SimpleConvModel, {
'norm': 'group',
}),
'dataset': RandomImageDataset,
'kwargs': {
'ghost_batch_size': 2,
}
Expand Down
1 change: 1 addition & 0 deletions tests/algorithms/test_torch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def test_surgery_torchfx_eval(
@pytest.mark.parametrize('alg_cls', torchscript_algs_with_marks)
@pytest.mark.filterwarnings(
r'ignore:Converting a tensor to a Python .* might cause the trace to be incorrect:torch.jit._trace.TracerWarning')
@pytest.mark.filterwarnings('ignore:__floordiv__ is deprecated')
def test_surgery_onnx(
input: Any,
alg_cls: Type[Algorithm],
Expand Down
18 changes: 16 additions & 2 deletions tests/common/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""Contains commonly used models that are shared across the test suite."""
import copy
from functools import partial
from typing import Any, Dict, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved

import pytest
import torch
Expand Down Expand Up @@ -146,14 +146,27 @@ class SimpleConvModel(ComposerClassifier):
num_classes (int): number of classes (default: 2)
"""

def __init__(self, num_channels: int = 3, num_classes: int = 2) -> None:
def __init__(self, num_channels: int = 3, num_classes: int = 2, norm: Optional[str] = None) -> None:

self.num_classes = num_classes
self.num_channels = num_channels

conv_args = {'kernel_size': (3, 3), 'padding': 1, 'stride': 2}
conv1 = torch.nn.Conv2d(in_channels=num_channels, out_channels=8, **conv_args)
conv2 = torch.nn.Conv2d(in_channels=8, out_channels=4, **conv_args)
norm_layer = None
if norm is None:
norm_layer = torch.nn.Identity()
elif norm == 'batch':
norm_layer = torch.nn.BatchNorm2d(4)
elif norm == 'instance':
norm_layer = torch.nn.InstanceNorm2d(4)
elif norm == 'layer':
norm_layer = torch.nn.LayerNorm(4)
elif norm == 'group':
norm_layer = torch.nn.GroupNorm(2, 4)
else:
raise ValueError(f'Unknown norm: {norm}')
pool = torch.nn.AdaptiveAvgPool2d(1)
flatten = torch.nn.Flatten()
fc1 = torch.nn.Linear(4, 16)
Expand All @@ -162,6 +175,7 @@ def __init__(self, num_channels: int = 3, num_classes: int = 2) -> None:
net = torch.nn.Sequential(
conv1,
conv2,
norm_layer,
pool,
flatten,
fc1,
Expand Down