Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed monitor refactor #1987

Merged
merged 17 commits into from
Feb 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 218 additions & 52 deletions composer/callbacks/speed_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,144 @@
"""Monitor throughput during training."""
from __future__ import annotations

import warnings
from collections import deque
from typing import Any, Deque, Dict
from typing import Any, Callable, Deque, Dict, Optional, Union

import torch

from composer.core import Callback, State
from composer.loggers import Logger
from composer.models.base import ComposerModel
from composer.utils import dist

__all__ = ['SpeedMonitor']

GPU_AVAILABLE_FLOPS = {
# source: https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet
# nvidia publishes spec sheet with a 2x sparsity factor
'h100-sxm': {
'fp64': 67e12,
'fp32': 67e12,
'tf32': 989e12 / 2,
'fp16': 1.979e15 / 2,
'amp_fp16': 1.979e15 / 2,
'bf16': 1.979e15 / 2,
'amp_bf16': 1.979e15 / 2,
'fp8': 3.958e15 / 2,
'amp_fp8': 3.958e15 / 2,
'int8': 3.958e15 / 2,
},
'h100-pcie': {
'fp64': 51e12,
'fp32': 51e12,
'tf32': 756e12 / 2,
'fp16': 1.513e15 / 2,
'amp_fp16': 1.513e15 / 2,
'bf16': 1.513e15 / 2,
'amp_bf16': 1.513e15 / 2,
'fp8': 3.026e15 / 2,
'amp_fp8': 3.026e15 / 2,
'int8': 3.026e15 / 2,
},
# source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
# sxm and pcie have same flop counts
'a100': {
'fp64': 19.5e12,
'fp32': 19.5e12,
'tf32': 156e12,
'fp16': 312e12,
'amp_fp16': 312e12,
'bf16': 312e12,
'amp_bf16': 312e12,
},
# source: https://images.nvidia.com/content/technologies/volta/pdf/volta-v100-datasheet-update-us-1165301-r5.pdf
'v100-sxm': {
'fp64': 7.8e12,
'fp32': 15.7e12,
'fp16': 125e12,
'amp_fp16': 125e12,
},
'v100-pcie': {
'fp64': 7e12,
'fp32': 14e12,
'fp16': 112e12,
'amp_fp16': 112e12,
},
'v100s-pcie': {
'fp64': 8.2e12,
'fp32': 16.4e12,
'fp16': 130e12,
'amp_fp16': 130e12,
},
# source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf
# sxm and pcie have same flop counts
't4': {
'fp32': 8.1e12,
'fp16': 65e12,
'amp_fp16': 65e12,
'int8': 130e12,
'int4': 260e12,
},
}


def get_gpu_flops_available(state: State):
gpu_flops_available = None

# Return 0 if no CUDA device (e.g., when running with CPU only)
if not torch.cuda.is_available():
return 0

# torch.cuda.get_device_name() ex output: 'NVIDIA A100-SXM4-40GB'
device_name = torch.cuda.get_device_name().lower()
if 'h100-sxm' in device_name:
device_name = 'h100-sxm'
elif 'h100-pcie' in device_name:
device_name = 'h100-pcie'
elif 'a100' in device_name:
device_name = 'a100'
elif 'v100-sxm' in device_name:
device_name = 'v100-sxm'
elif 'v100-pcie' in device_name:
device_name = 'v100-pcie'
elif 't4' in device_name:
device_name = 't4'
else:
device_name = None

if device_name is not None:
try:
gpu_flops_available = int(GPU_AVAILABLE_FLOPS[device_name][state.precision.value])
except:
gpu_flops_available = None

if gpu_flops_available is None:
warnings.warn(
f'gpu_flop count not found for {device_name} with precision: {state.precision.value}; ' +\
f'MFU cannot be calculated and reported. gpu_flops_available can be manually' +\
f'overridden by setting gpu_flops_available in SpeedMonitor.'
)
# Setting to 0 will disable MFU computation and prevent
# the speed monitor from running this helper every batch
gpu_flops_available = 0

return gpu_flops_available


class SpeedMonitor(Callback):
"""Logs the training throughput.

The training throughput in terms of number of samples per second is logged on the
:attr:`.Event.BATCH_END` event if we have reached the ``window_size`` threshold.
The training throughput is logged on the :attr:`.Event.BATCH_END` event once we have reached
the `window_size` threshold. If a model has `flops_per_batch` attribute, then flops per second
is also logged. If running on a known GPU type or if `gpu_flops_available` is set, then MFU is
also logged. All metrics are also logged as per device by dividing by world size.

The wall clock train time is logged on every :attr:`.Event.BATCH_END` event.
To compute `flops_per_sec`, the model attribute `flops_per_batch` should be set to a callable
which accepts a batch and returns the number of flops for that batch. Typically, this should
be flops per sample times the batch size unless pad tokens are used.

The average throughout over an epoch is logged on the :attr:`.Event.EPOCH_END` event.
The wall clock time is logged on every :attr:`.Event.BATCH_END` event.

Example:
.. doctest::
Expand All @@ -41,84 +161,130 @@ class SpeedMonitor(Callback):
The training throughput is logged by the :class:`.Logger` to the following keys as
described below.

+----------------------------------+-------------------------------------------------------------+
| Key | Logged data |
+==================================+=============================================================+
| | Rolling average (over ``window_size`` most recent |
| ``throughput/samples_per_sec`` | batches) of the number of samples processed per second |
| | |
+----------------------------------+-------------------------------------------------------------+
| ``wall_clock/train`` | Total elapsed training time |
+----------------------------------+-------------------------------------------------------------+
| ``wall_clock/val`` | Total elapsed validation time |
+----------------------------------+-------------------------------------------------------------+
| ``wall_clock/total`` | Total elapsed time (wall_clock/train + wall_clock/val) |
+----------------------------------+-------------------------------------------------------------+
+-------------------------------------+-----------------------------------------------------------+
| Key | Logged data |
+=====================================+===========================================================+
| | Rolling average (over `window_size` most recent |
| `throughput/batches_per_sec` | batches) of the number of batches processed per second |
| | |
+-------------------------------------+-----------------------------------------------------------+
| | Rolling average (over `window_size` most recent |
| `throughput/samples_per_sec` | batches) of the number of samples processed per second |
| | |
+-------------------------------------+-----------------------------------------------------------+
| | Rolling average (over `window_size` most recent |
| `throughput/tokens_per_sec` | batches) of the number of tokens processed per second. |
| | Only logged when dataloader.dataset has `max_seq_len`. |
| | This may include padding depending on dataset |
+-------------------------------------+-----------------------------------------------------------+
| | Estimates flops by `flops_per_batch * batches_per_sec` |
| `throughput/flops_per_sec` | if model has attribute `flops_per_batch` |
| | |
+-------------------------------------+-----------------------------------------------------------+
| `throughput/device/batches_per_sec` | `throughput/batches_per_sec` divided by world size |
+-------------------------------------+-----------------------------------------------------------+
| `throughput/device/samples_per_sec` | `throughput/samples_per_sec` divided by world size |
+-------------------------------------+-----------------------------------------------------------+
| | `throughput/tokens_per_sec` divided by world size. Only |
| `throughput/device/tokens_per_sec` | logged when dataloader.dataset has `max_seq_len`. This |
| | may include pad tokens depending on dataset |
+-------------------------------------+-----------------------------------------------------------+
| | `throughput/flops_per_sec` divided by world size. Only |
| `throughput/device/flops_per_sec` | logged when model has attribute `flops_per_batch` |
| | |
+-------------------------------------+-----------------------------------------------------------+
| | `throughput/device/flops_per_sec` divided by world size. |
| `throughput/device/mfu` | Only logged when model has attribute `flops_per_batch` |
| | and `gpu_flops_available`, which can be passed as an |
| | argument if not automatically determined by SpeedMonitor |
+-------------------------------------+-----------------------------------------------------------+
| `wall_clock/train` | Total elapsed training time |
+-------------------------------------+-----------------------------------------------------------+
| `wall_clock/val` | Total elapsed validation time |
+-------------------------------------+-----------------------------------------------------------+
| `wall_clock/total` | Total elapsed time (wall_clock/train + wall_clock/val) |
+-------------------------------------+-----------------------------------------------------------+

Args:
window_size (int, optional): Number of batches to use for a rolling average of throughput.
Defaults to 100.
"""

def __init__(self, window_size: int = 100):
def __init__(self, window_size: int = 100, gpu_flops_available: Optional[Union[float, int]] = None):
# Track the batch num samples and wct to compute throughput over a window of batches
self.batch_start_num_samples = 0
self.batch_start_wct = 0.0
self.batch_wct_buffer: Deque[float] = deque(maxlen=window_size)
self.batch_num_samples_buffer: Deque[int] = deque(maxlen=window_size)
self.window_size = window_size
self.history_samples: Deque[int] = deque(maxlen=window_size + 1)
self.history_wct: Deque[float] = deque(maxlen=window_size + 1)

self.gpu_flops_available = gpu_flops_available

# Keep track of time spent evaluating
self.total_eval_wct = 0.0

def state_dict(self) -> Dict[str, Any]:
return {
'batch_start_num_samples': self.batch_start_num_samples,
'batch_start_wct': self.batch_start_wct,
'batch_wct_buffer': self.batch_wct_buffer,
'batch_num_samples_buffer': self.batch_num_samples_buffer,
# "window_wct": self.window_wct,
# "window_num_samples": self.window_num_samples,
'total_eval_wct': self.total_eval_wct,
}

def load_state_dict(self, state: Dict[str, Any]) -> None:
self.batch_start_num_samples = state['batch_start_num_samples']
self.batch_start_wct = state['batch_start_wct']
self.batch_wct_buffer = deque(
[x for x in state['batch_wct_buffer']],
maxlen=self.window_size,
)
self.batch_num_samples_buffer = deque(
[x for x in state['batch_num_samples_buffer']],
maxlen=self.window_size,
)
self.total_eval_wct = state['total_eval_wct']

def before_dataloader(self, state: State, logger: Logger) -> None:
def init(self, state: State, logger: Logger) -> None:
del logger # unused
self.batch_start_wct = state.timestamp.total_wct.total_seconds()
self.batch_start_num_samples = int(state.timestamp.sample)
if self.gpu_flops_available is None:
self.gpu_flops_available = get_gpu_flops_available(state)

def batch_end(self, state: State, logger: Logger):
batch_num_samples = int(state.timestamp.sample) - self.batch_start_num_samples
batch_wct = state.timestamp.total_wct.total_seconds() - self.batch_start_wct

# Add the new element
self.batch_wct_buffer.append(batch_wct)
self.batch_num_samples_buffer.append(batch_num_samples)
self.history_samples.append(state.timestamp.sample.value)
self.history_wct.append(state.timestamp.total_wct.total_seconds())

# Log the throughput
if len(self.batch_num_samples_buffer) == self.window_size:
throughput = sum(self.batch_num_samples_buffer) / sum(self.batch_wct_buffer)
logger.log_metrics({'throughput/samples_per_sec': throughput})
if len(self.history_wct) == self.history_wct.maxlen:
world_size = dist.get_world_size()
elapsed_batches = len(self.history_samples) - 1
elapsed_samples = int(self.history_samples[-1]) - int(self.history_samples[0])
elapsed_wct = self.history_wct[-1] - self.history_wct[0]
batches_per_sec = elapsed_batches / elapsed_wct
samples_per_sec = elapsed_samples / elapsed_wct
dev_batches_per_sec = batches_per_sec / world_size
dev_samples_per_sec = samples_per_sec / world_size
logger.log_metrics({'throughput/batches_per_sec': batches_per_sec})
logger.log_metrics({'throughput/samples_per_sec': samples_per_sec})
logger.log_metrics({'throughput/device/batches_per_sec': dev_batches_per_sec})
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
logger.log_metrics({'throughput/device/samples_per_sec': dev_samples_per_sec})

# Compute token stats if dataloader.dataset has max_seq_len. Assumes no padding.
try:
max_seq_len = state.dataloader.dataset.max_seq_len # type: ignore
# Only applicable to seq data / models
logger.log_metrics({'throughput/tokens_per_sec': samples_per_sec * max_seq_len})
logger.log_metrics({'throughput/device/tokens_per_sec': dev_samples_per_sec * max_seq_len})
except AttributeError:
pass

composer_model = state.model
if not isinstance(composer_model, ComposerModel):
composer_model = composer_model.module # Pass through DDP wrapping
if hasattr(composer_model, 'flops_per_batch'):
model_flops_per_batch = composer_model.flops_per_batch # type: ignore
if not isinstance(model_flops_per_batch, Callable):
raise TypeError('flops_per_batch must a callable accepting a batch and '
f'returning an int or float. Instead, got {type(model_flops_per_batch)}.')
flops_per_batch = model_flops_per_batch(state.batch)
flops_per_sec = flops_per_batch * batches_per_sec
logger.log_metrics({'throughput/flops_per_sec': flops_per_sec})
dev_flops_per_sec = flops_per_sec / world_size
logger.log_metrics({'throughput/device/flops_per_sec': dev_flops_per_sec})
if self.gpu_flops_available:
mfu = dev_flops_per_sec / self.gpu_flops_available
logger.log_metrics({'throughput/device/mfu': mfu})

# Log the time
# `state.timestamp` excludes any time spent in evaluation
logger.log_metrics({
'wall_clock/train': state.timestamp.total_wct.total_seconds(),
'wall_clock/val': self.total_eval_wct,
'wall_clock/total': (state.timestamp.total_wct.total_seconds() + self.total_eval_wct),
'wall_clock/total': state.timestamp.total_wct.total_seconds() + self.total_eval_wct,
})

def eval_end(self, state: State, logger: Logger):
Expand Down
10 changes: 8 additions & 2 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2104,7 +2104,8 @@ def _train_batch(self, use_grad_scaling: bool) -> Dict[str, torch.Tensor]:
"""
assert self._train_data_spec is not None, 'The train data spec should be set on __init__ or fit()'

# Cache the device batch, because `self.state.batch` gets overridden in microbatching loop
# Cache the device batch, because `self.state.batch` gets overridden in microbatching loop.
# Any in-place changes to a microbatch will be reflected in the device batch.
device_batch = self.state.batch

# Retry until we successfully complete training and return loss
Expand Down Expand Up @@ -2212,8 +2213,10 @@ def _train_microbatches(self,
except TypeError:
optimizer.zero_grad()

# tracker for gradient accumulation
# Tracker for gradient accumulation
current_batch_size = sum([self._train_data_spec.get_num_samples_in_batch(batch) for batch in microbatches])
# Cache batch, which will be overwritten by microbatches. Restore after microbatches complete
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
current_batch = self.state.batch
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved

for microbatch_idx, self.state.batch in enumerate(microbatches):
is_final_microbatch = microbatch_idx + 1 == len(microbatches)
Expand All @@ -2226,6 +2229,9 @@ def _train_microbatches(self,
total_loss_dict[loss_key] = self.state.device.tensor_to_device(torch.zeros(size=(1,)))
total_loss_dict[loss_key] += microbatch_loss

# Restore batch
self.state.batch = current_batch

# Unscale gradients before `Event.AFTER_TRAIN_BATCH`
if use_grad_scaling:
for optimizer in ensure_tuple(self.state.optimizers):
Expand Down
Loading