Skip to content

Commit

Permalink
Improve some tests (#410)
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo authored Jan 9, 2024
1 parent d13f1d9 commit 77a3bb4
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 105 deletions.
25 changes: 10 additions & 15 deletions tests/unit/equality/comparators/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,21 +164,16 @@ def test_jax_array_equality_comparator_equal_false_show_difference(


@jax_available
def test_jax_array_equality_comparator_equal_nan_false(config: EqualityConfig) -> None:
assert not JaxArrayEqualityComparator().equal(
object1=jnp.array([0.0, jnp.nan, jnp.nan, 1.2]),
object2=jnp.array([0.0, jnp.nan, jnp.nan, 1.2]),
config=config,
)


@jax_available
def test_jax_array_equality_comparator_equal_nan_true(config: EqualityConfig) -> None:
config.equal_nan = True
assert JaxArrayEqualityComparator().equal(
object1=jnp.array([0.0, jnp.nan, jnp.nan, 1.2]),
object2=jnp.array([0.0, jnp.nan, jnp.nan, 1.2]),
config=config,
@pytest.mark.parametrize("equal_nan", [False, True])
def test_jax_array_equality_comparator_equal_nan(config: EqualityConfig, equal_nan: bool) -> None:
config.equal_nan = equal_nan
assert (
JaxArrayEqualityComparator().equal(
object1=jnp.array([0.0, jnp.nan, jnp.nan, 1.2]),
object2=jnp.array([0.0, jnp.nan, jnp.nan, 1.2]),
config=config,
)
== equal_nan
)


Expand Down
54 changes: 24 additions & 30 deletions tests/unit/equality/comparators/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,21 +220,18 @@ def test_numpy_array_equality_comparator_equal_false_show_difference(


@numpy_available
def test_numpy_array_equality_comparator_equal_nan_false(config: EqualityConfig) -> None:
assert not NumpyArrayEqualityComparator().equal(
object1=np.array([0.0, np.nan, np.nan, 1.2]),
object2=np.array([0.0, np.nan, np.nan, 1.2]),
config=config,
)


@numpy_available
def test_numpy_array_equality_comparator_equal_nan_true(config: EqualityConfig) -> None:
config.equal_nan = True
assert NumpyArrayEqualityComparator().equal(
object1=np.array([0.0, np.nan, np.nan, 1.2]),
object2=np.array([0.0, np.nan, np.nan, 1.2]),
config=config,
@pytest.mark.parametrize("equal_nan", [False, True])
def test_numpy_array_equality_comparator_equal_nan_true(
config: EqualityConfig, equal_nan: bool
) -> None:
config.equal_nan = equal_nan
assert (
NumpyArrayEqualityComparator().equal(
object1=np.array([0.0, np.nan, np.nan, 1.2]),
object2=np.array([0.0, np.nan, np.nan, 1.2]),
config=config,
)
== equal_nan
)


Expand Down Expand Up @@ -339,21 +336,18 @@ def test_numpy_masked_array_equality_comparator_equal_false_show_difference(


@numpy_available
def test_numpy_masked_array_equality_comparator_equal_nan_false(config: EqualityConfig) -> None:
assert not NumpyMaskedArrayEqualityComparator().equal(
object1=np.ma.array(data=[0.0, np.nan, np.nan, 1.2], mask=[0, 1, 0, 1]),
object2=np.ma.array(data=[0.0, np.nan, np.nan, 1.2], mask=[0, 1, 0, 1]),
config=config,
)


@numpy_available
def test_numpy_masked_array_equality_comparator_equal_nan_true(config: EqualityConfig) -> None:
config.equal_nan = True
assert NumpyMaskedArrayEqualityComparator().equal(
object1=np.ma.array(data=[0.0, np.nan, np.nan, 1.2], mask=[0, 1, 0, 1]),
object2=np.ma.array(data=[0.0, np.nan, np.nan, 1.2], mask=[0, 1, 0, 1]),
config=config,
@pytest.mark.parametrize("equal_nan", [False, True])
def test_numpy_masked_array_equality_comparator_equal_nan(
config: EqualityConfig, equal_nan: bool
) -> None:
config.equal_nan = equal_nan
assert (
NumpyMaskedArrayEqualityComparator().equal(
object1=np.ma.array(data=[0.0, np.nan, np.nan, 1.2], mask=[0, 1, 0, 1]),
object2=np.ma.array(data=[0.0, np.nan, np.nan, 1.2], mask=[0, 1, 0, 1]),
config=config,
)
== equal_nan
)


Expand Down
54 changes: 24 additions & 30 deletions tests/unit/equality/comparators/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,21 +230,18 @@ def test_pandas_dataframe_equality_comparator_equal_false_show_difference(


@pandas_available
def test_pandas_dataframe_equality_comparator_equal_nan_false(config: EqualityConfig) -> None:
assert not PandasDataFrameEqualityComparator().equal(
object1=pandas.DataFrame({"col": [1, float("nan"), 3]}),
object2=pandas.DataFrame({"col": [1, float("nan"), 3]}),
config=config,
)


@pandas_available
def test_pandas_dataframe_equality_comparator_equal_nan_true(config: EqualityConfig) -> None:
config.equal_nan = True
assert PandasDataFrameEqualityComparator().equal(
object1=pandas.DataFrame({"col": [1, float("nan"), 3]}),
object2=pandas.DataFrame({"col": [1, float("nan"), 3]}),
config=config,
@pytest.mark.parametrize("equal_nan", [False, True])
def test_pandas_dataframe_equality_comparator_equal_nan(
config: EqualityConfig, equal_nan: bool
) -> None:
config.equal_nan = equal_nan
assert (
PandasDataFrameEqualityComparator().equal(
object1=pandas.DataFrame({"col": [1, float("nan"), 3]}),
object2=pandas.DataFrame({"col": [1, float("nan"), 3]}),
config=config,
)
== equal_nan
)


Expand Down Expand Up @@ -352,21 +349,18 @@ def test_pandas_series_equality_comparator_equal_false_show_difference(


@pandas_available
def test_pandas_series_equality_comparator_equal_nan_false(config: EqualityConfig) -> None:
assert not PandasSeriesEqualityComparator().equal(
object1=pandas.Series([0.0, float("nan"), float("nan"), 1.2]),
object2=pandas.Series([0.0, float("nan"), float("nan"), 1.2]),
config=config,
)


@pandas_available
def test_pandas_series_equality_comparator_equal_nan_true(config: EqualityConfig) -> None:
config.equal_nan = True
assert PandasSeriesEqualityComparator().equal(
object1=pandas.Series([0.0, float("nan"), float("nan"), 1.2]),
object2=pandas.Series([0.0, float("nan"), float("nan"), 1.2]),
config=config,
@pytest.mark.parametrize("equal_nan", [False, True])
def test_pandas_series_equality_comparator_equal_nan(
config: EqualityConfig, equal_nan: bool
) -> None:
config.equal_nan = equal_nan
assert (
PandasSeriesEqualityComparator().equal(
object1=pandas.Series([0.0, float("nan"), float("nan"), 1.2]),
object2=pandas.Series([0.0, float("nan"), float("nan"), 1.2]),
config=config,
)
== equal_nan
)


Expand Down
54 changes: 24 additions & 30 deletions tests/unit/equality/comparators/test_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,21 +230,18 @@ def test_polars_dataframe_equality_comparator_equal_false_show_difference(


@polars_available
def test_polars_dataframe_equality_comparator_equal_nan_false(config: EqualityConfig) -> None:
assert not PolarsDataFrameEqualityComparator().equal(
object1=polars.DataFrame({"col": [1, float("nan"), 3]}),
object2=polars.DataFrame({"col": [1, float("nan"), 3]}),
config=config,
)


@polars_available
def test_polars_dataframe_equality_comparator_equal_nan_true(config: EqualityConfig) -> None:
config.equal_nan = True
assert PolarsDataFrameEqualityComparator().equal(
object1=polars.DataFrame({"col": [1, float("nan"), 3]}),
object2=polars.DataFrame({"col": [1, float("nan"), 3]}),
config=config,
@pytest.mark.parametrize("equal_nan", [False, True])
def test_polars_dataframe_equality_comparator_equal_nan(
config: EqualityConfig, equal_nan: bool
) -> None:
config.equal_nan = equal_nan
assert (
PolarsDataFrameEqualityComparator().equal(
object1=polars.DataFrame({"col": [1, float("nan"), 3]}),
object2=polars.DataFrame({"col": [1, float("nan"), 3]}),
config=config,
)
== equal_nan
)


Expand Down Expand Up @@ -352,21 +349,18 @@ def test_polars_series_equality_comparator_equal_false_show_difference(


@polars_available
def test_polars_series_equality_comparator_equal_nan_false(config: EqualityConfig) -> None:
assert not PolarsSeriesEqualityComparator().equal(
object1=polars.Series([0.0, float("nan"), float("nan"), 1.2]),
object2=polars.Series([0.0, float("nan"), float("nan"), 1.2]),
config=config,
)


@polars_available
def test_polars_series_equality_comparator_equal_nan_true(config: EqualityConfig) -> None:
config.equal_nan = True
assert PolarsSeriesEqualityComparator().equal(
object1=polars.Series([0.0, float("nan"), float("nan"), 1.2]),
object2=polars.Series([0.0, float("nan"), float("nan"), 1.2]),
config=config,
@pytest.mark.parametrize("equal_nan", [False, True])
def test_polars_series_equality_comparator_equal_nan(
config: EqualityConfig, equal_nan: bool
) -> None:
config.equal_nan = equal_nan
assert (
PolarsSeriesEqualityComparator().equal(
object1=polars.Series([0.0, float("nan"), float("nan"), 1.2]),
object2=polars.Series([0.0, float("nan"), float("nan"), 1.2]),
config=config,
)
== equal_nan
)


Expand Down

0 comments on commit 77a3bb4

Please sign in to comment.