Skip to content

Commit

Permalink
address PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
martimors committed Feb 29, 2024
1 parent a6260a2 commit 1ff4c0b
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,15 @@ def response_hook(span, instance, response):
import redis
from wrapt import wrap_function_wrapper

from opentelemetry import context, trace
from opentelemetry.context import _SUPPRESS_INSTRUMENTATION_KEY
from opentelemetry import trace
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.redis.package import _instruments
from opentelemetry.instrumentation.redis.util import (
_extract_conn_attributes,
_format_command_args,
)
from opentelemetry.instrumentation.redis.version import __version__
from opentelemetry.instrumentation.utils import unwrap
from opentelemetry.instrumentation.utils import unwrap, is_instrumentation_enabled
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import Span

Expand Down Expand Up @@ -180,14 +179,12 @@ def _instrument(
response_hook: _ResponseHookT = None,
):
def _traced_execute_command(func, instance, args, kwargs):
query = _format_command_args(args)
name = _build_span_name(instance, args)

if context.get_value(
_SUPPRESS_INSTRUMENTATION_KEY
):
if not is_instrumentation_enabled():
return func(*args, **kwargs)

query = _format_command_args(args)
name = _build_span_name(instance, args)
with tracer.start_as_current_span(
name, kind=trace.SpanKind.CLIENT
) as span:
Expand All @@ -204,9 +201,7 @@ def _traced_execute_command(func, instance, args, kwargs):

def _traced_execute_pipeline(func, instance, args, kwargs):

if context.get_value(
_SUPPRESS_INSTRUMENTATION_KEY
):
if not is_instrumentation_enabled():
return func(*args, **kwargs)

(
Expand Down Expand Up @@ -260,6 +255,10 @@ def _traced_execute_pipeline(func, instance, args, kwargs):
)

async def _async_traced_execute_command(func, instance, args, kwargs):

if not is_instrumentation_enabled():
return func(*args, **kwargs)

query = _format_command_args(args)
name = _build_span_name(instance, args)

Expand All @@ -278,6 +277,10 @@ async def _async_traced_execute_command(func, instance, args, kwargs):
return response

async def _async_traced_execute_pipeline(func, instance, args, kwargs):

if not is_instrumentation_enabled():
return func(*args, **kwargs)

(
command_stack,
resource,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from opentelemetry import context, trace
from opentelemetry.instrumentation.redis import RedisInstrumentor
from opentelemetry.instrumentation.utils import _SUPPRESS_INSTRUMENTATION_KEY
from opentelemetry.instrumentation.utils import suppress_instrumentation
from opentelemetry.test.test_base import TestBase
from opentelemetry.trace import SpanKind

Expand Down Expand Up @@ -65,12 +65,35 @@ def test_not_recording(self):
def test_suppress_instrumentation_no_span(self):
redis_client = redis.Redis()

token = context.attach(context.set_value(_SUPPRESS_INSTRUMENTATION_KEY, True))
with mock.patch.object(redis_client, "connection"):
redis_client.ping()
context.detach(token)
redis_client.get("key")
spans = self.memory_exporter.get_finished_spans()

self.assertEqual(len(spans), 1)
self.memory_exporter.clear()

with suppress_instrumentation():
with mock.patch.object(redis_client, "connection"):
redis_client.ping()
spans = self.memory_exporter.get_finished_spans()

self.assertEqual(len(spans), 0)

def test_suppress_async_instrumentation_no_span(self):
redis_client = redis.Redis()

with mock.patch.object(redis_client, "connection", AsyncMock()):
redis_client.get("key")
spans = self.memory_exporter.get_finished_spans()

self.assertEqual(len(spans), 1)
self.memory_exporter.clear()

with suppress_instrumentation():
with mock.patch.object(redis_client, "connection", AsyncMock()):
redis_client.ping()
spans = self.memory_exporter.get_finished_spans()

self.assertEqual(len(spans), 0)

def test_instrument_uninstrument(self):
Expand Down

0 comments on commit 1ff4c0b

Please sign in to comment.