diff --git a/pandas/_libs/lib.pyx b/pandas/_libs/lib.pyx index 1f0f0a408aee8..34ceeb20e260e 100644 --- a/pandas/_libs/lib.pyx +++ b/pandas/_libs/lib.pyx @@ -939,6 +939,7 @@ _TYPE_MAP = { 'float32': 'floating', 'float64': 'floating', 'f': 'floating', + 'complex64': 'complex', 'complex128': 'complex', 'c': 'complex', 'string': 'string' if PY2 else 'bytes', @@ -1305,6 +1306,9 @@ def infer_dtype(value: object, skipna: object=None) -> str: elif is_decimal(val): return 'decimal' + elif is_complex(val): + return 'complex' + elif util.is_float_object(val): if is_float_array(values): return 'floating' diff --git a/pandas/tests/dtypes/test_inference.py b/pandas/tests/dtypes/test_inference.py index 49a66efaffc11..187b37d4f788e 100644 --- a/pandas/tests/dtypes/test_inference.py +++ b/pandas/tests/dtypes/test_inference.py @@ -618,6 +618,37 @@ def test_decimals(self): result = lib.infer_dtype(arr, skipna=True) assert result == 'decimal' + # complex is compatible with nan, so skipna has no effect + @pytest.mark.parametrize('skipna', [True, False]) + def test_complex(self, skipna): + # gets cast to complex on array construction + arr = np.array([1.0, 2.0, 1 + 1j]) + result = lib.infer_dtype(arr, skipna=skipna) + assert result == 'complex' + + arr = np.array([1.0, 2.0, 1 + 1j], dtype='O') + result = lib.infer_dtype(arr, skipna=skipna) + assert result == 'mixed' + + # gets cast to complex on array construction + arr = np.array([1, np.nan, 1 + 1j]) + result = lib.infer_dtype(arr, skipna=skipna) + assert result == 'complex' + + arr = np.array([1.0, np.nan, 1 + 1j], dtype='O') + result = lib.infer_dtype(arr, skipna=skipna) + assert result == 'mixed' + + # complex with nans stays complex + arr = np.array([1 + 1j, np.nan, 3 + 3j], dtype='O') + result = lib.infer_dtype(arr, skipna=skipna) + assert result == 'complex' + + # test smaller complex dtype; will pass through _try_infer_map fastpath + arr = np.array([1 + 1j, np.nan, 3 + 3j], dtype=np.complex64) + result = lib.infer_dtype(arr, skipna=skipna) + assert result == 'complex' + def test_string(self): pass