diff --git a/.gitignore b/.gitignore index 6c3c275c48fb7..edf1d07d901e2 100644 --- a/.gitignore +++ b/.gitignore @@ -118,3 +118,4 @@ doc/build/html/index.html doc/tmp.sv env/ doc/source/savefig/ +.dmypy.json diff --git a/pandas/_typing.py b/pandas/_typing.py index 445eff9e19e47..db3dae46ed479 100644 --- a/pandas/_typing.py +++ b/pandas/_typing.py @@ -23,6 +23,7 @@ from pandas.core.indexes.base import Index # noqa: F401 from pandas.core.series import Series # noqa: F401 from pandas.core.generic import NDFrame # noqa: F401 + from pandas.core.base import IndexOpsMixin # noqa: F401 AnyArrayLike = TypeVar("AnyArrayLike", "ExtensionArray", "Index", "Series", np.ndarray) @@ -32,6 +33,7 @@ FilePathOrBuffer = Union[str, Path, IO[AnyStr]] FrameOrSeries = TypeVar("FrameOrSeries", bound="NDFrame") +IndexOrSeries = TypeVar("IndexOrSeries", bound="IndexOpsMixin") Scalar = Union[str, int, float, bool] Axis = Union[str, int] Ordered = Optional[bool] diff --git a/pandas/conftest.py b/pandas/conftest.py index b032e14d8f7e1..cbf82a66d3e17 100644 --- a/pandas/conftest.py +++ b/pandas/conftest.py @@ -14,6 +14,7 @@ import pandas as pd from pandas import DataFrame +from pandas._typing import IndexOrSeries from pandas.core import ops import pandas.util.testing as tm @@ -790,6 +791,17 @@ def tick_classes(request): return request.param +index_or_series_params = [pd.Index, pd.Series] # type: IndexOrSeries + + +@pytest.fixture(params=index_or_series_params, ids=["series", "index"]) +def index_or_series(request) -> IndexOrSeries: + """ + Parametrized fixture providing the Index or Series class. + """ + return request.param + + # ---------------------------------------------------------------- # Global setup for tests using Hypothesis diff --git a/pandas/tests/arithmetic/test_numeric.py b/pandas/tests/arithmetic/test_numeric.py index 584e22f8488f5..7ccb77748b114 100644 --- a/pandas/tests/arithmetic/test_numeric.py +++ b/pandas/tests/arithmetic/test_numeric.py @@ -5,6 +5,7 @@ from decimal import Decimal from itertools import combinations import operator +from typing import List, Type, Union import numpy as np import pytest @@ -74,6 +75,14 @@ def test_compare_invalid(self): # ------------------------------------------------------------------ # Numeric dtypes Arithmetic with Datetime/Timedelta Scalar +index_or_series_params = [ + pd.Series, + pd.Index, +] # type: List[Union[Type[pd.Index], Type[pd.RangeIndex], Type[pd.Series]]] +left = [pd.RangeIndex(10, 40, 10)] # type: List[Union[Index, Series]] +for cls in index_or_series_params: + for dtype in ["i1", "i2", "i4", "i8", "u1", "u2", "u4", "u8", "f2", "f4", "f8"]: + left.append(cls([10, 20, 30], dtype=dtype)) class TestNumericArraylikeArithmeticWithDatetimeLike: @@ -81,26 +90,7 @@ class TestNumericArraylikeArithmeticWithDatetimeLike: # TODO: also check name retentention @pytest.mark.parametrize("box_cls", [np.array, pd.Index, pd.Series]) @pytest.mark.parametrize( - "left", - [pd.RangeIndex(10, 40, 10)] - + [ - cls([10, 20, 30], dtype=dtype) - for dtype in [ - "i1", - "i2", - "i4", - "i8", - "u1", - "u2", - "u4", - "u8", - "f2", - "f4", - "f8", - ] - for cls in [pd.Series, pd.Index] - ], - ids=lambda x: type(x).__name__ + str(x.dtype), + "left", left, ids=lambda x: type(x).__name__ + str(x.dtype) ) def test_mul_td64arr(self, left, box_cls): # GH#22390 @@ -120,26 +110,7 @@ def test_mul_td64arr(self, left, box_cls): # TODO: also check name retentention @pytest.mark.parametrize("box_cls", [np.array, pd.Index, pd.Series]) @pytest.mark.parametrize( - "left", - [pd.RangeIndex(10, 40, 10)] - + [ - cls([10, 20, 30], dtype=dtype) - for dtype in [ - "i1", - "i2", - "i4", - "i8", - "u1", - "u2", - "u4", - "u8", - "f2", - "f4", - "f8", - ] - for cls in [pd.Series, pd.Index] - ], - ids=lambda x: type(x).__name__ + str(x.dtype), + "left", left, ids=lambda x: type(x).__name__ + str(x.dtype) ) def test_div_td64arr(self, left, box_cls): # GH#22390 diff --git a/pandas/tests/arrays/test_array.py b/pandas/tests/arrays/test_array.py index e8d9ecfac61e4..d30c34ecc70ef 100644 --- a/pandas/tests/arrays/test_array.py +++ b/pandas/tests/arrays/test_array.py @@ -272,9 +272,8 @@ def _from_sequence(cls, scalars, dtype=None, copy=False): return super()._from_sequence(scalars, dtype=dtype, copy=copy) -@pytest.mark.parametrize("box", [pd.Series, pd.Index]) -def test_array_unboxes(box): - data = box([decimal.Decimal("1"), decimal.Decimal("2")]) +def test_array_unboxes(index_or_series): + data = index_or_series([decimal.Decimal("1"), decimal.Decimal("2")]) # make sure it works with pytest.raises(TypeError): DecimalArray2._from_sequence(data) diff --git a/pandas/tests/dtypes/test_concat.py b/pandas/tests/dtypes/test_concat.py index 0ca2f7c976535..02daa185b1cdb 100644 --- a/pandas/tests/dtypes/test_concat.py +++ b/pandas/tests/dtypes/test_concat.py @@ -2,7 +2,7 @@ import pandas.core.dtypes.concat as _concat -from pandas import DatetimeIndex, Index, Period, PeriodIndex, Series, TimedeltaIndex +from pandas import DatetimeIndex, Period, PeriodIndex, Series, TimedeltaIndex @pytest.mark.parametrize( @@ -40,9 +40,8 @@ ), ], ) -@pytest.mark.parametrize("klass", [Index, Series]) -def test_get_dtype_kinds(klass, to_concat, expected): - to_concat_klass = [klass(c) for c in to_concat] +def test_get_dtype_kinds(index_or_series, to_concat, expected): + to_concat_klass = [index_or_series(c) for c in to_concat] result = _concat.get_dtype_kinds(to_concat_klass) assert result == set(expected) diff --git a/pandas/tests/indexing/test_coercion.py b/pandas/tests/indexing/test_coercion.py index 4f38d7beb9c0b..aca78f5349b5e 100644 --- a/pandas/tests/indexing/test_coercion.py +++ b/pandas/tests/indexing/test_coercion.py @@ -515,55 +515,52 @@ def _assert_where_conversion( res = target.where(cond, values) self._assert(res, expected, expected_dtype) - @pytest.mark.parametrize("klass", [pd.Series, pd.Index], ids=["series", "index"]) @pytest.mark.parametrize( "fill_val,exp_dtype", [(1, np.object), (1.1, np.object), (1 + 1j, np.object), (True, np.object)], ) - def test_where_object(self, klass, fill_val, exp_dtype): - obj = klass(list("abcd")) + def test_where_object(self, index_or_series, fill_val, exp_dtype): + obj = index_or_series(list("abcd")) assert obj.dtype == np.object - cond = klass([True, False, True, False]) + cond = index_or_series([True, False, True, False]) - if fill_val is True and klass is pd.Series: + if fill_val is True and index_or_series is pd.Series: ret_val = 1 else: ret_val = fill_val - exp = klass(["a", ret_val, "c", ret_val]) + exp = index_or_series(["a", ret_val, "c", ret_val]) self._assert_where_conversion(obj, cond, fill_val, exp, exp_dtype) if fill_val is True: - values = klass([True, False, True, True]) + values = index_or_series([True, False, True, True]) else: - values = klass(fill_val * x for x in [5, 6, 7, 8]) + values = index_or_series(fill_val * x for x in [5, 6, 7, 8]) - exp = klass(["a", values[1], "c", values[3]]) + exp = index_or_series(["a", values[1], "c", values[3]]) self._assert_where_conversion(obj, cond, values, exp, exp_dtype) - @pytest.mark.parametrize("klass", [pd.Series, pd.Index], ids=["series", "index"]) @pytest.mark.parametrize( "fill_val,exp_dtype", [(1, np.int64), (1.1, np.float64), (1 + 1j, np.complex128), (True, np.object)], ) - def test_where_int64(self, klass, fill_val, exp_dtype): - if klass is pd.Index and exp_dtype is np.complex128: + def test_where_int64(self, index_or_series, fill_val, exp_dtype): + if index_or_series is pd.Index and exp_dtype is np.complex128: pytest.skip("Complex Index not supported") - obj = klass([1, 2, 3, 4]) + obj = index_or_series([1, 2, 3, 4]) assert obj.dtype == np.int64 - cond = klass([True, False, True, False]) + cond = index_or_series([True, False, True, False]) - exp = klass([1, fill_val, 3, fill_val]) + exp = index_or_series([1, fill_val, 3, fill_val]) self._assert_where_conversion(obj, cond, fill_val, exp, exp_dtype) if fill_val is True: - values = klass([True, False, True, True]) + values = index_or_series([True, False, True, True]) else: - values = klass(x * fill_val for x in [5, 6, 7, 8]) - exp = klass([1, values[1], 3, values[3]]) + values = index_or_series(x * fill_val for x in [5, 6, 7, 8]) + exp = index_or_series([1, values[1], 3, values[3]]) self._assert_where_conversion(obj, cond, values, exp, exp_dtype) - @pytest.mark.parametrize("klass", [pd.Series, pd.Index], ids=["series", "index"]) @pytest.mark.parametrize( "fill_val, exp_dtype", [ @@ -573,21 +570,21 @@ def test_where_int64(self, klass, fill_val, exp_dtype): (True, np.object), ], ) - def test_where_float64(self, klass, fill_val, exp_dtype): - if klass is pd.Index and exp_dtype is np.complex128: + def test_where_float64(self, index_or_series, fill_val, exp_dtype): + if index_or_series is pd.Index and exp_dtype is np.complex128: pytest.skip("Complex Index not supported") - obj = klass([1.1, 2.2, 3.3, 4.4]) + obj = index_or_series([1.1, 2.2, 3.3, 4.4]) assert obj.dtype == np.float64 - cond = klass([True, False, True, False]) + cond = index_or_series([True, False, True, False]) - exp = klass([1.1, fill_val, 3.3, fill_val]) + exp = index_or_series([1.1, fill_val, 3.3, fill_val]) self._assert_where_conversion(obj, cond, fill_val, exp, exp_dtype) if fill_val is True: - values = klass([True, False, True, True]) + values = index_or_series([True, False, True, True]) else: - values = klass(x * fill_val for x in [5, 6, 7, 8]) - exp = klass([1.1, values[1], 3.3, values[3]]) + values = index_or_series(x * fill_val for x in [5, 6, 7, 8]) + exp = index_or_series([1.1, values[1], 3.3, values[3]]) self._assert_where_conversion(obj, cond, values, exp, exp_dtype) @pytest.mark.parametrize( @@ -783,19 +780,17 @@ def _assert_fillna_conversion(self, original, value, expected, expected_dtype): res = target.fillna(value) self._assert(res, expected, expected_dtype) - @pytest.mark.parametrize("klass", [pd.Series, pd.Index], ids=["series", "index"]) @pytest.mark.parametrize( "fill_val, fill_dtype", [(1, np.object), (1.1, np.object), (1 + 1j, np.object), (True, np.object)], ) - def test_fillna_object(self, klass, fill_val, fill_dtype): - obj = klass(["a", np.nan, "c", "d"]) + def test_fillna_object(self, index_or_series, fill_val, fill_dtype): + obj = index_or_series(["a", np.nan, "c", "d"]) assert obj.dtype == np.object - exp = klass(["a", fill_val, "c", "d"]) + exp = index_or_series(["a", fill_val, "c", "d"]) self._assert_fillna_conversion(obj, fill_val, exp, fill_dtype) - @pytest.mark.parametrize("klass", [pd.Series, pd.Index], ids=["series", "index"]) @pytest.mark.parametrize( "fill_val,fill_dtype", [ @@ -805,15 +800,15 @@ def test_fillna_object(self, klass, fill_val, fill_dtype): (True, np.object), ], ) - def test_fillna_float64(self, klass, fill_val, fill_dtype): - obj = klass([1.1, np.nan, 3.3, 4.4]) + def test_fillna_float64(self, index_or_series, fill_val, fill_dtype): + obj = index_or_series([1.1, np.nan, 3.3, 4.4]) assert obj.dtype == np.float64 - exp = klass([1.1, fill_val, 3.3, 4.4]) + exp = index_or_series([1.1, fill_val, 3.3, 4.4]) # float + complex -> we don't support a complex Index # complex for Series, # object for Index - if fill_dtype == np.complex128 and klass == pd.Index: + if fill_dtype == np.complex128 and index_or_series == pd.Index: fill_dtype = np.object self._assert_fillna_conversion(obj, fill_val, exp, fill_dtype) @@ -833,7 +828,6 @@ def test_fillna_series_complex128(self, fill_val, fill_dtype): exp = pd.Series([1 + 1j, fill_val, 3 + 3j, 4 + 4j]) self._assert_fillna_conversion(obj, fill_val, exp, fill_dtype) - @pytest.mark.parametrize("klass", [pd.Series, pd.Index], ids=["series", "index"]) @pytest.mark.parametrize( "fill_val,fill_dtype", [ @@ -844,8 +838,8 @@ def test_fillna_series_complex128(self, fill_val, fill_dtype): ], ids=["datetime64", "datetime64tz", "object", "object"], ) - def test_fillna_datetime(self, klass, fill_val, fill_dtype): - obj = klass( + def test_fillna_datetime(self, index_or_series, fill_val, fill_dtype): + obj = index_or_series( [ pd.Timestamp("2011-01-01"), pd.NaT, @@ -855,7 +849,7 @@ def test_fillna_datetime(self, klass, fill_val, fill_dtype): ) assert obj.dtype == "datetime64[ns]" - exp = klass( + exp = index_or_series( [ pd.Timestamp("2011-01-01"), fill_val, @@ -865,7 +859,6 @@ def test_fillna_datetime(self, klass, fill_val, fill_dtype): ) self._assert_fillna_conversion(obj, fill_val, exp, fill_dtype) - @pytest.mark.parametrize("klass", [pd.Series, pd.Index]) @pytest.mark.parametrize( "fill_val,fill_dtype", [ @@ -876,10 +869,10 @@ def test_fillna_datetime(self, klass, fill_val, fill_dtype): ("x", np.object), ], ) - def test_fillna_datetime64tz(self, klass, fill_val, fill_dtype): + def test_fillna_datetime64tz(self, index_or_series, fill_val, fill_dtype): tz = "US/Eastern" - obj = klass( + obj = index_or_series( [ pd.Timestamp("2011-01-01", tz=tz), pd.NaT, @@ -889,7 +882,7 @@ def test_fillna_datetime64tz(self, klass, fill_val, fill_dtype): ) assert obj.dtype == "datetime64[ns, US/Eastern]" - exp = klass( + exp = index_or_series( [ pd.Timestamp("2011-01-01", tz=tz), fill_val, diff --git a/pandas/tests/io/json/test_json_table_schema.py b/pandas/tests/io/json/test_json_table_schema.py index 569e299860614..d3c493f35f5c4 100644 --- a/pandas/tests/io/json/test_json_table_schema.py +++ b/pandas/tests/io/json/test_json_table_schema.py @@ -431,17 +431,15 @@ def test_date_format_raises(self): self.df.to_json(orient="table", date_format="iso") self.df.to_json(orient="table") - @pytest.mark.parametrize("kind", [pd.Series, pd.Index]) - def test_convert_pandas_type_to_json_field_int(self, kind): + def test_convert_pandas_type_to_json_field_int(self, index_or_series): data = [1, 2, 3] - result = convert_pandas_type_to_json_field(kind(data, name="name")) + result = convert_pandas_type_to_json_field(index_or_series(data, name="name")) expected = {"name": "name", "type": "integer"} assert result == expected - @pytest.mark.parametrize("kind", [pd.Series, pd.Index]) - def test_convert_pandas_type_to_json_field_float(self, kind): + def test_convert_pandas_type_to_json_field_float(self, index_or_series): data = [1.0, 2.0, 3.0] - result = convert_pandas_type_to_json_field(kind(data, name="name")) + result = convert_pandas_type_to_json_field(index_or_series(data, name="name")) expected = {"name": "name", "type": "number"} assert result == expected diff --git a/pandas/tests/test_base.py b/pandas/tests/test_base.py index 1f19f58e80f26..472551b76b355 100644 --- a/pandas/tests/test_base.py +++ b/pandas/tests/test_base.py @@ -2,6 +2,7 @@ from io import StringIO import re import sys +from typing import List, Type import numpy as np import pytest @@ -35,10 +36,12 @@ ) from pandas.core.accessor import PandasDelegate from pandas.core.arrays import DatetimeArray, PandasArray, TimedeltaArray -from pandas.core.base import NoNewAttributesMixin, PandasObject +from pandas.core.base import IndexOpsMixin, NoNewAttributesMixin, PandasObject from pandas.core.indexes.datetimelike import DatetimeIndexOpsMixin import pandas.util.testing as tm +index_or_series = [Index, Series] # type: List[Type[IndexOpsMixin]] + class CheckStringMixin: def test_string_methods_dont_fail(self): @@ -516,7 +519,7 @@ def test_value_counts_unique_nunique_null(self, null_obj): assert o.nunique() == 8 assert o.nunique(dropna=False) == 9 - @pytest.mark.parametrize("klass", [Index, Series]) + @pytest.mark.parametrize("klass", index_or_series) def test_value_counts_inferred(self, klass): s_values = ["a", "b", "b", "b", "b", "c", "d", "d", "a", "a"] s = klass(s_values) @@ -547,7 +550,7 @@ def test_value_counts_inferred(self, klass): expected = Series([0.4, 0.3, 0.2, 0.1], index=["b", "a", "d", "c"]) tm.assert_series_equal(hist, expected) - @pytest.mark.parametrize("klass", [Index, Series]) + @pytest.mark.parametrize("klass", index_or_series) def test_value_counts_bins(self, klass): s_values = ["a", "b", "b", "b", "b", "c", "d", "d", "a", "a"] s = klass(s_values) @@ -612,7 +615,7 @@ def test_value_counts_bins(self, klass): assert s.nunique() == 0 - @pytest.mark.parametrize("klass", [Index, Series]) + @pytest.mark.parametrize("klass", index_or_series) def test_value_counts_datetime64(self, klass): # GH 3002, datetime64[ns] @@ -1090,7 +1093,7 @@ class TestToIterable: ], ids=["tolist", "to_list", "list", "iter"], ) - @pytest.mark.parametrize("typ", [Series, Index]) + @pytest.mark.parametrize("typ", index_or_series) @pytest.mark.filterwarnings("ignore:\\n Passing:FutureWarning") # TODO(GH-24559): Remove the filterwarnings def test_iterable(self, typ, method, dtype, rdtype): @@ -1120,7 +1123,7 @@ def test_iterable(self, typ, method, dtype, rdtype): ], ids=["tolist", "to_list", "list", "iter"], ) - @pytest.mark.parametrize("typ", [Series, Index]) + @pytest.mark.parametrize("typ", index_or_series) def test_iterable_object_and_category(self, typ, method, dtype, rdtype, obj): # gh-10904 # gh-13258 @@ -1144,7 +1147,7 @@ def test_iterable_items(self, dtype, rdtype): @pytest.mark.parametrize( "dtype, rdtype", dtypes + [("object", int), ("category", int)] ) - @pytest.mark.parametrize("typ", [Series, Index]) + @pytest.mark.parametrize("typ", index_or_series) @pytest.mark.filterwarnings("ignore:\\n Passing:FutureWarning") # TODO(GH-24559): Remove the filterwarnings def test_iterable_map(self, typ, dtype, rdtype): @@ -1332,7 +1335,7 @@ def test_numpy_array_all_dtypes(any_numpy_dtype): ), ], ) -@pytest.mark.parametrize("box", [pd.Series, pd.Index]) +@pytest.mark.parametrize("box", index_or_series) def test_array(array, attr, box): if array.dtype.name in ("Int64", "Sparse[int64, 0]") and box is pd.Index: pytest.skip("No index type for {}".format(array.dtype)) @@ -1396,7 +1399,7 @@ def test_array_multiindex_raises(): ), ], ) -@pytest.mark.parametrize("box", [pd.Series, pd.Index]) +@pytest.mark.parametrize("box", index_or_series) def test_to_numpy(array, expected, box): thing = box(array) diff --git a/pandas/tests/test_strings.py b/pandas/tests/test_strings.py index 53d74f74dc439..03cc80673fd5c 100644 --- a/pandas/tests/test_strings.py +++ b/pandas/tests/test_strings.py @@ -1,5 +1,6 @@ from datetime import datetime, timedelta import re +from typing import List, Type import numpy as np from numpy.random import randint @@ -8,10 +9,13 @@ from pandas._libs import lib from pandas import DataFrame, Index, MultiIndex, Series, concat, isna, notna +from pandas.core.base import PandasObject import pandas.core.strings as strings import pandas.util.testing as tm from pandas.util.testing import assert_index_equal, assert_series_equal +index_or_series_params = [Index, Series] # type: List[Type[PandasObject]] + def assert_series_or_index_equal(left, right): if isinstance(left, Series): @@ -203,12 +207,11 @@ def test_api_mi_raises(self): assert not hasattr(mi, "str") @pytest.mark.parametrize("dtype", [object, "category"]) - @pytest.mark.parametrize("box", [Series, Index]) - def test_api_per_dtype(self, box, dtype, any_skipna_inferred_dtype): + def test_api_per_dtype(self, index_or_series, dtype, any_skipna_inferred_dtype): # one instance of parametrized fixture inferred_dtype, values = any_skipna_inferred_dtype - t = box(values, dtype=dtype) # explicit dtype to avoid casting + t = index_or_series(values, dtype=dtype) # explicit dtype to avoid casting # TODO: get rid of these xfails if dtype == "category" and inferred_dtype in ["period", "interval"]: @@ -237,9 +240,12 @@ def test_api_per_dtype(self, box, dtype, any_skipna_inferred_dtype): assert not hasattr(t, "str") @pytest.mark.parametrize("dtype", [object, "category"]) - @pytest.mark.parametrize("box", [Series, Index]) def test_api_per_method( - self, box, dtype, any_allowed_skipna_inferred_dtype, any_string_method + self, + index_or_series, + dtype, + any_allowed_skipna_inferred_dtype, + any_string_method, ): # this test does not check correctness of the different methods, # just that the methods work on the specified (inferred) dtypes, @@ -252,26 +258,26 @@ def test_api_per_method( # TODO: get rid of these xfails if ( method_name in ["partition", "rpartition"] - and box == Index + and index_or_series == Index and inferred_dtype == "empty" ): pytest.xfail(reason="Method cannot deal with empty Index") if ( method_name == "split" - and box == Index + and index_or_series == Index and values.size == 0 and kwargs.get("expand", None) is not None ): pytest.xfail(reason="Split fails on empty Series when expand=True") if ( method_name == "get_dummies" - and box == Index + and index_or_series == Index and inferred_dtype == "empty" and (dtype == object or values.size == 0) ): pytest.xfail(reason="Need to fortify get_dummies corner cases") - t = box(values, dtype=dtype) # explicit dtype to avoid casting + t = index_or_series(values, dtype=dtype) # explicit dtype to avoid casting method = getattr(t.str, method_name) bytes_allowed = method_name in ["decode", "get", "len", "slice"] @@ -376,23 +382,21 @@ def test_iter_object_try_string(self): assert i == 100 assert s == "h" - @pytest.mark.parametrize("box", [Series, Index]) @pytest.mark.parametrize("other", [None, Series, Index]) - def test_str_cat_name(self, box, other): + def test_str_cat_name(self, index_or_series, other): # GH 21053 values = ["a", "b"] if other: other = other(values) else: other = values - result = box(values, name="name").str.cat(other, sep=",") + result = index_or_series(values, name="name").str.cat(other, sep=",") assert result.name == "name" - @pytest.mark.parametrize("box", [Series, Index]) - def test_str_cat(self, box): + def test_str_cat(self, index_or_series): # test_cat above tests "str_cat" from ndarray; # here testing "str.cat" from Series/Indext to ndarray/list - s = box(["a", "a", "b", "b", "c", np.nan]) + s = index_or_series(["a", "a", "b", "b", "c", np.nan]) # single array result = s.str.cat() @@ -408,7 +412,7 @@ def test_str_cat(self, box): assert result == expected t = np.array(["a", np.nan, "b", "d", "foo", np.nan], dtype=object) - expected = box(["aa", "a-", "bb", "bd", "cfoo", "--"]) + expected = index_or_series(["aa", "a-", "bb", "bd", "cfoo", "--"]) # Series/Index with array result = s.str.cat(t, na_rep="-") @@ -428,10 +432,9 @@ def test_str_cat(self, box): with pytest.raises(ValueError, match=rgx): s.str.cat(list(z)) - @pytest.mark.parametrize("box", [Series, Index]) - def test_str_cat_raises_intuitive_error(self, box): + def test_str_cat_raises_intuitive_error(self, index_or_series): # GH 11334 - s = box(["a", "b", "c", "d"]) + s = index_or_series(["a", "b", "c", "d"]) message = "Did you mean to supply a `sep` keyword?" with pytest.raises(ValueError, match=message): s.str.cat("|") @@ -441,14 +444,15 @@ def test_str_cat_raises_intuitive_error(self, box): @pytest.mark.parametrize("sep", ["", None]) @pytest.mark.parametrize("dtype_target", ["object", "category"]) @pytest.mark.parametrize("dtype_caller", ["object", "category"]) - @pytest.mark.parametrize("box", [Series, Index]) - def test_str_cat_categorical(self, box, dtype_caller, dtype_target, sep): + def test_str_cat_categorical( + self, index_or_series, dtype_caller, dtype_target, sep + ): s = Index(["a", "a", "b", "a"], dtype=dtype_caller) - s = s if box == Index else Series(s, index=s) + s = s if index_or_series == Index else Series(s, index=s) t = Index(["b", "a", "b", "c"], dtype=dtype_target) expected = Index(["ab", "aa", "bb", "ac"]) - expected = expected if box == Index else Series(expected, index=s) + expected = expected if index_or_series == Index else Series(expected, index=s) # Series/Index with unaligned Index -> t.values result = s.str.cat(t.values, sep=sep) @@ -467,7 +471,9 @@ def test_str_cat_categorical(self, box, dtype_caller, dtype_target, sep): t = Series(t.values, index=t.values) expected = Index(["aa", "aa", "aa", "bb", "bb"]) expected = ( - expected if box == Index else Series(expected, index=expected.str[:1]) + expected + if index_or_series == Index + else Series(expected, index=expected.str[:1]) ) result = s.str.cat(t, sep=sep) @@ -495,16 +501,19 @@ def test_str_cat_wrong_dtype_raises(self, box, data): # need to use outer and na_rep, as otherwise Index would not raise s.str.cat(t, join="outer", na_rep="-") - @pytest.mark.parametrize("box", [Series, Index]) - def test_str_cat_mixed_inputs(self, box): + def test_str_cat_mixed_inputs(self, index_or_series): s = Index(["a", "b", "c", "d"]) - s = s if box == Index else Series(s, index=s) + s = s if index_or_series == Index else Series(s, index=s) t = Series(["A", "B", "C", "D"], index=s.values) d = concat([t, Series(s, index=s)], axis=1) expected = Index(["aAa", "bBb", "cCc", "dDd"]) - expected = expected if box == Index else Series(expected.values, index=s.values) + expected = ( + expected + if index_or_series == Index + else Series(expected.values, index=s.values) + ) # Series/Index with DataFrame result = s.str.cat(d) @@ -524,8 +533,12 @@ def test_str_cat_mixed_inputs(self, box): # Series/Index with list of Series; different indexes t.index = ["b", "c", "d", "a"] - expected = box(["aDa", "bAb", "cBc", "dCd"]) - expected = expected if box == Index else Series(expected.values, index=s.values) + expected = index_or_series(["aDa", "bAb", "cBc", "dCd"]) + expected = ( + expected + if index_or_series == Index + else Series(expected.values, index=s.values) + ) result = s.str.cat([t, s]) assert_series_or_index_equal(result, expected) @@ -535,8 +548,12 @@ def test_str_cat_mixed_inputs(self, box): # Series/Index with DataFrame; different indexes d.index = ["b", "c", "d", "a"] - expected = box(["aDd", "bAa", "cBb", "dCc"]) - expected = expected if box == Index else Series(expected.values, index=s.values) + expected = index_or_series(["aDd", "bAa", "cBb", "dCc"]) + expected = ( + expected + if index_or_series == Index + else Series(expected.values, index=s.values) + ) result = s.str.cat(d) assert_series_or_index_equal(result, expected) @@ -597,8 +614,7 @@ def test_str_cat_mixed_inputs(self, box): s.str.cat(iter([t.values, list(s)])) @pytest.mark.parametrize("join", ["left", "outer", "inner", "right"]) - @pytest.mark.parametrize("box", [Series, Index]) - def test_str_cat_align_indexed(self, box, join): + def test_str_cat_align_indexed(self, index_or_series, join): # https://github.com/pandas-dev/pandas/issues/18657 s = Series(["a", "b", "c", "d"], index=["a", "b", "c", "d"]) t = Series(["D", "A", "E", "B"], index=["d", "a", "e", "b"]) @@ -606,7 +622,7 @@ def test_str_cat_align_indexed(self, box, join): # result after manual alignment of inputs expected = sa.str.cat(ta, na_rep="-") - if box == Index: + if index_or_series == Index: s = Index(s) sa = Index(sa) expected = Index(expected) @@ -657,22 +673,20 @@ def test_str_cat_align_mixed_inputs(self, join): with pytest.raises(ValueError, match=rgx): s.str.cat([t, z], join=join) - @pytest.mark.parametrize("box", [Series, Index]) - @pytest.mark.parametrize("other", [Series, Index]) - def test_str_cat_all_na(self, box, other): + @pytest.mark.parametrize("other", index_or_series_params) + def test_str_cat_all_na(self, index_or_series, other): # GH 24044 - # check that all NaNs in caller / target work s = Index(["a", "b", "c", "d"]) - s = s if box == Index else Series(s, index=s) + s = s if index_or_series == Index else Series(s, index=s) t = other([np.nan] * 4, dtype=object) # add index of s for alignment t = t if other == Index else Series(t, index=s) # all-NA target - if box == Series: + if index_or_series == Series: expected = Series([np.nan] * 4, index=s.index, dtype=object) - else: # box == Index + else: # index_or_series == Index expected = Index([np.nan] * 4, dtype=object) result = s.str.cat(t, join="left") assert_series_or_index_equal(result, expected)