Skip to content

Commit

Permalink
Update MLPerfCallback for v2.1 (#1607)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanlint authored Oct 11, 2022
1 parent 65527d4 commit 50253c1
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 20 deletions.
45 changes: 33 additions & 12 deletions composer/callbacks/mlperf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing import Any, Dict, Iterable, Optional

import torch
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, IterableDataset

import composer
from composer.core import State
Expand All @@ -32,7 +32,7 @@
mlperf_available = False

# this callback only supports the following options:
BENCHMARKS = ('resnet',)
BENCHMARKS = ('resnet', 'bert')
DIVISIONS = ('open',)
STATUS = ('onprem', 'cloud', 'preview')

Expand Down Expand Up @@ -126,6 +126,7 @@ class MLPerfCallback(Callback):
cache_clear_cmd (str, optional): Command to invoke during the cache clear. This callback
will call ``os.system(cache_clear_cmd)``. Default is disabled (None)
host_processors_per_node (int, optional): Total number of host processors per node. Default: ``None``.
exit_at_target (bool, optional): Whether to exit training when target metric is met. Default: ``False``.
"""

def __init__(
Expand All @@ -142,6 +143,7 @@ def __init__(
status: str = 'onprem',
cache_clear_cmd: Optional[str] = None,
host_processors_per_node: Optional[int] = None,
exit_at_target: bool = False,
) -> None:

_require_mlperf_logging()
Expand All @@ -164,6 +166,7 @@ def __init__(
self.root_folder = root_folder
self.metric_name = metric_name
self.metric_label = metric_label
self.exit_at_target = exit_at_target
self._file_handler = None

self.system_desc = get_system_description(submitter, division, status, system_name, host_processors_per_node)
Expand Down Expand Up @@ -246,10 +249,21 @@ def _get_accuracy(self, state: State) -> float:
metric = state.eval_metrics[self.metric_label][self.metric_name].compute()
return float(metric)

def _get_time(self, state: State) -> int:
"""Different benchmarks log different units of time."""
benchmark_time = {
'resnet': state.timestamp.epoch.value,
'bert': state.timestamp.sample.value,
}
return benchmark_time[self.benchmark]

def _get_dataloader_stats(self, dataloader: Iterable):
"""Returns a tuple of ``(batch_size, num_samples)``."""
if isinstance(dataloader, DataLoader):
return (dataloader.batch_size, len(dataloader.dataset)) # type: ignore
num_samples = len(dataloader.dataset) # type: ignore
if isinstance(dataloader.dataset, IterableDataset):
num_samples *= dist.get_world_size()
return (dataloader.batch_size, num_samples)
try:
# attempt to import ffcv and test if its an ffcv loader.
import ffcv # type: ignore
Expand Down Expand Up @@ -297,37 +311,44 @@ def fit_start(self, state: State, logger: Logger) -> None:

def epoch_start(self, state: State, logger: Logger) -> None:
if _global_rank_zero():
self.mllogger.event(key=constants.EPOCH_START, metadata={'epoch_num': state.timestamp.epoch.value})
self.mllogger.event(key=constants.EPOCH_START, metadata={'epoch_num': self._get_time(state)})
self.mllogger.event(key=constants.BLOCK_START,
metadata={
'first_epoch_num': state.timestamp.epoch.value,
'first_epoch_num': self._get_time(state),
'epoch_count': 1
})

def epoch_end(self, state: State, logger: Logger) -> None:
if _global_rank_zero():
self.mllogger.event(key=constants.EPOCH_STOP, metadata={'epoch_num': state.timestamp.epoch.value})
self.mllogger.event(key=constants.EPOCH_STOP, metadata={'epoch_num': self._get_time(state)})
logger.upload_file(remote_file_name=self.upload_name, file_path=self.filename)

def eval_start(self, state: State, logger: Logger) -> None:
if _global_rank_zero():
self.mllogger.event(key=constants.EVAL_START, metadata={'epoch_num': state.timestamp.epoch.value})
self.mllogger.event(key=constants.EVAL_START, metadata={'epoch_num': self._get_time(state)})

def eval_end(self, state: State, logger: Logger) -> None:
if _global_rank_zero():
accuracy = self._get_accuracy(state)
accuracy = self._get_accuracy(state)

self.mllogger.event(key=constants.EVAL_STOP, metadata={'epoch_num': state.timestamp.epoch.value})
if _global_rank_zero():
self.mllogger.event(key=constants.EVAL_STOP, metadata={'epoch_num': self._get_time(state)})
self.mllogger.event(key=constants.EVAL_ACCURACY,
value=accuracy,
metadata={'epoch_num': state.timestamp.epoch.value})
self.mllogger.event(key=constants.BLOCK_STOP, metadata={'first_epoch_num': state.timestamp.epoch.value})
metadata={'epoch_num': self._get_time(state)})
self.mllogger.event(key=constants.BLOCK_STOP, metadata={'first_epoch_num': self._get_time(state)})

if accuracy > self.target and not self.success:
self.mllogger.event(key=constants.RUN_STOP, metadata={'status': 'success'})
self.mllogger.logger.removeHandler(self._file_handler)
self.success = True # only log once

# upload to object store after eval complete
logger.upload_file(remote_file_name=self.upload_name, file_path=self.filename)

if accuracy > self.target and self.exit_at_target:
# stop training
state.max_duration = state.timestamp.batch

def close(self, state: State, logger: Logger) -> None:
if self._file_handler is not None:
self._file_handler.close()
Expand Down
33 changes: 25 additions & 8 deletions tests/callbacks/test_mlperf_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

import numpy as np
import pytest
import torch
from torch.utils.data import DataLoader
from torchmetrics import Accuracy

from composer import State, Trainer
from composer.callbacks import MLPerfCallback
Expand All @@ -27,6 +29,14 @@ def importor_skip_mlperf_logging():
pytest.importorskip('mlperf_logging')


# MLperf requires different number of results
# depending on the benchmark
NUM_TRIALS = {
'resnet': 5,
'bert': 10,
}


class MockMLLogger:
"""Mocks the MLPerf Logger interface."""

Expand All @@ -50,7 +60,12 @@ def mlperf_callback(self, monkeypatch, tmp_path) -> MLPerfCallback:
@pytest.fixture
def mock_state(self):
"""Mocks a state at epoch 1 with Accuracy 0.99."""
eval_metrics = {'eval': {'Accuracy': 0.99}}
acc = Accuracy()
eval_metrics = {'eval': {'Accuracy': acc}}
acc.update(
torch.tensor([1, 1], dtype=torch.int8),
torch.tensor([1, 1], dtype=torch.int8),
)

state = Mock()
state.eval_metrics = eval_metrics
Expand Down Expand Up @@ -87,10 +102,11 @@ def test_eval_end(self, mlperf_callback, mock_state):

@world_size(1, 2)
@device('cpu', 'gpu')
@pytest.mark.parametrize('benchmark', ['resnet', 'bert'])
class TestWithMLPerfChecker:
"""Ensures that the logs created by the MLPerfCallback pass the official package checker."""

def test_mlperf_callback_passes(self, tmp_path, monkeypatch, world_size, device):
def test_mlperf_callback_passes(self, tmp_path, monkeypatch, benchmark, world_size, device):

def mock_accuracy(self, state: State):
if state.timestamp.epoch >= 2:
Expand All @@ -100,28 +116,29 @@ def mock_accuracy(self, state: State):

monkeypatch.setattr(MLPerfCallback, '_get_accuracy', mock_accuracy)

self.generate_submission(tmp_path, device)
self.generate_submission(tmp_path, device, benchmark)

if rank_zero():
self.run_mlperf_checker(tmp_path, monkeypatch)

def test_mlperf_callback_fails(self, tmp_path, monkeypatch, world_size, device):
def test_mlperf_callback_fails(self, tmp_path, monkeypatch, benchmark, world_size, device):

def mock_accuracy(self, state: State):
return 0.01

monkeypatch.setattr(MLPerfCallback, '_get_accuracy', mock_accuracy)

self.generate_submission(tmp_path, device)
self.generate_submission(tmp_path, device, benchmark)
with pytest.raises(ValueError, match='MLPerf checker failed'):
self.run_mlperf_checker(tmp_path, monkeypatch)

def generate_submission(self, directory, device):
def generate_submission(self, directory, device, benchmark):
"""Generates submission files by training the benchark n=5 times."""

for run in range(5):
for run in range(NUM_TRIALS[benchmark]):

mlperf_callback = MLPerfCallback(
benchmark=benchmark,
root_folder=directory,
index=run,
cache_clear_cmd='sleep 0.1',
Expand Down Expand Up @@ -165,7 +182,7 @@ def fail_on_error(msg, *args, **kwargs):
check_training_package(
folder=directory,
usage='training',
ruleset='1.1.0',
ruleset='2.1.0',
werror=True,
quiet=False,
rcp_bypass=False,
Expand Down

0 comments on commit 50253c1

Please sign in to comment.