Skip to content

Commit

Permalink
Improve xarray tests (#411)
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo authored Jan 9, 2024
1 parent 77a3bb4 commit fdcafeb
Show file tree
Hide file tree
Showing 11 changed files with 481 additions and 487 deletions.
13 changes: 4 additions & 9 deletions tests/unit/equality/checks/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,14 @@

from coola import objects_are_equal
from coola.testing import numpy_available
from tests.unit.equality.comparators.test_numpy import (
NUMPY_ARRAY_EQUAL,
NUMPY_ARRAY_NOT_EQUAL,
NUMPY_MASKED_ARRAY_EQUAL,
NUMPY_MASKED_ARRAY_NOT_EQUAL,
)
from tests.unit.equality.comparators.test_numpy import NUMPY_EQUAL, NUMPY_NOT_EQUAL

if TYPE_CHECKING:
from tests.unit.equality.comparators.utils import ExamplePair


@numpy_available
@pytest.mark.parametrize("example", NUMPY_ARRAY_EQUAL + NUMPY_MASKED_ARRAY_EQUAL)
@pytest.mark.parametrize("example", NUMPY_EQUAL)
@pytest.mark.parametrize("show_difference", [True, False])
def test_objects_are_equal_true(
example: ExamplePair, show_difference: bool, caplog: pytest.LogCaptureFixture
Expand All @@ -30,15 +25,15 @@ def test_objects_are_equal_true(


@numpy_available
@pytest.mark.parametrize("example", NUMPY_ARRAY_NOT_EQUAL + NUMPY_MASKED_ARRAY_NOT_EQUAL)
@pytest.mark.parametrize("example", NUMPY_NOT_EQUAL)
def test_objects_are_equal_false(example: ExamplePair, caplog: pytest.LogCaptureFixture) -> None:
with caplog.at_level(logging.INFO):
assert not objects_are_equal(example.object1, example.object2)
assert not caplog.messages


@numpy_available
@pytest.mark.parametrize("example", NUMPY_ARRAY_NOT_EQUAL + NUMPY_MASKED_ARRAY_NOT_EQUAL)
@pytest.mark.parametrize("example", NUMPY_NOT_EQUAL)
def test_objects_are_equal_false_show_difference(
example: ExamplePair, caplog: pytest.LogCaptureFixture
) -> None:
Expand Down
13 changes: 4 additions & 9 deletions tests/unit/equality/checks/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,14 @@

from coola import objects_are_equal
from coola.testing import pandas_available
from tests.unit.equality.comparators.test_pandas import (
PANDAS_DATAFRAME_EQUAL,
PANDAS_DATAFRAME_NOT_EQUAL,
PANDAS_SERIES_EQUAL,
PANDAS_SERIES_NOT_EQUAL,
)
from tests.unit.equality.comparators.test_pandas import PANDAS_EQUAL, PANDAS_NOT_EQUAL

if TYPE_CHECKING:
from tests.unit.equality.comparators.utils import ExamplePair


@pandas_available
@pytest.mark.parametrize("example", PANDAS_SERIES_EQUAL + PANDAS_DATAFRAME_EQUAL)
@pytest.mark.parametrize("example", PANDAS_EQUAL)
@pytest.mark.parametrize("show_difference", [True, False])
def test_objects_are_equal_true(
example: ExamplePair, show_difference: bool, caplog: pytest.LogCaptureFixture
Expand All @@ -30,15 +25,15 @@ def test_objects_are_equal_true(


@pandas_available
@pytest.mark.parametrize("example", PANDAS_SERIES_NOT_EQUAL + PANDAS_DATAFRAME_NOT_EQUAL)
@pytest.mark.parametrize("example", PANDAS_NOT_EQUAL)
def test_objects_are_equal_false(example: ExamplePair, caplog: pytest.LogCaptureFixture) -> None:
with caplog.at_level(logging.INFO):
assert not objects_are_equal(example.object1, example.object2)
assert not caplog.messages


@pandas_available
@pytest.mark.parametrize("example", PANDAS_SERIES_NOT_EQUAL + PANDAS_DATAFRAME_NOT_EQUAL)
@pytest.mark.parametrize("example", PANDAS_NOT_EQUAL)
def test_objects_are_equal_false_show_difference(
example: ExamplePair, caplog: pytest.LogCaptureFixture
) -> None:
Expand Down
13 changes: 4 additions & 9 deletions tests/unit/equality/checks/test_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,14 @@

from coola import objects_are_equal
from coola.testing import polars_available
from tests.unit.equality.comparators.test_polars import (
POLARS_DATAFRAME_EQUAL,
POLARS_DATAFRAME_NOT_EQUAL,
POLARS_SERIES_EQUAL,
POLARS_SERIES_NOT_EQUAL,
)
from tests.unit.equality.comparators.test_polars import POLARS_EQUAL, POLARS_NOT_EQUAL

if TYPE_CHECKING:
from tests.unit.equality.comparators.utils import ExamplePair


@polars_available
@pytest.mark.parametrize("example", POLARS_SERIES_EQUAL + POLARS_DATAFRAME_EQUAL)
@pytest.mark.parametrize("example", POLARS_EQUAL)
@pytest.mark.parametrize("show_difference", [True, False])
def test_objects_are_equal_true(
example: ExamplePair, show_difference: bool, caplog: pytest.LogCaptureFixture
Expand All @@ -30,15 +25,15 @@ def test_objects_are_equal_true(


@polars_available
@pytest.mark.parametrize("example", POLARS_SERIES_NOT_EQUAL + POLARS_DATAFRAME_NOT_EQUAL)
@pytest.mark.parametrize("example", POLARS_NOT_EQUAL)
def test_objects_are_equal_false(example: ExamplePair, caplog: pytest.LogCaptureFixture) -> None:
with caplog.at_level(logging.INFO):
assert not objects_are_equal(example.object1, example.object2)
assert not caplog.messages


@polars_available
@pytest.mark.parametrize("example", POLARS_SERIES_NOT_EQUAL + POLARS_DATAFRAME_NOT_EQUAL)
@pytest.mark.parametrize("example", POLARS_NOT_EQUAL)
def test_objects_are_equal_false_show_difference(
example: ExamplePair, caplog: pytest.LogCaptureFixture
) -> None:
Expand Down
13 changes: 4 additions & 9 deletions tests/unit/equality/checks/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,14 @@

from coola import objects_are_equal
from coola.testing import torch_available
from tests.unit.equality.comparators.test_torch import (
TORCH_PACKED_SEQUENCE_EQUAL,
TORCH_PACKED_SEQUENCE_NOT_EQUAL,
TORCH_TENSOR_EQUAL,
TORCH_TENSOR_NOT_EQUAL,
)
from tests.unit.equality.comparators.test_torch import TORCH_EQUAL, TORCH_NOT_EQUAL

if TYPE_CHECKING:
from tests.unit.equality.comparators.utils import ExamplePair


@torch_available
@pytest.mark.parametrize("example", TORCH_TENSOR_EQUAL + TORCH_PACKED_SEQUENCE_EQUAL)
@pytest.mark.parametrize("example", TORCH_EQUAL)
@pytest.mark.parametrize("show_difference", [True, False])
def test_objects_are_equal_true(
example: ExamplePair, show_difference: bool, caplog: pytest.LogCaptureFixture
Expand All @@ -30,15 +25,15 @@ def test_objects_are_equal_true(


@torch_available
@pytest.mark.parametrize("example", TORCH_TENSOR_NOT_EQUAL + TORCH_PACKED_SEQUENCE_NOT_EQUAL)
@pytest.mark.parametrize("example", TORCH_NOT_EQUAL)
def test_objects_are_equal_false(example: ExamplePair, caplog: pytest.LogCaptureFixture) -> None:
with caplog.at_level(logging.INFO):
assert not objects_are_equal(example.object1, example.object2)
assert not caplog.messages


@torch_available
@pytest.mark.parametrize("example", TORCH_TENSOR_NOT_EQUAL + TORCH_PACKED_SEQUENCE_NOT_EQUAL)
@pytest.mark.parametrize("example", TORCH_NOT_EQUAL)
def test_objects_are_equal_false_show_difference(
example: ExamplePair, caplog: pytest.LogCaptureFixture
) -> None:
Expand Down
42 changes: 42 additions & 0 deletions tests/unit/equality/checks/test_xarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING

import pytest

from coola import objects_are_equal
from coola.testing import xarray_available
from tests.unit.equality.comparators.test_xarray import XARRAY_EQUAL, XARRAY_NOT_EQUAL

if TYPE_CHECKING:
from tests.unit.equality.comparators.utils import ExamplePair


@xarray_available
@pytest.mark.parametrize("example", XARRAY_EQUAL)
@pytest.mark.parametrize("show_difference", [True, False])
def test_objects_are_equal_true(
example: ExamplePair, show_difference: bool, caplog: pytest.LogCaptureFixture
) -> None:
with caplog.at_level(logging.INFO):
assert objects_are_equal(example.object1, example.object2, show_difference)
assert not caplog.messages


@xarray_available
@pytest.mark.parametrize("example", XARRAY_NOT_EQUAL)
def test_objects_are_equal_false(example: ExamplePair, caplog: pytest.LogCaptureFixture) -> None:
with caplog.at_level(logging.INFO):
assert not objects_are_equal(example.object1, example.object2)
assert not caplog.messages


@xarray_available
@pytest.mark.parametrize("example", XARRAY_NOT_EQUAL)
def test_objects_are_equal_false_show_difference(
example: ExamplePair, caplog: pytest.LogCaptureFixture
) -> None:
with caplog.at_level(logging.INFO):
assert not objects_are_equal(example.object1, example.object2, show_difference=True)
assert caplog.messages[-1].startswith(example.expected_message)
38 changes: 26 additions & 12 deletions tests/unit/equality/comparators/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,41 +30,55 @@ def config() -> EqualityConfig:

JAX_ARRAY_EQUAL = [
pytest.param(
ExamplePair(jnp.ones(shape=(2, 3), dtype=float), jnp.ones(shape=(2, 3), dtype=float)),
ExamplePair(
object1=jnp.ones(shape=(2, 3), dtype=float), object2=jnp.ones(shape=(2, 3), dtype=float)
),
id="float dtype",
),
pytest.param(
ExamplePair(jnp.ones(shape=(2, 3), dtype=int), jnp.ones(shape=(2, 3), dtype=int)),
ExamplePair(
object1=jnp.ones(shape=(2, 3), dtype=int), object2=jnp.ones(shape=(2, 3), dtype=int)
),
id="int dtype",
),
pytest.param(ExamplePair(jnp.ones(shape=6), jnp.ones(shape=6)), id="1d array"),
pytest.param(ExamplePair(jnp.ones(shape=(2, 3)), jnp.ones(shape=(2, 3))), id="2d array"),
pytest.param(ExamplePair(object1=jnp.ones(shape=6), object2=jnp.ones(shape=6)), id="1d array"),
pytest.param(
ExamplePair(object1=jnp.ones(shape=(2, 3)), object2=jnp.ones(shape=(2, 3))), id="2d array"
),
]


JAX_ARRAY_NOT_EQUAL = [
pytest.param(
ExamplePair(
jnp.ones(shape=(2, 3), dtype=float),
jnp.ones(shape=(2, 3), dtype=int),
"objects have different data types:",
object1=jnp.ones(shape=(2, 3), dtype=float),
object2=jnp.ones(shape=(2, 3), dtype=int),
expected_message="objects have different data types:",
),
id="different data types",
),
pytest.param(
ExamplePair(jnp.ones(shape=(2, 3)), jnp.ones(shape=6), "objects have different shapes:"),
ExamplePair(
object1=jnp.ones(shape=(2, 3)),
object2=jnp.ones(shape=6),
expected_message="objects have different shapes:",
),
id="different shapes",
),
pytest.param(
ExamplePair(
jnp.ones(shape=(2, 3)),
jnp.zeros(shape=(2, 3)),
"jax.numpy.ndarrays have different elements:",
object1=jnp.ones(shape=(2, 3)),
object2=jnp.zeros(shape=(2, 3)),
expected_message="jax.numpy.ndarrays have different elements:",
),
id="different values",
),
pytest.param(
ExamplePair(jnp.ones(shape=(2, 3)), "meow", "objects have different types:"),
ExamplePair(
object1=jnp.ones(shape=(2, 3)),
object2="meow",
expected_message="objects have different types:",
),
id="different types",
),
]
Expand Down
Loading

0 comments on commit fdcafeb

Please sign in to comment.