Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Integration of lightning_utilties function into flash #1457

Merged
merged 23 commits into from
Oct 1, 2022
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
cfb5387
Removed the functions that are integrated in lightning_utilities
uakarsh Sep 17, 2022
eade3fd
imort of `is_overriden` method from `lightning_utilities`
uakarsh Sep 17, 2022
8eb4c87
import of `is_overriden` method from `lightning_utilities`
uakarsh Sep 17, 2022
0f50f56
Merge branch 'master' of https://github.com/uakarsh/lightning-flash
uakarsh Sep 17, 2022
3f89766
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 17, 2022
a19beac
Added `lightning_utilties`
uakarsh Sep 17, 2022
6b0d996
Apply suggestions from code review
Borda Sep 23, 2022
532fd2f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2022
980c1e0
Some updates from suggestion
uakarsh Sep 24, 2022
6bc4401
Merge branch 'master' of https://github.com/uakarsh/lightning-flash
uakarsh Sep 24, 2022
d345d96
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 24, 2022
0e6a692
Apply suggestions from code review
krshrimali Sep 24, 2022
c25d3df
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 24, 2022
7d5a533
Modified the `is_overriden` method
uakarsh Sep 24, 2022
2fa6b10
Merge branch 'master' of https://github.com/uakarsh/lightning-flash
uakarsh Sep 24, 2022
8f1f3f4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 24, 2022
2481218
Syntax Error solve
uakarsh Sep 24, 2022
652ba2c
Merge branch 'master' of https://github.com/uakarsh/lightning-flash
uakarsh Sep 24, 2022
8540419
Renamed `is_overriden` to `is_overridden`
uakarsh Sep 24, 2022
9ef1553
Apply suggestions from code review
krshrimali Sep 24, 2022
f0544a6
Install libsndfile for doctests
ethanwharris Oct 1, 2022
b212406
Merge branch 'master' into master
mergify[bot] Oct 1, 2022
3c12949
Merge branch 'master' into master
mergify[bot] Oct 1, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flash/core/data/base_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# limitations under the License.
from typing import Any, Dict, List, Set, Tuple

from lightning_utilities.core.overrides import is_overridden
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from flash.core.data.callback import BaseDataFetcher
from flash.core.data.utils import _CALLBACK_FUNCS
from flash.core.utilities.apply_func import _is_overridden
from flash.core.utilities.stages import RunningStage


Expand Down Expand Up @@ -124,7 +124,7 @@ def show(

for func_name in func_names_set:
hook_name = f"show_{func_name}"
if _is_overridden(hook_name, self, BaseVisualization):
if is_overridden(hook_name, self, BaseVisualization):
getattr(self, hook_name)(batch[func_name], running_stage, limit_nb_samples, figsize)

def show_load_sample(
Expand Down
11 changes: 1 addition & 10 deletions flash/core/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from typing import Callable, Dict, Mapping, Sequence, Type, Union
krshrimali marked this conversation as resolved.
Show resolved Hide resolved

from lightning_utilities.core.overrides import is_overridden
krshrimali marked this conversation as resolved.
Show resolved Hide resolved
from torch import nn


Expand All @@ -29,13 +30,3 @@ def get_callable_dict(fn: Union[nn.Module, Callable, Mapping, Sequence]) -> Unio
return {get_callable_name(f): f for f in fn}
if callable(fn):
return {get_callable_name(fn): fn}


def _is_overridden(method_name: str, instance: object, parent: Type[object]) -> bool:
"""Cropped Version of https://github.com/Lightning-
AI/lightning/blob/master/src/pytorch_lightning/utilities/model_helpers.py."""

if not hasattr(instance, method_name):
return False

return getattr(instance, method_name).__code__ != getattr(parent, method_name).__code__
160 changes: 56 additions & 104 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@
import operator
import os
import types
from importlib.util import find_spec
from typing import Callable, List, Tuple, Union
from typing import List, Tuple, Union

import pkg_resources
from lightning_utilities.core.imports import compare_version, module_available
from pkg_resources import DistributionNotFound

try:
Expand All @@ -28,97 +27,50 @@
Version = None


def _module_available(module_path: str) -> bool:
"""Check if a path is available in your environment.

>>> _module_available('os')
True
>>> _module_available('bla.bla')
False
"""
try:
return find_spec(module_path) is not None
except AttributeError:
# Python 3.6
return False
except ModuleNotFoundError:
# Python 3.7+
return False
except ValueError:
# Sometimes __spec__ can be None and gives a ValueError
return True


def _compare_version(package: str, op: Callable, version: str, use_base_version: bool = False) -> bool:
"""Compare package version with some requirements.

>>> _compare_version("torch", operator.ge, "0.1")
True
>>> _compare_version("does_not_exist", operator.ge, "0.0")
False
"""
try:
pkg = importlib.import_module(package)
except (ImportError, DistributionNotFound):
return False
try:
if hasattr(pkg, "__version__"):
pkg_version = Version(pkg.__version__)
else:
# try pkg_resources to infer version
pkg_version = Version(pkg_resources.get_distribution(package).version)
except TypeError:
# this is mocked by Sphinx, so it should return True to generate all summaries
return True
if use_base_version:
pkg_version = Version(pkg_version.base_version)
return op(pkg_version, Version(version))


_TORCH_AVAILABLE = _module_available("torch")
_PL_AVAILABLE = _module_available("pytorch_lightning")
_BOLTS_AVAILABLE = _module_available("pl_bolts") and _compare_version("torch", operator.lt, "1.9.0")
_PANDAS_AVAILABLE = _module_available("pandas")
_SKLEARN_AVAILABLE = _module_available("sklearn")
_PYTORCHTABULAR_AVAILABLE = _module_available("pytorch_tabular")
_FORECASTING_AVAILABLE = _module_available("pytorch_forecasting")
_KORNIA_AVAILABLE = _module_available("kornia")
_COCO_AVAILABLE = _module_available("pycocotools")
_TIMM_AVAILABLE = _module_available("timm")
_TORCHVISION_AVAILABLE = _module_available("torchvision")
_PYTORCHVIDEO_AVAILABLE = _module_available("pytorchvideo")
_MATPLOTLIB_AVAILABLE = _module_available("matplotlib")
_TRANSFORMERS_AVAILABLE = _module_available("transformers")
_PYSTICHE_AVAILABLE = _module_available("pystiche")
_FIFTYONE_AVAILABLE = _module_available("fiftyone")
_FASTAPI_AVAILABLE = _module_available("fastapi")
_PYDANTIC_AVAILABLE = _module_available("pydantic")
_GRAPHVIZ_AVAILABLE = _module_available("graphviz")
_CYTOOLZ_AVAILABLE = _module_available("cytoolz")
_UVICORN_AVAILABLE = _module_available("uvicorn")
_PIL_AVAILABLE = _module_available("PIL")
_OPEN3D_AVAILABLE = _module_available("open3d")
_SEGMENTATION_MODELS_AVAILABLE = _module_available("segmentation_models_pytorch")
_FASTFACE_AVAILABLE = _module_available("fastface") and _compare_version("pytorch_lightning", operator.lt, "1.5.0")
_LIBROSA_AVAILABLE = _module_available("librosa")
_TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter")
_TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse")
_TORCH_GEOMETRIC_AVAILABLE = _module_available("torch_geometric")
_NETWORKX_AVAILABLE = _module_available("networkx")
_TORCHAUDIO_AVAILABLE = _module_available("torchaudio")
_SENTENCEPIECE_AVAILABLE = _module_available("sentencepiece")
_DATASETS_AVAILABLE = _module_available("datasets")
_TM_TEXT_AVAILABLE: bool = _module_available("torchmetrics.text")
_ICEVISION_AVAILABLE = _module_available("icevision")
_ICEDATA_AVAILABLE = _module_available("icedata")
_LEARN2LEARN_AVAILABLE = _module_available("learn2learn") and _compare_version("learn2learn", operator.ge, "0.1.6")
_TORCH_ORT_AVAILABLE = _module_available("torch_ort")
_VISSL_AVAILABLE = _module_available("vissl") and _module_available("classy_vision")
_ALBUMENTATIONS_AVAILABLE = _module_available("albumentations")
_BAAL_AVAILABLE = _module_available("baal")
_TORCH_OPTIMIZER_AVAILABLE = _module_available("torch_optimizer")
_SENTENCE_TRANSFORMERS_AVAILABLE = _module_available("sentence_transformers")
_DEEPSPEED_AVAILABLE = _module_available("deepspeed")
_TORCH_AVAILABLE = module_available("torch")
_PL_AVAILABLE = module_available("pytorch_lightning")
_BOLTS_AVAILABLE = module_available("pl_bolts") and compare_version("torch", operator.lt, "1.9.0")
_PANDAS_AVAILABLE = module_available("pandas")
_SKLEARN_AVAILABLE = module_available("sklearn")
_PYTORCHTABULAR_AVAILABLE = module_available("pytorch_tabular")
_FORECASTING_AVAILABLE = module_available("pytorch_forecasting")
_KORNIA_AVAILABLE = module_available("kornia")
_COCO_AVAILABLE = module_available("pycocotools")
_TIMM_AVAILABLE = module_available("timm")
_TORCHVISION_AVAILABLE = module_available("torchvision")
_PYTORCHVIDEO_AVAILABLE = module_available("pytorchvideo")
_MATPLOTLIB_AVAILABLE = module_available("matplotlib")
_TRANSFORMERS_AVAILABLE = module_available("transformers")
_PYSTICHE_AVAILABLE = module_available("pystiche")
_FIFTYONE_AVAILABLE = module_available("fiftyone")
_FASTAPI_AVAILABLE = module_available("fastapi")
_PYDANTIC_AVAILABLE = module_available("pydantic")
_GRAPHVIZ_AVAILABLE = module_available("graphviz")
_CYTOOLZ_AVAILABLE = module_available("cytoolz")
_UVICORN_AVAILABLE = module_available("uvicorn")
_PIL_AVAILABLE = module_available("PIL")
_OPEN3D_AVAILABLE = module_available("open3d")
_SEGMENTATION_MODELS_AVAILABLE = module_available("segmentation_models_pytorch")
_FASTFACE_AVAILABLE = module_available("fastface") and compare_version("pytorch_lightning", operator.lt, "1.5.0")
_LIBROSA_AVAILABLE = module_available("librosa")
_TORCH_SCATTER_AVAILABLE = module_available("torch_scatter")
_TORCH_SPARSE_AVAILABLE = module_available("torch_sparse")
_TORCH_GEOMETRIC_AVAILABLE = module_available("torch_geometric")
_NETWORKX_AVAILABLE = module_available("networkx")
_TORCHAUDIO_AVAILABLE = module_available("torchaudio")
_SENTENCEPIECE_AVAILABLE = module_available("sentencepiece")
_DATASETS_AVAILABLE = module_available("datasets")
_TM_TEXT_AVAILABLE: bool = module_available("torchmetrics.text")
_ICEVISION_AVAILABLE = module_available("icevision")
_ICEDATA_AVAILABLE = module_available("icedata")
_LEARN2LEARN_AVAILABLE = module_available("learn2learn") and compare_version("learn2learn", operator.ge, "0.1.6")
_TORCH_ORT_AVAILABLE = module_available("torch_ort")
_VISSL_AVAILABLE = module_available("vissl") and module_available("classy_vision")
_ALBUMENTATIONS_AVAILABLE = module_available("albumentations")
_BAAL_AVAILABLE = module_available("baal")
_TORCH_OPTIMIZER_AVAILABLE = module_available("torch_optimizer")
_SENTENCE_TRANSFORMERS_AVAILABLE = module_available("sentence_transformers")
_DEEPSPEED_AVAILABLE = module_available("deepspeed")


if _PIL_AVAILABLE:
Expand All @@ -130,15 +82,15 @@ class Image:


if Version:
_TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0")
_PL_GREATER_EQUAL_1_4_3 = _compare_version("pytorch_lightning", operator.ge, "1.4.3")
_PL_GREATER_EQUAL_1_4_0 = _compare_version("pytorch_lightning", operator.ge, "1.4.0")
_PL_GREATER_EQUAL_1_5_0 = _compare_version("pytorch_lightning", operator.ge, "1.5.0")
_PL_GREATER_EQUAL_1_6_0 = _compare_version("pytorch_lightning", operator.ge, "1.6.0rc0")
_PANDAS_GREATER_EQUAL_1_3_0 = _compare_version("pandas", operator.ge, "1.3.0")
_ICEVISION_GREATER_EQUAL_0_11_0 = _compare_version("icevision", operator.ge, "0.11.0")
_TM_GREATER_EQUAL_0_7_0 = _compare_version("torchmetrics", operator.ge, "0.7.0")
_BAAL_GREATER_EQUAL_1_5_2 = _compare_version("baal", operator.ge, "1.5.2")
_TORCHVISION_GREATER_EQUAL_0_9 = compare_version("torchvision", operator.ge, "0.9.0")
_PL_GREATER_EQUAL_1_4_3 = compare_version("pytorch_lightning", operator.ge, "1.4.3")
_PL_GREATER_EQUAL_1_4_0 = compare_version("pytorch_lightning", operator.ge, "1.4.0")
_PL_GREATER_EQUAL_1_5_0 = compare_version("pytorch_lightning", operator.ge, "1.5.0")
_PL_GREATER_EQUAL_1_6_0 = compare_version("pytorch_lightning", operator.ge, "1.6.0rc0")
_PANDAS_GREATER_EQUAL_1_3_0 = compare_version("pandas", operator.ge, "1.3.0")
_ICEVISION_GREATER_EQUAL_0_11_0 = compare_version("icevision", operator.ge, "0.11.0")
_TM_GREATER_EQUAL_0_7_0 = compare_version("torchmetrics", operator.ge, "0.7.0")
_BAAL_GREATER_EQUAL_1_5_2 = compare_version("baal", operator.ge, "1.5.2")

_TEXT_AVAILABLE = all(
[
Expand Down Expand Up @@ -193,7 +145,7 @@ def decorator(func):
available = False
else:
modules.append(module_path)
if not _module_available(module_path):
if not module_available(module_path):
available = False
else:
available, module_path = module_path
Expand Down
10 changes: 6 additions & 4 deletions flash/image/detection/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from functools import partial
from typing import Optional

from lightning_utilities.core.imports import module_available

from flash.core.adapter import Adapter
from flash.core.integrations.icevision.adapter import IceVisionAdapter, SimpleCOCOMetric
from flash.core.integrations.icevision.backbones import (
Expand All @@ -23,7 +25,7 @@
)
from flash.core.model import Task
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _module_available, _TORCHVISION_AVAILABLE
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.core.utilities.providers import _EFFDET, _ICEVISION, _MMDET, _TORCHVISION, _ULTRALYTICS

if _ICEVISION_AVAILABLE:
Expand Down Expand Up @@ -75,7 +77,7 @@ def from_task(
providers=[_ICEVISION, _TORCHVISION],
)

if _module_available("yolov5"):
if module_available("yolov5"):
model_type = icevision_models.ultralytics.yolov5
OBJECT_DETECTION_HEADS(
partial(load_icevision_with_image_size, model_type),
Expand All @@ -85,7 +87,7 @@ def from_task(
providers=[_ICEVISION, _ULTRALYTICS],
)

if _module_available("mmdet"):
if module_available("mmdet"):
for model_type in [
icevision_models.mmdet.faster_rcnn,
icevision_models.mmdet.retinanet,
Expand All @@ -100,7 +102,7 @@ def from_task(
providers=[_ICEVISION, _MMDET],
)

if _module_available("effdet"):
if module_available("effdet"):

model_type = icevision_models.ross.efficientdet
OBJECT_DETECTION_HEADS(
Expand Down
6 changes: 4 additions & 2 deletions flash/image/instance_segmentation/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
from functools import partial
from typing import Optional

from lightning_utilities.core.imports import module_available

from flash.core.adapter import Adapter
from flash.core.integrations.icevision.adapter import IceVisionAdapter, SimpleCOCOMetric
from flash.core.integrations.icevision.backbones import get_backbones, load_icevision_ignore_image_size
from flash.core.model import Task
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _module_available, _TORCHVISION_AVAILABLE
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.core.utilities.providers import _ICEVISION, _MMDET, _TORCHVISION

if _ICEVISION_AVAILABLE:
Expand Down Expand Up @@ -71,7 +73,7 @@ def from_task(
providers=[_ICEVISION, _TORCHVISION],
)

if _module_available("mmdet"):
if module_available("mmdet"):
model_type = icevision_models.mmdet.mask_rcnn
INSTANCE_SEGMENTATION_HEADS(
partial(load_icevision_ignore_image_size, model_type),
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ jsonargparse[signatures]>=3.17.0, <=4.9.0
click>=7.1.2
protobuf<=3.20.1
fsspec
lightning_utilities>=0.3.0