Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to register algorithm passes #1377

Merged
merged 4 commits into from
Aug 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,7 @@ package-lock.json

# pycharm
.idea/

# composer
data/
datasets/
177 changes: 78 additions & 99 deletions composer/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,34 @@
:class:`.SelectiveBackprop` algorithm runs on the :attr:`.Event.AFTER_DATALOADER` event and must run before
any data augmentations. :class:`.Engine` runs re-ordering passes to resolve such ordering issues or conflicts.

.. note::
These orderings are enforced by algorithm passes. The default passes registered to the Engine are found in
:mod:`composer.core.passes`. To register a new pass, use :meth:`.Engine.register_pass`, e.g.

* An instance of :class:`.Engine` is automatically constructed by the :class:`.Trainer`
constructor. A user need not instantiate the :class:`.Engine` class.

* The design of :class:`.Engine` is subject to change in future releases
to accommodate more complexity as we investigate composition of algorithms.
.. testsetup::

# dummy algorithm
MyAlgorithm = None

Currently, the following passes are registered:
.. doctest::

* **LIFO order for events**
from composer import Engine, Algorithm, Event
from typing import Sequence

For the events that follow the ``before_*`` (e.g., :attr:`.Event.BEFORE_LOSS`) and ``after_*`` (e.g.,
:attr:`.Event.AFTER_LOSS`) pattern, the ordering of algorithms is reversed for the ``after_*`` events. For example,
four given algorithms ``A``, ``B``, ``C``, and ``D`` will run in ``ABCD`` ordering on the ``before_*`` event while
``DCBA`` ordering on the ``after_*`` event.
def run_last(algorithms: Sequence[Algorithm], event: Event) -> Sequence[Algorithm]:
algorithms = sorted(algorithms, key=lambda x: isinstance(x, MyAlgorithm))

This allows algorithms to "clean up" their changes. For example, :class:`.LabelSmoothing` will smooth the labels
upon the :attr:`.Event.BEFORE_LOSS` event and then restore the original unsmoothed labels on the
:attr:`.Event.AFTER_LOSS` event.
Engine.register_pass(run_last)

* **Run Selective Backprop first**
.. note::

* An instance of :class:`.Engine` is automatically constructed by the :class:`.Trainer`
constructor. A user need not instantiate the :class:`.Engine` class.

.. note::
* The design of :class:`.Engine` is subject to change in future releases
to accommodate more complexity as we investigate composition of algorithms.

:class:`.SelectiveBackprop` runs after the dataloader returns the batch and executes an extra forward pass to rank
and prune the examples in the batch by loss. To ensure a clean estimate of loss, :class:`.SelectiveBackprop` should
run before any other data augmentations (e.g., :class:`.MixUp`) on the :attr:`.Event.AFTER_DATALOADER` event.

Trace
~~~~~
Expand Down Expand Up @@ -64,12 +65,12 @@
import atexit
import contextlib
import logging
import warnings
import weakref
from collections import OrderedDict
from dataclasses import dataclass
from typing import ContextManager, Dict, Optional, Sequence, Union, cast
from typing import ContextManager, Dict, List, Optional, Sequence, TypeVar, Union, cast

from composer.core import passes
from composer.core.algorithm import Algorithm
from composer.core.callback import Callback
from composer.core.event import Event
Expand All @@ -81,35 +82,15 @@

__all__ = ['Trace', 'Engine', 'Traces']

T = TypeVar('T')

_ALWAYS_RECORD_EVENTS = [Event.INIT, Event.FIT_START, Event.EPOCH_START, Event.EPOCH_END]

#: The default traces of an entire run is an OrderedDict.
#: The keys are of format ``<algorithm_name>/<event>`` (e.g., ``Blurpool/INIT``) and values are an instance of
#: :class:`Trace`.
Traces = Dict[str, 'Trace']

_ALWAYS_RECORD_EVENTS = [Event.INIT, Event.FIT_START, Event.EPOCH_START, Event.EPOCH_END]
_EVENTS_WHERE_DATALOADER_IS_SET = [e for e in Event if e != Event.INIT]
_EVENTS_WHERE_MAX_DURATION_IS_SET = [
Event.FIT_START,
Event.EPOCH_START,
Event.BATCH_START,
Event.AFTER_DATALOADER,
Event.BEFORE_TRAIN_BATCH,
Event.BEFORE_FORWARD,
Event.AFTER_FORWARD,
Event.BEFORE_LOSS,
Event.AFTER_LOSS,
Event.BEFORE_BACKWARD,
Event.AFTER_BACKWARD,
Event.AFTER_TRAIN_BATCH,
Event.BATCH_END,
Event.BATCH_CHECKPOINT,
Event.EPOCH_END,
Event.EPOCH_CHECKPOINT,
Event.FIT_END,
]
_EVAL_EVENTS = [e for e in Event if e.name.startswith('EVAL_')]
_PREDICT_EVENTS = [e for e in Event if e.name.startswith('PREDICT_')]

# Track whether atexit triggered _close(), which indicates whether the python process is shutting down
# If so, do not run close() again via __del__(), as Python machinery (e.g. the ability to do conditional
# imports) are destroyed between close() and __del__().
Expand All @@ -127,6 +108,15 @@ def _set_atexit_ran():
atexit.register(_set_atexit_ran)


def _get_default_passes():
return [
passes.sort_selective_backprop_first,
passes.sort_fused_layernorm_last,
passes.set_filo_order,
passes.warn_if_multiple_loss_interpolation,
]
hanlint marked this conversation as resolved.
Show resolved Hide resolved


@dataclass
class Trace():
"""Record of an algorithm's execution.
Expand Down Expand Up @@ -173,6 +163,9 @@ def __init__(self, state: State, logger: Logger):
self.logger = logger
self.state = state
self._is_closed = False

self.algorithm_passes: List[passes.AlgorithmPass] = _get_default_passes()
hanlint marked this conversation as resolved.
Show resolved Hide resolved

atexit.register(self._close, state, logger)

def run_event(
Expand Down Expand Up @@ -233,11 +226,7 @@ def run_event(
if event.is_after_event and duration_marker is not None:
duration_marker.finish()

if event in _EVENTS_WHERE_DATALOADER_IS_SET:
assert self.state.dataloader is not None, f'The trainer should have set state.dataloader for event {event}.'

if event in _EVENTS_WHERE_MAX_DURATION_IS_SET:
assert self.state.max_duration is not None, f'The trainer should have set state.max_duration for event {event}.'
self._assert_dataloader_and_duration_set(self.state, event)

if event == Event.INIT:
# For the INIT event, run the callbacks first to initialize the loggers
Expand All @@ -254,13 +243,37 @@ def run_event(

return traces

def register_pass(self, algorithm_pass: passes.AlgorithmPass, index: int = -1):
hanlint marked this conversation as resolved.
Show resolved Hide resolved
"""Registers an algorithm pass with the Engine.

Args:
algorithm_pass (passes.AlgorithmPass): A method that maps a list of
algorithms to a list of algorithms.
index (int, optional): The index to insert into the list of passes.
If -1 (default), the pass will be insert to the end of the list.
"""
if index == -1:
index = len(self.algorithm_passes)

self.algorithm_passes.insert(index, algorithm_pass)

@staticmethod
def _assert_dataloader_and_duration_set(state: State, event: Event):
# correctness checks that dataloader and max duration need to be set for certain events

if event != Event.INIT: # datalaoder should be set on all events expect INIT
assert state.dataloader is not None, f'The trainer should have set state.dataloader for event {event}.'

if event != Event.INIT and not event.is_predict and not event.is_eval:
assert state.max_duration is not None, f'The trainer should have set state.max_duration for event {event}.'

def _run_algorithms(
self,
event: Event,
) -> Traces:
algorithms_to_run = [algo for algo in self.state.algorithms if algo.match(event, self.state)]

# future collision resolution
# apply algorithm passes
algorithms_to_run = self._compile(algorithms_to_run, event)

trace = _setup_trace(algorithms_to_run, event)
Expand Down Expand Up @@ -321,30 +334,11 @@ def _compile(
Returns:
Sequence[Algorithm]: Modified sequence of algorithms.
"""
from composer.algorithms import CutMix, FusedLayerNorm, MixUp, SelectiveBackprop, StochasticDepth

# Move selective backprop to the beginning while maintaining order of other algorithms
algorithms = sorted(algorithms_to_run,
key=lambda x: not isinstance(x, SelectiveBackprop) and not isinstance(x, StochasticDepth))

# Move fused layernorm to the end while maintaining order of other algorithms (FLN only does surgery on leaf modules)
algorithms = sorted(algorithms, key=lambda x: isinstance(x, FusedLayerNorm))

# Check for multiple algorithms that try to interpolate the loss at the same time
interpolation_settings = [a.interpolate_loss for a in algorithms if isinstance(a, (CutMix, MixUp))]
if sum(interpolation_settings) > 1:
warnings.warn(
'Multiple algorithms are trying to interpolate the loss. This can result in strange behavior.')

if event.is_after_event:
"""Establish a FILO queue of algorithms ``before_`` and ``after_`` an event.
# run reordering passes on the algorithms
for passes in self.algorithm_passes:
algorithms_to_run = passes(algorithms_to_run, event)

before_loss: A, B, C, D
after_loss: D, C, B, A
"""
algorithms = list(reversed(algorithms))

return algorithms
return algorithms_to_run

def _run_callbacks(
self,
Expand Down Expand Up @@ -392,32 +386,17 @@ def __del__(self):

def _debug_log(self, event: Event, msg: str):
"""Helper to include timestamp and event info in log messages."""
if event in _EVAL_EVENTS:
log.debug(
'[ep=%i][ba=%i][eval_ba=%i][event=%s]: %s',
int(self.state.timestamp.epoch),
int(self.state.timestamp.batch),
int(self.state.eval_timestamp.batch),
event.name,
msg,
)
elif event in _PREDICT_EVENTS:
log.debug(
'[ep=%i][ba=%i][predict_ba=%i][event=%s]: %s',
int(self.state.timestamp.epoch),
int(self.state.timestamp.batch),
int(self.state.predict_timestamp.batch),
event.name,
msg,
)
else:
log.debug(
'[ep=%i][ba=%i][event=%s]: %s',
int(self.state.timestamp.epoch),
int(self.state.timestamp.batch),
event.name,
msg,
)
timestamp = f'[ep={int(self.state.timestamp.epoch)}][ba={int(self.state.timestamp.batch)}]'

# for eval or pr
if event.is_eval:
timestamp += f'[eval_ba={int(self.state.eval_timestamp.batch)}]'
if event.is_predict:
timestamp += f'[predict_ba={int(self.state.predict_timestamp.batch)}]'

timestamp += f'[event={event.name}]'

log.debug(f'{timestamp}: {msg}')

def close(self) -> None:
"""Shutdown the engine.
Expand Down
10 changes: 10 additions & 0 deletions composer/core/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,16 @@ def canonical_name(self) -> str:
name = name.replace('_end', '')
return name

@property
def is_predict(self) -> bool:
"""Whether the event is during the predict loop."""
return self.value.startswith('predict')

@property
def is_eval(self) -> bool:
"""Whether the event is during the eval loop."""
return self.value.startswith('eval')


_BEFORE_EVENTS = (Event.FIT_START, Event.EPOCH_START, Event.BATCH_START, Event.BEFORE_TRAIN_BATCH, Event.BEFORE_FORWARD,
Event.BEFORE_LOSS, Event.BEFORE_BACKWARD, Event.EVAL_START, Event.EVAL_BATCH_START,
Expand Down
Loading