diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 50165763461..f993f135577 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,7 +10,7 @@ env: # Otherwise, set variable to the commit of your branch on # opentelemetry-python-contrib which is compatible with these Core repo # changes. - CONTRIB_REPO_SHA: 5bc0fa1611502be47a1f4eb550fe255e4b707ba1 + CONTRIB_REPO_SHA: 040fa8f9b58e35fda60d2200068d082d1c237435 jobs: build: diff --git a/docs/conf.py b/docs/conf.py index d23cebfe96c..5d44c24dd22 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -104,7 +104,6 @@ ("py:class", "opentelemetry.trace._LinkBase",), # TODO: Understand why sphinx is not able to find this local class ("py:class", "opentelemetry.propagators.textmap.TextMapPropagator",), - ("py:class", "opentelemetry.propagators.textmap.DictGetter",), ("any", "opentelemetry.propagators.textmap.TextMapPropagator.extract",), ("any", "opentelemetry.propagators.textmap.TextMapPropagator.inject",), ] diff --git a/docs/examples/auto-instrumentation/README.rst b/docs/examples/auto-instrumentation/README.rst index 607aa1b44b7..0a6c07bb408 100644 --- a/docs/examples/auto-instrumentation/README.rst +++ b/docs/examples/auto-instrumentation/README.rst @@ -37,7 +37,7 @@ Manually instrumented server def server_request(): with tracer.start_as_current_span( "server_request", - context=propagators.extract(DictGetter(), request.headers + context=propagators.extract(request.headers ), ): print(request.args.get("param")) diff --git a/docs/examples/auto-instrumentation/client.py b/docs/examples/auto-instrumentation/client.py index fefc1f67b98..cc948cc54b8 100644 --- a/docs/examples/auto-instrumentation/client.py +++ b/docs/examples/auto-instrumentation/client.py @@ -37,7 +37,7 @@ with tracer.start_as_current_span("client-server"): headers = {} - propagators.inject(dict.__setitem__, headers) + propagators.inject(headers) requested = get( "http://localhost:8082/server_request", params={"param": argv[1]}, diff --git a/docs/examples/auto-instrumentation/server_instrumented.py b/docs/examples/auto-instrumentation/server_instrumented.py index 1ac1bd6b71b..652358e3a2e 100644 --- a/docs/examples/auto-instrumentation/server_instrumented.py +++ b/docs/examples/auto-instrumentation/server_instrumented.py @@ -17,7 +17,6 @@ from opentelemetry import trace from opentelemetry.instrumentation.wsgi import collect_request_attributes from opentelemetry.propagate import extract -from opentelemetry.propagators.textmap import DictGetter from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import ( ConsoleSpanExporter, @@ -38,7 +37,7 @@ def server_request(): with tracer.start_as_current_span( "server_request", - context=extract(DictGetter(), request.headers), + context=extract(request.headers), kind=trace.SpanKind.SERVER, attributes=collect_request_attributes(request.environ), ): diff --git a/docs/examples/datadog_exporter/client.py b/docs/examples/datadog_exporter/client.py index 6b4b5d00ec1..7c6196ad4ab 100644 --- a/docs/examples/datadog_exporter/client.py +++ b/docs/examples/datadog_exporter/client.py @@ -47,7 +47,7 @@ with tracer.start_as_current_span("client-server"): headers = {} - inject(dict.__setitem__, headers) + inject(headers) requested = get( "http://localhost:8082/server_request", params={"param": argv[1]}, diff --git a/docs/examples/django/client.py b/docs/examples/django/client.py index bc3606cbe76..3ae0cb6e1cf 100644 --- a/docs/examples/django/client.py +++ b/docs/examples/django/client.py @@ -36,7 +36,7 @@ with tracer.start_as_current_span("client-server"): headers = {} - inject(dict.__setitem__, headers) + inject(headers) requested = get( "http://localhost:8000", params={"param": argv[1]}, diff --git a/opentelemetry-api/src/opentelemetry/baggage/propagation/__init__.py b/opentelemetry-api/src/opentelemetry/baggage/propagation/__init__.py index e6d1c4207bc..d96fc067d09 100644 --- a/opentelemetry-api/src/opentelemetry/baggage/propagation/__init__.py +++ b/opentelemetry-api/src/opentelemetry/baggage/propagation/__init__.py @@ -11,29 +11,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# -import typing -import urllib.parse + +from typing import Optional, Set +from urllib.parse import quote_plus, unquote from opentelemetry import baggage from opentelemetry.context import get_current from opentelemetry.context.context import Context -from opentelemetry.propagators import textmap +from opentelemetry.propagators.textmap import ( + TextMapPropagator, + TextMapPropagatorT, +) -class W3CBaggagePropagator(textmap.TextMapPropagator): +class W3CBaggagePropagator(TextMapPropagator): """Extracts and injects Baggage which is used to annotate telemetry.""" - _MAX_HEADER_LENGTH = 8192 - _MAX_PAIR_LENGTH = 4096 - _MAX_PAIRS = 180 - _BAGGAGE_HEADER_NAME = "baggage" + _baggage_header_name = "baggage" + _max_header_length = 9182 + _max_pairs = 180 + _max_pair_length = 4096 def extract( - self, - getter: textmap.Getter[textmap.TextMapPropagatorT], - carrier: textmap.TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> Context: """Extract Baggage from the carrier. @@ -44,38 +44,36 @@ def extract( if context is None: context = get_current() - header = _extract_first_element( - getter.get(carrier, self._BAGGAGE_HEADER_NAME) - ) + value = carrier.get(self._baggage_header_name) + + if value is None: + header = None + else: + header = next(iter(value), None) - if not header or len(header) > self._MAX_HEADER_LENGTH: + if header is None or len(header) > self._max_header_length: return context - baggage_entries = header.split(",") - total_baggage_entries = self._MAX_PAIRS - for entry in baggage_entries: + total_baggage_entries = self._max_pairs + + for entry in header.split(","): if total_baggage_entries <= 0: return context total_baggage_entries -= 1 - if len(entry) > self._MAX_PAIR_LENGTH: + if len(entry) > self._max_pair_length: continue - try: + if "=" in entry: name, value = entry.split("=", 1) - except Exception: # pylint: disable=broad-except - continue - context = baggage.set_baggage( - urllib.parse.unquote(name).strip(), - urllib.parse.unquote(value).strip(), - context=context, - ) + context = baggage.set_baggage( + unquote(name).strip(), + unquote(value).strip(), + context=context, + ) return context def inject( - self, - set_in_carrier: textmap.Setter[textmap.TextMapPropagatorT], - carrier: textmap.TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> None: """Injects Baggage into the carrier. @@ -83,28 +81,14 @@ def inject( `opentelemetry.propagators.textmap.TextMapPropagator.inject` """ baggage_entries = baggage.get_all(context=context) - if not baggage_entries: - return - baggage_string = _format_baggage(baggage_entries) - set_in_carrier(carrier, self._BAGGAGE_HEADER_NAME, baggage_string) + if baggage_entries: + carrier[self._baggage_header_name] = ",".join( + key + "=" + quote_plus(str(value)) + for key, value in baggage_entries.items() + ) @property - def fields(self) -> typing.Set[str]: + def fields(self) -> Set[str]: """Returns a set with the fields set in `inject`.""" - return {self._BAGGAGE_HEADER_NAME} - - -def _format_baggage(baggage_entries: typing.Mapping[str, object]) -> str: - return ",".join( - key + "=" + urllib.parse.quote_plus(str(value)) - for key, value in baggage_entries.items() - ) - - -def _extract_first_element( - items: typing.Optional[typing.Iterable[textmap.TextMapPropagatorT]], -) -> typing.Optional[textmap.TextMapPropagatorT]: - if items is None: - return None - return next(iter(items), None) + return {self._baggage_header_name} diff --git a/opentelemetry-api/src/opentelemetry/propagate/__init__.py b/opentelemetry-api/src/opentelemetry/propagate/__init__.py index 44f9897a532..d23a0fcd239 100644 --- a/opentelemetry-api/src/opentelemetry/propagate/__init__.py +++ b/opentelemetry-api/src/opentelemetry/propagate/__init__.py @@ -40,23 +40,12 @@ PROPAGATOR = propagators.get_global_textmap() - def get_header_from_flask_request(request, key): - return request.headers.get_all(key) - - def set_header_into_requests_request(request: requests.Request, - key: str, value: str): - request.headers[key] = value - def example_route(): - context = PROPAGATOR.extract( - get_header_from_flask_request, - flask.request - ) + context = PROPAGATOR.extract(flask.request) request_to_downstream = requests.Request( "GET", "http://httpbin.org/get" ) PROPAGATOR.inject( - set_header_into_requests_request, request_to_downstream, context=context ) @@ -68,23 +57,25 @@ def example_route(): https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/context/api-propagators.md """ -import typing from logging import getLogger from os import environ +from typing import Optional from pkg_resources import iter_entry_points from opentelemetry.context.context import Context from opentelemetry.environment_variables import OTEL_PROPAGATORS -from opentelemetry.propagators import composite, textmap +from opentelemetry.propagators import composite +from opentelemetry.propagators.textmap import ( + TextMapPropagator, + TextMapPropagatorT, +) -logger = getLogger(__name__) +_logger = getLogger(__name__) def extract( - getter: textmap.Getter[textmap.TextMapPropagatorT], - carrier: textmap.TextMapPropagatorT, - context: typing.Optional[Context] = None, + carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> Context: """Uses the configured propagator to extract a Context from the carrier. @@ -99,26 +90,21 @@ def extract( context: an optional Context to use. Defaults to current context if not set. """ - return get_global_textmap().extract(getter, carrier, context) + return get_global_textmap().extract(carrier, context) def inject( - set_in_carrier: textmap.Setter[textmap.TextMapPropagatorT], - carrier: textmap.TextMapPropagatorT, - context: typing.Optional[Context] = None, + carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> None: """Uses the configured propagator to inject a Context into the carrier. Args: - set_in_carrier: A setter function that can set values - on the carrier. - carrier: An object that contains a representation of HTTP - headers. Should be paired with set_in_carrier, which - should know how to set header values on the carrier. + carrier: A dict-like object that contains a representation of HTTP + headers. context: an optional Context to use. Defaults to current context if not set. """ - get_global_textmap().inject(set_in_carrier, carrier, context) + get_global_textmap().inject(carrier, context) try: @@ -138,16 +124,16 @@ def inject( ) except Exception: # pylint: disable=broad-except - logger.exception("Failed to load configured propagators") + _logger.error("Failed to load configured propagators") raise -_HTTP_TEXT_FORMAT = composite.CompositeHTTPPropagator(propagators) # type: ignore +_textmap_propagator = composite.CompositeHTTPPropagator(propagators) # type: ignore -def get_global_textmap() -> textmap.TextMapPropagator: - return _HTTP_TEXT_FORMAT +def get_global_textmap() -> TextMapPropagator: + return _textmap_propagator -def set_global_textmap(http_text_format: textmap.TextMapPropagator,) -> None: - global _HTTP_TEXT_FORMAT # pylint:disable=global-statement - _HTTP_TEXT_FORMAT = http_text_format # type: ignore +def set_global_textmap(http_text_format: TextMapPropagator,) -> None: + global _textmap_propagator # pylint:disable=global-statement + _textmap_propagator = http_text_format # type: ignore diff --git a/opentelemetry-api/src/opentelemetry/propagators/composite.py b/opentelemetry-api/src/opentelemetry/propagators/composite.py index 92dc6b8a380..811934b5164 100644 --- a/opentelemetry-api/src/opentelemetry/propagators/composite.py +++ b/opentelemetry-api/src/opentelemetry/propagators/composite.py @@ -11,16 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -import typing + +from logging import getLogger +from typing import Optional, Sequence, Set from opentelemetry.context.context import Context -from opentelemetry.propagators import textmap +from opentelemetry.propagators.textmap import ( + TextMapPropagator, + TextMapPropagatorT, +) -logger = logging.getLogger(__name__) +_logger = getLogger(__name__) -class CompositeHTTPPropagator(textmap.TextMapPropagator): +class CompositeHTTPPropagator(TextMapPropagator): """CompositeHTTPPropagator provides a mechanism for combining multiple propagators into a single one. @@ -28,46 +32,39 @@ class CompositeHTTPPropagator(textmap.TextMapPropagator): propagators: the list of propagators to use """ - def __init__( - self, propagators: typing.Sequence[textmap.TextMapPropagator] - ) -> None: + def __init__(self, propagators: Sequence[TextMapPropagator]) -> None: self._propagators = propagators def extract( - self, - getter: textmap.Getter[textmap.TextMapPropagatorT], - carrier: textmap.TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> Context: - """Run each of the configured propagators with the given context and carrier. + """Run each of the configured propagators with the given context and + carrier. Propagators are run in the order they are configured, if multiple - propagators write the same context key, the propagator later in the list - will override previous propagators. + propagators write the same context key, the last propagator that writes + the context key will override previous propagators. See `opentelemetry.propagators.textmap.TextMapPropagator.extract` """ for propagator in self._propagators: - context = propagator.extract(getter, carrier, context) + context = propagator.extract(carrier, context) return context # type: ignore def inject( - self, - set_in_carrier: textmap.Setter[textmap.TextMapPropagatorT], - carrier: textmap.TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> None: - """Run each of the configured propagators with the given context and carrier. - Propagators are run in the order they are configured, if multiple - propagators write the same carrier key, the propagator later in the list - will override previous propagators. + """Run each of the configured propagators with the given context and + carrier. Propagators are run in the order they are configured, if + multiple propagators write the same carrier key, the last propagator + that writes the carrier key will override previous propagators. See `opentelemetry.propagators.textmap.TextMapPropagator.inject` """ for propagator in self._propagators: - propagator.inject(set_in_carrier, carrier, context) + propagator.inject(carrier, context) @property - def fields(self) -> typing.Set[str]: + def fields(self) -> Set[str]: """Returns a set with the fields set in `inject`. See diff --git a/opentelemetry-api/src/opentelemetry/propagators/textmap.py b/opentelemetry-api/src/opentelemetry/propagators/textmap.py index cf93d1d6319..af6d3a49595 100644 --- a/opentelemetry-api/src/opentelemetry/propagators/textmap.py +++ b/opentelemetry-api/src/opentelemetry/propagators/textmap.py @@ -12,139 +12,58 @@ # See the License for the specific language governing permissions and # limitations under the License. -import abc -import typing +from abc import ABC, abstractmethod +from typing import Optional, Set, TypeVar from opentelemetry.context.context import Context -TextMapPropagatorT = typing.TypeVar("TextMapPropagatorT") -CarrierValT = typing.Union[typing.List[str], str] +TextMapPropagatorT = TypeVar("TextMapPropagatorT") -Setter = typing.Callable[[TextMapPropagatorT, str, str], None] - -class Getter(typing.Generic[TextMapPropagatorT]): - """This class implements a Getter that enables extracting propagated - fields from a carrier. - """ - - def get( - self, carrier: TextMapPropagatorT, key: str - ) -> typing.Optional[typing.List[str]]: - """Function that can retrieve zero - or more values from the carrier. In the case that - the value does not exist, returns None. - - Args: - carrier: An object which contains values that are used to - construct a Context. - key: key of a field in carrier. - Returns: first value of the propagation key or None if the key doesn't - exist. - """ - raise NotImplementedError() - - def keys(self, carrier: TextMapPropagatorT) -> typing.List[str]: - """Function that can retrieve all the keys in a carrier object. - - Args: - carrier: An object which contains values that are - used to construct a Context. - Returns: - list of keys from the carrier. - """ - raise NotImplementedError() - - -class DictGetter(Getter[typing.Dict[str, CarrierValT]]): - def get( - self, carrier: typing.Dict[str, CarrierValT], key: str - ) -> typing.Optional[typing.List[str]]: - """Getter implementation to retrieve a value from a dictionary. - - Args: - carrier: dictionary in which header - key: the key used to get the value - Returns: - A list with a single string with the value if it exists, else None. - """ - val = carrier.get(key, None) - if val is None: - return None - if isinstance(val, typing.Iterable) and not isinstance(val, str): - return list(val) - return [val] - - def keys(self, carrier: typing.Dict[str, CarrierValT]) -> typing.List[str]: - """Keys implementation that returns all keys from a dictionary.""" - return list(carrier.keys()) - - -class TextMapPropagator(abc.ABC): +class TextMapPropagator(ABC): """This class provides an interface that enables extracting and injecting - context into headers of HTTP requests. HTTP frameworks and clients - can integrate with TextMapPropagator by providing the object containing the - headers, and a getter and setter function for the extraction and - injection of values, respectively. - + context into headers of HTTP requests. HTTP frameworks and clients can + integrate with TextMapPropagator by providing the object containing the + headers. """ - @abc.abstractmethod + @abstractmethod def extract( - self, - getter: Getter[TextMapPropagatorT], - carrier: TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> Context: """Create a Context from values in the carrier. - The extract function should retrieve values from the carrier - object using getter, and use values to populate a - Context value and return it. + Retrieves values from the carrier object and uses them to populate a + context and returns it afterwards. Args: - getter: a function that can retrieve zero - or more values from the carrier. In the case that - the value does not exist, return an empty list. - carrier: and object which contains values that are - used to construct a Context. This object - must be paired with an appropriate getter - which understands how to extract a value from it. - context: an optional Context to use. Defaults to current - context if not set. + carrier: and object which contains values that are used to + construct a Context. + context: an optional Context to use. Defaults to current context if + not set. Returns: - A Context with configuration found in the carrier. - + A Context with the configuration found in the carrier. """ - @abc.abstractmethod + @abstractmethod def inject( - self, - set_in_carrier: Setter[TextMapPropagatorT], - carrier: TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> None: """Inject values from a Context into a carrier. - inject enables the propagation of values into HTTP clients or - other objects which perform an HTTP request. Implementations - should use the set_in_carrier method to set values on the - carrier. + Enables the propagation of values into HTTP clients or other objects + which perform an HTTP request. Args: - set_in_carrier: A setter function that can set values - on the carrier. - carrier: An object that a place to define HTTP headers. - Should be paired with set_in_carrier, which should - know how to set header values on the carrier. + carrier: An dict-like object where to store HTTP headers. context: an optional Context to use. Defaults to current context if not set. """ @property - @abc.abstractmethod - def fields(self) -> typing.Set[str]: + @abstractmethod + def fields(self) -> Set[str]: """ Gets the fields set in the carrier by the `inject` method. diff --git a/opentelemetry-api/src/opentelemetry/trace/propagation/tracecontext.py b/opentelemetry-api/src/opentelemetry/trace/propagation/tracecontext.py index 480e716bf78..c35c11b65e5 100644 --- a/opentelemetry-api/src/opentelemetry/trace/propagation/tracecontext.py +++ b/opentelemetry-api/src/opentelemetry/trace/propagation/tracecontext.py @@ -11,111 +11,120 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# -import re -import typing -import opentelemetry.trace as trace +from re import compile as compile_ +from re import search +from typing import Optional, Set + from opentelemetry.context.context import Context -from opentelemetry.propagators import textmap -from opentelemetry.trace import format_span_id, format_trace_id +from opentelemetry.propagators.textmap import ( + TextMapPropagator, + TextMapPropagatorT, +) +from opentelemetry.trace import ( + INVALID_SPAN, + INVALID_SPAN_CONTEXT, + NonRecordingSpan, + SpanContext, + TraceFlags, + format_span_id, + format_trace_id, + get_current_span, + set_span_in_context, +) from opentelemetry.trace.span import TraceState -class TraceContextTextMapPropagator(textmap.TextMapPropagator): +class TraceContextTextMapPropagator(TextMapPropagator): """Extracts and injects using w3c TraceContext's headers.""" - _TRACEPARENT_HEADER_NAME = "traceparent" - _TRACESTATE_HEADER_NAME = "tracestate" - _TRACEPARENT_HEADER_FORMAT = ( - "^[ \t]*([0-9a-f]{2})-([0-9a-f]{32})-([0-9a-f]{16})-([0-9a-f]{2})" - + "(-.*)?[ \t]*$" + _traceparent_header_name = "traceparent" + _tracestate_header_name = "tracestate" + _traceparent_header_format_re = compile_( + r"^\s*(?P[0-9a-f]{2})-" + r"(?P[0-9a-f]{32})-" + r"(?P[0-9a-f]{16})-" + r"(?P[0-9a-f]{2})" + r"(?P.+?)?\s*$" ) - _TRACEPARENT_HEADER_FORMAT_RE = re.compile(_TRACEPARENT_HEADER_FORMAT) def extract( - self, - getter: textmap.Getter[textmap.TextMapPropagatorT], - carrier: textmap.TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> Context: """Extracts SpanContext from the carrier. See `opentelemetry.propagators.textmap.TextMapPropagator.extract` """ - header = getter.get(carrier, self._TRACEPARENT_HEADER_NAME) + header = carrier.get(self._traceparent_header_name) + + if header is None: + return set_span_in_context(INVALID_SPAN, context) - if not header: - return trace.set_span_in_context(trace.INVALID_SPAN, context) + match = search(self._traceparent_header_format_re, header[0]) + if match is None: + return set_span_in_context(INVALID_SPAN, context) - match = re.search(self._TRACEPARENT_HEADER_FORMAT_RE, header[0]) - if not match: - return trace.set_span_in_context(trace.INVALID_SPAN, context) + version = match.group("version") + trace_id = match.group("trace_id") + span_id = match.group("span_id") - version = match.group(1) - trace_id = match.group(2) - span_id = match.group(3) - trace_flags = match.group(4) + if ( + version == "ff" + or trace_id == "0" * 32 + or span_id == "0" * 16 + or (version == "00" and match.group("remainder")) + ): - if trace_id == "0" * 32 or span_id == "0" * 16: - return trace.set_span_in_context(trace.INVALID_SPAN, context) + return set_span_in_context(INVALID_SPAN, context) - if version == "00": - if match.group(5): - return trace.set_span_in_context(trace.INVALID_SPAN, context) - if version == "ff": - return trace.set_span_in_context(trace.INVALID_SPAN, context) + tracestate_headers = carrier.get(self._tracestate_header_name) - tracestate_headers = getter.get(carrier, self._TRACESTATE_HEADER_NAME) if tracestate_headers is None: tracestate = None else: tracestate = TraceState.from_header(tracestate_headers) - span_context = trace.SpanContext( - trace_id=int(trace_id, 16), - span_id=int(span_id, 16), - is_remote=True, - trace_flags=trace.TraceFlags(trace_flags), - trace_state=tracestate, - ) - return trace.set_span_in_context( - trace.NonRecordingSpan(span_context), context + return set_span_in_context( + NonRecordingSpan( + SpanContext( + trace_id=int(trace_id, 16), + span_id=int(span_id, 16), + is_remote=True, + trace_flags=TraceFlags(match.group("trace_flags")), + trace_state=tracestate, + ) + ), + context, ) def inject( - self, - set_in_carrier: textmap.Setter[textmap.TextMapPropagatorT], - carrier: textmap.TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> None: """Injects SpanContext into the carrier. See `opentelemetry.propagators.textmap.TextMapPropagator.inject` """ - span = trace.get_current_span(context) - span_context = span.get_span_context() - if span_context == trace.INVALID_SPAN_CONTEXT: + span_context = get_current_span(context).get_span_context() + if span_context == INVALID_SPAN_CONTEXT: return - traceparent_string = "00-{trace_id}-{span_id}-{:02x}".format( + carrier[ + self._traceparent_header_name + ] = "00-{trace_id}-{span_id}-{:02x}".format( span_context.trace_flags, trace_id=format_trace_id(span_context.trace_id), span_id=format_span_id(span_context.span_id), ) - set_in_carrier( - carrier, self._TRACEPARENT_HEADER_NAME, traceparent_string - ) + if span_context.trace_state: - tracestate_string = span_context.trace_state.to_header() - set_in_carrier( - carrier, self._TRACESTATE_HEADER_NAME, tracestate_string - ) + carrier[ + self._tracestate_header_name + ] = span_context.trace_state.to_header() @property - def fields(self) -> typing.Set[str]: + def fields(self) -> Set[str]: """Returns a set with the fields set in `inject`. See `opentelemetry.propagators.textmap.TextMapPropagator.fields` """ - return {self._TRACEPARENT_HEADER_NAME, self._TRACESTATE_HEADER_NAME} + return {self._traceparent_header_name, self._tracestate_header_name} diff --git a/opentelemetry-api/tests/baggage/test_baggage_propagation.py b/opentelemetry-api/tests/baggage/test_baggage_propagation.py index a928a2fc8cb..57b2b531c02 100644 --- a/opentelemetry-api/tests/baggage/test_baggage_propagation.py +++ b/opentelemetry-api/tests/baggage/test_baggage_propagation.py @@ -11,26 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# -import unittest -from unittest.mock import Mock, patch + +from unittest import TestCase +from unittest.mock import patch from opentelemetry import baggage from opentelemetry.baggage.propagation import W3CBaggagePropagator from opentelemetry.context import get_current -from opentelemetry.propagators.textmap import DictGetter - -carrier_getter = DictGetter() -class TestBaggagePropagation(unittest.TestCase): +class TestBaggagePropagation(TestCase): def setUp(self): self.propagator = W3CBaggagePropagator() def _extract(self, header_value): """Test helper""" - header = {"baggage": [header_value]} - return baggage.get_all(self.propagator.extract(carrier_getter, header)) + return baggage.get_all( + self.propagator.extract({"baggage": [header_value]}) + ) def _inject(self, values): """Test helper""" @@ -38,122 +36,114 @@ def _inject(self, values): for k, v in values.items(): ctx = baggage.set_baggage(k, v, context=ctx) output = {} - self.propagator.inject(dict.__setitem__, output, context=ctx) + self.propagator.inject(output, context=ctx) return output.get("baggage") def test_no_context_header(self): - baggage_entries = baggage.get_all( - self.propagator.extract(carrier_getter, {}) - ) - self.assertEqual(baggage_entries, {}) + self.assertEqual(baggage.get_all(self.propagator.extract({})), {}) def test_empty_context_header(self): - header = "" - self.assertEqual(self._extract(header), {}) + self.assertEqual(self._extract(""), {}) def test_valid_header(self): - header = "key1=val1,key2=val2" - expected = {"key1": "val1", "key2": "val2"} - self.assertEqual(self._extract(header), expected) + self.assertEqual( + self._extract("key1=val1,key2=val2"), + {"key1": "val1", "key2": "val2"}, + ) def test_valid_header_with_space(self): - header = "key1 = val1, key2 =val2 " - expected = {"key1": "val1", "key2": "val2"} - self.assertEqual(self._extract(header), expected) + self.assertEqual( + self._extract("key1 = val1, key2 =val2 "), + {"key1": "val1", "key2": "val2"}, + ) def test_valid_header_with_properties(self): - header = "key1=val1,key2=val2;prop=1" - expected = {"key1": "val1", "key2": "val2;prop=1"} - self.assertEqual(self._extract(header), expected) + self.assertEqual( + self._extract("key1=val1,key2=val2;prop=1"), + {"key1": "val1", "key2": "val2;prop=1"}, + ) def test_valid_header_with_url_escaped_comma(self): - header = "key%2C1=val1,key2=val2%2Cval3" - expected = {"key,1": "val1", "key2": "val2,val3"} - self.assertEqual(self._extract(header), expected) + self.assertEqual( + self._extract("key%2C1=val1,key2=val2%2Cval3"), + {"key,1": "val1", "key2": "val2,val3"}, + ) def test_valid_header_with_invalid_value(self): - header = "key1=val1,key2=val2,a,val3" - expected = {"key1": "val1", "key2": "val2"} - self.assertEqual(self._extract(header), expected) + self.assertEqual( + self._extract("key1=val1,key2=val2,a,val3"), + {"key1": "val1", "key2": "val2"}, + ) def test_valid_header_with_empty_value(self): - header = "key1=,key2=val2" - expected = {"key1": "", "key2": "val2"} - self.assertEqual(self._extract(header), expected) + self.assertEqual( + self._extract("key1=,key2=val2"), {"key1": "", "key2": "val2"} + ) def test_invalid_header(self): - header = "header1" - expected = {} - self.assertEqual(self._extract(header), expected) + self.assertEqual(self._extract("header1"), {}) def test_header_too_long(self): - long_value = "s" * (W3CBaggagePropagator._MAX_HEADER_LENGTH + 1) - header = "key1={}".format(long_value) - expected = {} - self.assertEqual(self._extract(header), expected) + self.assertEqual( + self._extract( + "key1={}".format( + "s" * (W3CBaggagePropagator._max_header_length + 1) + ) + ), + {}, + ) def test_header_contains_too_many_entries(self): - header = ",".join( - [ - "key{}=val".format(k) - for k in range(W3CBaggagePropagator._MAX_PAIRS + 1) - ] - ) self.assertEqual( - len(self._extract(header)), W3CBaggagePropagator._MAX_PAIRS + len( + self._extract( + ",".join( + "key{}=val".format(k) + for k in range(W3CBaggagePropagator._max_pairs + 1) + ) + ) + ), + W3CBaggagePropagator._max_pairs, ) def test_header_contains_pair_too_long(self): - long_value = "s" * (W3CBaggagePropagator._MAX_PAIR_LENGTH + 1) - header = "key1=value1,key2={},key3=value3".format(long_value) - expected = {"key1": "value1", "key3": "value3"} - self.assertEqual(self._extract(header), expected) + self.assertEqual( + self._extract( + "key1=value1,key2={},key3=value3".format( + "s" * (W3CBaggagePropagator._max_pair_length + 1) + ) + ), + {"key1": "value1", "key3": "value3"}, + ) def test_inject_no_baggage_entries(self): - values = {} - output = self._inject(values) - self.assertEqual(None, output) + self.assertEqual(None, self._inject({})) def test_inject(self): - values = { - "key1": "val1", - "key2": "val2", - } - output = self._inject(values) + output = self._inject({"key1": "val1", "key2": "val2"}) self.assertIn("key1=val1", output) self.assertIn("key2=val2", output) def test_inject_escaped_values(self): - values = { - "key1": "val1,val2", - "key2": "val3=4", - } - output = self._inject(values) + output = self._inject({"key1": "val1,val2", "key2": "val3=4"}) self.assertIn("key1=val1%2Cval2", output) self.assertIn("key2=val3%3D4", output) def test_inject_non_string_values(self): - values = { - "key1": True, - "key2": 123, - "key3": 123.567, - } - output = self._inject(values) + output = self._inject({"key1": True, "key2": 123, "key3": 123.567}) self.assertIn("key1=True", output) self.assertIn("key2=123", output) self.assertIn("key3=123.567", output) @patch("opentelemetry.baggage.propagation.baggage") - @patch("opentelemetry.baggage.propagation._format_baggage") - def test_fields(self, mock_format_baggage, mock_baggage): - - mock_set_in_carrier = Mock() + def test_fields(self, mock_baggage): - self.propagator.inject(mock_set_in_carrier, {}) + mock_baggage.configure_mock( + **{"get_all.return_value": {"a": "b", "c": "d"}} + ) - inject_fields = set() + carrier = {} - for mock_call in mock_set_in_carrier.mock_calls: - inject_fields.add(mock_call[1][1]) + self.propagator.inject(carrier) - self.assertEqual(inject_fields, self.propagator.fields) + self.assertEqual(carrier.keys(), self.propagator.fields) diff --git a/opentelemetry-api/tests/propagators/test_composite.py b/opentelemetry-api/tests/propagators/test_composite.py index 232e177d3d0..1020c6d426b 100644 --- a/opentelemetry-api/tests/propagators/test_composite.py +++ b/opentelemetry-api/tests/propagators/test_composite.py @@ -18,22 +18,15 @@ from opentelemetry.propagators.composite import CompositeHTTPPropagator -def get_as_list(dict_object, key): - value = dict_object.get(key) - return [value] if value is not None else [] - - def mock_inject(name, value="data"): - def wrapped(setter, carrier=None, context=None): + def wrapped(carrier=None, context=None): carrier[name] = value - setter({}, "inject_field_{}_0".format(name), None) - setter({}, "inject_field_{}_1".format(name), None) return wrapped def mock_extract(name, value="context"): - def wrapped(getter, carrier=None, context=None): + def wrapped(carrier=None, context=None): new_context = context.copy() new_context[name] = value return new_context @@ -67,24 +60,20 @@ def setUpClass(cls): def test_no_propagators(self): propagator = CompositeHTTPPropagator([]) new_carrier = {} - propagator.inject(dict.__setitem__, carrier=new_carrier) + propagator.inject(carrier=new_carrier) self.assertEqual(new_carrier, {}) - context = propagator.extract( - get_as_list, carrier=new_carrier, context={} - ) + context = propagator.extract(carrier=new_carrier, context={}) self.assertEqual(context, {}) def test_single_propagator(self): propagator = CompositeHTTPPropagator([self.mock_propagator_0]) new_carrier = {} - propagator.inject(dict.__setitem__, carrier=new_carrier) + propagator.inject(carrier=new_carrier) self.assertEqual(new_carrier, {"mock-0": "data"}) - context = propagator.extract( - get_as_list, carrier=new_carrier, context={} - ) + context = propagator.extract(carrier=new_carrier, context={}) self.assertEqual(context, {"mock-0": "context"}) def test_multiple_propagators(self): @@ -93,12 +82,10 @@ def test_multiple_propagators(self): ) new_carrier = {} - propagator.inject(dict.__setitem__, carrier=new_carrier) + propagator.inject(carrier=new_carrier) self.assertEqual(new_carrier, {"mock-0": "data", "mock-1": "data"}) - context = propagator.extract( - get_as_list, carrier=new_carrier, context={} - ) + context = propagator.extract(carrier=new_carrier, context={}) self.assertEqual(context, {"mock-0": "context", "mock-1": "context"}) def test_multiple_propagators_same_key(self): @@ -109,12 +96,10 @@ def test_multiple_propagators_same_key(self): ) new_carrier = {} - propagator.inject(dict.__setitem__, carrier=new_carrier) + propagator.inject(carrier=new_carrier) self.assertEqual(new_carrier, {"mock-0": "data2"}) - context = propagator.extract( - get_as_list, carrier=new_carrier, context={} - ) + context = propagator.extract(carrier=new_carrier, context={}) self.assertEqual(context, {"mock-0": "context2"}) def test_fields(self): @@ -126,13 +111,14 @@ def test_fields(self): ] ) - mock_set_in_carrier = Mock() - - propagator.inject(mock_set_in_carrier, {}) - - inject_fields = set() + propagator.inject({}) - for mock_call in mock_set_in_carrier.mock_calls: - inject_fields.add(mock_call[1][1]) - - self.assertEqual(inject_fields, propagator.fields) + self.assertEqual( + { + "inject_field_mock-0_0", + "inject_field_mock-0_1", + "inject_field_mock-1_0", + "inject_field_mock-1_1", + }, + propagator.fields, + ) diff --git a/opentelemetry-api/tests/propagators/test_global_httptextformat.py b/opentelemetry-api/tests/propagators/test_global_httptextformat.py index faa4023d5da..940096d0a40 100644 --- a/opentelemetry-api/tests/propagators/test_global_httptextformat.py +++ b/opentelemetry-api/tests/propagators/test_global_httptextformat.py @@ -16,45 +16,43 @@ from opentelemetry import baggage, trace from opentelemetry.propagate import extract, inject -from opentelemetry.propagators.textmap import DictGetter from opentelemetry.trace import get_current_span, set_span_in_context from opentelemetry.trace.span import format_span_id, format_trace_id -carrier_getter = DictGetter() - class TestDefaultGlobalPropagator(unittest.TestCase): """Test ensures the default global composite propagator works as intended""" - TRACE_ID = int("12345678901234567890123456789012", 16) # type:int - SPAN_ID = int("1234567890123456", 16) # type:int + trace_id = int("12345678901234567890123456789012", 16) # type:int + span_id = int("1234567890123456", 16) # type:int def test_propagation(self): traceparent_value = "00-{trace_id}-{span_id}-00".format( - trace_id=format_trace_id(self.TRACE_ID), - span_id=format_span_id(self.SPAN_ID), + trace_id=format_trace_id(self.trace_id), + span_id=format_span_id(self.span_id), + ) + + ctx = extract( + { + "baggage": ["key1=val1,key2=val2"], + "traceparent": [traceparent_value], + "tracestate": ["foo=1,bar=2,baz=3"], + } + ) + self.assertEqual( + baggage.get_all(context=ctx), {"key1": "val1", "key2": "val2"} ) - tracestate_value = "foo=1,bar=2,baz=3" - headers = { - "baggage": ["key1=val1,key2=val2"], - "traceparent": [traceparent_value], - "tracestate": [tracestate_value], - } - ctx = extract(carrier_getter, headers) - baggage_entries = baggage.get_all(context=ctx) - expected = {"key1": "val1", "key2": "val2"} - self.assertEqual(baggage_entries, expected) span_context = get_current_span(context=ctx).get_span_context() - self.assertEqual(span_context.trace_id, self.TRACE_ID) - self.assertEqual(span_context.span_id, self.SPAN_ID) + self.assertEqual(span_context.trace_id, self.trace_id) + self.assertEqual(span_context.span_id, self.span_id) span = trace.NonRecordingSpan(span_context) ctx = baggage.set_baggage("key3", "val3") ctx = baggage.set_baggage("key4", "val4", context=ctx) ctx = set_span_in_context(span, context=ctx) output = {} - inject(dict.__setitem__, output, context=ctx) + inject(output, context=ctx) self.assertEqual(traceparent_value, output["traceparent"]) self.assertIn("key3=val3", output["baggage"]) self.assertIn("key4=val4", output["baggage"]) diff --git a/opentelemetry-api/tests/trace/propagation/test_textmap.py b/opentelemetry-api/tests/trace/propagation/test_textmap.py deleted file mode 100644 index 12e851de348..00000000000 --- a/opentelemetry-api/tests/trace/propagation/test_textmap.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright The OpenTelemetry Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -from opentelemetry.propagators.textmap import DictGetter - - -class TestDictGetter(unittest.TestCase): - def test_get_none(self): - getter = DictGetter() - carrier = {} - val = getter.get(carrier, "test") - self.assertIsNone(val) - - def test_get_str(self): - getter = DictGetter() - carrier = {"test": "val"} - val = getter.get(carrier, "test") - self.assertEqual(val, ["val"]) - - def test_get_iter(self): - getter = DictGetter() - carrier = {"test": ["val"]} - val = getter.get(carrier, "test") - self.assertEqual(val, ["val"]) - - def test_keys(self): - getter = DictGetter() - keys = getter.keys({"test": "val"}) - self.assertEqual(keys, ["test"]) diff --git a/opentelemetry-api/tests/trace/propagation/test_tracecontexthttptextformat.py b/opentelemetry-api/tests/trace/propagation/test_tracecontexthttptextformat.py index cff30b7c9b8..6fb44fa2dfa 100644 --- a/opentelemetry-api/tests/trace/propagation/test_tracecontexthttptextformat.py +++ b/opentelemetry-api/tests/trace/propagation/test_tracecontexthttptextformat.py @@ -13,22 +13,26 @@ # limitations under the License. import typing -import unittest +from unittest import TestCase from unittest.mock import Mock, patch -from opentelemetry import trace -from opentelemetry.propagators.textmap import DictGetter +from opentelemetry.trace import ( + INVALID_SPAN, + INVALID_SPAN_CONTEXT, + NonRecordingSpan, + SpanContext, + get_current_span, + set_span_in_context, +) from opentelemetry.trace.propagation import tracecontext from opentelemetry.trace.span import TraceState FORMAT = tracecontext.TraceContextTextMapPropagator() -carrier_getter = DictGetter() - -class TestTraceContextFormat(unittest.TestCase): - TRACE_ID = int("12345678901234567890123456789012", 16) # type:int - SPAN_ID = int("1234567890123456", 16) # type:int +class TestTraceContextFormat(TestCase): + trace_id = int("12345678901234567890123456789012", 16) # type:int + span_id = int("1234567890123456", 16) # type:int def test_no_traceparent_header(self): """When tracecontext headers are not present, a new SpanContext @@ -40,38 +44,37 @@ def test_no_traceparent_header(self): trace-id and parent-id that represents the current request. """ output = {} # type:typing.Dict[str, typing.List[str]] - span = trace.get_current_span(FORMAT.extract(carrier_getter, output)) - self.assertIsInstance(span.get_span_context(), trace.SpanContext) + span = get_current_span(FORMAT.extract(output)) + self.assertIsInstance(span.get_span_context(), SpanContext) def test_headers_with_tracestate(self): """When there is a traceparent and tracestate header, data from both should be addded to the SpanContext. """ traceparent_value = "00-{trace_id}-{span_id}-00".format( - trace_id=format(self.TRACE_ID, "032x"), - span_id=format(self.SPAN_ID, "016x"), + trace_id=format(self.trace_id, "032x"), + span_id=format(self.span_id, "016x"), ) tracestate_value = "foo=1,bar=2,baz=3" - span_context = trace.get_current_span( + span_context = get_current_span( FORMAT.extract( - carrier_getter, { "traceparent": [traceparent_value], "tracestate": [tracestate_value], }, ) ).get_span_context() - self.assertEqual(span_context.trace_id, self.TRACE_ID) - self.assertEqual(span_context.span_id, self.SPAN_ID) + self.assertEqual(span_context.trace_id, self.trace_id) + self.assertEqual(span_context.span_id, self.span_id) self.assertEqual( span_context.trace_state, {"foo": "1", "bar": "2", "baz": "3"} ) self.assertTrue(span_context.is_remote) output = {} # type:typing.Dict[str, str] - span = trace.NonRecordingSpan(span_context) + span = NonRecordingSpan(span_context) - ctx = trace.set_span_in_context(span) - FORMAT.inject(dict.__setitem__, output, ctx) + ctx = set_span_in_context(span) + FORMAT.inject(output, ctx) self.assertEqual(output["traceparent"], traceparent_value) for pair in ["foo=1", "bar=2", "baz=3"]: self.assertIn(pair, output["tracestate"]) @@ -96,18 +99,18 @@ def test_invalid_trace_id(self): Note that the opposite is not true: failure to parse tracestate MUST NOT affect the parsing of traceparent. """ - span = trace.get_current_span( + span = get_current_span( FORMAT.extract( - carrier_getter, { "traceparent": [ - "00-00000000000000000000000000000000-1234567890123456-00" + "00-00000000000000000000000000000000-" + "1234567890123456-00" ], "tracestate": ["foo=1,bar=2,foo=3"], }, ) ) - self.assertEqual(span.get_span_context(), trace.INVALID_SPAN_CONTEXT) + self.assertEqual(span.get_span_context(), INVALID_SPAN_CONTEXT) def test_invalid_parent_id(self): """If the parent id is invalid, we must ignore the full traceparent @@ -127,18 +130,18 @@ def test_invalid_parent_id(self): Note that the opposite is not true: failure to parse tracestate MUST NOT affect the parsing of traceparent. """ - span = trace.get_current_span( + span = get_current_span( FORMAT.extract( - carrier_getter, { "traceparent": [ - "00-00000000000000000000000000000000-0000000000000000-00" + "00-00000000000000000000000000000000-" + "0000000000000000-00" ], "tracestate": ["foo=1,bar=2,foo=3"], }, ) ) - self.assertEqual(span.get_span_context(), trace.INVALID_SPAN_CONTEXT) + self.assertEqual(span.get_span_context(), INVALID_SPAN_CONTEXT) def test_no_send_empty_tracestate(self): """If the tracestate is empty, do not set the header. @@ -149,11 +152,12 @@ def test_no_send_empty_tracestate(self): empty tracestate headers but SHOULD avoid sending them. """ output = {} # type:typing.Dict[str, str] - span = trace.NonRecordingSpan( - trace.SpanContext(self.TRACE_ID, self.SPAN_ID, is_remote=False) + span = NonRecordingSpan( + SpanContext(self.trace_id, self.span_id, is_remote=False) ) - ctx = trace.set_span_in_context(span) - FORMAT.inject(dict.__setitem__, output, ctx) + ctx = set_span_in_context(span) + + FORMAT.inject(output, ctx) self.assertTrue("traceparent" in output) self.assertFalse("tracestate" in output) @@ -165,9 +169,8 @@ def test_format_not_supported(self): If the version cannot be parsed, return an invalid trace header. """ - span = trace.get_current_span( + span = get_current_span( FORMAT.extract( - carrier_getter, { "traceparent": [ "00-12345678901234567890123456789012-" @@ -177,23 +180,23 @@ def test_format_not_supported(self): }, ) ) - self.assertEqual(span.get_span_context(), trace.INVALID_SPAN_CONTEXT) + self.assertEqual(span.get_span_context(), INVALID_SPAN_CONTEXT) def test_propagate_invalid_context(self): """Do not propagate invalid trace context.""" output = {} # type:typing.Dict[str, str] - ctx = trace.set_span_in_context(trace.INVALID_SPAN) - FORMAT.inject(dict.__setitem__, output, context=ctx) + ctx = set_span_in_context(INVALID_SPAN) + FORMAT.inject(output, context=ctx) self.assertFalse("traceparent" in output) def test_tracestate_empty_header(self): """Test tracestate with an additional empty header (should be ignored)""" - span = trace.get_current_span( + span = get_current_span( FORMAT.extract( - carrier_getter, { "traceparent": [ - "00-12345678901234567890123456789012-1234567890123456-00" + "00-12345678901234567890123456789012-" + "1234567890123456-00" ], "tracestate": ["foo=1", ""], }, @@ -203,12 +206,12 @@ def test_tracestate_empty_header(self): def test_tracestate_header_with_trailing_comma(self): """Do not propagate invalid trace context.""" - span = trace.get_current_span( + span = get_current_span( FORMAT.extract( - carrier_getter, { "traceparent": [ - "00-12345678901234567890123456789012-1234567890123456-00" + "00-12345678901234567890123456789012-" + "1234567890123456-00" ], "tracestate": ["foo=1,"], }, @@ -226,9 +229,8 @@ def test_tracestate_keys(self): "foo-_*/bar=bar4", ] ) - span = trace.get_current_span( + span = get_current_span( FORMAT.extract( - carrier_getter, { "traceparent": [ "00-12345678901234567890123456789012-" @@ -249,9 +251,8 @@ def test_tracestate_keys(self): span.get_span_context().trace_state["foo-_*/bar"], "bar4" ) - @patch("opentelemetry.trace.INVALID_SPAN_CONTEXT") - @patch("opentelemetry.trace.get_current_span") - def test_fields(self, mock_get_current_span, mock_invalid_span_context): + @patch("opentelemetry.trace.propagation.tracecontext.get_current_span") + def test_fields(self, mock_get_current_span): mock_get_current_span.configure_mock( return_value=Mock( @@ -268,13 +269,8 @@ def test_fields(self, mock_get_current_span, mock_invalid_span_context): ) ) - mock_set_in_carrier = Mock() - - FORMAT.inject(mock_set_in_carrier, {}) - - inject_fields = set() + carrier = {} - for mock_call in mock_set_in_carrier.mock_calls: - inject_fields.add(mock_call[1][1]) + FORMAT.inject(carrier) - self.assertEqual(inject_fields, FORMAT.fields) + self.assertEqual(carrier.keys(), {"traceparent", "tracestate"}) diff --git a/propagator/opentelemetry-propagator-b3/src/opentelemetry/propagators/b3/__init__.py b/propagator/opentelemetry-propagator-b3/src/opentelemetry/propagators/b3/__init__.py index 01abcc7c879..7906b049243 100644 --- a/propagator/opentelemetry-propagator-b3/src/opentelemetry/propagators/b3/__init__.py +++ b/propagator/opentelemetry-propagator-b3/src/opentelemetry/propagators/b3/__init__.py @@ -12,14 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import typing -from re import compile as re_compile +from re import compile as compile_ +from typing import Optional, Set import opentelemetry.trace as trace from opentelemetry.context import Context from opentelemetry.propagators.textmap import ( - Getter, - Setter, TextMapPropagator, TextMapPropagatorT, ) @@ -39,23 +37,19 @@ class B3Format(TextMapPropagator): SAMPLED_KEY = "x-b3-sampled" FLAGS_KEY = "x-b3-flags" _SAMPLE_PROPAGATE_VALUES = set(["1", "True", "true", "d"]) - _trace_id_regex = re_compile(r"[\da-fA-F]{16}|[\da-fA-F]{32}") - _span_id_regex = re_compile(r"[\da-fA-F]{16}") + _trace_id_regex = compile_(r"[\da-fA-F]{16}|[\da-fA-F]{32}") + _span_id_regex = compile_(r"[\da-fA-F]{16}") def extract( - self, - getter: Getter[TextMapPropagatorT], - carrier: TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> Context: trace_id = format_trace_id(trace.INVALID_TRACE_ID) span_id = format_span_id(trace.INVALID_SPAN_ID) sampled = "0" flags = None - single_header = _extract_first_element( - getter.get(carrier, self.SINGLE_HEADER_KEY) - ) + single_header = next(iter(carrier.get(self.TRACE_ID_KEY, [])), False) + if single_header: # The b3 spec calls for the sampling state to be # "deferred", which is unspecified. This concept does not @@ -75,21 +69,16 @@ def extract( return trace.set_span_in_context(trace.INVALID_SPAN) else: trace_id = ( - _extract_first_element(getter.get(carrier, self.TRACE_ID_KEY)) + next(iter(carrier.get(self.TRACE_ID_KEY, [])), False) or trace_id ) span_id = ( - _extract_first_element(getter.get(carrier, self.SPAN_ID_KEY)) - or span_id + next(iter(carrier.get(self.SPAN_ID_KEY, [])), False) or span_id ) sampled = ( - _extract_first_element(getter.get(carrier, self.SAMPLED_KEY)) - or sampled - ) - flags = ( - _extract_first_element(getter.get(carrier, self.FLAGS_KEY)) - or flags + next(iter(carrier.get(self.SAMPLED_KEY, [])), False) or sampled ) + flags = next(iter(carrier.get(self.FLAGS_KEY, [])), False) or flags if ( self._trace_id_regex.fullmatch(trace_id) is None @@ -126,10 +115,7 @@ def extract( ) def inject( - self, - set_in_carrier: Setter[TextMapPropagatorT], - carrier: TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> None: span = trace.get_current_span(context=context) @@ -138,34 +124,20 @@ def inject( return sampled = (trace.TraceFlags.SAMPLED & span_context.trace_flags) != 0 - set_in_carrier( - carrier, self.TRACE_ID_KEY, format_trace_id(span_context.trace_id), - ) - set_in_carrier( - carrier, self.SPAN_ID_KEY, format_span_id(span_context.span_id) - ) + carrier[self.TRACE_ID_KEY] = format_trace_id(span_context.trace_id) + carrier[self.SPAN_ID_KEY] = format_span_id(span_context.span_id) span_parent = getattr(span, "parent", None) if span_parent is not None: - set_in_carrier( - carrier, - self.PARENT_SPAN_ID_KEY, + carrier[self.PARENT_SPAN_ID_KEY] = ( format_span_id(span_parent.span_id), ) - set_in_carrier(carrier, self.SAMPLED_KEY, "1" if sampled else "0") + carrier[self.SAMPLED_KEY] = "1" if sampled else "0" @property - def fields(self) -> typing.Set[str]: + def fields(self) -> Set[str]: return { self.TRACE_ID_KEY, self.SPAN_ID_KEY, self.PARENT_SPAN_ID_KEY, self.SAMPLED_KEY, } - - -def _extract_first_element( - items: typing.Iterable[TextMapPropagatorT], -) -> typing.Optional[TextMapPropagatorT]: - if items is None: - return None - return next(iter(items), None) diff --git a/propagator/opentelemetry-propagator-b3/tests/performance/benchmarks/trace/propagation/test_benchmark_b3_format.py b/propagator/opentelemetry-propagator-b3/tests/performance/benchmarks/trace/propagation/test_benchmark_b3_format.py index 5048f495f06..3a7a251ad88 100644 --- a/propagator/opentelemetry-propagator-b3/tests/performance/benchmarks/trace/propagation/test_benchmark_b3_format.py +++ b/propagator/opentelemetry-propagator-b3/tests/performance/benchmarks/trace/propagation/test_benchmark_b3_format.py @@ -14,7 +14,6 @@ import opentelemetry.propagators.b3 as b3_format import opentelemetry.sdk.trace as trace -from opentelemetry.propagators.textmap import DictGetter FORMAT = b3_format.B3Format() @@ -22,7 +21,6 @@ def test_extract_single_header(benchmark): benchmark( FORMAT.extract, - DictGetter(), { FORMAT.SINGLE_HEADER_KEY: "bdb5b63237ed38aea578af665aa5aa60-c32d953d73ad2251-1-11fd79a30b0896cd285b396ae102dd76" }, @@ -35,7 +33,6 @@ def test_inject_empty_context(benchmark): with tracer.start_as_current_span("Child Span"): benchmark( FORMAT.inject, - dict.__setitem__, { FORMAT.TRACE_ID_KEY: "bdb5b63237ed38aea578af665aa5aa60", FORMAT.SPAN_ID_KEY: "00000000000000000c32d953d73ad225", diff --git a/propagator/opentelemetry-propagator-b3/tests/test_b3_format.py b/propagator/opentelemetry-propagator-b3/tests/test_b3_format.py index f9d3bce1adb..ac48494aeaa 100644 --- a/propagator/opentelemetry-propagator-b3/tests/test_b3_format.py +++ b/propagator/opentelemetry-propagator-b3/tests/test_b3_format.py @@ -12,31 +12,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest +from unittest import TestCase from unittest.mock import Mock, patch -import opentelemetry.propagators.b3 as b3_format # pylint: disable=no-name-in-module,import-error -import opentelemetry.sdk.trace as trace -import opentelemetry.sdk.trace.id_generator as id_generator -import opentelemetry.trace as trace_api +from opentelemetry import trace from opentelemetry.context import get_current -from opentelemetry.propagators.textmap import DictGetter - -FORMAT = b3_format.B3Format() - - -carrier_getter = DictGetter() +from opentelemetry.propagators.b3 import ( + B3Format, + format_span_id, + format_trace_id, +) +from opentelemetry.sdk.trace import TracerProvider, _Span, id_generator +from opentelemetry.trace import ( + INVALID_SPAN_ID, + INVALID_TRACE_ID, + SpanContext, + get_current_span, + set_span_in_context, +) + +format_ = B3Format() def get_child_parent_new_carrier(old_carrier): - ctx = FORMAT.extract(carrier_getter, old_carrier) - parent_span_context = trace_api.get_current_span(ctx).get_span_context() + ctx = format_.extract(old_carrier) + parent_span_context = get_current_span(ctx).get_span_context() - parent = trace._Span("parent", parent_span_context) - child = trace._Span( + parent = _Span("parent", parent_span_context) + child = _Span( "child", - trace_api.SpanContext( + SpanContext( parent_span_context.trace_id, id_generator.RandomIdGenerator().generate_span_id(), is_remote=False, @@ -47,30 +53,26 @@ def get_child_parent_new_carrier(old_carrier): ) new_carrier = {} - ctx = trace_api.set_span_in_context(child) - FORMAT.inject(dict.__setitem__, new_carrier, context=ctx) + ctx = set_span_in_context(child) + format_.inject(new_carrier, context=ctx) return child, parent, new_carrier -class TestB3Format(unittest.TestCase): +class TestB3Format(TestCase): @classmethod def setUpClass(cls): generator = id_generator.RandomIdGenerator() - cls.serialized_trace_id = b3_format.format_trace_id( + cls.serialized_trace_id = format_trace_id( generator.generate_trace_id() ) - cls.serialized_span_id = b3_format.format_span_id( - generator.generate_span_id() - ) - cls.serialized_parent_id = b3_format.format_span_id( - generator.generate_span_id() - ) + cls.serialized_span_id = format_span_id(generator.generate_span_id()) + cls.serialized_parent_id = format_span_id(generator.generate_span_id()) def setUp(self) -> None: - tracer_provider = trace.TracerProvider() - patcher = unittest.mock.patch.object( - trace_api, "get_tracer_provider", return_value=tracer_provider + tracer_provider = TracerProvider() + patcher = patch.object( + trace, "get_tracer_provider", return_value=tracer_provider ) patcher.start() self.addCleanup(patcher.stop) @@ -79,52 +81,52 @@ def test_extract_multi_header(self): """Test the extraction of B3 headers.""" child, parent, new_carrier = get_child_parent_new_carrier( { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.PARENT_SPAN_ID_KEY: self.serialized_parent_id, - FORMAT.SAMPLED_KEY: "1", + format_.TRACE_ID_KEY: self.serialized_trace_id, + format_.SPAN_ID_KEY: self.serialized_span_id, + format_.PARENT_SPAN_ID_KEY: self.serialized_parent_id, + format_.SAMPLED_KEY: "1", } ) self.assertEqual( - new_carrier[FORMAT.TRACE_ID_KEY], - b3_format.format_trace_id(child.context.trace_id), + new_carrier[format_.TRACE_ID_KEY], + format_trace_id(child.context.trace_id), ) self.assertEqual( - new_carrier[FORMAT.SPAN_ID_KEY], - b3_format.format_span_id(child.context.span_id), + new_carrier[format_.SPAN_ID_KEY], + format_span_id(child.context.span_id), ) self.assertEqual( - new_carrier[FORMAT.PARENT_SPAN_ID_KEY], - b3_format.format_span_id(parent.context.span_id), + new_carrier[format_.PARENT_SPAN_ID_KEY], + format_span_id(parent.context.span_id), ) self.assertTrue(parent.context.is_remote) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + self.assertEqual(new_carrier[format_.SAMPLED_KEY], "1") def test_extract_single_header(self): """Test the extraction from a single b3 header.""" child, parent, new_carrier = get_child_parent_new_carrier( { - FORMAT.SINGLE_HEADER_KEY: "{}-{}".format( + format_.SINGLE_HEADER_KEY: "{}-{}".format( self.serialized_trace_id, self.serialized_span_id ) } ) self.assertEqual( - new_carrier[FORMAT.TRACE_ID_KEY], - b3_format.format_trace_id(child.context.trace_id), + new_carrier[format_.TRACE_ID_KEY], + format_trace_id(child.context.trace_id), ) self.assertEqual( - new_carrier[FORMAT.SPAN_ID_KEY], - b3_format.format_span_id(child.context.span_id), + new_carrier[format_.SPAN_ID_KEY], + format_span_id(child.context.span_id), ) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + self.assertEqual(new_carrier[format_.SAMPLED_KEY], "1") self.assertTrue(parent.context.is_remote) child, parent, new_carrier = get_child_parent_new_carrier( { - FORMAT.SINGLE_HEADER_KEY: "{}-{}-1-{}".format( + format_.SINGLE_HEADER_KEY: "{}-{}-1-{}".format( self.serialized_trace_id, self.serialized_span_id, self.serialized_parent_id, @@ -133,19 +135,19 @@ def test_extract_single_header(self): ) self.assertEqual( - new_carrier[FORMAT.TRACE_ID_KEY], - b3_format.format_trace_id(child.context.trace_id), + new_carrier[format_.TRACE_ID_KEY], + format_trace_id(child.context.trace_id), ) self.assertEqual( - new_carrier[FORMAT.SPAN_ID_KEY], - b3_format.format_span_id(child.context.span_id), + new_carrier[format_.SPAN_ID_KEY], + format_span_id(child.context.span_id), ) self.assertEqual( - new_carrier[FORMAT.PARENT_SPAN_ID_KEY], - b3_format.format_span_id(parent.context.span_id), + new_carrier[format_.PARENT_SPAN_ID_KEY], + format_span_id(parent.context.span_id), ) self.assertTrue(parent.context.is_remote) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + self.assertEqual(new_carrier[format_.SAMPLED_KEY], "1") def test_extract_header_precedence(self): """A single b3 header should take precedence over multiple @@ -155,17 +157,17 @@ def test_extract_header_precedence(self): _, _, new_carrier = get_child_parent_new_carrier( { - FORMAT.SINGLE_HEADER_KEY: "{}-{}".format( + format_.SINGLE_HEADER_KEY: "{}-{}".format( single_header_trace_id, self.serialized_span_id ), - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.SAMPLED_KEY: "1", + format_.TRACE_ID_KEY: self.serialized_trace_id, + format_.SPAN_ID_KEY: self.serialized_span_id, + format_.SAMPLED_KEY: "1", } ) self.assertEqual( - new_carrier[FORMAT.TRACE_ID_KEY], single_header_trace_id + new_carrier[format_.TRACE_ID_KEY], single_header_trace_id ) def test_enabled_sampling(self): @@ -173,50 +175,50 @@ def test_enabled_sampling(self): for variant in ["1", "True", "true", "d"]: _, _, new_carrier = get_child_parent_new_carrier( { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.SAMPLED_KEY: variant, + format_.TRACE_ID_KEY: self.serialized_trace_id, + format_.SPAN_ID_KEY: self.serialized_span_id, + format_.SAMPLED_KEY: variant, } ) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + self.assertEqual(new_carrier[format_.SAMPLED_KEY], "1") def test_disabled_sampling(self): """Test b3 sample key variants that turn off sampling.""" for variant in ["0", "False", "false", None]: _, _, new_carrier = get_child_parent_new_carrier( { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.SAMPLED_KEY: variant, + format_.TRACE_ID_KEY: self.serialized_trace_id, + format_.SPAN_ID_KEY: self.serialized_span_id, + format_.SAMPLED_KEY: variant, } ) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "0") + self.assertEqual(new_carrier[format_.SAMPLED_KEY], "0") def test_flags(self): """x-b3-flags set to "1" should result in propagation.""" _, _, new_carrier = get_child_parent_new_carrier( { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.FLAGS_KEY: "1", + format_.TRACE_ID_KEY: self.serialized_trace_id, + format_.SPAN_ID_KEY: self.serialized_span_id, + format_.FLAGS_KEY: "1", } ) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + self.assertEqual(new_carrier[format_.SAMPLED_KEY], "1") def test_flags_and_sampling(self): """Propagate if b3 flags and sampling are set.""" _, _, new_carrier = get_child_parent_new_carrier( { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.FLAGS_KEY: "1", + format_.TRACE_ID_KEY: self.serialized_trace_id, + format_.SPAN_ID_KEY: self.serialized_span_id, + format_.FLAGS_KEY: "1", } ) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + self.assertEqual(new_carrier[format_.SAMPLED_KEY], "1") def test_64bit_trace_id(self): """64 bit trace ids should be padded to 128 bit trace ids.""" @@ -224,36 +226,36 @@ def test_64bit_trace_id(self): _, _, new_carrier = get_child_parent_new_carrier( { - FORMAT.TRACE_ID_KEY: trace_id_64_bit, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.FLAGS_KEY: "1", + format_.TRACE_ID_KEY: trace_id_64_bit, + format_.SPAN_ID_KEY: self.serialized_span_id, + format_.FLAGS_KEY: "1", } ) self.assertEqual( - new_carrier[FORMAT.TRACE_ID_KEY], "0" * 16 + trace_id_64_bit + new_carrier[format_.TRACE_ID_KEY], "0" * 16 + trace_id_64_bit ) def test_invalid_single_header(self): """If an invalid single header is passed, return an invalid SpanContext. """ - carrier = {FORMAT.SINGLE_HEADER_KEY: "0-1-2-3-4-5-6-7"} - ctx = FORMAT.extract(carrier_getter, carrier) - span_context = trace_api.get_current_span(ctx).get_span_context() - self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID) - self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID) + carrier = {format_.SINGLE_HEADER_KEY: "0-1-2-3-4-5-6-7"} + ctx = format_.extract(carrier) + span_context = get_current_span(ctx).get_span_context() + self.assertEqual(span_context.trace_id, INVALID_TRACE_ID) + self.assertEqual(span_context.span_id, INVALID_SPAN_ID) def test_missing_trace_id(self): """If a trace id is missing, populate an invalid trace id.""" carrier = { - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.FLAGS_KEY: "1", + format_.SPAN_ID_KEY: self.serialized_span_id, + format_.FLAGS_KEY: "1", } - ctx = FORMAT.extract(carrier_getter, carrier) - span_context = trace_api.get_current_span(ctx).get_span_context() - self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID) + ctx = format_.extract(carrier) + span_context = get_current_span(ctx).get_span_context() + self.assertEqual(span_context.trace_id, INVALID_TRACE_ID) @patch( "opentelemetry.sdk.trace.id_generator.RandomIdGenerator.generate_trace_id" @@ -270,13 +272,13 @@ def test_invalid_trace_id( mock_generate_span_id.configure_mock(return_value=2) carrier = { - FORMAT.TRACE_ID_KEY: "abc123", - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.FLAGS_KEY: "1", + format_.TRACE_ID_KEY: "abc123", + format_.SPAN_ID_KEY: self.serialized_span_id, + format_.FLAGS_KEY: "1", } - ctx = FORMAT.extract(carrier_getter, carrier) - span_context = trace_api.get_current_span(ctx).get_span_context() + ctx = format_.extract(carrier) + span_context = get_current_span(ctx).get_span_context() self.assertEqual(span_context.trace_id, 1) self.assertEqual(span_context.span_id, 2) @@ -296,13 +298,13 @@ def test_invalid_span_id( mock_generate_span_id.configure_mock(return_value=2) carrier = { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: "abc123", - FORMAT.FLAGS_KEY: "1", + format_.TRACE_ID_KEY: self.serialized_trace_id, + format_.SPAN_ID_KEY: "abc123", + format_.FLAGS_KEY: "1", } - ctx = FORMAT.extract(carrier_getter, carrier) - span_context = trace_api.get_current_span(ctx).get_span_context() + ctx = format_.extract(carrier) + span_context = get_current_span(ctx).get_span_context() self.assertEqual(span_context.trace_id, 1) self.assertEqual(span_context.span_id, 2) @@ -310,49 +312,43 @@ def test_invalid_span_id( def test_missing_span_id(self): """If a trace id is missing, populate an invalid trace id.""" carrier = { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.FLAGS_KEY: "1", + format_.TRACE_ID_KEY: self.serialized_trace_id, + format_.FLAGS_KEY: "1", } - ctx = FORMAT.extract(carrier_getter, carrier) - span_context = trace_api.get_current_span(ctx).get_span_context() - self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID) + ctx = format_.extract(carrier) + span_context = get_current_span(ctx).get_span_context() + self.assertEqual(span_context.span_id, INVALID_SPAN_ID) - @staticmethod - def test_inject_empty_context(): + def test_inject_empty_context(self): """If the current context has no span, don't add headers""" new_carrier = {} - FORMAT.inject(dict.__setitem__, new_carrier, get_current()) - assert len(new_carrier) == 0 - - @staticmethod - def test_default_span(): - """Make sure propagator does not crash when working with NonRecordingSpan""" - - class CarrierGetter(DictGetter): - def get(self, carrier, key): - return carrier.get(key, None) + format_.inject(new_carrier, get_current()) + self.assertEqual(len(new_carrier), 0) - def setter(carrier, key, value): - carrier[key] = value + def test_default_span(self): + """Make sure propagator does not crash when working with + NonRecordingSpan""" - ctx = FORMAT.extract(CarrierGetter(), {}) - FORMAT.inject(setter, {}, ctx) + try: + format_.inject({}, format_.extract({})) + except Exception: # pylint: disable=broad-except + self.fail("propagator crashed when working with NonRecordingSpan") def test_fields(self): """Make sure the fields attribute returns the fields used in inject""" - tracer = trace.TracerProvider().get_tracer("sdk_tracer_provider") + tracer = TracerProvider().get_tracer("sdk_tracer_provider") mock_set_in_carrier = Mock() with tracer.start_as_current_span("parent"): with tracer.start_as_current_span("child"): - FORMAT.inject(mock_set_in_carrier, {}) + format_.inject(mock_set_in_carrier, {}) inject_fields = set() for call in mock_set_in_carrier.mock_calls: inject_fields.add(call[1][1]) - self.assertEqual(FORMAT.fields, inject_fields) + self.assertEqual(format_.fields, inject_fields) diff --git a/propagator/opentelemetry-propagator-jaeger/src/opentelemetry/propagators/jaeger/__init__.py b/propagator/opentelemetry-propagator-jaeger/src/opentelemetry/propagators/jaeger/__init__.py index 8e7fe5f69ff..18bef73c5bb 100644 --- a/propagator/opentelemetry-propagator-jaeger/src/opentelemetry/propagators/jaeger/__init__.py +++ b/propagator/opentelemetry-propagator-jaeger/src/opentelemetry/propagators/jaeger/__init__.py @@ -12,19 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -import typing -import urllib.parse +from typing import Optional, Set +from urllib.parse import quote, unquote -import opentelemetry.trace as trace from opentelemetry import baggage from opentelemetry.context import Context, get_current from opentelemetry.propagators.textmap import ( - Getter, - Setter, TextMapPropagator, TextMapPropagatorT, ) -from opentelemetry.trace import format_span_id, format_trace_id +from opentelemetry.trace import ( + INVALID_SPAN, + INVALID_SPAN_CONTEXT, + INVALID_SPAN_ID, + INVALID_TRACE_ID, + NonRecordingSpan, + SpanContext, + TraceFlags, + format_span_id, + format_trace_id, + get_current_span, + set_span_in_context, +) class JaegerPropagator(TextMapPropagator): @@ -33,73 +42,68 @@ class JaegerPropagator(TextMapPropagator): See: https://www.jaegertracing.io/docs/1.19/client-libraries/#propagation-format """ - TRACE_ID_KEY = "uber-trace-id" - BAGGAGE_PREFIX = "uberctx-" - DEBUG_FLAG = 0x02 + _trace_id_key = "uber-trace-id" + _baggage_prefix = "uberctx-" + _debug_flag = 0x02 def extract( - self, - getter: Getter[TextMapPropagatorT], - carrier: TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> Context: if context is None: context = get_current() - header = getter.get(carrier, self.TRACE_ID_KEY) - if not header: - return trace.set_span_in_context(trace.INVALID_SPAN, context) - fields = _extract_first_element(header).split(":") + header = carrier.get(self._trace_id_key) + if header is None: + return set_span_in_context(INVALID_SPAN, context) + fields = next(iter(header)).split(":") + + for key in [ + key + for key in carrier.keys() + if key.startswith(self._baggage_prefix) + ]: + context = baggage.set_baggage( + key.replace(self._baggage_prefix, ""), + unquote(next(iter(carrier[key]))).strip(), + context=context, + ) - context = self._extract_baggage(getter, carrier, context) if len(fields) != 4: - return trace.set_span_in_context(trace.INVALID_SPAN, context) + return set_span_in_context(INVALID_SPAN, context) trace_id, span_id, _parent_id, flags = fields - if ( - trace_id == trace.INVALID_TRACE_ID - or span_id == trace.INVALID_SPAN_ID - ): - return trace.set_span_in_context(trace.INVALID_SPAN, context) - - span = trace.NonRecordingSpan( - trace.SpanContext( + if trace_id == INVALID_TRACE_ID or span_id == INVALID_SPAN_ID: + return set_span_in_context(INVALID_SPAN, context) + + span = NonRecordingSpan( + SpanContext( trace_id=int(trace_id, 16), span_id=int(span_id, 16), is_remote=True, - trace_flags=trace.TraceFlags( - int(flags, 16) & trace.TraceFlags.SAMPLED - ), + trace_flags=TraceFlags(int(flags, 16) & TraceFlags.SAMPLED), ) ) - return trace.set_span_in_context(span, context) + return set_span_in_context(span, context) def inject( - self, - set_in_carrier: Setter[TextMapPropagatorT], - carrier: TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> None: - span = trace.get_current_span(context=context) + span = get_current_span(context=context) span_context = span.get_span_context() - if span_context == trace.INVALID_SPAN_CONTEXT: + if span_context == INVALID_SPAN_CONTEXT: return span_parent_id = span.parent.span_id if span.parent else 0 trace_flags = span_context.trace_flags if trace_flags.sampled: - trace_flags |= self.DEBUG_FLAG + trace_flags |= self._debug_flag # set span identity - set_in_carrier( - carrier, - self.TRACE_ID_KEY, - _format_uber_trace_id( - span_context.trace_id, - span_context.span_id, - span_parent_id, - trace_flags, - ), + carrier[self._trace_id_key] = _format_uber_trace_id( + span_context.trace_id, + span_context.span_id, + span_parent_id, + trace_flags, ) # set span baggage, if any @@ -107,43 +111,18 @@ def inject( if not baggage_entries: return for key, value in baggage_entries.items(): - baggage_key = self.BAGGAGE_PREFIX + key - set_in_carrier( - carrier, baggage_key, urllib.parse.quote(str(value)) - ) + baggage_key = self._baggage_prefix + key + carrier[baggage_key] = quote(str(value)) @property - def fields(self) -> typing.Set[str]: - return {self.TRACE_ID_KEY} - - def _extract_baggage(self, getter, carrier, context): - baggage_keys = [ - key - for key in getter.keys(carrier) - if key.startswith(self.BAGGAGE_PREFIX) - ] - for key in baggage_keys: - value = _extract_first_element(getter.get(carrier, key)) - context = baggage.set_baggage( - key.replace(self.BAGGAGE_PREFIX, ""), - urllib.parse.unquote(value).strip(), - context=context, - ) - return context + def fields(self) -> Set[str]: + return {self._trace_id_key} def _format_uber_trace_id(trace_id, span_id, parent_span_id, flags): - return "{trace_id}:{span_id}:{parent_id}:{:02x}".format( - flags, + return "{trace_id}:{span_id}:{parent_id}:{flags:02x}".format( trace_id=format_trace_id(trace_id), span_id=format_span_id(span_id), parent_id=format_span_id(parent_span_id), + flags=flags, ) - - -def _extract_first_element( - items: typing.Iterable[TextMapPropagatorT], -) -> typing.Optional[TextMapPropagatorT]: - if items is None: - return None - return next(iter(items), None) diff --git a/propagator/opentelemetry-propagator-jaeger/tests/test_jaeger_propagator.py b/propagator/opentelemetry-propagator-jaeger/tests/test_jaeger_propagator.py index da8a855edbe..1fdb6c74062 100644 --- a/propagator/opentelemetry-propagator-jaeger/tests/test_jaeger_propagator.py +++ b/propagator/opentelemetry-propagator-jaeger/tests/test_jaeger_propagator.py @@ -12,27 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest +from unittest import TestCase from unittest.mock import Mock import opentelemetry.sdk.trace as trace import opentelemetry.sdk.trace.id_generator as id_generator import opentelemetry.trace as trace_api from opentelemetry import baggage -from opentelemetry.propagators import ( # pylint: disable=no-name-in-module - jaeger, +from opentelemetry.propagators.jaeger import ( # pylint: disable=no-name-in-module + JaegerPropagator, + _format_uber_trace_id, ) -from opentelemetry.propagators.textmap import DictGetter -FORMAT = jaeger.JaegerPropagator() - - -carrier_getter = DictGetter() +format_ = JaegerPropagator() def get_context_new_carrier(old_carrier, carrier_baggage=None): - ctx = FORMAT.extract(carrier_getter, old_carrier) + ctx = format_.extract(old_carrier) if carrier_baggage: for key, value in carrier_baggage.items(): ctx = baggage.set_baggage(key, value, ctx) @@ -54,93 +51,94 @@ def get_context_new_carrier(old_carrier, carrier_baggage=None): new_carrier = {} ctx = trace_api.set_span_in_context(child, ctx) - FORMAT.inject(dict.__setitem__, new_carrier, context=ctx) + format_.inject(new_carrier, context=ctx) return ctx, new_carrier -class TestJaegerPropagator(unittest.TestCase): +class TestJaegerPropagator(TestCase): + # pylint: disable=protected-access @classmethod def setUpClass(cls): generator = id_generator.RandomIdGenerator() cls.trace_id = generator.generate_trace_id() cls.span_id = generator.generate_span_id() cls.parent_span_id = generator.generate_span_id() - cls.serialized_uber_trace_id = jaeger._format_uber_trace_id( # pylint: disable=protected-access + cls.serialized_uber_trace_id = _format_uber_trace_id( cls.trace_id, cls.span_id, cls.parent_span_id, 11 ) def test_extract_valid_span(self): - old_carrier = {FORMAT.TRACE_ID_KEY: self.serialized_uber_trace_id} - ctx = FORMAT.extract(carrier_getter, old_carrier) + old_carrier = {format_._trace_id_key: self.serialized_uber_trace_id} + ctx = format_.extract(old_carrier) span_context = trace_api.get_current_span(ctx).get_span_context() self.assertEqual(span_context.trace_id, self.trace_id) self.assertEqual(span_context.span_id, self.span_id) def test_missing_carrier(self): old_carrier = {} - ctx = FORMAT.extract(carrier_getter, old_carrier) + ctx = format_.extract(old_carrier) span_context = trace_api.get_current_span(ctx).get_span_context() self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID) self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID) def test_trace_id(self): - old_carrier = {FORMAT.TRACE_ID_KEY: self.serialized_uber_trace_id} + old_carrier = {format_._trace_id_key: self.serialized_uber_trace_id} _, new_carrier = get_context_new_carrier(old_carrier) self.assertEqual( self.serialized_uber_trace_id.split(":")[0], - new_carrier[FORMAT.TRACE_ID_KEY].split(":")[0], + new_carrier[format_._trace_id_key].split(":")[0], ) def test_parent_span_id(self): - old_carrier = {FORMAT.TRACE_ID_KEY: self.serialized_uber_trace_id} + old_carrier = {format_._trace_id_key: self.serialized_uber_trace_id} _, new_carrier = get_context_new_carrier(old_carrier) span_id = self.serialized_uber_trace_id.split(":")[1] - parent_span_id = new_carrier[FORMAT.TRACE_ID_KEY].split(":")[2] + parent_span_id = new_carrier[format_._trace_id_key].split(":")[2] self.assertEqual(span_id, parent_span_id) def test_sampled_flag_set(self): - old_carrier = {FORMAT.TRACE_ID_KEY: self.serialized_uber_trace_id} + old_carrier = {format_._trace_id_key: self.serialized_uber_trace_id} _, new_carrier = get_context_new_carrier(old_carrier) sample_flag_value = ( - int(new_carrier[FORMAT.TRACE_ID_KEY].split(":")[3]) & 0x01 + int(new_carrier[format_._trace_id_key].split(":")[3]) & 0x01 ) self.assertEqual(1, sample_flag_value) def test_debug_flag_set(self): - old_carrier = {FORMAT.TRACE_ID_KEY: self.serialized_uber_trace_id} + old_carrier = {format_._trace_id_key: self.serialized_uber_trace_id} _, new_carrier = get_context_new_carrier(old_carrier) debug_flag_value = ( - int(new_carrier[FORMAT.TRACE_ID_KEY].split(":")[3]) - & FORMAT.DEBUG_FLAG + int(new_carrier[format_._trace_id_key].split(":")[3]) + & format_._debug_flag ) - self.assertEqual(FORMAT.DEBUG_FLAG, debug_flag_value) + self.assertEqual(format_._debug_flag, debug_flag_value) def test_sample_debug_flags_unset(self): - uber_trace_id = jaeger._format_uber_trace_id( # pylint: disable=protected-access + uber_trace_id = _format_uber_trace_id( self.trace_id, self.span_id, self.parent_span_id, 0 ) - old_carrier = {FORMAT.TRACE_ID_KEY: uber_trace_id} + old_carrier = {format_._trace_id_key: uber_trace_id} _, new_carrier = get_context_new_carrier(old_carrier) - flags = int(new_carrier[FORMAT.TRACE_ID_KEY].split(":")[3]) + flags = int(new_carrier[format_._trace_id_key].split(":")[3]) sample_flag_value = flags & 0x01 - debug_flag_value = flags & FORMAT.DEBUG_FLAG + debug_flag_value = flags & format_._debug_flag self.assertEqual(0, sample_flag_value) self.assertEqual(0, debug_flag_value) def test_baggage(self): - old_carrier = {FORMAT.TRACE_ID_KEY: self.serialized_uber_trace_id} + old_carrier = {format_._trace_id_key: self.serialized_uber_trace_id} input_baggage = {"key1": "value1"} _, new_carrier = get_context_new_carrier(old_carrier, input_baggage) - ctx = FORMAT.extract(carrier_getter, new_carrier) + ctx = format_.extract(new_carrier) self.assertDictEqual(input_baggage, ctx["baggage"]) def test_non_string_baggage(self): - old_carrier = {FORMAT.TRACE_ID_KEY: self.serialized_uber_trace_id} + old_carrier = {format_._trace_id_key: self.serialized_uber_trace_id} input_baggage = {"key1": 1, "key2": True} formatted_baggage = {"key1": "1", "key2": "True"} _, new_carrier = get_context_new_carrier(old_carrier, input_baggage) - ctx = FORMAT.extract(carrier_getter, new_carrier) + ctx = format_.extract(new_carrier) self.assertDictEqual(formatted_baggage, ctx["baggage"]) def test_extract_invalid_uber_trace_id(self): @@ -149,7 +147,7 @@ def test_extract_invalid_uber_trace_id(self): "uberctx-key1": "value1", } formatted_baggage = {"key1": "value1"} - context = FORMAT.extract(carrier_getter, old_carrier) + context = format_.extract(old_carrier) span_context = trace_api.get_current_span(context).get_span_context() self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID) self.assertDictEqual(formatted_baggage, context["baggage"]) @@ -160,7 +158,7 @@ def test_extract_invalid_trace_id(self): "uberctx-key1": "value1", } formatted_baggage = {"key1": "value1"} - context = FORMAT.extract(carrier_getter, old_carrier) + context = format_.extract(old_carrier) span_context = trace_api.get_current_span(context).get_span_context() self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID) self.assertDictEqual(formatted_baggage, context["baggage"]) @@ -171,7 +169,7 @@ def test_extract_invalid_span_id(self): "uberctx-key1": "value1", } formatted_baggage = {"key1": "value1"} - context = FORMAT.extract(carrier_getter, old_carrier) + context = format_.extract(old_carrier) span_context = trace_api.get_current_span(context).get_span_context() self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID) self.assertDictEqual(formatted_baggage, context["baggage"]) @@ -181,8 +179,8 @@ def test_fields(self): mock_set_in_carrier = Mock() with tracer.start_as_current_span("parent"): with tracer.start_as_current_span("child"): - FORMAT.inject(mock_set_in_carrier, {}) + format_.inject(mock_set_in_carrier, {}) inject_fields = set() for call in mock_set_in_carrier.mock_calls: inject_fields.add(call[1][1]) - self.assertEqual(FORMAT.fields, inject_fields) + self.assertEqual(format_.fields, inject_fields) diff --git a/shim/opentelemetry-opentracing-shim/src/opentelemetry/shim/opentracing_shim/__init__.py b/shim/opentelemetry-opentracing-shim/src/opentelemetry/shim/opentracing_shim/__init__.py index b7a365302f9..2327abbfae1 100644 --- a/shim/opentelemetry-opentracing-shim/src/opentelemetry/shim/opentracing_shim/__init__.py +++ b/shim/opentelemetry-opentracing-shim/src/opentelemetry/shim/opentracing_shim/__init__.py @@ -102,7 +102,6 @@ from opentelemetry.baggage import get_baggage, set_baggage from opentelemetry.context import Context, attach, detach, get_value, set_value from opentelemetry.propagate import get_global_textmap -from opentelemetry.propagators.textmap import DictGetter from opentelemetry.shim.opentracing_shim import util from opentelemetry.shim.opentracing_shim.version import __version__ from opentelemetry.trace import INVALID_SPAN_CONTEXT, Link, NonRecordingSpan @@ -527,7 +526,6 @@ def __init__(self, tracer: OtelTracer): Format.TEXT_MAP, Format.HTTP_HEADERS, ) - self._carrier_getter = DictGetter() def unwrap(self): """Returns the :class:`opentelemetry.trace.Tracer` object that is @@ -684,7 +682,7 @@ def inject(self, span_context, format: object, carrier: object): propagator = get_global_textmap() ctx = set_span_in_context(NonRecordingSpan(span_context.unwrap())) - propagator.inject(type(carrier).__setitem__, carrier, context=ctx) + propagator.inject(carrier, context=ctx) def extract(self, format: object, carrier: object): """Returns an ``opentracing.SpanContext`` instance extracted from a @@ -712,7 +710,7 @@ def extract(self, format: object, carrier: object): raise UnsupportedFormatException propagator = get_global_textmap() - ctx = propagator.extract(self._carrier_getter, carrier) + ctx = propagator.extract(carrier) span = get_current_span(ctx) if span is not None: otel_context = span.get_span_context() diff --git a/shim/opentelemetry-opentracing-shim/tests/test_shim.py b/shim/opentelemetry-opentracing-shim/tests/test_shim.py index a27d30de718..ef73843073f 100644 --- a/shim/opentelemetry-opentracing-shim/tests/test_shim.py +++ b/shim/opentelemetry-opentracing-shim/tests/test_shim.py @@ -493,9 +493,9 @@ def test_inject_http_headers(self): headers = {} self.shim.inject(context, opentracing.Format.HTTP_HEADERS, headers) self.assertEqual( - headers[MockTextMapPropagator.TRACE_ID_KEY], str(1220) + headers[MockTextMapPropagator.trace_id_key], str(1220) ) - self.assertEqual(headers[MockTextMapPropagator.SPAN_ID_KEY], str(7478)) + self.assertEqual(headers[MockTextMapPropagator.span_id_key], str(7478)) def test_inject_text_map(self): """Test `inject()` method for Format.TEXT_MAP.""" @@ -509,10 +509,10 @@ def test_inject_text_map(self): text_map = {} self.shim.inject(context, opentracing.Format.TEXT_MAP, text_map) self.assertEqual( - text_map[MockTextMapPropagator.TRACE_ID_KEY], str(1220) + text_map[MockTextMapPropagator.trace_id_key], str(1220) ) self.assertEqual( - text_map[MockTextMapPropagator.SPAN_ID_KEY], str(7478) + text_map[MockTextMapPropagator.span_id_key], str(7478) ) def test_inject_binary(self): @@ -531,8 +531,8 @@ def test_extract_http_headers(self): """Test `extract()` method for Format.HTTP_HEADERS.""" carrier = { - MockTextMapPropagator.TRACE_ID_KEY: 1220, - MockTextMapPropagator.SPAN_ID_KEY: 7478, + MockTextMapPropagator.trace_id_key: 1220, + MockTextMapPropagator.span_id_key: 7478, } ctx = self.shim.extract(opentracing.Format.HTTP_HEADERS, carrier) @@ -557,8 +557,8 @@ def test_extract_text_map(self): """Test `extract()` method for Format.TEXT_MAP.""" carrier = { - MockTextMapPropagator.TRACE_ID_KEY: 1220, - MockTextMapPropagator.SPAN_ID_KEY: 7478, + MockTextMapPropagator.trace_id_key: 1220, + MockTextMapPropagator.span_id_key: 7478, } ctx = self.shim.extract(opentracing.Format.TEXT_MAP, carrier) diff --git a/tests/util/src/opentelemetry/test/mock_textmap.py b/tests/util/src/opentelemetry/test/mock_textmap.py index 1edd079042f..1961b9089cb 100644 --- a/tests/util/src/opentelemetry/test/mock_textmap.py +++ b/tests/util/src/opentelemetry/test/mock_textmap.py @@ -12,16 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import typing +from typing import Optional -from opentelemetry import trace from opentelemetry.context import Context, get_current from opentelemetry.propagators.textmap import ( - Getter, - Setter, TextMapPropagator, TextMapPropagatorT, ) +from opentelemetry.trace import ( + INVALID_SPAN, + NonRecordingSpan, + SpanContext, + get_current_span, + set_span_in_context, +) class NOOPTextMapPropagator(TextMapPropagator): @@ -32,18 +36,12 @@ class NOOPTextMapPropagator(TextMapPropagator): """ def extract( - self, - getter: Getter[TextMapPropagatorT], - carrier: TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> Context: return get_current() def inject( - self, - set_in_carrier: Setter[TextMapPropagatorT], - carrier: TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> None: return None @@ -55,45 +53,33 @@ def fields(self): class MockTextMapPropagator(TextMapPropagator): """Mock propagator for testing purposes.""" - TRACE_ID_KEY = "mock-traceid" - SPAN_ID_KEY = "mock-spanid" + trace_id_key = "mock-traceid" + span_id_key = "mock-spanid" def extract( - self, - getter: Getter[TextMapPropagatorT], - carrier: TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> Context: - trace_id_list = getter.get(carrier, self.TRACE_ID_KEY) - span_id_list = getter.get(carrier, self.SPAN_ID_KEY) - - if not trace_id_list or not span_id_list: - return trace.set_span_in_context(trace.INVALID_SPAN) - - return trace.set_span_in_context( - trace.NonRecordingSpan( - trace.SpanContext( - trace_id=int(trace_id_list[0]), - span_id=int(span_id_list[0]), - is_remote=True, + trace_id = carrier.get(self.trace_id_key) + span_id = carrier.get(self.span_id_key) + + if trace_id is None or span_id is None: + return set_span_in_context(INVALID_SPAN) + + return set_span_in_context( + NonRecordingSpan( + SpanContext( + trace_id=trace_id, span_id=span_id, is_remote=True, ) ) ) def inject( - self, - set_in_carrier: Setter[TextMapPropagatorT], - carrier: TextMapPropagatorT, - context: typing.Optional[Context] = None, + self, carrier: TextMapPropagatorT, context: Optional[Context] = None, ) -> None: - span = trace.get_current_span(context) - set_in_carrier( - carrier, self.TRACE_ID_KEY, str(span.get_span_context().trace_id) - ) - set_in_carrier( - carrier, self.SPAN_ID_KEY, str(span.get_span_context().span_id) - ) + span = get_current_span(context) + carrier[self.trace_id_key] = str(span.get_span_context().trace_id) + carrier[self.span_id_key] = str(span.get_span_context().span_id) @property def fields(self): - return {self.TRACE_ID_KEY, self.SPAN_ID_KEY} + return {self.trace_id_key, self.span_id_key}