-
Notifications
You must be signed in to change notification settings - Fork 188
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Optional check for concurrent usage errors (#989)
Optional check for concurrent usage errors Some driver objects (e.g, Sessions, Transactions, Result streams) are not safe for concurrent use. By default, it will cause hard to interpret errors or, in the worst case, wrong behavior. To aid finding such bugs, the driver now detects if the Python interpreter is running development mode and enables extra locking around those objects. If they are used concurrently, an error will be raised. The way this is implemented, it will only cause a one-time overhead when loading the driver's modules if the checks are disabled. Obviously, those checks are somewhat expensive as they entail locks (less so in the async driver). Therefore, the checks are only happening if either * Python is started in development mode (`python -X dev ...`) or * The environment variable `PYTHONNEO4JDEBUG` is set (to anything non-empty) at the time the driver's modules is loaded.
- Loading branch information
1 parent
f722f65
commit d130b9c
Showing
15 changed files
with
467 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright (c) "Neo4j" | ||
# Neo4j Sweden AB [https://neo4j.com] | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from ._concurrency_check import AsyncNonConcurrentMethodChecker | ||
|
||
|
||
__all__ = ["AsyncNonConcurrentMethodChecker"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
# Copyright (c) "Neo4j" | ||
# Neo4j Sweden AB [https://neo4j.com] | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from __future__ import annotations | ||
|
||
import inspect | ||
import os | ||
import sys | ||
import traceback | ||
import typing as t | ||
from copy import deepcopy | ||
from functools import wraps | ||
|
||
from ..._async_compat.concurrency import ( | ||
AsyncLock, | ||
AsyncRLock, | ||
) | ||
from ..._async_compat.util import AsyncUtil | ||
from ..._meta import copy_signature | ||
|
||
|
||
_TWrapped = t.TypeVar("_TWrapped", bound=t.Callable[..., t.Awaitable[t.Any]]) | ||
_TWrappedIter = t.TypeVar("_TWrappedIter", | ||
bound=t.Callable[..., t.AsyncIterator]) | ||
|
||
|
||
ENABLED = sys.flags.dev_mode or bool(os.getenv("PYTHONNEO4JDEBUG")) | ||
|
||
|
||
class NonConcurrentMethodError(RuntimeError): | ||
pass | ||
|
||
|
||
class AsyncNonConcurrentMethodChecker: | ||
if ENABLED: | ||
|
||
def __init__(self): | ||
self.__lock = AsyncRLock() | ||
self.__tracebacks_lock = AsyncLock() | ||
self.__tracebacks = [] | ||
|
||
def __make_error(self, tbs): | ||
msg = (f"Methods of {self.__class__} are not concurrency " | ||
"safe, but were invoked concurrently.") | ||
if tbs: | ||
msg += ("\n\nOther invocation site:\n\n" | ||
f"{''.join(traceback.format_list(tbs[0]))}") | ||
return NonConcurrentMethodError(msg) | ||
|
||
@classmethod | ||
def non_concurrent_method(cls, f: _TWrapped) -> _TWrapped: | ||
if AsyncUtil.is_async_code: | ||
if not inspect.iscoroutinefunction(f): | ||
raise TypeError( | ||
"cannot decorate non-coroutine function with " | ||
"AsyncNonConcurrentMethodChecked.non_concurrent_method" | ||
) | ||
else: | ||
if not callable(f): | ||
raise TypeError( | ||
"cannot decorate non-callable object with " | ||
"NonConcurrentMethodChecked.non_concurrent_method" | ||
) | ||
|
||
@wraps(f) | ||
@copy_signature(f) | ||
async def inner(*args, **kwargs): | ||
self = args[0] | ||
assert isinstance(self, cls) | ||
|
||
async with self.__tracebacks_lock: | ||
acquired = await self.__lock.acquire(blocking=False) | ||
if acquired: | ||
self.__tracebacks.append(AsyncUtil.extract_stack()) | ||
else: | ||
tbs = deepcopy(self.__tracebacks) | ||
if acquired: | ||
try: | ||
return await f(*args, **kwargs) | ||
finally: | ||
async with self.__tracebacks_lock: | ||
self.__tracebacks.pop() | ||
self.__lock.release() | ||
else: | ||
raise self.__make_error(tbs) | ||
|
||
return inner | ||
|
||
@classmethod | ||
def non_concurrent_iter(cls, f: _TWrappedIter) -> _TWrappedIter: | ||
if AsyncUtil.is_async_code: | ||
if not inspect.isasyncgenfunction(f): | ||
raise TypeError( | ||
"cannot decorate non-async-generator function with " | ||
"AsyncNonConcurrentMethodChecked.non_concurrent_iter" | ||
) | ||
else: | ||
if not inspect.isgeneratorfunction(f): | ||
raise TypeError( | ||
"cannot decorate non-generator function with " | ||
"NonConcurrentMethodChecked.non_concurrent_iter" | ||
) | ||
|
||
@wraps(f) | ||
@copy_signature(f) | ||
async def inner(*args, **kwargs): | ||
self = args[0] | ||
assert isinstance(self, cls) | ||
|
||
iter_ = f(*args, **kwargs) | ||
while True: | ||
async with self.__tracebacks_lock: | ||
acquired = await self.__lock.acquire(blocking=False) | ||
if acquired: | ||
self.__tracebacks.append(AsyncUtil.extract_stack()) | ||
else: | ||
tbs = deepcopy(self.__tracebacks) | ||
if acquired: | ||
try: | ||
item = await iter_.__anext__() | ||
finally: | ||
async with self.__tracebacks_lock: | ||
self.__tracebacks.pop() | ||
self.__lock.release() | ||
yield item | ||
else: | ||
raise self.__make_error(tbs) | ||
|
||
return inner | ||
|
||
else: | ||
|
||
@classmethod | ||
def non_concurrent_method(cls, f: _TWrapped) -> _TWrapped: | ||
return f | ||
|
||
@classmethod | ||
def non_concurrent_iter(cls, f: _TWrappedIter) -> _TWrappedIter: | ||
return f |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.