Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace package_available with module_available #16607

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/lightning/app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os

from lightning_utilities.core.imports import module_available, package_available
from lightning_utilities import module_available

_root_logger = logging.getLogger()
_logger = logging.getLogger(__name__)
Expand All @@ -28,7 +28,7 @@
if "__version__" not in locals():
if os.path.isfile(os.path.join(os.path.dirname(__file__), "__version__.py")):
from lightning.app.__version__ import version as __version__
elif package_available("lightning"):
elif module_available("lightning"):
from lightning import __version__ # noqa: F401

from lightning.app.core.app import LightningApp # noqa: E402
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/app/cli/connect/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import click
import psutil
from lightning_utilities.core.imports import package_available
from lightning_utilities import module_available
from rich.progress import Progress

from lightning.app.utilities.cli_helpers import _get_app_display_name, _LightningAppOpenAPIRetriever
Expand Down Expand Up @@ -330,7 +330,7 @@ def _install_missing_requirements(
if requirements:
missing_requirements = []
for req in requirements:
if not (package_available(req) or package_available(req.replace("-", "_"))):
if not (module_available(req) or module_available(req.replace("-", "_"))):
missing_requirements.append(req)

if missing_requirements:
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/app/cli/lightning_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import rich
from lightning_cloud.openapi import V1LightningappInstanceState, V1LightningworkState
from lightning_cloud.openapi.rest import ApiException
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.imports import ModuleAvailableCache
from requests.exceptions import ConnectionError

import lightning.app.core.constants as constants
Expand Down Expand Up @@ -416,7 +416,7 @@ def run_app(
)


if RequirementCache("lightning-fabric>=1.9.0") or RequirementCache("lightning>=1.9.0"):
if ModuleAvailableCache("lightning-fabric>=1.9.0") or ModuleAvailableCache("lightning>=1.9.0"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would this work? module available checks don't support version comparison!

# note it is automatically replaced to `from lightning.fabric.cli` when building monolithic/mirror package
from lightning.fabric.cli import _run_model

Expand Down
4 changes: 2 additions & 2 deletions src/lightning/app/testing/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import requests
from lightning_cloud.openapi import V1LightningappInstanceState
from lightning_cloud.openapi.rest import ApiException
from lightning_utilities.core.imports import package_available
from lightning_utilities import module_available
from requests import Session
from rich import print
from rich.color import ANSI_COLOR_NAMES
Expand Down Expand Up @@ -153,7 +153,7 @@ def application_testing(lit_app_cls: Type[LightningTestApp] = LightningTestApp,

patch1 = mock.patch("lightning.app.LightningApp", lit_app_cls)
# we need to patch both only with the mirror package
patch2 = mock.patch("lightning.LightningApp", lit_app_cls) if package_available("lightning") else nullcontext()
patch2 = mock.patch("lightning.LightningApp", lit_app_cls) if module_available("lightning") else nullcontext()
with patch1, patch2:
original = sys.argv
sys.argv = command_line
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/fabric/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import logging
import os

from lightning_utilities.core.imports import package_available
from lightning_utilities import module_available

if os.path.isfile(os.path.join(os.path.dirname(__file__), "__about__.py")):
from lightning.fabric.__about__ import * # noqa: F401, F403
if os.path.isfile(os.path.join(os.path.dirname(__file__), "__version__.py")):
from lightning.fabric.__version__ import version as __version__
elif package_available("lightning"):
elif module_available("lightning"):
from lightning import __version__ # type: ignore[misc] # noqa: F401

_root_logger = logging.getLogger()
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/fabric/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from argparse import Namespace
from typing import Any, List, Optional

from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.imports import ModuleAvailableCache
from typing_extensions import get_args

from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator
Expand All @@ -27,7 +27,7 @@

_log = logging.getLogger(__name__)

_CLICK_AVAILABLE = RequirementCache("click")
_CLICK_AVAILABLE = ModuleAvailableCache("click")

_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu")

Expand Down
6 changes: 3 additions & 3 deletions src/lightning/fabric/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import Any, Dict, Mapping, Optional, TYPE_CHECKING, Union

import numpy as np
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.imports import ModuleAvailableCache
from torch import Tensor
from torch.nn import Module

Expand All @@ -31,8 +31,8 @@

log = logging.getLogger(__name__)

_TENSORBOARD_AVAILABLE = RequirementCache("tensorboard")
_TENSORBOARDX_AVAILABLE = RequirementCache("tensorboardX")
_TENSORBOARD_AVAILABLE = ModuleAvailableCache("tensorboard")
_TENSORBOARDX_AVAILABLE = ModuleAvailableCache("tensorboardX")
if TYPE_CHECKING:
# assumes at least one will be installed when type checking
if _TENSORBOARD_AVAILABLE:
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/fabric/plugins/environments/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
from typing import Optional

import numpy as np
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.imports import ModuleAvailableCache

from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning.fabric.plugins.environments.lightning import find_free_network_port

log = logging.getLogger(__name__)

_MPI4PY_AVAILABLE = RequirementCache("mpi4py")
_MPI4PY_AVAILABLE = ModuleAvailableCache("mpi4py")


class MPIEnvironment(ClusterEnvironment):
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/fabric/plugins/io/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Any, Dict, Optional

from lightning_utilities.core.apply_func import apply_to_collection
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.imports import ModuleAvailableCache

from lightning.fabric.accelerators.tpu import _XLA_AVAILABLE
from lightning.fabric.plugins.io.torch_io import TorchCheckpointIO
Expand Down Expand Up @@ -51,7 +51,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio
)
fs = get_filesystem(path)
fs.makedirs(os.path.dirname(path), exist_ok=True)
if RequirementCache("omegaconf"):
if ModuleAvailableCache("omegaconf"):
# workaround for https://github.com/pytorch/xla/issues/2773
from omegaconf import DictConfig, ListConfig, OmegaConf

Expand Down
4 changes: 2 additions & 2 deletions src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union

import torch
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.imports import ModuleAvailableCache
from torch.nn import Module
from torch.optim import Optimizer

Expand All @@ -37,7 +37,7 @@
from lightning.fabric.utilities.types import _PATH

# check packaging because of https://github.com/microsoft/DeepSpeed/pull/2771
_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed") and RequirementCache("packaging>=20.0")
_DEEPSPEED_AVAILABLE = ModuleAvailableCache("deepspeed") and ModuleAvailableCache("packaging>=20.0")
if TYPE_CHECKING and _DEEPSPEED_AVAILABLE:
import deepspeed

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
import sys
from typing import Any, Callable, Optional, Sequence, Tuple

from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.imports import ModuleAvailableCache

from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning.fabric.strategies.launchers.launcher import _Launcher

_HYDRA_AVAILABLE = RequirementCache("hydra-core")
_HYDRA_AVAILABLE = ModuleAvailableCache("hydra-core")


class _SubprocessScriptLauncher(_Launcher):
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
from datetime import timedelta
from typing import Any, cast, Dict, Optional, Union

from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.imports import ModuleAvailableCache

import lightning.pytorch as pl
from lightning.pytorch.callbacks.progress.base import ProgressBarBase
from lightning.pytorch.utilities.types import STEP_OUTPUT

_RICH_AVAILABLE: bool = RequirementCache("rich>=10.2.2")
_RICH_AVAILABLE: bool = ModuleAvailableCache("rich>=10.2.2")

if _RICH_AVAILABLE:
from rich import get_console, reconfigure
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union

import torch
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.imports import ModuleAvailableCache
from lightning_utilities.core.rank_zero import _warn
from torch.optim import Optimizer

Expand All @@ -30,7 +30,7 @@
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_warn

_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache("jsonargparse[signatures]>=4.17.0")
_JSONARGPARSE_SIGNATURES_AVAILABLE = ModuleAvailableCache("jsonargparse[signatures]>=4.17.0")

if _JSONARGPARSE_SIGNATURES_AVAILABLE:
import docstring_parser
Expand Down
6 changes: 3 additions & 3 deletions src/lightning/pytorch/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from typing import Any, Dict, List, Literal, Mapping, Optional, Union

import yaml
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.imports import ModuleAvailableCache
from torch import Tensor

from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict
Expand All @@ -36,8 +36,8 @@

log = logging.getLogger(__name__)
LOCAL_FILE_URI_PREFIX = "file:"
_MLFLOW_FULL_AVAILABLE = RequirementCache("mlflow>=1.0.0")
_MLFLOW_SKINNY_AVAILABLE = RequirementCache("mlflow-skinny>=1.0.0")
_MLFLOW_FULL_AVAILABLE = ModuleAvailableCache("mlflow>=1.0.0")
_MLFLOW_SKINNY_AVAILABLE = ModuleAvailableCache("mlflow-skinny>=1.0.0")
_MLFLOW_AVAILABLE = _MLFLOW_FULL_AVAILABLE or _MLFLOW_SKINNY_AVAILABLE
if _MLFLOW_AVAILABLE:
from mlflow.entities import Metric, Param
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/loggers/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from argparse import Namespace
from typing import Any, Dict, Generator, List, Optional, Set, Union

from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.imports import ModuleAvailableCache
from torch import Tensor

import lightning.pytorch as pl
Expand All @@ -34,7 +34,7 @@
from lightning.pytorch.utilities.model_summary import ModelSummary
from lightning.pytorch.utilities.rank_zero import rank_zero_only

_NEPTUNE_AVAILABLE = RequirementCache("neptune-client")
_NEPTUNE_AVAILABLE = ModuleAvailableCache("neptune-client")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not a module

if _NEPTUNE_AVAILABLE:
from neptune import new as neptune
from neptune.new.run import Run
Expand Down
8 changes: 4 additions & 4 deletions src/lightning/pytorch/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import Any, Dict, List, Mapping, Optional, Union

import torch.nn as nn
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.imports import ModuleAvailableCache
from torch import Tensor

from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict, _sanitize_callable_params
Expand All @@ -40,9 +40,9 @@
# needed for test mocks, these tests shall be updated
wandb, Run, RunDisabled = None, None, None

_WANDB_AVAILABLE = RequirementCache("wandb")
_WANDB_GREATER_EQUAL_0_10_22 = RequirementCache("wandb>=0.10.22")
_WANDB_GREATER_EQUAL_0_12_10 = RequirementCache("wandb>=0.12.10")
_WANDB_AVAILABLE = ModuleAvailableCache("wandb")
_WANDB_GREATER_EQUAL_0_10_22 = ModuleAvailableCache("wandb>=0.10.22")
_WANDB_GREATER_EQUAL_0_12_10 = ModuleAvailableCache("wandb>=0.12.10")


class WandbLogger(Logger):
Expand Down
6 changes: 3 additions & 3 deletions src/lightning/pytorch/serve/servable_module_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import requests
import torch
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.imports import ModuleAvailableCache

import lightning.pytorch as pl
from lightning.pytorch.callbacks import Callback
Expand Down Expand Up @@ -49,10 +49,10 @@ def __init__(
exit_on_failure: bool = True,
):
super().__init__()
fastapi_installed = RequirementCache("fastapi")
fastapi_installed = ModuleAvailableCache("fastapi")
if not fastapi_installed:
raise ModuleNotFoundError(fastapi_installed.message)
uvicorn_installed = RequirementCache("uvicorn")
uvicorn_installed = ModuleAvailableCache("uvicorn")
if not uvicorn_installed:
raise ModuleNotFoundError(uvicorn_installed.message)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import subprocess
from typing import Any, Callable, List, Optional

from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.imports import ModuleAvailableCache

import lightning.pytorch as pl
from lightning.fabric.plugins import ClusterEnvironment
Expand All @@ -25,7 +25,7 @@
from lightning.pytorch.trainer.connectors.signal_connector import _SIGNUM

log = logging.getLogger(__name__)
_HYDRA_AVAILABLE = RequirementCache("hydra-core")
_HYDRA_AVAILABLE = ModuleAvailableCache("hydra-core")


class _SubprocessScriptLauncher(_Launcher):
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import numpy as np
import torch
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.imports import ModuleAvailableCache

import lightning.pytorch as pl
from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER
Expand All @@ -37,7 +37,7 @@
else:
from tqdm import tqdm

_MATPLOTLIB_AVAILABLE = RequirementCache("matplotlib")
_MATPLOTLIB_AVAILABLE = ModuleAvailableCache("matplotlib")
if TYPE_CHECKING and _MATPLOTLIB_AVAILABLE:
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
Expand Down
16 changes: 8 additions & 8 deletions src/lightning/pytorch/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,22 @@
import sys

import torch
from lightning_utilities.core.imports import compare_version, package_available, RequirementCache
from lightning_utilities.core.imports import compare_version, module_available, ModuleAvailableCache

_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
_PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11)
# duplicated from fabric because HPU is patching it below
_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0")
_TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1")
_TORCHMETRICS_GREATER_EQUAL_0_9_1 = ModuleAvailableCache("torchmetrics>=0.9.1")

_KINETO_AVAILABLE = torch.profiler.kineto_available()
_OMEGACONF_AVAILABLE = package_available("omegaconf")
_POPTORCH_AVAILABLE = package_available("poptorch")
_PSUTIL_AVAILABLE = package_available("psutil")
_RICH_AVAILABLE = package_available("rich") and compare_version("rich", operator.ge, "10.2.2")
_TORCHVISION_AVAILABLE = RequirementCache("torchvision")
_LIGHTNING_COLOSSALAI_AVAILABLE = RequirementCache("lightning-colossalai")
_OMEGACONF_AVAILABLE = module_available("omegaconf")
_POPTORCH_AVAILABLE = module_available("poptorch")
carmocca marked this conversation as resolved.
Show resolved Hide resolved
_PSUTIL_AVAILABLE = module_available("psutil")
_RICH_AVAILABLE = module_available("rich") and compare_version("rich", operator.ge, "10.2.2")
_TORCHVISION_AVAILABLE = ModuleAvailableCache("torchvision")
_LIGHTNING_COLOSSALAI_AVAILABLE = ModuleAvailableCache("lightning-colossalai")

if _POPTORCH_AVAILABLE:
import poptorch
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/utilities/model_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from typing import Any, Optional, Type

from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.imports import ModuleAvailableCache
from torch import nn

import lightning.pytorch as pl
Expand Down Expand Up @@ -45,7 +45,7 @@ def get_torchvision_model(model_name: str, **kwargs: Any) -> nn.Module:

from torchvision import models

torchvision_greater_equal_0_14 = RequirementCache("torchvision>=0.14.0")
torchvision_greater_equal_0_14 = ModuleAvailableCache("torchvision>=0.14.0")
# TODO: deprecate this function when 0.14 is the minimum supported torchvision
if torchvision_greater_equal_0_14:
return models.get_model(model_name, **kwargs)
Expand Down
Loading