Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Support PL 1.5.0 (#933)
Browse files Browse the repository at this point in the history
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
  • Loading branch information
ethanwharris and justusschock authored Nov 5, 2021
1 parent 642e63f commit d0adc61
Show file tree
Hide file tree
Showing 17 changed files with 138 additions and 72 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


- Fixed a bug where validation metrics could be aggregated together with test metrics in some cases ([#900](https://github.com/PyTorchLightning/lightning-flash/pull/900))


- Fixed a bug where the latest versions of torchmetrics and Lightning Flash could not be installed together ([#902](https://github.com/PyTorchLightning/lightning-flash/pull/902))


- Fixed compatibility with PyTorch-Lightning 1.5 ([#933](https://github.com/PyTorchLightning/lightning-flash/pull/933))


## [0.5.1] - 2021-10-26

### Added
Expand Down
50 changes: 41 additions & 9 deletions flash/core/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,40 @@
from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, TYPE_CHECKING, Union

import torch
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from torch.utils.data import DataLoader, IterableDataset

import flash
from flash.core.data.auto_dataset import IterableAutoDataset
from flash.core.data.batch import _DeserializeProcessor, _Postprocessor, _Preprocessor, _Sequential, _SerializeProcessor
from flash.core.data.data_source import DataSource
from flash.core.data.process import DefaultPreprocess, Deserializer, Postprocess, Preprocess, Serializer
from flash.core.data.properties import ProcessState
from flash.core.data.utils import _POSTPROCESS_FUNCS, _PREPROCESS_FUNCS, _STAGES_PREFIX
from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_4_3
from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_4_3, _PL_GREATER_EQUAL_1_5_0
from flash.core.utilities.stages import _RUNNING_STAGE_MAPPING, RunningStage

if not _PL_GREATER_EQUAL_1_5_0:
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader

if TYPE_CHECKING:
from flash.core.model import Task


class DataLoaderGetter:
"""A utility class to be used when patching the ``{stage}_dataloader`` attribute of a LightningModule."""

def __init__(self, dataloader):
self.dataloader = dataloader

# Dummy `__code__` attribute to trick is_overridden
self.__code__ = self.__call__.__code__

def __call__(self):
return self.dataloader


class DataPipelineState:
"""A class to store and share all process states once a :class:`.DataPipeline` has been initialized."""

Expand Down Expand Up @@ -315,16 +331,34 @@ def _get_dataloader(model: "Task", loader_name: str) -> Tuple[DataLoader, str]:
dataloader = getattr(model, loader_name)
attr_name = loader_name

elif model.trainer and hasattr(model.trainer, "datamodule") and model.trainer.datamodule:
dataloader = getattr(model, f"trainer.datamodule.{loader_name}", None)
elif (
model.trainer
and hasattr(model.trainer, "datamodule")
and model.trainer.datamodule
and is_overridden(loader_name, model.trainer.datamodule, flash.DataModule)
):
dataloader = getattr(model.trainer.datamodule, loader_name, None)
attr_name = f"trainer.datamodule.{loader_name}"

elif _PL_GREATER_EQUAL_1_5_0 and model.trainer is not None:
source = getattr(model.trainer._data_connector, f"_{loader_name}_source")
if not source.is_module():
dataloader = source.dataloader()
attr_name = loader_name

if dataloader is not None:
# Update source as wrapped loader will be attached to model
source.instance = model
source.name = loader_name

return dataloader, attr_name

@staticmethod
def _patch_dataloader(model: "Task", dataloader: Union[Callable, DataLoader], stage: RunningStage):
if isinstance(dataloader, DataLoader):
if _PL_GREATER_EQUAL_1_4_3:
if _PL_GREATER_EQUAL_1_5_0:
dataloader = DataLoaderGetter(dataloader)
elif _PL_GREATER_EQUAL_1_4_3:
dataloader = _PatchDataLoader(dataloader, _STAGES_PREFIX[stage])
dataloader.patch(model)
else:
Expand Down Expand Up @@ -369,7 +403,7 @@ def _attach_preprocess_to_model(
if not dataloader:
continue

if isinstance(dataloader, (_PatchDataLoader, Callable)):
if callable(dataloader):
dataloader = dataloader()

if dataloader is None:
Expand Down Expand Up @@ -504,9 +538,7 @@ def _detach_preprocessing_from_model(self, model: "Task", stage: Optional[Runnin
if not dataloader:
continue

if isinstance(dataloader, _PatchDataLoader):
dataloader = dataloader()
elif isinstance(dataloader, Callable):
if callable(dataloader):
dataloader = dataloader()

if isinstance(dataloader, Sequence):
Expand Down
8 changes: 1 addition & 7 deletions flash/core/data/new_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Tuple, Type, TYPE_CHECKING, Union
from typing import Any, Optional, Tuple, Type, Union

import pytorch_lightning as pl
import torch
Expand All @@ -30,14 +30,8 @@
from flash.core.data.datasets import BaseDataset
from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE, InputTransform
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE
from flash.core.utilities.stages import RunningStage

if _FIFTYONE_AVAILABLE and TYPE_CHECKING:
from fiftyone.core.collections import SampleCollection
else:
SampleCollection = None


class DataModule(DataModule):
"""A basic DataModule class for all Flash tasks. This class includes references to a
Expand Down
17 changes: 14 additions & 3 deletions flash/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.argparse import add_argparse_args, get_init_arguments_and_types, parse_env_variables
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from torch.utils.data import DataLoader

import flash
from flash.core.finetuning import _DEFAULTS_FINETUNE_STRATEGIES, instantiate_default_finetuning_callbacks
from flash.core.utilities.imports import _SERVE_AVAILABLE
from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_5_0, _SERVE_AVAILABLE


def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
Expand Down Expand Up @@ -277,14 +278,24 @@ def request_dataloader(
The dataloader
"""
model, stage, is_legacy = self._parse_request_dataloader_args(args, kwargs)

if is_legacy:
self.call_hook(f"on_{stage}_dataloader")
dataloader = getattr(model, f"{stage}_dataloader")()
else:
hook = f"{stage.dataloader_prefix}_dataloader"
self.call_hook("on_" + hook, pl_module=model)
dataloader = self.call_hook(hook, pl_module=model)

if is_overridden(hook, model):
dataloader = self.call_hook(hook, pl_module=model)
elif _PL_GREATER_EQUAL_1_5_0:
source = getattr(self._data_connector, f"_{stage.dataloader_prefix}_dataloader_source")
dataloader = source.dataloader()

if isinstance(dataloader, tuple):
dataloader = list(dataloader)
self.accelerator.barrier("get_dataloaders")
if _PL_GREATER_EQUAL_1_5_0:
self.training_type_plugin.barrier("get_dataloaders")
else:
self.accelerator.barrier("get_dataloaders")
return dataloader
20 changes: 20 additions & 0 deletions flash/core/utilities/compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning import Trainer


def accelerator_connector(trainer: Trainer):
if hasattr(trainer, "_accelerator_connector"):
return trainer._accelerator_connector
return trainer.accelerator_connector
3 changes: 2 additions & 1 deletion flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _compare_version(package: str, op, version) -> bool:
_PIL_AVAILABLE = _module_available("PIL")
_OPEN3D_AVAILABLE = _module_available("open3d")
_SEGMENTATION_MODELS_AVAILABLE = _module_available("segmentation_models_pytorch")
_FASTFACE_AVAILABLE = _module_available("fastface")
_FASTFACE_AVAILABLE = _module_available("fastface") and _compare_version("pytorch_lightning", operator.lt, "1.5.0")
_LIBROSA_AVAILABLE = _module_available("librosa")
_TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter")
_TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse")
Expand Down Expand Up @@ -118,6 +118,7 @@ class Image:
if Version:
_TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0")
_PL_GREATER_EQUAL_1_4_3 = _compare_version("pytorch_lightning", operator.ge, "1.4.3")
_PL_GREATER_EQUAL_1_5_0 = _compare_version("pytorch_lightning", operator.ge, "1.5.0")

_TEXT_AVAILABLE = all(
[
Expand Down
12 changes: 11 additions & 1 deletion flash/image/classification/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
import os
from collections import defaultdict
Expand All @@ -31,6 +32,7 @@
from flash.core.data.data_source import DefaultDataKeys
from flash.core.model import Task
from flash.core.registry import FlashRegistry
from flash.core.utilities.compatibility import accelerator_connector
from flash.core.utilities.imports import _LEARN2LEARN_AVAILABLE
from flash.core.utilities.providers import _LEARN2LEARN
from flash.core.utilities.url_error import catch_url_error
Expand Down Expand Up @@ -183,9 +185,17 @@ def __init__(

self.model = self.algorithm_cls(**algorithm_kwargs)

# Patch log to avoid error with learn2learn and PL 1.5
self.model.log = functools.partial(self._patch_log, self.model.log)

# this algorithm requires a special treatment
self._algorithm_has_validated = self.algorithm_cls != l2l.algorithms.LightningPrototypicalNetworks

def _patch_log(self, log, *args, on_step: Optional[bool] = None, on_epoch: Optional[bool] = None, **kwargs):
if not on_step and not on_epoch:
on_epoch = True
return log(*args, on_step=on_step, on_epoch=on_epoch, **kwargs)

def _default_transform(self, dataset, ways: int, shots: int, queries) -> List[Callable]:
return [
l2l.data.transforms.FusedNWaysKShots(dataset, n=ways, k=shots + queries),
Expand Down Expand Up @@ -268,7 +278,7 @@ def _convert_dataset(
devices = 1
if isinstance(trainer.training_type_plugin, DataParallelPlugin):
# when using DP, we need to sample n tasks, so it can splitted across multiple devices.
devices = trainer.accelerator_connector.devices
devices = accelerator_connector(trainer).devices
dataset = TaskDataParallel(taskset, epoch_length=epoch_length, devices=devices, collate_fn=None)
self.trainer.accumulated_grad_batches = self.meta_batch_size / devices

Expand Down
44 changes: 33 additions & 11 deletions flash/image/classification/integrations/baal/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,24 @@
from typing import Any, Dict, Optional

import torch
from pytorch_lightning import LightningModule
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.fit_loop import FitLoop
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.trainer.states import TrainerFn, TrainerStatus
from pytorch_lightning.utilities.model_helpers import is_overridden

import flash
from flash.core.data.data_pipeline import DataLoaderGetter
from flash.core.data.utils import _STAGES_PREFIX
from flash.core.utilities.imports import requires
from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_5_0, requires
from flash.core.utilities.stages import RunningStage
from flash.image.classification.integrations.baal.data import ActiveLearningDataModule
from flash.image.classification.integrations.baal.dropout import InferenceMCDropoutTask

if not _PL_GREATER_EQUAL_1_5_0:
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader


class ActiveLearningLoop(Loop):
@requires("baal")
Expand Down Expand Up @@ -133,35 +138,52 @@ def __getattr__(self, key):
return getattr(self.fit_loop, key)
return self.__dict__[key]

def _connect(self, model: LightningModule):
if _PL_GREATER_EQUAL_1_5_0:
self.trainer.training_type_plugin.connect(model)
else:
self.trainer.accelerator.connect(model)

def _reset_fitting(self):
self.trainer.state.fn = TrainerFn.FITTING
self.trainer.training = True
self.trainer.lightning_module.on_train_dataloader()
self.trainer.accelerator.connect(self._lightning_module)
self._connect(self._lightning_module)
self.fit_loop.epoch_progress = Progress()

def _reset_predicting(self):
self.trainer.state.fn = TrainerFn.PREDICTING
self.trainer.predicting = True
self.trainer.lightning_module.on_predict_dataloader()
self.trainer.accelerator.connect(self.inference_model)
self._connect(self.inference_model)

def _reset_testing(self):
self.trainer.state.fn = TrainerFn.TESTING
self.trainer.state.status = TrainerStatus.RUNNING
self.trainer.testing = True
self.trainer.lightning_module.on_test_dataloader()
self.trainer.accelerator.connect(self._lightning_module)
self._connect(self._lightning_module)

def _reset_dataloader_for_stage(self, running_state: RunningStage):
dataloader_name = f"{_STAGES_PREFIX[running_state]}_dataloader"
# If the dataloader exists, we reset it.
dataloader = getattr(self.trainer.datamodule, dataloader_name, None)
dataloader = (
getattr(self.trainer.datamodule, dataloader_name)
if is_overridden(dataloader_name, self.trainer.datamodule)
else None
)
if dataloader:
setattr(
self.trainer.lightning_module,
dataloader_name,
_PatchDataLoader(dataloader(), running_state),
)
if _PL_GREATER_EQUAL_1_5_0:
setattr(
self.trainer.lightning_module,
dataloader_name,
DataLoaderGetter(dataloader()),
)
else:
setattr(
self.trainer.lightning_module,
dataloader_name,
_PatchDataLoader(dataloader(), running_state),
)
setattr(self.trainer, dataloader_name, None)
getattr(self.trainer, f"reset_{dataloader_name}")(self.trainer.lightning_module)
3 changes: 2 additions & 1 deletion flash/image/embedding/vissl/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pytorch_lightning.core.hooks import ModelHooks

import flash
from flash.core.utilities.compatibility import accelerator_connector
from flash.core.utilities.imports import _VISSL_AVAILABLE

if _VISSL_AVAILABLE:
Expand Down Expand Up @@ -48,7 +49,7 @@ def on_start(self, task: "flash.image.embedding.vissl.adapter.MockVISSLTask") ->

# get around vissl distributed training by setting MockTask flags
num_nodes = lightning_module.trainer.num_nodes
accelerators_ids = lightning_module.trainer.accelerator_connector.parallel_device_ids
accelerators_ids = accelerator_connector(lightning_module.trainer).parallel_device_ids
accelerator_per_node = len(accelerators_ids) if accelerators_ids is not None else 1
task.world_size = num_nodes * accelerator_per_node

Expand Down
5 changes: 3 additions & 2 deletions flash/video/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from flash.core.classification import ClassificationTask, Labels
from flash.core.data.data_source import DefaultDataKeys
from flash.core.registry import FlashRegistry
from flash.core.utilities.compatibility import accelerator_connector
from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE
from flash.core.utilities.providers import _PYTORCHVIDEO
from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE
Expand Down Expand Up @@ -146,13 +147,13 @@ def __init__(
)

def on_train_start(self) -> None:
if self.trainer.accelerator_connector.is_distributed:
if accelerator_connector(self.trainer).is_distributed:
encoded_dataset = self.trainer.train_dataloader.loaders.dataset.dataset
encoded_dataset._video_sampler = DistributedSampler(encoded_dataset._labeled_videos)
super().on_train_start()

def on_train_epoch_start(self) -> None:
if self.trainer.accelerator_connector.is_distributed:
if accelerator_connector(self.trainer).is_distributed:
encoded_dataset = self.trainer.train_dataloader.loaders.dataset.dataset
encoded_dataset._video_sampler.set_epoch(self.trainer.current_epoch)
super().on_train_epoch_start()
Expand Down
Loading

0 comments on commit d0adc61

Please sign in to comment.