diff --git a/elasticapm/instrumentation/packages/botocore.py b/elasticapm/instrumentation/packages/botocore.py index 7320e02d3..78ef99feb 100644 --- a/elasticapm/instrumentation/packages/botocore.py +++ b/elasticapm/instrumentation/packages/botocore.py @@ -32,12 +32,14 @@ from elasticapm.conf import constants from elasticapm.instrumentation.packages.base import AbstractInstrumentedModule -from elasticapm.traces import capture_span +from elasticapm.traces import capture_span, execution_context from elasticapm.utils.compat import urlparse from elasticapm.utils.logging import get_logger logger = get_logger("elasticapm.instrument") +SQS_MAX_ATTRIBUTES = 10 + HandlerInfo = namedtuple("HandlerInfo", ("signature", "span_type", "span_subtype", "span_action", "context")) @@ -171,7 +173,12 @@ def handle_sqs(operation_name, service, instance, args, kwargs, context): def modify_span_sqs(span, args, kwargs): - trace_parent = span.transaction.trace_parent.copy_from(span_id=span.id) + if span.id: + trace_parent = span.transaction.trace_parent.copy_from(span_id=span.id) + else: + # this is a dropped span, use transaction id instead + transaction = execution_context.get_transaction() + trace_parent = transaction.trace_parent.copy_from(span_id=transaction.id) attributes = {constants.TRACEPARENT_HEADER_NAME: {"DataType": "String", "StringValue": trace_parent.to_string()}} if trace_parent.tracestate: attributes[constants.TRACESTATE_HEADER_NAME] = {"DataType": "String", "StringValue": trace_parent.tracestate} @@ -179,12 +186,12 @@ def modify_span_sqs(span, args, kwargs): attributes_count = len(attributes) if "MessageAttributes" in args[1]: messages = [args[1]] - # elif "Entries" in args[1]: - # messages = args[1]["Entries"] + elif "Entries" in args[1]: + messages = args[1]["Entries"] else: messages = [] for message in messages: - if len(message["MessageAttributes"]) + attributes_count <= 10: + if len(message["MessageAttributes"]) + attributes_count <= SQS_MAX_ATTRIBUTES: message["MessageAttributes"].update(attributes) else: logger.info("Not adding disttracing headers to message due to attribute limit reached") diff --git a/tests/instrumentation/botocore_tests.py b/tests/instrumentation/botocore_tests.py index b8de516d0..5e9319886 100644 --- a/tests/instrumentation/botocore_tests.py +++ b/tests/instrumentation/botocore_tests.py @@ -29,12 +29,13 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import os -import mock import pytest +import elasticapm from elasticapm.conf import constants -from elasticapm.instrumentation.packages.botocore import BotocoreInstrumentation +from elasticapm.instrumentation.packages.botocore import SQS_MAX_ATTRIBUTES from elasticapm.utils.compat import urlparse +from tests.utils import assert_any_record_contains boto3 = pytest.importorskip("boto3") @@ -79,9 +80,7 @@ def dynamodb(): @pytest.fixture() def sqs_client_and_queue(): sqs = boto3.client("sqs", endpoint_url=LOCALSTACK_ENDPOINT) - response = sqs.create_queue( - QueueName="myqueue", Attributes={"DelaySeconds": "60", "MessageRetentionPeriod": "86400"} - ) + response = sqs.create_queue(QueueName="myqueue", Attributes={"MessageRetentionPeriod": "86400"}) queue_url = response["QueueUrl"] yield sqs, queue_url sqs.delete_queue(QueueUrl=queue_url) @@ -213,7 +212,7 @@ def test_sqs_send(instrument, elasticapm_client, sqs_client_and_queue): }, MessageBody=("bar"), ) - elasticapm_client.end_transaction("test", "test") + transaction = elasticapm_client.end_transaction("test", "test") span = elasticapm_client.events[constants.SPAN][0] assert span["name"] == "SQS SEND to myqueue" assert span["type"] == "messaging" @@ -224,6 +223,19 @@ def test_sqs_send(instrument, elasticapm_client, sqs_client_and_queue): assert span["context"]["destination"]["service"]["resource"] == "sqs/myqueue" assert span["context"]["destination"]["service"]["type"] == "messaging" + messages = sqs.receive_message( + QueueUrl=queue_url, + AttributeNames=["All"], + MessageAttributeNames=[ + "All", + ], + ) + message = messages["Messages"][0] + assert "traceparent" in message["MessageAttributes"] + traceparent = message["MessageAttributes"]["traceparent"]["StringValue"] + assert transaction.trace_parent.trace_id in traceparent + assert span["id"] in traceparent + def test_sqs_send_batch(instrument, elasticapm_client, sqs_client_and_queue): sqs, queue_url = sqs_client_and_queue @@ -234,12 +246,11 @@ def test_sqs_send_batch(instrument, elasticapm_client, sqs_client_and_queue): { "Id": "foo", "MessageBody": "foo", - "DelaySeconds": 123, "MessageAttributes": {"string": {"StringValue": "foo", "DataType": "String"}}, }, ], ) - elasticapm_client.end_transaction("test", "test") + transaction = elasticapm_client.end_transaction("test", "test") span = elasticapm_client.events[constants.SPAN][0] assert span["name"] == "SQS SEND_BATCH to myqueue" assert span["type"] == "messaging" @@ -249,6 +260,68 @@ def test_sqs_send_batch(instrument, elasticapm_client, sqs_client_and_queue): assert span["context"]["destination"]["service"]["name"] == "sqs" assert span["context"]["destination"]["service"]["resource"] == "sqs/myqueue" assert span["context"]["destination"]["service"]["type"] == "messaging" + messages = sqs.receive_message( + QueueUrl=queue_url, + AttributeNames=["All"], + MessageAttributeNames=[ + "All", + ], + ) + message = messages["Messages"][0] + assert "traceparent" in message["MessageAttributes"] + traceparent = message["MessageAttributes"]["traceparent"]["StringValue"] + assert transaction.trace_parent.trace_id in traceparent + assert span["id"] in traceparent + + +def test_sqs_send_too_many_attributes_for_disttracing(instrument, elasticapm_client, sqs_client_and_queue, caplog): + sqs, queue_url = sqs_client_and_queue + attributes = {str(i): {"DataType": "String", "StringValue": str(i)} for i in range(SQS_MAX_ATTRIBUTES)} + elasticapm_client.begin_transaction("test") + with caplog.at_level("INFO"): + sqs.send_message( + QueueUrl=queue_url, + MessageAttributes=attributes, + MessageBody=("bar"), + ) + elasticapm_client.end_transaction("test", "test") + messages = sqs.receive_message( + QueueUrl=queue_url, + AttributeNames=["All"], + MessageAttributeNames=[ + "All", + ], + ) + message = messages["Messages"][0] + assert "traceparent" not in message["MessageAttributes"] + assert_any_record_contains(caplog.records, "Not adding disttracing headers") + + +def test_sqs_send_disttracing_dropped_span(instrument, elasticapm_client, sqs_client_and_queue): + sqs, queue_url = sqs_client_and_queue + elasticapm_client.begin_transaction("test") + with elasticapm.capture_span("test", leaf=True): + sqs.send_message( + QueueUrl=queue_url, + MessageAttributes={ + "Title": {"DataType": "String", "StringValue": "foo"}, + }, + MessageBody=("bar"), + ) + transaction = elasticapm_client.end_transaction("test", "test") + assert len(elasticapm_client.events[constants.SPAN]) == 1 + messages = sqs.receive_message( + QueueUrl=queue_url, + AttributeNames=["All"], + MessageAttributeNames=[ + "All", + ], + ) + message = messages["Messages"][0] + assert "traceparent" in message["MessageAttributes"] + traceparent = message["MessageAttributes"]["traceparent"]["StringValue"] + assert transaction.trace_parent.trace_id in traceparent + assert transaction.id in traceparent # due to DroppedSpan, transaction.id is used instead of span.id def test_sqs_receive(instrument, elasticapm_client, sqs_client_and_queue):