diff --git a/docs/history.rst b/docs/history.rst index 217dc863..fcd17ec3 100644 --- a/docs/history.rst +++ b/docs/history.rst @@ -4,6 +4,7 @@ History Latest ------ - BUG: pass kwargs with lock=False (issue #344) +- BUG: Close file handle with lock=False (pull #346) 0.4.0 ------ diff --git a/rioxarray/_io.py b/rioxarray/_io.py index aeda3b7a..5a3c76ad 100644 --- a/rioxarray/_io.py +++ b/rioxarray/_io.py @@ -9,6 +9,7 @@ import contextlib import os import re +import threading import warnings from distutils.version import LooseVersion @@ -35,6 +36,61 @@ NO_LOCK = contextlib.nullcontext() +class FileHandleLocal(threading.local): + """ + This contains the thread local ThreadURIManager + """ + + def __init__(self): # pylint: disable=super-init-not-called + self.thread_manager = None # Initialises in each thread + + +class ThreadURIManager: + """ + This handles opening & closing file handles in each thread. + """ + + def __init__( + self, + opener, + *args, + mode="r", + kwargs=None, + ): + self._opener = opener + self._args = args + self._mode = mode + self._kwargs = {} if kwargs is None else dict(kwargs) + self._file_handle = None + + @property + def file_handle(self): + """ + File handle returned by the opener. + """ + if self._file_handle is not None: + return self._file_handle + self._file_handle = self._opener(*self._args, mode=self._mode, **self._kwargs) + return self._file_handle + + def close(self): + """ + Close file handle. + """ + if self._file_handle is not None: + self._file_handle.close() + self._file_handle = None + + def __del__(self): + self.close() + + def __enter__(self): + return self + + def __exit__(self, type_, value, traceback): + self.close() + + class URIManager(FileManager): """ The URI manager is used for lockless reading @@ -51,16 +107,31 @@ def __init__( self._args = args self._mode = mode self._kwargs = {} if kwargs is None else dict(kwargs) + self._local = FileHandleLocal() def acquire(self, needs_lock=True): - return self._opener(*self._args, mode=self._mode, **self._kwargs) + if self._local.thread_manager is None: + self._local.thread_manager = ThreadURIManager( + self._opener, *self._args, mode=self._mode, kwargs=self._kwargs + ) + return self._local.thread_manager.file_handle + @contextlib.contextmanager def acquire_context(self, needs_lock=True): yield self.acquire(needs_lock=needs_lock) def close(self, needs_lock=True): pass + def __getstate__(self): + """State for pickling.""" + return (self._opener, self._args, self._mode, self._kwargs) + + def __setstate__(self, state): + """Restore from a pickle.""" + opener, args, mode, kwargs = state + self.__init__(opener, *args, mode=mode, kwargs=kwargs) + class RasterioArrayWrapper(BackendArray): """A wrapper around rasterio dataset objects""" diff --git a/test/integration/test_integration__io.py b/test/integration/test_integration__io.py index d4713757..339767aa 100644 --- a/test/integration/test_integration__io.py +++ b/test/integration/test_integration__io.py @@ -1009,14 +1009,14 @@ def test_nc_attr_loading__disable_decode_times(open_rasterio): def test_lockless(): with rioxarray.open_rasterio( - os.path.join(TEST_INPUT_DATA_DIR, "PLANET_SCOPE_3D.nc"), lock=False, chunks=True + os.path.join(TEST_INPUT_DATA_DIR, "cog.tif"), lock=False, chunks=True ) as rds: rds.mean().compute() def test_lock_true(): with rioxarray.open_rasterio( - os.path.join(TEST_INPUT_DATA_DIR, "PLANET_SCOPE_3D.nc"), lock=True, chunks=True + os.path.join(TEST_INPUT_DATA_DIR, "cog.tif"), lock=True, chunks=True ) as rds: rds.mean().compute() diff --git a/test/integration/test_integration__io_uri_manager.py b/test/integration/test_integration__io_uri_manager.py new file mode 100644 index 00000000..65e08295 --- /dev/null +++ b/test/integration/test_integration__io_uri_manager.py @@ -0,0 +1,133 @@ +import concurrent.futures +import gc +import pickle +from unittest import mock + +import pytest + +from rioxarray._io import URIManager + + +def test_uri_manager_mock_write(): + mock_file = mock.Mock() + opener = mock.Mock(spec=open, return_value=mock_file) + + manager = URIManager(opener, "filename") + f = manager.acquire() + f.write("contents") + manager.close() + + opener.assert_called_once_with("filename", mode="r") + mock_file.write.assert_called_once_with("contents") + mock_file.close.assert_called_once_with() + + +def test_uri_manager_mock_write__threaded(): + mock_file = mock.Mock() + opener = mock.Mock(spec=open, return_value=mock_file) + + manager = URIManager(opener, "filename") + + def write(iter): + nonlocal manager + fh = manager.acquire() + fh.write("contents") + manager._local.thread_manager = None + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + for result in executor.map(write, range(5)): + pass + + gc.collect() + + opener.assert_has_calls([mock.call("filename", mode="r") for _ in range(5)]) + mock_file.write.assert_has_calls([mock.call("contents") for _ in range(5)]) + mock_file.close.assert_has_calls([mock.call() for _ in range(5)]) + + +@pytest.mark.parametrize("expected_warning", [None, RuntimeWarning]) +def test_uri_manager_autoclose(expected_warning): + mock_file = mock.Mock() + opener = mock.Mock(return_value=mock_file) + + manager = URIManager(opener, "filename") + manager.acquire() + + del manager + gc.collect() + + mock_file.close.assert_called_once_with() + + +def test_uri_manager_write_concurrent(tmpdir): + path = str(tmpdir.join("testing.txt")) + manager = URIManager(open, path, mode="w") + f1 = manager.acquire() + f2 = manager.acquire() + f3 = manager.acquire() + assert f1 is f2 + assert f2 is f3 + f1.write("foo") + f1.flush() + f2.write("bar") + f2.flush() + f3.write("baz") + f3.flush() + manager.close() + + with open(path) as f: + assert f.read() == "foobarbaz" + + +def test_uri_manager_write_pickle(tmpdir): + path = str(tmpdir.join("testing.txt")) + manager = URIManager(open, path, mode="a") + f = manager.acquire() + f.write("foo") + f.flush() + manager2 = pickle.loads(pickle.dumps(manager)) + f2 = manager2.acquire() + f2.write("bar") + manager2.close() + manager.close() + + with open(path) as f: + assert f.read() == "foobar" + + +def test_uri_manager_read(tmpdir): + path = str(tmpdir.join("testing.txt")) + + with open(path, "w") as f: + f.write("foobar") + + manager = URIManager(open, path) + f = manager.acquire() + assert f.read() == "foobar" + manager.close() + + +def test_uri_manager_acquire_context(tmpdir): + path = str(tmpdir.join("testing.txt")) + + with open(path, "w") as f: + f.write("foobar") + + class AcquisitionError(Exception): + pass + + manager = URIManager(open, path) + with pytest.raises(AcquisitionError): + with manager.acquire_context() as f: + assert f.read() == "foobar" + raise AcquisitionError + + with manager.acquire_context() as f: + assert f.read() == "foobar" + + with pytest.raises(AcquisitionError): + with manager.acquire_context() as f: + f.seek(0) + assert f.read() == "foobar" + raise AcquisitionError + manager.close()