diff --git a/composer/callbacks/speed_monitor.py b/composer/callbacks/speed_monitor.py index e7a3c8e517..0b23521e97 100644 --- a/composer/callbacks/speed_monitor.py +++ b/composer/callbacks/speed_monitor.py @@ -4,24 +4,144 @@ """Monitor throughput during training.""" from __future__ import annotations +import warnings from collections import deque -from typing import Any, Deque, Dict +from typing import Any, Callable, Deque, Dict, Optional, Union + +import torch from composer.core import Callback, State from composer.loggers import Logger +from composer.models.base import ComposerModel +from composer.utils import dist __all__ = ['SpeedMonitor'] +GPU_AVAILABLE_FLOPS = { + # source: https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet + # nvidia publishes spec sheet with a 2x sparsity factor + 'h100-sxm': { + 'fp64': 67e12, + 'fp32': 67e12, + 'tf32': 989e12 / 2, + 'fp16': 1.979e15 / 2, + 'amp_fp16': 1.979e15 / 2, + 'bf16': 1.979e15 / 2, + 'amp_bf16': 1.979e15 / 2, + 'fp8': 3.958e15 / 2, + 'amp_fp8': 3.958e15 / 2, + 'int8': 3.958e15 / 2, + }, + 'h100-pcie': { + 'fp64': 51e12, + 'fp32': 51e12, + 'tf32': 756e12 / 2, + 'fp16': 1.513e15 / 2, + 'amp_fp16': 1.513e15 / 2, + 'bf16': 1.513e15 / 2, + 'amp_bf16': 1.513e15 / 2, + 'fp8': 3.026e15 / 2, + 'amp_fp8': 3.026e15 / 2, + 'int8': 3.026e15 / 2, + }, + # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf + # sxm and pcie have same flop counts + 'a100': { + 'fp64': 19.5e12, + 'fp32': 19.5e12, + 'tf32': 156e12, + 'fp16': 312e12, + 'amp_fp16': 312e12, + 'bf16': 312e12, + 'amp_bf16': 312e12, + }, + # source: https://images.nvidia.com/content/technologies/volta/pdf/volta-v100-datasheet-update-us-1165301-r5.pdf + 'v100-sxm': { + 'fp64': 7.8e12, + 'fp32': 15.7e12, + 'fp16': 125e12, + 'amp_fp16': 125e12, + }, + 'v100-pcie': { + 'fp64': 7e12, + 'fp32': 14e12, + 'fp16': 112e12, + 'amp_fp16': 112e12, + }, + 'v100s-pcie': { + 'fp64': 8.2e12, + 'fp32': 16.4e12, + 'fp16': 130e12, + 'amp_fp16': 130e12, + }, + # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf + # sxm and pcie have same flop counts + 't4': { + 'fp32': 8.1e12, + 'fp16': 65e12, + 'amp_fp16': 65e12, + 'int8': 130e12, + 'int4': 260e12, + }, +} + + +def get_gpu_flops_available(state: State): + gpu_flops_available = None + + # Return 0 if no CUDA device (e.g., when running with CPU only) + if not torch.cuda.is_available(): + return 0 + + # torch.cuda.get_device_name() ex output: 'NVIDIA A100-SXM4-40GB' + device_name = torch.cuda.get_device_name().lower() + if 'h100-sxm' in device_name: + device_name = 'h100-sxm' + elif 'h100-pcie' in device_name: + device_name = 'h100-pcie' + elif 'a100' in device_name: + device_name = 'a100' + elif 'v100-sxm' in device_name: + device_name = 'v100-sxm' + elif 'v100-pcie' in device_name: + device_name = 'v100-pcie' + elif 't4' in device_name: + device_name = 't4' + else: + device_name = None + + if device_name is not None: + try: + gpu_flops_available = int(GPU_AVAILABLE_FLOPS[device_name][state.precision.value]) + except: + gpu_flops_available = None + + if gpu_flops_available is None: + warnings.warn( + f'gpu_flop count not found for {device_name} with precision: {state.precision.value}; ' +\ + f'MFU cannot be calculated and reported. gpu_flops_available can be manually' +\ + f'overridden by setting gpu_flops_available in SpeedMonitor.' + ) + # Setting to 0 will disable MFU computation and prevent + # the speed monitor from running this helper every batch + gpu_flops_available = 0 + + return gpu_flops_available + class SpeedMonitor(Callback): """Logs the training throughput. - The training throughput in terms of number of samples per second is logged on the - :attr:`.Event.BATCH_END` event if we have reached the ``window_size`` threshold. + The training throughput is logged on the :attr:`.Event.BATCH_END` event once we have reached + the `window_size` threshold. If a model has `flops_per_batch` attribute, then flops per second + is also logged. If running on a known GPU type or if `gpu_flops_available` is set, then MFU is + also logged. All metrics are also logged as per device by dividing by world size. - The wall clock train time is logged on every :attr:`.Event.BATCH_END` event. + To compute `flops_per_sec`, the model attribute `flops_per_batch` should be set to a callable + which accepts a batch and returns the number of flops for that batch. Typically, this should + be flops per sample times the batch size unless pad tokens are used. - The average throughout over an epoch is logged on the :attr:`.Event.EPOCH_END` event. + The wall clock time is logged on every :attr:`.Event.BATCH_END` event. Example: .. doctest:: @@ -41,84 +161,130 @@ class SpeedMonitor(Callback): The training throughput is logged by the :class:`.Logger` to the following keys as described below. - +----------------------------------+-------------------------------------------------------------+ - | Key | Logged data | - +==================================+=============================================================+ - | | Rolling average (over ``window_size`` most recent | - | ``throughput/samples_per_sec`` | batches) of the number of samples processed per second | - | | | - +----------------------------------+-------------------------------------------------------------+ - | ``wall_clock/train`` | Total elapsed training time | - +----------------------------------+-------------------------------------------------------------+ - | ``wall_clock/val`` | Total elapsed validation time | - +----------------------------------+-------------------------------------------------------------+ - | ``wall_clock/total`` | Total elapsed time (wall_clock/train + wall_clock/val) | - +----------------------------------+-------------------------------------------------------------+ + +-------------------------------------+-----------------------------------------------------------+ + | Key | Logged data | + +=====================================+===========================================================+ + | | Rolling average (over `window_size` most recent | + | `throughput/batches_per_sec` | batches) of the number of batches processed per second | + | | | + +-------------------------------------+-----------------------------------------------------------+ + | | Rolling average (over `window_size` most recent | + | `throughput/samples_per_sec` | batches) of the number of samples processed per second | + | | | + +-------------------------------------+-----------------------------------------------------------+ + | | Rolling average (over `window_size` most recent | + | `throughput/tokens_per_sec` | batches) of the number of tokens processed per second. | + | | Only logged when dataloader.dataset has `max_seq_len`. | + | | This may include padding depending on dataset | + +-------------------------------------+-----------------------------------------------------------+ + | | Estimates flops by `flops_per_batch * batches_per_sec` | + | `throughput/flops_per_sec` | if model has attribute `flops_per_batch` | + | | | + +-------------------------------------+-----------------------------------------------------------+ + | `throughput/device/batches_per_sec` | `throughput/batches_per_sec` divided by world size | + +-------------------------------------+-----------------------------------------------------------+ + | `throughput/device/samples_per_sec` | `throughput/samples_per_sec` divided by world size | + +-------------------------------------+-----------------------------------------------------------+ + | | `throughput/tokens_per_sec` divided by world size. Only | + | `throughput/device/tokens_per_sec` | logged when dataloader.dataset has `max_seq_len`. This | + | | may include pad tokens depending on dataset | + +-------------------------------------+-----------------------------------------------------------+ + | | `throughput/flops_per_sec` divided by world size. Only | + | `throughput/device/flops_per_sec` | logged when model has attribute `flops_per_batch` | + | | | + +-------------------------------------+-----------------------------------------------------------+ + | | `throughput/device/flops_per_sec` divided by world size. | + | `throughput/device/mfu` | Only logged when model has attribute `flops_per_batch` | + | | and `gpu_flops_available`, which can be passed as an | + | | argument if not automatically determined by SpeedMonitor | + +-------------------------------------+-----------------------------------------------------------+ + | `wall_clock/train` | Total elapsed training time | + +-------------------------------------+-----------------------------------------------------------+ + | `wall_clock/val` | Total elapsed validation time | + +-------------------------------------+-----------------------------------------------------------+ + | `wall_clock/total` | Total elapsed time (wall_clock/train + wall_clock/val) | + +-------------------------------------+-----------------------------------------------------------+ Args: window_size (int, optional): Number of batches to use for a rolling average of throughput. Defaults to 100. """ - def __init__(self, window_size: int = 100): + def __init__(self, window_size: int = 100, gpu_flops_available: Optional[Union[float, int]] = None): # Track the batch num samples and wct to compute throughput over a window of batches - self.batch_start_num_samples = 0 - self.batch_start_wct = 0.0 - self.batch_wct_buffer: Deque[float] = deque(maxlen=window_size) - self.batch_num_samples_buffer: Deque[int] = deque(maxlen=window_size) - self.window_size = window_size + self.history_samples: Deque[int] = deque(maxlen=window_size + 1) + self.history_wct: Deque[float] = deque(maxlen=window_size + 1) + + self.gpu_flops_available = gpu_flops_available # Keep track of time spent evaluating self.total_eval_wct = 0.0 def state_dict(self) -> Dict[str, Any]: return { - 'batch_start_num_samples': self.batch_start_num_samples, - 'batch_start_wct': self.batch_start_wct, - 'batch_wct_buffer': self.batch_wct_buffer, - 'batch_num_samples_buffer': self.batch_num_samples_buffer, - # "window_wct": self.window_wct, - # "window_num_samples": self.window_num_samples, 'total_eval_wct': self.total_eval_wct, } def load_state_dict(self, state: Dict[str, Any]) -> None: - self.batch_start_num_samples = state['batch_start_num_samples'] - self.batch_start_wct = state['batch_start_wct'] - self.batch_wct_buffer = deque( - [x for x in state['batch_wct_buffer']], - maxlen=self.window_size, - ) - self.batch_num_samples_buffer = deque( - [x for x in state['batch_num_samples_buffer']], - maxlen=self.window_size, - ) self.total_eval_wct = state['total_eval_wct'] - def before_dataloader(self, state: State, logger: Logger) -> None: + def init(self, state: State, logger: Logger) -> None: del logger # unused - self.batch_start_wct = state.timestamp.total_wct.total_seconds() - self.batch_start_num_samples = int(state.timestamp.sample) + if self.gpu_flops_available is None: + self.gpu_flops_available = get_gpu_flops_available(state) def batch_end(self, state: State, logger: Logger): - batch_num_samples = int(state.timestamp.sample) - self.batch_start_num_samples - batch_wct = state.timestamp.total_wct.total_seconds() - self.batch_start_wct - # Add the new element - self.batch_wct_buffer.append(batch_wct) - self.batch_num_samples_buffer.append(batch_num_samples) + self.history_samples.append(state.timestamp.sample.value) + self.history_wct.append(state.timestamp.total_wct.total_seconds()) # Log the throughput - if len(self.batch_num_samples_buffer) == self.window_size: - throughput = sum(self.batch_num_samples_buffer) / sum(self.batch_wct_buffer) - logger.log_metrics({'throughput/samples_per_sec': throughput}) + if len(self.history_wct) == self.history_wct.maxlen: + world_size = dist.get_world_size() + elapsed_batches = len(self.history_samples) - 1 + elapsed_samples = int(self.history_samples[-1]) - int(self.history_samples[0]) + elapsed_wct = self.history_wct[-1] - self.history_wct[0] + batches_per_sec = elapsed_batches / elapsed_wct + samples_per_sec = elapsed_samples / elapsed_wct + dev_batches_per_sec = batches_per_sec / world_size + dev_samples_per_sec = samples_per_sec / world_size + logger.log_metrics({'throughput/batches_per_sec': batches_per_sec}) + logger.log_metrics({'throughput/samples_per_sec': samples_per_sec}) + logger.log_metrics({'throughput/device/batches_per_sec': dev_batches_per_sec}) + logger.log_metrics({'throughput/device/samples_per_sec': dev_samples_per_sec}) + + # Compute token stats if dataloader.dataset has max_seq_len. Assumes no padding. + try: + max_seq_len = state.dataloader.dataset.max_seq_len # type: ignore + # Only applicable to seq data / models + logger.log_metrics({'throughput/tokens_per_sec': samples_per_sec * max_seq_len}) + logger.log_metrics({'throughput/device/tokens_per_sec': dev_samples_per_sec * max_seq_len}) + except AttributeError: + pass + + composer_model = state.model + if not isinstance(composer_model, ComposerModel): + composer_model = composer_model.module # Pass through DDP wrapping + if hasattr(composer_model, 'flops_per_batch'): + model_flops_per_batch = composer_model.flops_per_batch # type: ignore + if not isinstance(model_flops_per_batch, Callable): + raise TypeError('flops_per_batch must a callable accepting a batch and ' + f'returning an int or float. Instead, got {type(model_flops_per_batch)}.') + flops_per_batch = model_flops_per_batch(state.batch) + flops_per_sec = flops_per_batch * batches_per_sec + logger.log_metrics({'throughput/flops_per_sec': flops_per_sec}) + dev_flops_per_sec = flops_per_sec / world_size + logger.log_metrics({'throughput/device/flops_per_sec': dev_flops_per_sec}) + if self.gpu_flops_available: + mfu = dev_flops_per_sec / self.gpu_flops_available + logger.log_metrics({'throughput/device/mfu': mfu}) # Log the time # `state.timestamp` excludes any time spent in evaluation logger.log_metrics({ 'wall_clock/train': state.timestamp.total_wct.total_seconds(), 'wall_clock/val': self.total_eval_wct, - 'wall_clock/total': (state.timestamp.total_wct.total_seconds() + self.total_eval_wct), + 'wall_clock/total': state.timestamp.total_wct.total_seconds() + self.total_eval_wct, }) def eval_end(self, state: State, logger: Logger): diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 658835327d..cc9072a51b 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2104,7 +2104,8 @@ def _train_batch(self, use_grad_scaling: bool) -> Dict[str, torch.Tensor]: """ assert self._train_data_spec is not None, 'The train data spec should be set on __init__ or fit()' - # Cache the device batch, because `self.state.batch` gets overridden in microbatching loop + # Cache the device batch, because `self.state.batch` gets overridden in microbatching loop. + # Any in-place changes to a microbatch will be reflected in the device batch. device_batch = self.state.batch # Retry until we successfully complete training and return loss @@ -2212,8 +2213,10 @@ def _train_microbatches(self, except TypeError: optimizer.zero_grad() - # tracker for gradient accumulation + # Tracker for gradient accumulation current_batch_size = sum([self._train_data_spec.get_num_samples_in_batch(batch) for batch in microbatches]) + # Cache batch, which will be overwritten by microbatches. Restore after microbatches complete + current_batch = self.state.batch for microbatch_idx, self.state.batch in enumerate(microbatches): is_final_microbatch = microbatch_idx + 1 == len(microbatches) @@ -2226,6 +2229,9 @@ def _train_microbatches(self, total_loss_dict[loss_key] = self.state.device.tensor_to_device(torch.zeros(size=(1,))) total_loss_dict[loss_key] += microbatch_loss + # Restore batch + self.state.batch = current_batch + # Unscale gradients before `Event.AFTER_TRAIN_BATCH` if use_grad_scaling: for optimizer in ensure_tuple(self.state.optimizers): diff --git a/tests/callbacks/test_runtime_estimator.py b/tests/callbacks/test_runtime_estimator.py new file mode 100644 index 0000000000..fba4b131d6 --- /dev/null +++ b/tests/callbacks/test_runtime_estimator.py @@ -0,0 +1,50 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +import datetime + +from torch.utils.data import DataLoader + +from composer.callbacks import RuntimeEstimator +from composer.core import Time +from composer.loggers import InMemoryLogger +from composer.trainer import Trainer +from tests.common import RandomClassificationDataset, SimpleModel + + +def _assert_no_negative_values(logged_values): + for timestamp, v in logged_values: + del timestamp # unused + if isinstance(v, Time): + assert int(v) >= 0 + elif isinstance(v, datetime.timedelta): + assert v.total_seconds() >= 0 + else: + assert v >= 0 + + +def test_runtime_estimator(): + # Construct the callbacks + skip_batches = 1 + runtime_estimator = RuntimeEstimator(skip_batches=skip_batches) + in_memory_logger = InMemoryLogger() # track the logged metrics in the in_memory_logger + + # Construct the trainer and train + trainer = Trainer( + model=SimpleModel(), + callbacks=runtime_estimator, + loggers=in_memory_logger, + train_dataloader=DataLoader(RandomClassificationDataset()), + eval_dataloader=DataLoader(RandomClassificationDataset()), + max_duration='2ep', + eval_interval='1ep', + train_subset_num_batches=10, + eval_subset_num_batches=10, + ) + trainer.fit() + + wall_clock_remaining_calls = len(in_memory_logger.data['wall_clock/remaining_estimate']) + _assert_no_negative_values(in_memory_logger.data['wall_clock/remaining_estimate']) + + expected_calls = int(trainer.state.timestamp.batch) - skip_batches + assert wall_clock_remaining_calls == expected_calls diff --git a/tests/callbacks/test_speed_monitor.py b/tests/callbacks/test_speed_monitor.py index deafb7e7c2..c442489adb 100644 --- a/tests/callbacks/test_speed_monitor.py +++ b/tests/callbacks/test_speed_monitor.py @@ -4,6 +4,7 @@ import collections.abc import datetime +import pytest from torch.utils.data import DataLoader from composer.callbacks import SpeedMonitor @@ -24,14 +25,19 @@ def _assert_no_negative_values(logged_values): assert v >= 0 -def test_speed_monitor(): +@pytest.mark.parametrize('flops_per_batch', [False, True]) +def test_speed_monitor(flops_per_batch: bool): # Construct the callbacks speed_monitor = SpeedMonitor(window_size=2) in_memory_logger = InMemoryLogger() # track the logged metrics in the in_memory_logger + model = SimpleModel() + if flops_per_batch: + model.flops_per_batch = lambda batch: len(batch) * 100.0 + # Construct the trainer and train trainer = Trainer( - model=SimpleModel(), + model=model, callbacks=speed_monitor, loggers=in_memory_logger, train_dataloader=DataLoader(RandomClassificationDataset()), @@ -40,23 +46,30 @@ def test_speed_monitor(): ) trainer.fit() - wall_clock_train_calls = len(in_memory_logger.data['wall_clock/train']) - wall_clock_val_calls = len(in_memory_logger.data['wall_clock/val']) - wall_clock_total_calls = len(in_memory_logger.data['wall_clock/total']) - throughput_step_calls = len(in_memory_logger.data['throughput/samples_per_sec']) _assert_no_negative_values(in_memory_logger.data['wall_clock/train']) _assert_no_negative_values(in_memory_logger.data['wall_clock/val']) _assert_no_negative_values(in_memory_logger.data['wall_clock/total']) - _assert_no_negative_values(in_memory_logger.data['wall_clock/train']) + _assert_no_negative_values(in_memory_logger.data['throughput/batches_per_sec']) _assert_no_negative_values(in_memory_logger.data['throughput/samples_per_sec']) + _assert_no_negative_values(in_memory_logger.data['throughput/device/batches_per_sec']) + _assert_no_negative_values(in_memory_logger.data['throughput/device/samples_per_sec']) + if flops_per_batch: + _assert_no_negative_values(in_memory_logger.data['throughput/flops_per_sec']) + _assert_no_negative_values(in_memory_logger.data['throughput/device/flops_per_sec']) assert isinstance(trainer.state.dataloader, collections.abc.Sized) assert trainer.state.dataloader_label is not None assert trainer.state.dataloader_len is not None - expected_step_calls = (trainer.state.dataloader_len - speed_monitor.window_size + 1) * int( + expected_step_calls = (trainer.state.dataloader_len - len(speed_monitor.history_samples) + 1) * int( trainer.state.timestamp.epoch) - assert throughput_step_calls == expected_step_calls + assert len(in_memory_logger.data['throughput/batches_per_sec']) == expected_step_calls + assert len(in_memory_logger.data['throughput/samples_per_sec']) == expected_step_calls + assert len(in_memory_logger.data['throughput/device/batches_per_sec']) == expected_step_calls + assert len(in_memory_logger.data['throughput/device/samples_per_sec']) == expected_step_calls + if flops_per_batch: + assert len(in_memory_logger.data['throughput/flops_per_sec']) == expected_step_calls + assert len(in_memory_logger.data['throughput/device/flops_per_sec']) == expected_step_calls num_batches = int(trainer.state.timestamp.batch) - assert wall_clock_total_calls == num_batches - assert wall_clock_train_calls == num_batches - assert wall_clock_val_calls == num_batches + assert len(in_memory_logger.data['wall_clock/total']) == num_batches + assert len(in_memory_logger.data['wall_clock/train']) == num_batches + assert len(in_memory_logger.data['wall_clock/val']) == num_batches