Skip to content

Commit

Permalink
TYPING: Added types for tests files
Browse files Browse the repository at this point in the history
Working around a strange typing issue. See
pandas-dev#28394 (comment)
for more, but the types on these were being inferred incorrectly by
mypy with just the addition of the `allows_duplicate_labels` kwarg.
  • Loading branch information
TomAugspurger committed Oct 24, 2019
1 parent da1401b commit b4e11f1
Show file tree
Hide file tree
Showing 10 changed files with 140 additions and 148 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,4 @@ doc/build/html/index.html
doc/tmp.sv
env/
doc/source/savefig/
.dmypy.json
2 changes: 2 additions & 0 deletions pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pandas.core.indexes.base import Index # noqa: F401
from pandas.core.series import Series # noqa: F401
from pandas.core.generic import NDFrame # noqa: F401
from pandas.core.base import IndexOpsMixin # noqa: F401


AnyArrayLike = TypeVar("AnyArrayLike", "ExtensionArray", "Index", "Series", np.ndarray)
Expand All @@ -32,6 +33,7 @@
FilePathOrBuffer = Union[str, Path, IO[AnyStr]]

FrameOrSeries = TypeVar("FrameOrSeries", bound="NDFrame")
IndexOrSeries = TypeVar("IndexOrSeries", bound="IndexOpsMixin")
Scalar = Union[str, int, float, bool]
Axis = Union[str, int]
Ordered = Optional[bool]
Expand Down
12 changes: 12 additions & 0 deletions pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import pandas as pd
from pandas import DataFrame
from pandas._typing import IndexOrSeries
from pandas.core import ops
import pandas.util.testing as tm

Expand Down Expand Up @@ -790,6 +791,17 @@ def tick_classes(request):
return request.param


index_or_series_params = [pd.Index, pd.Series] # type: IndexOrSeries


@pytest.fixture(params=index_or_series_params, ids=["series", "index"])
def index_or_series(request) -> IndexOrSeries:
"""
Parametrized fixture providing the Index or Series class.
"""
return request.param


# ----------------------------------------------------------------
# Global setup for tests using Hypothesis

Expand Down
51 changes: 11 additions & 40 deletions pandas/tests/arithmetic/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from decimal import Decimal
from itertools import combinations
import operator
from typing import List, Type, Union

import numpy as np
import pytest
Expand Down Expand Up @@ -74,33 +75,22 @@ def test_compare_invalid(self):

# ------------------------------------------------------------------
# Numeric dtypes Arithmetic with Datetime/Timedelta Scalar
index_or_series_params = [
pd.Series,
pd.Index,
] # type: List[Union[Type[pd.Index], Type[pd.RangeIndex], Type[pd.Series]]]
left = [pd.RangeIndex(10, 40, 10)] # type: List[Union[Index, Series]]
for cls in index_or_series_params:
for dtype in ["i1", "i2", "i4", "i8", "u1", "u2", "u4", "u8", "f2", "f4", "f8"]:
left.append(cls([10, 20, 30], dtype=dtype))


class TestNumericArraylikeArithmeticWithDatetimeLike:

# TODO: also check name retentention
@pytest.mark.parametrize("box_cls", [np.array, pd.Index, pd.Series])
@pytest.mark.parametrize(
"left",
[pd.RangeIndex(10, 40, 10)]
+ [
cls([10, 20, 30], dtype=dtype)
for dtype in [
"i1",
"i2",
"i4",
"i8",
"u1",
"u2",
"u4",
"u8",
"f2",
"f4",
"f8",
]
for cls in [pd.Series, pd.Index]
],
ids=lambda x: type(x).__name__ + str(x.dtype),
"left", left, ids=lambda x: type(x).__name__ + str(x.dtype)
)
def test_mul_td64arr(self, left, box_cls):
# GH#22390
Expand All @@ -120,26 +110,7 @@ def test_mul_td64arr(self, left, box_cls):
# TODO: also check name retentention
@pytest.mark.parametrize("box_cls", [np.array, pd.Index, pd.Series])
@pytest.mark.parametrize(
"left",
[pd.RangeIndex(10, 40, 10)]
+ [
cls([10, 20, 30], dtype=dtype)
for dtype in [
"i1",
"i2",
"i4",
"i8",
"u1",
"u2",
"u4",
"u8",
"f2",
"f4",
"f8",
]
for cls in [pd.Series, pd.Index]
],
ids=lambda x: type(x).__name__ + str(x.dtype),
"left", left, ids=lambda x: type(x).__name__ + str(x.dtype)
)
def test_div_td64arr(self, left, box_cls):
# GH#22390
Expand Down
5 changes: 2 additions & 3 deletions pandas/tests/arrays/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,8 @@ def _from_sequence(cls, scalars, dtype=None, copy=False):
return super()._from_sequence(scalars, dtype=dtype, copy=copy)


@pytest.mark.parametrize("box", [pd.Series, pd.Index])
def test_array_unboxes(box):
data = box([decimal.Decimal("1"), decimal.Decimal("2")])
def test_array_unboxes(index_or_series):
data = index_or_series([decimal.Decimal("1"), decimal.Decimal("2")])
# make sure it works
with pytest.raises(TypeError):
DecimalArray2._from_sequence(data)
Expand Down
7 changes: 3 additions & 4 deletions pandas/tests/dtypes/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pandas.core.dtypes.concat as _concat

from pandas import DatetimeIndex, Index, Period, PeriodIndex, Series, TimedeltaIndex
from pandas import DatetimeIndex, Period, PeriodIndex, Series, TimedeltaIndex


@pytest.mark.parametrize(
Expand Down Expand Up @@ -40,9 +40,8 @@
),
],
)
@pytest.mark.parametrize("klass", [Index, Series])
def test_get_dtype_kinds(klass, to_concat, expected):
to_concat_klass = [klass(c) for c in to_concat]
def test_get_dtype_kinds(index_or_series, to_concat, expected):
to_concat_klass = [index_or_series(c) for c in to_concat]
result = _concat.get_dtype_kinds(to_concat_klass)
assert result == set(expected)

Expand Down
81 changes: 37 additions & 44 deletions pandas/tests/indexing/test_coercion.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,55 +515,52 @@ def _assert_where_conversion(
res = target.where(cond, values)
self._assert(res, expected, expected_dtype)

@pytest.mark.parametrize("klass", [pd.Series, pd.Index], ids=["series", "index"])
@pytest.mark.parametrize(
"fill_val,exp_dtype",
[(1, np.object), (1.1, np.object), (1 + 1j, np.object), (True, np.object)],
)
def test_where_object(self, klass, fill_val, exp_dtype):
obj = klass(list("abcd"))
def test_where_object(self, index_or_series, fill_val, exp_dtype):
obj = index_or_series(list("abcd"))
assert obj.dtype == np.object
cond = klass([True, False, True, False])
cond = index_or_series([True, False, True, False])

if fill_val is True and klass is pd.Series:
if fill_val is True and index_or_series is pd.Series:
ret_val = 1
else:
ret_val = fill_val

exp = klass(["a", ret_val, "c", ret_val])
exp = index_or_series(["a", ret_val, "c", ret_val])
self._assert_where_conversion(obj, cond, fill_val, exp, exp_dtype)

if fill_val is True:
values = klass([True, False, True, True])
values = index_or_series([True, False, True, True])
else:
values = klass(fill_val * x for x in [5, 6, 7, 8])
values = index_or_series(fill_val * x for x in [5, 6, 7, 8])

exp = klass(["a", values[1], "c", values[3]])
exp = index_or_series(["a", values[1], "c", values[3]])
self._assert_where_conversion(obj, cond, values, exp, exp_dtype)

@pytest.mark.parametrize("klass", [pd.Series, pd.Index], ids=["series", "index"])
@pytest.mark.parametrize(
"fill_val,exp_dtype",
[(1, np.int64), (1.1, np.float64), (1 + 1j, np.complex128), (True, np.object)],
)
def test_where_int64(self, klass, fill_val, exp_dtype):
if klass is pd.Index and exp_dtype is np.complex128:
def test_where_int64(self, index_or_series, fill_val, exp_dtype):
if index_or_series is pd.Index and exp_dtype is np.complex128:
pytest.skip("Complex Index not supported")
obj = klass([1, 2, 3, 4])
obj = index_or_series([1, 2, 3, 4])
assert obj.dtype == np.int64
cond = klass([True, False, True, False])
cond = index_or_series([True, False, True, False])

exp = klass([1, fill_val, 3, fill_val])
exp = index_or_series([1, fill_val, 3, fill_val])
self._assert_where_conversion(obj, cond, fill_val, exp, exp_dtype)

if fill_val is True:
values = klass([True, False, True, True])
values = index_or_series([True, False, True, True])
else:
values = klass(x * fill_val for x in [5, 6, 7, 8])
exp = klass([1, values[1], 3, values[3]])
values = index_or_series(x * fill_val for x in [5, 6, 7, 8])
exp = index_or_series([1, values[1], 3, values[3]])
self._assert_where_conversion(obj, cond, values, exp, exp_dtype)

@pytest.mark.parametrize("klass", [pd.Series, pd.Index], ids=["series", "index"])
@pytest.mark.parametrize(
"fill_val, exp_dtype",
[
Expand All @@ -573,21 +570,21 @@ def test_where_int64(self, klass, fill_val, exp_dtype):
(True, np.object),
],
)
def test_where_float64(self, klass, fill_val, exp_dtype):
if klass is pd.Index and exp_dtype is np.complex128:
def test_where_float64(self, index_or_series, fill_val, exp_dtype):
if index_or_series is pd.Index and exp_dtype is np.complex128:
pytest.skip("Complex Index not supported")
obj = klass([1.1, 2.2, 3.3, 4.4])
obj = index_or_series([1.1, 2.2, 3.3, 4.4])
assert obj.dtype == np.float64
cond = klass([True, False, True, False])
cond = index_or_series([True, False, True, False])

exp = klass([1.1, fill_val, 3.3, fill_val])
exp = index_or_series([1.1, fill_val, 3.3, fill_val])
self._assert_where_conversion(obj, cond, fill_val, exp, exp_dtype)

if fill_val is True:
values = klass([True, False, True, True])
values = index_or_series([True, False, True, True])
else:
values = klass(x * fill_val for x in [5, 6, 7, 8])
exp = klass([1.1, values[1], 3.3, values[3]])
values = index_or_series(x * fill_val for x in [5, 6, 7, 8])
exp = index_or_series([1.1, values[1], 3.3, values[3]])
self._assert_where_conversion(obj, cond, values, exp, exp_dtype)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -783,19 +780,17 @@ def _assert_fillna_conversion(self, original, value, expected, expected_dtype):
res = target.fillna(value)
self._assert(res, expected, expected_dtype)

@pytest.mark.parametrize("klass", [pd.Series, pd.Index], ids=["series", "index"])
@pytest.mark.parametrize(
"fill_val, fill_dtype",
[(1, np.object), (1.1, np.object), (1 + 1j, np.object), (True, np.object)],
)
def test_fillna_object(self, klass, fill_val, fill_dtype):
obj = klass(["a", np.nan, "c", "d"])
def test_fillna_object(self, index_or_series, fill_val, fill_dtype):
obj = index_or_series(["a", np.nan, "c", "d"])
assert obj.dtype == np.object

exp = klass(["a", fill_val, "c", "d"])
exp = index_or_series(["a", fill_val, "c", "d"])
self._assert_fillna_conversion(obj, fill_val, exp, fill_dtype)

@pytest.mark.parametrize("klass", [pd.Series, pd.Index], ids=["series", "index"])
@pytest.mark.parametrize(
"fill_val,fill_dtype",
[
Expand All @@ -805,15 +800,15 @@ def test_fillna_object(self, klass, fill_val, fill_dtype):
(True, np.object),
],
)
def test_fillna_float64(self, klass, fill_val, fill_dtype):
obj = klass([1.1, np.nan, 3.3, 4.4])
def test_fillna_float64(self, index_or_series, fill_val, fill_dtype):
obj = index_or_series([1.1, np.nan, 3.3, 4.4])
assert obj.dtype == np.float64

exp = klass([1.1, fill_val, 3.3, 4.4])
exp = index_or_series([1.1, fill_val, 3.3, 4.4])
# float + complex -> we don't support a complex Index
# complex for Series,
# object for Index
if fill_dtype == np.complex128 and klass == pd.Index:
if fill_dtype == np.complex128 and index_or_series == pd.Index:
fill_dtype = np.object
self._assert_fillna_conversion(obj, fill_val, exp, fill_dtype)

Expand All @@ -833,7 +828,6 @@ def test_fillna_series_complex128(self, fill_val, fill_dtype):
exp = pd.Series([1 + 1j, fill_val, 3 + 3j, 4 + 4j])
self._assert_fillna_conversion(obj, fill_val, exp, fill_dtype)

@pytest.mark.parametrize("klass", [pd.Series, pd.Index], ids=["series", "index"])
@pytest.mark.parametrize(
"fill_val,fill_dtype",
[
Expand All @@ -844,8 +838,8 @@ def test_fillna_series_complex128(self, fill_val, fill_dtype):
],
ids=["datetime64", "datetime64tz", "object", "object"],
)
def test_fillna_datetime(self, klass, fill_val, fill_dtype):
obj = klass(
def test_fillna_datetime(self, index_or_series, fill_val, fill_dtype):
obj = index_or_series(
[
pd.Timestamp("2011-01-01"),
pd.NaT,
Expand All @@ -855,7 +849,7 @@ def test_fillna_datetime(self, klass, fill_val, fill_dtype):
)
assert obj.dtype == "datetime64[ns]"

exp = klass(
exp = index_or_series(
[
pd.Timestamp("2011-01-01"),
fill_val,
Expand All @@ -865,7 +859,6 @@ def test_fillna_datetime(self, klass, fill_val, fill_dtype):
)
self._assert_fillna_conversion(obj, fill_val, exp, fill_dtype)

@pytest.mark.parametrize("klass", [pd.Series, pd.Index])
@pytest.mark.parametrize(
"fill_val,fill_dtype",
[
Expand All @@ -876,10 +869,10 @@ def test_fillna_datetime(self, klass, fill_val, fill_dtype):
("x", np.object),
],
)
def test_fillna_datetime64tz(self, klass, fill_val, fill_dtype):
def test_fillna_datetime64tz(self, index_or_series, fill_val, fill_dtype):
tz = "US/Eastern"

obj = klass(
obj = index_or_series(
[
pd.Timestamp("2011-01-01", tz=tz),
pd.NaT,
Expand All @@ -889,7 +882,7 @@ def test_fillna_datetime64tz(self, klass, fill_val, fill_dtype):
)
assert obj.dtype == "datetime64[ns, US/Eastern]"

exp = klass(
exp = index_or_series(
[
pd.Timestamp("2011-01-01", tz=tz),
fill_val,
Expand Down
10 changes: 4 additions & 6 deletions pandas/tests/io/json/test_json_table_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,17 +431,15 @@ def test_date_format_raises(self):
self.df.to_json(orient="table", date_format="iso")
self.df.to_json(orient="table")

@pytest.mark.parametrize("kind", [pd.Series, pd.Index])
def test_convert_pandas_type_to_json_field_int(self, kind):
def test_convert_pandas_type_to_json_field_int(self, index_or_series):
data = [1, 2, 3]
result = convert_pandas_type_to_json_field(kind(data, name="name"))
result = convert_pandas_type_to_json_field(index_or_series(data, name="name"))
expected = {"name": "name", "type": "integer"}
assert result == expected

@pytest.mark.parametrize("kind", [pd.Series, pd.Index])
def test_convert_pandas_type_to_json_field_float(self, kind):
def test_convert_pandas_type_to_json_field_float(self, index_or_series):
data = [1.0, 2.0, 3.0]
result = convert_pandas_type_to_json_field(kind(data, name="name"))
result = convert_pandas_type_to_json_field(index_or_series(data, name="name"))
expected = {"name": "name", "type": "number"}
assert result == expected

Expand Down
Loading

0 comments on commit b4e11f1

Please sign in to comment.