diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 9dfd64e9c99..0ad6b0a2abc 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -284,6 +284,7 @@ def _update_coords( if dim in variables: new_coord_names.add(dim) + self._data._manifest.variables = variables self._data._variables = variables self._data._coord_names.update(new_coord_names) self._data._dims = dims diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index e882495dce5..60b640ca448 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -37,6 +37,7 @@ from ..coding.cftimeindex import _parse_array_of_cftime_strings from ..plot.dataset_plot import _Dataset_PlotMethods +from ..tree.manifest import DataManifest from . import ( alignment, dtypes, @@ -705,6 +706,7 @@ class Dataset(DataWithCoords, DatasetArithmetic, Mapping): _close: Optional[Callable[[], None]] _indexes: Optional[Dict[Hashable, Index]] _variables: Dict[Hashable, Variable] + _manifest: DataManifest __slots__ = ( "_attrs", @@ -715,6 +717,7 @@ class Dataset(DataWithCoords, DatasetArithmetic, Mapping): "_close", "_indexes", "_variables", + "_manifest", "__weakref__", ) @@ -752,10 +755,13 @@ def __init__( data_vars, coords, compat="broadcast_equals" ) + self._manifest = DataManifest(variables=variables) + + # The private attributes that effectively define the data model self._attrs = dict(attrs) if attrs is not None else None self._close = None self._encoding = None - self._variables = variables + self._variables = self._manifest.variables self._coord_names = coord_names self._dims = dims self._indexes = indexes @@ -781,7 +787,7 @@ def variables(self) -> Mapping[Hashable, Variable]: constituting the Dataset, including both data variables and coordinates. """ - return Frozen(self._variables) + return Frozen(self._manifest.variables) @property def attrs(self) -> Dict[Hashable, Any]: @@ -1082,7 +1088,34 @@ def _construct_direct( if dims is None: dims = calculate_dimensions(variables) obj = object.__new__(cls) - obj._variables = variables + obj._manifest = DataManifest(variables=variables) + obj._variables = obj._manifest.variables + obj._coord_names = coord_names + obj._dims = dims + obj._indexes = indexes + obj._attrs = attrs + obj._close = close + obj._encoding = encoding + return obj + + @classmethod + def _construct_from_manifest( + cls, + manifest, + coord_names, + dims=None, + attrs=None, + indexes=None, + encoding=None, + close=None, + ): + """Creates a Dataset that is forced to be consistent with a DataTree node that shares its manifest.""" + variables = manifest.variables + if dims is None: + dims = calculate_dimensions(variables) + obj = object.__new__(cls) + obj._manifest = manifest + obj._variables = obj._manifest.variables obj._coord_names = coord_names obj._dims = dims obj._indexes = indexes @@ -1111,7 +1144,9 @@ def _replace( """ if inplace: if variables is not None: - self._variables = variables + self._manifest.variables = variables + # TODO if ds._variables properly pointed to ds._manifest.variables we wouldn't need this line + self._variables = self._manifest.variables if coord_names is not None: self._coord_names = coord_names if dims is not None: @@ -1629,7 +1664,7 @@ def _setitem_check(self, key, value): def __delitem__(self, key: Hashable) -> None: """Remove a variable from this dataset.""" - del self._variables[key] + del self._manifest[key] self._coord_names.discard(key) if key in self.xindexes: assert self._indexes is not None diff --git a/xarray/tree/datatree.py b/xarray/tree/datatree.py new file mode 100644 index 00000000000..2d553748d9e --- /dev/null +++ b/xarray/tree/datatree.py @@ -0,0 +1,2 @@ +class DataTree: + ... diff --git a/xarray/tree/manifest.py b/xarray/tree/manifest.py new file mode 100644 index 00000000000..9dae00557b0 --- /dev/null +++ b/xarray/tree/manifest.py @@ -0,0 +1,74 @@ +from collections.abc import MutableMapping +from typing import Dict, Hashable, Mapping, Union + +from xarray.core.variable import Variable +from xarray.tree.datatree import DataTree + + +class DataManifest(MutableMapping): + """ + Stores variables and/or child tree nodes. + + When stored inside a DataTree node it prevents name collisions by acting as a common + record of stored items for both the DataTree instance and its wrapped Dataset instance. + + When stored inside a lone Dataset it acts merely like a dictionary of Variables. + """ + + def __init__( + self, + variables: Dict[Hashable, Variable] = {}, + children: Dict[Hashable, DataTree] = {}, + ): + if variables and children: + keys_in_both = set(variables.keys()) & set(children.keys()) + if keys_in_both: + raise KeyError( + f"The keys {keys_in_both} exist in both the variables and child nodes" + ) + + self._variables = variables + self._children = children + + @property + def variables(self) -> Mapping[Hashable, Variable]: + return self._variables + + @variables.setter + def variables(self, variables): + for key in variables: + if key in self.children: + raise KeyError( + f"Cannot set variable under name {key} because a child node " + "with that name already exists" + ) + self._variables = variables + + @property + def children(self) -> Mapping[Hashable, DataTree]: + return self._children + + def __getitem__(self, key: Hashable) -> Union[Variable, DataTree]: + if key in self._variables: + return self._variables[key] + elif key in self._children: + return self._children[key] + else: + raise KeyError(f"{key} is not present") + + def __setitem__(self, key, value): + raise NotImplementedError + + def __delitem__(self, key: Hashable): + if key in self.variables: + del self._variables[key] + elif key in self.children: + del self._children[key] + else: + raise KeyError(f"Cannot remove item because nothing is stored under {key}") + + def __iter__(self): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError