Skip to content

Commit

Permalink
Accept transaction config for execute_query (#991)
Browse files Browse the repository at this point in the history
`Driver.execute_query` now accepts a `Query` object to specify transaction
config like metadata and transaction timeout. Example:

```python
from neo4j import (
    GraphDatabase,
    Query,
)

with GraphDatabase.driver(...) as driver:
    driver.execute_query(
        Query(
            "MATCH (n) RETURN n",
            # metadata to be logged with the transaction
            metadata={"foo": "bar"},
            # give the transaction 5 seconds to complete on the DBMS
            timeout=5,
        ),
        # all the other configuration options as before
        database_="neo4j",
        # ...
    )
```
  • Loading branch information
robsdedude authored Nov 27, 2023
1 parent 17c6097 commit 656c796
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 49 deletions.
18 changes: 13 additions & 5 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,9 @@ Closing a driver will immediately shut down all connections in the pool.
query_, parameters_, routing_, database_, impersonated_user_,
bookmark_manager_, auth_, result_transformer_, **kwargs
):
@unit_of_work(query_.metadata, query_.timeout)
def work(tx):
result = tx.run(query_, parameters_, **kwargs)
result = tx.run(query_.text, parameters_, **kwargs)
return result_transformer_(result)

with driver.session(
Expand Down Expand Up @@ -245,16 +246,19 @@ Closing a driver will immediately shut down all connections in the pool.
assert isinstance(count, int)
return count

:param query_: cypher query to execute
:type query_: typing.LiteralString
:param query_:
Cypher query to execute.
Use a :class:`.Query` object to pass a query with additional
transaction configuration.
:type query_: typing.LiteralString | Query
:param parameters_: parameters to use in the query
:type parameters_: typing.Dict[str, typing.Any] | None
:param routing_:
whether to route the query to a reader (follower/read replica) or
Whether to route the query to a reader (follower/read replica) or
a writer (leader) in the cluster. Default is to route to a writer.
:type routing_: RoutingControl
:param database_:
database to execute the query against.
Database to execute the query against.

None (default) uses the database configured on the server side.

Expand Down Expand Up @@ -375,6 +379,10 @@ Closing a driver will immediately shut down all connections in the pool.
.. versionchanged:: 5.14
Stabilized ``auth_`` parameter from preview.

.. versionchanged:: 5.15
The ``query_`` parameter now also accepts a :class:`.Query` object
instead of only :class:`str`.


.. _driver-configuration-ref:

Expand Down
18 changes: 13 additions & 5 deletions docs/source/async_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,9 @@ Closing a driver will immediately shut down all connections in the pool.
query_, parameters_, routing_, database_, impersonated_user_,
bookmark_manager_, auth_, result_transformer_, **kwargs
):
@unit_of_work(query_.metadata, query_.timeout)
async def work(tx):
result = await tx.run(query_, parameters_, **kwargs)
result = await tx.run(query_.text, parameters_, **kwargs)
return await result_transformer_(result)

async with driver.session(
Expand Down Expand Up @@ -232,16 +233,19 @@ Closing a driver will immediately shut down all connections in the pool.
assert isinstance(count, int)
return count

:param query_: cypher query to execute
:type query_: typing.LiteralString
:param query_:
Cypher query to execute.
Use a :class:`.Query` object to pass a query with additional
transaction configuration.
:type query_: typing.LiteralString | Query
:param parameters_: parameters to use in the query
:type parameters_: typing.Dict[str, typing.Any] | None
:param routing_:
whether to route the query to a reader (follower/read replica) or
Whether to route the query to a reader (follower/read replica) or
a writer (leader) in the cluster. Default is to route to a writer.
:type routing_: RoutingControl
:param database_:
database to execute the query against.
Database to execute the query against.

None (default) uses the database configured on the server side.

Expand Down Expand Up @@ -362,6 +366,10 @@ Closing a driver will immediately shut down all connections in the pool.
.. versionchanged:: 5.14
Stabilized ``auth_`` parameter from preview.

.. versionchanged:: 5.15
The ``query_`` parameter now also accepts a :class:`.Query` object
instead of only :class:`str`.


.. _async-driver-configuration-ref:

Expand Down
42 changes: 31 additions & 11 deletions src/neo4j/_async/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@
experimental_warn,
unclosed_resource_warn,
)
from .._work import EagerResult
from .._work import (
EagerResult,
Query,
unit_of_work,
)
from ..addressing import Address
from ..api import (
AsyncBookmarkManager,
Expand Down Expand Up @@ -581,7 +585,7 @@ async def close(self) -> None:
@t.overload
async def execute_query(
self,
query_: te.LiteralString,
query_: t.Union[te.LiteralString, Query],
parameters_: t.Optional[t.Dict[str, t.Any]] = None,
routing_: T_RoutingControl = RoutingControl.WRITE,
database_: t.Optional[str] = None,
Expand All @@ -600,7 +604,7 @@ async def execute_query(
@t.overload
async def execute_query(
self,
query_: te.LiteralString,
query_: t.Union[te.LiteralString, Query],
parameters_: t.Optional[t.Dict[str, t.Any]] = None,
routing_: T_RoutingControl = RoutingControl.WRITE,
database_: t.Optional[str] = None,
Expand All @@ -618,7 +622,7 @@ async def execute_query(

async def execute_query(
self,
query_: te.LiteralString,
query_: t.Union[te.LiteralString, Query],
parameters_: t.Optional[t.Dict[str, t.Any]] = None,
routing_: T_RoutingControl = RoutingControl.WRITE,
database_: t.Optional[str] = None,
Expand Down Expand Up @@ -651,8 +655,9 @@ async def execute_query(
query_, parameters_, routing_, database_, impersonated_user_,
bookmark_manager_, auth_, result_transformer_, **kwargs
):
@unit_of_work(query_.metadata, query_.timeout)
async def work(tx):
result = await tx.run(query_, parameters_, **kwargs)
result = await tx.run(query_.text, parameters_, **kwargs)
return await result_transformer_(result)
async with driver.session(
Expand Down Expand Up @@ -709,16 +714,19 @@ async def example(driver: neo4j.AsyncDriver) -> int:
assert isinstance(count, int)
return count
:param query_: cypher query to execute
:type query_: typing.LiteralString
:param query_:
Cypher query to execute.
Use a :class:`.Query` object to pass a query with additional
transaction configuration.
:type query_: typing.LiteralString | Query
:param parameters_: parameters to use in the query
:type parameters_: typing.Optional[typing.Dict[str, typing.Any]]
:param routing_:
whether to route the query to a reader (follower/read replica) or
Whether to route the query to a reader (follower/read replica) or
a writer (leader) in the cluster. Default is to route to a writer.
:type routing_: RoutingControl
:param database_:
database to execute the query against.
Database to execute the query against.
None (default) uses the database configured on the server side.
Expand Down Expand Up @@ -838,6 +846,10 @@ async def example(driver: neo4j.AsyncDriver) -> neo4j.Record::
.. versionchanged:: 5.14
Stabilized ``auth_`` parameter from preview.
.. versionchanged:: 5.15
The ``query_`` parameter now also accepts a :class:`.Query` object
instead of only :class:`str`.
"""
self._check_state()
invalid_kwargs = [k for k in kwargs if
Expand All @@ -850,6 +862,14 @@ async def example(driver: neo4j.AsyncDriver) -> neo4j.Record::
"latter case, use the `parameters_` dictionary instead."
% invalid_kwargs
)
if isinstance(query_, Query):
timeout = query_.timeout
metadata = query_.metadata
query_str = query_.text
work = unit_of_work(metadata, timeout)(_work)
else:
query_str = query_
work = _work
parameters = dict(parameters_ or {}, **kwargs)

if bookmark_manager_ is _default:
Expand All @@ -876,7 +896,7 @@ async def example(driver: neo4j.AsyncDriver) -> neo4j.Record::
with session._pipelined_begin:
return await session._run_transaction(
access_mode, TelemetryAPI.DRIVER,
_work, (query_, parameters, result_transformer_), {}
work, (query_str, parameters, result_transformer_), {}
)

@property
Expand Down Expand Up @@ -1195,7 +1215,7 @@ async def _get_server_info(self, session_config) -> ServerInfo:

async def _work(
tx: AsyncManagedTransaction,
query: str,
query: te.LiteralString,
parameters: t.Dict[str, t.Any],
transformer: t.Callable[[AsyncResult], t.Awaitable[_T]]
) -> _T:
Expand Down
42 changes: 31 additions & 11 deletions src/neo4j/_sync/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@
experimental_warn,
unclosed_resource_warn,
)
from .._work import EagerResult
from .._work import (
EagerResult,
Query,
unit_of_work,
)
from ..addressing import Address
from ..api import (
Auth,
Expand Down Expand Up @@ -580,7 +584,7 @@ def close(self) -> None:
@t.overload
def execute_query(
self,
query_: te.LiteralString,
query_: t.Union[te.LiteralString, Query],
parameters_: t.Optional[t.Dict[str, t.Any]] = None,
routing_: T_RoutingControl = RoutingControl.WRITE,
database_: t.Optional[str] = None,
Expand All @@ -599,7 +603,7 @@ def execute_query(
@t.overload
def execute_query(
self,
query_: te.LiteralString,
query_: t.Union[te.LiteralString, Query],
parameters_: t.Optional[t.Dict[str, t.Any]] = None,
routing_: T_RoutingControl = RoutingControl.WRITE,
database_: t.Optional[str] = None,
Expand All @@ -617,7 +621,7 @@ def execute_query(

def execute_query(
self,
query_: te.LiteralString,
query_: t.Union[te.LiteralString, Query],
parameters_: t.Optional[t.Dict[str, t.Any]] = None,
routing_: T_RoutingControl = RoutingControl.WRITE,
database_: t.Optional[str] = None,
Expand Down Expand Up @@ -650,8 +654,9 @@ def execute_query(
query_, parameters_, routing_, database_, impersonated_user_,
bookmark_manager_, auth_, result_transformer_, **kwargs
):
@unit_of_work(query_.metadata, query_.timeout)
def work(tx):
result = tx.run(query_, parameters_, **kwargs)
result = tx.run(query_.text, parameters_, **kwargs)
return result_transformer_(result)
with driver.session(
Expand Down Expand Up @@ -708,16 +713,19 @@ def example(driver: neo4j.Driver) -> int:
assert isinstance(count, int)
return count
:param query_: cypher query to execute
:type query_: typing.LiteralString
:param query_:
Cypher query to execute.
Use a :class:`.Query` object to pass a query with additional
transaction configuration.
:type query_: typing.LiteralString | Query
:param parameters_: parameters to use in the query
:type parameters_: typing.Optional[typing.Dict[str, typing.Any]]
:param routing_:
whether to route the query to a reader (follower/read replica) or
Whether to route the query to a reader (follower/read replica) or
a writer (leader) in the cluster. Default is to route to a writer.
:type routing_: RoutingControl
:param database_:
database to execute the query against.
Database to execute the query against.
None (default) uses the database configured on the server side.
Expand Down Expand Up @@ -837,6 +845,10 @@ def example(driver: neo4j.Driver) -> neo4j.Record::
.. versionchanged:: 5.14
Stabilized ``auth_`` parameter from preview.
.. versionchanged:: 5.15
The ``query_`` parameter now also accepts a :class:`.Query` object
instead of only :class:`str`.
"""
self._check_state()
invalid_kwargs = [k for k in kwargs if
Expand All @@ -849,6 +861,14 @@ def example(driver: neo4j.Driver) -> neo4j.Record::
"latter case, use the `parameters_` dictionary instead."
% invalid_kwargs
)
if isinstance(query_, Query):
timeout = query_.timeout
metadata = query_.metadata
query_str = query_.text
work = unit_of_work(metadata, timeout)(_work)
else:
query_str = query_
work = _work
parameters = dict(parameters_ or {}, **kwargs)

if bookmark_manager_ is _default:
Expand All @@ -875,7 +895,7 @@ def example(driver: neo4j.Driver) -> neo4j.Record::
with session._pipelined_begin:
return session._run_transaction(
access_mode, TelemetryAPI.DRIVER,
_work, (query_, parameters, result_transformer_), {}
work, (query_str, parameters, result_transformer_), {}
)

@property
Expand Down Expand Up @@ -1194,7 +1214,7 @@ def _get_server_info(self, session_config) -> ServerInfo:

def _work(
tx: ManagedTransaction,
query: str,
query: te.LiteralString,
parameters: t.Dict[str, t.Any],
transformer: t.Callable[[Result], t.Union[_T]]
) -> _T:
Expand Down
13 changes: 10 additions & 3 deletions src/neo4j/_work/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ class Query:
"""A query with attached extra data.
This wrapper class for queries is used to attach extra data to queries
passed to :meth:`.Session.run` and :meth:`.AsyncSession.run`, fulfilling
a similar role as :func:`.unit_of_work` for transactions functions.
passed to :meth:`.Session.run`/:meth:`.AsyncSession.run` and
:meth:`.Driver.execute_query`/:meth:`.AsyncDriver.execute_query`,
fulfilling a similar role as :func:`.unit_of_work` for transactions
functions.
:param text: The query text.
:type text: typing.LiteralString
Expand Down Expand Up @@ -74,7 +76,12 @@ def __init__(
self.timeout = timeout

def __str__(self) -> te.LiteralString:
return str(self.text)
# we know that if Query is constructed with a LiteralString,
# str(self.text) will be a LiteralString as well. The conversion isn't
# necessary if the user adheres to the type hints. However, it was
# here before, and we don't want to break backwards compatibility.
text: te.LiteralString = str(self.text) # type: ignore[assignment]
return text


def unit_of_work(
Expand Down
7 changes: 6 additions & 1 deletion testkitbackend/_async/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,11 @@ async def ExecuteQuery(backend, data):
value = config.get(config_key, None)
if value is not None:
kwargs[kwargs_key] = value
tx_kwargs = fromtestkit.to_tx_kwargs(config)
if tx_kwargs:
query = neo4j.Query(cypher, **tx_kwargs)
else:
query = cypher
bookmark_manager_id = config.get("bookmarkManagerId")
if bookmark_manager_id is not None:
if bookmark_manager_id == -1:
Expand All @@ -371,7 +376,7 @@ async def ExecuteQuery(backend, data):
bookmark_manager = backend.bookmark_managers[bookmark_manager_id]
kwargs["bookmark_manager_"] = bookmark_manager

eager_result = await driver.execute_query(cypher, params, **kwargs)
eager_result = await driver.execute_query(query, params, **kwargs)
await backend.send_response("EagerResult", {
"keys": eager_result.keys,
"records": list(map(totestkit.record, eager_result.records)),
Expand Down
7 changes: 6 additions & 1 deletion testkitbackend/_sync/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,11 @@ def ExecuteQuery(backend, data):
value = config.get(config_key, None)
if value is not None:
kwargs[kwargs_key] = value
tx_kwargs = fromtestkit.to_tx_kwargs(config)
if tx_kwargs:
query = neo4j.Query(cypher, **tx_kwargs)
else:
query = cypher
bookmark_manager_id = config.get("bookmarkManagerId")
if bookmark_manager_id is not None:
if bookmark_manager_id == -1:
Expand All @@ -371,7 +376,7 @@ def ExecuteQuery(backend, data):
bookmark_manager = backend.bookmark_managers[bookmark_manager_id]
kwargs["bookmark_manager_"] = bookmark_manager

eager_result = driver.execute_query(cypher, params, **kwargs)
eager_result = driver.execute_query(query, params, **kwargs)
backend.send_response("EagerResult", {
"keys": eager_result.keys,
"records": list(map(totestkit.record, eager_result.records)),
Expand Down
Loading

0 comments on commit 656c796

Please sign in to comment.