Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Support PL 1.5.0 #933

Merged
merged 17 commits into from
Nov 5, 2021
50 changes: 41 additions & 9 deletions flash/core/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,40 @@
from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, TYPE_CHECKING, Union

import torch
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from torch.utils.data import DataLoader, IterableDataset

import flash
from flash.core.data.auto_dataset import IterableAutoDataset
from flash.core.data.batch import _DeserializeProcessor, _Postprocessor, _Preprocessor, _Sequential, _SerializeProcessor
from flash.core.data.data_source import DataSource
from flash.core.data.process import DefaultPreprocess, Deserializer, Postprocess, Preprocess, Serializer
from flash.core.data.properties import ProcessState
from flash.core.data.utils import _POSTPROCESS_FUNCS, _PREPROCESS_FUNCS, _STAGES_PREFIX
from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_4_3
from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_4_3, _PL_GREATER_EQUAL_1_5_0
from flash.core.utilities.stages import _RUNNING_STAGE_MAPPING, RunningStage

if not _PL_GREATER_EQUAL_1_5_0:
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader

if TYPE_CHECKING:
from flash.core.model import Task


class DataLoaderGetter:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still wraps and patches the dataloader, right? Is there no way around this?
Will this patch be assigned to back to the loader or just used internally?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, unfortunately this was the only solution for now, but the upcoming data pipeline refactor should remove this patching entirely

"""A utility class to be used when patching the ``{stage}_dataloader`` attribute of a LightningModule."""

def __init__(self, dataloader):
self.dataloader = dataloader

# Dummy `__code__` attribute to trick is_overridden
self.__code__ = self.__call__.__code__

def __call__(self):
return self.dataloader


class DataPipelineState:
"""A class to store and share all process states once a :class:`.DataPipeline` has been initialized."""

Expand Down Expand Up @@ -315,16 +331,34 @@ def _get_dataloader(model: "Task", loader_name: str) -> Tuple[DataLoader, str]:
dataloader = getattr(model, loader_name)
attr_name = loader_name

elif model.trainer and hasattr(model.trainer, "datamodule") and model.trainer.datamodule:
dataloader = getattr(model, f"trainer.datamodule.{loader_name}", None)
elif (
model.trainer
and hasattr(model.trainer, "datamodule")
and model.trainer.datamodule
and is_overridden(loader_name, model.trainer.datamodule, flash.DataModule)
):
dataloader = getattr(model.trainer.datamodule, loader_name, None)
attr_name = f"trainer.datamodule.{loader_name}"

elif _PL_GREATER_EQUAL_1_5_0 and model.trainer:
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
source = getattr(model.trainer._data_connector, f"_{loader_name}_source")
if not source.is_module():
dataloader = source.dataloader()
attr_name = loader_name

if dataloader is not None:
# Update source as wrapped loader will be attached to model
source.instance = model
source.name = loader_name

return dataloader, attr_name

@staticmethod
def _patch_dataloader(model: "Task", dataloader: Union[Callable, DataLoader], stage: RunningStage):
if isinstance(dataloader, DataLoader):
if _PL_GREATER_EQUAL_1_4_3:
if _PL_GREATER_EQUAL_1_5_0:
dataloader = DataLoaderGetter(dataloader)
elif _PL_GREATER_EQUAL_1_4_3:
dataloader = _PatchDataLoader(dataloader, _STAGES_PREFIX[stage])
dataloader.patch(model)
else:
Expand Down Expand Up @@ -369,7 +403,7 @@ def _attach_preprocess_to_model(
if not dataloader:
continue

if isinstance(dataloader, (_PatchDataLoader, Callable)):
if callable(dataloader):
dataloader = dataloader()

if dataloader is None:
Expand Down Expand Up @@ -504,9 +538,7 @@ def _detach_preprocessing_from_model(self, model: "Task", stage: Optional[Runnin
if not dataloader:
continue

if isinstance(dataloader, _PatchDataLoader):
dataloader = dataloader()
elif isinstance(dataloader, Callable):
if callable(dataloader):
dataloader = dataloader()

if isinstance(dataloader, Sequence):
Expand Down
8 changes: 1 addition & 7 deletions flash/core/data/new_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Tuple, Type, TYPE_CHECKING, Union
from typing import Any, Optional, Tuple, Type, Union

import pytorch_lightning as pl
import torch
Expand All @@ -30,14 +30,8 @@
from flash.core.data.datasets import BaseDataset
from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE, InputTransform
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE
from flash.core.utilities.stages import RunningStage

if _FIFTYONE_AVAILABLE and TYPE_CHECKING:
from fiftyone.core.collections import SampleCollection
else:
SampleCollection = None


class DataModule(DataModule):
"""A basic DataModule class for all Flash tasks. This class includes references to a
Expand Down
25 changes: 22 additions & 3 deletions flash/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
from torch.utils.data import DataLoader

import flash
from flash.core.data.data_module import DataModule
from flash.core.data.new_data_module import DataModule as NewDataModule
from flash.core.finetuning import _DEFAULTS_FINETUNE_STRATEGIES, instantiate_default_finetuning_callbacks
from flash.core.utilities.imports import _SERVE_AVAILABLE
from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_5_0, _SERVE_AVAILABLE


def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
Expand Down Expand Up @@ -277,14 +279,31 @@ def request_dataloader(
The dataloader
"""
model, stage, is_legacy = self._parse_request_dataloader_args(args, kwargs)

if is_legacy:
self.call_hook(f"on_{stage}_dataloader")
dataloader = getattr(model, f"{stage}_dataloader")()
else:
hook = f"{stage.dataloader_prefix}_dataloader"
self.call_hook("on_" + hook, pl_module=model)
dataloader = self.call_hook(hook, pl_module=model)

dataloader = None
if _PL_GREATER_EQUAL_1_5_0:
source = getattr(self._data_connector, f"_{stage.dataloader_prefix}_dataloader_source")
if (
not source.is_module()
or not isinstance(source.instance, DataModule)
or isinstance(source.instance, (LightningModule, NewDataModule))
):
dataloader = source.dataloader()

if dataloader is None:
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
dataloader = self.call_hook(hook, pl_module=model)

if isinstance(dataloader, tuple):
dataloader = list(dataloader)
self.accelerator.barrier("get_dataloaders")
if _PL_GREATER_EQUAL_1_5_0:
self.training_type_plugin.barrier("get_dataloaders")
else:
self.accelerator.barrier("get_dataloaders")
return dataloader
3 changes: 2 additions & 1 deletion flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _compare_version(package: str, op, version) -> bool:
_PIL_AVAILABLE = _module_available("PIL")
_OPEN3D_AVAILABLE = _module_available("open3d")
_SEGMENTATION_MODELS_AVAILABLE = _module_available("segmentation_models_pytorch")
_FASTFACE_AVAILABLE = _module_available("fastface")
_FASTFACE_AVAILABLE = _module_available("fastface") and _compare_version("pytorch_lightning", operator.lt, "1.5.0")
_LIBROSA_AVAILABLE = _module_available("librosa")
_TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter")
_TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse")
Expand Down Expand Up @@ -118,6 +118,7 @@ class Image:
if Version:
_TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0")
_PL_GREATER_EQUAL_1_4_3 = _compare_version("pytorch_lightning", operator.ge, "1.4.3")
_PL_GREATER_EQUAL_1_5_0 = _compare_version("pytorch_lightning", operator.ge, "1.5.0")

_TEXT_AVAILABLE = all(
[
Expand Down
6 changes: 5 additions & 1 deletion flash/image/classification/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,11 @@ def _convert_dataset(
devices = 1
if isinstance(trainer.training_type_plugin, DataParallelPlugin):
# when using DP, we need to sample n tasks, so it can splitted across multiple devices.
devices = trainer.accelerator_connector.devices
if hasattr(trainer, "accelerator_connector"):
accelerator_connector = trainer.accelerator_connector
else:
accelerator_connector = trainer._accelerator_connector
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
devices = accelerator_connector.devices
dataset = TaskDataParallel(taskset, epoch_length=epoch_length, devices=devices, collate_fn=None)
self.trainer.accumulated_grad_batches = self.meta_batch_size / devices

Expand Down
37 changes: 27 additions & 10 deletions flash/image/classification/integrations/baal/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,23 @@
from typing import Any, Dict, Optional

import torch
from pytorch_lightning import LightningModule
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.fit_loop import FitLoop
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.trainer.states import TrainerFn, TrainerStatus

import flash
from flash.core.data.data_pipeline import DataLoaderGetter
from flash.core.data.utils import _STAGES_PREFIX
from flash.core.utilities.imports import requires
from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_5_0, requires
from flash.core.utilities.stages import RunningStage
from flash.image.classification.integrations.baal.data import ActiveLearningDataModule
from flash.image.classification.integrations.baal.dropout import InferenceMCDropoutTask

if not _PL_GREATER_EQUAL_1_5_0:
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader


class ActiveLearningLoop(Loop):
@requires("baal")
Expand Down Expand Up @@ -133,35 +137,48 @@ def __getattr__(self, key):
return getattr(self.fit_loop, key)
return self.__dict__[key]

def _connect(self, model: LightningModule):
if _PL_GREATER_EQUAL_1_5_0:
self.trainer.training_type_plugin.connect(model)
else:
self.trainer.accelerator.connect(model)

def _reset_fitting(self):
self.trainer.state.fn = TrainerFn.FITTING
self.trainer.training = True
self.trainer.lightning_module.on_train_dataloader()
self.trainer.accelerator.connect(self._lightning_module)
self._connect(self._lightning_module)
self.fit_loop.epoch_progress = Progress()

def _reset_predicting(self):
self.trainer.state.fn = TrainerFn.PREDICTING
self.trainer.predicting = True
self.trainer.lightning_module.on_predict_dataloader()
self.trainer.accelerator.connect(self.inference_model)
self._connect(self.inference_model)

def _reset_testing(self):
self.trainer.state.fn = TrainerFn.TESTING
self.trainer.state.status = TrainerStatus.RUNNING
self.trainer.testing = True
self.trainer.lightning_module.on_test_dataloader()
self.trainer.accelerator.connect(self._lightning_module)
self._connect(self._lightning_module)

def _reset_dataloader_for_stage(self, running_state: RunningStage):
dataloader_name = f"{_STAGES_PREFIX[running_state]}_dataloader"
# If the dataloader exists, we reset it.
dataloader = getattr(self.trainer.datamodule, dataloader_name, None)
if dataloader:
setattr(
self.trainer.lightning_module,
dataloader_name,
_PatchDataLoader(dataloader(), running_state),
)
if _PL_GREATER_EQUAL_1_5_0:
setattr(
self.trainer.lightning_module,
dataloader_name,
DataLoaderGetter(dataloader()),
)
else:
setattr(
self.trainer.lightning_module,
dataloader_name,
_PatchDataLoader(dataloader(), running_state),
)
setattr(self.trainer, dataloader_name, None)
getattr(self.trainer, f"reset_{dataloader_name}")(self.trainer.lightning_module)
6 changes: 5 additions & 1 deletion flash/image/embedding/vissl/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ def on_start(self, task: "flash.image.embedding.vissl.adapter.MockVISSLTask") ->

# get around vissl distributed training by setting MockTask flags
num_nodes = lightning_module.trainer.num_nodes
accelerators_ids = lightning_module.trainer.accelerator_connector.parallel_device_ids
if hasattr(lightning_module.trainer, "accelerator_connector"):
accelerator_connector = lightning_module.trainer.accelerator_connector
else:
accelerator_connector = lightning_module.trainer._accelerator_connector
accelerators_ids = accelerator_connector.parallel_device_ids
accelerator_per_node = len(accelerators_ids) if accelerators_ids is not None else 1
task.world_size = num_nodes * accelerator_per_node

Expand Down
12 changes: 10 additions & 2 deletions flash/video/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,21 @@ def __init__(
)

def on_train_start(self) -> None:
if self.trainer.accelerator_connector.is_distributed:
if hasattr(self.trainer, "accelerator_connector"):
accelerator_connector = self.trainer.accelerator_connector
else:
accelerator_connector = self.trainer._accelerator_connector
if accelerator_connector.is_distributed:
encoded_dataset = self.trainer.train_dataloader.loaders.dataset.dataset
encoded_dataset._video_sampler = DistributedSampler(encoded_dataset._labeled_videos)
super().on_train_start()

def on_train_epoch_start(self) -> None:
if self.trainer.accelerator_connector.is_distributed:
if hasattr(self.trainer, "accelerator_connector"):
accelerator_connector = self.trainer.accelerator_connector
else:
accelerator_connector = self.trainer._accelerator_connector
if accelerator_connector.is_distributed:
encoded_dataset = self.trainer.train_dataloader.loaders.dataset.dataset
encoded_dataset._video_sampler.set_epoch(self.trainer.current_epoch)
super().on_train_epoch_start()
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ packaging
numpy
torch>=1.7.1
torchmetrics>=0.4.0,!=0.5.1
pytorch-lightning==1.4.9
pytorch-lightning>=1.4.0
pyDeprecate
pandas<1.3.0
jsonargparse[signatures]>=3.17.0
Expand Down
7 changes: 0 additions & 7 deletions tests/audio/classification/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ def test_from_filepaths_smoke(tmpdir):
num_workers=0,
)
assert spectrograms_data.train_dataloader() is not None
assert spectrograms_data.val_dataloader() is None
assert spectrograms_data.test_dataloader() is None

data = next(iter(spectrograms_data.train_dataloader()))
imgs, labels = data["input"], data["target"]
Expand Down Expand Up @@ -130,8 +128,6 @@ def test_from_filepaths_numpy(tmpdir):
num_workers=0,
)
assert spectrograms_data.train_dataloader() is not None
assert spectrograms_data.val_dataloader() is None
assert spectrograms_data.test_dataloader() is None

data = next(iter(spectrograms_data.train_dataloader()))
imgs, labels = data["input"], data["target"]
Expand Down Expand Up @@ -323,9 +319,6 @@ def test_from_folders_only_train(tmpdir):
assert imgs.shape == (1, 3, 128, 128)
assert labels.shape == (1,)

assert spectrograms_data.val_dataloader() is None
assert spectrograms_data.test_dataloader() is None


@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
def test_from_folders_train_val(tmpdir):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from pathlib import Path

import pytest
from pytorch_lightning import Trainer

import flash
from flash import Trainer
from flash.audio import SpeechRecognition, SpeechRecognitionData
from tests.helpers.utils import _AUDIO_TESTING

Expand Down
2 changes: 1 addition & 1 deletion tests/core/data/test_data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
import numpy as np
import pytest
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import Tensor, tensor
from torch.utils.data import DataLoader
from torch.utils.data._utils.collate import default_collate

from flash import Trainer
from flash.core.data.auto_dataset import IterableAutoDataset
from flash.core.data.batch import _Postprocessor, _Preprocessor
from flash.core.data.data_module import DataModule
Expand Down
Loading