Skip to content

Commit

Permalink
Add Torch ORT Callback, re-write callback section (#720)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Naren authored Aug 31, 2021
1 parent b8a6afa commit 7c0fcdd
Show file tree
Hide file tree
Showing 15 changed files with 341 additions and 50 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added Advantage Actor-Critic (A2C) Model [#598](https://github.com/PyTorchLightning/lightning-bolts/pull/598))


- Added Torch ORT Callback [#720](https://github.com/PyTorchLightning/lightning-bolts/pull/720))


### Changed

- Changed the default values `pin_memory=False`, `shuffle=False` and `num_workers=16` to `pin_memory=True`, `shuffle=True` and `num_workers=0` of datamodules ([#701](https://github.com/PyTorchLightning/lightning-bolts/pull/701))
Expand Down
40 changes: 0 additions & 40 deletions docs/source/callbacks.rst

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
.. role:: hidden
:class: hidden-section

Info Callbacks
==============
Monitoring Callbacks
====================

These callbacks give all sorts of useful information during training.

Expand Down
File renamed without changes.
35 changes: 35 additions & 0 deletions docs/source/callbacks/torch_ort.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
==================
Torch ORT Callback
==================

`Torch ORT <https://cloudblogs.microsoft.com/opensource/2021/07/13/accelerate-pytorch-training-with-torch-ort/>`__ converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. See installation instructions `here <https://github.com/pytorch/ort#install-in-a-local-python-environment>`__.

This is primarily useful for when training with a Transformer model. The ORT callback works when a single model is specified as `self.model` within the ``LightningModule`` as shown below.

.. note::

Not all Transformer models are supported. See `this table <https://github.com/microsoft/onnxruntime-training-examples#examples>`__ for supported models + branches containing fixes for certain models.

.. code-block:: python
from pytorch_lightning import LightningModule, Trainer
from transformers import AutoModel
from pl_bolts.callbacks import ORTCallback
class MyTransformerModel(LightningModule):
def __init__(self):
super().__init__()
self.model = AutoModel.from_pretrained('bert-base-cased')
...
model = MyTransformerModel()
trainer = Trainer(gpus=1, callbacks=ORTCallback())
trainer.fit(model)
For even easier setup and integration, have a look at our Lightning Flash integration for :ref:`Text Classification <lightning_flash:text_classification_ort>`, :ref:`Translation <lightning_flash:translation_ort>` and :ref:`Summarization <lightning_flash:summarization_ort>`.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Interpolates latent dims.

Example output:

.. image:: _images/gans/basic_gan_interpolate.jpg
.. image:: ../_images/gans/basic_gan_interpolate.jpg
:width: 400
:alt: Example latent space interpolation

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Shows how the input would have to change to move the prediction from one logit t

Example outputs:

.. image:: _images/vision/confused_logit.png
.. image:: ../_images/vision/confused_logit.png
:width: 400
:alt: Example of prediction confused between 5 and 8

Expand Down
3 changes: 2 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
"pytorch_lightning": ("https://pytorch-lightning.readthedocs.io/en/stable/", None),
"lightning_flash": ("https://lightning-flash.readthedocs.io/en/latest/", None),
"python": ("https://docs.python.org/3", None),
"torch": ("https://pytorch.org/docs/stable/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
Expand Down Expand Up @@ -385,7 +386,7 @@ def find_source():
# This value determines the text for the permalink; it defaults to "¶". Set it to None or the empty
# string to disable permalinks.
# https://www.sphinx-doc.org/en/master/usage/configuration.html#confval-html_add_permalinks
html_add_permalinks = "¶"
html_permalinks_icon = "¶"

# True to prefix each section label with the name of the document it is in, followed by a colon.
# For example, index:Introduction for a section called Introduction that appears in document index.rst.
Expand Down
10 changes: 5 additions & 5 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ Lightning-Bolts documentation
:name: callbacks
:caption: Callbacks

callbacks
info_callbacks
self_supervised_callbacks
variational_callbacks
vision_callbacks
callbacks/monitor_callbacks
callbacks/self_supervised_callbacks
callbacks/variational_callbacks
callbacks/vision_callbacks
callbacks/torch_ort

.. toctree::
:maxdepth: 2
Expand Down
2 changes: 2 additions & 0 deletions pl_bolts/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pl_bolts.callbacks.data_monitor import ModuleDataMonitor, TrainingDataMonitor
from pl_bolts.callbacks.printing import PrintTableMetricsCallback
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
from pl_bolts.callbacks.torch_ort import ORTCallback
from pl_bolts.callbacks.variational import LatentDimInterpolator
from pl_bolts.callbacks.verification.batch_gradient import BatchGradientVerificationCallback # type: ignore
from pl_bolts.callbacks.vision.confused_logit import ConfusedLogitCallback
Expand All @@ -18,4 +19,5 @@
"LatentDimInterpolator",
"ConfusedLogitCallback",
"TensorboardGenerativeModelImageSampler",
"ORTCallback",
]
51 changes: 51 additions & 0 deletions pl_bolts/callbacks/torch_ort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from pl_bolts.utils import _TORCH_ORT_AVAILABLE

if _TORCH_ORT_AVAILABLE:
from torch_ort import ORTModule


class ORTCallback(Callback):
"""Enables Torch ORT: Accelerate PyTorch models with ONNX Runtime.
Wraps a model with the ORT wrapper, lazily converting your module into an ONNX export, to optimize for
training and inference.
Usage:
# via Transformer Tasks
model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=True)
# or via the trainer
trainer = flash.Trainer(callbacks=ORTCallback())
"""

def __init__(self):
if not _TORCH_ORT_AVAILABLE:
raise MisconfigurationException(
"Torch ORT is required to use ORT. See here for installation: https://github.com/pytorch/ort"
)

def on_before_accelerator_backend_setup(self, trainer: Trainer, pl_module: LightningModule) -> None:
if not hasattr(pl_module, "model"):
raise MisconfigurationException(
"Torch ORT requires to wrap a single model that defines a forward function "
"assigned as `model` inside the `LightningModule`."
)
if not isinstance(pl_module.model, ORTModule):
pl_module.model = ORTModule(pl_module.model)
1 change: 1 addition & 0 deletions pl_bolts/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,6 @@ def _compare_version(package: str, op, version) -> bool:
_MATPLOTLIB_AVAILABLE: bool = _module_available("matplotlib")
_TORCHVISION_LESS_THAN_0_9_1: bool = _compare_version("torchvision", operator.lt, "0.9.1")
_PL_GREATER_EQUAL_1_4 = _compare_version("pytorch_lightning", operator.ge, "1.4.0")
_TORCH_ORT_AVAILABLE = _module_available("torch_ort")

__all__ = ["BatchGradientVerification"]
55 changes: 55 additions & 0 deletions tests/callbacks/test_ort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from pl_bolts.callbacks import ORTCallback
from pl_bolts.utils import _TORCH_ORT_AVAILABLE
from tests.helpers.boring_model import BoringModel

if _TORCH_ORT_AVAILABLE:
from torch_ort import ORTModule


@pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.")
def test_init_train_enable_ort(tmpdir):
class TestCallback(Callback):
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
assert isinstance(pl_module.model, ORTModule)

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.model = self.layer

def forward(self, x):
return self.model(x)

model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=[ORTCallback(), TestCallback()])
trainer.fit(model)
trainer.test(model)


@pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.")
def test_ort_callback_fails_no_model(tmpdir):
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=ORTCallback())
with pytest.raises(MisconfigurationException, match="Torch ORT requires to wrap a single model"):
trainer.fit(
model,
)
Empty file added tests/helpers/__init__.py
Empty file.
Loading

0 comments on commit 7c0fcdd

Please sign in to comment.