diff --git a/CHANGELOG.md b/CHANGELOG.md index ee5ece78bf..99b702799e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- `opentelemetry-instrumentation-redis` Add `sanitize_query` config option to allow query sanitization. ([#1572](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1572)) - `opentelemetry-instrumentation-celery` Record exceptions as events on the span. ([#1573](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1573)) - Add metric instrumentation for urllib diff --git a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py index b85c2336b0..0f18639bd2 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py @@ -64,6 +64,8 @@ async def redis_get(): response_hook (Callable) - a function with extra user-defined logic to be performed after performing the request this function signature is: def response_hook(span: Span, instance: redis.connection.Connection, response) -> None +sanitize_query (Boolean) - default False, enable the Redis query sanitization + for example: .. code: python @@ -139,9 +141,11 @@ def _instrument( tracer, request_hook: _RequestHookT = None, response_hook: _ResponseHookT = None, + sanitize_query: bool = False, ): def _traced_execute_command(func, instance, args, kwargs): - query = _format_command_args(args) + query = _format_command_args(args, sanitize_query) + if len(args) > 0 and args[0]: name = args[0] else: @@ -169,7 +173,9 @@ def _traced_execute_pipeline(func, instance, args, kwargs): ) cmds = [ - _format_command_args(c.args if hasattr(c, "args") else c[0]) + _format_command_args( + c.args if hasattr(c, "args") else c[0], sanitize_query + ) for c in command_stack ] resource = "\n".join(cmds) @@ -281,6 +287,7 @@ def _instrument(self, **kwargs): tracer, request_hook=kwargs.get("request_hook"), response_hook=kwargs.get("response_hook"), + sanitize_query=kwargs.get("sanitize_query", False), ) def _uninstrument(self, **kwargs): diff --git a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/util.py b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/util.py index fdc5cb5fd6..1eadaba718 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/util.py +++ b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/util.py @@ -48,11 +48,27 @@ def _extract_conn_attributes(conn_kwargs): return attributes -def _format_command_args(args): - """Format command arguments and trim them as needed""" - value_max_len = 100 - value_too_long_mark = "..." +def _format_command_args(args, sanitize_query): + """Format and sanitize command arguments, and trim them as needed""" cmd_max_len = 1000 + value_too_long_mark = "..." + if sanitize_query: + # Sanitized query format: "COMMAND ? ?" + args_length = len(args) + if args_length > 0: + out = [str(args[0])] + ["?"] * (args_length - 1) + out_str = " ".join(out) + + if len(out_str) > cmd_max_len: + out_str = ( + out_str[: cmd_max_len - len(value_too_long_mark)] + + value_too_long_mark + ) + else: + out_str = "" + return out_str + + value_max_len = 100 length = 0 out = [] for arg in args: diff --git a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py index 1ae1690efa..1c64fcb13e 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py +++ b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py @@ -148,6 +148,40 @@ def request_hook(span, conn, args, kwargs): span = spans[0] self.assertEqual(span.attributes.get(custom_attribute_name), "GET") + def test_query_sanitizer_enabled(self): + redis_client = redis.Redis() + connection = redis.connection.Connection() + redis_client.connection = connection + + RedisInstrumentor().uninstrument() + RedisInstrumentor().instrument( + tracer_provider=self.tracer_provider, + sanitize_query=True, + ) + + with mock.patch.object(redis_client, "connection"): + redis_client.set("key", "value") + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + + span = spans[0] + self.assertEqual(span.attributes.get("db.statement"), "SET ? ?") + + def test_query_sanitizer_disabled(self): + redis_client = redis.Redis() + connection = redis.connection.Connection() + redis_client.connection = connection + + with mock.patch.object(redis_client, "connection"): + redis_client.set("key", "value") + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + + span = spans[0] + self.assertEqual(span.attributes.get("db.statement"), "SET key value") + def test_no_op_tracer_provider(self): RedisInstrumentor().uninstrument() tracer_provider = trace.NoOpTracerProvider() diff --git a/tests/opentelemetry-docker-tests/tests/redis/test_redis_functional.py b/tests/opentelemetry-docker-tests/tests/redis/test_redis_functional.py index db82ec489c..675a37fa9f 100644 --- a/tests/opentelemetry-docker-tests/tests/redis/test_redis_functional.py +++ b/tests/opentelemetry-docker-tests/tests/redis/test_redis_functional.py @@ -45,6 +45,27 @@ def _check_span(self, span, name): ) self.assertEqual(span.attributes[SpanAttributes.NET_PEER_PORT], 6379) + def test_long_command_sanitized(self): + RedisInstrumentor().uninstrument() + RedisInstrumentor().instrument( + tracer_provider=self.tracer_provider, sanitize_query=True + ) + + self.redis_client.mget(*range(2000)) + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + self._check_span(span, "MGET") + self.assertTrue( + span.attributes.get(SpanAttributes.DB_STATEMENT).startswith( + "MGET ? ? ? ?" + ) + ) + self.assertTrue( + span.attributes.get(SpanAttributes.DB_STATEMENT).endswith("...") + ) + def test_long_command(self): self.redis_client.mget(*range(1000)) @@ -61,6 +82,22 @@ def test_long_command(self): span.attributes.get(SpanAttributes.DB_STATEMENT).endswith("...") ) + def test_basics_sanitized(self): + RedisInstrumentor().uninstrument() + RedisInstrumentor().instrument( + tracer_provider=self.tracer_provider, sanitize_query=True + ) + + self.assertIsNone(self.redis_client.get("cheese")) + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + self._check_span(span, "GET") + self.assertEqual( + span.attributes.get(SpanAttributes.DB_STATEMENT), "GET ?" + ) + self.assertEqual(span.attributes.get("db.redis.args_length"), 2) + def test_basics(self): self.assertIsNone(self.redis_client.get("cheese")) spans = self.memory_exporter.get_finished_spans() @@ -72,6 +109,28 @@ def test_basics(self): ) self.assertEqual(span.attributes.get("db.redis.args_length"), 2) + def test_pipeline_traced_sanitized(self): + RedisInstrumentor().uninstrument() + RedisInstrumentor().instrument( + tracer_provider=self.tracer_provider, sanitize_query=True + ) + + with self.redis_client.pipeline(transaction=False) as pipeline: + pipeline.set("blah", 32) + pipeline.rpush("foo", "éé") + pipeline.hgetall("xxx") + pipeline.execute() + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + self._check_span(span, "SET RPUSH HGETALL") + self.assertEqual( + span.attributes.get(SpanAttributes.DB_STATEMENT), + "SET ? ?\nRPUSH ? ?\nHGETALL ?", + ) + self.assertEqual(span.attributes.get("db.redis.pipeline_length"), 3) + def test_pipeline_traced(self): with self.redis_client.pipeline(transaction=False) as pipeline: pipeline.set("blah", 32) @@ -89,6 +148,27 @@ def test_pipeline_traced(self): ) self.assertEqual(span.attributes.get("db.redis.pipeline_length"), 3) + def test_pipeline_immediate_sanitized(self): + RedisInstrumentor().uninstrument() + RedisInstrumentor().instrument( + tracer_provider=self.tracer_provider, sanitize_query=True + ) + + with self.redis_client.pipeline() as pipeline: + pipeline.set("a", 1) + pipeline.immediate_execute_command("SET", "b", 2) + pipeline.execute() + + spans = self.memory_exporter.get_finished_spans() + # expecting two separate spans here, rather than a + # single span for the whole pipeline + self.assertEqual(len(spans), 2) + span = spans[0] + self._check_span(span, "SET") + self.assertEqual( + span.attributes.get(SpanAttributes.DB_STATEMENT), "SET ? ?" + ) + def test_pipeline_immediate(self): with self.redis_client.pipeline() as pipeline: pipeline.set("a", 1)