diff --git a/CHANGELOG.md b/CHANGELOG.md index b56a453d60..4d44be1acd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `opentelemetry-instrumentation-elasticsearch` Added `response_hook` and `request_hook` callbacks ([#670](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/670)) +### Added +- `opentelemetry-instrumentation-redis` added request_hook and response_hook callbacks passed as arguments to the instrument method. + ([#669](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/669)) + ### Changed - `opentelemetry-instrumentation-botocore` Unpatch botocore Endpoint.prepare_request on uninstrument ([#664](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/664)) diff --git a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py index ea40307e21..29b15e3424 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py @@ -38,11 +38,43 @@ client = redis.StrictRedis(host="localhost", port=6379) client.get("my-key") +The `instrument` method accepts the following keyword args: + +tracer_provider (TracerProvider) - an optional tracer provider + +request_hook (Callable) - a function with extra user-defined logic to be performed before performing the request +this function signature is: def request_hook(span: Span, instance: redis.connection.Connection, args, kwargs) -> None + +response_hook (Callable) - a function with extra user-defined logic to be performed after performing the request +this function signature is: def response_hook(span: Span, instance: redis.connection.Connection, response) -> None + +for example: + +.. code: python + + from opentelemetry.instrumentation.redis import RedisInstrumentor + import redis + + def request_hook(span, instance, args, kwargs): + if span and span.is_recording(): + span.set_attribute("custom_user_attribute_from_request_hook", "some-value") + + def response_hook(span, instance, response): + if span and span.is_recording(): + span.set_attribute("custom_user_attribute_from_response_hook", "some-value") + + # Instrument redis with hooks + RedisInstrumentor().instrument(request_hook=request_hook, response_hook=response_hook) + + # This will report a span with the default settings and the custom attributes added from the hooks + client = redis.StrictRedis(host="localhost", port=6379) + client.get("my-key") + API --- """ - -from typing import Collection +import typing +from typing import Any, Collection import redis from wrapt import wrap_function_wrapper @@ -57,9 +89,19 @@ from opentelemetry.instrumentation.redis.version import __version__ from opentelemetry.instrumentation.utils import unwrap from opentelemetry.semconv.trace import SpanAttributes +from opentelemetry.trace import Span _DEFAULT_SERVICE = "redis" +_RequestHookT = typing.Optional[ + typing.Callable[ + [Span, redis.connection.Connection, typing.List, typing.Dict], None + ] +] +_ResponseHookT = typing.Optional[ + typing.Callable[[Span, redis.connection.Connection, Any], None] +] + def _set_connection_attributes(span, conn): if not span.is_recording(): @@ -70,42 +112,68 @@ def _set_connection_attributes(span, conn): span.set_attribute(key, value) -def _traced_execute_command(func, instance, args, kwargs): - tracer = getattr(redis, "_opentelemetry_tracer") - query = _format_command_args(args) - name = "" - if len(args) > 0 and args[0]: - name = args[0] - else: - name = instance.connection_pool.connection_kwargs.get("db", 0) - with tracer.start_as_current_span( - name, kind=trace.SpanKind.CLIENT - ) as span: - if span.is_recording(): - span.set_attribute(SpanAttributes.DB_STATEMENT, query) - _set_connection_attributes(span, instance) - span.set_attribute("db.redis.args_length", len(args)) - return func(*args, **kwargs) - - -def _traced_execute_pipeline(func, instance, args, kwargs): - tracer = getattr(redis, "_opentelemetry_tracer") - - cmds = [_format_command_args(c) for c, _ in instance.command_stack] - resource = "\n".join(cmds) - - span_name = " ".join([args[0] for args, _ in instance.command_stack]) - - with tracer.start_as_current_span( - span_name, kind=trace.SpanKind.CLIENT - ) as span: - if span.is_recording(): - span.set_attribute(SpanAttributes.DB_STATEMENT, resource) - _set_connection_attributes(span, instance) - span.set_attribute( - "db.redis.pipeline_length", len(instance.command_stack) - ) - return func(*args, **kwargs) +def _instrument( + tracer, + request_hook: _RequestHookT = None, + response_hook: _ResponseHookT = None, +): + def _traced_execute_command(func, instance, args, kwargs): + query = _format_command_args(args) + name = "" + if len(args) > 0 and args[0]: + name = args[0] + else: + name = instance.connection_pool.connection_kwargs.get("db", 0) + with tracer.start_as_current_span( + name, kind=trace.SpanKind.CLIENT + ) as span: + if span.is_recording(): + span.set_attribute(SpanAttributes.DB_STATEMENT, query) + _set_connection_attributes(span, instance) + span.set_attribute("db.redis.args_length", len(args)) + if callable(request_hook): + request_hook(span, instance, args, kwargs) + response = func(*args, **kwargs) + if callable(response_hook): + response_hook(span, instance, response) + return response + + def _traced_execute_pipeline(func, instance, args, kwargs): + cmds = [_format_command_args(c) for c, _ in instance.command_stack] + resource = "\n".join(cmds) + + span_name = " ".join([args[0] for args, _ in instance.command_stack]) + + with tracer.start_as_current_span( + span_name, kind=trace.SpanKind.CLIENT + ) as span: + if span.is_recording(): + span.set_attribute(SpanAttributes.DB_STATEMENT, resource) + _set_connection_attributes(span, instance) + span.set_attribute( + "db.redis.pipeline_length", len(instance.command_stack) + ) + response = func(*args, **kwargs) + if callable(response_hook): + response_hook(span, instance, response) + return response + + pipeline_class = ( + "BasePipeline" if redis.VERSION < (3, 0, 0) else "Pipeline" + ) + redis_class = "StrictRedis" if redis.VERSION < (3, 0, 0) else "Redis" + + wrap_function_wrapper( + "redis", f"{redis_class}.execute_command", _traced_execute_command + ) + wrap_function_wrapper( + "redis.client", f"{pipeline_class}.execute", _traced_execute_pipeline, + ) + wrap_function_wrapper( + "redis.client", + f"{pipeline_class}.immediate_execute_command", + _traced_execute_command, + ) class RedisInstrumentor(BaseInstrumentor): @@ -117,41 +185,22 @@ def instrumentation_dependencies(self) -> Collection[str]: return _instruments def _instrument(self, **kwargs): + """Instruments the redis module + + Args: + **kwargs: Optional arguments + ``tracer_provider``: a TracerProvider, defaults to global. + ``response_hook``: An optional callback which is invoked right before the span is finished processing a response. + """ tracer_provider = kwargs.get("tracer_provider") - setattr( - redis, - "_opentelemetry_tracer", - trace.get_tracer( - __name__, __version__, tracer_provider=tracer_provider, - ), + tracer = trace.get_tracer( + __name__, __version__, tracer_provider=tracer_provider + ) + _instrument( + tracer, + request_hook=kwargs.get("request_hook"), + response_hook=kwargs.get("response_hook"), ) - - if redis.VERSION < (3, 0, 0): - wrap_function_wrapper( - "redis", "StrictRedis.execute_command", _traced_execute_command - ) - wrap_function_wrapper( - "redis.client", - "BasePipeline.execute", - _traced_execute_pipeline, - ) - wrap_function_wrapper( - "redis.client", - "BasePipeline.immediate_execute_command", - _traced_execute_command, - ) - else: - wrap_function_wrapper( - "redis", "Redis.execute_command", _traced_execute_command - ) - wrap_function_wrapper( - "redis.client", "Pipeline.execute", _traced_execute_pipeline - ) - wrap_function_wrapper( - "redis.client", - "Pipeline.immediate_execute_command", - _traced_execute_command, - ) def _uninstrument(self, **kwargs): if redis.VERSION < (3, 0, 0): diff --git a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py index 0f8d82e0ec..3780f0c245 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py +++ b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py @@ -80,3 +80,64 @@ def test_instrument_uninstrument(self): spans = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans), 1) + + def test_response_hook(self): + redis_client = redis.Redis() + connection = redis.connection.Connection() + redis_client.connection = connection + + response_attribute_name = "db.redis.response" + + def response_hook(span, conn, response): + span.set_attribute(response_attribute_name, response) + + RedisInstrumentor().uninstrument() + RedisInstrumentor().instrument( + tracer_provider=self.tracer_provider, response_hook=response_hook + ) + + test_value = "test_value" + + with mock.patch.object(connection, "send_command"): + with mock.patch.object( + redis_client, "parse_response", return_value=test_value + ): + redis_client.get("key") + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + + span = spans[0] + self.assertEqual( + span.attributes.get(response_attribute_name), test_value + ) + + def test_request_hook(self): + redis_client = redis.Redis() + connection = redis.connection.Connection() + redis_client.connection = connection + + custom_attribute_name = "my.request.attribute" + + def request_hook(span, conn, args, kwargs): + if span and span.is_recording(): + span.set_attribute(custom_attribute_name, args[0]) + + RedisInstrumentor().uninstrument() + RedisInstrumentor().instrument( + tracer_provider=self.tracer_provider, request_hook=request_hook + ) + + test_value = "test_value" + + with mock.patch.object(connection, "send_command"): + with mock.patch.object( + redis_client, "parse_response", return_value=test_value + ): + redis_client.get("key") + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + + span = spans[0] + self.assertEqual(span.attributes.get(custom_attribute_name), "GET")