diff --git a/CHANGELOG.md b/CHANGELOG.md index b5031bc905..09c3d5fbd9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#2573](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2573)) - `opentelemetry-instrumentation-confluent-kafka` Add support for version 2.4.0 of confluent_kafka ([#2616](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2616)) +- `opentelemetry-instrumentation-asyncpg` Add instrumentation to cursor based queries + ([#2501](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2501)) - `opentelemetry-instrumentation-confluent-kafka` Add support for produce purge ([#2638](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2638)) - `opentelemetry-instrumentation-httpx` Implement new semantic convention opt-in migration with stable http semantic conventions diff --git a/instrumentation/opentelemetry-instrumentation-asyncpg/src/opentelemetry/instrumentation/asyncpg/__init__.py b/instrumentation/opentelemetry-instrumentation-asyncpg/src/opentelemetry/instrumentation/asyncpg/__init__.py index 798a5dc00b..ba76254aa8 100644 --- a/instrumentation/opentelemetry-instrumentation-asyncpg/src/opentelemetry/instrumentation/asyncpg/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-asyncpg/src/opentelemetry/instrumentation/asyncpg/__init__.py @@ -127,15 +127,27 @@ def _instrument(self, **kwargs): "asyncpg.connection", method, self._do_execute ) - def _uninstrument(self, **__): for method in [ - "execute", - "executemany", - "fetch", - "fetchval", - "fetchrow", + "Cursor.fetch", + "Cursor.forward", + "Cursor.fetchrow", + "CursorIterator.__anext__", ]: - unwrap(asyncpg.Connection, method) + wrapt.wrap_function_wrapper( + "asyncpg.cursor", method, self._do_cursor_execute + ) + + def _uninstrument(self, **__): + for cls, methods in [ + ( + asyncpg.connection.Connection, + ("execute", "executemany", "fetch", "fetchval", "fetchrow"), + ), + (asyncpg.cursor.Cursor, ("forward", "fetch", "fetchrow")), + (asyncpg.cursor.CursorIterator, ("__anext__",)), + ]: + for method_name in methods: + unwrap(cls, method_name) async def _do_execute(self, func, instance, args, kwargs): exception = None @@ -170,3 +182,49 @@ async def _do_execute(self, func, instance, args, kwargs): span.set_status(Status(StatusCode.ERROR)) return result + + async def _do_cursor_execute(self, func, instance, args, kwargs): + """Wrap cursor based functions. For every call this will generate a new span.""" + exception = None + params = getattr(instance._connection, "_params", {}) + name = ( + instance._query + if instance._query + else params.get("database", "postgresql") + ) + + try: + # Strip leading comments so we get the operation name. + name = self._leading_comment_remover.sub("", name).split()[0] + except IndexError: + name = "" + + stop = False + with self._tracer.start_as_current_span( + f"CURSOR: {name}", + kind=SpanKind.CLIENT, + ) as span: + if span.is_recording(): + span_attributes = _hydrate_span_from_args( + instance._connection, + instance._query, + instance._args if self.capture_parameters else None, + ) + for attribute, value in span_attributes.items(): + span.set_attribute(attribute, value) + + try: + result = await func(*args, **kwargs) + except StopAsyncIteration: + # Do not show this exception to the span + stop = True + except Exception as exc: # pylint: disable=W0703 + exception = exc + raise + finally: + if span.is_recording() and exception is not None: + span.set_status(Status(StatusCode.ERROR)) + + if not stop: + return result + raise StopAsyncIteration diff --git a/instrumentation/opentelemetry-instrumentation-asyncpg/tests/test_asyncpg_wrapper.py b/instrumentation/opentelemetry-instrumentation-asyncpg/tests/test_asyncpg_wrapper.py index 12aad0c6dc..7c88b9c005 100644 --- a/instrumentation/opentelemetry-instrumentation-asyncpg/tests/test_asyncpg_wrapper.py +++ b/instrumentation/opentelemetry-instrumentation-asyncpg/tests/test_asyncpg_wrapper.py @@ -1,4 +1,9 @@ -from asyncpg import Connection +import asyncio +from unittest import mock + +import pytest +from asyncpg import Connection, Record, cursor +from wrapt import ObjectProxy from opentelemetry.instrumentation.asyncpg import AsyncPGInstrumentor from opentelemetry.test.test_base import TestBase @@ -34,3 +39,69 @@ def test_duplicated_uninstrumentation(self): self.assertFalse( hasattr(method, "_opentelemetry_ext_asyncpg_applied") ) + + def test_cursor_instrumentation(self): + def assert_wrapped(assert_fnc): + for cls, methods in [ + (cursor.Cursor, ("forward", "fetch", "fetchrow")), + (cursor.CursorIterator, ("__anext__",)), + ]: + for method_name in methods: + method = getattr(cls, method_name, None) + assert_fnc( + isinstance(method, ObjectProxy), + f"{method} isinstance {type(method)}", + ) + + assert_wrapped(self.assertFalse) + AsyncPGInstrumentor().instrument() + assert_wrapped(self.assertTrue) + AsyncPGInstrumentor().uninstrument() + assert_wrapped(self.assertFalse) + + def test_cursor_span_creation(self): + """Test the cursor wrapper if it creates spans correctly.""" + + # Mock out all interaction with postgres + async def bind_mock(*args, **kwargs): + return [] + + async def exec_mock(*args, **kwargs): + return [], None, True + + conn = mock.Mock() + conn.is_closed = lambda: False + + conn._protocol = mock.Mock() + conn._protocol.bind = bind_mock + conn._protocol.execute = exec_mock + conn._protocol.bind_execute = exec_mock + conn._protocol.close_portal = bind_mock + + state = mock.Mock() + state.closed = False + + apg = AsyncPGInstrumentor() + apg.instrument(tracer_provider=self.tracer_provider) + + # init the cursor and fetch a single record + crs = cursor.Cursor(conn, "SELECT * FROM test", state, [], Record) + asyncio.run(crs._init(1)) + asyncio.run(crs.fetch(1)) + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + self.assertEqual(spans[0].name, "CURSOR: SELECT") + self.assertTrue(spans[0].status.is_ok) + + # Now test that the StopAsyncIteration of the cursor does not get recorded as an ERROR + crs_iter = cursor.CursorIterator( + conn, "SELECT * FROM test", state, [], Record, 1, 1 + ) + + with pytest.raises(StopAsyncIteration): + asyncio.run(crs_iter.__anext__()) + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 2) + self.assertEqual([span.status.is_ok for span in spans], [True, True])