Skip to content

Commit

Permalink
Remove the references to _file_obj outside low level code paths, ch…
Browse files Browse the repository at this point in the history
…ange to `_close` (#4809)

* Move from _file_obj object to _close function

* Remove all references to _close outside of low level

* Fix type hints

* Cleanup code style

* Fix non-trivial type hint problem

* Revert adding the `close` argument and add a set_close instead

* Remove helper class for an easier helper function + code style

* Add set_close docstring

* Code style

* Revert changes in _replace to keep cose as an exception

See: https://github.com/pydata/xarray/pull/4809/files#r557628298

* One more bit to revert

* One more bit to revert

* Add What's New entry

* Use set_close setter

* Apply suggestions from code review

Co-authored-by: Stephan Hoyer <shoyer@google.com>

* Rename user-visible argument

* Sync wording in docstrings.

Co-authored-by: Stephan Hoyer <shoyer@google.com>
  • Loading branch information
alexamici and shoyer authored Jan 18, 2021
1 parent a2b1712 commit 2a43385
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 38 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ Internal Changes
By `Maximilian Roos <https://github.com/max-sixty>`_.
- Speed up attribute style access (e.g. ``ds.somevar`` instead of ``ds["somevar"]``) and tab completion
in ipython (:issue:`4741`, :pull:`4742`). By `Richard Kleijn <https://github.com/rhkleijn>`_.
- Added the ``set_close`` method to ``Dataset`` and ``DataArray`` for beckends to specify how to voluntary release
all resources. (:pull:`#4809`), By `Alessandro Amici <https://github.com/alexamici>`_.

.. _whats-new.0.16.2:

Expand Down
25 changes: 9 additions & 16 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ def maybe_decode_store(store, chunks):

else:
ds2 = ds
ds2._file_obj = ds._file_obj
ds2.set_close(ds._close)
return ds2

filename_or_obj = _normalize_path(filename_or_obj)
Expand Down Expand Up @@ -701,7 +701,7 @@ def open_dataarray(
else:
(data_array,) = dataset.data_vars.values()

data_array._file_obj = dataset._file_obj
data_array.set_close(dataset._close)

# Reset names if they were changed during saving
# to ensure that we can 'roundtrip' perfectly
Expand All @@ -715,17 +715,6 @@ def open_dataarray(
return data_array


class _MultiFileCloser:
__slots__ = ("file_objs",)

def __init__(self, file_objs):
self.file_objs = file_objs

def close(self):
for f in self.file_objs:
f.close()


def open_mfdataset(
paths,
chunks=None,
Expand Down Expand Up @@ -918,14 +907,14 @@ def open_mfdataset(
getattr_ = getattr

datasets = [open_(p, **open_kwargs) for p in paths]
file_objs = [getattr_(ds, "_file_obj") for ds in datasets]
closers = [getattr_(ds, "_close") for ds in datasets]
if preprocess is not None:
datasets = [preprocess(ds) for ds in datasets]

if parallel:
# calling compute here will return the datasets/file_objs lists,
# the underlying datasets will still be stored as dask arrays
datasets, file_objs = dask.compute(datasets, file_objs)
datasets, closers = dask.compute(datasets, closers)

# Combine all datasets, closing them in case of a ValueError
try:
Expand Down Expand Up @@ -963,7 +952,11 @@ def open_mfdataset(
ds.close()
raise

combined._file_obj = _MultiFileCloser(file_objs)
def multi_file_closer():
for closer in closers:
closer()

combined.set_close(multi_file_closer)

# read global attributes from the attrs_file or from the first dataset
if attrs_file is not None:
Expand Down
2 changes: 1 addition & 1 deletion xarray/backends/apiv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _dataset_from_backend_dataset(
**extra_tokens,
)

ds._file_obj = backend_ds._file_obj
ds.set_close(backend_ds._close)

# Ensure source filename always stored in dataset object (GH issue #2550)
if "source" not in ds.encoding:
Expand Down
2 changes: 1 addition & 1 deletion xarray/backends/rasterio_.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,6 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, loc
result = result.chunk(chunks, name_prefix=name_prefix, token=token)

# Make the file closeable
result._file_obj = manager
result.set_close(manager.close)

return result
3 changes: 1 addition & 2 deletions xarray/backends/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def open_backend_dataset_store(
decode_timedelta=None,
):
vars, attrs = store.load()
file_obj = store
encoding = store.get_encoding()

vars, attrs, coord_names = conventions.decode_cf_variables(
Expand All @@ -36,7 +35,7 @@ def open_backend_dataset_store(

ds = Dataset(vars, attrs=attrs)
ds = ds.set_coords(coord_names.intersection(vars))
ds._file_obj = file_obj
ds.set_close(store.close)
ds.encoding = encoding

return ds
Expand Down
6 changes: 3 additions & 3 deletions xarray/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,12 +576,12 @@ def decode_cf(
vars = obj._variables
attrs = obj.attrs
extra_coords = set(obj.coords)
file_obj = obj._file_obj
close = obj._close
encoding = obj.encoding
elif isinstance(obj, AbstractDataStore):
vars, attrs = obj.load()
extra_coords = set()
file_obj = obj
close = obj.close
encoding = obj.get_encoding()
else:
raise TypeError("can only decode Dataset or DataStore objects")
Expand All @@ -599,7 +599,7 @@ def decode_cf(
)
ds = Dataset(vars, attrs=attrs)
ds = ds.set_coords(coord_names.union(extra_coords).intersection(vars))
ds._file_obj = file_obj
ds.set_close(close)
ds.encoding = encoding

return ds
Expand Down
29 changes: 24 additions & 5 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Iterator,
List,
Mapping,
Optional,
Tuple,
TypeVar,
Union,
Expand Down Expand Up @@ -330,7 +331,9 @@ def get_squeeze_dims(
class DataWithCoords(SupportsArithmetic, AttrAccessMixin):
"""Shared base class for Dataset and DataArray."""

__slots__ = ()
_close: Optional[Callable[[], None]]

__slots__ = ("_close",)

_rolling_exp_cls = RollingExp

Expand Down Expand Up @@ -1263,11 +1266,27 @@ def where(self, cond, other=dtypes.NA, drop: bool = False):

return ops.where_method(self, cond, other)

def set_close(self, close: Optional[Callable[[], None]]) -> None:
"""Register the function that releases any resources linked to this object.
This method controls how xarray cleans up resources associated
with this object when the ``.close()`` method is called. It is mostly
intended for backend developers and it is rarely needed by regular
end-users.
Parameters
----------
close : callable
The function that when called like ``close()`` releases
any resources linked to this object.
"""
self._close = close

def close(self: Any) -> None:
"""Close any files linked to this object"""
if self._file_obj is not None:
self._file_obj.close()
self._file_obj = None
"""Release any resources linked to this object."""
if self._close is not None:
self._close()
self._close = None

def isnull(self, keep_attrs: bool = None):
"""Test each value in the array for whether it is a missing value.
Expand Down
5 changes: 3 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,14 +344,15 @@ class DataArray(AbstractArray, DataWithCoords):

_cache: Dict[str, Any]
_coords: Dict[Any, Variable]
_close: Optional[Callable[[], None]]
_indexes: Optional[Dict[Hashable, pd.Index]]
_name: Optional[Hashable]
_variable: Variable

__slots__ = (
"_cache",
"_coords",
"_file_obj",
"_close",
"_indexes",
"_name",
"_variable",
Expand Down Expand Up @@ -421,7 +422,7 @@ def __init__(
# public interface.
self._indexes = indexes

self._file_obj = None
self._close = None

def _replace(
self,
Expand Down
17 changes: 9 additions & 8 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,7 @@ class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords):
_coord_names: Set[Hashable]
_dims: Dict[Hashable, int]
_encoding: Optional[Dict[Hashable, Any]]
_close: Optional[Callable[[], None]]
_indexes: Optional[Dict[Hashable, pd.Index]]
_variables: Dict[Hashable, Variable]

Expand All @@ -645,7 +646,7 @@ class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords):
"_coord_names",
"_dims",
"_encoding",
"_file_obj",
"_close",
"_indexes",
"_variables",
"__weakref__",
Expand Down Expand Up @@ -687,7 +688,7 @@ def __init__(
)

self._attrs = dict(attrs) if attrs is not None else None
self._file_obj = None
self._close = None
self._encoding = None
self._variables = variables
self._coord_names = coord_names
Expand All @@ -703,7 +704,7 @@ def load_store(cls, store, decoder=None) -> "Dataset":
if decoder:
variables, attributes = decoder(variables, attributes)
obj = cls(variables, attrs=attributes)
obj._file_obj = store
obj.set_close(store.close)
return obj

@property
Expand Down Expand Up @@ -876,7 +877,7 @@ def __dask_postcompute__(self):
self._attrs,
self._indexes,
self._encoding,
self._file_obj,
self._close,
)
return self._dask_postcompute, args

Expand All @@ -896,7 +897,7 @@ def __dask_postpersist__(self):
self._attrs,
self._indexes,
self._encoding,
self._file_obj,
self._close,
)
return self._dask_postpersist, args

Expand Down Expand Up @@ -1007,7 +1008,7 @@ def _construct_direct(
attrs=None,
indexes=None,
encoding=None,
file_obj=None,
close=None,
):
"""Shortcut around __init__ for internal use when we want to skip
costly validation
Expand All @@ -1020,7 +1021,7 @@ def _construct_direct(
obj._dims = dims
obj._indexes = indexes
obj._attrs = attrs
obj._file_obj = file_obj
obj._close = close
obj._encoding = encoding
return obj

Expand Down Expand Up @@ -2122,7 +2123,7 @@ def isel(
attrs=self._attrs,
indexes=indexes,
encoding=self._encoding,
file_obj=self._file_obj,
close=self._close,
)

def _isel_fancy(
Expand Down

0 comments on commit 2a43385

Please sign in to comment.