Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Dataset.dtypes property #6706

Merged
merged 7 commits into from
Jun 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Attributes

Dataset.dims
Dataset.sizes
Dataset.dtypes
Dataset.data_vars
Dataset.coords
Dataset.attrs
Expand Down
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ v2022.06.0 (unreleased)
New Features
~~~~~~~~~~~~

- Add :py:meth:`Dataset.dtypes`, :py:meth:`DatasetCoordinates.dtypes`,
:py:meth:`DataArrayCoordinates.dtypes` properties: Mapping from variable names to dtypes.
(:pull:`6706`)
By `Michael Niklas <https://github.com/headtr1ck>`_.

Deprecations
~~~~~~~~~~~~
Expand Down
34 changes: 34 additions & 0 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def _names(self) -> set[Hashable]:
def dims(self) -> Mapping[Hashable, int] | tuple[Hashable, ...]:
raise NotImplementedError()

@property
def dtypes(self) -> Frozen[Hashable, np.dtype]:
raise NotImplementedError()

@property
def indexes(self) -> Indexes[pd.Index]:
return self._data.indexes # type: ignore[attr-defined]
Expand Down Expand Up @@ -242,6 +246,24 @@ def _names(self) -> set[Hashable]:
def dims(self) -> Mapping[Hashable, int]:
return self._data.dims

@property
def dtypes(self) -> Frozen[Hashable, np.dtype]:
"""Mapping from coordinate names to dtypes.

Cannot be modified directly, but is updated when adding new variables.

See Also
--------
Dataset.dtypes
"""
return Frozen(
{
n: v.dtype
for n, v in self._data._variables.items()
if n in self._data._coord_names
}
)

@property
def variables(self) -> Mapping[Hashable, Variable]:
return Frozen(
Expand Down Expand Up @@ -313,6 +335,18 @@ def __init__(self, dataarray: DataArray):
def dims(self) -> tuple[Hashable, ...]:
return self._data.dims

@property
def dtypes(self) -> Frozen[Hashable, np.dtype]:
"""Mapping from coordinate names to dtypes.

Cannot be modified directly, but is updated when adding new variables.

See Also
--------
DataArray.dtype
"""
return Frozen({n: v.dtype for n, v in self._data._coords.items()})

@property
def _names(self) -> set[Hashable]:
return set(self._data._coords)
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@

def _infer_coords_and_dims(
shape, coords, dims
) -> tuple[dict[Any, Variable], tuple[Hashable, ...]]:
) -> tuple[dict[Hashable, Variable], tuple[Hashable, ...]]:
max-sixty marked this conversation as resolved.
Show resolved Hide resolved
"""All the logic for creating a new DataArray"""

if (
Expand Down Expand Up @@ -140,7 +140,7 @@ def _infer_coords_and_dims(
if not isinstance(d, str):
raise TypeError(f"dimension {d} is not a string")

new_coords: dict[Any, Variable] = {}
new_coords: dict[Hashable, Variable] = {}

if utils.is_dict_like(coords):
for k, v in coords.items():
Expand Down
60 changes: 46 additions & 14 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
from ..coding.calendar_ops import convert_calendar, interp_calendar
from ..coding.cftimeindex import CFTimeIndex, _parse_array_of_cftime_strings
from ..plot.dataset_plot import _Dataset_PlotMethods
from . import alignment
from . import dtypes as xrdtypes
from . import (
alignment,
dtypes,
duck_array_ops,
formatting,
formatting_html,
Expand Down Expand Up @@ -385,6 +385,18 @@ def variables(self) -> Mapping[Hashable, Variable]:
all_variables = self._dataset.variables
return Frozen({k: all_variables[k] for k in self})

@property
def dtypes(self) -> Frozen[Hashable, np.dtype]:
"""Mapping from data variable names to dtypes.

Cannot be modified directly, but is updated when adding new variables.

See Also
--------
Dataset.dtype
"""
return self._dataset.dtypes

def _ipython_key_completions_(self):
"""Provide method for the key-autocompletions in IPython."""
return [
Expand Down Expand Up @@ -677,6 +689,24 @@ def sizes(self) -> Frozen[Hashable, int]:
"""
return self.dims

@property
def dtypes(self) -> Frozen[Hashable, np.dtype]:
"""Mapping from data variable names to dtypes.

Cannot be modified directly, but is updated when adding new variables.

See Also
--------
DataArray.dtype
"""
return Frozen(
{
n: v.dtype
for n, v in self._variables.items()
if n not in self._coord_names
}
)

def load(self: T_Dataset, **kwargs) -> T_Dataset:
"""Manually trigger loading and/or computation of this dataset's data
from disk or a remote source into memory and return this dataset.
Expand Down Expand Up @@ -2791,7 +2821,7 @@ def reindex_like(
method: ReindexMethodOptions = None,
tolerance: int | float | Iterable[int | float] | None = None,
copy: bool = True,
fill_value: Any = dtypes.NA,
fill_value: Any = xrdtypes.NA,
) -> T_Dataset:
"""Conform this object onto the indexes of another object, filling in
missing values with ``fill_value``. The default fill value is NaN.
Expand Down Expand Up @@ -2857,7 +2887,7 @@ def reindex(
method: ReindexMethodOptions = None,
tolerance: int | float | Iterable[int | float] | None = None,
copy: bool = True,
fill_value: Any = dtypes.NA,
fill_value: Any = xrdtypes.NA,
**indexers_kwargs: Any,
) -> T_Dataset:
"""Conform this object onto a new set of indexes, filling in
Expand Down Expand Up @@ -3073,7 +3103,7 @@ def _reindex(
method: str = None,
tolerance: int | float | Iterable[int | float] | None = None,
copy: bool = True,
fill_value: Any = dtypes.NA,
fill_value: Any = xrdtypes.NA,
sparse: bool = False,
**indexers_kwargs: Any,
) -> T_Dataset:
Expand Down Expand Up @@ -4531,7 +4561,7 @@ def _unstack_full_reindex(
def unstack(
self: T_Dataset,
dim: Hashable | Iterable[Hashable] | None = None,
fill_value: Any = dtypes.NA,
fill_value: Any = xrdtypes.NA,
sparse: bool = False,
) -> T_Dataset:
"""
Expand Down Expand Up @@ -4676,7 +4706,7 @@ def merge(
overwrite_vars: Hashable | Iterable[Hashable] = frozenset(),
compat: CompatOptions = "no_conflicts",
join: JoinOptions = "outer",
fill_value: Any = dtypes.NA,
fill_value: Any = xrdtypes.NA,
combine_attrs: CombineAttrsOptions = "override",
) -> T_Dataset:
"""Merge the arrays of two datasets into a single dataset.
Expand Down Expand Up @@ -5885,7 +5915,7 @@ def _set_sparse_data_from_dataframe(
# missing values and needs a fill_value. For consistency, don't
# special case the rare exceptions (e.g., dtype=int without a
# MultiIndex).
dtype, fill_value = dtypes.maybe_promote(values.dtype)
dtype, fill_value = xrdtypes.maybe_promote(values.dtype)
values = np.asarray(values, dtype=dtype)

data = COO(
Expand Down Expand Up @@ -5923,7 +5953,7 @@ def _set_numpy_data_from_dataframe(
# fill in missing values:
# https://stackoverflow.com/a/35049899/809705
if missing_values:
dtype, fill_value = dtypes.maybe_promote(values.dtype)
dtype, fill_value = xrdtypes.maybe_promote(values.dtype)
data = np.full(shape, fill_value, dtype)
else:
# If there are no missing values, keep the existing dtype
Expand Down Expand Up @@ -6414,7 +6444,7 @@ def diff(
def shift(
self: T_Dataset,
shifts: Mapping[Any, int] | None = None,
fill_value: Any = dtypes.NA,
fill_value: Any = xrdtypes.NA,
**shifts_kwargs: int,
) -> T_Dataset:

Expand Down Expand Up @@ -6469,7 +6499,7 @@ def shift(
for name, var in self.variables.items():
if name in self.data_vars:
fill_value_ = (
fill_value.get(name, dtypes.NA)
fill_value.get(name, xrdtypes.NA)
if isinstance(fill_value, dict)
else fill_value
)
Expand Down Expand Up @@ -6930,7 +6960,9 @@ def differentiate(
dim = coord_var.dims[0]
if _contains_datetime_like_objects(coord_var):
if coord_var.dtype.kind in "mM" and datetime_unit is None:
datetime_unit, _ = np.datetime_data(coord_var.dtype)
datetime_unit = cast(
"DatetimeUnitOptions", np.datetime_data(coord_var.dtype)[0]
)
elif datetime_unit is None:
datetime_unit = "s" # Default to seconds for cftime objects
coord_var = coord_var._to_numeric(datetime_unit=datetime_unit)
Expand Down Expand Up @@ -7743,7 +7775,7 @@ def idxmin(
self: T_Dataset,
dim: Hashable | None = None,
skipna: bool | None = None,
fill_value: Any = dtypes.NA,
fill_value: Any = xrdtypes.NA,
keep_attrs: bool | None = None,
) -> T_Dataset:
"""Return the coordinate label of the minimum value along a dimension.
Expand Down Expand Up @@ -7840,7 +7872,7 @@ def idxmax(
self: T_Dataset,
dim: Hashable | None = None,
skipna: bool | None = None,
fill_value: Any = dtypes.NA,
fill_value: Any = xrdtypes.NA,
keep_attrs: bool | None = None,
) -> T_Dataset:
"""Return the coordinate label of the maximum value along a dimension.
Expand Down
10 changes: 8 additions & 2 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,9 +1321,11 @@ def test_coords(self) -> None:
]
da = DataArray(np.random.randn(2, 3), coords, name="foo")

assert 2 == len(da.coords)
# len
assert len(da.coords) == 2

assert ["x", "y"] == list(da.coords)
# iter
assert list(da.coords) == ["x", "y"]

assert coords[0].identical(da.coords["x"])
assert coords[1].identical(da.coords["y"])
Expand All @@ -1337,6 +1339,7 @@ def test_coords(self) -> None:
with pytest.raises(KeyError):
da.coords["foo"]

# repr
expected_repr = dedent(
"""\
Coordinates:
Expand All @@ -1346,6 +1349,9 @@ def test_coords(self) -> None:
actual = repr(da.coords)
assert expected_repr == actual

# dtypes
assert da.coords.dtypes == {"x": np.dtype("int64"), "y": np.dtype("int64")}

del da.coords["x"]
da._indexes = filter_indexes_from_coords(da.xindexes, set(da.coords))
expected = DataArray(da.values, {"y": [0, 1, 2]}, dims=["x", "y"], name="foo")
Expand Down
Loading