diff --git a/CHANGELOG.md b/CHANGELOG.md index 835601d58ce..2523b5834d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased](https://github.com/open-telemetry/opentelemetry-python/compare/v1.5.0-0.24b0...HEAD) +- Fix race in `set_tracer_provider()` + ([#2182](https://github.com/open-telemetry/opentelemetry-python/pull/2182)) - Automatically load OTEL environment variables as options for `opentelemetry-instrument` ([#1969](https://github.com/open-telemetry/opentelemetry-python/pull/1969)) - `opentelemetry-semantic-conventions` Update to semantic conventions v1.6.1 diff --git a/exporter/opentelemetry-exporter-jaeger-thrift/tests/test_jaeger_exporter_thrift.py b/exporter/opentelemetry-exporter-jaeger-thrift/tests/test_jaeger_exporter_thrift.py index c72cd579ff7..8a30527eebc 100644 --- a/exporter/opentelemetry-exporter-jaeger-thrift/tests/test_jaeger_exporter_thrift.py +++ b/exporter/opentelemetry-exporter-jaeger-thrift/tests/test_jaeger_exporter_thrift.py @@ -15,7 +15,6 @@ import unittest from unittest import mock -from unittest.mock import patch # pylint:disable=no-name-in-module # pylint:disable=import-error @@ -38,6 +37,7 @@ from opentelemetry.sdk.resources import SERVICE_NAME from opentelemetry.sdk.trace import Resource, TracerProvider from opentelemetry.sdk.util.instrumentation import InstrumentationInfo +from opentelemetry.test.globals_test import TraceGlobalsTest from opentelemetry.test.spantestutil import ( get_span_with_dropped_attributes_events_links, ) @@ -53,7 +53,7 @@ def _translate_spans_with_dropped_attributes(): return translate._translate(ThriftTranslator(max_tag_value_length=5)) -class TestJaegerExporter(unittest.TestCase): +class TestJaegerExporter(TraceGlobalsTest, unittest.TestCase): def setUp(self): # create and save span to be used in tests self.context = trace_api.SpanContext( @@ -73,7 +73,6 @@ def setUp(self): self._test_span.end(end_time=3) # pylint: disable=protected-access - @patch("opentelemetry.exporter.jaeger.thrift.trace._TRACER_PROVIDER", None) def test_constructor_default(self): # pylint: disable=protected-access """Test the default values assigned by constructor.""" @@ -98,7 +97,6 @@ def test_constructor_default(self): self.assertTrue(exporter._agent_client is not None) self.assertIsNone(exporter._max_tag_value_length) - @patch("opentelemetry.exporter.jaeger.thrift.trace._TRACER_PROVIDER", None) def test_constructor_explicit(self): # pylint: disable=protected-access """Test the constructor passing all the options.""" @@ -143,7 +141,6 @@ def test_constructor_explicit(self): self.assertTrue(exporter._collector_http_client.auth is None) self.assertEqual(exporter._max_tag_value_length, 42) - @patch("opentelemetry.exporter.jaeger.thrift.trace._TRACER_PROVIDER", None) def test_constructor_by_environment_variables(self): # pylint: disable=protected-access """Test the constructor using Environment Variables.""" @@ -198,7 +195,6 @@ def test_constructor_by_environment_variables(self): self.assertTrue(exporter._collector_http_client.auth is None) environ_patcher.stop() - @patch("opentelemetry.exporter.jaeger.thrift.trace._TRACER_PROVIDER", None) def test_constructor_with_no_traceprovider_resource(self): """Test the constructor when there is no resource attached to trace_provider""" @@ -480,7 +476,6 @@ def test_translate_to_jaeger(self): self.assertEqual(spans, expected_spans) - @patch("opentelemetry.exporter.jaeger.thrift.trace._TRACER_PROVIDER", None) def test_export(self): """Test that agent and/or collector are invoked""" @@ -511,9 +506,7 @@ def test_export(self): exporter.export((self._test_span,)) self.assertEqual(agent_client_mock.emit.call_count, 1) self.assertEqual(collector_mock.submit.call_count, 1) - # trace_api._TRACER_PROVIDER = None - @patch("opentelemetry.exporter.jaeger.thrift.trace._TRACER_PROVIDER", None) def test_export_span_service_name(self): trace_api.set_tracer_provider( TracerProvider( diff --git a/exporter/opentelemetry-exporter-opencensus/tests/test_otcollector_trace_exporter.py b/exporter/opentelemetry-exporter-opencensus/tests/test_otcollector_trace_exporter.py index 43d9bcd430b..cd4dcb1a08c 100644 --- a/exporter/opentelemetry-exporter-opencensus/tests/test_otcollector_trace_exporter.py +++ b/exporter/opentelemetry-exporter-opencensus/tests/test_otcollector_trace_exporter.py @@ -29,15 +29,12 @@ from opentelemetry.sdk.resources import SERVICE_NAME, Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import SpanExportResult +from opentelemetry.test.globals_test import TraceGlobalsTest from opentelemetry.trace import TraceFlags # pylint: disable=no-member -class TestCollectorSpanExporter(unittest.TestCase): - @mock.patch( - "opentelemetry.exporter.opencensus.trace_exporter.trace._TRACER_PROVIDER", - None, - ) +class TestCollectorSpanExporter(TraceGlobalsTest, unittest.TestCase): def test_constructor(self): mock_get_node = mock.Mock() patch = mock.patch( @@ -329,10 +326,6 @@ def test_export(self): getattr(output_identifier, "host_name"), "testHostName" ) - @mock.patch( - "opentelemetry.exporter.opencensus.trace_exporter.trace._TRACER_PROVIDER", - None, - ) def test_export_service_name(self): trace_api.set_tracer_provider( TracerProvider( diff --git a/opentelemetry-api/src/opentelemetry/trace/__init__.py b/opentelemetry-api/src/opentelemetry/trace/__init__.py index 24c42b04c64..26df821cbc3 100644 --- a/opentelemetry-api/src/opentelemetry/trace/__init__.py +++ b/opentelemetry-api/src/opentelemetry/trace/__init__.py @@ -108,6 +108,7 @@ ) from opentelemetry.trace.status import Status, StatusCode from opentelemetry.util import types +from opentelemetry.util._once import Once from opentelemetry.util._providers import _load_provider logger = getLogger(__name__) @@ -452,8 +453,9 @@ def start_as_current_span( yield INVALID_SPAN -_TRACER_PROVIDER = None -_PROXY_TRACER_PROVIDER = None +_TRACER_PROVIDER_SET_ONCE = Once() +_TRACER_PROVIDER: Optional[TracerProvider] = None +_PROXY_TRACER_PROVIDER = ProxyTracerProvider() def get_tracer( @@ -476,40 +478,40 @@ def get_tracer( ) +def _set_tracer_provider(tracer_provider: TracerProvider, log: bool) -> None: + def set_tp() -> None: + global _TRACER_PROVIDER # pylint: disable=global-statement + _TRACER_PROVIDER = tracer_provider + + did_set = _TRACER_PROVIDER_SET_ONCE.do_once(set_tp) + + if log and not did_set: + logger.warning("Overriding of current TracerProvider is not allowed") + + def set_tracer_provider(tracer_provider: TracerProvider) -> None: """Sets the current global :class:`~.TracerProvider` object. This can only be done once, a warning will be logged if any furter attempt is made. """ - global _TRACER_PROVIDER # pylint: disable=global-statement - - if _TRACER_PROVIDER is not None: - logger.warning("Overriding of current TracerProvider is not allowed") - return - - _TRACER_PROVIDER = tracer_provider + _set_tracer_provider(tracer_provider, log=True) def get_tracer_provider() -> TracerProvider: """Gets the current global :class:`~.TracerProvider` object.""" - # pylint: disable=global-statement - global _TRACER_PROVIDER - global _PROXY_TRACER_PROVIDER - if _TRACER_PROVIDER is None: # if a global tracer provider has not been set either via code or env # vars, return a proxy tracer provider if OTEL_PYTHON_TRACER_PROVIDER not in os.environ: - if not _PROXY_TRACER_PROVIDER: - _PROXY_TRACER_PROVIDER = ProxyTracerProvider() return _PROXY_TRACER_PROVIDER - _TRACER_PROVIDER = cast( # type: ignore - "TracerProvider", - _load_provider(OTEL_PYTHON_TRACER_PROVIDER, "tracer_provider"), + tracer_provider: TracerProvider = _load_provider( + OTEL_PYTHON_TRACER_PROVIDER, "tracer_provider" ) - return _TRACER_PROVIDER + _set_tracer_provider(tracer_provider, log=False) + # _TRACER_PROVIDER will have been set by one thread + return cast("TracerProvider", _TRACER_PROVIDER) @contextmanager # type: ignore diff --git a/opentelemetry-api/src/opentelemetry/util/_once.py b/opentelemetry-api/src/opentelemetry/util/_once.py new file mode 100644 index 00000000000..c0cee43a174 --- /dev/null +++ b/opentelemetry-api/src/opentelemetry/util/_once.py @@ -0,0 +1,47 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from threading import Lock +from typing import Callable + + +class Once: + """Execute a function exactly once and block all callers until the function returns + + Same as golang's `sync.Once `_ + """ + + def __init__(self) -> None: + self._lock = Lock() + self._done = False + + def do_once(self, func: Callable[[], None]) -> bool: + """Execute ``func`` if it hasn't been executed or return. + + Will block until ``func`` has been called by one thread. + + Returns: + Whether or not ``func`` was executed in this call + """ + + # fast path, try to avoid locking + if self._done: + return False + + with self._lock: + if not self._done: + func() + self._done = True + return True + return False diff --git a/opentelemetry-api/tests/trace/test_globals.py b/opentelemetry-api/tests/trace/test_globals.py index 034a97e4ded..421b72d65fe 100644 --- a/opentelemetry-api/tests/trace/test_globals.py +++ b/opentelemetry-api/tests/trace/test_globals.py @@ -1,7 +1,9 @@ import unittest -from unittest.mock import patch +from unittest.mock import Mock, patch from opentelemetry import context, trace +from opentelemetry.test.concurrency_test import ConcurrencyTestBase, MockFunc +from opentelemetry.test.globals_test import TraceGlobalsTest from opentelemetry.trace.status import Status, StatusCode @@ -25,25 +27,60 @@ def record_exception( self.recorded_exception = exception -class TestGlobals(unittest.TestCase): - def setUp(self): - self._patcher = patch("opentelemetry.trace._TRACER_PROVIDER") - self._mock_tracer_provider = self._patcher.start() - - def tearDown(self) -> None: - self._patcher.stop() - - def test_get_tracer(self): +class TestGlobals(TraceGlobalsTest, unittest.TestCase): + @staticmethod + @patch("opentelemetry.trace._TRACER_PROVIDER") + def test_get_tracer(mock_tracer_provider): # type: ignore """trace.get_tracer should proxy to the global tracer provider.""" trace.get_tracer("foo", "var") - self._mock_tracer_provider.get_tracer.assert_called_with( - "foo", "var", None - ) - mock_provider = unittest.mock.Mock() + mock_tracer_provider.get_tracer.assert_called_with("foo", "var", None) + mock_provider = Mock() trace.get_tracer("foo", "var", mock_provider) mock_provider.get_tracer.assert_called_with("foo", "var", None) +class TestGlobalsConcurrency(TraceGlobalsTest, ConcurrencyTestBase): + @patch("opentelemetry.trace.logger") + def test_set_tracer_provider_many_threads(self, mock_logger) -> None: # type: ignore + mock_logger.warning = MockFunc() + + def do_concurrently() -> Mock: + # first get a proxy tracer + proxy_tracer = trace.ProxyTracerProvider().get_tracer("foo") + + # try to set the global tracer provider + mock_tracer_provider = Mock(get_tracer=MockFunc()) + trace.set_tracer_provider(mock_tracer_provider) + + # start a span through the proxy which will call through to the mock provider + proxy_tracer.start_span("foo") + + return mock_tracer_provider + + num_threads = 100 + mock_tracer_providers = self.run_with_many_threads( + do_concurrently, + num_threads=num_threads, + ) + + # despite trying to set tracer provider many times, only one of the + # mock_tracer_providers should have stuck and been called from + # proxy_tracer.start_span() + mock_tps_with_any_call = [ + mock + for mock in mock_tracer_providers + if mock.get_tracer.call_count > 0 + ] + + self.assertEqual(len(mock_tps_with_any_call), 1) + self.assertEqual( + mock_tps_with_any_call[0].get_tracer.call_count, num_threads + ) + + # should have warned everytime except for the successful set + self.assertEqual(mock_logger.warning.call_count, num_threads - 1) + + class TestTracer(unittest.TestCase): def setUp(self): # pylint: disable=protected-access diff --git a/opentelemetry-api/tests/trace/test_proxy.py b/opentelemetry-api/tests/trace/test_proxy.py index 42a31b41322..da1d60c74e1 100644 --- a/opentelemetry-api/tests/trace/test_proxy.py +++ b/opentelemetry-api/tests/trace/test_proxy.py @@ -17,6 +17,7 @@ import unittest from opentelemetry import trace +from opentelemetry.test.globals_test import TraceGlobalsTest from opentelemetry.trace.span import INVALID_SPAN_CONTEXT, NonRecordingSpan @@ -39,10 +40,8 @@ class TestSpan(NonRecordingSpan): pass -class TestProxy(unittest.TestCase): +class TestProxy(TraceGlobalsTest, unittest.TestCase): def test_proxy_tracer(self): - original_provider = trace._TRACER_PROVIDER - provider = trace.get_tracer_provider() # proxy provider self.assertIsInstance(provider, trace.ProxyTracerProvider) @@ -60,6 +59,9 @@ def test_proxy_tracer(self): # set a real provider trace.set_tracer_provider(TestProvider()) + # get_tracer_provider() now returns the real provider + self.assertIsInstance(trace.get_tracer_provider(), TestProvider) + # tracer provider now returns real instance self.assertIsInstance(trace.get_tracer_provider(), TestProvider) @@ -71,5 +73,3 @@ def test_proxy_tracer(self): # creates real spans with tracer.start_span("") as span: self.assertIsInstance(span, TestSpan) - - trace._TRACER_PROVIDER = original_provider diff --git a/opentelemetry-api/tests/util/test_once.py b/opentelemetry-api/tests/util/test_once.py new file mode 100644 index 00000000000..ee94318d228 --- /dev/null +++ b/opentelemetry-api/tests/util/test_once.py @@ -0,0 +1,48 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from opentelemetry.test.concurrency_test import ConcurrencyTestBase, MockFunc +from opentelemetry.util._once import Once + + +class TestOnce(ConcurrencyTestBase): + def test_once_single_thread(self): + once_func = MockFunc() + once = Once() + + self.assertEqual(once_func.call_count, 0) + + # first call should run + called = once.do_once(once_func) + self.assertTrue(called) + self.assertEqual(once_func.call_count, 1) + + # subsequent calls do nothing + called = once.do_once(once_func) + self.assertFalse(called) + self.assertEqual(once_func.call_count, 1) + + def test_once_many_threads(self): + once_func = MockFunc() + once = Once() + + def run_concurrently() -> bool: + return once.do_once(once_func) + + results = self.run_with_many_threads(run_concurrently, num_threads=100) + + self.assertEqual(once_func.call_count, 1) + + # check that only one of the threads got True + self.assertEqual(results.count(True), 1) diff --git a/tests/util/src/opentelemetry/test/concurrency_test.py b/tests/util/src/opentelemetry/test/concurrency_test.py new file mode 100644 index 00000000000..5d178e24fff --- /dev/null +++ b/tests/util/src/opentelemetry/test/concurrency_test.py @@ -0,0 +1,90 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import threading +import unittest +from functools import partial +from typing import Callable, List, Optional, TypeVar +from unittest.mock import Mock + +ReturnT = TypeVar("ReturnT") + + +class MockFunc: + """A thread safe mock function + + Use this as part of your mock if you want to count calls across multiple + threads. + """ + + def __init__(self) -> None: + self.lock = threading.Lock() + self.call_count = 0 + self.mock = Mock() + + def __call__(self, *args, **kwargs): + with self.lock: + self.call_count += 1 + return self.mock + + +class ConcurrencyTestBase(unittest.TestCase): + """Test base class/mixin for tests of concurrent code + + This test class calls ``sys.setswitchinterval(1e-12)`` to try to create more + contention while running tests that use many threads. It also provides + ``run_with_many_threads`` to run some test code in many threads + concurrently. + """ + + orig_switch_interval = sys.getswitchinterval() + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + # switch threads more often to increase chance of contention + sys.setswitchinterval(1e-12) + + @classmethod + def tearDownClass(cls) -> None: + super().tearDownClass() + sys.setswitchinterval(cls.orig_switch_interval) + + @staticmethod + def run_with_many_threads( + func_to_test: Callable[[], ReturnT], + num_threads: int = 100, + ) -> List[ReturnT]: + """Util to run ``func_to_test`` in ``num_threads`` concurrently""" + + barrier = threading.Barrier(num_threads) + results: List[Optional[ReturnT]] = [None] * num_threads + + def thread_start(idx: int) -> None: + nonlocal results + # Get all threads here before releasing them to create contention + barrier.wait() + results[idx] = func_to_test() + + threads = [ + threading.Thread(target=partial(thread_start, i)) + for i in range(num_threads) + ] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + return results # type: ignore diff --git a/tests/util/src/opentelemetry/test/globals_test.py b/tests/util/src/opentelemetry/test/globals_test.py new file mode 100644 index 00000000000..bb2cad6a0ac --- /dev/null +++ b/tests/util/src/opentelemetry/test/globals_test.py @@ -0,0 +1,41 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from opentelemetry import trace as trace_api +from opentelemetry.util._once import Once + + +# pylint: disable=protected-access +def reset_trace_globals() -> None: + """WARNING: only use this for tests.""" + trace_api._TRACER_PROVIDER_SET_ONCE = Once() + trace_api._TRACER_PROVIDER = None + trace_api._PROXY_TRACER_PROVIDER = trace_api.ProxyTracerProvider() + + +class TraceGlobalsTest(unittest.TestCase): + """Resets trace API globals in setUp/tearDown + + Use as a base class or mixin for your test that modifies trace API globals. + """ + + def setUp(self) -> None: + super().setUp() + reset_trace_globals() + + def tearDown(self) -> None: + super().tearDown() + reset_trace_globals() diff --git a/tests/util/src/opentelemetry/test/test_base.py b/tests/util/src/opentelemetry/test/test_base.py index 9762a08010e..f176238add3 100644 --- a/tests/util/src/opentelemetry/test/test_base.py +++ b/tests/util/src/opentelemetry/test/test_base.py @@ -21,6 +21,7 @@ from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( InMemorySpanExporter, ) +from opentelemetry.test.globals_test import reset_trace_globals class TestBase(unittest.TestCase): @@ -28,20 +29,18 @@ class TestBase(unittest.TestCase): @classmethod def setUpClass(cls): - cls.original_tracer_provider = trace_api.get_tracer_provider() result = cls.create_tracer_provider() cls.tracer_provider, cls.memory_exporter = result # This is done because set_tracer_provider cannot override the # current tracer provider. - trace_api._TRACER_PROVIDER = None # pylint: disable=protected-access + reset_trace_globals() trace_api.set_tracer_provider(cls.tracer_provider) @classmethod def tearDownClass(cls): # This is done because set_tracer_provider cannot override the # current tracer provider. - trace_api._TRACER_PROVIDER = None # pylint: disable=protected-access - trace_api.set_tracer_provider(cls.original_tracer_provider) + reset_trace_globals() def setUp(self): self.memory_exporter.clear() diff --git a/tox.ini b/tox.ini index 724a7f20185..0210e7e6e9d 100644 --- a/tox.ini +++ b/tox.ini @@ -85,7 +85,7 @@ setenv = ; i.e: CONTRIB_REPO_SHA=dde62cebffe519c35875af6d06fae053b3be65ec tox -e CONTRIB_REPO_SHA={env:CONTRIB_REPO_SHA:"main"} CONTRIB_REPO="git+https://github.com/open-telemetry/opentelemetry-python-contrib.git@{env:CONTRIB_REPO_SHA}" - mypy: MYPYPATH={toxinidir}/opentelemetry-api/src/ + mypy: MYPYPATH={toxinidir}/opentelemetry-api/src/:{toxinidir}/tests/util/src/ changedir = api: opentelemetry-api/tests