Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix pretraining_transforms for ImageEmbedder #1196

Merged
merged 30 commits into from
Mar 4, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
1912a23
Fix transforms, use manual setters and getters
krshrimali Feb 25, 2022
bbafe62
Remove unused import
krshrimali Feb 25, 2022
34ecadd
Change tests and attempt a fix
krshrimali Feb 25, 2022
b09d286
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 25, 2022
fb67d64
fix running_stage, use partial
krshrimali Feb 28, 2022
c813060
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 28, 2022
0bd1eb7
remove unused import
krshrimali Feb 28, 2022
be97303
Merge branch 'fix/ImageEmbedder/transforms' of github.com:PyTorchLigh…
krshrimali Feb 28, 2022
cdaa003
Remove unused part of code
krshrimali Mar 1, 2022
4ef4c72
Modify tests, mix dimensions
krshrimali Mar 2, 2022
d98b582
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 2, 2022
be0b8b9
Modify tests, fix kwargs (dims), not working for moco as of now
krshrimali Mar 2, 2022
60eed8a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 2, 2022
093e2c8
Merge branch 'master' into fix/ImageEmbedder/transforms
krshrimali Mar 2, 2022
344bbd6
unwanted changes...
krshrimali Mar 2, 2022
8282f74
Merge branch 'fix/ImageEmbedder/transforms' of github.com:PyTorchLigh…
krshrimali Mar 2, 2022
ae9cc0e
Add changelog entry
krshrimali Mar 2, 2022
001dfc3
changelog update
krshrimali Mar 2, 2022
cd0b09e
Per review, bring back some code + import from existing file
krshrimali Mar 3, 2022
8c8a70b
Merge branch 'fix/ImageEmbedder/transforms' of github.com:PyTorchLigh…
krshrimali Mar 3, 2022
ad2f90e
unused imports...
krshrimali Mar 3, 2022
005a71a
Merge branch 'master' into fix/ImageEmbedder/transforms
ethanwharris Mar 3, 2022
3556f18
Merge branch 'master' into fix/ImageEmbedder/transforms
krshrimali Mar 3, 2022
f36c865
testing...
krshrimali Mar 3, 2022
b3671a2
Merge branch 'fix/ImageEmbedder/transforms' of github.com:PyTorchLigh…
krshrimali Mar 3, 2022
ef0296f
use_gpu: True, losses are nn modules
krshrimali Mar 3, 2022
aba00f7
Fixes
ethanwharris Mar 3, 2022
e6f004f
Add back flaky reruns
ethanwharris Mar 3, 2022
ffe8736
Merge branch 'master' into fix/ImageEmbedder/transforms
ethanwharris Mar 3, 2022
bb27c2a
Update tests/image/embedding/test_model.py
ethanwharris Mar 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed a bug where input transforms in the `ImageEmbedder` were not called. ([1196](https://github.com/PyTorchLightning/lightning-flash/pull/1196))
krshrimali marked this conversation as resolved.
Show resolved Hide resolved

- Fixed a bug where `BASE_MODEL_NAME` was not in the dict for dino and moco strategies. ([1196](https://github.com/PyTorchLightning/lightning-flash/pull/1196))

- Fixed a bug where DDP would not work with Flash tasks ([#1182](https://github.com/PyTorchLightning/lightning-flash/pull/1182))

- Fixed DDP support for `VideoClassifier` ([#1189](https://github.com/PyTorchLightning/lightning-flash/pull/1189))
Expand Down
10 changes: 0 additions & 10 deletions flash/core/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import flash
from flash.core.data.io.input import InputBase
from flash.core.model import DatasetProcessor, ModuleWrapperBase, Task
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE


class Adapter(DatasetProcessor, ModuleWrapperBase, nn.Module):
Expand Down Expand Up @@ -79,15 +78,6 @@ def __init__(self, adapter: Adapter, **kwargs):

self.adapter = adapter

@torch.jit.unused
@property
def input_transform(self) -> Optional[INPUT_TRANSFORM_TYPE]:
return self.adapter.input_transform

@input_transform.setter
def input_transform(self, input_transform: INPUT_TRANSFORM_TYPE) -> None:
self.adapter.input_transform = input_transform

krshrimali marked this conversation as resolved.
Show resolved Hide resolved
@torch.jit.unused
@property
def backbone(self) -> nn.Module:
Expand Down
3 changes: 3 additions & 0 deletions flash/core/data/io/input_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,9 @@ def create_transform(
if inspect.isclass(transform) and issubclass(transform, InputTransform):
return transform(running_stage=running_stage, **transform_kwargs)

if isinstance(transform, partial) and transform.func.__name__ == "LambdaInputTransform":
return transform(running_stage=running_stage, **transform_kwargs)

if isinstance(transform, Callable):
return LambdaInputTransform(
running_stage=running_stage,
Expand Down
5 changes: 2 additions & 3 deletions flash/image/embedding/heads/vissl_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


def simclr_head(
dims: List[int] = [2048, 2048, 128],
dims: List[int] = [2048, 2048, 256],
use_bn: bool = True,
**kwargs,
) -> nn.Module:
Expand Down Expand Up @@ -141,7 +141,7 @@ def swav_head(


def barlow_twins_head(**kwargs) -> nn.Module:
return simclr_head(dims=[2048, 8192, 8192, 8192], **kwargs)
return simclr_head(**kwargs)


def moco_head(**kwargs) -> nn.Module:
Expand All @@ -150,7 +150,6 @@ def moco_head(**kwargs) -> nn.Module:

def dino_head(**kwargs) -> nn.Module:
return swav_head(
dims=[384, 2048, 2048, 256],
use_bn=False,
return_embeddings=False,
activation_name="GELU",
Expand Down
8 changes: 7 additions & 1 deletion flash/image/embedding/losses/vissl_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def swav_loss(
queue_length: int = 0,
start_iter: int = 0,
local_queue_length: int = 0,
**kwargs,
) -> ClassyLoss:
loss_name = "swav_loss"
cfg = AttrDict(
Expand Down Expand Up @@ -108,7 +109,10 @@ def swav_loss(


def barlow_twins_loss(
lambda_: float = 0.0051, scale_loss: float = 0.024, latent_embedding_dim: int = 8192
lambda_: float = 0.0051,
scale_loss: float = 0.024,
latent_embedding_dim: int = 8192,
**kwargs,
) -> ClassyLoss:
loss_name = "barlow_twins_loss"
cfg = AttrDict(
Expand All @@ -127,6 +131,7 @@ def simclr_loss(
embedding_dim: int = 128,
effective_batch_size: int = 1, # set by setup training hook
world_size: int = 1, # set by setup training hook
**kwargs,
) -> ClassyLoss:
loss_name = "simclr_info_nce_loss"
cfg = AttrDict(
Expand All @@ -151,6 +156,7 @@ def moco_loss(
momentum: float = 0.999,
temperature: int = 0.2,
shuffle_batch: bool = True,
**kwargs,
) -> ClassyLoss:
loss_name = "moco_loss"
cfg = AttrDict(
Expand Down
16 changes: 14 additions & 2 deletions flash/image/embedding/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Any, Dict, List, Optional
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Optional

from flash.core.adapter import AdapterTask
from flash.core.data.io.input import DataKeys
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.transforms import ApplyToKeys
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _VISSL_AVAILABLE, requires
Expand Down Expand Up @@ -110,8 +113,17 @@ def __init__(
learning_rate=learning_rate,
)

@dataclass
class LambdaInputTransform(InputTransform):
krshrimali marked this conversation as resolved.
Show resolved Hide resolved

transform: Callable = InputTransform._identity

def per_sample_transform(self) -> Callable:
return self.transform

input_transform, self.collate_fn = self.transforms.get(pretraining_transform)(**pretraining_transform_kwargs)
self.input_transform = ApplyToKeys(DataKeys.INPUT, input_transform)
output = ApplyToKeys(DataKeys.INPUT, input_transform)
self.input_transform = partial(LambdaInputTransform, transform=output)

warnings.warn(
"Warning: VISSL ImageEmbedder overrides any user provided transforms"
Expand Down
1 change: 1 addition & 0 deletions flash/image/embedding/vissl/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def from_task(
def get_model_config_template():
cfg = AttrDict(
{
"BASE_MODEL_NAME": "multi_input_output_model",
"SINGLE_PASS_EVERY_CROP": False,
"INPUT_TYPE": "rgb",
"MULTI_INPUT_HEAD_MAPPING": [],
Expand Down
33 changes: 21 additions & 12 deletions tests/image/embedding/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,26 +51,35 @@ def test_load_from_checkpoint_dependency_error():


@pytest.mark.skipif(not (_TORCHVISION_AVAILABLE and _VISSL_AVAILABLE), reason="vissl not installed.")
@pytest.mark.parametrize("backbone, training_strategy", [("resnet", "barlow_twins")])
def test_vissl_training(tmpdir, backbone, training_strategy):
@pytest.mark.parametrize(
"backbone, training_strategy, head, pretraining_transform",
[
("vision_transformer", "simclr", "simclr_head", "simclr_transform"),
("vision_transformer", "dino", "dino_head", "dino_transform"),
("vision_transformer", "barlow_twins", "simclr_head", "barlow_twins_transform"),
("vision_transformer", "swav", "swav_head", "swav_transform"),
],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

)
def test_vissl_training(backbone, training_strategy, head, pretraining_transform):
# moco strategy, transform and head is not added for this test as it doesn't work as of now.
datamodule = ImageClassificationData.from_datasets(
train_dataset=FakeData(),
batch_size=4,
)

training_strategy_kwargs = {
"dims": [384, 2048, 2048, 256],
}
dim_key = "latent_embedding_dim" if training_strategy == "barlow_twins" else "embedding_dim"
training_strategy_kwargs[dim_key] = 256

embedder = ImageEmbedder(
backbone=backbone,
training_strategy=training_strategy,
head="simclr_head",
pretraining_transform="barlow_twins_transform",
training_strategy_kwargs={"latent_embedding_dim": 128},
pretraining_transform_kwargs={
"total_num_crops": 2,
"num_crops": [2],
"size_crops": [96],
"crop_scales": [[0.4, 1]],
},
head=head,
pretraining_transform=pretraining_transform,
training_strategy_kwargs=training_strategy_kwargs,
)

trainer = flash.Trainer(max_steps=3, max_epochs=1, gpus=torch.cuda.device_count())
trainer = flash.Trainer(max_steps=3, max_epochs=1, gpus=torch.cuda.device_count(), strategy="ddp")
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
trainer.fit(embedder, datamodule=datamodule)