Skip to content

Commit

Permalink
ENH: switch Dataset and DataArray to use explicit indexes (#2639)
Browse files Browse the repository at this point in the history
* ENH: switch Dataset and DataArray to use explicit indexes

This change switches Dataset.indexes and DataArray.indexes to be backed by
explicit dictionaries of indexes, instead of being implicitly defined by
the set of coordinates with names matching dimensions.

There are no changes to the public interface yet: these will come later.

For now, indexes are recreated from coordinates every time a new DataArray
or Dataset is created. In follow-up PRs, I will refactor indexes to be
propagated explicitly in xarray operations. This will facilitate future API
changes, when indexes will no longer only be associated with dimensions.

* Add xarray.core.indexes

* Fixes per review
  • Loading branch information
shoyer authored Jan 4, 2019
1 parent 28123bb commit 06244df
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 50 deletions.
39 changes: 1 addition & 38 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def _update_coords(self, coords):
self._data._variables = variables
self._data._coord_names.update(new_coord_names)
self._data._dims = dict(dims)
self._data._indexes = None

def __delitem__(self, key):
if key in self:
Expand Down Expand Up @@ -276,44 +277,6 @@ def __iter__(self):
return iter(self._data._level_coords)


class Indexes(Mapping, formatting.ReprMixin):
"""Ordered Mapping[str, pandas.Index] for xarray objects.
"""

def __init__(self, variables, sizes):
"""Not for public consumption.
Parameters
----------
variables : OrderedDict[Any, Variable]
Reference to OrderedDict holding variable objects. Should be the
same dictionary used by the source object.
sizes : OrderedDict[Any, int]
Map from dimension names to sizes.
"""
self._variables = variables
self._sizes = sizes

def __iter__(self):
for key in self._sizes:
if key in self._variables:
yield key

def __len__(self):
return sum(key in self._variables for key in self._sizes)

def __contains__(self, key):
return key in self._sizes and key in self._variables

def __getitem__(self, key):
if key not in self._sizes:
raise KeyError(key)
return self._variables[key].to_index()

def __unicode__(self):
return formatting.indexes_repr(self)


def assert_coordinate_consistent(obj, coords):
""" Maeke sure the dimension coordinate of obj is
consistent with coords.
Expand Down
15 changes: 11 additions & 4 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
from .alignment import align, reindex_like_indexers
from .common import AbstractArray, DataWithCoords
from .coordinates import (
DataArrayCoordinates, Indexes, LevelCoordinatesSource,
DataArrayCoordinates, LevelCoordinatesSource,
assert_coordinate_consistent, remap_label_indexers)
from .dataset import Dataset, merge_indexes, split_indexes
from .formatting import format_item
from .indexes import default_indexes, Indexes
from .options import OPTIONS
from .pycompat import OrderedDict, basestring, iteritems, range, zip
from .utils import (
Expand Down Expand Up @@ -165,7 +166,7 @@ class DataArray(AbstractArray, DataWithCoords):
dt = property(DatetimeAccessor)

def __init__(self, data, coords=None, dims=None, name=None,
attrs=None, encoding=None, fastpath=False):
attrs=None, encoding=None, indexes=None, fastpath=False):
"""
Parameters
----------
Expand Down Expand Up @@ -237,6 +238,10 @@ def __init__(self, data, coords=None, dims=None, name=None,
self._coords = coords
self._name = name

# TODO(shoyer): document this argument, once it becomes part of the
# public interface.
self._indexes = indexes

self._file_obj = None

self._initialized = True
Expand Down Expand Up @@ -534,9 +539,11 @@ def encoding(self, value):

@property
def indexes(self):
"""OrderedDict of pandas.Index objects used for label based indexing
"""Mapping of pandas.Index objects used for label based indexing
"""
return Indexes(self._coords, self.sizes)
if self._indexes is None:
self._indexes = default_indexes(self._coords, self.dims)
return Indexes(self._indexes)

@property
def coords(self):
Expand Down
27 changes: 19 additions & 8 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@
import xarray as xr

from . import (
alignment, dtypes, duck_array_ops, formatting, groupby, indexing, ops,
pdcompat, resample, rolling, utils)
alignment, dtypes, duck_array_ops, formatting, groupby,
indexing, ops, pdcompat, resample, rolling, utils)
from ..coding.cftimeindex import _parse_array_of_cftime_strings
from .alignment import align
from .common import (
ALL_DIMS, DataWithCoords, ImplementsDatasetReduce,
_contains_datetime_like_objects)
from .coordinates import (
DatasetCoordinates, Indexes, LevelCoordinatesSource,
DatasetCoordinates, LevelCoordinatesSource,
assert_coordinate_consistent, remap_label_indexers)
from .indexes import Indexes, default_indexes
from .merge import (
dataset_merge_method, dataset_update_method, merge_data_and_coords,
merge_variables)
Expand Down Expand Up @@ -364,6 +365,10 @@ def __init__(self, data_vars=None, coords=None, attrs=None,
coords = {}
if data_vars is not None or coords is not None:
self._set_init_vars_and_dims(data_vars, coords, compat)

# TODO(shoyer): expose indexes as a public argument in __init__
self._indexes = None

if attrs is not None:
self.attrs = attrs
self._encoding = None
Expand Down Expand Up @@ -642,14 +647,15 @@ def persist(self, **kwargs):

@classmethod
def _construct_direct(cls, variables, coord_names, dims=None, attrs=None,
file_obj=None, encoding=None):
indexes=None, file_obj=None, encoding=None):
"""Shortcut around __init__ for internal use when we want to skip
costly validation
"""
obj = object.__new__(cls)
obj._variables = variables
obj._coord_names = coord_names
obj._dims = dims
obj._indexes = indexes
obj._attrs = attrs
obj._file_obj = file_obj
obj._encoding = encoding
Expand All @@ -664,7 +670,8 @@ def _from_vars_and_coord_names(cls, variables, coord_names, attrs=None):
return cls._construct_direct(variables, coord_names, dims, attrs)

def _replace_vars_and_dims(self, variables, coord_names=None, dims=None,
attrs=__default_attrs, inplace=False):
attrs=__default_attrs, indexes=None,
inplace=False):
"""Fastpath constructor for internal use.
Preserves coord names and attributes. If not provided explicitly,
Expand Down Expand Up @@ -693,13 +700,15 @@ def _replace_vars_and_dims(self, variables, coord_names=None, dims=None,
self._coord_names = coord_names
if attrs is not self.__default_attrs:
self._attrs = attrs
self._indexes = indexes
obj = self
else:
if coord_names is None:
coord_names = self._coord_names.copy()
if attrs is self.__default_attrs:
attrs = self._attrs_copy()
obj = self._construct_direct(variables, coord_names, dims, attrs)
obj = self._construct_direct(
variables, coord_names, dims, attrs, indexes)
return obj

def _replace_indexes(self, indexes):
Expand Down Expand Up @@ -1064,9 +1073,11 @@ def identical(self, other):

@property
def indexes(self):
"""OrderedDict of pandas.Index objects used for label based indexing
"""Mapping of pandas.Index objects used for label based indexing
"""
return Indexes(self._variables, self._dims)
if self._indexes is None:
self._indexes = default_indexes(self._variables, self._dims)
return Indexes(self._indexes)

@property
def coords(self):
Expand Down
55 changes: 55 additions & 0 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from __future__ import absolute_import, division, print_function
try:
from collections.abc import Mapping
except ImportError:
from collections import Mapping
from collections import OrderedDict

from . import formatting


class Indexes(Mapping, formatting.ReprMixin):
"""Immutable proxy for Dataset or DataArrary indexes."""
def __init__(self, indexes):
"""Not for public consumption.
Parameters
----------
indexes : Dict[Any, pandas.Index]
Indexes held by this object.
"""
self._indexes = indexes

def __iter__(self):
return iter(self._indexes)

def __len__(self):
return len(self._indexes)

def __contains__(self, key):
return key in self._indexes

def __getitem__(self, key):
return self._indexes[key]

def __unicode__(self):
return formatting.indexes_repr(self)


def default_indexes(coords, dims):
"""Default indexes for a Dataset/DataArray.
Parameters
----------
coords : Mapping[Any, xarray.Variable]
Coordinate variables from which to draw default indexes.
dims : iterable
Iterable of dimension names.
Returns
-------
Mapping[Any, pandas.Index] mapping indexing keys (levels/dimension names)
to indexes used for indexing along that dimension.
"""
return OrderedDict((key, coords[key].to_index())
for key in dims if key in coords)

0 comments on commit 06244df

Please sign in to comment.