Skip to content

Commit

Permalink
Backport PR #56445: Adjust merge tests for new string option (#56938)
Browse files Browse the repository at this point in the history
Co-authored-by: Patrick Hoefler <61934744+phofl@users.noreply.github.com>
  • Loading branch information
lithomas1 and phofl authored Jan 18, 2024
1 parent 988c3a4 commit 74fa740
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 25 deletions.
4 changes: 2 additions & 2 deletions pandas/tests/reshape/merge/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ def test_mixed_type_join_with_suffix(self):
df.insert(5, "dt", "foo")

grouped = df.groupby("id")
msg = re.escape("agg function failed [how->mean,dtype->object]")
msg = re.escape("agg function failed [how->mean,dtype->")
with pytest.raises(TypeError, match=msg):
grouped.mean()
mn = grouped.mean(numeric_only=True)
Expand Down Expand Up @@ -776,7 +776,7 @@ def test_join_on_tz_aware_datetimeindex(self):
)
result = df1.join(df2.set_index("date"), on="date")
expected = df1.copy()
expected["vals_2"] = Series([np.nan] * 2 + list("tuv"), dtype=object)
expected["vals_2"] = Series([np.nan] * 2 + list("tuv"))
tm.assert_frame_equal(result, expected)

def test_join_datetime_string(self):
Expand Down
46 changes: 28 additions & 18 deletions pandas/tests/reshape/merge/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
import numpy as np
import pytest

from pandas.core.dtypes.common import is_object_dtype
from pandas.core.dtypes.common import (
is_object_dtype,
is_string_dtype,
)
from pandas.core.dtypes.dtypes import CategoricalDtype

import pandas as pd
Expand Down Expand Up @@ -316,14 +319,15 @@ def test_merge_copy(self):
merged["d"] = "peekaboo"
assert (right["d"] == "bar").all()

def test_merge_nocopy(self, using_array_manager):
def test_merge_nocopy(self, using_array_manager, using_infer_string):
left = DataFrame({"a": 0, "b": 1}, index=range(10))
right = DataFrame({"c": "foo", "d": "bar"}, index=range(10))

merged = merge(left, right, left_index=True, right_index=True, copy=False)

assert np.shares_memory(merged["a"]._values, left["a"]._values)
assert np.shares_memory(merged["d"]._values, right["d"]._values)
if not using_infer_string:
assert np.shares_memory(merged["d"]._values, right["d"]._values)

def test_intelligently_handle_join_key(self):
# #733, be a bit more 1337 about not returning unconsolidated DataFrame
Expand Down Expand Up @@ -667,11 +671,13 @@ def test_merge_nan_right(self):
"i1_": {0: 0, 1: np.nan},
"i3": {0: 0.0, 1: np.nan},
None: {0: 0, 1: 0},
}
},
columns=Index(["i1", "i2", "i1_", "i3", None], dtype=object),
)
.set_index(None)
.reset_index()[["i1", "i2", "i1_", "i3"]]
)
result.columns = result.columns.astype("object")
tm.assert_frame_equal(result, expected, check_dtype=False)

def test_merge_nan_right2(self):
Expand Down Expand Up @@ -820,7 +826,7 @@ def test_overlapping_columns_error_message(self):

# #2649, #10639
df2.columns = ["key1", "foo", "foo"]
msg = r"Data columns not unique: Index\(\['foo'\], dtype='object'\)"
msg = r"Data columns not unique: Index\(\['foo'\], dtype='object|string'\)"
with pytest.raises(MergeError, match=msg):
merge(df, df2)

Expand Down Expand Up @@ -1498,7 +1504,7 @@ def test_different(self, right_vals):
# We allow merging on object and categorical cols and cast
# categorical cols to object
result = merge(left, right, on="A")
assert is_object_dtype(result.A.dtype)
assert is_object_dtype(result.A.dtype) or is_string_dtype(result.A.dtype)

@pytest.mark.parametrize(
"d1", [np.int64, np.int32, np.intc, np.int16, np.int8, np.uint8]
Expand Down Expand Up @@ -1637,7 +1643,7 @@ def test_merge_incompat_dtypes_are_ok(self, df1_vals, df2_vals):
result = merge(df1, df2, on=["A"])
assert is_object_dtype(result.A.dtype)
result = merge(df2, df1, on=["A"])
assert is_object_dtype(result.A.dtype)
assert is_object_dtype(result.A.dtype) or is_string_dtype(result.A.dtype)

@pytest.mark.parametrize(
"df1_vals, df2_vals",
Expand Down Expand Up @@ -1867,25 +1873,27 @@ def right():


class TestMergeCategorical:
def test_identical(self, left):
def test_identical(self, left, using_infer_string):
# merging on the same, should preserve dtypes
merged = merge(left, left, on="X")
result = merged.dtypes.sort_index()
dtype = np.dtype("O") if not using_infer_string else "string"
expected = Series(
[CategoricalDtype(categories=["foo", "bar"]), np.dtype("O"), np.dtype("O")],
[CategoricalDtype(categories=["foo", "bar"]), dtype, dtype],
index=["X", "Y_x", "Y_y"],
)
tm.assert_series_equal(result, expected)

def test_basic(self, left, right):
def test_basic(self, left, right, using_infer_string):
# we have matching Categorical dtypes in X
# so should preserve the merged column
merged = merge(left, right, on="X")
result = merged.dtypes.sort_index()
dtype = np.dtype("O") if not using_infer_string else "string"
expected = Series(
[
CategoricalDtype(categories=["foo", "bar"]),
np.dtype("O"),
dtype,
np.dtype("int64"),
],
index=["X", "Y", "Z"],
Expand Down Expand Up @@ -1989,16 +1997,17 @@ def test_multiindex_merge_with_unordered_categoricalindex(self, ordered):
).set_index(["id", "p"])
tm.assert_frame_equal(result, expected)

def test_other_columns(self, left, right):
def test_other_columns(self, left, right, using_infer_string):
# non-merge columns should preserve if possible
right = right.assign(Z=right.Z.astype("category"))

merged = merge(left, right, on="X")
result = merged.dtypes.sort_index()
dtype = np.dtype("O") if not using_infer_string else "string"
expected = Series(
[
CategoricalDtype(categories=["foo", "bar"]),
np.dtype("O"),
dtype,
CategoricalDtype(categories=[1, 2]),
],
index=["X", "Y", "Z"],
Expand All @@ -2017,7 +2026,9 @@ def test_other_columns(self, left, right):
lambda x: x.astype(CategoricalDtype(ordered=True)),
],
)
def test_dtype_on_merged_different(self, change, join_type, left, right):
def test_dtype_on_merged_different(
self, change, join_type, left, right, using_infer_string
):
# our merging columns, X now has 2 different dtypes
# so we must be object as a result

Expand All @@ -2029,9 +2040,8 @@ def test_dtype_on_merged_different(self, change, join_type, left, right):
merged = merge(left, right, on="X", how=join_type)

result = merged.dtypes.sort_index()
expected = Series(
[np.dtype("O"), np.dtype("O"), np.dtype("int64")], index=["X", "Y", "Z"]
)
dtype = np.dtype("O") if not using_infer_string else "string"
expected = Series([dtype, dtype, np.dtype("int64")], index=["X", "Y", "Z"])
tm.assert_series_equal(result, expected)

def test_self_join_multiple_categories(self):
Expand Down Expand Up @@ -2499,7 +2509,7 @@ def test_merge_multiindex_columns():
expected_index = MultiIndex.from_tuples(tuples, names=["outer", "inner"])
expected = DataFrame(columns=expected_index)

tm.assert_frame_equal(result, expected)
tm.assert_frame_equal(result, expected, check_dtype=False)


def test_merge_datetime_upcast_dtype():
Expand Down
15 changes: 11 additions & 4 deletions pandas/tests/reshape/merge/test_merge_asof.py
Original file line number Diff line number Diff line change
Expand Up @@ -3081,8 +3081,11 @@ def test_on_float_by_int(self):

tm.assert_frame_equal(result, expected)

def test_merge_datatype_error_raises(self):
msg = r"Incompatible merge dtype, .*, both sides must have numeric dtype"
def test_merge_datatype_error_raises(self, using_infer_string):
if using_infer_string:
msg = "incompatible merge keys"
else:
msg = r"Incompatible merge dtype, .*, both sides must have numeric dtype"

left = pd.DataFrame({"left_val": [1, 5, 10], "a": ["a", "b", "c"]})
right = pd.DataFrame({"right_val": [1, 2, 3, 6, 7], "a": [1, 2, 3, 6, 7]})
Expand Down Expand Up @@ -3134,7 +3137,7 @@ def test_merge_on_nans(self, func, side):
else:
merge_asof(df, df_null, on="a")

def test_by_nullable(self, any_numeric_ea_dtype):
def test_by_nullable(self, any_numeric_ea_dtype, using_infer_string):
# Note: this test passes if instead of using pd.array we use
# np.array([np.nan, 1]). Other than that, I (@jbrockmendel)
# have NO IDEA what the expected behavior is.
Expand Down Expand Up @@ -3176,6 +3179,8 @@ def test_by_nullable(self, any_numeric_ea_dtype):
}
)
expected["value_y"] = np.array([np.nan, np.nan, np.nan], dtype=object)
if using_infer_string:
expected["value_y"] = expected["value_y"].astype("string[pyarrow_numpy]")
tm.assert_frame_equal(result, expected)

def test_merge_by_col_tz_aware(self):
Expand All @@ -3201,7 +3206,7 @@ def test_merge_by_col_tz_aware(self):
)
tm.assert_frame_equal(result, expected)

def test_by_mixed_tz_aware(self):
def test_by_mixed_tz_aware(self, using_infer_string):
# GH 26649
left = pd.DataFrame(
{
Expand All @@ -3225,6 +3230,8 @@ def test_by_mixed_tz_aware(self):
columns=["by_col1", "by_col2", "on_col", "value_x"],
)
expected["value_y"] = np.array([np.nan], dtype=object)
if using_infer_string:
expected["value_y"] = expected["value_y"].astype("string[pyarrow_numpy]")
tm.assert_frame_equal(result, expected)

@pytest.mark.parametrize("dtype", ["float64", "int16", "m8[ns]", "M8[us]"])
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/reshape/merge/test_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ def test_join_multi_levels_outer(self, portfolio, household, expected):
axis=0,
sort=True,
).reindex(columns=expected.columns)
tm.assert_frame_equal(result, expected)
tm.assert_frame_equal(result, expected, check_index_type=False)

def test_join_multi_levels_invalid(self, portfolio, household):
portfolio = portfolio.copy()
Expand Down

0 comments on commit 74fa740

Please sign in to comment.