diff --git a/CHANGELOG.md b/CHANGELOG.md index c332cda6d8..4eee938181 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -66,6 +66,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#563](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/563)) - `opentelemetry-exporter-datadog` Datadog exporter should not use `unknown_service` as fallback resource service name. ([#570](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/570)) +- Add support for the async extension of SQLAlchemy (>= 1.4) + ([#568](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/568)) ### Added - `opentelemetry-instrumentation-httpx` Add `httpx` instrumentation diff --git a/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/__init__.py b/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/__init__.py index 01db312b3f..05e6451626 100644 --- a/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/__init__.py @@ -36,18 +36,32 @@ engine=engine, ) + # of the async variant of SQLAlchemy + + from sqlalchemy.ext.asyncio import create_async_engine + + from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor + import sqlalchemy + + engine = create_async_engine("sqlite:///:memory:") + SQLAlchemyInstrumentor().instrument( + engine=engine.sync_engine + ) + API --- """ from typing import Collection import sqlalchemy +from packaging.version import parse as parse_version from wrapt import wrap_function_wrapper as _w from opentelemetry.instrumentation.instrumentor import BaseInstrumentor from opentelemetry.instrumentation.sqlalchemy.engine import ( EngineTracer, _get_tracer, + _wrap_create_async_engine, _wrap_create_engine, ) from opentelemetry.instrumentation.sqlalchemy.package import _instruments @@ -76,6 +90,13 @@ def _instrument(self, **kwargs): """ _w("sqlalchemy", "create_engine", _wrap_create_engine) _w("sqlalchemy.engine", "create_engine", _wrap_create_engine) + if parse_version(sqlalchemy.__version__).release >= (1, 4): + _w( + "sqlalchemy.ext.asyncio", + "create_async_engine", + _wrap_create_async_engine, + ) + if kwargs.get("engine") is not None: return EngineTracer( _get_tracer( @@ -88,3 +109,5 @@ def _instrument(self, **kwargs): def _uninstrument(self, **kwargs): unwrap(sqlalchemy, "create_engine") unwrap(sqlalchemy.engine, "create_engine") + if parse_version(sqlalchemy.__version__).release >= (1, 4): + unwrap(sqlalchemy.ext.asyncio, "create_async_engine") diff --git a/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py b/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py index e69c6dbcb4..ed1dfb1976 100644 --- a/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py +++ b/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from threading import local - from sqlalchemy.event import listen # pylint: disable=no-name-in-module from opentelemetry import trace @@ -44,6 +42,16 @@ def _get_tracer(engine, tracer_provider=None): ) +# pylint: disable=unused-argument +def _wrap_create_async_engine(func, module, args, kwargs): + """Trace the SQLAlchemy engine, creating an `EngineTracer` + object that will listen to SQLAlchemy events. + """ + engine = func(*args, **kwargs) + EngineTracer(_get_tracer(engine), engine.sync_engine) + return engine + + # pylint: disable=unused-argument def _wrap_create_engine(func, module, args, kwargs): """Trace the SQLAlchemy engine, creating an `EngineTracer` @@ -59,20 +67,10 @@ def __init__(self, tracer, engine): self.tracer = tracer self.engine = engine self.vendor = _normalize_vendor(engine.name) - self.cursor_mapping = {} - self.local = local() listen(engine, "before_cursor_execute", self._before_cur_exec) - listen(engine, "after_cursor_execute", self._after_cur_exec) - listen(engine, "handle_error", self._handle_error) - - @property - def current_thread_span(self): - return getattr(self.local, "current_span", None) - - @current_thread_span.setter - def current_thread_span(self, span): - setattr(self.local, "current_span", span) + listen(engine, "after_cursor_execute", _after_cur_exec) + listen(engine, "handle_error", _handle_error) def _operation_name(self, db_name, statement): parts = [] @@ -90,7 +88,9 @@ def _operation_name(self, db_name, statement): return " ".join(parts) # pylint: disable=unused-argument - def _before_cur_exec(self, conn, cursor, statement, *args): + def _before_cur_exec( + self, conn, cursor, statement, params, context, executemany + ): attrs, found = _get_attributes_from_url(conn.engine.url) if not found: attrs = _get_attributes_from_cursor(self.vendor, cursor, attrs) @@ -100,7 +100,6 @@ def _before_cur_exec(self, conn, cursor, statement, *args): self._operation_name(db_name, statement), kind=trace.SpanKind.CLIENT, ) - self.current_thread_span = self.cursor_mapping[cursor] = span with trace.use_span(span, end_on_exit=False): if span.is_recording(): span.set_attribute(SpanAttributes.DB_STATEMENT, statement) @@ -108,34 +107,28 @@ def _before_cur_exec(self, conn, cursor, statement, *args): for key, value in attrs.items(): span.set_attribute(key, value) - # pylint: disable=unused-argument - def _after_cur_exec(self, conn, cursor, statement, *args): - span = self.cursor_mapping.get(cursor, None) - if span is None: - return + context._otel_span = span - span.end() - self._cleanup(cursor) - def _handle_error(self, context): - span = self.current_thread_span - if span is None: - return +# pylint: disable=unused-argument +def _after_cur_exec(conn, cursor, statement, params, context, executemany): + span = getattr(context, "_otel_span", None) + if span is None: + return - try: - if span.is_recording(): - span.set_status( - Status(StatusCode.ERROR, str(context.original_exception),) - ) - finally: - span.end() - self._cleanup(context.cursor) - - def _cleanup(self, cursor): - try: - del self.cursor_mapping[cursor] - except KeyError: - pass + span.end() + + +def _handle_error(context): + span = getattr(context.execution_context, "_otel_span", None) + if span is None: + return + + if span.is_recording(): + span.set_status( + Status(StatusCode.ERROR, str(context.original_exception),) + ) + span.end() def _get_attributes_from_url(url): diff --git a/instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_sqlalchemy.py b/instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_sqlalchemy.py index 4a633687e6..bed2b5f312 100644 --- a/instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_sqlalchemy.py +++ b/instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_sqlalchemy.py @@ -11,8 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio from unittest import mock +import pytest +import sqlalchemy from sqlalchemy import create_engine from opentelemetry import trace @@ -38,6 +41,29 @@ def test_trace_integration(self): self.assertEqual(spans[0].name, "SELECT :memory:") self.assertEqual(spans[0].kind, trace.SpanKind.CLIENT) + @pytest.mark.skipif( + not sqlalchemy.__version__.startswith("1.4"), + reason="only run async tests for 1.4", + ) + def test_async_trace_integration(self): + async def run(): + from sqlalchemy.ext.asyncio import ( # pylint: disable-all + create_async_engine, + ) + + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + SQLAlchemyInstrumentor().instrument( + engine=engine.sync_engine, tracer_provider=self.tracer_provider + ) + async with engine.connect() as cnx: + await cnx.execute(sqlalchemy.text("SELECT 1 + 1;")) + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + self.assertEqual(spans[0].name, "SELECT :memory:") + self.assertEqual(spans[0].kind, trace.SpanKind.CLIENT) + + asyncio.get_event_loop().run_until_complete(run()) + def test_not_recording(self): mock_tracer = mock.Mock() mock_span = mock.Mock() @@ -68,3 +94,24 @@ def test_create_engine_wrapper(self): self.assertEqual(len(spans), 1) self.assertEqual(spans[0].name, "SELECT :memory:") self.assertEqual(spans[0].kind, trace.SpanKind.CLIENT) + + @pytest.mark.skipif( + not sqlalchemy.__version__.startswith("1.4"), + reason="only run async tests for 1.4", + ) + def test_create_async_engine_wrapper(self): + async def run(): + SQLAlchemyInstrumentor().instrument() + from sqlalchemy.ext.asyncio import ( # pylint: disable-all + create_async_engine, + ) + + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + async with engine.connect() as cnx: + await cnx.execute(sqlalchemy.text("SELECT 1 + 1;")) + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + self.assertEqual(spans[0].name, "SELECT :memory:") + self.assertEqual(spans[0].kind, trace.SpanKind.CLIENT) + + asyncio.get_event_loop().run_until_complete(run()) diff --git a/tests/opentelemetry-docker-tests/tests/sqlalchemy_tests/mixins.py b/tests/opentelemetry-docker-tests/tests/sqlalchemy_tests/mixins.py index c2e5548ab1..b9c766ad1c 100644 --- a/tests/opentelemetry-docker-tests/tests/sqlalchemy_tests/mixins.py +++ b/tests/opentelemetry-docker-tests/tests/sqlalchemy_tests/mixins.py @@ -15,6 +15,7 @@ import contextlib import logging import threading +import unittest from sqlalchemy import Column, Integer, String, create_engine, insert from sqlalchemy.ext.declarative import declarative_base @@ -242,4 +243,10 @@ def insert_players(session): close_all_sessions() spans = self.memory_exporter.get_finished_spans() - self.assertEqual(len(spans), 5) + + # SQLAlchemy 1.4 uses the `execute_values` extension of the psycopg2 dialect to + # batch inserts together which means `insert_players` only generates one span. + # See https://docs.sqlalchemy.org/en/14/changelog/migration_14.html#orm-batch-inserts-with-psycopg2-now-batch-statements-with-returning-in-most-cases + self.assertEqual( + len(spans), 5 if self.VENDOR not in ["postgresql"] else 3 + ) diff --git a/tox.ini b/tox.ini index 2d58b8928e..e7a5445043 100644 --- a/tox.ini +++ b/tox.ini @@ -122,8 +122,8 @@ envlist = py3{6,7,8,9}-test-instrumentation-grpc ; opentelemetry-instrumentation-sqlalchemy - py3{6,7,8,9}-test-instrumentation-sqlalchemy - pypy3-test-instrumentation-sqlalchemy + py3{6,7,8,9}-test-instrumentation-sqlalchemy{11,14} + pypy3-test-instrumentation-sqlalchemy{11,14} ; opentelemetry-instrumentation-redis py3{6,7,8,9}-test-instrumentation-redis @@ -173,6 +173,9 @@ deps = elasticsearch6: elasticsearch>=6.0,<7.0 elasticsearch7: elasticsearch-dsl>=7.0,<8.0 elasticsearch7: elasticsearch>=7.0,<8.0 + sqlalchemy11: sqlalchemy>=1.1,<1.2 + sqlalchemy14: aiosqlite + sqlalchemy14: sqlalchemy~=1.4 ; FIXME: add coverage testing ; FIXME: add mypy testing @@ -205,7 +208,7 @@ changedir = test-instrumentation-redis: instrumentation/opentelemetry-instrumentation-redis/tests test-instrumentation-requests: instrumentation/opentelemetry-instrumentation-requests/tests test-instrumentation-sklearn: instrumentation/opentelemetry-instrumentation-sklearn/tests - test-instrumentation-sqlalchemy: instrumentation/opentelemetry-instrumentation-sqlalchemy/tests + test-instrumentation-sqlalchemy{11,14}: instrumentation/opentelemetry-instrumentation-sqlalchemy/tests test-instrumentation-sqlite3: instrumentation/opentelemetry-instrumentation-sqlite3/tests test-instrumentation-starlette: instrumentation/opentelemetry-instrumentation-starlette/tests test-instrumentation-tornado: instrumentation/opentelemetry-instrumentation-tornado/tests @@ -290,7 +293,7 @@ commands_pre = sklearn: pip install {toxinidir}/instrumentation/opentelemetry-instrumentation-sklearn[test] - sqlalchemy: pip install {toxinidir}/instrumentation/opentelemetry-instrumentation-sqlalchemy[test] + sqlalchemy{11,14}: pip install {toxinidir}/instrumentation/opentelemetry-instrumentation-sqlalchemy[test] elasticsearch{2,5,6,7}: pip install {toxinidir}/opentelemetry-python-core/opentelemetry-instrumentation {toxinidir}/instrumentation/opentelemetry-instrumentation-elasticsearch[test] @@ -329,7 +332,7 @@ commands = [testenv:lint] basepython: python3.9 -recreate = False +recreate = False deps = -c dev-requirements.txt flaky @@ -399,7 +402,7 @@ deps = PyMySQL ~= 0.10.1 psycopg2 ~= 2.8.4 aiopg >= 0.13.0, < 1.3.0 - sqlalchemy ~= 1.3.16 + sqlalchemy ~= 1.4 redis ~= 3.3.11 celery[pytest] >= 4.0, < 6.0 protobuf>=3.13.0