Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pickle test #27

Merged
merged 16 commits into from
Nov 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/nd2/_sdk/latest.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ cdef class ND2Reader:
@property
def attributes(self) -> structures.Attributes:
if not hasattr(self, '__attributes'):
if not self._is_open:
raise ValueError("Attempt to get attributes from closed nd2 file")
cont = self._metadata().get('contents')
attrs = self._attributes()
nC = cont.get('channelCount') if cont else attrs.get("componentCount", 1)
Expand All @@ -88,21 +90,32 @@ cdef class ND2Reader:
return (1, 1, 1)

def _metadata(self) -> dict:
if not self._is_open:
raise ValueError("Attempt to get metadata from closed nd2 file")
return _loads(Lim_FileGetMetadata(self._fh))

def metadata(self) -> structures.Metadata:
return structures.Metadata(**self._metadata())

def _frame_metadata(self, seq_index: int) -> dict:
if not self._is_open:
raise ValueError("Attempt to get frame_metadata from closed nd2 file")
return _loads(Lim_FileGetFrameMetadata(self._fh, seq_index))

def frame_metadata(self) -> structures.Metadata:
return structures.FrameMetadata(**self._frame_metadata())

def text_info(self) -> dict:
if not self._is_open:
raise ValueError("Attempt to get text_info from closed nd2 file")
return _loads(Lim_FileGetTextinfo(self._fh))

def _description(self) -> str:
return self.text_info().get("description", '')

def _experiment(self) -> list:
if not self._is_open:
raise ValueError("Attempt to get experiment from closed nd2 file")
return _loads(Lim_FileGetExperiment(self._fh), list)

def experiment(self) -> List[structures.ExpLoop]:
Expand All @@ -116,6 +129,8 @@ cdef class ND2Reader:
return Lim_FileGetCoordSize(self._fh)

def _seq_index_from_coords(self, coords: Sequence) -> int:
if not self._is_open:
raise ValueError("Attempt to seq_index from closed nd2 file")
cdef LIMSIZE size = self._coord_size()
if size == 0:
return -1
Expand Down
56 changes: 29 additions & 27 deletions src/nd2/nd2file.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,10 @@ def close(self) -> None:
self._rdr.close()
self._closed = True

def __getstate__(self):
state = self.__dict__.copy()
del state["_rdr"]
return state

def __setstate__(self, d):
self.__dict__ = d
self._rdr = get_reader(self._path)
if not self._closed:
self.open()
@property
def closed(self) -> bool:
"""Whether the file is closed."""
return self._closed

def __enter__(self) -> ND2File:
self.open()
Expand All @@ -99,10 +93,18 @@ def __enter__(self) -> ND2File:
def __exit__(self, *_) -> None:
self.close()

@property
def closed(self) -> bool:
"""Whether the file is closed."""
return self._closed
def __getstate__(self):
state = self.__dict__.copy()
del state["_rdr"]
del state["_lock"]
return state

def __setstate__(self, d):
self.__dict__ = d
self._lock = threading.RLock()
self._rdr = get_reader(self._path)
if self._closed:
self._rdr.close()

@cached_property
def attributes(self) -> Attributes:
Expand Down Expand Up @@ -256,43 +258,43 @@ def __array__(self) -> np.ndarray:
"""array protocol"""
return self.asarray()

def to_dask(self, opening_array=True) -> da.Array:
def to_dask(self, wrapper=True) -> da.Array:
"""Create dask array (delayed reader) representing image.

Parameters
----------
opening_array : bool, optional
wrapper : bool, optional
If True (the default), the returned obect will be a thin subclass of
a :class:`dask.array.Array` (an
`nd2.opening_dask_array.OpeningDaskArray`) that manages the opening
and closing of this file when getting chunks. If opening_array is
`False`, then a pure `da.Array` will be returned. However, when that
`nd2.resource_backed_array.OpeningDaskArray`) that manages the opening
and closing of this file when getting chunks via compute(). If `wrapper`
is `False`, then a pure `da.Array` will be returned. However, when that
array is computed, it will incur a file open/close on *every* chunk
that is read (in the `_dask_block` method). As such `opening_array`
will generally be much faster, however, it *may* fail with certain
dask schedulers.
that is read (in the `_dask_block` method). As such `wrapper`
will generally be much faster, however, it *may* fail (i.e. result in
segmentation faults) with certain dask schedulers.

Returns
-------
da.Array
"""
from dask.array import map_blocks

from .opening_dask_array import OpeningDaskArray
from .resource_backed_array import ResourceBackedDaskArray

chunks = [(1,) * x for x in self._coord_shape]
chunks += [(x,) for x in self._frame_shape]
dask_arr = map_blocks(
self._dask_block,
# the opening_array doesn't need a threading lock
nullcontext() if opening_array else self._lock,
# the wrapper doesn't need a threading lock
nullcontext() if wrapper else self._lock,
chunks=chunks,
dtype=self.dtype,
)
if opening_array:
if wrapper:
# this subtype allows the dask array to re-open the underlying
# nd2 file on compute.
return OpeningDaskArray.from_array(dask_arr, self)
return ResourceBackedDaskArray.from_array(dask_arr, self)
return dask_arr

_NO_IDX = -1
Expand Down
44 changes: 34 additions & 10 deletions src/nd2/opening_dask_array.py → src/nd2/resource_backed_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _copy_doc(method):
return method


class OpeningDaskArray(da.Array):
class ResourceBackedDaskArray(da.Array):
_file_ctx: CheckableContext

def __new__(
Expand All @@ -42,25 +42,25 @@ def __new__(
dtype=None,
meta=None,
shape=None,
*,
_file_ctx: CheckableContext
_file_ctx: CheckableContext = None,
):
arr = super().__new__(
cls, dask, name, chunks, dtype=dtype, meta=meta, shape=shape
)
assert _file_ctx is not None
arr._file_ctx = _file_ctx
return arr

@classmethod
def from_array(cls, arr, ctx: CheckableContext) -> OpeningDaskArray:
def from_array(cls, arr, ctx: CheckableContext) -> ResourceBackedDaskArray:
"""Create an OpeningDaskArray with a checkable context.

`ctx` must be a context manager that opens/closes some underlying resource (like
a file), and has a `closed` attribute that returns the current state of the
resource. This subclass will take care of opening and closing the resource on
compute.
"""
if isinstance(arr, OpeningDaskArray):
if isinstance(arr, ResourceBackedDaskArray):
return arr
_a = arr if isinstance(arr, da.Array) else da.from_array(arr)
arr = cls(
Expand All @@ -85,27 +85,51 @@ def compute(self, **kwargs: Any) -> np.ndarray:

def __getitem__(self, index):
# indexing should also return an Opening Array
return OpeningDaskArray.from_array(super().__getitem__(index), self._file_ctx)
return ResourceBackedDaskArray.from_array(
super().__getitem__(index), self._file_ctx
)

def __getattribute__(self, name: Any) -> Any:
# allows methods like `array.mean()` to also return an OpeningDaskArray
attr = object.__getattribute__(self, name)
if (
not name.startswith("_")
and name not in OpeningDaskArray.__dict__
and name not in ResourceBackedDaskArray.__dict__
and callable(attr)
):
return _ArrayMethodProxy(attr, self._file_ctx)
return attr

def __array_function__(self, func, types, args, kwargs):
# obey NEP18
types = tuple(da.Array if x is OpeningDaskArray else x for x in types)
types = tuple(da.Array if x is ResourceBackedDaskArray else x for x in types)
arr = super().__array_function__(func, types, args, kwargs)
if isinstance(arr, da.Array):
return OpeningDaskArray.from_array(arr, self._file_ctx)
return ResourceBackedDaskArray.from_array(arr, self._file_ctx)
return arr

def __reduce__(self):
# for pickle
return (
ResourceBackedDaskArray,
(
self.dask,
self.name,
self.chunks,
self.dtype,
None,
None,
self._file_ctx,
),
# this empty dict causes __setstate__ to be called during pickle.load
# allowing us to close the newly created file_ctx, preventing leaked handle
{},
)

def __setstate__(self, d):
if not self._file_ctx.closed:
self._file_ctx.__exit__()


class _ArrayMethodProxy:
"""Wraps method on a dask array and returns a OpeningDaskArray if the result of the
Expand All @@ -122,5 +146,5 @@ def __call__(self, *args: Any, **kwds: Any) -> Any:
with self._file_ctx if self._file_ctx.closed else nullcontext():
result = self.method(*args, **kwds)
if isinstance(result, da.Array):
return OpeningDaskArray.from_array(result, self._file_ctx)
return ResourceBackedDaskArray.from_array(result, self._file_ctx)
return result
16 changes: 8 additions & 8 deletions tests/test_dask_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
import numpy as np
import pytest
from nd2 import ND2File
from nd2.opening_dask_array import OpeningDaskArray
from nd2.resource_backed_array import ResourceBackedDaskArray


@pytest.mark.parametrize("leave_open", [True, False])
@pytest.mark.parametrize("wrapper", [True, False])
def test_nd2_dask_closed_wrapper(single_nd2, wrapper, leave_open):
f = ND2File(single_nd2)
arr = f.to_dask(opening_array=wrapper)
arr = f.to_dask(wrapper=wrapper)
if not leave_open:
f.close()

is_wrapped = isinstance(arr, OpeningDaskArray)
is_wrapped = isinstance(arr, ResourceBackedDaskArray)
assert is_wrapped if wrapper else not is_wrapped
assert isinstance(arr, da.Array)
assert isinstance(arr.compute(), np.ndarray)
Expand All @@ -25,7 +25,7 @@ def test_nd2_dask_closed_wrapper(single_nd2, wrapper, leave_open):
def test_nd2_dask_einsum(single_nd2):
with ND2File(single_nd2) as f:
arr = f.to_dask()
assert isinstance(arr, OpeningDaskArray)
assert isinstance(arr, ResourceBackedDaskArray)
assert arr.shape == (3, 2, 32, 32)
reordered_dask = da.einsum("abcd->abcd", arr)
assert isinstance(reordered_dask[:1, :1, :1, :1].compute(), np.ndarray)
Expand All @@ -34,9 +34,9 @@ def test_nd2_dask_einsum(single_nd2):
def test_nd2_dask_einsum_via_nep18(single_nd2):
with ND2File(single_nd2) as f:
arr = f.to_dask()
assert isinstance(arr, OpeningDaskArray)
assert isinstance(arr, ResourceBackedDaskArray)
reordered_nep18 = np.einsum("abcd->abcd", arr)
assert isinstance(reordered_nep18, OpeningDaskArray)
assert isinstance(reordered_nep18, ResourceBackedDaskArray)
assert isinstance(reordered_nep18[:1, :1, :1, :1].compute(), np.ndarray)


Expand All @@ -50,9 +50,9 @@ def test_synthetic_dask_einsum_via_nep18():
def test_nd2_dask_einsum_via_nep18_small(single_nd2):
with ND2File(single_nd2) as f:
arr = f.to_dask()
assert isinstance(arr, OpeningDaskArray)
assert isinstance(arr, ResourceBackedDaskArray)
arr = arr[:10, :10, :10, :10]
assert isinstance(arr, OpeningDaskArray)
assert isinstance(arr, ResourceBackedDaskArray)
reordered_nep18 = np.einsum("abcd->abcd", arr)
assert isinstance(reordered_nep18, da.Array)
assert isinstance(reordered_nep18[:1, :1, :1, :1].compute(), np.ndarray)
Loading