From d130b9c41f4d2d38f76279367162d99a2b088527 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Mon, 27 Nov 2023 15:00:56 +0100 Subject: [PATCH] Optional check for concurrent usage errors (#989) Optional check for concurrent usage errors Some driver objects (e.g, Sessions, Transactions, Result streams) are not safe for concurrent use. By default, it will cause hard to interpret errors or, in the worst case, wrong behavior. To aid finding such bugs, the driver now detects if the Python interpreter is running development mode and enables extra locking around those objects. If they are used concurrently, an error will be raised. The way this is implemented, it will only cause a one-time overhead when loading the driver's modules if the checks are disabled. Obviously, those checks are somewhat expensive as they entail locks (less so in the async driver). Therefore, the checks are only happening if either * Python is started in development mode (`python -X dev ...`) or * The environment variable `PYTHONNEO4JDEBUG` is set (to anything non-empty) at the time the driver's modules is loaded. --- docs/source/index.rst | 19 +++ .../themes/neo4j/static/css/neo4j.css_t | 19 +-- src/neo4j/_async/_debug/__init__.py | 20 +++ src/neo4j/_async/_debug/_concurrency_check.py | 152 ++++++++++++++++++ src/neo4j/_async/work/result.py | 24 ++- src/neo4j/_async/work/session.py | 9 ++ src/neo4j/_async/work/transaction.py | 10 +- src/neo4j/_async/work/workspace.py | 5 +- src/neo4j/_async_compat/util.py | 13 ++ src/neo4j/_sync/_debug/__init__.py | 20 +++ src/neo4j/_sync/_debug/_concurrency_check.py | 152 ++++++++++++++++++ src/neo4j/_sync/work/result.py | 24 ++- src/neo4j/_sync/work/session.py | 9 ++ src/neo4j/_sync/work/transaction.py | 10 +- src/neo4j/_sync/work/workspace.py | 5 +- 15 files changed, 467 insertions(+), 24 deletions(-) create mode 100644 src/neo4j/_async/_debug/__init__.py create mode 100644 src/neo4j/_async/_debug/_concurrency_check.py create mode 100644 src/neo4j/_sync/_debug/__init__.py create mode 100644 src/neo4j/_sync/_debug/_concurrency_check.py diff --git a/docs/source/index.rst b/docs/source/index.rst index c7ab462a0..0bfb00637 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -99,6 +99,25 @@ To deactivate the current active virtual environment, use: deactivate +Development Environment +======================= + +For development, we recommend to run Python in `development mode`_ (``python -X dev ...``). +Specifically for this driver, this will: + + * enable :class:`ResourceWarning`, which the driver emits if resources (e.g., Sessions) aren't properly closed. + * enable :class:`DeprecationWarning`, which the driver emits if deprecated APIs are used. + * enable the driver's debug mode (this can also be achieved by setting the environment variable ``PYTHONNEO4JDEBUG``): + + * **This is experimental**. + It might be changed or removed any time even without prior notice. + * the driver will raise an exception if non-concurrency-safe methods are used concurrently. + + .. versionadded:: 5.15 + +.. _development mode: https://docs.python.org/3/library/devmode.html + + ************* Quick Example ************* diff --git a/docs/source/themes/neo4j/static/css/neo4j.css_t b/docs/source/themes/neo4j/static/css/neo4j.css_t index c246ede27..39bf61482 100644 --- a/docs/source/themes/neo4j/static/css/neo4j.css_t +++ b/docs/source/themes/neo4j/static/css/neo4j.css_t @@ -503,25 +503,16 @@ dl.field-list > dd > ol { margin-left: 0; } -ol.simple p, ul.simple p { - margin-bottom: 0; -} - -ol.simple > li:not(:first-child) > p, -ul.simple > li:not(:first-child) > p, -:not(li) > ol > li:first-child > :first-child, -:not(li) > ul > li:first-child > :first-child { +.content ol li > p:first-of-type, +.content ul li > p:first-of-type { margin-top: 0; } - -li > p:last-child { - margin-top: 10px; +.content ol li > p:last-of-type, +.content ul li > p:last-of-type { + margin-bottom: 0; } -li > p:first-child { - margin-top: 10px; -} table.docutils { margin-top: 10px; diff --git a/src/neo4j/_async/_debug/__init__.py b/src/neo4j/_async/_debug/__init__.py new file mode 100644 index 000000000..c02365c3c --- /dev/null +++ b/src/neo4j/_async/_debug/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ._concurrency_check import AsyncNonConcurrentMethodChecker + + +__all__ = ["AsyncNonConcurrentMethodChecker"] diff --git a/src/neo4j/_async/_debug/_concurrency_check.py b/src/neo4j/_async/_debug/_concurrency_check.py new file mode 100644 index 000000000..478884921 --- /dev/null +++ b/src/neo4j/_async/_debug/_concurrency_check.py @@ -0,0 +1,152 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import inspect +import os +import sys +import traceback +import typing as t +from copy import deepcopy +from functools import wraps + +from ..._async_compat.concurrency import ( + AsyncLock, + AsyncRLock, +) +from ..._async_compat.util import AsyncUtil +from ..._meta import copy_signature + + +_TWrapped = t.TypeVar("_TWrapped", bound=t.Callable[..., t.Awaitable[t.Any]]) +_TWrappedIter = t.TypeVar("_TWrappedIter", + bound=t.Callable[..., t.AsyncIterator]) + + +ENABLED = sys.flags.dev_mode or bool(os.getenv("PYTHONNEO4JDEBUG")) + + +class NonConcurrentMethodError(RuntimeError): + pass + + +class AsyncNonConcurrentMethodChecker: + if ENABLED: + + def __init__(self): + self.__lock = AsyncRLock() + self.__tracebacks_lock = AsyncLock() + self.__tracebacks = [] + + def __make_error(self, tbs): + msg = (f"Methods of {self.__class__} are not concurrency " + "safe, but were invoked concurrently.") + if tbs: + msg += ("\n\nOther invocation site:\n\n" + f"{''.join(traceback.format_list(tbs[0]))}") + return NonConcurrentMethodError(msg) + + @classmethod + def non_concurrent_method(cls, f: _TWrapped) -> _TWrapped: + if AsyncUtil.is_async_code: + if not inspect.iscoroutinefunction(f): + raise TypeError( + "cannot decorate non-coroutine function with " + "AsyncNonConcurrentMethodChecked.non_concurrent_method" + ) + else: + if not callable(f): + raise TypeError( + "cannot decorate non-callable object with " + "NonConcurrentMethodChecked.non_concurrent_method" + ) + + @wraps(f) + @copy_signature(f) + async def inner(*args, **kwargs): + self = args[0] + assert isinstance(self, cls) + + async with self.__tracebacks_lock: + acquired = await self.__lock.acquire(blocking=False) + if acquired: + self.__tracebacks.append(AsyncUtil.extract_stack()) + else: + tbs = deepcopy(self.__tracebacks) + if acquired: + try: + return await f(*args, **kwargs) + finally: + async with self.__tracebacks_lock: + self.__tracebacks.pop() + self.__lock.release() + else: + raise self.__make_error(tbs) + + return inner + + @classmethod + def non_concurrent_iter(cls, f: _TWrappedIter) -> _TWrappedIter: + if AsyncUtil.is_async_code: + if not inspect.isasyncgenfunction(f): + raise TypeError( + "cannot decorate non-async-generator function with " + "AsyncNonConcurrentMethodChecked.non_concurrent_iter" + ) + else: + if not inspect.isgeneratorfunction(f): + raise TypeError( + "cannot decorate non-generator function with " + "NonConcurrentMethodChecked.non_concurrent_iter" + ) + + @wraps(f) + @copy_signature(f) + async def inner(*args, **kwargs): + self = args[0] + assert isinstance(self, cls) + + iter_ = f(*args, **kwargs) + while True: + async with self.__tracebacks_lock: + acquired = await self.__lock.acquire(blocking=False) + if acquired: + self.__tracebacks.append(AsyncUtil.extract_stack()) + else: + tbs = deepcopy(self.__tracebacks) + if acquired: + try: + item = await iter_.__anext__() + finally: + async with self.__tracebacks_lock: + self.__tracebacks.pop() + self.__lock.release() + yield item + else: + raise self.__make_error(tbs) + + return inner + + else: + + @classmethod + def non_concurrent_method(cls, f: _TWrapped) -> _TWrapped: + return f + + @classmethod + def non_concurrent_iter(cls, f: _TWrappedIter) -> _TWrappedIter: + return f diff --git a/src/neo4j/_async/work/result.py b/src/neo4j/_async/work/result.py index c74588fc4..6f1d8d2d2 100644 --- a/src/neo4j/_async/work/result.py +++ b/src/neo4j/_async/work/result.py @@ -43,6 +43,7 @@ Date, DateTime, ) +from .._debug import AsyncNonConcurrentMethodChecker from ..io import ConnectionErrorHandler @@ -71,7 +72,7 @@ ) -class AsyncResult: +class AsyncResult(AsyncNonConcurrentMethodChecker): """Handler for the result of Cypher query execution. Instances of this class are typically constructed and returned by @@ -109,6 +110,7 @@ def __init__(self, connection, fetch_size, on_closed, on_error): self._out_of_scope = False # exception shared across all results of a transaction self._exception = None + super().__init__() async def _connection_error_handler(self, exc): self._exception = exc @@ -251,11 +253,15 @@ def on_success(summary_metadata): ) self._streaming = True + @AsyncNonConcurrentMethodChecker.non_concurrent_iter async def __aiter__(self) -> t.AsyncIterator[Record]: """Iterator returning Records. - :returns: Record, it is an immutable ordered collection of key-value pairs. - :rtype: :class:`neo4j.Record` + Advancing the iterator advances the underlying result stream. + So even when creating multiple iterators from the same result, each + Record will only be returned once. + + :returns: Iterator over the result stream's records. """ while self._record_buffer or self._attached: if self._record_buffer: @@ -278,7 +284,9 @@ async def __aiter__(self) -> t.AsyncIterator[Record]: if self._consumed: raise ResultConsumedError(self, _RESULT_CONSUMED_ERROR) + @AsyncNonConcurrentMethodChecker.non_concurrent_method async def __anext__(self) -> Record: + """Advance the result stream and return the record.""" return await self.__aiter__().__anext__() async def _attach(self): @@ -367,6 +375,7 @@ def _tx_failure(self, exc): self._attached = False self._exception = exc + @AsyncNonConcurrentMethodChecker.non_concurrent_method async def consume(self) -> ResultSummary: """Consume the remainder of this result and return a :class:`neo4j.ResultSummary`. @@ -434,6 +443,7 @@ async def single( async def single(self, strict: te.Literal[True]) -> Record: ... + @AsyncNonConcurrentMethodChecker.non_concurrent_method async def single(self, strict: bool = False) -> t.Optional[Record]: """Obtain the next and only remaining record or None. @@ -495,6 +505,7 @@ async def single(self, strict: bool = False) -> t.Optional[Record]: ) return buffer.popleft() + @AsyncNonConcurrentMethodChecker.non_concurrent_method async def fetch(self, n: int) -> t.List[Record]: """Obtain up to n records from this result. @@ -517,6 +528,7 @@ async def fetch(self, n: int) -> t.List[Record]: for _ in range(min(n, len(self._record_buffer))) ] + @AsyncNonConcurrentMethodChecker.non_concurrent_method async def peek(self) -> t.Optional[Record]: """Obtain the next record from this result without consuming it. @@ -537,6 +549,7 @@ async def peek(self) -> t.Optional[Record]: return self._record_buffer[0] return None + @AsyncNonConcurrentMethodChecker.non_concurrent_method async def graph(self) -> Graph: """Turn the result into a :class:`neo4j.Graph`. @@ -559,6 +572,7 @@ async def graph(self) -> Graph: await self._buffer_all() return self._hydration_scope.get_graph() + @AsyncNonConcurrentMethodChecker.non_concurrent_method async def value( self, key: _TResultKey = 0, default: t.Optional[object] = None ) -> t.List[t.Any]: @@ -580,6 +594,7 @@ async def value( """ return [record.value(key, default) async for record in self] + @AsyncNonConcurrentMethodChecker.non_concurrent_method async def values( self, *keys: _TResultKey ) -> t.List[t.List[t.Any]]: @@ -600,6 +615,7 @@ async def values( """ return [record.values(*keys) async for record in self] + @AsyncNonConcurrentMethodChecker.non_concurrent_method async def data(self, *keys: _TResultKey) -> t.List[t.Dict[str, t.Any]]: """Return the remainder of the result as a list of dictionaries. @@ -626,6 +642,7 @@ async def data(self, *keys: _TResultKey) -> t.List[t.Dict[str, t.Any]]: """ return [record.data(*keys) async for record in self] + @AsyncNonConcurrentMethodChecker.non_concurrent_method async def to_eager_result(self) -> EagerResult: """Convert this result to an :class:`.EagerResult`. @@ -650,6 +667,7 @@ async def to_eager_result(self) -> EagerResult: summary=await self.consume() ) + @AsyncNonConcurrentMethodChecker.non_concurrent_method async def to_df( self, expand: bool = False, diff --git a/src/neo4j/_async/work/session.py b/src/neo4j/_async/work/session.py index b1b5a5c4a..b7b4452a4 100644 --- a/src/neo4j/_async/work/session.py +++ b/src/neo4j/_async/work/session.py @@ -42,6 +42,7 @@ SessionExpired, TransactionError, ) +from .._debug import AsyncNonConcurrentMethodChecker from ..auth_management import AsyncAuthManagers from .result import AsyncResult from .transaction import ( @@ -178,6 +179,7 @@ async def _verify_authentication(self): await self._connect(READ_ACCESS, force_auth=True) await self._disconnect() + @AsyncNonConcurrentMethodChecker.non_concurrent_method async def close(self) -> None: """Close the session. @@ -247,6 +249,7 @@ def cancel(self) -> None: """ self._handle_cancellation(message="manual cancel") + @AsyncNonConcurrentMethodChecker.non_concurrent_method async def run( self, query: t.Union[te.LiteralString, Query], @@ -320,6 +323,7 @@ async def run( "`last_bookmark` has been deprecated in favor of `last_bookmarks`. " "This method can lead to unexpected behaviour." ) + @AsyncNonConcurrentMethodChecker.non_concurrent_method async def last_bookmark(self) -> t.Optional[str]: """Get the bookmark received following the last completed transaction. @@ -434,6 +438,7 @@ async def _open_transaction( pipelined=self._pipelined_begin ) + @AsyncNonConcurrentMethodChecker.non_concurrent_method async def begin_transaction( self, metadata: t.Optional[t.Dict[str, t.Any]] = None, @@ -583,6 +588,7 @@ def api_success_cb(meta): else: raise ServiceUnavailable("Transaction failed") + @AsyncNonConcurrentMethodChecker.non_concurrent_method async def execute_read( self, transaction_function: t.Callable[ @@ -658,6 +664,7 @@ async def get_two_tx(tx): # TODO: 6.0 - Remove this method @deprecated("read_transaction has been renamed to execute_read") + @AsyncNonConcurrentMethodChecker.non_concurrent_method async def read_transaction( self, transaction_function: t.Callable[ @@ -695,6 +702,7 @@ async def read_transaction( transaction_function, args, kwargs ) + @AsyncNonConcurrentMethodChecker.non_concurrent_method async def execute_write( self, transaction_function: t.Callable[ @@ -752,6 +760,7 @@ async def create_node_tx(tx, name): # TODO: 6.0 - Remove this method @deprecated("write_transaction has been renamed to execute_write") + @AsyncNonConcurrentMethodChecker.non_concurrent_method async def write_transaction( self, transaction_function: t.Callable[ diff --git a/src/neo4j/_async/work/transaction.py b/src/neo4j/_async/work/transaction.py index e806c1f09..d5b7b1a71 100644 --- a/src/neo4j/_async/work/transaction.py +++ b/src/neo4j/_async/work/transaction.py @@ -23,6 +23,7 @@ from ..._async_compat.util import AsyncUtil from ..._work import Query from ...exceptions import TransactionError +from .._debug import AsyncNonConcurrentMethodChecker from ..io import ConnectionErrorHandler from .result import AsyncResult @@ -38,7 +39,7 @@ ) -class AsyncTransactionBase: +class AsyncTransactionBase(AsyncNonConcurrentMethodChecker): def __init__(self, connection, fetch_size, on_closed, on_error, on_cancel): self._connection = connection @@ -54,10 +55,12 @@ def __init__(self, connection, fetch_size, on_closed, on_error, self._on_closed = on_closed self._on_error = on_error self._on_cancel = on_cancel + super().__init__() async def _enter(self): return self + @AsyncNonConcurrentMethodChecker.non_concurrent_method async def _exit(self, exception_type, exception_value, traceback): if self._closed_flag: return @@ -69,6 +72,7 @@ async def _exit(self, exception_type, exception_value, traceback): return await self._close() + @AsyncNonConcurrentMethodChecker.non_concurrent_method async def _begin( self, database, imp_user, bookmarks, access_mode, metadata, timeout, notifications_min_severity, notifications_disabled_categories, @@ -102,6 +106,7 @@ async def _consume_results(self): await result._tx_end() self._results = [] + @AsyncNonConcurrentMethodChecker.non_concurrent_method async def run( self, query: te.LiteralString, @@ -165,6 +170,7 @@ async def run( return result + @AsyncNonConcurrentMethodChecker.non_concurrent_method async def _commit(self): """Mark this transaction as successful and close in order to trigger a COMMIT. @@ -194,6 +200,7 @@ async def _commit(self): return self._bookmark + @AsyncNonConcurrentMethodChecker.non_concurrent_method async def _rollback(self): """Mark this transaction as unsuccessful and close in order to trigger a ROLLBACK. @@ -219,6 +226,7 @@ async def _rollback(self): self._closed_flag = True await AsyncUtil.callback(self._on_closed) + @AsyncNonConcurrentMethodChecker.non_concurrent_method async def _close(self): """Close this transaction, triggering a ROLLBACK if not closed. """ diff --git a/src/neo4j/_async/work/workspace.py b/src/neo4j/_async/work/workspace.py index 5434963bf..cc1ee246c 100644 --- a/src/neo4j/_async/work/workspace.py +++ b/src/neo4j/_async/work/workspace.py @@ -31,6 +31,7 @@ SessionError, SessionExpired, ) +from .._debug import AsyncNonConcurrentMethodChecker from ..io import ( AcquireAuth, AsyncNeo4jPool, @@ -40,7 +41,7 @@ log = logging.getLogger("neo4j") -class AsyncWorkspace: +class AsyncWorkspace(AsyncNonConcurrentMethodChecker): def __init__(self, pool, config): assert isinstance(config, WorkspaceConfig) @@ -56,6 +57,7 @@ def __init__(self, pool, config): self._last_from_bookmark_manager = None # Workspace has been closed. self._closed = False + super().__init__() def __del__(self): if self._closed: @@ -189,6 +191,7 @@ async def _disconnect(self, sync=False): self._connection = None self._connection_access_mode = None + @AsyncNonConcurrentMethodChecker.non_concurrent_method async def close(self) -> None: if self._closed: return diff --git a/src/neo4j/_async_compat/util.py b/src/neo4j/_async_compat/util.py index 878b9dd87..3bafba1c3 100644 --- a/src/neo4j/_async_compat/util.py +++ b/src/neo4j/_async_compat/util.py @@ -18,6 +18,7 @@ import asyncio import inspect +import traceback import typing as t from functools import wraps @@ -86,6 +87,14 @@ async def shielded_function(*args, **kwargs): is_async_code: t.ClassVar = True + @staticmethod + def extract_stack(limit=None): + # can maybe be improved in the future + # https://github.com/python/cpython/issues/91048 + stack = asyncio.current_task().get_stack(limit=limit) + stack_walk = ((f, f.f_lineno) for f in stack) + return traceback.StackSummary.extract(stack_walk, limit=limit) + class Util: iter: t.ClassVar = iter @@ -113,3 +122,7 @@ def shielded(coro_function): return coro_function is_async_code: t.ClassVar = False + + @staticmethod + def extract_stack(limit=None): + return traceback.extract_stack(limit=limit)[:-1] diff --git a/src/neo4j/_sync/_debug/__init__.py b/src/neo4j/_sync/_debug/__init__.py new file mode 100644 index 000000000..8764ebb39 --- /dev/null +++ b/src/neo4j/_sync/_debug/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ._concurrency_check import NonConcurrentMethodChecker + + +__all__ = ["NonConcurrentMethodChecker"] diff --git a/src/neo4j/_sync/_debug/_concurrency_check.py b/src/neo4j/_sync/_debug/_concurrency_check.py new file mode 100644 index 000000000..6a5ef9b3c --- /dev/null +++ b/src/neo4j/_sync/_debug/_concurrency_check.py @@ -0,0 +1,152 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import inspect +import os +import sys +import traceback +import typing as t +from copy import deepcopy +from functools import wraps + +from ..._async_compat.concurrency import ( + Lock, + RLock, +) +from ..._async_compat.util import Util +from ..._meta import copy_signature + + +_TWrapped = t.TypeVar("_TWrapped", bound=t.Callable[..., t.Union[t.Any]]) +_TWrappedIter = t.TypeVar("_TWrappedIter", + bound=t.Callable[..., t.Iterator]) + + +ENABLED = sys.flags.dev_mode or bool(os.getenv("PYTHONNEO4JDEBUG")) + + +class NonConcurrentMethodError(RuntimeError): + pass + + +class NonConcurrentMethodChecker: + if ENABLED: + + def __init__(self): + self.__lock = RLock() + self.__tracebacks_lock = Lock() + self.__tracebacks = [] + + def __make_error(self, tbs): + msg = (f"Methods of {self.__class__} are not concurrency " + "safe, but were invoked concurrently.") + if tbs: + msg += ("\n\nOther invocation site:\n\n" + f"{''.join(traceback.format_list(tbs[0]))}") + return NonConcurrentMethodError(msg) + + @classmethod + def non_concurrent_method(cls, f: _TWrapped) -> _TWrapped: + if Util.is_async_code: + if not inspect.iscoroutinefunction(f): + raise TypeError( + "cannot decorate non-coroutine function with " + "NonConcurrentMethodChecked.non_concurrent_method" + ) + else: + if not callable(f): + raise TypeError( + "cannot decorate non-callable object with " + "NonConcurrentMethodChecked.non_concurrent_method" + ) + + @wraps(f) + @copy_signature(f) + def inner(*args, **kwargs): + self = args[0] + assert isinstance(self, cls) + + with self.__tracebacks_lock: + acquired = self.__lock.acquire(blocking=False) + if acquired: + self.__tracebacks.append(Util.extract_stack()) + else: + tbs = deepcopy(self.__tracebacks) + if acquired: + try: + return f(*args, **kwargs) + finally: + with self.__tracebacks_lock: + self.__tracebacks.pop() + self.__lock.release() + else: + raise self.__make_error(tbs) + + return inner + + @classmethod + def non_concurrent_iter(cls, f: _TWrappedIter) -> _TWrappedIter: + if Util.is_async_code: + if not inspect.isasyncgenfunction(f): + raise TypeError( + "cannot decorate non-async-generator function with " + "NonConcurrentMethodChecked.non_concurrent_iter" + ) + else: + if not inspect.isgeneratorfunction(f): + raise TypeError( + "cannot decorate non-generator function with " + "NonConcurrentMethodChecked.non_concurrent_iter" + ) + + @wraps(f) + @copy_signature(f) + def inner(*args, **kwargs): + self = args[0] + assert isinstance(self, cls) + + iter_ = f(*args, **kwargs) + while True: + with self.__tracebacks_lock: + acquired = self.__lock.acquire(blocking=False) + if acquired: + self.__tracebacks.append(Util.extract_stack()) + else: + tbs = deepcopy(self.__tracebacks) + if acquired: + try: + item = iter_.__next__() + finally: + with self.__tracebacks_lock: + self.__tracebacks.pop() + self.__lock.release() + yield item + else: + raise self.__make_error(tbs) + + return inner + + else: + + @classmethod + def non_concurrent_method(cls, f: _TWrapped) -> _TWrapped: + return f + + @classmethod + def non_concurrent_iter(cls, f: _TWrappedIter) -> _TWrappedIter: + return f diff --git a/src/neo4j/_sync/work/result.py b/src/neo4j/_sync/work/result.py index 634064094..58fc704bd 100644 --- a/src/neo4j/_sync/work/result.py +++ b/src/neo4j/_sync/work/result.py @@ -43,6 +43,7 @@ Date, DateTime, ) +from .._debug import NonConcurrentMethodChecker from ..io import ConnectionErrorHandler @@ -71,7 +72,7 @@ ) -class Result: +class Result(NonConcurrentMethodChecker): """Handler for the result of Cypher query execution. Instances of this class are typically constructed and returned by @@ -109,6 +110,7 @@ def __init__(self, connection, fetch_size, on_closed, on_error): self._out_of_scope = False # exception shared across all results of a transaction self._exception = None + super().__init__() def _connection_error_handler(self, exc): self._exception = exc @@ -251,11 +253,15 @@ def on_success(summary_metadata): ) self._streaming = True + @NonConcurrentMethodChecker.non_concurrent_iter def __iter__(self) -> t.Iterator[Record]: """Iterator returning Records. - :returns: Record, it is an immutable ordered collection of key-value pairs. - :rtype: :class:`neo4j.Record` + Advancing the iterator advances the underlying result stream. + So even when creating multiple iterators from the same result, each + Record will only be returned once. + + :returns: Iterator over the result stream's records. """ while self._record_buffer or self._attached: if self._record_buffer: @@ -278,7 +284,9 @@ def __iter__(self) -> t.Iterator[Record]: if self._consumed: raise ResultConsumedError(self, _RESULT_CONSUMED_ERROR) + @NonConcurrentMethodChecker.non_concurrent_method def __next__(self) -> Record: + """Advance the result stream and return the record.""" return self.__iter__().__next__() def _attach(self): @@ -367,6 +375,7 @@ def _tx_failure(self, exc): self._attached = False self._exception = exc + @NonConcurrentMethodChecker.non_concurrent_method def consume(self) -> ResultSummary: """Consume the remainder of this result and return a :class:`neo4j.ResultSummary`. @@ -434,6 +443,7 @@ def single( def single(self, strict: te.Literal[True]) -> Record: ... + @NonConcurrentMethodChecker.non_concurrent_method def single(self, strict: bool = False) -> t.Optional[Record]: """Obtain the next and only remaining record or None. @@ -495,6 +505,7 @@ def single(self, strict: bool = False) -> t.Optional[Record]: ) return buffer.popleft() + @NonConcurrentMethodChecker.non_concurrent_method def fetch(self, n: int) -> t.List[Record]: """Obtain up to n records from this result. @@ -517,6 +528,7 @@ def fetch(self, n: int) -> t.List[Record]: for _ in range(min(n, len(self._record_buffer))) ] + @NonConcurrentMethodChecker.non_concurrent_method def peek(self) -> t.Optional[Record]: """Obtain the next record from this result without consuming it. @@ -537,6 +549,7 @@ def peek(self) -> t.Optional[Record]: return self._record_buffer[0] return None + @NonConcurrentMethodChecker.non_concurrent_method def graph(self) -> Graph: """Turn the result into a :class:`neo4j.Graph`. @@ -559,6 +572,7 @@ def graph(self) -> Graph: self._buffer_all() return self._hydration_scope.get_graph() + @NonConcurrentMethodChecker.non_concurrent_method def value( self, key: _TResultKey = 0, default: t.Optional[object] = None ) -> t.List[t.Any]: @@ -580,6 +594,7 @@ def value( """ return [record.value(key, default) for record in self] + @NonConcurrentMethodChecker.non_concurrent_method def values( self, *keys: _TResultKey ) -> t.List[t.List[t.Any]]: @@ -600,6 +615,7 @@ def values( """ return [record.values(*keys) for record in self] + @NonConcurrentMethodChecker.non_concurrent_method def data(self, *keys: _TResultKey) -> t.List[t.Dict[str, t.Any]]: """Return the remainder of the result as a list of dictionaries. @@ -626,6 +642,7 @@ def data(self, *keys: _TResultKey) -> t.List[t.Dict[str, t.Any]]: """ return [record.data(*keys) for record in self] + @NonConcurrentMethodChecker.non_concurrent_method def to_eager_result(self) -> EagerResult: """Convert this result to an :class:`.EagerResult`. @@ -650,6 +667,7 @@ def to_eager_result(self) -> EagerResult: summary=self.consume() ) + @NonConcurrentMethodChecker.non_concurrent_method def to_df( self, expand: bool = False, diff --git a/src/neo4j/_sync/work/session.py b/src/neo4j/_sync/work/session.py index cb90ac09c..1479b90b5 100644 --- a/src/neo4j/_sync/work/session.py +++ b/src/neo4j/_sync/work/session.py @@ -42,6 +42,7 @@ SessionExpired, TransactionError, ) +from .._debug import NonConcurrentMethodChecker from ..auth_management import AuthManagers from .result import Result from .transaction import ( @@ -178,6 +179,7 @@ def _verify_authentication(self): self._connect(READ_ACCESS, force_auth=True) self._disconnect() + @NonConcurrentMethodChecker.non_concurrent_method def close(self) -> None: """Close the session. @@ -247,6 +249,7 @@ def cancel(self) -> None: """ self._handle_cancellation(message="manual cancel") + @NonConcurrentMethodChecker.non_concurrent_method def run( self, query: t.Union[te.LiteralString, Query], @@ -320,6 +323,7 @@ def run( "`last_bookmark` has been deprecated in favor of `last_bookmarks`. " "This method can lead to unexpected behaviour." ) + @NonConcurrentMethodChecker.non_concurrent_method def last_bookmark(self) -> t.Optional[str]: """Get the bookmark received following the last completed transaction. @@ -434,6 +438,7 @@ def _open_transaction( pipelined=self._pipelined_begin ) + @NonConcurrentMethodChecker.non_concurrent_method def begin_transaction( self, metadata: t.Optional[t.Dict[str, t.Any]] = None, @@ -583,6 +588,7 @@ def api_success_cb(meta): else: raise ServiceUnavailable("Transaction failed") + @NonConcurrentMethodChecker.non_concurrent_method def execute_read( self, transaction_function: t.Callable[ @@ -658,6 +664,7 @@ def get_two_tx(tx): # TODO: 6.0 - Remove this method @deprecated("read_transaction has been renamed to execute_read") + @NonConcurrentMethodChecker.non_concurrent_method def read_transaction( self, transaction_function: t.Callable[ @@ -695,6 +702,7 @@ def read_transaction( transaction_function, args, kwargs ) + @NonConcurrentMethodChecker.non_concurrent_method def execute_write( self, transaction_function: t.Callable[ @@ -752,6 +760,7 @@ def create_node_tx(tx, name): # TODO: 6.0 - Remove this method @deprecated("write_transaction has been renamed to execute_write") + @NonConcurrentMethodChecker.non_concurrent_method def write_transaction( self, transaction_function: t.Callable[ diff --git a/src/neo4j/_sync/work/transaction.py b/src/neo4j/_sync/work/transaction.py index 3e029cb1d..ea8ddac88 100644 --- a/src/neo4j/_sync/work/transaction.py +++ b/src/neo4j/_sync/work/transaction.py @@ -23,6 +23,7 @@ from ..._async_compat.util import Util from ..._work import Query from ...exceptions import TransactionError +from .._debug import NonConcurrentMethodChecker from ..io import ConnectionErrorHandler from .result import Result @@ -38,7 +39,7 @@ ) -class TransactionBase: +class TransactionBase(NonConcurrentMethodChecker): def __init__(self, connection, fetch_size, on_closed, on_error, on_cancel): self._connection = connection @@ -54,10 +55,12 @@ def __init__(self, connection, fetch_size, on_closed, on_error, self._on_closed = on_closed self._on_error = on_error self._on_cancel = on_cancel + super().__init__() def _enter(self): return self + @NonConcurrentMethodChecker.non_concurrent_method def _exit(self, exception_type, exception_value, traceback): if self._closed_flag: return @@ -69,6 +72,7 @@ def _exit(self, exception_type, exception_value, traceback): return self._close() + @NonConcurrentMethodChecker.non_concurrent_method def _begin( self, database, imp_user, bookmarks, access_mode, metadata, timeout, notifications_min_severity, notifications_disabled_categories, @@ -102,6 +106,7 @@ def _consume_results(self): result._tx_end() self._results = [] + @NonConcurrentMethodChecker.non_concurrent_method def run( self, query: te.LiteralString, @@ -165,6 +170,7 @@ def run( return result + @NonConcurrentMethodChecker.non_concurrent_method def _commit(self): """Mark this transaction as successful and close in order to trigger a COMMIT. @@ -194,6 +200,7 @@ def _commit(self): return self._bookmark + @NonConcurrentMethodChecker.non_concurrent_method def _rollback(self): """Mark this transaction as unsuccessful and close in order to trigger a ROLLBACK. @@ -219,6 +226,7 @@ def _rollback(self): self._closed_flag = True Util.callback(self._on_closed) + @NonConcurrentMethodChecker.non_concurrent_method def _close(self): """Close this transaction, triggering a ROLLBACK if not closed. """ diff --git a/src/neo4j/_sync/work/workspace.py b/src/neo4j/_sync/work/workspace.py index b0780dab6..31f5bd111 100644 --- a/src/neo4j/_sync/work/workspace.py +++ b/src/neo4j/_sync/work/workspace.py @@ -31,6 +31,7 @@ SessionError, SessionExpired, ) +from .._debug import NonConcurrentMethodChecker from ..io import ( AcquireAuth, Neo4jPool, @@ -40,7 +41,7 @@ log = logging.getLogger("neo4j") -class Workspace: +class Workspace(NonConcurrentMethodChecker): def __init__(self, pool, config): assert isinstance(config, WorkspaceConfig) @@ -56,6 +57,7 @@ def __init__(self, pool, config): self._last_from_bookmark_manager = None # Workspace has been closed. self._closed = False + super().__init__() def __del__(self): if self._closed: @@ -189,6 +191,7 @@ def _disconnect(self, sync=False): self._connection = None self._connection_access_mode = None + @NonConcurrentMethodChecker.non_concurrent_method def close(self) -> None: if self._closed: return