Skip to content

Commit

Permalink
Improve import utility functions (#676)
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo authored Jul 31, 2024
1 parent 2fded78 commit 6008278
Show file tree
Hide file tree
Showing 14 changed files with 45 additions and 43 deletions.
34 changes: 18 additions & 16 deletions src/coola/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def package_available(name: str) -> bool:
```pycon
>>> from coola.utils.imports import package_available
>>> package_available("os")
True
>>> package_available("missing_package")
Expand All @@ -86,6 +87,7 @@ def module_available(name: str) -> bool:
```pycon
>>> from coola.utils.imports import module_available
>>> module_available("os")
True
>>> module_available("os.missing")
Expand Down Expand Up @@ -209,8 +211,8 @@ def check_jax() -> None:
"""
if not is_jax_available():
msg = (
"`jax` package is required but not installed. "
"You can install `jax` package with the command:\n\n"
"'jax' package is required but not installed. "
"You can install 'jax' package with the command:\n\n"
"pip install jax\n"
)
raise RuntimeError(msg)
Expand Down Expand Up @@ -284,8 +286,8 @@ def check_numpy() -> None:
"""
if not is_numpy_available():
msg = (
"`numpy` package is required but not installed. "
"You can install `numpy` package with the command:\n\n"
"'numpy' package is required but not installed. "
"You can install 'numpy' package with the command:\n\n"
"pip install numpy\n"
)
raise RuntimeError(msg)
Expand Down Expand Up @@ -359,8 +361,8 @@ def check_packaging() -> None:
"""
if not is_packaging_available():
msg = (
"`packaging` package is required but not installed. "
"You can install `packaging` package with the command:\n\n"
"'packaging' package is required but not installed. "
"You can install 'packaging' package with the command:\n\n"
"pip install packaging\n"
)
raise RuntimeError(msg)
Expand Down Expand Up @@ -434,8 +436,8 @@ def check_pandas() -> None:
"""
if not is_pandas_available():
msg = (
"`pandas` package is required but not installed. "
"You can install `pandas` package with the command:\n\n"
"'pandas' package is required but not installed. "
"You can install 'pandas' package with the command:\n\n"
"pip install pandas\n"
)
raise RuntimeError(msg)
Expand Down Expand Up @@ -509,8 +511,8 @@ def check_polars() -> None:
"""
if not is_polars_available():
msg = (
"`polars` package is required but not installed. "
"You can install `polars` package with the command:\n\n"
"'polars' package is required but not installed. "
"You can install 'polars' package with the command:\n\n"
"pip install polars\n"
)
raise RuntimeError(msg)
Expand Down Expand Up @@ -584,8 +586,8 @@ def check_pyarrow() -> None:
"""
if not is_pyarrow_available():
msg = (
"`pyarrow` package is required but not installed. "
"You can install `pyarrow` package with the command:\n\n"
"'pyarrow' package is required but not installed. "
"You can install 'pyarrow' package with the command:\n\n"
"pip install pyarrow\n"
)
raise RuntimeError(msg)
Expand Down Expand Up @@ -659,8 +661,8 @@ def check_torch() -> None:
"""
if not is_torch_available():
msg = (
"`torch` package is required but not installed. "
"You can install `torch` package with the command:\n\n"
"'torch' package is required but not installed. "
"You can install 'torch' package with the command:\n\n"
"pip install torch\n"
)
raise RuntimeError(msg)
Expand Down Expand Up @@ -734,8 +736,8 @@ def check_xarray() -> None:
"""
if not is_xarray_available():
msg = (
"`xarray` package is required but not installed. "
"You can install `xarray` package with the command:\n\n"
"'xarray' package is required but not installed. "
"You can install 'xarray' package with the command:\n\n"
"pip install xarray\n"
)
raise RuntimeError(msg)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/equality/comparators/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def test_jax_array_equality_comparator_equal_true_tolerance(
def test_jax_array_equality_comparator_no_jax() -> None:
with (
patch("coola.utils.imports.is_jax_available", lambda: False),
pytest.raises(RuntimeError, match="`jax` package is required but not installed."),
pytest.raises(RuntimeError, match="'jax' package is required but not installed."),
):
JaxArrayEqualityComparator()

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/equality/comparators/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def test_numpy_array_equality_comparator_equal_true_tolerance(
def test_numpy_array_equality_comparator_no_numpy() -> None:
with (
patch("coola.utils.imports.is_numpy_available", lambda: False),
pytest.raises(RuntimeError, match="`numpy` package is required but not installed."),
pytest.raises(RuntimeError, match="'numpy' package is required but not installed."),
):
NumpyArrayEqualityComparator()

Expand Down Expand Up @@ -477,7 +477,7 @@ def test_numpy_masked_array_equality_comparator_equal_nan(
def test_numpy_masked_array_equality_comparator_no_numpy() -> None:
with (
patch("coola.utils.imports.is_numpy_available", lambda: False),
pytest.raises(RuntimeError, match="`numpy` package is required but not installed."),
pytest.raises(RuntimeError, match="'numpy' package is required but not installed."),
):
NumpyMaskedArrayEqualityComparator()

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/equality/comparators/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def test_pandas_dataframe_equality_comparator_equal_tolerance(
def test_pandas_dataframe_equality_comparator_no_pandas() -> None:
with (
patch("coola.utils.imports.is_pandas_available", lambda: False),
pytest.raises(RuntimeError, match="`pandas` package is required but not installed."),
pytest.raises(RuntimeError, match="'pandas' package is required but not installed."),
):
PandasDataFrameEqualityComparator()

Expand Down Expand Up @@ -485,7 +485,7 @@ def test_pandas_series_equality_comparator_equal_tolerance(
def test_pandas_series_equality_comparator_no_pandas() -> None:
with (
patch("coola.utils.imports.is_pandas_available", lambda: False),
pytest.raises(RuntimeError, match="`pandas` package is required but not installed."),
pytest.raises(RuntimeError, match="'pandas' package is required but not installed."),
):
PandasSeriesEqualityComparator()

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/equality/comparators/test_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def test_polars_dataframe_equality_comparator_equal_tolerance(
def test_polars_dataframe_equality_comparator_no_polars() -> None:
with (
patch("coola.utils.imports.is_polars_available", lambda: False),
pytest.raises(RuntimeError, match="`polars` package is required but not installed."),
pytest.raises(RuntimeError, match="'polars' package is required but not installed."),
):
PolarsDataFrameEqualityComparator()

Expand Down Expand Up @@ -487,7 +487,7 @@ def test_polars_series_equality_comparator_equal_tolerance(
def test_polars_series_equality_comparator_no_polars() -> None:
with (
patch("coola.utils.imports.is_polars_available", lambda: False),
pytest.raises(RuntimeError, match="`polars` package is required but not installed."),
pytest.raises(RuntimeError, match="'polars' package is required but not installed."),
):
PolarsSeriesEqualityComparator()

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/equality/comparators/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def test_pyarrow_equality_comparator_equal_true_tolerance(
def test_pyarrow_equality_comparator_no_pyarrow() -> None:
with (
patch("coola.utils.imports.is_pyarrow_available", lambda: False),
pytest.raises(RuntimeError, match="`pyarrow` package is required but not installed."),
pytest.raises(RuntimeError, match="'pyarrow' package is required but not installed."),
):
PyarrowEqualityComparator()

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/equality/comparators/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ def test_tensor_packed_sequence_equality_comparator_equal_nan_false(
def test_tensor_packed_sequence_equality_comparator_no_torch() -> None:
with (
patch("coola.utils.imports.is_torch_available", lambda: False),
pytest.raises(RuntimeError, match="`torch` package is required but not installed."),
pytest.raises(RuntimeError, match="'torch' package is required but not installed."),
):
TorchPackedSequenceEqualityComparator()

Expand Down Expand Up @@ -613,7 +613,7 @@ def test_torch_tensor_equality_comparator_true_tolerance(
def test_torch_tensor_equality_comparator_no_torch() -> None:
with (
patch("coola.utils.imports.is_torch_available", lambda: False),
pytest.raises(RuntimeError, match="`torch` package is required but not installed."),
pytest.raises(RuntimeError, match="'torch' package is required but not installed."),
):
TorchTensorEqualityComparator()

Expand Down
6 changes: 3 additions & 3 deletions tests/unit/equality/comparators/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def test_xarray_data_array_equality_comparator_equal_true_tolerance(
def test_xarray_data_array_equality_comparator_no_xarray() -> None:
with (
patch("coola.utils.imports.is_xarray_available", lambda: False),
pytest.raises(RuntimeError, match="`xarray` package is required but not installed."),
pytest.raises(RuntimeError, match="'xarray' package is required but not installed."),
):
XarrayDataArrayEqualityComparator()

Expand Down Expand Up @@ -685,7 +685,7 @@ def test_xarray_dataset_equality_comparator_equal_true_tolerance(
def test_xarray_dataset_equality_comparator_no_xarray() -> None:
with (
patch("coola.utils.imports.is_xarray_available", lambda: False),
pytest.raises(RuntimeError, match="`xarray` package is required but not installed."),
pytest.raises(RuntimeError, match="'xarray' package is required but not installed."),
):
XarrayDatasetEqualityComparator()

Expand Down Expand Up @@ -810,7 +810,7 @@ def test_xarray_variable_equality_comparator_equal_true_tolerance(
def test_xarray_variable_equality_comparator_no_xarray() -> None:
with (
patch("coola.utils.imports.is_xarray_available", lambda: False),
pytest.raises(RuntimeError, match="`xarray` package is required but not installed."),
pytest.raises(RuntimeError, match="'xarray' package is required but not installed."),
):
XarrayVariableEqualityComparator()

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/random/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_numpy_random_manager_set_rng_state() -> None:
def test_numpy_random_manager_no_numpy() -> None:
with (
patch("coola.utils.imports.is_numpy_available", lambda: False),
pytest.raises(RuntimeError, match="`numpy` package is required but not installed."),
pytest.raises(RuntimeError, match="'numpy' package is required but not installed."),
):
NumpyRandomManager()

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/random/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_torch_random_manager_set_rng_state() -> None:
def test_torch_random_manager_no_torch() -> None:
with (
patch("coola.utils.imports.is_torch_available", lambda: False),
pytest.raises(RuntimeError, match="`torch` package is required but not installed."),
pytest.raises(RuntimeError, match="'torch' package is required but not installed."),
):
TorchRandomManager()

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/reducers/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,6 @@ def test_numpy_reducer_std_empty(values: Sequence[int | float]) -> None:
def test_numpy_reducer_no_numpy() -> None:
with (
patch("coola.utils.imports.is_numpy_available", lambda: False),
pytest.raises(RuntimeError, match="`numpy` package is required but not installed."),
pytest.raises(RuntimeError, match="'numpy' package is required but not installed."),
):
NumpyReducer()
2 changes: 1 addition & 1 deletion tests/unit/reducers/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,6 @@ def test_torch_reducer_std_empty(values: Sequence[int | float]) -> None:
def test_torch_reducer_no_torch() -> None:
with (
patch("coola.utils.imports.is_torch_available", lambda: False),
pytest.raises(RuntimeError, match="`torch` package is required but not installed."),
pytest.raises(RuntimeError, match="'torch' package is required but not installed."),
):
TorchReducer()
16 changes: 8 additions & 8 deletions tests/unit/utils/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def test_check_jax_with_package() -> None:
def test_check_jax_without_package() -> None:
with (
patch("coola.utils.imports.is_jax_available", lambda: False),
pytest.raises(RuntimeError, match="`jax` package is required but not installed."),
pytest.raises(RuntimeError, match="'jax' package is required but not installed."),
):
check_jax()

Expand Down Expand Up @@ -227,7 +227,7 @@ def test_check_numpy_with_package() -> None:
def test_check_numpy_without_package() -> None:
with (
patch("coola.utils.imports.is_numpy_available", lambda: False),
pytest.raises(RuntimeError, match="`numpy` package is required but not installed."),
pytest.raises(RuntimeError, match="'numpy' package is required but not installed."),
):
check_numpy()

Expand Down Expand Up @@ -281,7 +281,7 @@ def test_check_packaging_with_package() -> None:
def test_check_packaging_without_package() -> None:
with (
patch("coola.utils.imports.is_packaging_available", lambda: False),
pytest.raises(RuntimeError, match="`packaging` package is required but not installed."),
pytest.raises(RuntimeError, match="'packaging' package is required but not installed."),
):
check_packaging()

Expand Down Expand Up @@ -335,7 +335,7 @@ def test_check_pandas_with_package() -> None:
def test_check_pandas_without_package() -> None:
with (
patch("coola.utils.imports.is_pandas_available", lambda: False),
pytest.raises(RuntimeError, match="`pandas` package is required but not installed."),
pytest.raises(RuntimeError, match="'pandas' package is required but not installed."),
):
check_pandas()

Expand Down Expand Up @@ -389,7 +389,7 @@ def test_check_polars_with_package() -> None:
def test_check_polars_without_package() -> None:
with (
patch("coola.utils.imports.is_polars_available", lambda: False),
pytest.raises(RuntimeError, match="`polars` package is required but not installed."),
pytest.raises(RuntimeError, match="'polars' package is required but not installed."),
):
check_polars()

Expand Down Expand Up @@ -443,7 +443,7 @@ def test_check_pyarrow_with_package() -> None:
def test_check_pyarrow_without_package() -> None:
with (
patch("coola.utils.imports.is_pyarrow_available", lambda: False),
pytest.raises(RuntimeError, match="`pyarrow` package is required but not installed."),
pytest.raises(RuntimeError, match="'pyarrow' package is required but not installed."),
):
check_pyarrow()

Expand Down Expand Up @@ -497,7 +497,7 @@ def test_check_torch_with_package() -> None:
def test_check_torch_without_package() -> None:
with (
patch("coola.utils.imports.is_torch_available", lambda: False),
pytest.raises(RuntimeError, match="`torch` package is required but not installed."),
pytest.raises(RuntimeError, match="'torch' package is required but not installed."),
):
check_torch()

Expand Down Expand Up @@ -551,7 +551,7 @@ def test_check_xarray_with_package() -> None:
def test_check_xarray_without_package() -> None:
with (
patch("coola.utils.imports.is_xarray_available", lambda: False),
pytest.raises(RuntimeError, match="`xarray` package is required but not installed."),
pytest.raises(RuntimeError, match="'xarray' package is required but not installed."),
):
check_xarray()

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/utils/test_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_compare_version_false_missing() -> None:
def test_compare_version_missing_packaging() -> None:
with (
patch("coola.utils.imports.is_packaging_available", lambda: False),
pytest.raises(RuntimeError, match="`packaging` package is required but not installed."),
pytest.raises(RuntimeError, match="'packaging' package is required but not installed."),
):
compare_version("my_package", operator.ge, "7.3.0")

Expand All @@ -59,6 +59,6 @@ def test_get_package_version_missing() -> None:
def test_get_package_version_missing_packaging() -> None:
with (
patch("coola.utils.imports.is_packaging_available", lambda: False),
pytest.raises(RuntimeError, match="`packaging` package is required but not installed."),
pytest.raises(RuntimeError, match="'packaging' package is required but not installed."),
):
get_package_version("my_package")

0 comments on commit 6008278

Please sign in to comment.