diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index e31918c21c2ac..f5406d9f89af8 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -358,6 +358,7 @@ def maybe_promote(dtype, fill_value=np.nan): fill_value = NaT elif is_extension_array_dtype(dtype) and isna(fill_value): fill_value = dtype.na_value + elif is_float(fill_value): if issubclass(dtype.type, np.bool_): dtype = np.object_ @@ -366,6 +367,8 @@ def maybe_promote(dtype, fill_value=np.nan): elif is_bool(fill_value): if not issubclass(dtype.type, np.bool_): dtype = np.object_ + else: + fill_value = np.bool_(fill_value) elif is_integer(fill_value): if issubclass(dtype.type, np.bool_): dtype = np.object_ @@ -374,6 +377,10 @@ def maybe_promote(dtype, fill_value=np.nan): arr = np.asarray(fill_value) if arr != arr.astype(dtype): dtype = arr.dtype + elif issubclass(dtype.type, np.floating): + # check if we can cast + if _check_lossless_cast(fill_value, dtype): + fill_value = dtype.type(fill_value) elif is_complex(fill_value): if issubclass(dtype.type, np.bool_): dtype = np.object_ @@ -404,6 +411,25 @@ def maybe_promote(dtype, fill_value=np.nan): return dtype, fill_value +def _check_lossless_cast(value, dtype: np.dtype) -> bool: + """ + Check if we can cast the given value to the given dtype _losslessly_. + + Parameters + ---------- + value : object + dtype : np.dtype + + Returns + ------- + bool + """ + casted = dtype.type(value) + if casted == value: + return True + return False + + def infer_dtype_from(val, pandas_dtype=False): """ interpret the dtype from a scalar or array. This is a convenience diff --git a/pandas/tests/dtypes/cast/test_promote.py b/pandas/tests/dtypes/cast/test_promote.py index 44aebd4d277f2..1ea49602a8b78 100644 --- a/pandas/tests/dtypes/cast/test_promote.py +++ b/pandas/tests/dtypes/cast/test_promote.py @@ -23,6 +23,7 @@ is_timedelta64_dtype, ) from pandas.core.dtypes.dtypes import DatetimeTZDtype, PandasExtensionDtype +from pandas.core.dtypes.missing import isna import pandas as pd @@ -95,6 +96,7 @@ def _safe_dtype_assert(left_dtype, right_dtype): """ Compare two dtypes without raising TypeError. """ + __tracebackhide__ = True if isinstance(right_dtype, PandasExtensionDtype): # switch order of equality check because numpy dtypes (e.g. if # left_dtype is np.object_) do not know some expected dtypes (e.g. @@ -157,20 +159,17 @@ def _check_promote( _safe_dtype_assert(result_dtype, expected_dtype) - # for equal values, also check type (relevant e.g. for int vs float, resp. - # for different datetimes and timedeltas) - match_value = ( - result_fill_value - == expected_fill_value - # disabled type check due to too many xfails; GH 23982/25425 - # and type(result_fill_value) == type(expected_fill_value) - ) + # GH#23982/25425 require the same type in addition to equality/NA-ness + res_type = type(result_fill_value) + ex_type = type(expected_fill_value) + assert res_type == ex_type + + match_value = result_fill_value == expected_fill_value + # Note: type check above ensures that we have the _same_ NA value # for missing values, None == None and iNaT == iNaT (which is checked # through match_value above), but np.nan != np.nan and pd.NaT != pd.NaT - match_missing = (result_fill_value is np.nan and expected_fill_value is np.nan) or ( - result_fill_value is NaT and expected_fill_value is NaT - ) + match_missing = isna(result_fill_value) and isna(expected_fill_value) assert match_value or match_missing