Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experimental] Refactor Dataset to store variables in a manifest #5961

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def _update_coords(
if dim in variables:
new_coord_names.add(dim)

self._data._manifest.variables = variables
self._data._variables = variables
self._data._coord_names.update(new_coord_names)
self._data._dims = dims
Expand Down
45 changes: 40 additions & 5 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

from ..coding.cftimeindex import _parse_array_of_cftime_strings
from ..plot.dataset_plot import _Dataset_PlotMethods
from ..tree.manifest import DataManifest
from . import (
alignment,
dtypes,
Expand Down Expand Up @@ -705,6 +706,7 @@ class Dataset(DataWithCoords, DatasetArithmetic, Mapping):
_close: Optional[Callable[[], None]]
_indexes: Optional[Dict[Hashable, Index]]
_variables: Dict[Hashable, Variable]
_manifest: DataManifest

__slots__ = (
"_attrs",
Expand All @@ -715,6 +717,7 @@ class Dataset(DataWithCoords, DatasetArithmetic, Mapping):
"_close",
"_indexes",
"_variables",
"_manifest",
"__weakref__",
)

Expand Down Expand Up @@ -752,10 +755,13 @@ def __init__(
data_vars, coords, compat="broadcast_equals"
)

self._manifest = DataManifest(variables=variables)

# The private attributes that effectively define the data model
self._attrs = dict(attrs) if attrs is not None else None
self._close = None
self._encoding = None
self._variables = variables
self._variables = self._manifest.variables
self._coord_names = coord_names
self._dims = dims
self._indexes = indexes
Expand All @@ -781,7 +787,7 @@ def variables(self) -> Mapping[Hashable, Variable]:
constituting the Dataset, including both data variables and
coordinates.
"""
return Frozen(self._variables)
return Frozen(self._manifest.variables)

@property
def attrs(self) -> Dict[Hashable, Any]:
Expand Down Expand Up @@ -1082,7 +1088,34 @@ def _construct_direct(
if dims is None:
dims = calculate_dimensions(variables)
obj = object.__new__(cls)
obj._variables = variables
obj._manifest = DataManifest(variables=variables)
obj._variables = obj._manifest.variables
obj._coord_names = coord_names
obj._dims = dims
obj._indexes = indexes
obj._attrs = attrs
obj._close = close
obj._encoding = encoding
return obj

@classmethod
def _construct_from_manifest(
cls,
manifest,
coord_names,
dims=None,
attrs=None,
indexes=None,
encoding=None,
close=None,
):
"""Creates a Dataset that is forced to be consistent with a DataTree node that shares its manifest."""
Comment on lines +1101 to +1112
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is to use this when someone calls dt.ds - so that the DataTree returns an actual Dataset, but one whose contents are linked to the wrapping tree node.

variables = manifest.variables
if dims is None:
dims = calculate_dimensions(variables)
obj = object.__new__(cls)
obj._manifest = manifest
obj._variables = obj._manifest.variables
obj._coord_names = coord_names
obj._dims = dims
obj._indexes = indexes
Expand Down Expand Up @@ -1111,7 +1144,9 @@ def _replace(
"""
if inplace:
if variables is not None:
self._variables = variables
self._manifest.variables = variables
# TODO if ds._variables properly pointed to ds._manifest.variables we wouldn't need this line
self._variables = self._manifest.variables
Comment on lines +1147 to +1149
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I wanted was for self._variables to always point to self._manifest.variables, such that updating the former by definition updates the latter. But I'm not really sure if that kind of pointer-like behaviour is possible.

if coord_names is not None:
self._coord_names = coord_names
if dims is not None:
Expand Down Expand Up @@ -1629,7 +1664,7 @@ def _setitem_check(self, key, value):

def __delitem__(self, key: Hashable) -> None:
"""Remove a variable from this dataset."""
del self._variables[key]
del self._manifest[key]
self._coord_names.discard(key)
if key in self.xindexes:
assert self._indexes is not None
Expand Down
2 changes: 2 additions & 0 deletions xarray/tree/datatree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class DataTree:
...
74 changes: 74 additions & 0 deletions xarray/tree/manifest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from collections.abc import MutableMapping
from typing import Dict, Hashable, Mapping, Union

from xarray.core.variable import Variable
from xarray.tree.datatree import DataTree


class DataManifest(MutableMapping):
"""
Stores variables and/or child tree nodes.

When stored inside a DataTree node it prevents name collisions by acting as a common
record of stored items for both the DataTree instance and its wrapped Dataset instance.

When stored inside a lone Dataset it acts merely like a dictionary of Variables.
"""

def __init__(
self,
variables: Dict[Hashable, Variable] = {},
children: Dict[Hashable, DataTree] = {},
):
if variables and children:
keys_in_both = set(variables.keys()) & set(children.keys())
if keys_in_both:
raise KeyError(
f"The keys {keys_in_both} exist in both the variables and child nodes"
)

self._variables = variables
self._children = children

@property
def variables(self) -> Mapping[Hashable, Variable]:
return self._variables

@variables.setter
def variables(self, variables):
for key in variables:
if key in self.children:
raise KeyError(
f"Cannot set variable under name {key} because a child node "
"with that name already exists"
)
self._variables = variables

@property
def children(self) -> Mapping[Hashable, DataTree]:
return self._children

def __getitem__(self, key: Hashable) -> Union[Variable, DataTree]:
if key in self._variables:
return self._variables[key]
elif key in self._children:
return self._children[key]
else:
raise KeyError(f"{key} is not present")

def __setitem__(self, key, value):
raise NotImplementedError

def __delitem__(self, key: Hashable):
if key in self.variables:
del self._variables[key]
elif key in self.children:
del self._children[key]
else:
raise KeyError(f"Cannot remove item because nothing is stored under {key}")

def __iter__(self):
raise NotImplementedError

def __len__(self):
raise NotImplementedError