Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make propagators conform to spec #1811

Merged
merged 6 commits into from
May 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#1806](https://github.com/open-telemetry/opentelemetry-python/pull/1806))
- Rename CompositeHTTPPropagator to CompositePropagator as per specification.
([#1807](https://github.com/open-telemetry/opentelemetry-python/pull/1807))
- Propagators use the root context as default for `extract` and do not modify
the context if extracting from carrier does not work.
([#1811](https://github.com/open-telemetry/opentelemetry-python/pull/1811))

### Removed
- Moved `opentelemetry-instrumentation` to contrib repository.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def extract(
used to construct a Context. This object
must be paired with an appropriate getter
which understands how to extract a value from it.
context: an optional Context to use. Defaults to current
context: an optional Context to use. Defaults to root
context if not set.
"""
return get_global_textmap().extract(carrier, context, getter=getter)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def extract(
used to construct a Context. This object
must be paired with an appropriate getter
which understands how to extract a value from it.
context: an optional Context to use. Defaults to current
context: an optional Context to use. Defaults to root
context if not set.
Returns:
A Context with configuration found in the carrier.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,28 +43,31 @@ def extract(

See `opentelemetry.propagators.textmap.TextMapPropagator.extract`
"""
if context is None:
context = Context()

header = getter.get(carrier, self._TRACEPARENT_HEADER_NAME)

if not header:
return trace.set_span_in_context(trace.INVALID_SPAN, context)
return context

match = re.search(self._TRACEPARENT_HEADER_FORMAT_RE, header[0])
if not match:
return trace.set_span_in_context(trace.INVALID_SPAN, context)
return context

version = match.group(1)
trace_id = match.group(2)
span_id = match.group(3)
trace_flags = match.group(4)

if trace_id == "0" * 32 or span_id == "0" * 16:
return trace.set_span_in_context(trace.INVALID_SPAN, context)
return context

if version == "00":
if match.group(5):
return trace.set_span_in_context(trace.INVALID_SPAN, context)
return context
if version == "ff":
return trace.set_span_in_context(trace.INVALID_SPAN, context)
return context

tracestate_headers = getter.get(carrier, self._TRACESTATE_HEADER_NAME)
if tracestate_headers is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from unittest.mock import Mock, patch

from opentelemetry import trace
from opentelemetry.context import Context
from opentelemetry.trace.propagation import tracecontext
from opentelemetry.trace.span import TraceState

Expand Down Expand Up @@ -270,3 +271,51 @@ def test_fields(self, mock_get_current_span, mock_invalid_span_context):
inject_fields.add(mock_call[1][1])

self.assertEqual(inject_fields, FORMAT.fields)

def test_extract_no_trace_parent_to_explicit_ctx(self):
carrier = {"tracestate": ["foo=1"]}
orig_ctx = Context({"k1": "v1"})

ctx = FORMAT.extract(carrier, orig_ctx)
self.assertDictEqual(orig_ctx, ctx)

def test_extract_no_trace_parent_to_implicit_ctx(self):
carrier = {"tracestate": ["foo=1"]}

ctx = FORMAT.extract(carrier)
self.assertDictEqual(Context(), ctx)

def test_extract_invalid_trace_parent_to_explicit_ctx(self):
trace_parent_headers = [
"invalid",
"00-00000000000000000000000000000000-1234567890123456-00",
"00-12345678901234567890123456789012-0000000000000000-00",
"00-12345678901234567890123456789012-1234567890123456-00-residue",
]
for trace_parent in trace_parent_headers:
with self.subTest(trace_parent=trace_parent):
carrier = {
"traceparent": [trace_parent],
"tracestate": ["foo=1"],
}
orig_ctx = Context({"k1": "v1"})

ctx = FORMAT.extract(carrier, orig_ctx)
self.assertDictEqual(orig_ctx, ctx)

def test_extract_invalid_trace_parent_to_implicit_ctx(self):
trace_parent_headers = [
"invalid",
"00-00000000000000000000000000000000-1234567890123456-00",
"00-12345678901234567890123456789012-0000000000000000-00",
"00-12345678901234567890123456789012-1234567890123456-00-residue",
]
for trace_parent in trace_parent_headers:
with self.subTest(trace_parent=trace_parent):
carrier = {
"traceparent": [trace_parent],
"tracestate": ["foo=1"],
}

ctx = FORMAT.extract(carrier)
self.assertDictEqual(Context(), ctx)
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def extract(
context: typing.Optional[Context] = None,
getter: Getter = default_getter,
) -> Context:
if context is None:
context = Context()
trace_id = trace.INVALID_TRACE_ID
span_id = trace.INVALID_SPAN_ID
sampled = "0"
Expand Down Expand Up @@ -97,8 +99,6 @@ def extract(
or self._trace_id_regex.fullmatch(trace_id) is None
or self._span_id_regex.fullmatch(span_id) is None
):
if context is None:
return trace.set_span_in_context(trace.INVALID_SPAN, context)
return context

trace_id = int(trace_id, 16)
Expand Down
87 changes: 68 additions & 19 deletions propagator/opentelemetry-propagator-b3/tests/test_b3_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import opentelemetry.sdk.trace as trace
import opentelemetry.sdk.trace.id_generator as id_generator
import opentelemetry.trace as trace_api
from opentelemetry.context import get_current
from opentelemetry.context import Context, get_current
from opentelemetry.propagators.textmap import DefaultGetter

FORMAT = b3_format.B3Format()
Expand Down Expand Up @@ -219,7 +219,7 @@ def test_flags_and_sampling(self):

def test_derived_ctx_is_returned_for_success(self):
"""Ensure returned context is derived from the given context."""
old_ctx = {"k1": "v1"}
old_ctx = Context({"k1": "v1"})
new_ctx = FORMAT.extract(
{
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
Expand All @@ -229,17 +229,19 @@ def test_derived_ctx_is_returned_for_success(self):
old_ctx,
)
self.assertIn("current-span", new_ctx)
for key, value in old_ctx.items():
for key, value in old_ctx.items(): # pylint:disable=no-member
self.assertIn(key, new_ctx)
# pylint:disable=unsubscriptable-object
self.assertEqual(new_ctx[key], value)

def test_derived_ctx_is_returned_for_failure(self):
"""Ensure returned context is derived from the given context."""
old_ctx = {"k2": "v2"}
old_ctx = Context({"k2": "v2"})
new_ctx = FORMAT.extract({}, old_ctx)
self.assertNotIn("current-span", new_ctx)
for key, value in old_ctx.items():
for key, value in old_ctx.items(): # pylint:disable=no-member
self.assertIn(key, new_ctx)
# pylint:disable=unsubscriptable-object
self.assertEqual(new_ctx[key], value)

def test_64bit_trace_id(self):
Expand All @@ -258,18 +260,24 @@ def test_64bit_trace_id(self):
new_carrier[FORMAT.TRACE_ID_KEY], "0" * 16 + trace_id_64_bit
)

def test_extract_invalid_single_header(self):
def test_extract_invalid_single_header_to_explicit_ctx(self):
"""Given unparsable header, do not modify context"""
old_ctx = {}
old_ctx = Context({"k1": "v1"})

carrier = {FORMAT.SINGLE_HEADER_KEY: "0-1-2-3-4-5-6-7"}
new_ctx = FORMAT.extract(carrier, old_ctx)

self.assertDictEqual(new_ctx, old_ctx)

def test_extract_missing_trace_id(self):
def test_extract_invalid_single_header_to_implicit_ctx(self):
carrier = {FORMAT.SINGLE_HEADER_KEY: "0-1-2-3-4-5-6-7"}
new_ctx = FORMAT.extract(carrier)

self.assertDictEqual(Context(), new_ctx)

def test_extract_missing_trace_id_to_explicit_ctx(self):
"""Given no trace ID, do not modify context"""
old_ctx = {}
old_ctx = Context({"k1": "v1"})

carrier = {
FORMAT.SPAN_ID_KEY: self.serialized_span_id,
Expand All @@ -279,9 +287,18 @@ def test_extract_missing_trace_id(self):

self.assertDictEqual(new_ctx, old_ctx)

def test_extract_invalid_trace_id(self):
def test_extract_missing_trace_id_to_implicit_ctx(self):
carrier = {
FORMAT.SPAN_ID_KEY: self.serialized_span_id,
FORMAT.FLAGS_KEY: "1",
}
new_ctx = FORMAT.extract(carrier)

self.assertDictEqual(Context(), new_ctx)

def test_extract_invalid_trace_id_to_explicit_ctx(self):
"""Given invalid trace ID, do not modify context"""
old_ctx = {}
old_ctx = Context({"k1": "v1"})

carrier = {
FORMAT.TRACE_ID_KEY: "abc123",
Expand All @@ -292,9 +309,19 @@ def test_extract_invalid_trace_id(self):

self.assertDictEqual(new_ctx, old_ctx)

def test_extract_invalid_span_id(self):
def test_extract_invalid_trace_id_to_implicit_ctx(self):
carrier = {
FORMAT.TRACE_ID_KEY: "abc123",
FORMAT.SPAN_ID_KEY: self.serialized_span_id,
FORMAT.FLAGS_KEY: "1",
}
new_ctx = FORMAT.extract(carrier)

self.assertDictEqual(Context(), new_ctx)

def test_extract_invalid_span_id_to_explicit_ctx(self):
"""Given invalid span ID, do not modify context"""
old_ctx = {}
old_ctx = Context({"k1": "v1"})

carrier = {
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
Expand All @@ -305,9 +332,19 @@ def test_extract_invalid_span_id(self):

self.assertDictEqual(new_ctx, old_ctx)

def test_extract_missing_span_id(self):
def test_extract_invalid_span_id_to_implicit_ctx(self):
carrier = {
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
FORMAT.SPAN_ID_KEY: "abc123",
FORMAT.FLAGS_KEY: "1",
}
new_ctx = FORMAT.extract(carrier)

self.assertDictEqual(Context(), new_ctx)

def test_extract_missing_span_id_to_explicit_ctx(self):
"""Given no span ID, do not modify context"""
old_ctx = {}
old_ctx = Context({"k1": "v1"})

carrier = {
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
Expand All @@ -317,15 +354,28 @@ def test_extract_missing_span_id(self):

self.assertDictEqual(new_ctx, old_ctx)

def test_extract_empty_carrier(self):
def test_extract_missing_span_id_to_implicit_ctx(self):
carrier = {
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
FORMAT.FLAGS_KEY: "1",
}
new_ctx = FORMAT.extract(carrier)

self.assertDictEqual(Context(), new_ctx)

def test_extract_empty_carrier_to_explicit_ctx(self):
"""Given no headers at all, do not modify context"""
old_ctx = {}
old_ctx = Context({"k1": "v1"})

carrier = {}
new_ctx = FORMAT.extract(carrier, old_ctx)

self.assertDictEqual(new_ctx, old_ctx)

def test_extract_empty_carrier_to_implicit_ctx(self):
new_ctx = FORMAT.extract({})
self.assertDictEqual(Context(), new_ctx)

@staticmethod
def test_inject_empty_context():
"""If the current context has no span, don't add headers"""
Expand Down Expand Up @@ -368,5 +418,4 @@ def test_extract_none_context(self):

carrier = {}
new_ctx = FORMAT.extract(carrier, old_ctx)
self.assertIsNotNone(new_ctx)
self.assertEqual(new_ctx["current-span"], trace_api.INVALID_SPAN)
self.assertDictEqual(Context(), new_ctx)
Original file line number Diff line number Diff line change
Expand Up @@ -47,31 +47,26 @@ def extract(
) -> Context:

if context is None:
context = get_current()
context = Context()
header = getter.get(carrier, self.TRACE_ID_KEY)
if not header:
return trace.set_span_in_context(trace.INVALID_SPAN, context)
fields = _extract_first_element(header).split(":")
return context

context = self._extract_baggage(getter, carrier, context)
if len(fields) != 4:
return trace.set_span_in_context(trace.INVALID_SPAN, context)

trace_id, span_id, _parent_id, flags = fields
trace_id, span_id, flags = _parse_trace_id_header(header)
if (
trace_id == trace.INVALID_TRACE_ID
or span_id == trace.INVALID_SPAN_ID
):
return trace.set_span_in_context(trace.INVALID_SPAN, context)
return context

span = trace.NonRecordingSpan(
trace.SpanContext(
trace_id=int(trace_id, 16),
span_id=int(span_id, 16),
trace_id=trace_id,
span_id=span_id,
is_remote=True,
trace_flags=trace.TraceFlags(
int(flags, 16) & trace.TraceFlags.SAMPLED
),
trace_flags=trace.TraceFlags(flags & trace.TraceFlags.SAMPLED),
)
)
return trace.set_span_in_context(span, context)
Expand Down Expand Up @@ -147,3 +142,35 @@ def _extract_first_element(
if items is None:
return None
return next(iter(items), None)


def _parse_trace_id_header(
items: typing.Iterable[CarrierT],
) -> typing.Tuple[int]:
invalid_header_result = (trace.INVALID_TRACE_ID, trace.INVALID_SPAN_ID, 0)

header = _extract_first_element(items)
if header is None:
return invalid_header_result

fields = header.split(":")
if len(fields) != 4:
return invalid_header_result

trace_id_str, span_id_str, _parent_id_str, flags_str = fields
flags = _int_from_hex_str(flags_str, None)
if flags is None:
return invalid_header_result

trace_id = _int_from_hex_str(trace_id_str, trace.INVALID_TRACE_ID)
span_id = _int_from_hex_str(span_id_str, trace.INVALID_SPAN_ID)
return trace_id, span_id, flags


def _int_from_hex_str(
identifier: str, default: typing.Optional[int]
) -> typing.Optional[int]:
try:
return int(identifier, 16)
except ValueError:
return default
Loading