Skip to content

Commit

Permalink
Add xray propagators that prioritizes xray environment variable (#2573)
Browse files Browse the repository at this point in the history
* Add AwsXrayLambdaPropagator

Fixes #2457

* Remove unnecessary AWS_TRACE_HEADER_PROP

* Add docstring

* Fix nit

* Add no environment variable test case

* Add test case for valid context

* Remove ipdb

* Fix lint

* Add missing entry point
  • Loading branch information
ocelotl authored Jun 13, 2024
1 parent 361da3e commit 881a179
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 0 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- `opentelemetry-sdk-extension-aws` Add AwsXrayLambdaPropagator
([#2573](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2573))

### Breaking changes

- `opentelemetry-instrumentation-asgi`, `opentelemetry-instrumentation-fastapi`, `opentelemetry-instrumentation-starlette` Use `tracer` and `meter` of originating components instead of one from `asgi` middleware
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [

[project.entry-points.opentelemetry_propagator]
xray = "opentelemetry.propagators.aws:AwsXRayPropagator"
xray_lambda = "opentelemetry.propagators.aws:AwsXRayLambdaPropagator"

[project.urls]
Homepage = "https://github.com/open-telemetry/opentelemetry-python-contrib/tree/main/propagator/opentelemetry-propagator-aws-xray"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@

import logging
import typing
from os import environ

from opentelemetry import trace
from opentelemetry.context import Context
Expand All @@ -71,6 +72,7 @@
)

TRACE_HEADER_KEY = "X-Amzn-Trace-Id"
AWS_TRACE_HEADER_ENV_KEY = "_X_AMZN_TRACE_ID"
KV_PAIR_DELIMITER = ";"
KEY_AND_VALUE_DELIMITER = "="

Expand Down Expand Up @@ -324,3 +326,33 @@ def fields(self):
"""Returns a set with the fields set in `inject`."""

return {TRACE_HEADER_KEY}


class AwsXrayLambdaPropagator(AwsXRayPropagator):
"""Implementation of the AWS X-Ray Trace Header propagation protocol but
with special handling for Lambda's ``_X_AMZN_TRACE_ID` environment
variable.
"""

def extract(
self,
carrier: CarrierT,
context: typing.Optional[Context] = None,
getter: Getter[CarrierT] = default_getter,
) -> Context:

xray_context = super().extract(carrier, context=context, getter=getter)

if trace.get_current_span(context=context).get_span_context().is_valid:
return xray_context

trace_header = environ.get(AWS_TRACE_HEADER_ENV_KEY)

if trace_header is None:
return xray_context

return super().extract(
{TRACE_HEADER_KEY: trace_header},
context=xray_context,
getter=getter,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from os import environ
from unittest import TestCase
from unittest.mock import patch

from requests.structures import CaseInsensitiveDict

from opentelemetry.context import get_current
from opentelemetry.propagators.aws.aws_xray_propagator import (
TRACE_HEADER_KEY,
AwsXrayLambdaPropagator,
)
from opentelemetry.propagators.textmap import DefaultGetter
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.trace import (
Link,
NonRecordingSpan,
SpanContext,
TraceState,
get_current_span,
use_span,
)


class AwsXRayLambdaPropagatorTest(TestCase):

def test_extract_no_environment_variable(self):

actual_context = get_current_span(
AwsXrayLambdaPropagator().extract(
{}, context=get_current(), getter=DefaultGetter()
)
).get_span_context()

self.assertEqual(hex(actual_context.trace_id), "0x0")
self.assertEqual(hex(actual_context.span_id), "0x0")
self.assertFalse(
actual_context.trace_flags.sampled,
)
self.assertEqual(actual_context.trace_state, TraceState.get_default())

def test_extract_no_environment_variable_valid_context(self):

with use_span(NonRecordingSpan(SpanContext(1, 2, False))):

actual_context = get_current_span(
AwsXrayLambdaPropagator().extract(
{}, context=get_current(), getter=DefaultGetter()
)
).get_span_context()

self.assertEqual(hex(actual_context.trace_id), "0x1")
self.assertEqual(hex(actual_context.span_id), "0x2")
self.assertFalse(
actual_context.trace_flags.sampled,
)
self.assertEqual(
actual_context.trace_state, TraceState.get_default()
)

@patch.dict(
environ,
{
"_X_AMZN_TRACE_ID": (
"Root=1-00000001-d188f8fa79d48a391a778fa6;"
"Parent=53995c3f42cd8ad8;Sampled=1;Foo=Bar"
)
},
)
def test_extract_from_environment_variable(self):

actual_context = get_current_span(
AwsXrayLambdaPropagator().extract(
{}, context=get_current(), getter=DefaultGetter()
)
).get_span_context()

self.assertEqual(
hex(actual_context.trace_id), "0x1d188f8fa79d48a391a778fa6"
)
self.assertEqual(hex(actual_context.span_id), "0x53995c3f42cd8ad8")
self.assertTrue(
actual_context.trace_flags.sampled,
)
self.assertEqual(actual_context.trace_state, TraceState.get_default())

@patch.dict(
environ,
{
"_X_AMZN_TRACE_ID": (
"Root=1-00000002-240000000000000000000002;"
"Parent=1600000000000002;Sampled=1;Foo=Bar"
)
},
)
def test_add_link_from_environment_variable(self):

propagator = AwsXrayLambdaPropagator()

default_getter = DefaultGetter()

carrier = CaseInsensitiveDict(
{
TRACE_HEADER_KEY: (
"Root=1-00000001-240000000000000000000001;"
"Parent=1600000000000001;Sampled=1"
)
}
)

extracted_context = propagator.extract(
carrier, context=get_current(), getter=default_getter
)

link_context = propagator.extract(
carrier, context=extracted_context, getter=default_getter
)

span = ReadableSpan(
"test", parent=extracted_context, links=[Link(link_context)]
)

span_parent_context = get_current_span(span.parent).get_span_context()

self.assertEqual(
hex(span_parent_context.trace_id), "0x2240000000000000000000002"
)
self.assertEqual(
hex(span_parent_context.span_id), "0x1600000000000002"
)
self.assertTrue(
span_parent_context.trace_flags.sampled,
)
self.assertEqual(
span_parent_context.trace_state, TraceState.get_default()
)

span_link_context = get_current_span(
span.links[0].context
).get_span_context()

self.assertEqual(
hex(span_link_context.trace_id), "0x1240000000000000000000001"
)
self.assertEqual(hex(span_link_context.span_id), "0x1600000000000001")
self.assertTrue(
span_link_context.trace_flags.sampled,
)
self.assertEqual(
span_link_context.trace_state, TraceState.get_default()
)

0 comments on commit 881a179

Please sign in to comment.