Skip to content

Commit

Permalink
Add AwsXrayLambdaPropagator
Browse files Browse the repository at this point in the history
Fixes #2457
  • Loading branch information
ocelotl committed Jun 10, 2024
1 parent 9762152 commit e698be6
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 0 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- `opentelemetry-sdk-extension-aws` Add AwsXrayLambdaPropagator
([#2573](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2573))

### Breaking changes

- `opentelemetry-instrumentation-asgi`, `opentelemetry-instrumentation-fastapi`, `opentelemetry-instrumentation-starlette` Use `tracer` and `meter` of originating components instead of one from `asgi` middleware
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@

import logging
import typing
from os import environ

from opentelemetry import trace
from opentelemetry.context import Context
Expand All @@ -71,6 +72,8 @@
)

TRACE_HEADER_KEY = "X-Amzn-Trace-Id"
AWS_TRACE_HEADER_PROP = "com.amazonaws.xray.traceHeader"
AWS_TRACE_HEADER_ENV_KEY = "_X_AMZN_TRACE_ID"
KV_PAIR_DELIMITER = ";"
KEY_AND_VALUE_DELIMITER = "="

Expand Down Expand Up @@ -324,3 +327,41 @@ def fields(self):
"""Returns a set with the fields set in `inject`."""

return {TRACE_HEADER_KEY}


class AwsXrayLambdaPropagator(AwsXRayPropagator):

_instance = None

def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls, *args, **kwargs)

return cls._instance

def extract(
self,
carrier: CarrierT,
context: typing.Optional[Context] = None,
getter: Getter[CarrierT] = default_getter,
) -> Context:

xray_context = super().extract(carrier, context=context, getter=getter)

if trace.get_current_span(context=context).get_span_context().is_valid:
return xray_context

trace_header = environ.get(AWS_TRACE_HEADER_PROP) or environ.get(
AWS_TRACE_HEADER_ENV_KEY
)

if trace_header is None:
return xray_context

result = super().extract(
{TRACE_HEADER_KEY: trace_header},
context=xray_context,
getter=getter,
)

return result
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from os import environ
from unittest import TestCase
from unittest.mock import patch

from requests.structures import CaseInsensitiveDict

from opentelemetry.context import get_current
from opentelemetry.propagators.aws.aws_xray_propagator import (
TRACE_HEADER_KEY,
AwsXrayLambdaPropagator,
)
from opentelemetry.propagators.textmap import DefaultGetter
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.trace import Link, TraceState, get_current_span


class AwsXRayLambdaPropagatorTest(TestCase):

@patch.dict(
environ,
{
"_X_AMZN_TRACE_ID": (
"Root=1-00000001-d188f8fa79d48a391a778fa6;"
"Parent=53995c3f42cd8ad8;Sampled=1;Foo=Bar"
)
},
)
def test_extract_from_environment_variable(self):
propagator = AwsXrayLambdaPropagator()

default_getter = DefaultGetter()

actual_context = get_current_span(
propagator.extract(
{}, context=get_current(), getter=default_getter
)
).get_span_context()

self.assertEqual(
hex(actual_context.trace_id), "0x1d188f8fa79d48a391a778fa6"
)
self.assertEqual(hex(actual_context.span_id), "0x53995c3f42cd8ad8")
self.assertTrue(
actual_context.trace_flags.sampled,
)
self.assertEqual(actual_context.trace_state, TraceState.get_default())

@patch.dict(
environ,
{
"_X_AMZN_TRACE_ID": (
"Root=1-00000002-240000000000000000000002;"
"Parent=1600000000000002;Sampled=1;Foo=Bar"
)
},
)
def test_add_link_from_environment_variable(self):

propagator = AwsXrayLambdaPropagator()

default_getter = DefaultGetter()

carrier = CaseInsensitiveDict(
{
TRACE_HEADER_KEY: (
"Root=1-00000001-240000000000000000000001;"
"Parent=1600000000000001;Sampled=1"
)
}
)

extracted_context = propagator.extract(
carrier, context=get_current(), getter=default_getter
)

link_context = propagator.extract(
carrier, context=extracted_context, getter=default_getter
)

span = ReadableSpan(
"test", parent=extracted_context, links=[Link(link_context)]
)

span_parent_context = get_current_span(span.parent).get_span_context()

self.assertEqual(
hex(span_parent_context.trace_id), "0x2240000000000000000000002"
)
self.assertEqual(
hex(span_parent_context.span_id), "0x1600000000000002"
)
self.assertTrue(
span_parent_context.trace_flags.sampled,
)
self.assertEqual(
span_parent_context.trace_state, TraceState.get_default()
)

span_link_context = get_current_span(
span.links[0].context
).get_span_context()

self.assertEqual(
hex(span_link_context.trace_id), "0x1240000000000000000000001"
)
self.assertEqual(hex(span_link_context.span_id), "0x1600000000000001")
self.assertTrue(
span_link_context.trace_flags.sampled,
)
self.assertEqual(
span_link_context.trace_state, TraceState.get_default()
)

0 comments on commit e698be6

Please sign in to comment.