diff --git a/.gitignore b/.gitignore index 956f0acea1..3e9a4058cf 100644 --- a/.gitignore +++ b/.gitignore @@ -159,3 +159,7 @@ package-lock.json # pycharm .idea/ + +# composer +data/ +datasets/ diff --git a/composer/core/engine.py b/composer/core/engine.py index ad1189de95..8ebc60d0c8 100644 --- a/composer/core/engine.py +++ b/composer/core/engine.py @@ -9,33 +9,34 @@ :class:`.SelectiveBackprop` algorithm runs on the :attr:`.Event.AFTER_DATALOADER` event and must run before any data augmentations. :class:`.Engine` runs re-ordering passes to resolve such ordering issues or conflicts. -.. note:: +These orderings are enforced by algorithm passes. The default passes registered to the Engine are found in +:mod:`composer.core.passes`. To register a new pass, use :meth:`.Engine.register_pass`, e.g. - * An instance of :class:`.Engine` is automatically constructed by the :class:`.Trainer` - constructor. A user need not instantiate the :class:`.Engine` class. - * The design of :class:`.Engine` is subject to change in future releases - to accommodate more complexity as we investigate composition of algorithms. +.. testsetup:: + # dummy algorithm + MyAlgorithm = None -Currently, the following passes are registered: +.. doctest:: -* **LIFO order for events** + from composer import Engine, Algorithm, Event + from typing import Sequence - For the events that follow the ``before_*`` (e.g., :attr:`.Event.BEFORE_LOSS`) and ``after_*`` (e.g., - :attr:`.Event.AFTER_LOSS`) pattern, the ordering of algorithms is reversed for the ``after_*`` events. For example, - four given algorithms ``A``, ``B``, ``C``, and ``D`` will run in ``ABCD`` ordering on the ``before_*`` event while - ``DCBA`` ordering on the ``after_*`` event. + def run_last(algorithms: Sequence[Algorithm], event: Event) -> Sequence[Algorithm]: + algorithms = sorted(algorithms, key=lambda x: isinstance(x, MyAlgorithm)) - This allows algorithms to "clean up" their changes. For example, :class:`.LabelSmoothing` will smooth the labels - upon the :attr:`.Event.BEFORE_LOSS` event and then restore the original unsmoothed labels on the - :attr:`.Event.AFTER_LOSS` event. + Engine.register_pass(run_last) -* **Run Selective Backprop first** +.. note:: + + * An instance of :class:`.Engine` is automatically constructed by the :class:`.Trainer` + constructor. A user need not instantiate the :class:`.Engine` class. + +.. note:: + * The design of :class:`.Engine` is subject to change in future releases + to accommodate more complexity as we investigate composition of algorithms. - :class:`.SelectiveBackprop` runs after the dataloader returns the batch and executes an extra forward pass to rank - and prune the examples in the batch by loss. To ensure a clean estimate of loss, :class:`.SelectiveBackprop` should - run before any other data augmentations (e.g., :class:`.MixUp`) on the :attr:`.Event.AFTER_DATALOADER` event. Trace ~~~~~ @@ -64,12 +65,12 @@ import atexit import contextlib import logging -import warnings import weakref from collections import OrderedDict from dataclasses import dataclass -from typing import ContextManager, Dict, Optional, Sequence, Union, cast +from typing import ContextManager, Dict, List, Optional, Sequence, TypeVar, Union, cast +from composer.core import passes from composer.core.algorithm import Algorithm from composer.core.callback import Callback from composer.core.event import Event @@ -81,35 +82,15 @@ __all__ = ['Trace', 'Engine', 'Traces'] +T = TypeVar('T') + +_ALWAYS_RECORD_EVENTS = [Event.INIT, Event.FIT_START, Event.EPOCH_START, Event.EPOCH_END] + #: The default traces of an entire run is an OrderedDict. #: The keys are of format ``/`` (e.g., ``Blurpool/INIT``) and values are an instance of #: :class:`Trace`. Traces = Dict[str, 'Trace'] -_ALWAYS_RECORD_EVENTS = [Event.INIT, Event.FIT_START, Event.EPOCH_START, Event.EPOCH_END] -_EVENTS_WHERE_DATALOADER_IS_SET = [e for e in Event if e != Event.INIT] -_EVENTS_WHERE_MAX_DURATION_IS_SET = [ - Event.FIT_START, - Event.EPOCH_START, - Event.BATCH_START, - Event.AFTER_DATALOADER, - Event.BEFORE_TRAIN_BATCH, - Event.BEFORE_FORWARD, - Event.AFTER_FORWARD, - Event.BEFORE_LOSS, - Event.AFTER_LOSS, - Event.BEFORE_BACKWARD, - Event.AFTER_BACKWARD, - Event.AFTER_TRAIN_BATCH, - Event.BATCH_END, - Event.BATCH_CHECKPOINT, - Event.EPOCH_END, - Event.EPOCH_CHECKPOINT, - Event.FIT_END, -] -_EVAL_EVENTS = [e for e in Event if e.name.startswith('EVAL_')] -_PREDICT_EVENTS = [e for e in Event if e.name.startswith('PREDICT_')] - # Track whether atexit triggered _close(), which indicates whether the python process is shutting down # If so, do not run close() again via __del__(), as Python machinery (e.g. the ability to do conditional # imports) are destroyed between close() and __del__(). @@ -127,6 +108,15 @@ def _set_atexit_ran(): atexit.register(_set_atexit_ran) +def _get_default_passes(): + return [ + passes.sort_selective_backprop_first, + passes.sort_fused_layernorm_last, + passes.set_filo_order, + passes.warn_if_multiple_loss_interpolation, + ] + + @dataclass class Trace(): """Record of an algorithm's execution. @@ -173,6 +163,9 @@ def __init__(self, state: State, logger: Logger): self.logger = logger self.state = state self._is_closed = False + + self.algorithm_passes: List[passes.AlgorithmPass] = _get_default_passes() + atexit.register(self._close, state, logger) def run_event( @@ -233,11 +226,7 @@ def run_event( if event.is_after_event and duration_marker is not None: duration_marker.finish() - if event in _EVENTS_WHERE_DATALOADER_IS_SET: - assert self.state.dataloader is not None, f'The trainer should have set state.dataloader for event {event}.' - - if event in _EVENTS_WHERE_MAX_DURATION_IS_SET: - assert self.state.max_duration is not None, f'The trainer should have set state.max_duration for event {event}.' + self._assert_dataloader_and_duration_set(self.state, event) if event == Event.INIT: # For the INIT event, run the callbacks first to initialize the loggers @@ -254,13 +243,37 @@ def run_event( return traces + def register_pass(self, algorithm_pass: passes.AlgorithmPass, index: int = -1): + """Registers an algorithm pass with the Engine. + + Args: + algorithm_pass (passes.AlgorithmPass): A method that maps a list of + algorithms to a list of algorithms. + index (int, optional): The index to insert into the list of passes. + If -1 (default), the pass will be insert to the end of the list. + """ + if index == -1: + index = len(self.algorithm_passes) + + self.algorithm_passes.insert(index, algorithm_pass) + + @staticmethod + def _assert_dataloader_and_duration_set(state: State, event: Event): + # correctness checks that dataloader and max duration need to be set for certain events + + if event != Event.INIT: # datalaoder should be set on all events expect INIT + assert state.dataloader is not None, f'The trainer should have set state.dataloader for event {event}.' + + if event != Event.INIT and not event.is_predict and not event.is_eval: + assert state.max_duration is not None, f'The trainer should have set state.max_duration for event {event}.' + def _run_algorithms( self, event: Event, ) -> Traces: algorithms_to_run = [algo for algo in self.state.algorithms if algo.match(event, self.state)] - # future collision resolution + # apply algorithm passes algorithms_to_run = self._compile(algorithms_to_run, event) trace = _setup_trace(algorithms_to_run, event) @@ -321,30 +334,11 @@ def _compile( Returns: Sequence[Algorithm]: Modified sequence of algorithms. """ - from composer.algorithms import CutMix, FusedLayerNorm, MixUp, SelectiveBackprop, StochasticDepth - - # Move selective backprop to the beginning while maintaining order of other algorithms - algorithms = sorted(algorithms_to_run, - key=lambda x: not isinstance(x, SelectiveBackprop) and not isinstance(x, StochasticDepth)) - - # Move fused layernorm to the end while maintaining order of other algorithms (FLN only does surgery on leaf modules) - algorithms = sorted(algorithms, key=lambda x: isinstance(x, FusedLayerNorm)) - - # Check for multiple algorithms that try to interpolate the loss at the same time - interpolation_settings = [a.interpolate_loss for a in algorithms if isinstance(a, (CutMix, MixUp))] - if sum(interpolation_settings) > 1: - warnings.warn( - 'Multiple algorithms are trying to interpolate the loss. This can result in strange behavior.') - - if event.is_after_event: - """Establish a FILO queue of algorithms ``before_`` and ``after_`` an event. + # run reordering passes on the algorithms + for passes in self.algorithm_passes: + algorithms_to_run = passes(algorithms_to_run, event) - before_loss: A, B, C, D - after_loss: D, C, B, A - """ - algorithms = list(reversed(algorithms)) - - return algorithms + return algorithms_to_run def _run_callbacks( self, @@ -392,32 +386,17 @@ def __del__(self): def _debug_log(self, event: Event, msg: str): """Helper to include timestamp and event info in log messages.""" - if event in _EVAL_EVENTS: - log.debug( - '[ep=%i][ba=%i][eval_ba=%i][event=%s]: %s', - int(self.state.timestamp.epoch), - int(self.state.timestamp.batch), - int(self.state.eval_timestamp.batch), - event.name, - msg, - ) - elif event in _PREDICT_EVENTS: - log.debug( - '[ep=%i][ba=%i][predict_ba=%i][event=%s]: %s', - int(self.state.timestamp.epoch), - int(self.state.timestamp.batch), - int(self.state.predict_timestamp.batch), - event.name, - msg, - ) - else: - log.debug( - '[ep=%i][ba=%i][event=%s]: %s', - int(self.state.timestamp.epoch), - int(self.state.timestamp.batch), - event.name, - msg, - ) + timestamp = f'[ep={int(self.state.timestamp.epoch)}][ba={int(self.state.timestamp.batch)}]' + + # for eval or pr + if event.is_eval: + timestamp += f'[eval_ba={int(self.state.eval_timestamp.batch)}]' + if event.is_predict: + timestamp += f'[predict_ba={int(self.state.predict_timestamp.batch)}]' + + timestamp += f'[event={event.name}]' + + log.debug(f'{timestamp}: {msg}') def close(self) -> None: """Shutdown the engine. diff --git a/composer/core/event.py b/composer/core/event.py index 0472789039..1a15c98969 100644 --- a/composer/core/event.py +++ b/composer/core/event.py @@ -207,6 +207,16 @@ def canonical_name(self) -> str: name = name.replace('_end', '') return name + @property + def is_predict(self) -> bool: + """Whether the event is during the predict loop.""" + return self.value.startswith('predict') + + @property + def is_eval(self) -> bool: + """Whether the event is during the eval loop.""" + return self.value.startswith('eval') + _BEFORE_EVENTS = (Event.FIT_START, Event.EPOCH_START, Event.BATCH_START, Event.BEFORE_TRAIN_BATCH, Event.BEFORE_FORWARD, Event.BEFORE_LOSS, Event.BEFORE_BACKWARD, Event.EVAL_START, Event.EVAL_BATCH_START, diff --git a/composer/core/passes.py b/composer/core/passes.py new file mode 100644 index 0000000000..889755e79e --- /dev/null +++ b/composer/core/passes.py @@ -0,0 +1,134 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +"""Algorithm Passes reorder or modify the execution of algorithms by the Engine. + +The order in which algorithms are run matters significantly during composition. For example, the +:class:`.SelectiveBackprop` algorithm runs on the :attr:`.Event.AFTER_DATALOADER` event and must run before +any data augmentations. :class:`.Engine` runs re-ordering passes to resolve such ordering issues or conflicts. + +These modifications are represented as algorithm passes, which are functions that modify a list of algorithms. + +For example, an algorithm pass that ensures a certain algorithm runs last, would be implemented as: + +.. code-block:: python + + def run_last(algorithms: Sequence[Algorithm], event: Event) -> Sequence[Algorithm]: + algorithms = sorted(algorithms, key=lambda x: isinstance(x, MyAlgorithm)) + +The passes in this module are registered by default into :class:`.Engine`. +""" +import warnings +from typing import Any, Callable, Sequence, TypeVar + +from composer.core.algorithm import Algorithm +from composer.core.event import Event + +T = TypeVar('T') + +AlgorithmPass = Callable[[Sequence[Algorithm], Event], Sequence[Algorithm]] + + +def sort_to_front(list_to_sort: Sequence[T], cls: Any) -> Sequence[T]: + """Helper function to sort instances of a provided class to the front. + + Example: + + .. testsetup:: + + from composer.core.passes import sort_to_front + + .. doctest:: + + >>> sort_to_front([1, 'b', 3], str) + ['b', 1, 3] + + Args: + list_to_sort: list of objects to sort + cls: sorts all objects of this class to the front + + Returns: + sorted_list: Sorted List + + """ + return sorted(list_to_sort, key=lambda x: not isinstance(x, cls)) + + +def sort_to_back(list_to_sort: Sequence[T], cls: Any) -> Sequence[T]: + """Helper function to sort instances of a provided class to the back. + + Example: + + .. testsetup:: + + from composer.core.passes import sort_to_back + + .. doctest:: + + >>> sort_to_back([1, 'b', 3], str) + [1, 3, 'b'] + + Args: + list_to_sort: list of objects to sort + cls: sorts all objects of this class to the back + + Returns: + sorted_list: Sorted List + + """ + return sorted(list_to_sort, key=lambda x: isinstance(x, cls)) + + +def sort_selective_backprop_first(algorithms: Sequence[Algorithm], event: Event) -> Sequence[Algorithm]: + """Selective Backprop should run before any algorithms modify the loss. + + :class:`.SelectiveBackprop` runs after the dataloader returns the batch and executes an extra forward pass to rank + and prune the examples in the batch by loss. To ensure a clean estimate of loss, :class:`.SelectiveBackprop` should + run before any other data augmentations (e.g., :class:`.MixUp`) on the :attr:`.Event.AFTER_DATALOADER` event. + + """ + from composer.algorithms import SelectiveBackprop + return sort_to_front(algorithms, cls=SelectiveBackprop) + + +def sort_fused_layernorm_last(algorithms: Sequence[Algorithm], event: Event) -> Sequence[Algorithm]: #noqa: D403 + """FusedLayerNorm should run after other algorithms that add LayerNorms (e.g. GatedLinearUnits). + + This ensures that all LayerNorms are converted to optimized fused versions. + + """ + from composer.algorithms import FusedLayerNorm + return sort_to_back(algorithms, cls=FusedLayerNorm) + + +def set_filo_order(algorithms: Sequence[Algorithm], event: Event) -> Sequence[Algorithm]: + """Establish a FILO order of algorithms ``before_`` and ``after_`` events. + + For the events that follow the ``before_*`` and ``after_*`` pattern (e.g., :attr:`.Event.BEFORE_LOSS` + and :attr:`.Event.AFTER_LOSS), the ordering of algorithms is reversed for the ``after_*`` events. + For example, four given algorithms ``A``, ``B``, ``C``, and ``D`` will run in ``ABCD`` ordering on + the ``before_*`` event while ``DCBA`` ordering on the ``after_*`` event. + + This allows algorithms to "clean up" their changes. For example, :class:`.LabelSmoothing` will smooth the labels + upon the :attr:`.Event.BEFORE_LOSS` event and then restore the original unsmoothed labels on the + :attr:`.Event.AFTER_LOSS` event. + + Events with the pattern ``_start`` or ``_end`` will not be affected. + """ + if event.name.startswith('AFTER_'): + return list(reversed(algorithms)) + + return algorithms + + +def warn_if_multiple_loss_interpolation(algorithms: Sequence[Algorithm], event: Event) -> Sequence[Algorithm]: + """Multiple algorithms that interpolate the loss may have unexpected behavior.""" + from composer.algorithms.warnings import NotIntendedUseWarning + + is_interpolate = [a for a in algorithms if hasattr(a, 'interpolate_loss') and a.interpolate_loss] # type: ignore + if len(is_interpolate) > 1: + warnings.warn( + NotIntendedUseWarning( + f'Multiple algorithms interpolating the loss can lead to unexpected behavior: {is_interpolate}')) + + return algorithms diff --git a/tests/test_engine.py b/tests/test_engine.py index 9c737ee907..38942a1b99 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -6,13 +6,12 @@ import subprocess import sys import textwrap -from typing import List, Sequence +from typing import List from unittest.mock import Mock import pytest import composer -from composer.algorithms import SelectiveBackprop from composer.core import Engine, Event from composer.core.algorithm import Algorithm from composer.core.callback import Callback @@ -27,6 +26,7 @@ def always_match_algorithms(): Mock(**{ 'match.return.value': True, 'apply.return_value': n, # return encodes order + 'interpolate_loss': False, }) for n in range(5) ] @@ -81,55 +81,6 @@ def test_engine_trace_never(self, event: Event, dummy_state: State, never_match_ assert all([tr.run is False for tr in trace.values()]) -@pytest.mark.parametrize('event', [ - Event.EPOCH_START, - Event.BEFORE_LOSS, - Event.BEFORE_BACKWARD, -]) -def test_engine_lifo_first_in(event: Event, dummy_state: State, dummy_logger: Logger, - always_match_algorithms: List[Algorithm]): - dummy_state.algorithms = always_match_algorithms - trace = run_event(event, dummy_state, dummy_logger) - order = [tr.order for tr in trace.values()] - expected_order = [tr.exit_code for tr in trace.values()] # use exit_code to uniquely label algos - - assert order == expected_order - - -@pytest.mark.parametrize('event', [ - Event.AFTER_LOSS, - Event.AFTER_BACKWARD, - Event.BATCH_END, -]) -def test_engine_lifo_last_out(event: Event, dummy_state: State, always_match_algorithms: List[Algorithm], - dummy_logger: Logger): - dummy_state.algorithms = always_match_algorithms - trace = run_event(event, dummy_state, dummy_logger) - order = [tr.order for tr in trace.values()] - expected_order = list(reversed([tr.exit_code for tr in trace.values()])) - - assert order == expected_order - - -def test_engine_with_selective_backprop(always_match_algorithms: Sequence[Algorithm], dummy_logger: Logger, - dummy_state: State): - sb = SelectiveBackprop(start=0.5, end=0.9, keep=0.5, scale_factor=0.5, interrupt=2) - sb.apply = Mock(return_value='sb') - sb.match = Mock(return_value=True) - - event = Event.INIT # doesn't matter for this test - - algorithms = list(always_match_algorithms[0:2]) + [sb] + list(always_match_algorithms[2:]) - dummy_state.algorithms = algorithms - - trace = run_event(event, dummy_state, dummy_logger) - - expected = ['sb', 0, 1, 2, 3, 4] - actual = [tr.exit_code for tr in trace.values()] - - assert actual == expected - - def test_engine_is_dead_after_close(dummy_state: State, dummy_logger: Logger): # Create the trainer and run an event engine = Engine(dummy_state, dummy_logger) diff --git a/tests/test_passes.py b/tests/test_passes.py new file mode 100644 index 0000000000..728a4111a1 --- /dev/null +++ b/tests/test_passes.py @@ -0,0 +1,129 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Type +from unittest.mock import Mock + +import pytest + +from composer import Algorithm, Engine, Event, Logger, State +from composer.algorithms import FusedLayerNorm, SelectiveBackprop +from composer.core.passes import sort_to_back, sort_to_front + +from .test_engine import run_event + + +@pytest.fixture +def always_match_algorithms(): + return [ + Mock(**{ + 'match.return.value': True, + 'apply.return_value': n, # return encodes order + 'interpolate_loss': False, + }) for n in range(5) + ] + + +@pytest.fixture() +def dummy_logger(dummy_state: State): + return Logger(dummy_state) + + +def test_register_pass(dummy_state, dummy_logger): + + dummy_algorithm = Mock() + dummy_algorithm.match.return_value = True + dummy_algorithm.apply.return_value = 'dummy' + + def insert_dummy_algorithm(algorithms, event): + algorithms.append(dummy_algorithm) + return algorithms + + engine = Engine(dummy_state, dummy_logger) + engine.register_pass(insert_dummy_algorithm) + + trace = engine.run_event(Event.INIT) + + assert 'dummy' in [tr.exit_code for tr in trace.values()] + + +class TestLIFOPass: + + @pytest.mark.parametrize('event', [ + Event.BEFORE_LOSS, + Event.BEFORE_BACKWARD, + ]) + def test_lifo_first_in(self, event: Event, dummy_state: State, dummy_logger: Logger, + always_match_algorithms: List[Algorithm]): + dummy_state.algorithms = always_match_algorithms + trace = run_event(event, dummy_state, dummy_logger) + order = [tr.order for tr in trace.values()] + expected_order = [tr.exit_code for tr in trace.values()] # use exit_code to uniquely label algos + + assert order == expected_order + + @pytest.mark.parametrize('event', [ + Event.AFTER_LOSS, + Event.AFTER_BACKWARD, + ]) + def test_lifo_last_out(self, event: Event, dummy_state: State, always_match_algorithms: List[Algorithm], + dummy_logger: Logger): + dummy_state.algorithms = always_match_algorithms + trace = run_event(event, dummy_state, dummy_logger) + order = [tr.order for tr in trace.values()] + expected_order = list(reversed([tr.exit_code for tr in trace.values()])) + + assert order == expected_order + + +class TestAlgorithmOrderingPasses: + + @pytest.mark.parametrize('algorithm_cls', [FusedLayerNorm]) + def test_algorithm_last(self, algorithm_cls: Type[Algorithm], always_match_algorithms: List[Algorithm], + dummy_logger: Logger, dummy_state: State): + + if algorithm_cls == FusedLayerNorm: + pytest.importorskip('apex') + + algorithm = algorithm_cls() + algorithm.apply = Mock(return_value='algo') + algorithm.match = Mock(return_value=True) + + algortihms = always_match_algorithms[0:2] + [algorithm] + always_match_algorithms[2:] + dummy_state._algorithms = algortihms + + trace = run_event(Event.INIT, dummy_state, dummy_logger) + + expected = [0, 1, 2, 3, 4, 'algo'] + actual = [tr.exit_code for tr in trace.values()] + + assert actual == expected + + @pytest.mark.parametrize('algorithm_cls', [SelectiveBackprop]) + def test_algorithm_first(self, algorithm_cls: Type[Algorithm], always_match_algorithms: List[Algorithm], + dummy_logger: Logger, dummy_state: State): + + algorithm = algorithm_cls() + algorithm.apply = Mock(return_value='algo') + algorithm.match = Mock(return_value=True) + + algortihms = always_match_algorithms[0:2] + [algorithm] + always_match_algorithms[2:] + dummy_state._algorithms = algortihms + + trace = run_event(Event.INIT, dummy_state, dummy_logger) + + expected = ['algo', 0, 1, 2, 3, 4] + actual = [tr.exit_code for tr in trace.values()] + + assert actual == expected + + +class TestSortHelpers: + + def test_sort_to_back(self): + lst = [1, 'a', 'c', 2, 3.0] + assert sort_to_back(lst, int) == ['a', 'c', 3.0, 1, 2] + + def test_sort_to_front(self): + lst = [1, 'a', 'c', 2, 3.0] + assert sort_to_front(lst, int) == [1, 2, 'a', 'c', 3.0]