diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e889bccc13..0e64b0c511d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added example for running Django with auto instrumentation. ([#1803](https://github.com/open-telemetry/opentelemetry-python/pull/1803)) +- Added `B3SingleFormat` and `B3MultiFormat` propagators to the `opentelemetry-propagator-b3` package. + ([#1823](https://github.com/open-telemetry/opentelemetry-python/pull/1823)) ### Changed - Fixed OTLP gRPC exporter silently failing if scheme is not specified in endpoint. @@ -18,6 +20,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Propagators use the root context as default for `extract` and do not modify the context if extracting from carrier does not work. ([#1811](https://github.com/open-telemetry/opentelemetry-python/pull/1811)) +- Fixed `b3` propagator entrypoint to point to `B3SingleFormat` propagator. + ([#1823](https://github.com/open-telemetry/opentelemetry-python/pull/1823)) +- Added `b3multi` propagator entrypoint to point to `B3MultiFormat` propagator. + ([#1823](https://github.com/open-telemetry/opentelemetry-python/pull/1823)) ### Removed - Moved `opentelemetry-instrumentation` to contrib repository. diff --git a/propagator/opentelemetry-propagator-b3/setup.cfg b/propagator/opentelemetry-propagator-b3/setup.cfg index 4739f9ff7ba..86962a8c814 100644 --- a/propagator/opentelemetry-propagator-b3/setup.cfg +++ b/propagator/opentelemetry-propagator-b3/setup.cfg @@ -41,6 +41,7 @@ package_dir= packages=find_namespace: install_requires = opentelemetry-api == 1.2.0.dev0 + deprecated >= 1.2.6 [options.extras_require] test = @@ -50,4 +51,5 @@ where = src [options.entry_points] opentelemetry_propagator = - b3 = opentelemetry.propagators.b3:B3Format \ No newline at end of file + b3 = opentelemetry.propagators.b3:B3SingleFormat + b3multi = opentelemetry.propagators.b3:B3MultiFormat \ No newline at end of file diff --git a/propagator/opentelemetry-propagator-b3/src/opentelemetry/propagators/b3/__init__.py b/propagator/opentelemetry-propagator-b3/src/opentelemetry/propagators/b3/__init__.py index d0beec401a8..c75d8e9a544 100644 --- a/propagator/opentelemetry-propagator-b3/src/opentelemetry/propagators/b3/__init__.py +++ b/propagator/opentelemetry-propagator-b3/src/opentelemetry/propagators/b3/__init__.py @@ -15,6 +15,8 @@ import typing from re import compile as re_compile +from deprecated import deprecated + import opentelemetry.trace as trace from opentelemetry.context import Context from opentelemetry.propagators.textmap import ( @@ -28,10 +30,11 @@ from opentelemetry.trace import format_span_id, format_trace_id -class B3Format(TextMapPropagator): - """Propagator for the B3 HTTP header format. +class B3MultiFormat(TextMapPropagator): + """Propagator for the B3 HTTP multi-header format. See: https://github.com/openzipkin/b3-propagation + https://github.com/openzipkin/b3-propagation#multiple-headers """ SINGLE_HEADER_KEY = "b3" @@ -165,6 +168,53 @@ def fields(self) -> typing.Set[str]: } +class B3SingleFormat(B3MultiFormat): + """Propagator for the B3 HTTP single-header format. + + See: https://github.com/openzipkin/b3-propagation + https://github.com/openzipkin/b3-propagation#single-header + """ + + def inject( + self, + carrier: CarrierT, + context: typing.Optional[Context] = None, + setter: Setter = default_setter, + ) -> None: + span = trace.get_current_span(context=context) + + span_context = span.get_span_context() + if span_context == trace.INVALID_SPAN_CONTEXT: + return + + sampled = (trace.TraceFlags.SAMPLED & span_context.trace_flags) != 0 + + fields = [ + format_trace_id(span_context.trace_id), + format_span_id(span_context.span_id), + "1" if sampled else "0", + ] + + span_parent = getattr(span, "parent", None) + if span_parent: + fields.append(format_span_id(span_parent.span_id)) + + setter.set(carrier, self.SINGLE_HEADER_KEY, "-".join(fields)) + + @property + def fields(self) -> typing.Set[str]: + return {self.SINGLE_HEADER_KEY} + + +class B3Format(B3MultiFormat): + @deprecated( + version="1.2.0", + reason="B3Format is deprecated in favor of B3MultiFormat", + ) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _extract_first_element( items: typing.Iterable[CarrierT], ) -> typing.Optional[CarrierT]: diff --git a/propagator/opentelemetry-propagator-b3/tests/test_b3_format.py b/propagator/opentelemetry-propagator-b3/tests/test_b3_format.py index 29d8a472ea3..87a059f57e2 100644 --- a/propagator/opentelemetry-propagator-b3/tests/test_b3_format.py +++ b/propagator/opentelemetry-propagator-b3/tests/test_b3_format.py @@ -13,21 +13,23 @@ # limitations under the License. import unittest +from abc import abstractclassmethod from unittest.mock import Mock -import opentelemetry.propagators.b3 as b3_format # pylint: disable=no-name-in-module,import-error import opentelemetry.sdk.trace as trace import opentelemetry.sdk.trace.id_generator as id_generator import opentelemetry.trace as trace_api from opentelemetry.context import Context, get_current +from opentelemetry.propagators.b3 import ( # pylint: disable=no-name-in-module,import-error + B3MultiFormat, + B3SingleFormat, +) from opentelemetry.propagators.textmap import DefaultGetter -FORMAT = b3_format.B3Format() +def get_child_parent_new_carrier(old_carrier, propagator): -def get_child_parent_new_carrier(old_carrier): - - ctx = FORMAT.extract(old_carrier) + ctx = propagator.extract(old_carrier) parent_span_context = trace_api.get_current_span(ctx).get_span_context() parent = trace._Span("parent", parent_span_context) @@ -45,24 +47,24 @@ def get_child_parent_new_carrier(old_carrier): new_carrier = {} ctx = trace_api.set_span_in_context(child) - FORMAT.inject(new_carrier, context=ctx) + propagator.inject(new_carrier, context=ctx) return child, parent, new_carrier -class TestB3Format(unittest.TestCase): - # pylint: disable=too-many-public-methods +class AbstractB3FormatTestCase: + # pylint: disable=too-many-public-methods,no-member,invalid-name @classmethod def setUpClass(cls): generator = id_generator.RandomIdGenerator() - cls.serialized_trace_id = b3_format.format_trace_id( + cls.serialized_trace_id = trace_api.format_trace_id( generator.generate_trace_id() ) - cls.serialized_span_id = b3_format.format_span_id( + cls.serialized_span_id = trace_api.format_span_id( generator.generate_span_id() ) - cls.serialized_parent_id = b3_format.format_span_id( + cls.serialized_parent_id = trace_api.format_span_id( generator.generate_span_id() ) @@ -74,56 +76,72 @@ def setUp(self) -> None: patcher.start() self.addCleanup(patcher.stop) + @classmethod + def get_child_parent_new_carrier(cls, old_carrier): + return get_child_parent_new_carrier(old_carrier, cls.get_propagator()) + + @abstractclassmethod + def get_propagator(cls): + pass + + @abstractclassmethod + def get_trace_id(cls, carrier): + pass + + def assertSampled(self, carrier): + pass + + def assertNotSampled(self, carrier): + pass + def test_extract_multi_header(self): """Test the extraction of B3 headers.""" - child, parent, new_carrier = get_child_parent_new_carrier( - { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.PARENT_SPAN_ID_KEY: self.serialized_parent_id, - FORMAT.SAMPLED_KEY: "1", - } - ) + propagator = self.get_propagator() + context = { + propagator.TRACE_ID_KEY: self.serialized_trace_id, + propagator.SPAN_ID_KEY: self.serialized_span_id, + propagator.PARENT_SPAN_ID_KEY: self.serialized_parent_id, + propagator.SAMPLED_KEY: "1", + } + child, parent, _ = self.get_child_parent_new_carrier(context) self.assertEqual( - new_carrier[FORMAT.TRACE_ID_KEY], - b3_format.format_trace_id(child.context.trace_id), - ) - self.assertEqual( - new_carrier[FORMAT.SPAN_ID_KEY], - b3_format.format_span_id(child.context.span_id), + context[propagator.TRACE_ID_KEY], + trace_api.format_trace_id(child.context.trace_id), ) + self.assertEqual( - new_carrier[FORMAT.PARENT_SPAN_ID_KEY], - b3_format.format_span_id(parent.context.span_id), + context[propagator.SPAN_ID_KEY], + trace_api.format_span_id(child.parent.span_id), ) self.assertTrue(parent.context.is_remote) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + self.assertTrue(parent.context.trace_flags.sampled) def test_extract_single_header(self): """Test the extraction from a single b3 header.""" - child, parent, new_carrier = get_child_parent_new_carrier( + propagator = self.get_propagator() + child, parent, _ = self.get_child_parent_new_carrier( { - FORMAT.SINGLE_HEADER_KEY: "{}-{}".format( + propagator.SINGLE_HEADER_KEY: "{}-{}".format( self.serialized_trace_id, self.serialized_span_id ) } ) self.assertEqual( - new_carrier[FORMAT.TRACE_ID_KEY], - b3_format.format_trace_id(child.context.trace_id), + self.serialized_trace_id, + trace_api.format_trace_id(child.context.trace_id), ) self.assertEqual( - new_carrier[FORMAT.SPAN_ID_KEY], - b3_format.format_span_id(child.context.span_id), + self.serialized_span_id, + trace_api.format_span_id(child.parent.span_id), ) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") self.assertTrue(parent.context.is_remote) + self.assertTrue(parent.context.trace_flags.sampled) - child, parent, new_carrier = get_child_parent_new_carrier( + child, parent, _ = self.get_child_parent_new_carrier( { - FORMAT.SINGLE_HEADER_KEY: "{}-{}-1-{}".format( + propagator.SINGLE_HEADER_KEY: "{}-{}-1-{}".format( self.serialized_trace_id, self.serialized_span_id, self.serialized_parent_id, @@ -132,99 +150,100 @@ def test_extract_single_header(self): ) self.assertEqual( - new_carrier[FORMAT.TRACE_ID_KEY], - b3_format.format_trace_id(child.context.trace_id), - ) - self.assertEqual( - new_carrier[FORMAT.SPAN_ID_KEY], - b3_format.format_span_id(child.context.span_id), + self.serialized_trace_id, + trace_api.format_trace_id(child.context.trace_id), ) self.assertEqual( - new_carrier[FORMAT.PARENT_SPAN_ID_KEY], - b3_format.format_span_id(parent.context.span_id), + self.serialized_span_id, + trace_api.format_span_id(child.parent.span_id), ) + self.assertTrue(parent.context.is_remote) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + self.assertTrue(parent.context.trace_flags.sampled) def test_extract_header_precedence(self): """A single b3 header should take precedence over multiple headers. """ + propagator = self.get_propagator() single_header_trace_id = self.serialized_trace_id[:-3] + "123" - _, _, new_carrier = get_child_parent_new_carrier( + _, _, new_carrier = self.get_child_parent_new_carrier( { - FORMAT.SINGLE_HEADER_KEY: "{}-{}".format( + propagator.SINGLE_HEADER_KEY: "{}-{}".format( single_header_trace_id, self.serialized_span_id ), - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.SAMPLED_KEY: "1", + propagator.TRACE_ID_KEY: self.serialized_trace_id, + propagator.SPAN_ID_KEY: self.serialized_span_id, + propagator.SAMPLED_KEY: "1", } ) self.assertEqual( - new_carrier[FORMAT.TRACE_ID_KEY], single_header_trace_id + self.get_trace_id(new_carrier), single_header_trace_id ) def test_enabled_sampling(self): """Test b3 sample key variants that turn on sampling.""" + propagator = self.get_propagator() for variant in ["1", "True", "true", "d"]: - _, _, new_carrier = get_child_parent_new_carrier( + _, _, new_carrier = self.get_child_parent_new_carrier( { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.SAMPLED_KEY: variant, + propagator.TRACE_ID_KEY: self.serialized_trace_id, + propagator.SPAN_ID_KEY: self.serialized_span_id, + propagator.SAMPLED_KEY: variant, } ) - - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + self.assertSampled(new_carrier) def test_disabled_sampling(self): """Test b3 sample key variants that turn off sampling.""" + propagator = self.get_propagator() for variant in ["0", "False", "false", None]: - _, _, new_carrier = get_child_parent_new_carrier( + _, _, new_carrier = self.get_child_parent_new_carrier( { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.SAMPLED_KEY: variant, + propagator.TRACE_ID_KEY: self.serialized_trace_id, + propagator.SPAN_ID_KEY: self.serialized_span_id, + propagator.SAMPLED_KEY: variant, } ) - - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "0") + self.assertNotSampled(new_carrier) def test_flags(self): """x-b3-flags set to "1" should result in propagation.""" - _, _, new_carrier = get_child_parent_new_carrier( + propagator = self.get_propagator() + _, _, new_carrier = self.get_child_parent_new_carrier( { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.FLAGS_KEY: "1", + propagator.TRACE_ID_KEY: self.serialized_trace_id, + propagator.SPAN_ID_KEY: self.serialized_span_id, + propagator.FLAGS_KEY: "1", } ) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + self.assertSampled(new_carrier) def test_flags_and_sampling(self): """Propagate if b3 flags and sampling are set.""" - _, _, new_carrier = get_child_parent_new_carrier( + propagator = self.get_propagator() + _, _, new_carrier = self.get_child_parent_new_carrier( { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.FLAGS_KEY: "1", + propagator.TRACE_ID_KEY: self.serialized_trace_id, + propagator.SPAN_ID_KEY: self.serialized_span_id, + propagator.FLAGS_KEY: "1", } ) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + self.assertSampled(new_carrier) def test_derived_ctx_is_returned_for_success(self): """Ensure returned context is derived from the given context.""" old_ctx = Context({"k1": "v1"}) - new_ctx = FORMAT.extract( + propagator = self.get_propagator() + new_ctx = propagator.extract( { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.FLAGS_KEY: "1", + propagator.TRACE_ID_KEY: self.serialized_trace_id, + propagator.SPAN_ID_KEY: self.serialized_span_id, + propagator.FLAGS_KEY: "1", }, old_ctx, ) @@ -237,7 +256,7 @@ def test_derived_ctx_is_returned_for_success(self): def test_derived_ctx_is_returned_for_failure(self): """Ensure returned context is derived from the given context.""" old_ctx = Context({"k2": "v2"}) - new_ctx = FORMAT.extract({}, old_ctx) + new_ctx = self.get_propagator().extract({}, old_ctx) self.assertNotIn("current-span", new_ctx) for key, value in old_ctx.items(): # pylint:disable=no-member self.assertIn(key, new_ctx) @@ -246,120 +265,131 @@ def test_derived_ctx_is_returned_for_failure(self): def test_64bit_trace_id(self): """64 bit trace ids should be padded to 128 bit trace ids.""" + propagator = self.get_propagator() trace_id_64_bit = self.serialized_trace_id[:16] - _, _, new_carrier = get_child_parent_new_carrier( + _, _, new_carrier = self.get_child_parent_new_carrier( { - FORMAT.TRACE_ID_KEY: trace_id_64_bit, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.FLAGS_KEY: "1", - } + propagator.TRACE_ID_KEY: trace_id_64_bit, + propagator.SPAN_ID_KEY: self.serialized_span_id, + propagator.FLAGS_KEY: "1", + }, ) self.assertEqual( - new_carrier[FORMAT.TRACE_ID_KEY], "0" * 16 + trace_id_64_bit + self.get_trace_id(new_carrier), "0" * 16 + trace_id_64_bit ) def test_extract_invalid_single_header_to_explicit_ctx(self): """Given unparsable header, do not modify context""" old_ctx = Context({"k1": "v1"}) + propagator = self.get_propagator() - carrier = {FORMAT.SINGLE_HEADER_KEY: "0-1-2-3-4-5-6-7"} - new_ctx = FORMAT.extract(carrier, old_ctx) + carrier = {propagator.SINGLE_HEADER_KEY: "0-1-2-3-4-5-6-7"} + new_ctx = propagator.extract(carrier, old_ctx) self.assertDictEqual(new_ctx, old_ctx) def test_extract_invalid_single_header_to_implicit_ctx(self): - carrier = {FORMAT.SINGLE_HEADER_KEY: "0-1-2-3-4-5-6-7"} - new_ctx = FORMAT.extract(carrier) + propagator = self.get_propagator() + carrier = {propagator.SINGLE_HEADER_KEY: "0-1-2-3-4-5-6-7"} + new_ctx = propagator.extract(carrier) self.assertDictEqual(Context(), new_ctx) def test_extract_missing_trace_id_to_explicit_ctx(self): """Given no trace ID, do not modify context""" old_ctx = Context({"k1": "v1"}) + propagator = self.get_propagator() carrier = { - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.FLAGS_KEY: "1", + propagator.SPAN_ID_KEY: self.serialized_span_id, + propagator.FLAGS_KEY: "1", } - new_ctx = FORMAT.extract(carrier, old_ctx) + new_ctx = propagator.extract(carrier, old_ctx) self.assertDictEqual(new_ctx, old_ctx) def test_extract_missing_trace_id_to_implicit_ctx(self): + propagator = self.get_propagator() carrier = { - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.FLAGS_KEY: "1", + propagator.SPAN_ID_KEY: self.serialized_span_id, + propagator.FLAGS_KEY: "1", } - new_ctx = FORMAT.extract(carrier) + new_ctx = propagator.extract(carrier) self.assertDictEqual(Context(), new_ctx) def test_extract_invalid_trace_id_to_explicit_ctx(self): """Given invalid trace ID, do not modify context""" old_ctx = Context({"k1": "v1"}) + propagator = self.get_propagator() carrier = { - FORMAT.TRACE_ID_KEY: "abc123", - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.FLAGS_KEY: "1", + propagator.TRACE_ID_KEY: "abc123", + propagator.SPAN_ID_KEY: self.serialized_span_id, + propagator.FLAGS_KEY: "1", } - new_ctx = FORMAT.extract(carrier, old_ctx) + new_ctx = propagator.extract(carrier, old_ctx) self.assertDictEqual(new_ctx, old_ctx) def test_extract_invalid_trace_id_to_implicit_ctx(self): + propagator = self.get_propagator() carrier = { - FORMAT.TRACE_ID_KEY: "abc123", - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.FLAGS_KEY: "1", + propagator.TRACE_ID_KEY: "abc123", + propagator.SPAN_ID_KEY: self.serialized_span_id, + propagator.FLAGS_KEY: "1", } - new_ctx = FORMAT.extract(carrier) + new_ctx = propagator.extract(carrier) self.assertDictEqual(Context(), new_ctx) def test_extract_invalid_span_id_to_explicit_ctx(self): """Given invalid span ID, do not modify context""" old_ctx = Context({"k1": "v1"}) + propagator = self.get_propagator() carrier = { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: "abc123", - FORMAT.FLAGS_KEY: "1", + propagator.TRACE_ID_KEY: self.serialized_trace_id, + propagator.SPAN_ID_KEY: "abc123", + propagator.FLAGS_KEY: "1", } - new_ctx = FORMAT.extract(carrier, old_ctx) + new_ctx = propagator.extract(carrier, old_ctx) self.assertDictEqual(new_ctx, old_ctx) def test_extract_invalid_span_id_to_implicit_ctx(self): + propagator = self.get_propagator() carrier = { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: "abc123", - FORMAT.FLAGS_KEY: "1", + propagator.TRACE_ID_KEY: self.serialized_trace_id, + propagator.SPAN_ID_KEY: "abc123", + propagator.FLAGS_KEY: "1", } - new_ctx = FORMAT.extract(carrier) + new_ctx = propagator.extract(carrier) self.assertDictEqual(Context(), new_ctx) def test_extract_missing_span_id_to_explicit_ctx(self): """Given no span ID, do not modify context""" old_ctx = Context({"k1": "v1"}) + propagator = self.get_propagator() carrier = { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.FLAGS_KEY: "1", + propagator.TRACE_ID_KEY: self.serialized_trace_id, + propagator.FLAGS_KEY: "1", } - new_ctx = FORMAT.extract(carrier, old_ctx) + new_ctx = propagator.extract(carrier, old_ctx) self.assertDictEqual(new_ctx, old_ctx) def test_extract_missing_span_id_to_implicit_ctx(self): + propagator = self.get_propagator() carrier = { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.FLAGS_KEY: "1", + propagator.TRACE_ID_KEY: self.serialized_trace_id, + propagator.FLAGS_KEY: "1", } - new_ctx = FORMAT.extract(carrier) + new_ctx = propagator.extract(carrier) self.assertDictEqual(Context(), new_ctx) @@ -368,54 +398,90 @@ def test_extract_empty_carrier_to_explicit_ctx(self): old_ctx = Context({"k1": "v1"}) carrier = {} - new_ctx = FORMAT.extract(carrier, old_ctx) + new_ctx = self.get_propagator().extract(carrier, old_ctx) self.assertDictEqual(new_ctx, old_ctx) def test_extract_empty_carrier_to_implicit_ctx(self): - new_ctx = FORMAT.extract({}) + new_ctx = self.get_propagator().extract({}) self.assertDictEqual(Context(), new_ctx) - @staticmethod - def test_inject_empty_context(): + def test_inject_empty_context(self): """If the current context has no span, don't add headers""" new_carrier = {} - FORMAT.inject(new_carrier, get_current()) + self.get_propagator().inject(new_carrier, get_current()) assert len(new_carrier) == 0 - @staticmethod - def test_default_span(): + def test_default_span(self): """Make sure propagator does not crash when working with NonRecordingSpan""" class CarrierGetter(DefaultGetter): def get(self, carrier, key): return carrier.get(key, None) - ctx = FORMAT.extract({}, getter=CarrierGetter()) - FORMAT.inject({}, context=ctx) + propagator = self.get_propagator() + ctx = propagator.extract({}, getter=CarrierGetter()) + propagator.inject({}, context=ctx) def test_fields(self): """Make sure the fields attribute returns the fields used in inject""" + propagator = self.get_propagator() tracer = trace.TracerProvider().get_tracer("sdk_tracer_provider") mock_setter = Mock() with tracer.start_as_current_span("parent"): with tracer.start_as_current_span("child"): - FORMAT.inject({}, setter=mock_setter) + propagator.inject({}, setter=mock_setter) inject_fields = set() for call in mock_setter.mock_calls: inject_fields.add(call[1][1]) - self.assertEqual(FORMAT.fields, inject_fields) + self.assertEqual(propagator.fields, inject_fields) def test_extract_none_context(self): """Given no trace ID, do not modify context""" old_ctx = None carrier = {} - new_ctx = FORMAT.extract(carrier, old_ctx) + new_ctx = self.get_propagator().extract(carrier, old_ctx) self.assertDictEqual(Context(), new_ctx) + + +class TestB3MultiFormat(AbstractB3FormatTestCase, unittest.TestCase): + @classmethod + def get_propagator(cls): + return B3MultiFormat() + + @classmethod + def get_trace_id(cls, carrier): + return carrier[cls.get_propagator().TRACE_ID_KEY] + + def assertSampled(self, carrier): + self.assertEqual(carrier[self.get_propagator().SAMPLED_KEY], "1") + + def assertNotSampled(self, carrier): + self.assertEqual(carrier[self.get_propagator().SAMPLED_KEY], "0") + + +class TestB3SingleFormat(AbstractB3FormatTestCase, unittest.TestCase): + @classmethod + def get_propagator(cls): + return B3SingleFormat() + + @classmethod + def get_trace_id(cls, carrier): + return carrier[cls.get_propagator().SINGLE_HEADER_KEY].split("-")[0] + + def assertSampled(self, carrier): + self.assertEqual( + carrier[self.get_propagator().SINGLE_HEADER_KEY].split("-")[2], "1" + ) + + def assertNotSampled(self, carrier): + self.assertEqual( + carrier[self.get_propagator().SINGLE_HEADER_KEY].split("-")[2], "0" + )