Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix pretraining_transforms for ImageEmbedder (#1196)
Browse files Browse the repository at this point in the history
Co-authored-by: Ethan Harris <ethanwharris@gmail.com>
  • Loading branch information
krshrimali and ethanwharris committed Mar 30, 2022
1 parent 63c4841 commit 2f12b35
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 16 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed examples (question answering), where NLTK's `punkt` module needs to be downloaded first. ([#1215](https://github.com/PyTorchLightning/lightning-flash/pull/1215/files))
- Fixed normalizing inputs to video classification ([#1213](https://github.com/PyTorchLightning/lightning-flash/pull/1213))
- Fixed a bug where `pretraining_transforms` in the `ImageEmbedder` was never called. ([1196](https://github.com/PyTorchLightning/lightning-flash/pull/1196))
- Fixed a bug where `BASE_MODEL_NAME` was not in the dict for dino and moco strategies. ([1196](https://github.com/PyTorchLightning/lightning-flash/pull/1196))

### Removed

Expand Down
3 changes: 3 additions & 0 deletions flash/core/data/io/input_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,9 @@ def create_transform(
if inspect.isclass(transform) and issubclass(transform, InputTransform):
return transform(running_stage=running_stage, **transform_kwargs)

if isinstance(transform, partial) and transform.func.__name__ == "LambdaInputTransform":
return transform(running_stage=running_stage, **transform_kwargs)

if isinstance(transform, Callable):
return LambdaInputTransform(
running_stage=running_stage,
Expand Down
5 changes: 2 additions & 3 deletions flash/image/embedding/heads/vissl_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


def simclr_head(
dims: List[int] = [2048, 2048, 128],
dims: List[int] = [2048, 2048, 256],
use_bn: bool = True,
**kwargs,
) -> nn.Module:
Expand Down Expand Up @@ -141,7 +141,7 @@ def swav_head(


def barlow_twins_head(**kwargs) -> nn.Module:
return simclr_head(dims=[2048, 8192, 8192, 8192], **kwargs)
return simclr_head(**kwargs)


def moco_head(**kwargs) -> nn.Module:
Expand All @@ -150,7 +150,6 @@ def moco_head(**kwargs) -> nn.Module:

def dino_head(**kwargs) -> nn.Module:
return swav_head(
dims=[384, 2048, 2048, 256],
use_bn=False,
return_embeddings=False,
activation_name="GELU",
Expand Down
10 changes: 9 additions & 1 deletion flash/image/embedding/losses/vissl_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

if _VISSL_AVAILABLE:
import vissl.losses # noqa: F401
from classy_vision.generic.distributed_util import set_cpu_device
from classy_vision.losses import ClassyLoss, LOSS_REGISTRY
from vissl.config.attr_dict import AttrDict
else:
Expand All @@ -26,6 +27,7 @@


def get_loss_fn(loss_name: str, cfg: AttrDict):
set_cpu_device()
loss_fn = LOSS_REGISTRY[loss_name](cfg)
loss_fn.__dict__["loss_name"] = loss_name

Expand Down Expand Up @@ -79,6 +81,7 @@ def swav_loss(
queue_length: int = 0,
start_iter: int = 0,
local_queue_length: int = 0,
**kwargs,
) -> ClassyLoss:
loss_name = "swav_loss"
cfg = AttrDict(
Expand Down Expand Up @@ -108,7 +111,10 @@ def swav_loss(


def barlow_twins_loss(
lambda_: float = 0.0051, scale_loss: float = 0.024, latent_embedding_dim: int = 8192
lambda_: float = 0.0051,
scale_loss: float = 0.024,
latent_embedding_dim: int = 8192,
**kwargs,
) -> ClassyLoss:
loss_name = "barlow_twins_loss"
cfg = AttrDict(
Expand All @@ -127,6 +133,7 @@ def simclr_loss(
embedding_dim: int = 128,
effective_batch_size: int = 1, # set by setup training hook
world_size: int = 1, # set by setup training hook
**kwargs,
) -> ClassyLoss:
loss_name = "simclr_info_nce_loss"
cfg = AttrDict(
Expand All @@ -151,6 +158,7 @@ def moco_loss(
momentum: float = 0.999,
temperature: int = 0.2,
shuffle_batch: bool = True,
**kwargs,
) -> ClassyLoss:
loss_name = "moco_loss"
cfg = AttrDict(
Expand Down
8 changes: 7 additions & 1 deletion flash/image/embedding/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from functools import partial
from typing import Any, Dict, List, Optional

from flash.core.adapter import AdapterTask
from flash.core.data.io.input import DataKeys
from flash.core.data.io.input_transform import LambdaInputTransform
from flash.core.data.transforms import ApplyToKeys
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _VISSL_AVAILABLE, requires
Expand Down Expand Up @@ -111,13 +113,17 @@ def __init__(
)

input_transform, self.collate_fn = self.transforms.get(pretraining_transform)(**pretraining_transform_kwargs)
self.input_transform = ApplyToKeys(DataKeys.INPUT, input_transform)
output = ApplyToKeys(DataKeys.INPUT, input_transform)
self.input_transform = partial(LambdaInputTransform, transform=output)

warnings.warn(
"Warning: VISSL ImageEmbedder overrides any user provided transforms"
" with pre-defined transforms for the training strategy."
)

def on_epoch_start(self) -> None:
self.adapter.on_epoch_start()

def on_train_start(self) -> None:
self.adapter.on_train_start()

Expand Down
8 changes: 8 additions & 0 deletions flash/image/embedding/vissl/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,18 @@ def from_task(

return result

def on_epoch_start(self) -> None:
use_gpu = self.adapter_task.device != torch.device("cpu") and self.adapter_task.device != "cpu"
if hasattr(self.loss_fn, "info_criterion"):
self.loss_fn.info_criterion.use_gpu = use_gpu
if hasattr(self.loss_fn, "swav_criterion"):
self.loss_fn.swav_criterion.use_gpu = use_gpu

@staticmethod
def get_model_config_template():
cfg = AttrDict(
{
"BASE_MODEL_NAME": "multi_input_output_model",
"SINGLE_PASS_EVERY_CROP": False,
"INPUT_TYPE": "rgb",
"MULTI_INPUT_HEAD_MAPPING": [],
Expand Down
37 changes: 26 additions & 11 deletions tests/image/embedding/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,25 +51,40 @@ def test_load_from_checkpoint_dependency_error():


@pytest.mark.skipif(not (_TORCHVISION_AVAILABLE and _VISSL_AVAILABLE), reason="vissl not installed.")
@pytest.mark.parametrize("backbone, training_strategy", [("resnet", "barlow_twins")])
def test_vissl_training(tmpdir, backbone, training_strategy):
@pytest.mark.parametrize(
"backbone, training_strategy, head, pretraining_transform",
[
("vision_transformer", "simclr", "simclr_head", "simclr_transform"),
pytest.param(
"vision_transformer",
"dino",
"dino_head",
"dino_transform",
marks=pytest.mark.skipif(torch.cuda.device_count() < 1, reason="VISSL DINO calls all_reduce internally."),
),
("vision_transformer", "barlow_twins", "simclr_head", "barlow_twins_transform"),
("vision_transformer", "swav", "swav_head", "swav_transform"),
],
)
def test_vissl_training(backbone, training_strategy, head, pretraining_transform):
# moco strategy, transform and head is not added for this test as it doesn't work as of now.
datamodule = ImageClassificationData.from_datasets(
train_dataset=FakeData(),
batch_size=4,
)

training_strategy_kwargs = {
"dims": [384, 2048, 2048, 256],
}
dim_key = "latent_embedding_dim" if training_strategy == "barlow_twins" else "embedding_dim"
training_strategy_kwargs[dim_key] = 256

embedder = ImageEmbedder(
backbone=backbone,
training_strategy=training_strategy,
head="simclr_head",
pretraining_transform="barlow_twins_transform",
training_strategy_kwargs={"latent_embedding_dim": 128},
pretraining_transform_kwargs={
"total_num_crops": 2,
"num_crops": [2],
"size_crops": [96],
"crop_scales": [[0.4, 1]],
},
head=head,
pretraining_transform=pretraining_transform,
training_strategy_kwargs=training_strategy_kwargs,
)

trainer = flash.Trainer(max_steps=3, max_epochs=1, gpus=torch.cuda.device_count())
Expand Down

0 comments on commit 2f12b35

Please sign in to comment.