Skip to content

Commit

Permalink
Improve torch tests (#409)
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo authored Jan 9, 2024
1 parent 15de2c5 commit d13f1d9
Show file tree
Hide file tree
Showing 2 changed files with 320 additions and 335 deletions.
47 changes: 47 additions & 0 deletions tests/unit/equality/checks/test_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING

import pytest

from coola import objects_are_equal
from coola.testing import torch_available
from tests.unit.equality.comparators.test_torch import (
TORCH_PACKED_SEQUENCE_EQUAL,
TORCH_PACKED_SEQUENCE_NOT_EQUAL,
TORCH_TENSOR_EQUAL,
TORCH_TENSOR_NOT_EQUAL,
)

if TYPE_CHECKING:
from tests.unit.equality.comparators.utils import ExamplePair


@torch_available
@pytest.mark.parametrize("example", TORCH_TENSOR_EQUAL + TORCH_PACKED_SEQUENCE_EQUAL)
@pytest.mark.parametrize("show_difference", [True, False])
def test_objects_are_equal_true(
example: ExamplePair, show_difference: bool, caplog: pytest.LogCaptureFixture
) -> None:
with caplog.at_level(logging.INFO):
assert objects_are_equal(example.object1, example.object2, show_difference)
assert not caplog.messages


@torch_available
@pytest.mark.parametrize("example", TORCH_TENSOR_NOT_EQUAL + TORCH_PACKED_SEQUENCE_NOT_EQUAL)
def test_objects_are_equal_false(example: ExamplePair, caplog: pytest.LogCaptureFixture) -> None:
with caplog.at_level(logging.INFO):
assert not objects_are_equal(example.object1, example.object2)
assert not caplog.messages


@torch_available
@pytest.mark.parametrize("example", TORCH_TENSOR_NOT_EQUAL + TORCH_PACKED_SEQUENCE_NOT_EQUAL)
def test_objects_are_equal_false_show_difference(
example: ExamplePair, caplog: pytest.LogCaptureFixture
) -> None:
with caplog.at_level(logging.INFO):
assert not objects_are_equal(example.object1, example.object2, show_difference=True)
assert caplog.messages[-1].startswith(example.expected_message)
Loading

0 comments on commit d13f1d9

Please sign in to comment.