Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update MLPerfCallback for v2.1 #1607

Merged
merged 14 commits into from
Oct 11, 2022
45 changes: 33 additions & 12 deletions composer/callbacks/mlperf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing import Any, Dict, Iterable, Optional

import torch
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, IterableDataset

import composer
from composer.core import State
Expand All @@ -32,7 +32,7 @@
mlperf_available = False

# this callback only supports the following options:
BENCHMARKS = ('resnet',)
BENCHMARKS = ('resnet', 'bert')
DIVISIONS = ('open',)
STATUS = ('onprem', 'cloud', 'preview')

Expand Down Expand Up @@ -126,6 +126,7 @@ class MLPerfCallback(Callback):
cache_clear_cmd (str, optional): Command to invoke during the cache clear. This callback
will call ``os.system(cache_clear_cmd)``. Default is disabled (None)
host_processors_per_node (int, optional): Total number of host processors per node. Default: ``None``.
exit_at_target (bool, optional): Whether to exit training when target metric is met. Default: ``False``.
"""

def __init__(
Expand All @@ -142,6 +143,7 @@ def __init__(
status: str = 'onprem',
cache_clear_cmd: Optional[str] = None,
host_processors_per_node: Optional[int] = None,
exit_at_target: bool = False,
) -> None:

_require_mlperf_logging()
Expand All @@ -164,6 +166,7 @@ def __init__(
self.root_folder = root_folder
self.metric_name = metric_name
self.metric_label = metric_label
self.exit_at_target = exit_at_target
self._file_handler = None

self.system_desc = get_system_description(submitter, division, status, system_name, host_processors_per_node)
Expand Down Expand Up @@ -246,10 +249,21 @@ def _get_accuracy(self, state: State) -> float:
metric = state.eval_metrics[self.metric_label][self.metric_name].compute()
return float(metric)

def _get_time(self, state: State) -> int:
"""Different benchmarks log different units of time."""
benchmark_time = {
'resnet': state.timestamp.epoch.value,
'bert': state.timestamp.sample.value,
}
return benchmark_time[self.benchmark]

def _get_dataloader_stats(self, dataloader: Iterable):
"""Returns a tuple of ``(batch_size, num_samples)``."""
if isinstance(dataloader, DataLoader):
return (dataloader.batch_size, len(dataloader.dataset)) # type: ignore
num_samples = len(dataloader.dataset) # type: ignore
if isinstance(dataloader.dataset, IterableDataset):
num_samples *= dist.get_world_size()
return (dataloader.batch_size, num_samples)
try:
# attempt to import ffcv and test if its an ffcv loader.
import ffcv # type: ignore
Expand Down Expand Up @@ -297,37 +311,44 @@ def fit_start(self, state: State, logger: Logger) -> None:

def epoch_start(self, state: State, logger: Logger) -> None:
if _global_rank_zero():
self.mllogger.event(key=constants.EPOCH_START, metadata={'epoch_num': state.timestamp.epoch.value})
self.mllogger.event(key=constants.EPOCH_START, metadata={'epoch_num': self._get_time(state)})
self.mllogger.event(key=constants.BLOCK_START,
metadata={
'first_epoch_num': state.timestamp.epoch.value,
'first_epoch_num': self._get_time(state),
'epoch_count': 1
})

def epoch_end(self, state: State, logger: Logger) -> None:
if _global_rank_zero():
self.mllogger.event(key=constants.EPOCH_STOP, metadata={'epoch_num': state.timestamp.epoch.value})
self.mllogger.event(key=constants.EPOCH_STOP, metadata={'epoch_num': self._get_time(state)})
logger.file_artifact(artifact_name=self.upload_name, file_path=self.filename)

def eval_start(self, state: State, logger: Logger) -> None:
if _global_rank_zero():
self.mllogger.event(key=constants.EVAL_START, metadata={'epoch_num': state.timestamp.epoch.value})
self.mllogger.event(key=constants.EVAL_START, metadata={'epoch_num': self._get_time(state)})

def eval_end(self, state: State, logger: Logger) -> None:
if _global_rank_zero():
accuracy = self._get_accuracy(state)
accuracy = self._get_accuracy(state)

self.mllogger.event(key=constants.EVAL_STOP, metadata={'epoch_num': state.timestamp.epoch.value})
if _global_rank_zero():
self.mllogger.event(key=constants.EVAL_STOP, metadata={'epoch_num': self._get_time(state)})
self.mllogger.event(key=constants.EVAL_ACCURACY,
value=accuracy,
metadata={'epoch_num': state.timestamp.epoch.value})
self.mllogger.event(key=constants.BLOCK_STOP, metadata={'first_epoch_num': state.timestamp.epoch.value})
metadata={'epoch_num': self._get_time(state)})
self.mllogger.event(key=constants.BLOCK_STOP, metadata={'first_epoch_num': self._get_time(state)})

if accuracy > self.target and not self.success:
self.mllogger.event(key=constants.RUN_STOP, metadata={'status': 'success'})
self.mllogger.logger.removeHandler(self._file_handler)
self.success = True # only log once

# upload to object store after eval complete
logger.file_artifact(artifact_name=self.upload_name, file_path=self.filename)
hanlint marked this conversation as resolved.
Show resolved Hide resolved

if accuracy > self.target and self.exit_at_target:
# stop training
state.max_duration = state.timestamp.batch

def close(self, state: State, logger: Logger) -> None:
if self._file_handler is not None:
self._file_handler.close()
Expand Down
33 changes: 25 additions & 8 deletions tests/callbacks/test_mlperf_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

import numpy as np
import pytest
import torch
from torch.utils.data import DataLoader
from torchmetrics import Accuracy

from composer import State, Trainer
from composer.callbacks import MLPerfCallback
Expand All @@ -27,6 +29,14 @@ def importor_skip_mlperf_logging():
pytest.importorskip('mlperf_logging')


# MLperf requires different number of results
# depending on the benchmark
NUM_TRIALS = {
'resnet': 5,
'bert': 10,
}


class MockMLLogger:
"""Mocks the MLPerf Logger interface."""

Expand All @@ -50,7 +60,12 @@ def mlperf_callback(self, monkeypatch, tmp_path) -> MLPerfCallback:
@pytest.fixture
def mock_state(self):
"""Mocks a state at epoch 1 with Accuracy 0.99."""
eval_metrics = {'eval': {'Accuracy': 0.99}}
acc = Accuracy()
eval_metrics = {'eval': {'Accuracy': acc}}
acc.update(
torch.tensor([1, 1], dtype=torch.int8),
torch.tensor([1, 1], dtype=torch.int8),
)

state = Mock()
state.eval_metrics = eval_metrics
Expand Down Expand Up @@ -87,10 +102,11 @@ def test_eval_end(self, mlperf_callback, mock_state):

@world_size(1, 2)
@device('cpu', 'gpu')
@pytest.mark.parametrize('benchmark', ['resnet', 'bert'])
class TestWithMLPerfChecker:
"""Ensures that the logs created by the MLPerfCallback pass the official package checker."""

def test_mlperf_callback_passes(self, tmp_path, monkeypatch, world_size, device):
def test_mlperf_callback_passes(self, tmp_path, monkeypatch, benchmark, world_size, device):

def mock_accuracy(self, state: State):
if state.timestamp.epoch >= 2:
Expand All @@ -100,28 +116,29 @@ def mock_accuracy(self, state: State):

monkeypatch.setattr(MLPerfCallback, '_get_accuracy', mock_accuracy)

self.generate_submission(tmp_path, device)
self.generate_submission(tmp_path, device, benchmark)

if rank_zero():
self.run_mlperf_checker(tmp_path, monkeypatch)

def test_mlperf_callback_fails(self, tmp_path, monkeypatch, world_size, device):
def test_mlperf_callback_fails(self, tmp_path, monkeypatch, benchmark, world_size, device):

def mock_accuracy(self, state: State):
return 0.01

monkeypatch.setattr(MLPerfCallback, '_get_accuracy', mock_accuracy)

self.generate_submission(tmp_path, device)
self.generate_submission(tmp_path, device, benchmark)
with pytest.raises(ValueError, match='MLPerf checker failed'):
self.run_mlperf_checker(tmp_path, monkeypatch)

def generate_submission(self, directory, device):
def generate_submission(self, directory, device, benchmark):
"""Generates submission files by training the benchark n=5 times."""

for run in range(5):
for run in range(NUM_TRIALS[benchmark]):

mlperf_callback = MLPerfCallback(
benchmark=benchmark,
root_folder=directory,
index=run,
cache_clear_cmd='sleep 0.1',
Expand Down Expand Up @@ -165,7 +182,7 @@ def fail_on_error(msg, *args, **kwargs):
check_training_package(
folder=directory,
usage='training',
ruleset='1.1.0',
ruleset='2.1.0',
werror=True,
quiet=False,
rcp_bypass=False,
Expand Down