From e3a65e8deded198f968d95a48dd20dcf82fb5337 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 21 Feb 2023 16:05:49 -0800 Subject: [PATCH 01/10] add speed monitor refactor --- composer/callbacks/speed_monitor.py | 271 +++++++++++++++++++++----- tests/callbacks/test_speed_monitor.py | 15 +- 2 files changed, 223 insertions(+), 63 deletions(-) diff --git a/composer/callbacks/speed_monitor.py b/composer/callbacks/speed_monitor.py index e7a3c8e517..e1f9f51abc 100644 --- a/composer/callbacks/speed_monitor.py +++ b/composer/callbacks/speed_monitor.py @@ -4,24 +4,140 @@ """Monitor throughput during training.""" from __future__ import annotations +import warnings from collections import deque -from typing import Any, Deque, Dict +from typing import Any, Deque, Dict, Optional, Union + +import torch from composer.core import Callback, State from composer.loggers import Logger +from composer.models.base import ComposerModel +from composer.utils import dist __all__ = ['SpeedMonitor'] +GPU_AVAILABLE_FLOPS = { + # source: https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet + # nvidia publishes spec sheet with a 2x sparsity factor + 'h100-sxm': { + 'fp64': 67e12, + 'fp32': 67e12, + 'tf32': 989e12 / 2, + 'fp16': 1.979e15 / 2, + 'amp_fp16': 1.979e15 / 2, + 'bf16': 1.979e15 / 2, + 'amp_bf16': 1.979e15 / 2, + 'fp8': 3.958e15 / 2, + 'amp_fp8': 3.958e15 / 2, + 'int8': 3.958e15 / 2, + }, + 'h100-pcie': { + 'fp64': 51e12, + 'fp32': 51e12, + 'tf32': 756e12 / 2, + 'fp16': 1.513e15 / 2, + 'amp_fp16': 1.513e15 / 2, + 'bf16': 1.513e15 / 2, + 'amp_bf16': 1.513e15 / 2, + 'fp8': 3.026e15 / 2, + 'amp_fp8': 3.026e15 / 2, + 'int8': 3.026e15 / 2, + }, + # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf + # sxm and pcie have same flop counts + 'a100': { + 'fp64': 19.5e12, + 'fp32': 19.5e12, + 'tf32': 156e12, + 'fp16': 312e12, + 'amp_fp16': 312e12, + 'bf16': 312e12, + 'amp_bf16': 312e12, + }, + # source: https://images.nvidia.com/content/technologies/volta/pdf/volta-v100-datasheet-update-us-1165301-r5.pdf + 'v100-sxm': { + 'fp64': 7.8e12, + 'fp32': 15.7e12, + 'fp16': 125e12, + 'amp_fp16': 125e12, + }, + 'v100-pcie': { + 'fp64': 7e12, + 'fp32': 14e12, + 'fp16': 112e12, + 'amp_fp16': 112e12, + }, + 'v100s-pcie': { + 'fp64': 8.2e12, + 'fp32': 16.4e12, + 'fp16': 130e12, + 'amp_fp16': 130e12, + }, + # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf + # sxm and pcie have same flop counts + 't4': { + 'fp32': 8.1e12, + 'fp16': 65e12, + 'amp_fp16': 65e12, + 'int8': 130e12, + 'int4': 260e12, + }, +} + + +def get_gpu_flops_available(state: State): + gpu_flops_available = None + + # Return 0 if no CUDA device (e.g., when running with CPU only) + if not torch.cuda.is_available(): + return 0 + + # torch.cuda.get_device_name() ex output: 'NVIDIA A100-SXM4-40GB' + dev_name = torch.cuda.get_device_name().lower() + if 'h100-sxm' in dev_name: + dev_name = 'h100-sxm' + elif 'h100-pcie' in dev_name: + dev_name = 'h100-pcie' + elif 'a100' in dev_name: + dev_name = 'a100' + elif 'v100-sxm' in dev_name: + dev_name = 'v100-sxm' + elif 'v100-pcie' in dev_name: + dev_name = 'v100-pcie' + elif 't4' in dev_name: + dev_name = 't4' + else: + dev_name = None + + if dev_name: + try: + gpu_flops_available = int(GPU_AVAILABLE_FLOPS[dev_name][state.precision.value]) + except: + gpu_flops_available = None + + if gpu_flops_available is None: + warnings.warn( + f'gpu_flop count not found for {dev_name=} with precision: {state.precision.value}; ' +\ + f'MFU cannot be calculated and reported. gpu_flops_available can be manually' +\ + f'overridden by setting gpu_flops_available in SpeedMonitor.' + ) + # Setting to 0 will disable MFU computation and prevent + # the speed monitor from running this helper every batch + gpu_flops_available = 0 + + return gpu_flops_available + class SpeedMonitor(Callback): """Logs the training throughput. - The training throughput in terms of number of samples per second is logged on the - :attr:`.Event.BATCH_END` event if we have reached the ``window_size`` threshold. - - The wall clock train time is logged on every :attr:`.Event.BATCH_END` event. + The training throughput is logged on the :attr:`.Event.BATCH_END` event once we have reached + the `window_size` threshold. If a model has `flops_per_batch` attribute, then flops per second + is also logged. If running on a known GPU type or if `gpu_flops_available` is set, then MFU is + also logged. All metrics are also logged as per device by dividing by world size. - The average throughout over an epoch is logged on the :attr:`.Event.EPOCH_END` event. + The wall clock time is logged on every :attr:`.Event.BATCH_END` event. Example: .. doctest:: @@ -41,84 +157,133 @@ class SpeedMonitor(Callback): The training throughput is logged by the :class:`.Logger` to the following keys as described below. - +----------------------------------+-------------------------------------------------------------+ - | Key | Logged data | - +==================================+=============================================================+ - | | Rolling average (over ``window_size`` most recent | - | ``throughput/samples_per_sec`` | batches) of the number of samples processed per second | - | | | - +----------------------------------+-------------------------------------------------------------+ - | ``wall_clock/train`` | Total elapsed training time | - +----------------------------------+-------------------------------------------------------------+ - | ``wall_clock/val`` | Total elapsed validation time | - +----------------------------------+-------------------------------------------------------------+ - | ``wall_clock/total`` | Total elapsed time (wall_clock/train + wall_clock/val) | - +----------------------------------+-------------------------------------------------------------+ + +-------------------------------------+-----------------------------------------------------------+ + | Key | Logged data | + +=====================================+===========================================================+ + | | Rolling average (over `window_size` most recent | + | `throughput/batches_per_sec` | batches) of the number of batches processed per second | + | | | + +-------------------------------------+-----------------------------------------------------------+ + | | Rolling average (over `window_size` most recent | + | `throughput/samples_per_sec` | batches) of the number of samples processed per second | + | | | + +-------------------------------------+-----------------------------------------------------------+ + | | Rolling average (over `window_size` most recent | + | `throughput/tokens_per_sec` | batches) of the number of tokens processed per second. | + | | Only logged when dataloader.dataset has `max_seq_len`. | + | | This may include padding depending on dataset | + +-------------------------------------+-----------------------------------------------------------+ + | | Estimates flops by `flops_per_batch * samples_per_sec` | + | `throughput/flops_per_sec` | if model has attribute `flops_per_batch` | + | | | + +-------------------------------------+-----------------------------------------------------------+ + | | If model has attribute `flops_per_batch`, estimates | + | `throughput/flops_per_sec` | flops per second with `flops_per_batch * samples_per_sec` | + | | | + +=====================================+===========================================================+ + | `throughput/device/batches_per_sec` | `throughput/batches_per_sec` divided by world size | + +-------------------------------------+-----------------------------------------------------------+ + | `throughput/device/samples_per_sec` | `throughput/samples_per_sec` divided by world size | + +-------------------------------------+-----------------------------------------------------------+ + | | `throughput/tokens_per_sec` divided by world size. Only | + | `throughput/device/tokens_per_sec` | logged when dataloader.dataset has `max_seq_len`. This | + | | may include pad tokens depending on how dataset | + +-------------------------------------+-----------------------------------------------------------+ + | | `throughput/flops_per_sec` divided by world size. Only | + | `throughput/device/flops_per_sec` | logged when model has attribute `flops_per_batch` | + | | | + +-------------------------------------+-----------------------------------------------------------+ + | | `throughput/device/flops_per_sec` divided by world size. | + | `throughput/device/mfu` | Only logged when model has attribute `flops_per_batch` | + | | and `gpu_flops_available`, which can be passed as an | + | | argument if not automatically determined by SpeedMonitor | + +=====================================+===========================================================+ + | `wall_clock/train` | Total elapsed training time | + +-------------------------------------+-----------------------------------------------------------+ + | `wall_clock/val` | Total elapsed validation time | + +-------------------------------------+-----------------------------------------------------------+ + | `wall_clock/total` | Total elapsed time (wall_clock/train + wall_clock/val) | + +-------------------------------------+-----------------------------------------------------------+ Args: window_size (int, optional): Number of batches to use for a rolling average of throughput. Defaults to 100. """ - def __init__(self, window_size: int = 100): + def __init__(self, window_size: int = 100, gpu_flops_available: Optional[Union[float, int]] = None): # Track the batch num samples and wct to compute throughput over a window of batches - self.batch_start_num_samples = 0 - self.batch_start_wct = 0.0 - self.batch_wct_buffer: Deque[float] = deque(maxlen=window_size) - self.batch_num_samples_buffer: Deque[int] = deque(maxlen=window_size) - self.window_size = window_size + self.history_samples: Deque[int] = deque(maxlen=window_size + 1) + self.history_wct: Deque[float] = deque(maxlen=window_size + 1) + + self.set_gpu_flops_available = False + self.gpu_flops_available = gpu_flops_available # Keep track of time spent evaluating self.total_eval_wct = 0.0 def state_dict(self) -> Dict[str, Any]: return { - 'batch_start_num_samples': self.batch_start_num_samples, - 'batch_start_wct': self.batch_start_wct, - 'batch_wct_buffer': self.batch_wct_buffer, - 'batch_num_samples_buffer': self.batch_num_samples_buffer, - # "window_wct": self.window_wct, - # "window_num_samples": self.window_num_samples, 'total_eval_wct': self.total_eval_wct, } def load_state_dict(self, state: Dict[str, Any]) -> None: - self.batch_start_num_samples = state['batch_start_num_samples'] - self.batch_start_wct = state['batch_start_wct'] - self.batch_wct_buffer = deque( - [x for x in state['batch_wct_buffer']], - maxlen=self.window_size, - ) - self.batch_num_samples_buffer = deque( - [x for x in state['batch_num_samples_buffer']], - maxlen=self.window_size, - ) self.total_eval_wct = state['total_eval_wct'] - def before_dataloader(self, state: State, logger: Logger) -> None: + def init(self, state: State, logger: Logger) -> None: del logger # unused - self.batch_start_wct = state.timestamp.total_wct.total_seconds() - self.batch_start_num_samples = int(state.timestamp.sample) + if self.gpu_flops_available is None: + self.gpu_flops_available = get_gpu_flops_available(state) def batch_end(self, state: State, logger: Logger): - batch_num_samples = int(state.timestamp.sample) - self.batch_start_num_samples - batch_wct = state.timestamp.total_wct.total_seconds() - self.batch_start_wct - # Add the new element - self.batch_wct_buffer.append(batch_wct) - self.batch_num_samples_buffer.append(batch_num_samples) + self.history_samples.append(state.timestamp.sample.value) + self.history_wct.append(state.timestamp.total_wct.total_seconds()) # Log the throughput - if len(self.batch_num_samples_buffer) == self.window_size: - throughput = sum(self.batch_num_samples_buffer) / sum(self.batch_wct_buffer) - logger.log_metrics({'throughput/samples_per_sec': throughput}) + if len(self.history_wct) == self.history_wct.maxlen: + world_size = dist.get_world_size() + elapsed_batches = len(self.history_samples) - 1 + elapsed_samples = int(self.history_samples[-1]) - int(self.history_samples[0]) + elapsed_wct = self.history_wct[-1] - self.history_wct[0] + batches_per_sec = elapsed_batches / elapsed_wct + samples_per_sec = elapsed_samples / elapsed_wct + dev_batches_per_sec = batches_per_sec / world_size + dev_samples_per_sec = samples_per_sec / world_size + logger.log_metrics({'throughput/batches_per_sec': batches_per_sec}) + logger.log_metrics({'throughput/samples_per_sec': samples_per_sec}) + logger.log_metrics({'throughput/device/batches_per_sec': dev_batches_per_sec}) + logger.log_metrics({'throughput/device/samples_per_sec': dev_samples_per_sec}) + + # Compute token stats if dataloader.dataset has max_seq_len. Assumes no padding. + try: + max_seq_len = state.dataloader.dataset.max_seq_len # type: ignore + # Only applicable to seq data / models + logger.log_metrics({'throughput/tokens_per_sec': samples_per_sec * max_seq_len}) + logger.log_metrics({'throughput/device/tokens_per_sec': dev_samples_per_sec * max_seq_len}) + except AttributeError: + pass + + composer_model = state.model + if not isinstance(composer_model, ComposerModel): + composer_model = composer_model.module # Pass through DDP wrapping + if hasattr(composer_model, 'num_fwd_flops'): + num_fwd_flops = composer_model.num_fwd_flops # type: ignore + if not isinstance(num_fwd_flops, (int, float)): + raise TypeError(f'num_fwd_flops must be int or float, got {type(num_fwd_flops)}.') + flops_per_sec = num_fwd_flops * samples_per_sec + logger.log_metrics({'throughput/flops_per_sec': flops_per_sec}) + dev_flops_per_sec = flops_per_sec / world_size + logger.log_metrics({'throughput/device/flops_per_sec': dev_flops_per_sec}) + if self.gpu_flops_available: + mfu = dev_flops_per_sec / self.gpu_flops_available + logger.log_metrics({'throughput/device/mfu': mfu}) # Log the time # `state.timestamp` excludes any time spent in evaluation logger.log_metrics({ 'wall_clock/train': state.timestamp.total_wct.total_seconds(), 'wall_clock/val': self.total_eval_wct, - 'wall_clock/total': (state.timestamp.total_wct.total_seconds() + self.total_eval_wct), + 'wall_clock/total': state.timestamp.total_wct.total_seconds() + self.total_eval_wct, }) def eval_end(self, state: State, logger: Logger): diff --git a/tests/callbacks/test_speed_monitor.py b/tests/callbacks/test_speed_monitor.py index deafb7e7c2..56779c82d6 100644 --- a/tests/callbacks/test_speed_monitor.py +++ b/tests/callbacks/test_speed_monitor.py @@ -40,23 +40,18 @@ def test_speed_monitor(): ) trainer.fit() - wall_clock_train_calls = len(in_memory_logger.data['wall_clock/train']) - wall_clock_val_calls = len(in_memory_logger.data['wall_clock/val']) - wall_clock_total_calls = len(in_memory_logger.data['wall_clock/total']) - throughput_step_calls = len(in_memory_logger.data['throughput/samples_per_sec']) _assert_no_negative_values(in_memory_logger.data['wall_clock/train']) _assert_no_negative_values(in_memory_logger.data['wall_clock/val']) _assert_no_negative_values(in_memory_logger.data['wall_clock/total']) - _assert_no_negative_values(in_memory_logger.data['wall_clock/train']) _assert_no_negative_values(in_memory_logger.data['throughput/samples_per_sec']) assert isinstance(trainer.state.dataloader, collections.abc.Sized) assert trainer.state.dataloader_label is not None assert trainer.state.dataloader_len is not None - expected_step_calls = (trainer.state.dataloader_len - speed_monitor.window_size + 1) * int( + expected_step_calls = (trainer.state.dataloader_len - len(speed_monitor.history_samples)) * int( trainer.state.timestamp.epoch) - assert throughput_step_calls == expected_step_calls + assert len(in_memory_logger.data['throughput/samples_per_sec']) == expected_step_calls num_batches = int(trainer.state.timestamp.batch) - assert wall_clock_total_calls == num_batches - assert wall_clock_train_calls == num_batches - assert wall_clock_val_calls == num_batches + assert len(in_memory_logger.data['wall_clock/total']) == num_batches + assert len(in_memory_logger.data['wall_clock/train']) == num_batches + assert len(in_memory_logger.data['wall_clock/val']) == num_batches From 2cae07e84bd6c5f8c96b0367d7819f74682dfdc9 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 21 Feb 2023 19:43:32 -0800 Subject: [PATCH 02/10] fix docs --- composer/callbacks/speed_monitor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/composer/callbacks/speed_monitor.py b/composer/callbacks/speed_monitor.py index e1f9f51abc..ce115b9466 100644 --- a/composer/callbacks/speed_monitor.py +++ b/composer/callbacks/speed_monitor.py @@ -180,7 +180,7 @@ class SpeedMonitor(Callback): | | If model has attribute `flops_per_batch`, estimates | | `throughput/flops_per_sec` | flops per second with `flops_per_batch * samples_per_sec` | | | | - +=====================================+===========================================================+ + +-------------------------------------+-----------------------------------------------------------+ | `throughput/device/batches_per_sec` | `throughput/batches_per_sec` divided by world size | +-------------------------------------+-----------------------------------------------------------+ | `throughput/device/samples_per_sec` | `throughput/samples_per_sec` divided by world size | @@ -197,7 +197,7 @@ class SpeedMonitor(Callback): | `throughput/device/mfu` | Only logged when model has attribute `flops_per_batch` | | | and `gpu_flops_available`, which can be passed as an | | | argument if not automatically determined by SpeedMonitor | - +=====================================+===========================================================+ + +-------------------------------------+-----------------------------------------------------------+ | `wall_clock/train` | Total elapsed training time | +-------------------------------------+-----------------------------------------------------------+ | `wall_clock/val` | Total elapsed validation time | From 8d21a5c63ef9a42df983265ff5faff5fb45721a0 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 21 Feb 2023 19:55:05 -0800 Subject: [PATCH 03/10] fix tests --- tests/callbacks/test_speed_monitor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/callbacks/test_speed_monitor.py b/tests/callbacks/test_speed_monitor.py index 56779c82d6..8a0befbb30 100644 --- a/tests/callbacks/test_speed_monitor.py +++ b/tests/callbacks/test_speed_monitor.py @@ -49,7 +49,7 @@ def test_speed_monitor(): assert trainer.state.dataloader_label is not None assert trainer.state.dataloader_len is not None expected_step_calls = (trainer.state.dataloader_len - len(speed_monitor.history_samples)) * int( - trainer.state.timestamp.epoch) + trainer.state.timestamp.epoch) - 1 assert len(in_memory_logger.data['throughput/samples_per_sec']) == expected_step_calls num_batches = int(trainer.state.timestamp.batch) assert len(in_memory_logger.data['wall_clock/total']) == num_batches From 3b79d01d8bfe9c587a28c015a217b2707e2e1e99 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 21 Feb 2023 20:15:01 -0800 Subject: [PATCH 04/10] fix remove 1 --- tests/callbacks/test_speed_monitor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/callbacks/test_speed_monitor.py b/tests/callbacks/test_speed_monitor.py index 8a0befbb30..ae1d5dd4bb 100644 --- a/tests/callbacks/test_speed_monitor.py +++ b/tests/callbacks/test_speed_monitor.py @@ -48,8 +48,8 @@ def test_speed_monitor(): assert isinstance(trainer.state.dataloader, collections.abc.Sized) assert trainer.state.dataloader_label is not None assert trainer.state.dataloader_len is not None - expected_step_calls = (trainer.state.dataloader_len - len(speed_monitor.history_samples)) * int( - trainer.state.timestamp.epoch) - 1 + expected_step_calls = (trainer.state.dataloader_len - len(speed_monitor.history_samples) + 1) * int( + trainer.state.timestamp.epoch) assert len(in_memory_logger.data['throughput/samples_per_sec']) == expected_step_calls num_batches = int(trainer.state.timestamp.batch) assert len(in_memory_logger.data['wall_clock/total']) == num_batches From 30d656003141adb72707f5e2c6943d9a1aad8ca6 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 21 Feb 2023 20:23:35 -0800 Subject: [PATCH 05/10] extend test --- composer/callbacks/speed_monitor.py | 10 +++++----- tests/callbacks/test_speed_monitor.py | 22 ++++++++++++++++++++-- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/composer/callbacks/speed_monitor.py b/composer/callbacks/speed_monitor.py index ce115b9466..fc73308c7a 100644 --- a/composer/callbacks/speed_monitor.py +++ b/composer/callbacks/speed_monitor.py @@ -266,11 +266,11 @@ def batch_end(self, state: State, logger: Logger): composer_model = state.model if not isinstance(composer_model, ComposerModel): composer_model = composer_model.module # Pass through DDP wrapping - if hasattr(composer_model, 'num_fwd_flops'): - num_fwd_flops = composer_model.num_fwd_flops # type: ignore - if not isinstance(num_fwd_flops, (int, float)): - raise TypeError(f'num_fwd_flops must be int or float, got {type(num_fwd_flops)}.') - flops_per_sec = num_fwd_flops * samples_per_sec + if hasattr(composer_model, 'flops_per_batch'): + flops_per_batch = composer_model.flops_per_batch # type: ignore + if not isinstance(flops_per_batch, (int, float)): + raise TypeError(f'flops_per_batch must be int or float, got {type(flops_per_batch)}.') + flops_per_sec = flops_per_batch * samples_per_sec logger.log_metrics({'throughput/flops_per_sec': flops_per_sec}) dev_flops_per_sec = flops_per_sec / world_size logger.log_metrics({'throughput/device/flops_per_sec': dev_flops_per_sec}) diff --git a/tests/callbacks/test_speed_monitor.py b/tests/callbacks/test_speed_monitor.py index ae1d5dd4bb..649290318b 100644 --- a/tests/callbacks/test_speed_monitor.py +++ b/tests/callbacks/test_speed_monitor.py @@ -4,6 +4,7 @@ import collections.abc import datetime +import pytest from torch.utils.data import DataLoader from composer.callbacks import SpeedMonitor @@ -24,14 +25,19 @@ def _assert_no_negative_values(logged_values): assert v >= 0 -def test_speed_monitor(): +@pytest.mark.parametrize('has_flops_per_batch', [True, False]) +def test_speed_monitor(has_flops_per_batch: bool): # Construct the callbacks speed_monitor = SpeedMonitor(window_size=2) in_memory_logger = InMemoryLogger() # track the logged metrics in the in_memory_logger + model = SimpleModel() + if has_flops_per_batch: + model.flops_per_batch = 100 + # Construct the trainer and train trainer = Trainer( - model=SimpleModel(), + model=model, callbacks=speed_monitor, loggers=in_memory_logger, train_dataloader=DataLoader(RandomClassificationDataset()), @@ -43,14 +49,26 @@ def test_speed_monitor(): _assert_no_negative_values(in_memory_logger.data['wall_clock/train']) _assert_no_negative_values(in_memory_logger.data['wall_clock/val']) _assert_no_negative_values(in_memory_logger.data['wall_clock/total']) + _assert_no_negative_values(in_memory_logger.data['throughput/batches_per_sec']) _assert_no_negative_values(in_memory_logger.data['throughput/samples_per_sec']) + _assert_no_negative_values(in_memory_logger.data['throughput/device/batches_per_sec']) + _assert_no_negative_values(in_memory_logger.data['throughput/device/samples_per_sec']) + if has_flops_per_batch: + _assert_no_negative_values(in_memory_logger.data['throughput/flops_per_sec']) + _assert_no_negative_values(in_memory_logger.data['throughput/device/flops_per_sec']) assert isinstance(trainer.state.dataloader, collections.abc.Sized) assert trainer.state.dataloader_label is not None assert trainer.state.dataloader_len is not None expected_step_calls = (trainer.state.dataloader_len - len(speed_monitor.history_samples) + 1) * int( trainer.state.timestamp.epoch) + assert len(in_memory_logger.data['throughput/batches_per_sec']) == expected_step_calls assert len(in_memory_logger.data['throughput/samples_per_sec']) == expected_step_calls + assert len(in_memory_logger.data['throughput/device/batches_per_sec']) == expected_step_calls + assert len(in_memory_logger.data['throughput/device/samples_per_sec']) == expected_step_calls + if has_flops_per_batch: + assert len(in_memory_logger.data['throughput/flops_per_sec']) == expected_step_calls + assert len(in_memory_logger.data['throughput/device/flops_per_sec']) == expected_step_calls num_batches = int(trainer.state.timestamp.batch) assert len(in_memory_logger.data['wall_clock/total']) == num_batches assert len(in_memory_logger.data['wall_clock/train']) == num_batches From 50decf562388cfdb55e25d1a885d3c19601ab0b7 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Wed, 22 Feb 2023 10:38:32 -0800 Subject: [PATCH 06/10] format --- composer/callbacks/speed_monitor.py | 18 ++++++++++++++---- composer/trainer/trainer.py | 7 ++++++- tests/callbacks/test_speed_monitor.py | 15 +++++++++------ 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/composer/callbacks/speed_monitor.py b/composer/callbacks/speed_monitor.py index fc73308c7a..0175cdaa6b 100644 --- a/composer/callbacks/speed_monitor.py +++ b/composer/callbacks/speed_monitor.py @@ -6,7 +6,7 @@ import warnings from collections import deque -from typing import Any, Deque, Dict, Optional, Union +from typing import Any, Callable, Deque, Dict, Optional, Union import torch @@ -137,6 +137,10 @@ class SpeedMonitor(Callback): is also logged. If running on a known GPU type or if `gpu_flops_available` is set, then MFU is also logged. All metrics are also logged as per device by dividing by world size. + To specify `flops_per_batch`, the model attribute can either be set as an int or float, which + would be used for every batch, or as a callable which accepts a batch and returns an int or + float. The latter formulation is useful for filtering out the flops of padding tokens. + The wall clock time is logged on every :attr:`.Event.BATCH_END` event. Example: @@ -267,9 +271,15 @@ def batch_end(self, state: State, logger: Logger): if not isinstance(composer_model, ComposerModel): composer_model = composer_model.module # Pass through DDP wrapping if hasattr(composer_model, 'flops_per_batch'): - flops_per_batch = composer_model.flops_per_batch # type: ignore - if not isinstance(flops_per_batch, (int, float)): - raise TypeError(f'flops_per_batch must be int or float, got {type(flops_per_batch)}.') + model_flops_per_batch = composer_model.flops_per_batch # type: ignore + flops_per_batch = None + if isinstance(model_flops_per_batch, (int, float)): + flops_per_batch = model_flops_per_batch + elif isinstance(model_flops_per_batch, Callable): + flops_per_batch = model_flops_per_batch(state.batch) + else: + raise TypeError(f'flops_per_batch must be int, float, or callable accepting a batch and ' + 'returning an int or float. Instead, got {type(flops_per_batch)}.') flops_per_sec = flops_per_batch * samples_per_sec logger.log_metrics({'throughput/flops_per_sec': flops_per_sec}) dev_flops_per_sec = flops_per_sec / world_size diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 90cd25c41b..5c0d783c04 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2207,8 +2207,10 @@ def _train_microbatches(self, except TypeError: optimizer.zero_grad() - # tracker for gradient accumulation + # Tracker for gradient accumulation current_batch_size = sum([self._train_data_spec.get_num_samples_in_batch(batch) for batch in microbatches]) + # Cache batch, which will be overwritten by microbatches. Restore after microbatches complete + current_batch = self.state.batch for microbatch_idx, self.state.batch in enumerate(microbatches): is_final_microbatch = microbatch_idx + 1 == len(microbatches) @@ -2221,6 +2223,9 @@ def _train_microbatches(self, total_loss_dict[loss_key] = self.state.device.tensor_to_device(torch.zeros(size=(1,))) total_loss_dict[loss_key] += microbatch_loss + # Restore batch + self.state.batch = current_batch + # Unscale gradients before `Event.AFTER_TRAIN_BATCH` if use_grad_scaling: for optimizer in ensure_tuple(self.state.optimizers): diff --git a/tests/callbacks/test_speed_monitor.py b/tests/callbacks/test_speed_monitor.py index 649290318b..144d1f509f 100644 --- a/tests/callbacks/test_speed_monitor.py +++ b/tests/callbacks/test_speed_monitor.py @@ -25,15 +25,18 @@ def _assert_no_negative_values(logged_values): assert v >= 0 -@pytest.mark.parametrize('has_flops_per_batch', [True, False]) -def test_speed_monitor(has_flops_per_batch: bool): +@pytest.mark.parametrize('flops_per_batch_float,flops_per_batch_callable', + [[False, False], [True, False], [False, True]]) +def test_speed_monitor(flops_per_batch_float: bool, flops_per_batch_callable: bool): # Construct the callbacks speed_monitor = SpeedMonitor(window_size=2) in_memory_logger = InMemoryLogger() # track the logged metrics in the in_memory_logger model = SimpleModel() - if has_flops_per_batch: - model.flops_per_batch = 100 + if flops_per_batch_float: + model.flops_per_batch = 100.0 + if flops_per_batch_callable: + model.flops_per_batch = lambda _batch: 100.0 # Construct the trainer and train trainer = Trainer( @@ -53,7 +56,7 @@ def test_speed_monitor(has_flops_per_batch: bool): _assert_no_negative_values(in_memory_logger.data['throughput/samples_per_sec']) _assert_no_negative_values(in_memory_logger.data['throughput/device/batches_per_sec']) _assert_no_negative_values(in_memory_logger.data['throughput/device/samples_per_sec']) - if has_flops_per_batch: + if flops_per_batch_float or flops_per_batch_callable: _assert_no_negative_values(in_memory_logger.data['throughput/flops_per_sec']) _assert_no_negative_values(in_memory_logger.data['throughput/device/flops_per_sec']) @@ -66,7 +69,7 @@ def test_speed_monitor(has_flops_per_batch: bool): assert len(in_memory_logger.data['throughput/samples_per_sec']) == expected_step_calls assert len(in_memory_logger.data['throughput/device/batches_per_sec']) == expected_step_calls assert len(in_memory_logger.data['throughput/device/samples_per_sec']) == expected_step_calls - if has_flops_per_batch: + if flops_per_batch_float or flops_per_batch_callable: assert len(in_memory_logger.data['throughput/flops_per_sec']) == expected_step_calls assert len(in_memory_logger.data['throughput/device/flops_per_sec']) == expected_step_calls num_batches = int(trainer.state.timestamp.batch) From 38126b2f5fb164461a2107351a39dd7e2e31a76c Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Thu, 23 Feb 2023 13:55:20 -0800 Subject: [PATCH 07/10] respond to comments --- composer/callbacks/speed_monitor.py | 63 ++++++++++++--------------- composer/trainer/trainer.py | 5 +-- tests/callbacks/test_speed_monitor.py | 15 +++---- 3 files changed, 34 insertions(+), 49 deletions(-) diff --git a/composer/callbacks/speed_monitor.py b/composer/callbacks/speed_monitor.py index 0175cdaa6b..0b23521e97 100644 --- a/composer/callbacks/speed_monitor.py +++ b/composer/callbacks/speed_monitor.py @@ -94,31 +94,31 @@ def get_gpu_flops_available(state: State): return 0 # torch.cuda.get_device_name() ex output: 'NVIDIA A100-SXM4-40GB' - dev_name = torch.cuda.get_device_name().lower() - if 'h100-sxm' in dev_name: - dev_name = 'h100-sxm' - elif 'h100-pcie' in dev_name: - dev_name = 'h100-pcie' - elif 'a100' in dev_name: - dev_name = 'a100' - elif 'v100-sxm' in dev_name: - dev_name = 'v100-sxm' - elif 'v100-pcie' in dev_name: - dev_name = 'v100-pcie' - elif 't4' in dev_name: - dev_name = 't4' + device_name = torch.cuda.get_device_name().lower() + if 'h100-sxm' in device_name: + device_name = 'h100-sxm' + elif 'h100-pcie' in device_name: + device_name = 'h100-pcie' + elif 'a100' in device_name: + device_name = 'a100' + elif 'v100-sxm' in device_name: + device_name = 'v100-sxm' + elif 'v100-pcie' in device_name: + device_name = 'v100-pcie' + elif 't4' in device_name: + device_name = 't4' else: - dev_name = None + device_name = None - if dev_name: + if device_name is not None: try: - gpu_flops_available = int(GPU_AVAILABLE_FLOPS[dev_name][state.precision.value]) + gpu_flops_available = int(GPU_AVAILABLE_FLOPS[device_name][state.precision.value]) except: gpu_flops_available = None if gpu_flops_available is None: warnings.warn( - f'gpu_flop count not found for {dev_name=} with precision: {state.precision.value}; ' +\ + f'gpu_flop count not found for {device_name} with precision: {state.precision.value}; ' +\ f'MFU cannot be calculated and reported. gpu_flops_available can be manually' +\ f'overridden by setting gpu_flops_available in SpeedMonitor.' ) @@ -137,9 +137,9 @@ class SpeedMonitor(Callback): is also logged. If running on a known GPU type or if `gpu_flops_available` is set, then MFU is also logged. All metrics are also logged as per device by dividing by world size. - To specify `flops_per_batch`, the model attribute can either be set as an int or float, which - would be used for every batch, or as a callable which accepts a batch and returns an int or - float. The latter formulation is useful for filtering out the flops of padding tokens. + To compute `flops_per_sec`, the model attribute `flops_per_batch` should be set to a callable + which accepts a batch and returns the number of flops for that batch. Typically, this should + be flops per sample times the batch size unless pad tokens are used. The wall clock time is logged on every :attr:`.Event.BATCH_END` event. @@ -177,21 +177,17 @@ class SpeedMonitor(Callback): | | Only logged when dataloader.dataset has `max_seq_len`. | | | This may include padding depending on dataset | +-------------------------------------+-----------------------------------------------------------+ - | | Estimates flops by `flops_per_batch * samples_per_sec` | + | | Estimates flops by `flops_per_batch * batches_per_sec` | | `throughput/flops_per_sec` | if model has attribute `flops_per_batch` | | | | +-------------------------------------+-----------------------------------------------------------+ - | | If model has attribute `flops_per_batch`, estimates | - | `throughput/flops_per_sec` | flops per second with `flops_per_batch * samples_per_sec` | - | | | - +-------------------------------------+-----------------------------------------------------------+ | `throughput/device/batches_per_sec` | `throughput/batches_per_sec` divided by world size | +-------------------------------------+-----------------------------------------------------------+ | `throughput/device/samples_per_sec` | `throughput/samples_per_sec` divided by world size | +-------------------------------------+-----------------------------------------------------------+ | | `throughput/tokens_per_sec` divided by world size. Only | | `throughput/device/tokens_per_sec` | logged when dataloader.dataset has `max_seq_len`. This | - | | may include pad tokens depending on how dataset | + | | may include pad tokens depending on dataset | +-------------------------------------+-----------------------------------------------------------+ | | `throughput/flops_per_sec` divided by world size. Only | | `throughput/device/flops_per_sec` | logged when model has attribute `flops_per_batch` | @@ -219,7 +215,6 @@ def __init__(self, window_size: int = 100, gpu_flops_available: Optional[Union[f self.history_samples: Deque[int] = deque(maxlen=window_size + 1) self.history_wct: Deque[float] = deque(maxlen=window_size + 1) - self.set_gpu_flops_available = False self.gpu_flops_available = gpu_flops_available # Keep track of time spent evaluating @@ -272,15 +267,11 @@ def batch_end(self, state: State, logger: Logger): composer_model = composer_model.module # Pass through DDP wrapping if hasattr(composer_model, 'flops_per_batch'): model_flops_per_batch = composer_model.flops_per_batch # type: ignore - flops_per_batch = None - if isinstance(model_flops_per_batch, (int, float)): - flops_per_batch = model_flops_per_batch - elif isinstance(model_flops_per_batch, Callable): - flops_per_batch = model_flops_per_batch(state.batch) - else: - raise TypeError(f'flops_per_batch must be int, float, or callable accepting a batch and ' - 'returning an int or float. Instead, got {type(flops_per_batch)}.') - flops_per_sec = flops_per_batch * samples_per_sec + if not isinstance(model_flops_per_batch, Callable): + raise TypeError('flops_per_batch must a callable accepting a batch and ' + f'returning an int or float. Instead, got {type(model_flops_per_batch)}.') + flops_per_batch = model_flops_per_batch(state.batch) + flops_per_sec = flops_per_batch * batches_per_sec logger.log_metrics({'throughput/flops_per_sec': flops_per_sec}) dev_flops_per_sec = flops_per_sec / world_size logger.log_metrics({'throughput/device/flops_per_sec': dev_flops_per_sec}) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 7fe4460c15..520fc4511c 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2248,9 +2248,6 @@ def _train_microbatch(self, use_grad_scaling: bool, current_batch_size: int, assert self.state.scaler is not None assert self._train_data_spec is not None - # Cache the device batch, because `self.state.batch` gets overridden in microbatching loop - device_batch = deepcopy(self.state.batch) - microbatch_num_samples = self._train_data_spec.get_num_samples_in_batch(self.state.batch) sync_context = contextlib.nullcontext() if self.deepspeed_enabled else ddp_sync_context( self.state, @@ -2320,7 +2317,7 @@ def _train_microbatch(self, use_grad_scaling: bool, current_batch_size: int, # Use microbatch outputs to update training metrics if self.state.train_metrics is not None: self.state.train_metrics = self._ensure_metrics_device_and_dtype(self.state.train_metrics) - self._eval_train_metrics(device_batch) + self._eval_train_metrics(self.state.batch) if self.deepspeed_enabled: self.state.deepspeed_model.step() diff --git a/tests/callbacks/test_speed_monitor.py b/tests/callbacks/test_speed_monitor.py index 144d1f509f..c442489adb 100644 --- a/tests/callbacks/test_speed_monitor.py +++ b/tests/callbacks/test_speed_monitor.py @@ -25,18 +25,15 @@ def _assert_no_negative_values(logged_values): assert v >= 0 -@pytest.mark.parametrize('flops_per_batch_float,flops_per_batch_callable', - [[False, False], [True, False], [False, True]]) -def test_speed_monitor(flops_per_batch_float: bool, flops_per_batch_callable: bool): +@pytest.mark.parametrize('flops_per_batch', [False, True]) +def test_speed_monitor(flops_per_batch: bool): # Construct the callbacks speed_monitor = SpeedMonitor(window_size=2) in_memory_logger = InMemoryLogger() # track the logged metrics in the in_memory_logger model = SimpleModel() - if flops_per_batch_float: - model.flops_per_batch = 100.0 - if flops_per_batch_callable: - model.flops_per_batch = lambda _batch: 100.0 + if flops_per_batch: + model.flops_per_batch = lambda batch: len(batch) * 100.0 # Construct the trainer and train trainer = Trainer( @@ -56,7 +53,7 @@ def test_speed_monitor(flops_per_batch_float: bool, flops_per_batch_callable: bo _assert_no_negative_values(in_memory_logger.data['throughput/samples_per_sec']) _assert_no_negative_values(in_memory_logger.data['throughput/device/batches_per_sec']) _assert_no_negative_values(in_memory_logger.data['throughput/device/samples_per_sec']) - if flops_per_batch_float or flops_per_batch_callable: + if flops_per_batch: _assert_no_negative_values(in_memory_logger.data['throughput/flops_per_sec']) _assert_no_negative_values(in_memory_logger.data['throughput/device/flops_per_sec']) @@ -69,7 +66,7 @@ def test_speed_monitor(flops_per_batch_float: bool, flops_per_batch_callable: bo assert len(in_memory_logger.data['throughput/samples_per_sec']) == expected_step_calls assert len(in_memory_logger.data['throughput/device/batches_per_sec']) == expected_step_calls assert len(in_memory_logger.data['throughput/device/samples_per_sec']) == expected_step_calls - if flops_per_batch_float or flops_per_batch_callable: + if flops_per_batch: assert len(in_memory_logger.data['throughput/flops_per_sec']) == expected_step_calls assert len(in_memory_logger.data['throughput/device/flops_per_sec']) == expected_step_calls num_batches = int(trainer.state.timestamp.batch) From 84dde509c4071dd3443c138849c05c0162311c9e Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Thu, 23 Feb 2023 15:44:20 -0800 Subject: [PATCH 08/10] restore caching --- composer/trainer/trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 520fc4511c..5d92cc33ff 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2248,6 +2248,9 @@ def _train_microbatch(self, use_grad_scaling: bool, current_batch_size: int, assert self.state.scaler is not None assert self._train_data_spec is not None + # Cache the device batch, because `self.state.batch` gets overridden in microbatching loop + device_batch = self.state.batch + microbatch_num_samples = self._train_data_spec.get_num_samples_in_batch(self.state.batch) sync_context = contextlib.nullcontext() if self.deepspeed_enabled else ddp_sync_context( self.state, @@ -2317,7 +2320,7 @@ def _train_microbatch(self, use_grad_scaling: bool, current_batch_size: int, # Use microbatch outputs to update training metrics if self.state.train_metrics is not None: self.state.train_metrics = self._ensure_metrics_device_and_dtype(self.state.train_metrics) - self._eval_train_metrics(self.state.batch) + self._eval_train_metrics(device_batch) if self.deepspeed_enabled: self.state.deepspeed_model.step() From bcf004ed80b4892548904ada50c4adaf164570c9 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Thu, 23 Feb 2023 15:56:28 -0800 Subject: [PATCH 09/10] add deepcopy --- composer/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 5d92cc33ff..7fe4460c15 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2249,7 +2249,7 @@ def _train_microbatch(self, use_grad_scaling: bool, current_batch_size: int, assert self._train_data_spec is not None # Cache the device batch, because `self.state.batch` gets overridden in microbatching loop - device_batch = self.state.batch + device_batch = deepcopy(self.state.batch) microbatch_num_samples = self._train_data_spec.get_num_samples_in_batch(self.state.batch) sync_context = contextlib.nullcontext() if self.deepspeed_enabled else ddp_sync_context( From 58e38f0b919c901ab6b7967270ed4c7dd1daa557 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Fri, 24 Feb 2023 12:14:52 -0800 Subject: [PATCH 10/10] add comment --- composer/trainer/trainer.py | 3 +- tests/callbacks/test_runtime_estimator.py | 50 +++++++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) create mode 100644 tests/callbacks/test_runtime_estimator.py diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 42e2c0b3d4..cc9072a51b 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2104,7 +2104,8 @@ def _train_batch(self, use_grad_scaling: bool) -> Dict[str, torch.Tensor]: """ assert self._train_data_spec is not None, 'The train data spec should be set on __init__ or fit()' - # Cache the device batch, because `self.state.batch` gets overridden in microbatching loop + # Cache the device batch, because `self.state.batch` gets overridden in microbatching loop. + # Any in-place changes to a microbatch will be reflected in the device batch. device_batch = self.state.batch # Retry until we successfully complete training and return loss diff --git a/tests/callbacks/test_runtime_estimator.py b/tests/callbacks/test_runtime_estimator.py new file mode 100644 index 0000000000..fba4b131d6 --- /dev/null +++ b/tests/callbacks/test_runtime_estimator.py @@ -0,0 +1,50 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +import datetime + +from torch.utils.data import DataLoader + +from composer.callbacks import RuntimeEstimator +from composer.core import Time +from composer.loggers import InMemoryLogger +from composer.trainer import Trainer +from tests.common import RandomClassificationDataset, SimpleModel + + +def _assert_no_negative_values(logged_values): + for timestamp, v in logged_values: + del timestamp # unused + if isinstance(v, Time): + assert int(v) >= 0 + elif isinstance(v, datetime.timedelta): + assert v.total_seconds() >= 0 + else: + assert v >= 0 + + +def test_runtime_estimator(): + # Construct the callbacks + skip_batches = 1 + runtime_estimator = RuntimeEstimator(skip_batches=skip_batches) + in_memory_logger = InMemoryLogger() # track the logged metrics in the in_memory_logger + + # Construct the trainer and train + trainer = Trainer( + model=SimpleModel(), + callbacks=runtime_estimator, + loggers=in_memory_logger, + train_dataloader=DataLoader(RandomClassificationDataset()), + eval_dataloader=DataLoader(RandomClassificationDataset()), + max_duration='2ep', + eval_interval='1ep', + train_subset_num_batches=10, + eval_subset_num_batches=10, + ) + trainer.fit() + + wall_clock_remaining_calls = len(in_memory_logger.data['wall_clock/remaining_estimate']) + _assert_no_negative_values(in_memory_logger.data['wall_clock/remaining_estimate']) + + expected_calls = int(trainer.state.timestamp.batch) - skip_batches + assert wall_clock_remaining_calls == expected_calls