Skip to content

Commit

Permalink
BUG: Close file handle with lock=False
Browse files Browse the repository at this point in the history
  • Loading branch information
snowman2 committed May 24, 2021
1 parent b7d5999 commit 236f049
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ History
Latest
------
- BUG: pass kwargs with lock=False (issue #344)
- BUG: Close file handle with lock=False (pull #346)

0.4.0
------
Expand Down
73 changes: 72 additions & 1 deletion rioxarray/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import contextlib
import os
import re
import threading
import warnings
from distutils.version import LooseVersion

Expand All @@ -35,6 +36,61 @@
NO_LOCK = contextlib.nullcontext()


class FileHandleLocal(threading.local):
"""
This contains the thread local ThreadURIManager
"""

def __init__(self): # pylint: disable=super-init-not-called
self.thread_manager = None # Initialises in each thread


class ThreadURIManager:
"""
This handles opening & closing file handles in each thread.
"""

def __init__(
self,
opener,
*args,
mode="r",
kwargs=None,
):
self._opener = opener
self._args = args
self._mode = mode
self._kwargs = {} if kwargs is None else dict(kwargs)
self._file_handle = None

@property
def file_handle(self):
"""
File handle returned by the opener.
"""
if self._file_handle is not None:
return self._file_handle
self._file_handle = self._opener(*self._args, mode=self._mode, **self._kwargs)
return self._file_handle

def close(self):
"""
Close file handle.
"""
if self._file_handle is not None:
self._file_handle.close()
self._file_handle = None

def __del__(self):
self.close()

def __enter__(self):
return self

def __exit__(self, type_, value, traceback):
self.close()


class URIManager(FileManager):
"""
The URI manager is used for lockless reading
Expand All @@ -51,16 +107,31 @@ def __init__(
self._args = args
self._mode = mode
self._kwargs = {} if kwargs is None else dict(kwargs)
self._local = FileHandleLocal()

def acquire(self, needs_lock=True):
return self._opener(*self._args, mode=self._mode, **self._kwargs)
if self._local.thread_manager is None:
self._local.thread_manager = ThreadURIManager(
self._opener, *self._args, mode=self._mode, kwargs=self._kwargs
)
return self._local.thread_manager.file_handle

@contextlib.contextmanager
def acquire_context(self, needs_lock=True):
yield self.acquire(needs_lock=needs_lock)

def close(self, needs_lock=True):
pass

def __getstate__(self):
"""State for pickling."""
return (self._opener, self._args, self._mode, self._kwargs)

def __setstate__(self, state):
"""Restore from a pickle."""
opener, args, mode, kwargs = state
self.__init__(opener, *args, mode=mode, kwargs=kwargs)


class RasterioArrayWrapper(BackendArray):
"""A wrapper around rasterio dataset objects"""
Expand Down
4 changes: 2 additions & 2 deletions test/integration/test_integration__io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,14 +1009,14 @@ def test_nc_attr_loading__disable_decode_times(open_rasterio):

def test_lockless():
with rioxarray.open_rasterio(
os.path.join(TEST_INPUT_DATA_DIR, "PLANET_SCOPE_3D.nc"), lock=False, chunks=True
os.path.join(TEST_INPUT_DATA_DIR, "cog.tif"), lock=False, chunks=True
) as rds:
rds.mean().compute()


def test_lock_true():
with rioxarray.open_rasterio(
os.path.join(TEST_INPUT_DATA_DIR, "PLANET_SCOPE_3D.nc"), lock=True, chunks=True
os.path.join(TEST_INPUT_DATA_DIR, "cog.tif"), lock=True, chunks=True
) as rds:
rds.mean().compute()

Expand Down
133 changes: 133 additions & 0 deletions test/integration/test_integration__io_uri_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import concurrent.futures
import gc
import pickle
from unittest import mock

import pytest

from rioxarray._io import URIManager


def test_uri_manager_mock_write():
mock_file = mock.Mock()
opener = mock.Mock(spec=open, return_value=mock_file)

manager = URIManager(opener, "filename")
f = manager.acquire()
f.write("contents")
manager.close()

opener.assert_called_once_with("filename", mode="r")
mock_file.write.assert_called_once_with("contents")
mock_file.close.assert_called_once_with()


def test_uri_manager_mock_write__threaded():
mock_file = mock.Mock()
opener = mock.Mock(spec=open, return_value=mock_file)

manager = URIManager(opener, "filename")

def write(iter):
nonlocal manager
fh = manager.acquire()
fh.write("contents")
manager._local.thread_manager = None

with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
for result in executor.map(write, range(5)):
pass

gc.collect()

opener.assert_has_calls([mock.call("filename", mode="r") for _ in range(5)])
mock_file.write.assert_has_calls([mock.call("contents") for _ in range(5)])
mock_file.close.assert_has_calls([mock.call() for _ in range(5)])


@pytest.mark.parametrize("expected_warning", [None, RuntimeWarning])
def test_uri_manager_autoclose(expected_warning):
mock_file = mock.Mock()
opener = mock.Mock(return_value=mock_file)

manager = URIManager(opener, "filename")
manager.acquire()

del manager
gc.collect()

mock_file.close.assert_called_once_with()


def test_uri_manager_write_concurrent(tmpdir):
path = str(tmpdir.join("testing.txt"))
manager = URIManager(open, path, mode="w")
f1 = manager.acquire()
f2 = manager.acquire()
f3 = manager.acquire()
assert f1 is f2
assert f2 is f3
f1.write("foo")
f1.flush()
f2.write("bar")
f2.flush()
f3.write("baz")
f3.flush()
manager.close()

with open(path) as f:
assert f.read() == "foobarbaz"


def test_uri_manager_write_pickle(tmpdir):
path = str(tmpdir.join("testing.txt"))
manager = URIManager(open, path, mode="a")
f = manager.acquire()
f.write("foo")
f.flush()
manager2 = pickle.loads(pickle.dumps(manager))
f2 = manager2.acquire()
f2.write("bar")
manager2.close()
manager.close()

with open(path) as f:
assert f.read() == "foobar"


def test_uri_manager_read(tmpdir):
path = str(tmpdir.join("testing.txt"))

with open(path, "w") as f:
f.write("foobar")

manager = URIManager(open, path)
f = manager.acquire()
assert f.read() == "foobar"
manager.close()


def test_uri_manager_acquire_context(tmpdir):
path = str(tmpdir.join("testing.txt"))

with open(path, "w") as f:
f.write("foobar")

class AcquisitionError(Exception):
pass

manager = URIManager(open, path)
with pytest.raises(AcquisitionError):
with manager.acquire_context() as f:
assert f.read() == "foobar"
raise AcquisitionError

with manager.acquire_context() as f:
assert f.read() == "foobar"

with pytest.raises(AcquisitionError):
with manager.acquire_context() as f:
f.seek(0)
assert f.read() == "foobar"
raise AcquisitionError
manager.close()

0 comments on commit 236f049

Please sign in to comment.