From 656c79648c8e91e2afb8ccdc824e8b498e92a2d7 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Mon, 27 Nov 2023 12:54:39 +0100 Subject: [PATCH] Accept transaction config for execute_query (#991) `Driver.execute_query` now accepts a `Query` object to specify transaction config like metadata and transaction timeout. Example: ```python from neo4j import ( GraphDatabase, Query, ) with GraphDatabase.driver(...) as driver: driver.execute_query( Query( "MATCH (n) RETURN n", # metadata to be logged with the transaction metadata={"foo": "bar"}, # give the transaction 5 seconds to complete on the DBMS timeout=5, ), # all the other configuration options as before database_="neo4j", # ... ) ``` --- docs/source/api.rst | 18 +++++++++---- docs/source/async_api.rst | 18 +++++++++---- src/neo4j/_async/driver.py | 42 +++++++++++++++++++++++-------- src/neo4j/_sync/driver.py | 42 +++++++++++++++++++++++-------- src/neo4j/_work/query.py | 13 +++++++--- testkitbackend/_async/requests.py | 7 +++++- testkitbackend/_sync/requests.py | 7 +++++- tests/unit/async_/test_driver.py | 40 ++++++++++++++++++++++++----- tests/unit/sync/test_driver.py | 40 ++++++++++++++++++++++++----- 9 files changed, 178 insertions(+), 49 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 960f25e71..7a30f6004 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -187,8 +187,9 @@ Closing a driver will immediately shut down all connections in the pool. query_, parameters_, routing_, database_, impersonated_user_, bookmark_manager_, auth_, result_transformer_, **kwargs ): + @unit_of_work(query_.metadata, query_.timeout) def work(tx): - result = tx.run(query_, parameters_, **kwargs) + result = tx.run(query_.text, parameters_, **kwargs) return result_transformer_(result) with driver.session( @@ -245,16 +246,19 @@ Closing a driver will immediately shut down all connections in the pool. assert isinstance(count, int) return count - :param query_: cypher query to execute - :type query_: typing.LiteralString + :param query_: + Cypher query to execute. + Use a :class:`.Query` object to pass a query with additional + transaction configuration. + :type query_: typing.LiteralString | Query :param parameters_: parameters to use in the query :type parameters_: typing.Dict[str, typing.Any] | None :param routing_: - whether to route the query to a reader (follower/read replica) or + Whether to route the query to a reader (follower/read replica) or a writer (leader) in the cluster. Default is to route to a writer. :type routing_: RoutingControl :param database_: - database to execute the query against. + Database to execute the query against. None (default) uses the database configured on the server side. @@ -375,6 +379,10 @@ Closing a driver will immediately shut down all connections in the pool. .. versionchanged:: 5.14 Stabilized ``auth_`` parameter from preview. + .. versionchanged:: 5.15 + The ``query_`` parameter now also accepts a :class:`.Query` object + instead of only :class:`str`. + .. _driver-configuration-ref: diff --git a/docs/source/async_api.rst b/docs/source/async_api.rst index 5518628e8..a346b538a 100644 --- a/docs/source/async_api.rst +++ b/docs/source/async_api.rst @@ -174,8 +174,9 @@ Closing a driver will immediately shut down all connections in the pool. query_, parameters_, routing_, database_, impersonated_user_, bookmark_manager_, auth_, result_transformer_, **kwargs ): + @unit_of_work(query_.metadata, query_.timeout) async def work(tx): - result = await tx.run(query_, parameters_, **kwargs) + result = await tx.run(query_.text, parameters_, **kwargs) return await result_transformer_(result) async with driver.session( @@ -232,16 +233,19 @@ Closing a driver will immediately shut down all connections in the pool. assert isinstance(count, int) return count - :param query_: cypher query to execute - :type query_: typing.LiteralString + :param query_: + Cypher query to execute. + Use a :class:`.Query` object to pass a query with additional + transaction configuration. + :type query_: typing.LiteralString | Query :param parameters_: parameters to use in the query :type parameters_: typing.Dict[str, typing.Any] | None :param routing_: - whether to route the query to a reader (follower/read replica) or + Whether to route the query to a reader (follower/read replica) or a writer (leader) in the cluster. Default is to route to a writer. :type routing_: RoutingControl :param database_: - database to execute the query against. + Database to execute the query against. None (default) uses the database configured on the server side. @@ -362,6 +366,10 @@ Closing a driver will immediately shut down all connections in the pool. .. versionchanged:: 5.14 Stabilized ``auth_`` parameter from preview. + .. versionchanged:: 5.15 + The ``query_`` parameter now also accepts a :class:`.Query` object + instead of only :class:`str`. + .. _async-driver-configuration-ref: diff --git a/src/neo4j/_async/driver.py b/src/neo4j/_async/driver.py index f0229ed39..8bbbb9d2f 100644 --- a/src/neo4j/_async/driver.py +++ b/src/neo4j/_async/driver.py @@ -47,7 +47,11 @@ experimental_warn, unclosed_resource_warn, ) -from .._work import EagerResult +from .._work import ( + EagerResult, + Query, + unit_of_work, +) from ..addressing import Address from ..api import ( AsyncBookmarkManager, @@ -581,7 +585,7 @@ async def close(self) -> None: @t.overload async def execute_query( self, - query_: te.LiteralString, + query_: t.Union[te.LiteralString, Query], parameters_: t.Optional[t.Dict[str, t.Any]] = None, routing_: T_RoutingControl = RoutingControl.WRITE, database_: t.Optional[str] = None, @@ -600,7 +604,7 @@ async def execute_query( @t.overload async def execute_query( self, - query_: te.LiteralString, + query_: t.Union[te.LiteralString, Query], parameters_: t.Optional[t.Dict[str, t.Any]] = None, routing_: T_RoutingControl = RoutingControl.WRITE, database_: t.Optional[str] = None, @@ -618,7 +622,7 @@ async def execute_query( async def execute_query( self, - query_: te.LiteralString, + query_: t.Union[te.LiteralString, Query], parameters_: t.Optional[t.Dict[str, t.Any]] = None, routing_: T_RoutingControl = RoutingControl.WRITE, database_: t.Optional[str] = None, @@ -651,8 +655,9 @@ async def execute_query( query_, parameters_, routing_, database_, impersonated_user_, bookmark_manager_, auth_, result_transformer_, **kwargs ): + @unit_of_work(query_.metadata, query_.timeout) async def work(tx): - result = await tx.run(query_, parameters_, **kwargs) + result = await tx.run(query_.text, parameters_, **kwargs) return await result_transformer_(result) async with driver.session( @@ -709,16 +714,19 @@ async def example(driver: neo4j.AsyncDriver) -> int: assert isinstance(count, int) return count - :param query_: cypher query to execute - :type query_: typing.LiteralString + :param query_: + Cypher query to execute. + Use a :class:`.Query` object to pass a query with additional + transaction configuration. + :type query_: typing.LiteralString | Query :param parameters_: parameters to use in the query :type parameters_: typing.Optional[typing.Dict[str, typing.Any]] :param routing_: - whether to route the query to a reader (follower/read replica) or + Whether to route the query to a reader (follower/read replica) or a writer (leader) in the cluster. Default is to route to a writer. :type routing_: RoutingControl :param database_: - database to execute the query against. + Database to execute the query against. None (default) uses the database configured on the server side. @@ -838,6 +846,10 @@ async def example(driver: neo4j.AsyncDriver) -> neo4j.Record:: .. versionchanged:: 5.14 Stabilized ``auth_`` parameter from preview. + + .. versionchanged:: 5.15 + The ``query_`` parameter now also accepts a :class:`.Query` object + instead of only :class:`str`. """ self._check_state() invalid_kwargs = [k for k in kwargs if @@ -850,6 +862,14 @@ async def example(driver: neo4j.AsyncDriver) -> neo4j.Record:: "latter case, use the `parameters_` dictionary instead." % invalid_kwargs ) + if isinstance(query_, Query): + timeout = query_.timeout + metadata = query_.metadata + query_str = query_.text + work = unit_of_work(metadata, timeout)(_work) + else: + query_str = query_ + work = _work parameters = dict(parameters_ or {}, **kwargs) if bookmark_manager_ is _default: @@ -876,7 +896,7 @@ async def example(driver: neo4j.AsyncDriver) -> neo4j.Record:: with session._pipelined_begin: return await session._run_transaction( access_mode, TelemetryAPI.DRIVER, - _work, (query_, parameters, result_transformer_), {} + work, (query_str, parameters, result_transformer_), {} ) @property @@ -1195,7 +1215,7 @@ async def _get_server_info(self, session_config) -> ServerInfo: async def _work( tx: AsyncManagedTransaction, - query: str, + query: te.LiteralString, parameters: t.Dict[str, t.Any], transformer: t.Callable[[AsyncResult], t.Awaitable[_T]] ) -> _T: diff --git a/src/neo4j/_sync/driver.py b/src/neo4j/_sync/driver.py index dbe92fcd2..0bed94284 100644 --- a/src/neo4j/_sync/driver.py +++ b/src/neo4j/_sync/driver.py @@ -47,7 +47,11 @@ experimental_warn, unclosed_resource_warn, ) -from .._work import EagerResult +from .._work import ( + EagerResult, + Query, + unit_of_work, +) from ..addressing import Address from ..api import ( Auth, @@ -580,7 +584,7 @@ def close(self) -> None: @t.overload def execute_query( self, - query_: te.LiteralString, + query_: t.Union[te.LiteralString, Query], parameters_: t.Optional[t.Dict[str, t.Any]] = None, routing_: T_RoutingControl = RoutingControl.WRITE, database_: t.Optional[str] = None, @@ -599,7 +603,7 @@ def execute_query( @t.overload def execute_query( self, - query_: te.LiteralString, + query_: t.Union[te.LiteralString, Query], parameters_: t.Optional[t.Dict[str, t.Any]] = None, routing_: T_RoutingControl = RoutingControl.WRITE, database_: t.Optional[str] = None, @@ -617,7 +621,7 @@ def execute_query( def execute_query( self, - query_: te.LiteralString, + query_: t.Union[te.LiteralString, Query], parameters_: t.Optional[t.Dict[str, t.Any]] = None, routing_: T_RoutingControl = RoutingControl.WRITE, database_: t.Optional[str] = None, @@ -650,8 +654,9 @@ def execute_query( query_, parameters_, routing_, database_, impersonated_user_, bookmark_manager_, auth_, result_transformer_, **kwargs ): + @unit_of_work(query_.metadata, query_.timeout) def work(tx): - result = tx.run(query_, parameters_, **kwargs) + result = tx.run(query_.text, parameters_, **kwargs) return result_transformer_(result) with driver.session( @@ -708,16 +713,19 @@ def example(driver: neo4j.Driver) -> int: assert isinstance(count, int) return count - :param query_: cypher query to execute - :type query_: typing.LiteralString + :param query_: + Cypher query to execute. + Use a :class:`.Query` object to pass a query with additional + transaction configuration. + :type query_: typing.LiteralString | Query :param parameters_: parameters to use in the query :type parameters_: typing.Optional[typing.Dict[str, typing.Any]] :param routing_: - whether to route the query to a reader (follower/read replica) or + Whether to route the query to a reader (follower/read replica) or a writer (leader) in the cluster. Default is to route to a writer. :type routing_: RoutingControl :param database_: - database to execute the query against. + Database to execute the query against. None (default) uses the database configured on the server side. @@ -837,6 +845,10 @@ def example(driver: neo4j.Driver) -> neo4j.Record:: .. versionchanged:: 5.14 Stabilized ``auth_`` parameter from preview. + + .. versionchanged:: 5.15 + The ``query_`` parameter now also accepts a :class:`.Query` object + instead of only :class:`str`. """ self._check_state() invalid_kwargs = [k for k in kwargs if @@ -849,6 +861,14 @@ def example(driver: neo4j.Driver) -> neo4j.Record:: "latter case, use the `parameters_` dictionary instead." % invalid_kwargs ) + if isinstance(query_, Query): + timeout = query_.timeout + metadata = query_.metadata + query_str = query_.text + work = unit_of_work(metadata, timeout)(_work) + else: + query_str = query_ + work = _work parameters = dict(parameters_ or {}, **kwargs) if bookmark_manager_ is _default: @@ -875,7 +895,7 @@ def example(driver: neo4j.Driver) -> neo4j.Record:: with session._pipelined_begin: return session._run_transaction( access_mode, TelemetryAPI.DRIVER, - _work, (query_, parameters, result_transformer_), {} + work, (query_str, parameters, result_transformer_), {} ) @property @@ -1194,7 +1214,7 @@ def _get_server_info(self, session_config) -> ServerInfo: def _work( tx: ManagedTransaction, - query: str, + query: te.LiteralString, parameters: t.Dict[str, t.Any], transformer: t.Callable[[Result], t.Union[_T]] ) -> _T: diff --git a/src/neo4j/_work/query.py b/src/neo4j/_work/query.py index 216743273..22e50b99f 100644 --- a/src/neo4j/_work/query.py +++ b/src/neo4j/_work/query.py @@ -29,8 +29,10 @@ class Query: """A query with attached extra data. This wrapper class for queries is used to attach extra data to queries - passed to :meth:`.Session.run` and :meth:`.AsyncSession.run`, fulfilling - a similar role as :func:`.unit_of_work` for transactions functions. + passed to :meth:`.Session.run`/:meth:`.AsyncSession.run` and + :meth:`.Driver.execute_query`/:meth:`.AsyncDriver.execute_query`, + fulfilling a similar role as :func:`.unit_of_work` for transactions + functions. :param text: The query text. :type text: typing.LiteralString @@ -74,7 +76,12 @@ def __init__( self.timeout = timeout def __str__(self) -> te.LiteralString: - return str(self.text) + # we know that if Query is constructed with a LiteralString, + # str(self.text) will be a LiteralString as well. The conversion isn't + # necessary if the user adheres to the type hints. However, it was + # here before, and we don't want to break backwards compatibility. + text: te.LiteralString = str(self.text) # type: ignore[assignment] + return text def unit_of_work( diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index 869812235..be95a9ed7 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -363,6 +363,11 @@ async def ExecuteQuery(backend, data): value = config.get(config_key, None) if value is not None: kwargs[kwargs_key] = value + tx_kwargs = fromtestkit.to_tx_kwargs(config) + if tx_kwargs: + query = neo4j.Query(cypher, **tx_kwargs) + else: + query = cypher bookmark_manager_id = config.get("bookmarkManagerId") if bookmark_manager_id is not None: if bookmark_manager_id == -1: @@ -371,7 +376,7 @@ async def ExecuteQuery(backend, data): bookmark_manager = backend.bookmark_managers[bookmark_manager_id] kwargs["bookmark_manager_"] = bookmark_manager - eager_result = await driver.execute_query(cypher, params, **kwargs) + eager_result = await driver.execute_query(query, params, **kwargs) await backend.send_response("EagerResult", { "keys": eager_result.keys, "records": list(map(totestkit.record, eager_result.records)), diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index 82ec61c39..b85e1c8a0 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -363,6 +363,11 @@ def ExecuteQuery(backend, data): value = config.get(config_key, None) if value is not None: kwargs[kwargs_key] = value + tx_kwargs = fromtestkit.to_tx_kwargs(config) + if tx_kwargs: + query = neo4j.Query(cypher, **tx_kwargs) + else: + query = cypher bookmark_manager_id = config.get("bookmarkManagerId") if bookmark_manager_id is not None: if bookmark_manager_id == -1: @@ -371,7 +376,7 @@ def ExecuteQuery(backend, data): bookmark_manager = backend.bookmark_managers[bookmark_manager_id] kwargs["bookmark_manager_"] = bookmark_manager - eager_result = driver.execute_query(cypher, params, **kwargs) + eager_result = driver.execute_query(query, params, **kwargs) backend.send_response("EagerResult", { "keys": eager_result.keys, "records": list(map(totestkit.record, eager_result.records)), diff --git a/tests/unit/async_/test_driver.py b/tests/unit/async_/test_driver.py index d83645d9c..e5418d46b 100644 --- a/tests/unit/async_/test_driver.py +++ b/tests/unit/async_/test_driver.py @@ -33,6 +33,7 @@ ExperimentalWarning, NotificationDisabledCategory, NotificationMinimumSeverity, + Query, TRUST_ALL_CERTIFICATES, TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, TrustAll, @@ -72,6 +73,13 @@ def session_cls_mock(mocker): yield session_cls_mock +@pytest.fixture +def unit_of_work_mock(mocker): + unit_of_work_mock = mocker.patch("neo4j._async.driver.unit_of_work", + autospec=True) + yield unit_of_work_mock + + @pytest.mark.parametrize("protocol", ("bolt://", "bolt+s://", "bolt+ssc://")) @pytest.mark.parametrize("host", ("localhost", "127.0.0.1", "[::1]", "[0:0:0:0:0:0:0:1]")) @@ -625,11 +633,20 @@ async def test_execute_query_work(mocker) -> None: assert res is transformer_mock.return_value -@pytest.mark.parametrize("query", ("foo", "bar", "RETURN 1 AS n")) +@pytest.mark.parametrize("query", ( + "foo", + "bar", + "RETURN 1 AS n", + Query("RETURN 1 AS n"), + Query("RETURN 1 AS n", metadata={"key": "value"}), + Query("RETURN 1 AS n", timeout=1234), + Query("RETURN 1 AS n", metadata={"key": "value"}, timeout=1234), +)) @pytest.mark.parametrize("positional", (True, False)) @mark_async_test async def test_execute_query_query( - query: str, positional: bool, session_cls_mock, mocker + query: te.LiteralString | Query, positional: bool, session_cls_mock, + unit_of_work_mock, mocker ) -> None: driver = AsyncGraphDatabase.driver("bolt://localhost") @@ -644,10 +661,21 @@ async def test_execute_query_query( session_mock.__aenter__.assert_awaited_once() session_mock.__aexit__.assert_awaited_once() session_executor_mock = session_mock._run_transaction - session_executor_mock.assert_awaited_once_with( - WRITE_ACCESS, TelemetryAPI.DRIVER, _work, - (query, mocker.ANY, mocker.ANY), {} - ) + if isinstance(query, Query): + unit_of_work_mock.assert_called_once_with(query.metadata, + query.timeout) + unit_of_work = unit_of_work_mock.return_value + unit_of_work.assert_called_once_with(_work) + session_executor_mock.assert_awaited_once_with( + WRITE_ACCESS, TelemetryAPI.DRIVER, unit_of_work.return_value, + (query.text, mocker.ANY, mocker.ANY), {} + ) + else: + unit_of_work_mock.assert_not_called() + session_executor_mock.assert_awaited_once_with( + WRITE_ACCESS, TelemetryAPI.DRIVER, _work, + (query, mocker.ANY, mocker.ANY), {} + ) assert res is session_executor_mock.return_value diff --git a/tests/unit/sync/test_driver.py b/tests/unit/sync/test_driver.py index fe833b6df..2d9767d58 100644 --- a/tests/unit/sync/test_driver.py +++ b/tests/unit/sync/test_driver.py @@ -32,6 +32,7 @@ Neo4jDriver, NotificationDisabledCategory, NotificationMinimumSeverity, + Query, Result, TRUST_ALL_CERTIFICATES, TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, @@ -71,6 +72,13 @@ def session_cls_mock(mocker): yield session_cls_mock +@pytest.fixture +def unit_of_work_mock(mocker): + unit_of_work_mock = mocker.patch("neo4j._sync.driver.unit_of_work", + autospec=True) + yield unit_of_work_mock + + @pytest.mark.parametrize("protocol", ("bolt://", "bolt+s://", "bolt+ssc://")) @pytest.mark.parametrize("host", ("localhost", "127.0.0.1", "[::1]", "[0:0:0:0:0:0:0:1]")) @@ -624,11 +632,20 @@ def test_execute_query_work(mocker) -> None: assert res is transformer_mock.return_value -@pytest.mark.parametrize("query", ("foo", "bar", "RETURN 1 AS n")) +@pytest.mark.parametrize("query", ( + "foo", + "bar", + "RETURN 1 AS n", + Query("RETURN 1 AS n"), + Query("RETURN 1 AS n", metadata={"key": "value"}), + Query("RETURN 1 AS n", timeout=1234), + Query("RETURN 1 AS n", metadata={"key": "value"}, timeout=1234), +)) @pytest.mark.parametrize("positional", (True, False)) @mark_sync_test def test_execute_query_query( - query: str, positional: bool, session_cls_mock, mocker + query: te.LiteralString | Query, positional: bool, session_cls_mock, + unit_of_work_mock, mocker ) -> None: driver = GraphDatabase.driver("bolt://localhost") @@ -643,10 +660,21 @@ def test_execute_query_query( session_mock.__enter__.assert_called_once() session_mock.__exit__.assert_called_once() session_executor_mock = session_mock._run_transaction - session_executor_mock.assert_called_once_with( - WRITE_ACCESS, TelemetryAPI.DRIVER, _work, - (query, mocker.ANY, mocker.ANY), {} - ) + if isinstance(query, Query): + unit_of_work_mock.assert_called_once_with(query.metadata, + query.timeout) + unit_of_work = unit_of_work_mock.return_value + unit_of_work.assert_called_once_with(_work) + session_executor_mock.assert_called_once_with( + WRITE_ACCESS, TelemetryAPI.DRIVER, unit_of_work.return_value, + (query.text, mocker.ANY, mocker.ANY), {} + ) + else: + unit_of_work_mock.assert_not_called() + session_executor_mock.assert_called_once_with( + WRITE_ACCESS, TelemetryAPI.DRIVER, _work, + (query, mocker.ANY, mocker.ANY), {} + ) assert res is session_executor_mock.return_value