Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sanitize redis db_statement by default #1776

Merged
merged 11 commits into from
Jun 13, 2023
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fix redis db.statements to be sanitized by default
([#1778](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1776))
nemoshlag marked this conversation as resolved.
Show resolved Hide resolved
- Fix elasticsearch db.statement attribute to be sanitized by default
([#1758](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1758))
- Fix `AttributeError` when AWS Lambda handler receives a list event
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ async def redis_get():
response_hook (Callable) - a function with extra user-defined logic to be performed after performing the request
this function signature is: def response_hook(span: Span, instance: redis.connection.Connection, response) -> None

sanitize_query (Boolean) - default False, enable the Redis query sanitization

for example:

.. code: python
Expand All @@ -88,37 +86,18 @@ def response_hook(span, instance, response):
client = redis.StrictRedis(host="localhost", port=6379)
client.get("my-key")

Configuration
-------------

Query sanitization
******************
To enable query sanitization with an environment variable, set
``OTEL_PYTHON_INSTRUMENTATION_SANITIZE_REDIS`` to "true".

For example,

::

export OTEL_PYTHON_INSTRUMENTATION_SANITIZE_REDIS="true"

will result in traced queries like "SET ? ?".

API
---
"""
import typing
from os import environ
from typing import Any, Collection

import redis
from wrapt import wrap_function_wrapper

from opentelemetry import trace
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.redis.environment_variables import (
OTEL_PYTHON_INSTRUMENTATION_SANITIZE_REDIS,
)
from opentelemetry.instrumentation.redis.package import _instruments
from opentelemetry.instrumentation.redis.util import (
_extract_conn_attributes,
Expand Down Expand Up @@ -161,10 +140,9 @@ def _instrument(
tracer,
request_hook: _RequestHookT = None,
response_hook: _ResponseHookT = None,
sanitize_query: bool = False,
):
def _traced_execute_command(func, instance, args, kwargs):
query = _format_command_args(args, sanitize_query)
query = _format_command_args(args)

if len(args) > 0 and args[0]:
name = args[0]
Expand Down Expand Up @@ -194,7 +172,7 @@ def _traced_execute_pipeline(func, instance, args, kwargs):

cmds = [
_format_command_args(
c.args if hasattr(c, "args") else c[0], sanitize_query
c.args if hasattr(c, "args") else c[0],
)
for c in command_stack
]
Expand Down Expand Up @@ -307,15 +285,6 @@ def _instrument(self, **kwargs):
tracer,
request_hook=kwargs.get("request_hook"),
response_hook=kwargs.get("response_hook"),
sanitize_query=kwargs.get(
"sanitize_query",
environ.get(
OTEL_PYTHON_INSTRUMENTATION_SANITIZE_REDIS, "false"
)
.lower()
.strip()
== "true",
),
)

def _uninstrument(self, **kwargs):
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -48,41 +48,23 @@ def _extract_conn_attributes(conn_kwargs):
return attributes


def _format_command_args(args, sanitize_query):
def _format_command_args(args):
"""Format and sanitize command arguments, and trim them as needed"""
cmd_max_len = 1000
value_too_long_mark = "..."
if sanitize_query:
# Sanitized query format: "COMMAND ? ?"
args_length = len(args)
if args_length > 0:
out = [str(args[0])] + ["?"] * (args_length - 1)
out_str = " ".join(out)

if len(out_str) > cmd_max_len:
out_str = (
out_str[: cmd_max_len - len(value_too_long_mark)]
+ value_too_long_mark
)
else:
out_str = ""
return out_str
# Sanitized query format: "COMMAND ? ?"
args_length = len(args)
if args_length > 0:
out = [str(args[0])] + ["?"] * (args_length - 1)
out_str = " ".join(out)

value_max_len = 100
length = 0
out = []
for arg in args:
cmd = str(arg)
if len(out_str) > cmd_max_len:
out_str = (
out_str[: cmd_max_len - len(value_too_long_mark)]
+ value_too_long_mark
)
else:
out_str = ""

if len(cmd) > value_max_len:
cmd = cmd[:value_max_len] + value_too_long_mark

if length + len(cmd) > cmd_max_len:
prefix = cmd[: cmd_max_len - length]
out.append(f"{prefix}{value_too_long_mark}")
break

out.append(cmd)
length += len(cmd)

return " ".join(out)
return out_str
Original file line number Diff line number Diff line change
Expand Up @@ -168,32 +168,6 @@ def test_query_sanitizer_enabled(self):
span = spans[0]
self.assertEqual(span.attributes.get("db.statement"), "SET ? ?")

def test_query_sanitizer_enabled_env(self):
redis_client = redis.Redis()
connection = redis.connection.Connection()
redis_client.connection = connection

RedisInstrumentor().uninstrument()

env_patch = mock.patch.dict(
"os.environ",
{"OTEL_PYTHON_INSTRUMENTATION_SANITIZE_REDIS": "true"},
)
env_patch.start()
RedisInstrumentor().instrument(
tracer_provider=self.tracer_provider,
)

with mock.patch.object(redis_client, "connection"):
redis_client.set("key", "value")

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)

span = spans[0]
self.assertEqual(span.attributes.get("db.statement"), "SET ? ?")
env_patch.stop()

def test_query_sanitizer_disabled(self):
nemoshlag marked this conversation as resolved.
Show resolved Hide resolved
redis_client = redis.Redis()
connection = redis.connection.Connection()
Expand All @@ -206,7 +180,7 @@ def test_query_sanitizer_disabled(self):
self.assertEqual(len(spans), 1)

span = spans[0]
self.assertEqual(span.attributes.get("db.statement"), "SET key value")
self.assertEqual(span.attributes.get("db.statement"), "SET ? ?")

def test_no_op_tracer_provider(self):
RedisInstrumentor().uninstrument()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ def _check_span(self, span, name):

def test_long_command_sanitized(self):
RedisInstrumentor().uninstrument()
RedisInstrumentor().instrument(
tracer_provider=self.tracer_provider, sanitize_query=True
)
RedisInstrumentor().instrument(tracer_provider=self.tracer_provider)

self.redis_client.mget(*range(2000))

Expand All @@ -75,7 +73,7 @@ def test_long_command(self):
self._check_span(span, "MGET")
self.assertTrue(
span.attributes.get(SpanAttributes.DB_STATEMENT).startswith(
"MGET 0 1 2 3"
"MGET ? ? ? ?"
)
)
self.assertTrue(
Expand All @@ -84,9 +82,7 @@ def test_long_command(self):

def test_basics_sanitized(self):
RedisInstrumentor().uninstrument()
RedisInstrumentor().instrument(
tracer_provider=self.tracer_provider, sanitize_query=True
)
RedisInstrumentor().instrument(tracer_provider=self.tracer_provider)

self.assertIsNone(self.redis_client.get("cheese"))
spans = self.memory_exporter.get_finished_spans()
Expand All @@ -105,15 +101,13 @@ def test_basics(self):
span = spans[0]
self._check_span(span, "GET")
self.assertEqual(
span.attributes.get(SpanAttributes.DB_STATEMENT), "GET cheese"
span.attributes.get(SpanAttributes.DB_STATEMENT), "GET ?"
)
self.assertEqual(span.attributes.get("db.redis.args_length"), 2)

def test_pipeline_traced_sanitized(self):
RedisInstrumentor().uninstrument()
RedisInstrumentor().instrument(
tracer_provider=self.tracer_provider, sanitize_query=True
)
RedisInstrumentor().instrument(tracer_provider=self.tracer_provider)

with self.redis_client.pipeline(transaction=False) as pipeline:
pipeline.set("blah", 32)
Expand Down Expand Up @@ -144,15 +138,13 @@ def test_pipeline_traced(self):
self._check_span(span, "SET RPUSH HGETALL")
self.assertEqual(
span.attributes.get(SpanAttributes.DB_STATEMENT),
"SET blah 32\nRPUSH foo éé\nHGETALL xxx",
"SET ? ?\nRPUSH ? ?\nHGETALL ?",
)
self.assertEqual(span.attributes.get("db.redis.pipeline_length"), 3)

def test_pipeline_immediate_sanitized(self):
RedisInstrumentor().uninstrument()
RedisInstrumentor().instrument(
tracer_provider=self.tracer_provider, sanitize_query=True
)
RedisInstrumentor().instrument(tracer_provider=self.tracer_provider)

with self.redis_client.pipeline() as pipeline:
pipeline.set("a", 1)
Expand Down Expand Up @@ -182,7 +174,7 @@ def test_pipeline_immediate(self):
span = spans[0]
self._check_span(span, "SET")
self.assertEqual(
span.attributes.get(SpanAttributes.DB_STATEMENT), "SET b 2"
span.attributes.get(SpanAttributes.DB_STATEMENT), "SET ? ?"
)

def test_parent(self):
Expand Down Expand Up @@ -230,7 +222,7 @@ def test_basics(self):
span = spans[0]
self._check_span(span, "GET")
self.assertEqual(
span.attributes.get(SpanAttributes.DB_STATEMENT), "GET cheese"
span.attributes.get(SpanAttributes.DB_STATEMENT), "GET ?"
)
self.assertEqual(span.attributes.get("db.redis.args_length"), 2)

Expand All @@ -247,7 +239,7 @@ def test_pipeline_traced(self):
self._check_span(span, "SET RPUSH HGETALL")
self.assertEqual(
span.attributes.get(SpanAttributes.DB_STATEMENT),
"SET blah 32\nRPUSH foo éé\nHGETALL xxx",
"SET ? ?\nRPUSH ? ?\nHGETALL ?",
)
self.assertEqual(span.attributes.get("db.redis.pipeline_length"), 3)

Expand Down Expand Up @@ -308,7 +300,7 @@ def test_long_command(self):
self._check_span(span, "MGET")
self.assertTrue(
span.attributes.get(SpanAttributes.DB_STATEMENT).startswith(
"MGET 0 1 2 3"
"MGET ? ? ? ?"
)
)
self.assertTrue(
Expand All @@ -322,7 +314,7 @@ def test_basics(self):
span = spans[0]
self._check_span(span, "GET")
self.assertEqual(
span.attributes.get(SpanAttributes.DB_STATEMENT), "GET cheese"
span.attributes.get(SpanAttributes.DB_STATEMENT), "GET ?"
)
self.assertEqual(span.attributes.get("db.redis.args_length"), 2)

Expand All @@ -344,7 +336,7 @@ async def pipeline_simple():
self._check_span(span, "SET RPUSH HGETALL")
self.assertEqual(
span.attributes.get(SpanAttributes.DB_STATEMENT),
"SET blah 32\nRPUSH foo éé\nHGETALL xxx",
"SET ? ?\nRPUSH ? ?\nHGETALL ?",
)
self.assertEqual(span.attributes.get("db.redis.pipeline_length"), 3)

Expand All @@ -364,7 +356,7 @@ async def pipeline_immediate():
span = spans[0]
self._check_span(span, "SET")
self.assertEqual(
span.attributes.get(SpanAttributes.DB_STATEMENT), "SET b 2"
span.attributes.get(SpanAttributes.DB_STATEMENT), "SET ? ?"
)

def test_parent(self):
Expand Down Expand Up @@ -412,7 +404,7 @@ def test_basics(self):
span = spans[0]
self._check_span(span, "GET")
self.assertEqual(
span.attributes.get(SpanAttributes.DB_STATEMENT), "GET cheese"
span.attributes.get(SpanAttributes.DB_STATEMENT), "GET ?"
)
self.assertEqual(span.attributes.get("db.redis.args_length"), 2)

Expand All @@ -434,7 +426,7 @@ async def pipeline_simple():
self._check_span(span, "SET RPUSH HGETALL")
self.assertEqual(
span.attributes.get(SpanAttributes.DB_STATEMENT),
"SET blah 32\nRPUSH foo éé\nHGETALL xxx",
"SET ? ?\nRPUSH ? ?\nHGETALL ?",
)
self.assertEqual(span.attributes.get("db.redis.pipeline_length"), 3)

Expand Down Expand Up @@ -488,5 +480,5 @@ def test_get(self):
span = spans[0]
self._check_span(span, "GET")
self.assertEqual(
span.attributes.get(SpanAttributes.DB_STATEMENT), "GET foo"
span.attributes.get(SpanAttributes.DB_STATEMENT), "GET ?"
)