diff --git a/CHANGELOG.md b/CHANGELOG.md index 97a91fc881..4572d25220 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed a bug where `pretraining_transforms` in the `ImageEmbedder` was never called. ([1196](https://github.com/PyTorchLightning/lightning-flash/pull/1196)) + +- Fixed a bug where `BASE_MODEL_NAME` was not in the dict for dino and moco strategies. ([1196](https://github.com/PyTorchLightning/lightning-flash/pull/1196)) + - Fixed normalizing inputs to video classification ([#1213](https://github.com/PyTorchLightning/lightning-flash/pull/1213)) - Fixed examples (question answering), where NLTK's `punkt` module needs to be downloaded first. ([#1215](https://github.com/PyTorchLightning/lightning-flash/pull/1215/files)) diff --git a/flash/core/data/io/input_transform.py b/flash/core/data/io/input_transform.py index c4aa3e7588..62bbce2258 100644 --- a/flash/core/data/io/input_transform.py +++ b/flash/core/data/io/input_transform.py @@ -1051,6 +1051,9 @@ def create_transform( if inspect.isclass(transform) and issubclass(transform, InputTransform): return transform(running_stage=running_stage, **transform_kwargs) + if isinstance(transform, partial) and transform.func.__name__ == "LambdaInputTransform": + return transform(running_stage=running_stage, **transform_kwargs) + if isinstance(transform, Callable): return LambdaInputTransform( running_stage=running_stage, diff --git a/flash/image/embedding/heads/vissl_heads.py b/flash/image/embedding/heads/vissl_heads.py index 5fb7817f67..cb2798fe38 100644 --- a/flash/image/embedding/heads/vissl_heads.py +++ b/flash/image/embedding/heads/vissl_heads.py @@ -89,7 +89,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def simclr_head( - dims: List[int] = [2048, 2048, 128], + dims: List[int] = [2048, 2048, 256], use_bn: bool = True, **kwargs, ) -> nn.Module: @@ -141,7 +141,7 @@ def swav_head( def barlow_twins_head(**kwargs) -> nn.Module: - return simclr_head(dims=[2048, 8192, 8192, 8192], **kwargs) + return simclr_head(**kwargs) def moco_head(**kwargs) -> nn.Module: @@ -150,7 +150,6 @@ def moco_head(**kwargs) -> nn.Module: def dino_head(**kwargs) -> nn.Module: return swav_head( - dims=[384, 2048, 2048, 256], use_bn=False, return_embeddings=False, activation_name="GELU", diff --git a/flash/image/embedding/losses/vissl_losses.py b/flash/image/embedding/losses/vissl_losses.py index 87dcf5260c..b1ba8f936b 100644 --- a/flash/image/embedding/losses/vissl_losses.py +++ b/flash/image/embedding/losses/vissl_losses.py @@ -18,6 +18,7 @@ if _VISSL_AVAILABLE: import vissl.losses # noqa: F401 + from classy_vision.generic.distributed_util import set_cpu_device from classy_vision.losses import ClassyLoss, LOSS_REGISTRY from vissl.config.attr_dict import AttrDict else: @@ -26,6 +27,7 @@ def get_loss_fn(loss_name: str, cfg: AttrDict): + set_cpu_device() loss_fn = LOSS_REGISTRY[loss_name](cfg) loss_fn.__dict__["loss_name"] = loss_name @@ -79,6 +81,7 @@ def swav_loss( queue_length: int = 0, start_iter: int = 0, local_queue_length: int = 0, + **kwargs, ) -> ClassyLoss: loss_name = "swav_loss" cfg = AttrDict( @@ -108,7 +111,10 @@ def swav_loss( def barlow_twins_loss( - lambda_: float = 0.0051, scale_loss: float = 0.024, latent_embedding_dim: int = 8192 + lambda_: float = 0.0051, + scale_loss: float = 0.024, + latent_embedding_dim: int = 8192, + **kwargs, ) -> ClassyLoss: loss_name = "barlow_twins_loss" cfg = AttrDict( @@ -127,6 +133,7 @@ def simclr_loss( embedding_dim: int = 128, effective_batch_size: int = 1, # set by setup training hook world_size: int = 1, # set by setup training hook + **kwargs, ) -> ClassyLoss: loss_name = "simclr_info_nce_loss" cfg = AttrDict( @@ -151,6 +158,7 @@ def moco_loss( momentum: float = 0.999, temperature: int = 0.2, shuffle_batch: bool = True, + **kwargs, ) -> ClassyLoss: loss_name = "moco_loss" cfg = AttrDict( diff --git a/flash/image/embedding/model.py b/flash/image/embedding/model.py index 521f95e28f..54d01ca043 100644 --- a/flash/image/embedding/model.py +++ b/flash/image/embedding/model.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import warnings +from functools import partial from typing import Any, Dict, List, Optional from flash.core.adapter import AdapterTask from flash.core.data.io.input import DataKeys +from flash.core.data.io.input_transform import LambdaInputTransform from flash.core.data.transforms import ApplyToKeys from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _VISSL_AVAILABLE, requires @@ -111,13 +113,17 @@ def __init__( ) input_transform, self.collate_fn = self.transforms.get(pretraining_transform)(**pretraining_transform_kwargs) - self.input_transform = ApplyToKeys(DataKeys.INPUT, input_transform) + output = ApplyToKeys(DataKeys.INPUT, input_transform) + self.input_transform = partial(LambdaInputTransform, transform=output) warnings.warn( "Warning: VISSL ImageEmbedder overrides any user provided transforms" " with pre-defined transforms for the training strategy." ) + def on_epoch_start(self) -> None: + self.adapter.on_epoch_start() + def on_train_start(self) -> None: self.adapter.on_train_start() diff --git a/flash/image/embedding/vissl/adapter.py b/flash/image/embedding/vissl/adapter.py index af48c4d2f2..119db01974 100644 --- a/flash/image/embedding/vissl/adapter.py +++ b/flash/image/embedding/vissl/adapter.py @@ -122,10 +122,18 @@ def from_task( return result + def on_epoch_start(self) -> None: + use_gpu = self.adapter_task.device != torch.device("cpu") and self.adapter_task.device != "cpu" + if hasattr(self.loss_fn, "info_criterion"): + self.loss_fn.info_criterion.use_gpu = use_gpu + if hasattr(self.loss_fn, "swav_criterion"): + self.loss_fn.swav_criterion.use_gpu = use_gpu + @staticmethod def get_model_config_template(): cfg = AttrDict( { + "BASE_MODEL_NAME": "multi_input_output_model", "SINGLE_PASS_EVERY_CROP": False, "INPUT_TYPE": "rgb", "MULTI_INPUT_HEAD_MAPPING": [], diff --git a/tests/image/embedding/test_model.py b/tests/image/embedding/test_model.py index f17287dafd..cb042e0053 100644 --- a/tests/image/embedding/test_model.py +++ b/tests/image/embedding/test_model.py @@ -51,25 +51,40 @@ def test_load_from_checkpoint_dependency_error(): @pytest.mark.skipif(not (_TORCHVISION_AVAILABLE and _VISSL_AVAILABLE), reason="vissl not installed.") -@pytest.mark.parametrize("backbone, training_strategy", [("resnet", "barlow_twins")]) -def test_vissl_training(tmpdir, backbone, training_strategy): +@pytest.mark.parametrize( + "backbone, training_strategy, head, pretraining_transform", + [ + ("vision_transformer", "simclr", "simclr_head", "simclr_transform"), + pytest.param( + "vision_transformer", + "dino", + "dino_head", + "dino_transform", + marks=pytest.mark.skipif(torch.cuda.device_count() < 1, reason="VISSL DINO calls all_reduce internally."), + ), + ("vision_transformer", "barlow_twins", "simclr_head", "barlow_twins_transform"), + ("vision_transformer", "swav", "swav_head", "swav_transform"), + ], +) +def test_vissl_training(backbone, training_strategy, head, pretraining_transform): + # moco strategy, transform and head is not added for this test as it doesn't work as of now. datamodule = ImageClassificationData.from_datasets( train_dataset=FakeData(), batch_size=4, ) + training_strategy_kwargs = { + "dims": [384, 2048, 2048, 256], + } + dim_key = "latent_embedding_dim" if training_strategy == "barlow_twins" else "embedding_dim" + training_strategy_kwargs[dim_key] = 256 + embedder = ImageEmbedder( backbone=backbone, training_strategy=training_strategy, - head="simclr_head", - pretraining_transform="barlow_twins_transform", - training_strategy_kwargs={"latent_embedding_dim": 128}, - pretraining_transform_kwargs={ - "total_num_crops": 2, - "num_crops": [2], - "size_crops": [96], - "crop_scales": [[0.4, 1]], - }, + head=head, + pretraining_transform=pretraining_transform, + training_strategy_kwargs=training_strategy_kwargs, ) trainer = flash.Trainer(max_steps=3, max_epochs=1, gpus=torch.cuda.device_count())