Skip to content

Commit

Permalink
Fix race in set_tracer_provider() (#2182)
Browse files Browse the repository at this point in the history
* Fix race in set_tracer_provider

* refactor _reset_globals to a test util

* get rid of "Mixin" name and simplify code a bit

* add some comments to concurrency_test.py

* actually respect log option

Co-authored-by: Diego Hurtado <ocelotl@users.noreply.github.com>
Co-authored-by: Leighton Chen <lechen@microsoft.com>
Co-authored-by: Owais Lone <owais@users.noreply.github.com>
  • Loading branch information
4 people authored Oct 12, 2021
1 parent 0770fcd commit 7867202
Show file tree
Hide file tree
Showing 12 changed files with 313 additions and 61 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased](https://github.com/open-telemetry/opentelemetry-python/compare/v1.5.0-0.24b0...HEAD)
- Fix race in `set_tracer_provider()`
([#2182](https://github.com/open-telemetry/opentelemetry-python/pull/2182))
- Automatically load OTEL environment variables as options for `opentelemetry-instrument`
([#1969](https://github.com/open-telemetry/opentelemetry-python/pull/1969))
- `opentelemetry-semantic-conventions` Update to semantic conventions v1.6.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import unittest
from unittest import mock
from unittest.mock import patch

# pylint:disable=no-name-in-module
# pylint:disable=import-error
Expand All @@ -38,6 +37,7 @@
from opentelemetry.sdk.resources import SERVICE_NAME
from opentelemetry.sdk.trace import Resource, TracerProvider
from opentelemetry.sdk.util.instrumentation import InstrumentationInfo
from opentelemetry.test.globals_test import TraceGlobalsTest
from opentelemetry.test.spantestutil import (
get_span_with_dropped_attributes_events_links,
)
Expand All @@ -53,7 +53,7 @@ def _translate_spans_with_dropped_attributes():
return translate._translate(ThriftTranslator(max_tag_value_length=5))


class TestJaegerExporter(unittest.TestCase):
class TestJaegerExporter(TraceGlobalsTest, unittest.TestCase):
def setUp(self):
# create and save span to be used in tests
self.context = trace_api.SpanContext(
Expand All @@ -73,7 +73,6 @@ def setUp(self):
self._test_span.end(end_time=3)
# pylint: disable=protected-access

@patch("opentelemetry.exporter.jaeger.thrift.trace._TRACER_PROVIDER", None)
def test_constructor_default(self):
# pylint: disable=protected-access
"""Test the default values assigned by constructor."""
Expand All @@ -98,7 +97,6 @@ def test_constructor_default(self):
self.assertTrue(exporter._agent_client is not None)
self.assertIsNone(exporter._max_tag_value_length)

@patch("opentelemetry.exporter.jaeger.thrift.trace._TRACER_PROVIDER", None)
def test_constructor_explicit(self):
# pylint: disable=protected-access
"""Test the constructor passing all the options."""
Expand Down Expand Up @@ -143,7 +141,6 @@ def test_constructor_explicit(self):
self.assertTrue(exporter._collector_http_client.auth is None)
self.assertEqual(exporter._max_tag_value_length, 42)

@patch("opentelemetry.exporter.jaeger.thrift.trace._TRACER_PROVIDER", None)
def test_constructor_by_environment_variables(self):
# pylint: disable=protected-access
"""Test the constructor using Environment Variables."""
Expand Down Expand Up @@ -198,7 +195,6 @@ def test_constructor_by_environment_variables(self):
self.assertTrue(exporter._collector_http_client.auth is None)
environ_patcher.stop()

@patch("opentelemetry.exporter.jaeger.thrift.trace._TRACER_PROVIDER", None)
def test_constructor_with_no_traceprovider_resource(self):

"""Test the constructor when there is no resource attached to trace_provider"""
Expand Down Expand Up @@ -480,7 +476,6 @@ def test_translate_to_jaeger(self):

self.assertEqual(spans, expected_spans)

@patch("opentelemetry.exporter.jaeger.thrift.trace._TRACER_PROVIDER", None)
def test_export(self):

"""Test that agent and/or collector are invoked"""
Expand Down Expand Up @@ -511,9 +506,7 @@ def test_export(self):
exporter.export((self._test_span,))
self.assertEqual(agent_client_mock.emit.call_count, 1)
self.assertEqual(collector_mock.submit.call_count, 1)
# trace_api._TRACER_PROVIDER = None

@patch("opentelemetry.exporter.jaeger.thrift.trace._TRACER_PROVIDER", None)
def test_export_span_service_name(self):
trace_api.set_tracer_provider(
TracerProvider(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,12 @@
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SpanExportResult
from opentelemetry.test.globals_test import TraceGlobalsTest
from opentelemetry.trace import TraceFlags


# pylint: disable=no-member
class TestCollectorSpanExporter(unittest.TestCase):
@mock.patch(
"opentelemetry.exporter.opencensus.trace_exporter.trace._TRACER_PROVIDER",
None,
)
class TestCollectorSpanExporter(TraceGlobalsTest, unittest.TestCase):
def test_constructor(self):
mock_get_node = mock.Mock()
patch = mock.patch(
Expand Down Expand Up @@ -329,10 +326,6 @@ def test_export(self):
getattr(output_identifier, "host_name"), "testHostName"
)

@mock.patch(
"opentelemetry.exporter.opencensus.trace_exporter.trace._TRACER_PROVIDER",
None,
)
def test_export_service_name(self):
trace_api.set_tracer_provider(
TracerProvider(
Expand Down
40 changes: 21 additions & 19 deletions opentelemetry-api/src/opentelemetry/trace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
)
from opentelemetry.trace.status import Status, StatusCode
from opentelemetry.util import types
from opentelemetry.util._once import Once
from opentelemetry.util._providers import _load_provider

logger = getLogger(__name__)
Expand Down Expand Up @@ -452,8 +453,9 @@ def start_as_current_span(
yield INVALID_SPAN


_TRACER_PROVIDER = None
_PROXY_TRACER_PROVIDER = None
_TRACER_PROVIDER_SET_ONCE = Once()
_TRACER_PROVIDER: Optional[TracerProvider] = None
_PROXY_TRACER_PROVIDER = ProxyTracerProvider()


def get_tracer(
Expand All @@ -476,40 +478,40 @@ def get_tracer(
)


def _set_tracer_provider(tracer_provider: TracerProvider, log: bool) -> None:
def set_tp() -> None:
global _TRACER_PROVIDER # pylint: disable=global-statement
_TRACER_PROVIDER = tracer_provider

did_set = _TRACER_PROVIDER_SET_ONCE.do_once(set_tp)

if log and not did_set:
logger.warning("Overriding of current TracerProvider is not allowed")


def set_tracer_provider(tracer_provider: TracerProvider) -> None:
"""Sets the current global :class:`~.TracerProvider` object.
This can only be done once, a warning will be logged if any furter attempt
is made.
"""
global _TRACER_PROVIDER # pylint: disable=global-statement

if _TRACER_PROVIDER is not None:
logger.warning("Overriding of current TracerProvider is not allowed")
return

_TRACER_PROVIDER = tracer_provider
_set_tracer_provider(tracer_provider, log=True)


def get_tracer_provider() -> TracerProvider:
"""Gets the current global :class:`~.TracerProvider` object."""
# pylint: disable=global-statement
global _TRACER_PROVIDER
global _PROXY_TRACER_PROVIDER

if _TRACER_PROVIDER is None:
# if a global tracer provider has not been set either via code or env
# vars, return a proxy tracer provider
if OTEL_PYTHON_TRACER_PROVIDER not in os.environ:
if not _PROXY_TRACER_PROVIDER:
_PROXY_TRACER_PROVIDER = ProxyTracerProvider()
return _PROXY_TRACER_PROVIDER

_TRACER_PROVIDER = cast( # type: ignore
"TracerProvider",
_load_provider(OTEL_PYTHON_TRACER_PROVIDER, "tracer_provider"),
tracer_provider: TracerProvider = _load_provider(
OTEL_PYTHON_TRACER_PROVIDER, "tracer_provider"
)
return _TRACER_PROVIDER
_set_tracer_provider(tracer_provider, log=False)
# _TRACER_PROVIDER will have been set by one thread
return cast("TracerProvider", _TRACER_PROVIDER)


@contextmanager # type: ignore
Expand Down
47 changes: 47 additions & 0 deletions opentelemetry-api/src/opentelemetry/util/_once.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from threading import Lock
from typing import Callable


class Once:
"""Execute a function exactly once and block all callers until the function returns
Same as golang's `sync.Once <https://pkg.go.dev/sync#Once>`_
"""

def __init__(self) -> None:
self._lock = Lock()
self._done = False

def do_once(self, func: Callable[[], None]) -> bool:
"""Execute ``func`` if it hasn't been executed or return.
Will block until ``func`` has been called by one thread.
Returns:
Whether or not ``func`` was executed in this call
"""

# fast path, try to avoid locking
if self._done:
return False

with self._lock:
if not self._done:
func()
self._done = True
return True
return False
65 changes: 51 additions & 14 deletions opentelemetry-api/tests/trace/test_globals.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import unittest
from unittest.mock import patch
from unittest.mock import Mock, patch

from opentelemetry import context, trace
from opentelemetry.test.concurrency_test import ConcurrencyTestBase, MockFunc
from opentelemetry.test.globals_test import TraceGlobalsTest
from opentelemetry.trace.status import Status, StatusCode


Expand All @@ -25,25 +27,60 @@ def record_exception(
self.recorded_exception = exception


class TestGlobals(unittest.TestCase):
def setUp(self):
self._patcher = patch("opentelemetry.trace._TRACER_PROVIDER")
self._mock_tracer_provider = self._patcher.start()

def tearDown(self) -> None:
self._patcher.stop()

def test_get_tracer(self):
class TestGlobals(TraceGlobalsTest, unittest.TestCase):
@staticmethod
@patch("opentelemetry.trace._TRACER_PROVIDER")
def test_get_tracer(mock_tracer_provider): # type: ignore
"""trace.get_tracer should proxy to the global tracer provider."""
trace.get_tracer("foo", "var")
self._mock_tracer_provider.get_tracer.assert_called_with(
"foo", "var", None
)
mock_provider = unittest.mock.Mock()
mock_tracer_provider.get_tracer.assert_called_with("foo", "var", None)
mock_provider = Mock()
trace.get_tracer("foo", "var", mock_provider)
mock_provider.get_tracer.assert_called_with("foo", "var", None)


class TestGlobalsConcurrency(TraceGlobalsTest, ConcurrencyTestBase):
@patch("opentelemetry.trace.logger")
def test_set_tracer_provider_many_threads(self, mock_logger) -> None: # type: ignore
mock_logger.warning = MockFunc()

def do_concurrently() -> Mock:
# first get a proxy tracer
proxy_tracer = trace.ProxyTracerProvider().get_tracer("foo")

# try to set the global tracer provider
mock_tracer_provider = Mock(get_tracer=MockFunc())
trace.set_tracer_provider(mock_tracer_provider)

# start a span through the proxy which will call through to the mock provider
proxy_tracer.start_span("foo")

return mock_tracer_provider

num_threads = 100
mock_tracer_providers = self.run_with_many_threads(
do_concurrently,
num_threads=num_threads,
)

# despite trying to set tracer provider many times, only one of the
# mock_tracer_providers should have stuck and been called from
# proxy_tracer.start_span()
mock_tps_with_any_call = [
mock
for mock in mock_tracer_providers
if mock.get_tracer.call_count > 0
]

self.assertEqual(len(mock_tps_with_any_call), 1)
self.assertEqual(
mock_tps_with_any_call[0].get_tracer.call_count, num_threads
)

# should have warned everytime except for the successful set
self.assertEqual(mock_logger.warning.call_count, num_threads - 1)


class TestTracer(unittest.TestCase):
def setUp(self):
# pylint: disable=protected-access
Expand Down
10 changes: 5 additions & 5 deletions opentelemetry-api/tests/trace/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import unittest

from opentelemetry import trace
from opentelemetry.test.globals_test import TraceGlobalsTest
from opentelemetry.trace.span import INVALID_SPAN_CONTEXT, NonRecordingSpan


Expand All @@ -39,10 +40,8 @@ class TestSpan(NonRecordingSpan):
pass


class TestProxy(unittest.TestCase):
class TestProxy(TraceGlobalsTest, unittest.TestCase):
def test_proxy_tracer(self):
original_provider = trace._TRACER_PROVIDER

provider = trace.get_tracer_provider()
# proxy provider
self.assertIsInstance(provider, trace.ProxyTracerProvider)
Expand All @@ -60,6 +59,9 @@ def test_proxy_tracer(self):
# set a real provider
trace.set_tracer_provider(TestProvider())

# get_tracer_provider() now returns the real provider
self.assertIsInstance(trace.get_tracer_provider(), TestProvider)

# tracer provider now returns real instance
self.assertIsInstance(trace.get_tracer_provider(), TestProvider)

Expand All @@ -71,5 +73,3 @@ def test_proxy_tracer(self):
# creates real spans
with tracer.start_span("") as span:
self.assertIsInstance(span, TestSpan)

trace._TRACER_PROVIDER = original_provider
Loading

0 comments on commit 7867202

Please sign in to comment.