Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix VISSL on GPU and add VISSL GPU CI (#1256)
Browse files Browse the repository at this point in the history
Co-authored-by: Kushashwa Ravi Shrimali <kushashwaravishrimali@gmail.com>
  • Loading branch information
ethanwharris and krshrimali authored Mar 30, 2022
1 parent 2a09ce0 commit df67f87
Show file tree
Hide file tree
Showing 12 changed files with 111 additions and 34 deletions.
3 changes: 2 additions & 1 deletion .azure-pipelines/gpu-example-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ jobs:
parameters:
configs:
- "image"
- "image,image_extras"
- "icevision"
- "vissl"
- "text"
- "tabular"
- "video"
Expand Down
13 changes: 10 additions & 3 deletions .azure-pipelines/testing-template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
- bash: |
# python -m pip install "pip==20.1"
if [ "${{config}}" == "image,image_extras" ]; then pip install '.[image]' icevision effdet icedata; else pip install '.[${{config}}]'; fi
if [ "${{config}}" == "icevision" ]; then pip install '.[image]' icevision effdet icedata; elif [ "${{config}}" == "vissl" ]; then pip install '.[image]'; else pip install '.[${{config}}]'; fi
pip install '.[test]' --upgrade-strategy only-if-needed
pip list
displayName: 'Install dependencies'
Expand All @@ -46,11 +46,18 @@ jobs:
pip uninstall -y opencv-python-headless
pip install opencv-python-headless==4.5.5.64
displayName: 'Install OpenCV dependencies'
condition: eq('${{ config }}', 'image,image_extras')
condition: eq('${{ config }}', 'icevision')
- bash: |
pip install fairscale
pip install git+https://github.com/facebookresearch/ClassyVision.git
pip install git+https://github.com/facebookresearch/vissl.git
displayName: 'Install VISSL dependencies'
condition: eq('${{ config }}', 'vissl')
- bash: |
python -c "import torch; print(f'found GPUs: {torch.cuda.device_count()}')"
python -m coverage run --source flash -m pytest tests/examples/test_scripts.py -v --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=30
python -m coverage run --source flash -m pytest tests/examples/test_scripts.py tests/image/embedding/test_model.py -v --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=30
env:
CUDA_VISIBLE_DEVICES: ${{gids}}
FLASH_TEST_TOPIC: ${{ config }}
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed GPU support for self-supervised training with the `ImageEmbedder` ([#1256](https://github.com/PyTorchLightning/lightning-flash/pull/1256))

- Fixed a bug where collate functions were never called in the `ImageEmbedder` class. ([#1217](https://github.com/PyTorchLightning/lightning-flash/pull/1217))

- Fixed a bug where `pretraining_transforms` in the `ImageEmbedder` was never called. ([#1196](https://github.com/PyTorchLightning/lightning-flash/pull/1196))
Expand Down
35 changes: 26 additions & 9 deletions docs/source/reference/image_embedder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_embedder.svg
:tags: Image,Embedding

.. warning::

Multi-gpu training is not currently supported by the :class:`~flash.image.embedding.model.ImageEmbedder` task.

.. _image_embedder:

##############
Expand All @@ -17,7 +21,9 @@ The Task
Image embedding encodes an image into a vector of features which can be used for a downstream task.
This could include: clustering, similarity search, or classification.

The Flash :class:`~flash.image.embedding.model.ImageEmbedder` can be trained with Self Supervised Learning (SSL) to improve the quality of the embeddings it produces for your data.
The :class:`~flash.image.embedding.model.ImageEmbedder` internally relies on `VISSL <https://vissl.ai/>`_.
You can read more about our integration with VISSL here: :ref:`vissl`.

------

Expand All @@ -26,18 +32,29 @@ Example
*******

Let's see how to configure a training strategy for the :class:`~flash.image.embedding.model.ImageEmbedder` task.
A vanilla :class:`~flash.core.data.data_module.DataModule` object be created using standard Datasets as shown below.
Then the user can configure the :class:`~flash.image.embedding.model.ImageEmbedder` task with ``training_strategy``, ``backbone``, ``head`` and ``pretraining_transform``.
There are options provided to send additional arguments to config selections.
This task can now be sent to the ``fit()`` method of :class:`~flash.core.trainer.Trainer`.

.. note::

A lot of VISSL loss functions use hard-coded ``torch.distributed`` methods. The user is suggested to use ``accelerator=ddp`` even with a single GPU.
Only ``barlow_twins`` training strategy works on the CPU. All other loss functions are configured to work on GPUs.
First we create an :class:`~flash.image.classification.data.ImageClassificationData` object using a `Dataset` from torchvision.
Next, we configure the :class:`~flash.image.embedding.model.ImageEmbedder` task with ``training_strategy``, ``backbone``, ``head`` and ``pretraining_transform``.
Finally, we construct a :class:`~flash.core.trainer.Trainer` and call ``fit()``.
Here's the full example:

.. literalinclude:: ../../../flash_examples/image_embedder.py
:language: python
:lines: 14-

To learn how to view the available backbones / heads for this task, see :ref:`backbones_heads`.
You can view the available training strategies with the :meth:`~flash.image.embedding.model.ImageEmbedder.available_training_strategies` method.

.. note::

The ``"dino"`` training strategy only supports single GPU training with ``strategy="ddp"``.

The ``head`` and ``pretraining_transform`` arguments should match the choice of ``training_strategy`` following this table:

===================== ===================== ==========================
``training_strategy`` ``head`` ``pretraining_transform``
===================== ===================== ==========================
``simclr`` ``simclr_head`` ``simclr_transform``
``barlow_twins`` ``barlow_twins_head`` ``barlow_twins_transform``
``swav`` ``swav_head`` ``swav_transform``
``dino`` ``dino_head`` ``dino_transform``
===================== ===================== ==========================
2 changes: 1 addition & 1 deletion flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def _import_module(self):
if "FLASH_TEST_TOPIC" in os.environ:
topic = os.environ["FLASH_TEST_TOPIC"]
_IMAGE_TESTING = topic == "image"
_IMAGE_EXTRAS_TESTING = topic == "image,image_extras"
_IMAGE_EXTRAS_TESTING = topic == "image,image_extras" or topic == "icevision" or topic == "vissl"
_VIDEO_TESTING = topic == "video"
_VIDEO_EXTRAS_TESTING = topic == "video,video_extras"
_TABULAR_TESTING = topic == "tabular"
Expand Down
31 changes: 23 additions & 8 deletions flash/image/embedding/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,24 @@
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _VISSL_AVAILABLE, requires
from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE
from flash.image.embedding.backbones import IMAGE_EMBEDDER_BACKBONES
from flash.image.embedding.strategies import IMAGE_EMBEDDER_STRATEGIES
from flash.image.embedding.transforms import IMAGE_EMBEDDER_TRANSFORMS

if _VISSL_AVAILABLE:
import classy_vision
import classy_vision.generic.distributed_util

from flash.image.embedding.backbones import IMAGE_EMBEDDER_BACKBONES
from flash.image.embedding.strategies import IMAGE_EMBEDDER_STRATEGIES
from flash.image.embedding.transforms import IMAGE_EMBEDDER_TRANSFORMS

# patch this to avoid classy vision/vissl based distributed training
classy_vision.generic.distributed_util.get_world_size = lambda: 1
else:
IMAGE_EMBEDDER_BACKBONES = FlashRegistry("backbones")
IMAGE_EMBEDDER_STRATEGIES = FlashRegistry("embedder_training_strategies")
IMAGE_EMBEDDER_TRANSFORMS = FlashRegistry("embedder_transforms")

# Skip doctests if requirements aren't available
__doctest_skip__ = []
if not _VISSL_AVAILABLE:
__doctest_skip__ += [
"ImageEmbedder",
"ImageEmbedder.*",
]


class ImageEmbedder(AdapterTask):
Expand Down Expand Up @@ -130,6 +133,18 @@ def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloade
@classmethod
@requires(["image", "vissl", "fairscale"])
def available_training_strategies(cls) -> List[str]:
"""Get the list of available training strategies (passed to the ``training_strategy`` argument) for this
task.
Examples
________
.. doctest::
>>> from flash.image import ImageEmbedder
>>> ImageEmbedder.available_training_strategies() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
['barlow_twins', ..., 'swav']
"""
registry: Optional[FlashRegistry] = getattr(cls, "training_strategies", None)
if registry is None:
return []
Expand Down
3 changes: 3 additions & 0 deletions flash/image/embedding/vissl/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def __init__(self, vissl_adapter, vissl_loss, task_config, vissl_model) -> None:
# set for momentum teacher based hooks
self.last_batch = AttrDict({"sample": AttrDict({"input": None, "data_momentum": None})})

# used in dino
self.additional_log_data = {}


class VISSLAdapter(Adapter, AdaptVISSLHooks):
"""The ``VISSLAdapter`` is an :class:`~flash.core.adapter.Adapter` for integrating with VISSL.
Expand Down
3 changes: 3 additions & 0 deletions flash/image/embedding/vissl/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def on_start(self, task: "flash.image.embedding.vissl.adapter.MockVISSLTask") ->

task.loss.info_criterion.precompute_pos_neg_mask()

# Cast the loss to the correct device / dtype
task.loss.to(lightning_module.device, lightning_module.dtype)


class AdaptVISSLHooks(ModelHooks):
def __init__(self, hooks: List[ClassyHook], task) -> None:
Expand Down
13 changes: 6 additions & 7 deletions flash_examples/image_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,21 @@
# 1. Download the data and prepare the datamodule
datamodule = ImageClassificationData.from_datasets(
train_dataset=CIFAR10(".", download=True),
batch_size=16,
batch_size=4,
)

# 2. Build the task
embedder = ImageEmbedder(
backbone="resnet",
backbone="vision_transformer",
training_strategy="barlow_twins",
head="barlow_twins_head",
pretraining_transform="barlow_twins_transform",
training_strategy_kwargs={"latent_embedding_dim": 128},
pretraining_transform_kwargs={"size_crops": [196]},
pretraining_transform_kwargs={"size_crops": [32]},
)

# 3. Create the trainer and pre-train the encoder
# use accelerator='ddp' when using GPU(s),
# i.e. flash.Trainer(max_epochs=3, gpus=1, accelerator='ddp')
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.fit(embedder, datamodule=datamodule)

# 4. Save the model!
Expand All @@ -50,7 +48,8 @@
predict_files=[
"data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg",
"data/hymenoptera_data/predict/2039585088_c6f47c592e.jpg",
]
],
batch_size=3,
)
embeddings = trainer.predict(embedder, datamodule=datamodule)

Expand Down
13 changes: 11 additions & 2 deletions tests/examples/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@
_TABULAR_TESTING,
_TEXT_TESTING,
_VIDEO_TESTING,
_VISSL_AVAILABLE,
)
from tests.examples.utils import run_test
from tests.helpers.forked import forked
from tests.helpers.decorators import forked

root = Path(__file__).parent.parent.parent

Expand All @@ -56,6 +57,15 @@
"image_classification_multi_label.py",
marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed"),
),
pytest.param(
"image_embedder.py",
marks=[
pytest.mark.skipif(
not (_IMAGE_AVAILABLE and _VISSL_AVAILABLE), reason="image libraries aren't installed"
),
pytest.mark.skipif(torch.cuda.device_count() > 1, reason="VISSL integration doesn't support multi-GPU"),
],
),
pytest.param(
"object_detection.py",
marks=pytest.mark.skipif(
Expand All @@ -74,7 +84,6 @@
not (_IMAGE_AVAILABLE and _ICEVISION_AVAILABLE), reason="image libraries aren't installed"
),
),
# pytest.param("finetuning", "object_detection.py"), # TODO: takes too long.
pytest.param(
"question_answering.py",
marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed"),
Expand Down
File renamed without changes.
27 changes: 24 additions & 3 deletions tests/image/embedding/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
import torch

import flash
from flash.core.utilities.imports import _IMAGE_AVAILABLE, _TORCHVISION_AVAILABLE, _VISSL_AVAILABLE
from flash.core.utilities.imports import (
_IMAGE_AVAILABLE,
_PL_GREATER_EQUAL_1_5_0,
_TORCHVISION_AVAILABLE,
_VISSL_AVAILABLE,
)
from flash.image import ImageClassificationData, ImageEmbedder

if _TORCHVISION_AVAILABLE:
Expand Down Expand Up @@ -50,6 +55,7 @@ def test_load_from_checkpoint_dependency_error():
ImageEmbedder.load_from_checkpoint("not_a_real_checkpoint.pt")


@pytest.mark.skipif(torch.cuda.device_count() > 1, reason="VISSL integration doesn't support multi-GPU")
@pytest.mark.skipif(not (_TORCHVISION_AVAILABLE and _VISSL_AVAILABLE), reason="vissl not installed.")
@pytest.mark.parametrize(
"backbone, training_strategy, head, pretraining_transform",
Expand All @@ -70,7 +76,7 @@ def test_vissl_training(backbone, training_strategy, head, pretraining_transform
# moco strategy, transform and head is not added for this test as it doesn't work as of now.
datamodule = ImageClassificationData.from_datasets(
train_dataset=FakeData(16),
predict_dataset=FakeData(4),
predict_dataset=FakeData(8),
batch_size=4,
)

Expand All @@ -81,7 +87,22 @@ def test_vissl_training(backbone, training_strategy, head, pretraining_transform
pretraining_transform=pretraining_transform,
)

trainer = flash.Trainer(max_steps=3, max_epochs=1, gpus=torch.cuda.device_count())
kwargs = {}

# DINO only works with DDP
if training_strategy == "dino":
if _PL_GREATER_EQUAL_1_5_0:
kwargs["strategy"] = "ddp"
else:
kwargs["accelerator"] = "ddp"

trainer = flash.Trainer(
max_steps=3,
max_epochs=1,
gpus=torch.cuda.device_count(),
**kwargs,
)

trainer.fit(embedder, datamodule=datamodule)
predictions = trainer.predict(embedder, datamodule=datamodule)
for prediction_batch in predictions:
Expand Down

0 comments on commit df67f87

Please sign in to comment.