Skip to content

Commit

Permalink
Optional check for concurrent usage errors (#989)
Browse files Browse the repository at this point in the history
Optional check for concurrent usage errors

Some driver objects (e.g, Sessions, Transactions, Result streams) are not safe
for concurrent use. By default, it will cause hard to interpret errors or, in
the worst case, wrong behavior.

To aid finding such bugs, the driver now detects if the Python interpreter is
running development mode and enables extra locking around those objects. If they
are used concurrently, an error will be raised. The way this is implemented, it
will only cause a one-time overhead when loading the driver's modules if the
checks are disabled.

Obviously, those checks are somewhat expensive as they entail locks (less so in
the async driver). Therefore, the checks are only happening if either
 * Python is started in development mode (`python -X dev ...`) or
 * The environment variable `PYTHONNEO4JDEBUG` is set (to anything non-empty)
   at the time the driver's modules is loaded.
  • Loading branch information
robsdedude authored Nov 27, 2023
1 parent f722f65 commit d130b9c
Show file tree
Hide file tree
Showing 15 changed files with 467 additions and 24 deletions.
19 changes: 19 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,25 @@ To deactivate the current active virtual environment, use:
deactivate
Development Environment
=======================

For development, we recommend to run Python in `development mode`_ (``python -X dev ...``).
Specifically for this driver, this will:

* enable :class:`ResourceWarning`, which the driver emits if resources (e.g., Sessions) aren't properly closed.
* enable :class:`DeprecationWarning`, which the driver emits if deprecated APIs are used.
* enable the driver's debug mode (this can also be achieved by setting the environment variable ``PYTHONNEO4JDEBUG``):

* **This is experimental**.
It might be changed or removed any time even without prior notice.
* the driver will raise an exception if non-concurrency-safe methods are used concurrently.

.. versionadded:: 5.15

.. _development mode: https://docs.python.org/3/library/devmode.html


*************
Quick Example
*************
Expand Down
19 changes: 5 additions & 14 deletions docs/source/themes/neo4j/static/css/neo4j.css_t
Original file line number Diff line number Diff line change
Expand Up @@ -503,25 +503,16 @@ dl.field-list > dd > ol {
margin-left: 0;
}

ol.simple p, ul.simple p {
margin-bottom: 0;
}

ol.simple > li:not(:first-child) > p,
ul.simple > li:not(:first-child) > p,
:not(li) > ol > li:first-child > :first-child,
:not(li) > ul > li:first-child > :first-child {
.content ol li > p:first-of-type,
.content ul li > p:first-of-type {
margin-top: 0;
}


li > p:last-child {
margin-top: 10px;
.content ol li > p:last-of-type,
.content ul li > p:last-of-type {
margin-bottom: 0;
}

li > p:first-child {
margin-top: 10px;
}

table.docutils {
margin-top: 10px;
Expand Down
20 changes: 20 additions & 0 deletions src/neo4j/_async/_debug/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from ._concurrency_check import AsyncNonConcurrentMethodChecker


__all__ = ["AsyncNonConcurrentMethodChecker"]
152 changes: 152 additions & 0 deletions src/neo4j/_async/_debug/_concurrency_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from __future__ import annotations

import inspect
import os
import sys
import traceback
import typing as t
from copy import deepcopy
from functools import wraps

from ..._async_compat.concurrency import (
AsyncLock,
AsyncRLock,
)
from ..._async_compat.util import AsyncUtil
from ..._meta import copy_signature


_TWrapped = t.TypeVar("_TWrapped", bound=t.Callable[..., t.Awaitable[t.Any]])
_TWrappedIter = t.TypeVar("_TWrappedIter",
bound=t.Callable[..., t.AsyncIterator])


ENABLED = sys.flags.dev_mode or bool(os.getenv("PYTHONNEO4JDEBUG"))


class NonConcurrentMethodError(RuntimeError):
pass


class AsyncNonConcurrentMethodChecker:
if ENABLED:

def __init__(self):
self.__lock = AsyncRLock()
self.__tracebacks_lock = AsyncLock()
self.__tracebacks = []

def __make_error(self, tbs):
msg = (f"Methods of {self.__class__} are not concurrency "
"safe, but were invoked concurrently.")
if tbs:
msg += ("\n\nOther invocation site:\n\n"
f"{''.join(traceback.format_list(tbs[0]))}")
return NonConcurrentMethodError(msg)

@classmethod
def non_concurrent_method(cls, f: _TWrapped) -> _TWrapped:
if AsyncUtil.is_async_code:
if not inspect.iscoroutinefunction(f):
raise TypeError(
"cannot decorate non-coroutine function with "
"AsyncNonConcurrentMethodChecked.non_concurrent_method"
)
else:
if not callable(f):
raise TypeError(
"cannot decorate non-callable object with "
"NonConcurrentMethodChecked.non_concurrent_method"
)

@wraps(f)
@copy_signature(f)
async def inner(*args, **kwargs):
self = args[0]
assert isinstance(self, cls)

async with self.__tracebacks_lock:
acquired = await self.__lock.acquire(blocking=False)
if acquired:
self.__tracebacks.append(AsyncUtil.extract_stack())
else:
tbs = deepcopy(self.__tracebacks)
if acquired:
try:
return await f(*args, **kwargs)
finally:
async with self.__tracebacks_lock:
self.__tracebacks.pop()
self.__lock.release()
else:
raise self.__make_error(tbs)

return inner

@classmethod
def non_concurrent_iter(cls, f: _TWrappedIter) -> _TWrappedIter:
if AsyncUtil.is_async_code:
if not inspect.isasyncgenfunction(f):
raise TypeError(
"cannot decorate non-async-generator function with "
"AsyncNonConcurrentMethodChecked.non_concurrent_iter"
)
else:
if not inspect.isgeneratorfunction(f):
raise TypeError(
"cannot decorate non-generator function with "
"NonConcurrentMethodChecked.non_concurrent_iter"
)

@wraps(f)
@copy_signature(f)
async def inner(*args, **kwargs):
self = args[0]
assert isinstance(self, cls)

iter_ = f(*args, **kwargs)
while True:
async with self.__tracebacks_lock:
acquired = await self.__lock.acquire(blocking=False)
if acquired:
self.__tracebacks.append(AsyncUtil.extract_stack())
else:
tbs = deepcopy(self.__tracebacks)
if acquired:
try:
item = await iter_.__anext__()
finally:
async with self.__tracebacks_lock:
self.__tracebacks.pop()
self.__lock.release()
yield item
else:
raise self.__make_error(tbs)

return inner

else:

@classmethod
def non_concurrent_method(cls, f: _TWrapped) -> _TWrapped:
return f

@classmethod
def non_concurrent_iter(cls, f: _TWrappedIter) -> _TWrappedIter:
return f
24 changes: 21 additions & 3 deletions src/neo4j/_async/work/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
Date,
DateTime,
)
from .._debug import AsyncNonConcurrentMethodChecker
from ..io import ConnectionErrorHandler


Expand Down Expand Up @@ -71,7 +72,7 @@
)


class AsyncResult:
class AsyncResult(AsyncNonConcurrentMethodChecker):
"""Handler for the result of Cypher query execution.
Instances of this class are typically constructed and returned by
Expand Down Expand Up @@ -109,6 +110,7 @@ def __init__(self, connection, fetch_size, on_closed, on_error):
self._out_of_scope = False
# exception shared across all results of a transaction
self._exception = None
super().__init__()

async def _connection_error_handler(self, exc):
self._exception = exc
Expand Down Expand Up @@ -251,11 +253,15 @@ def on_success(summary_metadata):
)
self._streaming = True

@AsyncNonConcurrentMethodChecker.non_concurrent_iter
async def __aiter__(self) -> t.AsyncIterator[Record]:
"""Iterator returning Records.
:returns: Record, it is an immutable ordered collection of key-value pairs.
:rtype: :class:`neo4j.Record`
Advancing the iterator advances the underlying result stream.
So even when creating multiple iterators from the same result, each
Record will only be returned once.
:returns: Iterator over the result stream's records.
"""
while self._record_buffer or self._attached:
if self._record_buffer:
Expand All @@ -278,7 +284,9 @@ async def __aiter__(self) -> t.AsyncIterator[Record]:
if self._consumed:
raise ResultConsumedError(self, _RESULT_CONSUMED_ERROR)

@AsyncNonConcurrentMethodChecker.non_concurrent_method
async def __anext__(self) -> Record:
"""Advance the result stream and return the record."""
return await self.__aiter__().__anext__()

async def _attach(self):
Expand Down Expand Up @@ -367,6 +375,7 @@ def _tx_failure(self, exc):
self._attached = False
self._exception = exc

@AsyncNonConcurrentMethodChecker.non_concurrent_method
async def consume(self) -> ResultSummary:
"""Consume the remainder of this result and return a :class:`neo4j.ResultSummary`.
Expand Down Expand Up @@ -434,6 +443,7 @@ async def single(
async def single(self, strict: te.Literal[True]) -> Record:
...

@AsyncNonConcurrentMethodChecker.non_concurrent_method
async def single(self, strict: bool = False) -> t.Optional[Record]:
"""Obtain the next and only remaining record or None.
Expand Down Expand Up @@ -495,6 +505,7 @@ async def single(self, strict: bool = False) -> t.Optional[Record]:
)
return buffer.popleft()

@AsyncNonConcurrentMethodChecker.non_concurrent_method
async def fetch(self, n: int) -> t.List[Record]:
"""Obtain up to n records from this result.
Expand All @@ -517,6 +528,7 @@ async def fetch(self, n: int) -> t.List[Record]:
for _ in range(min(n, len(self._record_buffer)))
]

@AsyncNonConcurrentMethodChecker.non_concurrent_method
async def peek(self) -> t.Optional[Record]:
"""Obtain the next record from this result without consuming it.
Expand All @@ -537,6 +549,7 @@ async def peek(self) -> t.Optional[Record]:
return self._record_buffer[0]
return None

@AsyncNonConcurrentMethodChecker.non_concurrent_method
async def graph(self) -> Graph:
"""Turn the result into a :class:`neo4j.Graph`.
Expand All @@ -559,6 +572,7 @@ async def graph(self) -> Graph:
await self._buffer_all()
return self._hydration_scope.get_graph()

@AsyncNonConcurrentMethodChecker.non_concurrent_method
async def value(
self, key: _TResultKey = 0, default: t.Optional[object] = None
) -> t.List[t.Any]:
Expand All @@ -580,6 +594,7 @@ async def value(
"""
return [record.value(key, default) async for record in self]

@AsyncNonConcurrentMethodChecker.non_concurrent_method
async def values(
self, *keys: _TResultKey
) -> t.List[t.List[t.Any]]:
Expand All @@ -600,6 +615,7 @@ async def values(
"""
return [record.values(*keys) async for record in self]

@AsyncNonConcurrentMethodChecker.non_concurrent_method
async def data(self, *keys: _TResultKey) -> t.List[t.Dict[str, t.Any]]:
"""Return the remainder of the result as a list of dictionaries.
Expand All @@ -626,6 +642,7 @@ async def data(self, *keys: _TResultKey) -> t.List[t.Dict[str, t.Any]]:
"""
return [record.data(*keys) async for record in self]

@AsyncNonConcurrentMethodChecker.non_concurrent_method
async def to_eager_result(self) -> EagerResult:
"""Convert this result to an :class:`.EagerResult`.
Expand All @@ -650,6 +667,7 @@ async def to_eager_result(self) -> EagerResult:
summary=await self.consume()
)

@AsyncNonConcurrentMethodChecker.non_concurrent_method
async def to_df(
self,
expand: bool = False,
Expand Down
Loading

0 comments on commit d130b9c

Please sign in to comment.