Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[PoC] Add MetaLearning support through learn2learn #737

Merged
merged 63 commits into from
Sep 20, 2021
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
4272598
update
tchaton Sep 6, 2021
73ec02e
update
tchaton Sep 7, 2021
986dfe0
update
tchaton Sep 7, 2021
17267a8
update
tchaton Sep 7, 2021
a3364a5
Merge branch 'master' into learn2learn
tchaton Sep 7, 2021
b33beb7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 7, 2021
1bc4298
wip
tchaton Sep 8, 2021
70ff518
Merge branch 'learn2learn' of https://github.com/PyTorchLightning/lig…
tchaton Sep 8, 2021
1af0544
update
tchaton Sep 8, 2021
084721c
Merge branch 'master' into learn2learn
tchaton Sep 8, 2021
c9c3a21
update imports
tchaton Sep 8, 2021
3b762f2
simplification
tchaton Sep 8, 2021
12d2668
Merge branch 'learn2learn' of https://github.com/PyTorchLightning/lig…
tchaton Sep 8, 2021
529d462
wip
tchaton Sep 8, 2021
5e202c4
update
tchaton Sep 8, 2021
73e4aa8
Fix JIT issues
ethanwharris Sep 8, 2021
593e0c9
Fix test
ethanwharris Sep 8, 2021
004e399
add ddp test
tchaton Sep 8, 2021
a65d23f
update
tchaton Sep 8, 2021
84bed01
test
tchaton Sep 8, 2021
3b6d919
update
tchaton Sep 8, 2021
38d5eee
add persistant workers
tchaton Sep 8, 2021
fffbaa6
update
tchaton Sep 8, 2021
2d819ca
update changelog
tchaton Sep 8, 2021
7e51199
update
tchaton Sep 8, 2021
2580063
Update flash_examples/image_classification.py
tchaton Sep 8, 2021
eaf8dfc
Update flash_examples/image_classification_meta_learning.py
tchaton Sep 8, 2021
9097697
repair the sampling
tchaton Sep 10, 2021
62476d5
update
tchaton Sep 10, 2021
3991f12
Merge branch 'learn2learn' of https://github.com/PyTorchLightning/lig…
tchaton Sep 10, 2021
4a550e4
update
tchaton Sep 10, 2021
c5b5940
update
tchaton Sep 10, 2021
40d0dca
update
tchaton Sep 10, 2021
dd0cb79
update
tchaton Sep 10, 2021
7dc1d34
update
tchaton Sep 10, 2021
482f576
update
tchaton Sep 10, 2021
72e1cb7
update
tchaton Sep 10, 2021
afc6219
update
tchaton Sep 11, 2021
cd63701
update
tchaton Sep 11, 2021
8dff389
update
tchaton Sep 12, 2021
ce54995
update
tchaton Sep 12, 2021
d0bd09c
update
tchaton Sep 12, 2021
dad26cc
Merge branch 'master' into learn2learn
tchaton Sep 13, 2021
a802ec7
Update CHANGELOG.md
ethanwharris Sep 13, 2021
43d201f
Update CHANGELOG.md
ethanwharris Sep 13, 2021
47dace8
update
tchaton Sep 14, 2021
baa1fcf
update example
tchaton Sep 20, 2021
8280807
Merge branch 'master' into learn2learn
tchaton Sep 20, 2021
01b2049
update
tchaton Sep 20, 2021
6928d10
update on comments
tchaton Sep 20, 2021
40c827d
Merge branch 'learn2learn' of https://github.com/PyTorchLightning/lig…
tchaton Sep 20, 2021
c040b5c
update
tchaton Sep 20, 2021
e95d565
update
tchaton Sep 20, 2021
a903cd2
remove typing
tchaton Sep 20, 2021
cd91b53
update
tchaton Sep 20, 2021
7376197
Update gpu-tests.yml
ethanwharris Sep 20, 2021
d2e22ec
update
tchaton Sep 20, 2021
9b2b2bf
Merge branch 'learn2learn' of https://github.com/PyTorchLightning/lig…
tchaton Sep 20, 2021
1c8660f
Apply suggestions from code review
ethanwharris Sep 20, 2021
3b28379
resolve test
tchaton Sep 20, 2021
65d3b06
Merge branch 'learn2learn' of https://github.com/PyTorchLightning/lig…
tchaton Sep 20, 2021
a06546a
Merge branch 'master' into learn2learn
mergify[bot] Sep 20, 2021
fd2cce5
update
tchaton Sep 20, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .azure-pipelines/gpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ jobs:
python -m coverage run --source flash -m pytest flash tests/examples/test_scripts.py -v --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=30
displayName: 'Testing'

- bash: |
pip install git+https://github.com/tchaton/learn2learn@flash
bash tests/special_tests.sh
displayName: 'Testing: special'
tchaton marked this conversation as resolved.
Show resolved Hide resolved

- bash: |
python -m coverage report
python -m coverage xml
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support `learn2learn` training_strategy for `ImageClassifier` ([#737](https://github.com/PyTorchLightning/lightning-flash/pull/737))

### Changed


### Fixed


## [0.5.0] - 2021-09-07

### Added
Expand Down
28 changes: 22 additions & 6 deletions flash/core/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from abc import abstractmethod
from typing import Any, Callable, Optional

import torch.jit
from torch import nn
from torch.utils.data import DataLoader, Sampler

Expand Down Expand Up @@ -59,6 +60,10 @@ def test_epoch_end(self, outputs) -> None:
pass


def identity_collate_fn(x):
return x


class AdapterTask(Task):
"""The ``AdapterTask`` is a :class:`~flash.core.model.Task` which wraps an :class:`~flash.core.adapter.Adapter`
and forwards all of the hooks.
Expand All @@ -73,11 +78,12 @@ def __init__(self, adapter: Adapter, **kwargs):

self.adapter = adapter

@torch.jit.unused
@property
def backbone(self) -> nn.Module:
return self.adapter.backbone

def forward(self, x: Any) -> Any:
def forward(self, x: torch.Tensor) -> Any:
return self.adapter.forward(x)

def training_step(self, batch: Any, batch_idx: int) -> Any:
Expand All @@ -104,6 +110,7 @@ def test_epoch_end(self, outputs) -> None:
def process_train_dataset(
self,
dataset: BaseAutoDataset,
trainer: flash.Trainer,
batch_size: int,
num_workers: int,
pin_memory: bool,
Expand All @@ -113,12 +120,13 @@ def process_train_dataset(
sampler: Optional[Sampler] = None,
) -> DataLoader:
return self.adapter.process_train_dataset(
dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler
dataset, trainer, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler
)

def process_val_dataset(
self,
dataset: BaseAutoDataset,
trainer: flash.Trainer,
batch_size: int,
num_workers: int,
pin_memory: bool,
Expand All @@ -128,12 +136,13 @@ def process_val_dataset(
sampler: Optional[Sampler] = None,
) -> DataLoader:
return self.adapter.process_val_dataset(
dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler
dataset, trainer, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler
)

def process_test_dataset(
self,
dataset: BaseAutoDataset,
trainer: flash.Trainer,
batch_size: int,
num_workers: int,
pin_memory: bool,
Expand All @@ -143,7 +152,7 @@ def process_test_dataset(
sampler: Optional[Sampler] = None,
) -> DataLoader:
return self.adapter.process_test_dataset(
dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler
dataset, trainer, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler
)

def process_predict_dataset(
Expand All @@ -152,11 +161,18 @@ def process_predict_dataset(
batch_size: int = 1,
num_workers: int = 0,
pin_memory: bool = False,
collate_fn: Callable = lambda x: x,
collate_fn: Callable = identity_collate_fn,
shuffle: bool = False,
drop_last: bool = True,
sampler: Optional[Sampler] = None,
) -> DataLoader:
return self.adapter.process_predict_dataset(
dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler
dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
collate_fn=collate_fn,
shuffle=shuffle,
drop_last=drop_last,
sampler=sampler,
)
58 changes: 48 additions & 10 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torchmetrics
from pytorch_lightning.utilities import rank_zero_warn

from flash.core.adapter import AdapterTask
from flash.core.data.data_source import DefaultDataKeys, LabelsState
from flash.core.data.process import Serializer
from flash.core.model import Task
Expand All @@ -37,7 +38,29 @@ def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch.
return F.binary_cross_entropy_with_logits(x, y.float())


class ClassificationTask(Task):
class ClassificationMixin:
def _build(
self,
num_classes: Optional[int] = None,
loss_fn: Optional[Callable] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
multi_label: bool = False,
):
if metrics is None:
metrics = torchmetrics.F1(num_classes) if (multi_label and num_classes) else torchmetrics.Accuracy()

if loss_fn is None:
loss_fn = binary_cross_entropy_with_logits if multi_label else F.cross_entropy

return metrics, loss_fn

def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor:
if getattr(self.hparams, "multi_label", False):
return torch.sigmoid(x)
return torch.softmax(x, dim=1)


class ClassificationTask(Task, ClassificationMixin):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
*args,
Expand All @@ -48,11 +71,9 @@ def __init__(
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
**kwargs,
) -> None:
if metrics is None:
metrics = torchmetrics.F1(num_classes) if (multi_label and num_classes) else torchmetrics.Accuracy()

if loss_fn is None:
loss_fn = binary_cross_entropy_with_logits if multi_label else F.cross_entropy
metrics, loss_fn = ClassificationMixin._build(self, num_classes, loss_fn, metrics, multi_label)

super().__init__(
*args,
loss_fn=loss_fn,
Expand All @@ -61,11 +82,28 @@ def __init__(
**kwargs,
)

def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor:
if getattr(self.hparams, "multi_label", False):
return torch.sigmoid(x)
# we'll assume that the data always comes as `(B, C, ...)`
return torch.softmax(x, dim=1)

class ClassificationAdapterTask(AdapterTask, ClassificationMixin):
def __init__(
self,
*args,
num_classes: Optional[int] = None,
loss_fn: Optional[Callable] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
multi_label: bool = False,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
**kwargs,
) -> None:

metrics, loss_fn = ClassificationMixin._build(self, num_classes, loss_fn, metrics, multi_label)

super().__init__(
*args,
loss_fn=loss_fn,
metrics=metrics,
serializer=serializer or Classes(multi_label=multi_label),
**kwargs,
)


class ClassificationSerializer(Serializer):
Expand Down
3 changes: 3 additions & 0 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def _train_dataloader(self) -> DataLoader:
if isinstance(getattr(self, "trainer", None), pl.Trainer):
return self.trainer.lightning_module.process_train_dataset(
train_ds,
trainer=self.trainer,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=pin_memory,
Expand Down Expand Up @@ -330,6 +331,7 @@ def _val_dataloader(self) -> DataLoader:
if isinstance(getattr(self, "trainer", None), pl.Trainer):
return self.trainer.lightning_module.process_val_dataset(
val_ds,
trainer=self.trainer,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=pin_memory,
Expand All @@ -352,6 +354,7 @@ def _test_dataloader(self) -> DataLoader:
if isinstance(getattr(self, "trainer", None), pl.Trainer):
return self.trainer.lightning_module.process_test_dataset(
test_ds,
trainer=self.trainer,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=pin_memory,
Expand Down
2 changes: 1 addition & 1 deletion flash/core/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def _detach_preprocessing_from_model(self, model: "Task", stage: Optional[Runnin
if isinstance(dl_args["collate_fn"], _Preprocessor):
dl_args["collate_fn"] = dl_args["collate_fn"]._original_collate_fn

if isinstance(dl_args["dataset"], IterableAutoDataset):
if isinstance(dl_args["dataset"], (IterableAutoDataset, IterableDataset)):
del dl_args["sampler"]

del dl_args["batch_sampler"]
Expand Down
13 changes: 13 additions & 0 deletions flash/core/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,9 @@ def load_data(

data = make_dataset(data, class_to_idx, extensions=self.extensions)
return [{DefaultDataKeys.INPUT: input, DefaultDataKeys.TARGET: target} for input, target in data]
elif dataset is not None:
dataset.num_classes = len(np.unique(data[1]))

return list(
filter(
lambda sample: has_file_allowed_extension(sample[DefaultDataKeys.INPUT], self.extensions),
Expand Down Expand Up @@ -622,6 +625,16 @@ class TensorDataSource(SequenceDataSource[torch.Tensor]):
"""The ``TensorDataSource`` is a ``SequenceDataSource`` which expects the input to
:meth:`~flash.core.data.data_source.DataSource.load_data` to be a sequence of ``torch.Tensor`` objects."""

def load_data(
self,
data: Tuple[Sequence[SEQUENCE_DATA_TYPE], Optional[Sequence]],
dataset: Optional[Any] = None,
) -> Sequence[Mapping[str, Any]]:
# TODO: Bring back the code to work out how many classes there are
if len(data) == 2:
dataset.num_classes = len(torch.unique(data[1]))
return super().load_data(data, dataset)


class NumpyDataSource(SequenceDataSource[np.ndarray]):
"""The ``NumpyDataSource`` is a ``SequenceDataSource`` which expects the input to
Expand Down
11 changes: 8 additions & 3 deletions flash/core/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,17 +342,22 @@ def default_transforms() -> Optional[Dict[str, Callable]]:
"""
return None

def _apply_sample_transform(self, sample: Any) -> Any:
if isinstance(sample, list):
return [self.current_transform(s) for s in sample]
return self.current_transform(sample)

def pre_tensor_transform(self, sample: Any) -> Any:
"""Transforms to apply on a single object."""
return self.current_transform(sample)
return self._apply_sample_transform(sample)

def to_tensor_transform(self, sample: Any) -> Tensor:
"""Transforms to convert single object to a tensor."""
return self.current_transform(sample)
return self._apply_sample_transform(sample)

def post_tensor_transform(self, sample: Tensor) -> Tensor:
"""Transforms to apply on a tensor."""
return self.current_transform(sample)
return self._apply_sample_transform(sample)

def per_batch_transform(self, batch: Any) -> Any:
"""Transforms to apply to a whole batch (if possible use this for efficiency).
Expand Down
2 changes: 2 additions & 0 deletions flash/core/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def kornia_collate(samples: Sequence[Dict[str, Any]]) -> Dict[str, Any]:
This function removes that dimension and then
applies ``torch.utils.data._utils.collate.default_collate``.
"""
if len(samples) == 1 and isinstance(samples[0], list):
samples = samples[0]
for sample in samples:
for key in sample.keys():
if torch.is_tensor(sample[key]) and sample[key].ndim == 4:
Expand Down
4 changes: 4 additions & 0 deletions flash/core/integrations/icevision/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from torch.utils.data import DataLoader, Sampler

import flash
from flash.core.adapter import Adapter
from flash.core.data.auto_dataset import BaseAutoDataset
from flash.core.data.data_source import DefaultDataKeys
Expand Down Expand Up @@ -91,6 +92,7 @@ def _collate_fn(collate_fn, samples, metadata: Optional[List[Dict[str, Any]]] =
def process_train_dataset(
self,
dataset: BaseAutoDataset,
trainer: "flash.Trainer",
batch_size: int,
num_workers: int,
pin_memory: bool,
Expand All @@ -114,6 +116,7 @@ def process_train_dataset(
def process_val_dataset(
self,
dataset: BaseAutoDataset,
trainer: "flash.Trainer",
batch_size: int,
num_workers: int,
pin_memory: bool,
Expand All @@ -137,6 +140,7 @@ def process_val_dataset(
def process_test_dataset(
self,
dataset: BaseAutoDataset,
trainer: "flash.Trainer",
batch_size: int,
num_workers: int,
pin_memory: bool,
Expand Down
Loading