Skip to content

Commit

Permalink
Minor improvements (#385)
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo authored Jan 7, 2024
1 parent dcee4ec commit b7dfa38
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/coola/equality/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
"TrueHandler",
]

from coola.equality.handlers.array import SameDTypeHandler
from coola.equality.handlers.base import AbstractEqualityHandler, BaseEqualityHandler
from coola.equality.handlers.dtype import SameDTypeHandler
from coola.equality.handlers.jax_ import JaxArrayEqualHandler
from coola.equality.handlers.mapping import (
MappingSameKeysHandler,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,17 @@
r"""Implement some handlers for arrays or similar data."""
r"""Implement handlers to check the objects have the same data type."""

from __future__ import annotations

__all__ = ["SameDTypeHandler"]
__all__ = ["SameDTypeHandler", "SupportsDType"]

import logging
from typing import TYPE_CHECKING, Any, Protocol

from coola.equality.handlers.base import AbstractEqualityHandler

if TYPE_CHECKING:
from unittest.mock import Mock

from coola.equality.config import EqualityConfig
from coola.utils import is_numpy_available, is_torch_available

if is_numpy_available():
import numpy as np
else: # pragma: no cover
np = Mock()

if is_torch_available():
import torch
else: # pragma: no cover
torch = Mock()

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion src/coola/equality/handlers/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

__all__ = ["SameShapeHandler"]
__all__ = ["SameShapeHandler", "SupportsShape"]

import logging
from typing import TYPE_CHECKING, Protocol
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ def config() -> EqualityConfig:
return EqualityConfig(tester=EqualityTester())


###########################################
# Tests for ArraySameDTypeHandler #
###########################################
######################################
# Tests for SameDTypeHandler #
######################################


def test_same_dtype_handler_eq_true() -> None:
Expand All @@ -53,7 +53,7 @@ def test_same_dtype_handler_str() -> None:
(np.ones(shape=(2, 3), dtype=bool), np.zeros(shape=(2, 3), dtype=bool)),
],
)
def test_same_dtype_handler_handle_true_ndarray(
def test_same_dtype_handler_handle_true(
object1: np.ndarray, object2: np.ndarray, config: EqualityConfig
) -> None:
assert SameDTypeHandler(next_handler=TrueHandler()).handle(object1, object2, config)
Expand All @@ -68,7 +68,7 @@ def test_same_dtype_handler_handle_true_ndarray(
(np.ones(shape=(2, 3), dtype=bool), np.ones(shape=(2, 3), dtype=float)),
],
)
def test_same_dtype_handler_handle_false_ndarray(
def test_same_dtype_handler_handle_false(
object1: np.ndarray, object2: np.ndarray, config: EqualityConfig
) -> None:
assert not SameDTypeHandler().handle(object1, object2, config)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/equality/handlers/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_same_shape_handler_str() -> None:
(np.ones(shape=(2, 3), dtype=bool), np.zeros(shape=(2, 3), dtype=float)),
],
)
def test_same_shape_handler_handle_true_ndarray(
def test_same_shape_handler_handle_true(
object1: np.ndarray, object2: np.ndarray, config: EqualityConfig
) -> None:
assert SameShapeHandler(next_handler=TrueHandler()).handle(object1, object2, config)
Expand All @@ -84,7 +84,7 @@ def test_same_shape_handler_handle_true_ndarray(
(np.ones(shape=(2, 3)), np.ones(shape=(2, 3, 1))),
],
)
def test_same_shape_handler_handle_false_ndarray(
def test_same_shape_handler_handle_false(
object1: np.ndarray, object2: np.ndarray, config: EqualityConfig
) -> None:
assert not SameShapeHandler().handle(object1, object2, config)
Expand Down

0 comments on commit b7dfa38

Please sign in to comment.