diff --git a/pandas/core/arrays/datetimes.py b/pandas/core/arrays/datetimes.py index aa1d8f6254e2c..2b2be214428d2 100644 --- a/pandas/core/arrays/datetimes.py +++ b/pandas/core/arrays/datetimes.py @@ -24,6 +24,7 @@ ) from pandas.errors import PerformanceWarning +from pandas.core.dtypes.cast import astype_dt64_to_dt64tz from pandas.core.dtypes.common import ( DT64NS_DTYPE, INT64_DTYPE, @@ -31,6 +32,7 @@ is_categorical_dtype, is_datetime64_any_dtype, is_datetime64_dtype, + is_datetime64_ns_dtype, is_datetime64tz_dtype, is_dtype_equal, is_extension_array_dtype, @@ -591,24 +593,8 @@ def astype(self, dtype, copy=True): return self.copy() return self - elif is_datetime64tz_dtype(dtype) and self.tz is None: - # FIXME: GH#33401 this does not match Series behavior - return self.tz_localize(dtype.tz) - - elif is_datetime64tz_dtype(dtype): - # GH#18951: datetime64_ns dtype but not equal means different tz - result = self.tz_convert(dtype.tz) - if copy: - result = result.copy() - return result - - elif dtype == "M8[ns]": - # we must have self.tz is None, otherwise we would have gone through - # the is_dtype_equal branch above. - result = self.tz_convert("UTC").tz_localize(None) - if copy: - result = result.copy() - return result + elif is_datetime64_ns_dtype(dtype): + return astype_dt64_to_dt64tz(self, dtype, copy, via_utc=False) elif is_period_dtype(dtype): return self.to_period(freq=dtype.freq) diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index ce50b81be7d63..4d7a06b691fa3 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -15,6 +15,7 @@ Tuple, Type, Union, + cast, ) import warnings @@ -85,7 +86,7 @@ if TYPE_CHECKING: from pandas import Series - from pandas.core.arrays import ExtensionArray + from pandas.core.arrays import DatetimeArray, ExtensionArray _int8_max = np.iinfo(np.int8).max _int16_max = np.iinfo(np.int16).max @@ -920,6 +921,56 @@ def coerce_indexer_dtype(indexer, categories): return ensure_int64(indexer) +def astype_dt64_to_dt64tz( + values: ArrayLike, dtype: DtypeObj, copy: bool, via_utc: bool = False +) -> "DatetimeArray": + # GH#33401 we have inconsistent behaviors between + # Datetimeindex[naive].astype(tzaware) + # Series[dt64].astype(tzaware) + # This collects them in one place to prevent further fragmentation. + + from pandas.core.construction import ensure_wrapped_if_datetimelike + + values = ensure_wrapped_if_datetimelike(values) + values = cast("DatetimeArray", values) + aware = isinstance(dtype, DatetimeTZDtype) + + if via_utc: + # Series.astype behavior + assert values.tz is None and aware # caller is responsible for checking this + dtype = cast(DatetimeTZDtype, dtype) + + if copy: + # this should be the only copy + values = values.copy() + # FIXME: GH#33401 this doesn't match DatetimeArray.astype, which + # goes through the `not via_utc` path + return values.tz_localize("UTC").tz_convert(dtype.tz) + + else: + # DatetimeArray/DatetimeIndex.astype behavior + + if values.tz is None and aware: + dtype = cast(DatetimeTZDtype, dtype) + return values.tz_localize(dtype.tz) + + elif aware: + # GH#18951: datetime64_tz dtype but not equal means different tz + dtype = cast(DatetimeTZDtype, dtype) + result = values.tz_convert(dtype.tz) + if copy: + result = result.copy() + return result + + elif values.tz is not None and not aware: + result = values.tz_convert("UTC").tz_localize(None) + if copy: + result = result.copy() + return result + + raise NotImplementedError("dtype_equal case should be handled elsewhere") + + def astype_td64_unit_conversion( values: np.ndarray, dtype: np.dtype, copy: bool ) -> np.ndarray: diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index 8752224356f61..27acd720b6d71 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -22,6 +22,7 @@ from pandas.util._validators import validate_bool_kwarg from pandas.core.dtypes.cast import ( + astype_dt64_to_dt64tz, astype_nansafe, convert_scalar_for_putitemlike, find_common_type, @@ -649,14 +650,7 @@ def _astype(self, dtype: DtypeObj, copy: bool) -> ArrayLike: values = self.values if is_datetime64tz_dtype(dtype) and is_datetime64_dtype(values.dtype): - # if we are passed a datetime64[ns, tz] - if copy: - # this should be the only copy - values = values.copy() - # i.e. values.tz_localize("UTC").tz_convert(dtype.tz) - # FIXME: GH#33401 this doesn't match DatetimeArray.astype, which - # would be self.array_values().tz_localize(dtype.tz) - return DatetimeArray._simple_new(values.view("i8"), dtype=dtype) + return astype_dt64_to_dt64tz(values, dtype, copy, via_utc=True) if is_dtype_equal(values.dtype, dtype): if copy: