From b7dfa389c888f760aaddf741baad98d6eaa5c337 Mon Sep 17 00:00:00 2001 From: Thibaut Durand Date: Sun, 7 Jan 2024 00:04:20 -0800 Subject: [PATCH] Minor improvements (#385) --- src/coola/equality/handlers/__init__.py | 2 +- .../equality/handlers/{array.py => dtype.py} | 16 ++-------------- src/coola/equality/handlers/shape.py | 2 +- .../handlers/{test_array.py => test_dtype.py} | 10 +++++----- tests/unit/equality/handlers/test_shape.py | 4 ++-- 5 files changed, 11 insertions(+), 23 deletions(-) rename src/coola/equality/handlers/{array.py => dtype.py} (84%) rename tests/unit/equality/handlers/{test_array.py => test_dtype.py} (94%) diff --git a/src/coola/equality/handlers/__init__.py b/src/coola/equality/handlers/__init__.py index 73d26a71..76b0cc1e 100644 --- a/src/coola/equality/handlers/__init__.py +++ b/src/coola/equality/handlers/__init__.py @@ -26,8 +26,8 @@ "TrueHandler", ] -from coola.equality.handlers.array import SameDTypeHandler from coola.equality.handlers.base import AbstractEqualityHandler, BaseEqualityHandler +from coola.equality.handlers.dtype import SameDTypeHandler from coola.equality.handlers.jax_ import JaxArrayEqualHandler from coola.equality.handlers.mapping import ( MappingSameKeysHandler, diff --git a/src/coola/equality/handlers/array.py b/src/coola/equality/handlers/dtype.py similarity index 84% rename from src/coola/equality/handlers/array.py rename to src/coola/equality/handlers/dtype.py index 1cadd9e2..40044fa2 100644 --- a/src/coola/equality/handlers/array.py +++ b/src/coola/equality/handlers/dtype.py @@ -1,8 +1,8 @@ -r"""Implement some handlers for arrays or similar data.""" +r"""Implement handlers to check the objects have the same data type.""" from __future__ import annotations -__all__ = ["SameDTypeHandler"] +__all__ = ["SameDTypeHandler", "SupportsDType"] import logging from typing import TYPE_CHECKING, Any, Protocol @@ -10,20 +10,8 @@ from coola.equality.handlers.base import AbstractEqualityHandler if TYPE_CHECKING: - from unittest.mock import Mock - from coola.equality.config import EqualityConfig - from coola.utils import is_numpy_available, is_torch_available - - if is_numpy_available(): - import numpy as np - else: # pragma: no cover - np = Mock() - if is_torch_available(): - import torch - else: # pragma: no cover - torch = Mock() logger = logging.getLogger(__name__) diff --git a/src/coola/equality/handlers/shape.py b/src/coola/equality/handlers/shape.py index 700321fc..a346abea 100644 --- a/src/coola/equality/handlers/shape.py +++ b/src/coola/equality/handlers/shape.py @@ -2,7 +2,7 @@ from __future__ import annotations -__all__ = ["SameShapeHandler"] +__all__ = ["SameShapeHandler", "SupportsShape"] import logging from typing import TYPE_CHECKING, Protocol diff --git a/tests/unit/equality/handlers/test_array.py b/tests/unit/equality/handlers/test_dtype.py similarity index 94% rename from tests/unit/equality/handlers/test_array.py rename to tests/unit/equality/handlers/test_dtype.py index 64584d40..dc693939 100644 --- a/tests/unit/equality/handlers/test_array.py +++ b/tests/unit/equality/handlers/test_dtype.py @@ -27,9 +27,9 @@ def config() -> EqualityConfig: return EqualityConfig(tester=EqualityTester()) -########################################### -# Tests for ArraySameDTypeHandler # -########################################### +###################################### +# Tests for SameDTypeHandler # +###################################### def test_same_dtype_handler_eq_true() -> None: @@ -53,7 +53,7 @@ def test_same_dtype_handler_str() -> None: (np.ones(shape=(2, 3), dtype=bool), np.zeros(shape=(2, 3), dtype=bool)), ], ) -def test_same_dtype_handler_handle_true_ndarray( +def test_same_dtype_handler_handle_true( object1: np.ndarray, object2: np.ndarray, config: EqualityConfig ) -> None: assert SameDTypeHandler(next_handler=TrueHandler()).handle(object1, object2, config) @@ -68,7 +68,7 @@ def test_same_dtype_handler_handle_true_ndarray( (np.ones(shape=(2, 3), dtype=bool), np.ones(shape=(2, 3), dtype=float)), ], ) -def test_same_dtype_handler_handle_false_ndarray( +def test_same_dtype_handler_handle_false( object1: np.ndarray, object2: np.ndarray, config: EqualityConfig ) -> None: assert not SameDTypeHandler().handle(object1, object2, config) diff --git a/tests/unit/equality/handlers/test_shape.py b/tests/unit/equality/handlers/test_shape.py index 946a46d6..d7ffe2b3 100644 --- a/tests/unit/equality/handlers/test_shape.py +++ b/tests/unit/equality/handlers/test_shape.py @@ -69,7 +69,7 @@ def test_same_shape_handler_str() -> None: (np.ones(shape=(2, 3), dtype=bool), np.zeros(shape=(2, 3), dtype=float)), ], ) -def test_same_shape_handler_handle_true_ndarray( +def test_same_shape_handler_handle_true( object1: np.ndarray, object2: np.ndarray, config: EqualityConfig ) -> None: assert SameShapeHandler(next_handler=TrueHandler()).handle(object1, object2, config) @@ -84,7 +84,7 @@ def test_same_shape_handler_handle_true_ndarray( (np.ones(shape=(2, 3)), np.ones(shape=(2, 3, 1))), ], ) -def test_same_shape_handler_handle_false_ndarray( +def test_same_shape_handler_handle_false( object1: np.ndarray, object2: np.ndarray, config: EqualityConfig ) -> None: assert not SameShapeHandler().handle(object1, object2, config)