diff --git a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py index 9618ab29d1..a313949953 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py @@ -96,8 +96,7 @@ def response_hook(span, instance, response): import redis from wrapt import wrap_function_wrapper -from opentelemetry import context, trace -from opentelemetry.context import _SUPPRESS_INSTRUMENTATION_KEY +from opentelemetry import trace from opentelemetry.instrumentation.instrumentor import BaseInstrumentor from opentelemetry.instrumentation.redis.package import _instruments from opentelemetry.instrumentation.redis.util import ( @@ -105,7 +104,7 @@ def response_hook(span, instance, response): _format_command_args, ) from opentelemetry.instrumentation.redis.version import __version__ -from opentelemetry.instrumentation.utils import unwrap +from opentelemetry.instrumentation.utils import unwrap, is_instrumentation_enabled from opentelemetry.semconv.trace import SpanAttributes from opentelemetry.trace import Span @@ -180,14 +179,12 @@ def _instrument( response_hook: _ResponseHookT = None, ): def _traced_execute_command(func, instance, args, kwargs): - query = _format_command_args(args) - name = _build_span_name(instance, args) - if context.get_value( - _SUPPRESS_INSTRUMENTATION_KEY - ): + if not is_instrumentation_enabled(): return func(*args, **kwargs) + query = _format_command_args(args) + name = _build_span_name(instance, args) with tracer.start_as_current_span( name, kind=trace.SpanKind.CLIENT ) as span: @@ -204,9 +201,7 @@ def _traced_execute_command(func, instance, args, kwargs): def _traced_execute_pipeline(func, instance, args, kwargs): - if context.get_value( - _SUPPRESS_INSTRUMENTATION_KEY - ): + if not is_instrumentation_enabled(): return func(*args, **kwargs) ( @@ -260,6 +255,10 @@ def _traced_execute_pipeline(func, instance, args, kwargs): ) async def _async_traced_execute_command(func, instance, args, kwargs): + + if not is_instrumentation_enabled(): + return func(*args, **kwargs) + query = _format_command_args(args) name = _build_span_name(instance, args) @@ -278,6 +277,10 @@ async def _async_traced_execute_command(func, instance, args, kwargs): return response async def _async_traced_execute_pipeline(func, instance, args, kwargs): + + if not is_instrumentation_enabled(): + return func(*args, **kwargs) + ( command_stack, resource, diff --git a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py index a41d1ea0f8..e8633d30ca 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py +++ b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py @@ -20,7 +20,7 @@ from opentelemetry import context, trace from opentelemetry.instrumentation.redis import RedisInstrumentor -from opentelemetry.instrumentation.utils import _SUPPRESS_INSTRUMENTATION_KEY +from opentelemetry.instrumentation.utils import suppress_instrumentation from opentelemetry.test.test_base import TestBase from opentelemetry.trace import SpanKind @@ -65,12 +65,35 @@ def test_not_recording(self): def test_suppress_instrumentation_no_span(self): redis_client = redis.Redis() - token = context.attach(context.set_value(_SUPPRESS_INSTRUMENTATION_KEY, True)) with mock.patch.object(redis_client, "connection"): - redis_client.ping() - context.detach(token) + redis_client.get("key") + spans = self.memory_exporter.get_finished_spans() + + self.assertEqual(len(spans), 1) + self.memory_exporter.clear() + + with suppress_instrumentation(): + with mock.patch.object(redis_client, "connection"): + redis_client.ping() + spans = self.memory_exporter.get_finished_spans() + + self.assertEqual(len(spans), 0) + + def test_suppress_async_instrumentation_no_span(self): + redis_client = redis.Redis() + + with mock.patch.object(redis_client, "connection", AsyncMock()): + redis_client.get("key") spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + self.memory_exporter.clear() + + with suppress_instrumentation(): + with mock.patch.object(redis_client, "connection", AsyncMock()): + redis_client.ping() + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 0) def test_instrument_uninstrument(self):