Skip to content

Commit

Permalink
REF: collect dt64<->dt64tz astype in dtypes.cast (#38662)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Dec 23, 2020
1 parent 573caff commit 0805043
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 27 deletions.
22 changes: 4 additions & 18 deletions pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@
)
from pandas.errors import PerformanceWarning

from pandas.core.dtypes.cast import astype_dt64_to_dt64tz
from pandas.core.dtypes.common import (
DT64NS_DTYPE,
INT64_DTYPE,
is_bool_dtype,
is_categorical_dtype,
is_datetime64_any_dtype,
is_datetime64_dtype,
is_datetime64_ns_dtype,
is_datetime64tz_dtype,
is_dtype_equal,
is_extension_array_dtype,
Expand Down Expand Up @@ -591,24 +593,8 @@ def astype(self, dtype, copy=True):
return self.copy()
return self

elif is_datetime64tz_dtype(dtype) and self.tz is None:
# FIXME: GH#33401 this does not match Series behavior
return self.tz_localize(dtype.tz)

elif is_datetime64tz_dtype(dtype):
# GH#18951: datetime64_ns dtype but not equal means different tz
result = self.tz_convert(dtype.tz)
if copy:
result = result.copy()
return result

elif dtype == "M8[ns]":
# we must have self.tz is None, otherwise we would have gone through
# the is_dtype_equal branch above.
result = self.tz_convert("UTC").tz_localize(None)
if copy:
result = result.copy()
return result
elif is_datetime64_ns_dtype(dtype):
return astype_dt64_to_dt64tz(self, dtype, copy, via_utc=False)

elif is_period_dtype(dtype):
return self.to_period(freq=dtype.freq)
Expand Down
53 changes: 52 additions & 1 deletion pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Tuple,
Type,
Union,
cast,
)
import warnings

Expand Down Expand Up @@ -85,7 +86,7 @@

if TYPE_CHECKING:
from pandas import Series
from pandas.core.arrays import ExtensionArray
from pandas.core.arrays import DatetimeArray, ExtensionArray

_int8_max = np.iinfo(np.int8).max
_int16_max = np.iinfo(np.int16).max
Expand Down Expand Up @@ -920,6 +921,56 @@ def coerce_indexer_dtype(indexer, categories):
return ensure_int64(indexer)


def astype_dt64_to_dt64tz(
values: ArrayLike, dtype: DtypeObj, copy: bool, via_utc: bool = False
) -> "DatetimeArray":
# GH#33401 we have inconsistent behaviors between
# Datetimeindex[naive].astype(tzaware)
# Series[dt64].astype(tzaware)
# This collects them in one place to prevent further fragmentation.

from pandas.core.construction import ensure_wrapped_if_datetimelike

values = ensure_wrapped_if_datetimelike(values)
values = cast("DatetimeArray", values)
aware = isinstance(dtype, DatetimeTZDtype)

if via_utc:
# Series.astype behavior
assert values.tz is None and aware # caller is responsible for checking this
dtype = cast(DatetimeTZDtype, dtype)

if copy:
# this should be the only copy
values = values.copy()
# FIXME: GH#33401 this doesn't match DatetimeArray.astype, which
# goes through the `not via_utc` path
return values.tz_localize("UTC").tz_convert(dtype.tz)

else:
# DatetimeArray/DatetimeIndex.astype behavior

if values.tz is None and aware:
dtype = cast(DatetimeTZDtype, dtype)
return values.tz_localize(dtype.tz)

elif aware:
# GH#18951: datetime64_tz dtype but not equal means different tz
dtype = cast(DatetimeTZDtype, dtype)
result = values.tz_convert(dtype.tz)
if copy:
result = result.copy()
return result

elif values.tz is not None and not aware:
result = values.tz_convert("UTC").tz_localize(None)
if copy:
result = result.copy()
return result

raise NotImplementedError("dtype_equal case should be handled elsewhere")


def astype_td64_unit_conversion(
values: np.ndarray, dtype: np.dtype, copy: bool
) -> np.ndarray:
Expand Down
10 changes: 2 additions & 8 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pandas.util._validators import validate_bool_kwarg

from pandas.core.dtypes.cast import (
astype_dt64_to_dt64tz,
astype_nansafe,
convert_scalar_for_putitemlike,
find_common_type,
Expand Down Expand Up @@ -649,14 +650,7 @@ def _astype(self, dtype: DtypeObj, copy: bool) -> ArrayLike:
values = self.values

if is_datetime64tz_dtype(dtype) and is_datetime64_dtype(values.dtype):
# if we are passed a datetime64[ns, tz]
if copy:
# this should be the only copy
values = values.copy()
# i.e. values.tz_localize("UTC").tz_convert(dtype.tz)
# FIXME: GH#33401 this doesn't match DatetimeArray.astype, which
# would be self.array_values().tz_localize(dtype.tz)
return DatetimeArray._simple_new(values.view("i8"), dtype=dtype)
return astype_dt64_to_dt64tz(values, dtype, copy, via_utc=True)

if is_dtype_equal(values.dtype, dtype):
if copy:
Expand Down

0 comments on commit 0805043

Please sign in to comment.