diff --git a/pytorch_forecasting/__init__.py b/pytorch_forecasting/__init__.py index 693dce3c..47f92761 100644 --- a/pytorch_forecasting/__init__.py +++ b/pytorch_forecasting/__init__.py @@ -2,6 +2,8 @@ PyTorch Forecasting package for timeseries forecasting with PyTorch. """ +__version__ = "1.1.1" + from pytorch_forecasting.data import ( EncoderNormalizer, GroupNormalizer, @@ -59,6 +61,7 @@ to_list, unpack_sequence, ) +from pytorch_forecasting.utils._maint._show_versions import show_versions __all__ = [ "TimeSeriesDataSet", @@ -109,7 +112,6 @@ "integer_histogram", "groupby_apply", "profile", + "show_versions", "unpack_sequence", ] - -__version__ = "1.1.1" diff --git a/pytorch_forecasting/utils/_maint/__init__.py b/pytorch_forecasting/utils/_maint/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pytorch_forecasting/utils/_maint/_show_versions.py b/pytorch_forecasting/utils/_maint/_show_versions.py new file mode 100644 index 00000000..1d8a803f --- /dev/null +++ b/pytorch_forecasting/utils/_maint/_show_versions.py @@ -0,0 +1,142 @@ +# License: BSD 3 clause +"""Utility methods to print system info for debugging. + +adapted from +:func: `sklearn.show_versions` and `sktime.show_versions` +""" + +__all__ = ["show_versions"] + +import importlib +import platform +import sys + + +def _get_sys_info(): + """System information. + + Return + ------ + sys_info : dict + system and Python version information + """ + python = sys.version.replace("\n", " ") + + blob = [ + ("python", python), + ("executable", sys.executable), + ("machine", platform.platform()), + ] + + return dict(blob) + + +# dependencies to print versions of, by default +DEFAULT_DEPS_TO_SHOW = [ + "pip", + "pytorch-forecasting", + "torch", + "lightning", + "numpy", + "scipy", + "pandas", + "cpflows", + "matplotlib", + "optuna", + "optuna-integration", + "pytorch_optimizer", + "scikit-learn", + "scikit-base", + "statsmodels", +] + + +def _get_deps_info(deps=None, source="distributions"): + """Overview of the installed version of main dependencies. + + Parameters + ---------- + deps : optional, list of strings with package names + if None, behaves as deps = ["sktime"]. + + source : str, optional one of "distributions" (default) or "import" + source of version information + + * "distributions" - uses importlib.distributions. In this case, + strings in deps are assumed to be PEP 440 package strings, + e.g., scikit-learn, not sklearn. + * "import" - uses the __version__ attribute of the module. + In this case, strings in deps are assumed to be import names, + e.g., sklearn, not scikit-learn. + + Returns + ------- + deps_info: dict + version information on libraries in `deps` + keys are package names, import names if source is "import", + and PEP 440 package strings if source is "distributions"; + values are PEP 440 version strings + of the import as present in the current python environment + """ + if deps is None: + deps = ["pytorch-forecasting"] + + if source == "distributions": + from pytorch_forecasting.utils._dependencies import _get_installed_packages + + KEY_ALIAS = {"sklearn": "scikit-learn", "skbase": "scikit-base"} + + pkgs = _get_installed_packages() + + deps_info = {} + for modname in deps: + pkg_name = KEY_ALIAS.get(modname, modname) + deps_info[modname] = pkgs.get(pkg_name, None) + + return deps_info + + def get_version(module): + return getattr(module, "__version__", None) + + deps_info = {} + + for modname in deps: + try: + if modname in sys.modules: + mod = sys.modules[modname] + else: + mod = importlib.import_module(modname) + except ImportError: + deps_info[modname] = None + else: + ver = get_version(mod) + deps_info[modname] = ver + + return deps_info + + +def show_versions(): + """Print python version, OS version, sktime version, selected dependency versions. + + Pretty prints: + + * python version of environment + * python executable location + * OS version + * list of import name and version number for selected python dependencies + + Developer note: + Python version/executable and OS version are from `_get_sys_info` + Package versions are retrieved by `_get_deps_info` + Selected dependencies are as in the DEFAULT_DEPS_TO_SHOW variable + """ + sys_info = _get_sys_info() + deps_info = _get_deps_info(deps=DEFAULT_DEPS_TO_SHOW) + + print("\nSystem:") # noqa: T001, T201 + for k, stat in sys_info.items(): + print(f"{k:>10}: {stat}") # noqa: T001, T201 + + print("\nPython dependencies:") # noqa: T001, T201 + for k, stat in deps_info.items(): + print(f"{k:>13}: {stat}") # noqa: T001, T201 diff --git a/tests/test_utils/test_show_versions.py b/tests/test_utils/test_show_versions.py new file mode 100644 index 00000000..7fa3626f --- /dev/null +++ b/tests/test_utils/test_show_versions.py @@ -0,0 +1,42 @@ +"""Tests for the show_versions utility.""" + +import pathlib +import uuid + +from pytorch_forecasting.utils._maint._show_versions import DEFAULT_DEPS_TO_SHOW, _get_deps_info, show_versions + + +def test_show_versions_runs(): + """Test that show_versions runs without exceptions.""" + # only prints, should return None + assert show_versions() is None + + +def test_show_versions_import_loc(): + """Test that show_version can be imported from root.""" + from pytorch_forecasting import show_versions as show_versions_imported + + assert show_versions == show_versions_imported + + +def test_deps_info(): + """Test that _get_deps_info returns package/version dict as per contract.""" + deps_info = _get_deps_info() + assert isinstance(deps_info, dict) + assert set(deps_info.keys()) == {"pytorch-forecasting"} + + deps_info_default = _get_deps_info(DEFAULT_DEPS_TO_SHOW) + assert isinstance(deps_info_default, dict) + assert set(deps_info_default.keys()) == set(DEFAULT_DEPS_TO_SHOW) + + +def test_deps_info_deps_missing_package_present_directory(): + """Test that _get_deps_info does not fail if a dependency is missing.""" + dummy_package_name = uuid.uuid4().hex + + dummy_folder_path = pathlib.Path(dummy_package_name) + dummy_folder_path.mkdir() + + assert _get_deps_info([dummy_package_name]) == {dummy_package_name: None} + + dummy_folder_path.rmdir()