Skip to content

Commit 30ba497

Browse files
committed
Fix race in set_tracer_provider
1 parent c9b18c6 commit 30ba497

File tree

11 files changed

+292
-38
lines changed

11 files changed

+292
-38
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

77
## [Unreleased](https://github.com/open-telemetry/opentelemetry-python/compare/v1.5.0-0.24b0...HEAD)
8+
- Fix race in `set_tracer_provider()`
9+
([#2182](https://github.com/open-telemetry/opentelemetry-python/pull/2182))
810
- Automatically load OTEL environment variables as options for `opentelemetry-instrument`
911
([#1969](https://github.com/open-telemetry/opentelemetry-python/pull/1969))
1012
- `opentelemetry-semantic-conventions` Update to semantic conventions v1.6.1

exporter/opentelemetry-exporter-jaeger-thrift/tests/test_jaeger_exporter_thrift.py

+5
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def _translate_spans_with_dropped_attributes():
5555

5656
class TestJaegerExporter(unittest.TestCase):
5757
def setUp(self):
58+
trace_api._reset_globals() # pylint: disable=protected-access
5859
# create and save span to be used in tests
5960
self.context = trace_api.SpanContext(
6061
trace_id=0x000000000000000000000000DEADBEEF,
@@ -73,6 +74,10 @@ def setUp(self):
7374
self._test_span.end(end_time=3)
7475
# pylint: disable=protected-access
7576

77+
def tearDown(self):
78+
super().tearDown()
79+
trace_api._reset_globals() # pylint: disable=protected-access
80+
7681
@patch("opentelemetry.exporter.jaeger.thrift.trace._TRACER_PROVIDER", None)
7782
def test_constructor_default(self):
7883
# pylint: disable=protected-access

exporter/opentelemetry-exporter-opencensus/tests/test_otcollector_trace_exporter.py

+8
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@
3434

3535
# pylint: disable=no-member
3636
class TestCollectorSpanExporter(unittest.TestCase):
37+
def setUp(self):
38+
super().setUp()
39+
trace_api._reset_globals() # pylint: disable=protected-access
40+
41+
def tearDown(self):
42+
super().tearDown()
43+
trace_api._reset_globals() # pylint: disable=protected-access
44+
3745
@mock.patch(
3846
"opentelemetry.exporter.opencensus.trace_exporter.trace._TRACER_PROVIDER",
3947
None,

opentelemetry-api/src/opentelemetry/trace/__init__.py

+31-19
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@
108108
)
109109
from opentelemetry.trace.status import Status, StatusCode
110110
from opentelemetry.util import types
111+
from opentelemetry.util._once import Once
111112
from opentelemetry.util._providers import _load_provider
112113

113114
logger = getLogger(__name__)
@@ -452,8 +453,19 @@ def start_as_current_span(
452453
yield INVALID_SPAN
453454

454455

455-
_TRACER_PROVIDER = None
456-
_PROXY_TRACER_PROVIDER = None
456+
_TRACER_PROVIDER_SET_ONCE = Once()
457+
_TRACER_PROVIDER: Optional[TracerProvider] = None
458+
_PROXY_TRACER_PROVIDER = ProxyTracerProvider()
459+
460+
461+
def _reset_globals() -> None:
462+
"""WARNING: only use this for tests."""
463+
global _TRACER_PROVIDER_SET_ONCE # pylint: disable=global-statement
464+
global _TRACER_PROVIDER # pylint: disable=global-statement
465+
global _PROXY_TRACER_PROVIDER # pylint: disable=global-statement
466+
_TRACER_PROVIDER_SET_ONCE = Once()
467+
_TRACER_PROVIDER = None
468+
_PROXY_TRACER_PROVIDER = ProxyTracerProvider()
457469

458470

459471
def get_tracer(
@@ -476,40 +488,40 @@ def get_tracer(
476488
)
477489

478490

491+
def _set_tracer_provider(tracer_provider: TracerProvider, log: bool) -> None:
492+
def set_tp() -> None:
493+
global _TRACER_PROVIDER # pylint: disable=global-statement
494+
_TRACER_PROVIDER = tracer_provider
495+
496+
did_set = _TRACER_PROVIDER_SET_ONCE.do_once(set_tp)
497+
498+
if not did_set:
499+
logger.warning("Overriding of current TracerProvider is not allowed")
500+
501+
479502
def set_tracer_provider(tracer_provider: TracerProvider) -> None:
480503
"""Sets the current global :class:`~.TracerProvider` object.
481504
482505
This can only be done once, a warning will be logged if any furter attempt
483506
is made.
484507
"""
485-
global _TRACER_PROVIDER # pylint: disable=global-statement
486-
487-
if _TRACER_PROVIDER is not None:
488-
logger.warning("Overriding of current TracerProvider is not allowed")
489-
return
490-
491-
_TRACER_PROVIDER = tracer_provider
508+
_set_tracer_provider(tracer_provider, log=True)
492509

493510

494511
def get_tracer_provider() -> TracerProvider:
495512
"""Gets the current global :class:`~.TracerProvider` object."""
496-
# pylint: disable=global-statement
497-
global _TRACER_PROVIDER
498-
global _PROXY_TRACER_PROVIDER
499-
500513
if _TRACER_PROVIDER is None:
501514
# if a global tracer provider has not been set either via code or env
502515
# vars, return a proxy tracer provider
503516
if OTEL_PYTHON_TRACER_PROVIDER not in os.environ:
504-
if not _PROXY_TRACER_PROVIDER:
505-
_PROXY_TRACER_PROVIDER = ProxyTracerProvider()
506517
return _PROXY_TRACER_PROVIDER
507518

508-
_TRACER_PROVIDER = cast( # type: ignore
509-
"TracerProvider",
510-
_load_provider(OTEL_PYTHON_TRACER_PROVIDER, "tracer_provider"),
519+
tracer_provider: TracerProvider = _load_provider(
520+
OTEL_PYTHON_TRACER_PROVIDER, "tracer_provider"
511521
)
512-
return _TRACER_PROVIDER
522+
_set_tracer_provider(tracer_provider, log=False)
523+
# _TRACER_PROVIDER will have been set by one thread
524+
return cast("TracerProvider", _TRACER_PROVIDER)
513525

514526

515527
@contextmanager # type: ignore
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright The OpenTelemetry Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from threading import Lock
16+
from typing import Callable
17+
18+
19+
class Once:
20+
"""Execute a function exactly once and block all callers until the function returns
21+
22+
Same as golang's `sync.Once <https://pkg.go.dev/sync#Once>`_
23+
"""
24+
25+
def __init__(self) -> None:
26+
self._lock = Lock()
27+
self._done = False
28+
29+
def do_once(self, func: Callable[[], None]) -> bool:
30+
"""Execute ``func`` if it hasn't been executed or return.
31+
32+
Will block until ``func`` has been called by one thread.
33+
34+
Returns:
35+
Whether or not ``func`` was executed in this call
36+
"""
37+
38+
# fast path, try to avoid locking
39+
if self._done:
40+
return False
41+
42+
with self._lock:
43+
if not self._done:
44+
func()
45+
self._done = True
46+
return True
47+
return False

opentelemetry-api/tests/trace/test_globals.py

+62-10
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import unittest
2-
from unittest.mock import patch
2+
from unittest.mock import Mock, patch
33

44
from opentelemetry import context, trace
5+
from opentelemetry.test.concurrency_test import ConcurrencyTestBase, MockFunc
56
from opentelemetry.trace.status import Status, StatusCode
67

78

@@ -27,23 +28,74 @@ def record_exception(
2728

2829
class TestGlobals(unittest.TestCase):
2930
def setUp(self):
30-
self._patcher = patch("opentelemetry.trace._TRACER_PROVIDER")
31-
self._mock_tracer_provider = self._patcher.start()
31+
super().setUp()
32+
trace._reset_globals() # pylint: disable=protected-access
3233

33-
def tearDown(self) -> None:
34-
self._patcher.stop()
34+
def tearDown(self):
35+
super().tearDown()
36+
trace._reset_globals() # pylint: disable=protected-access
3537

36-
def test_get_tracer(self):
38+
@staticmethod
39+
@patch("opentelemetry.trace._TRACER_PROVIDER")
40+
def test_get_tracer(mock_tracer_provider): # type: ignore
3741
"""trace.get_tracer should proxy to the global tracer provider."""
3842
trace.get_tracer("foo", "var")
39-
self._mock_tracer_provider.get_tracer.assert_called_with(
40-
"foo", "var", None
41-
)
42-
mock_provider = unittest.mock.Mock()
43+
mock_tracer_provider.get_tracer.assert_called_with("foo", "var", None)
44+
mock_provider = Mock()
4345
trace.get_tracer("foo", "var", mock_provider)
4446
mock_provider.get_tracer.assert_called_with("foo", "var", None)
4547

4648

49+
class TestGlobalsConcurrency(ConcurrencyTestBase):
50+
def setUp(self):
51+
super().setUp()
52+
trace._reset_globals() # pylint: disable=protected-access
53+
54+
def tearDown(self):
55+
super().tearDown()
56+
trace._reset_globals() # pylint: disable=protected-access
57+
58+
@patch("opentelemetry.trace.logger")
59+
def test_set_tracer_provider_many_threads(self, mock_logger) -> None: # type: ignore
60+
mock_logger.warning = MockFunc()
61+
62+
def do_concurrently() -> Mock:
63+
# first get a proxy tracer
64+
proxy_tracer = trace.ProxyTracerProvider().get_tracer("foo")
65+
66+
# try to set the global tracer provider
67+
mock_tracer_provider = Mock(get_tracer=MockFunc())
68+
trace.set_tracer_provider(mock_tracer_provider)
69+
70+
# start a span through the proxy which will call through to the mock provider
71+
proxy_tracer.start_span("foo")
72+
73+
return mock_tracer_provider
74+
75+
num_threads = 100
76+
mock_tracer_providers = self.run_with_many_threads(
77+
do_concurrently,
78+
num_threads=num_threads,
79+
)
80+
81+
# despite trying to set tracer provider many times, only one of the
82+
# mock_tracer_providers should have stuck and been called from
83+
# proxy_tracer.start_span()
84+
mock_tps_with_any_call = [
85+
mock
86+
for mock in mock_tracer_providers
87+
if mock.get_tracer.call_count > 0
88+
]
89+
90+
self.assertEqual(len(mock_tps_with_any_call), 1)
91+
self.assertEqual(
92+
mock_tps_with_any_call[0].get_tracer.call_count, num_threads
93+
)
94+
95+
# should have warned everytime except for the successful set
96+
self.assertEqual(mock_logger.warning.call_count, num_threads - 1)
97+
98+
4799
class TestTracer(unittest.TestCase):
48100
def setUp(self):
49101
# pylint: disable=protected-access

opentelemetry-api/tests/trace/test_proxy.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,15 @@ class TestSpan(NonRecordingSpan):
4040

4141

4242
class TestProxy(unittest.TestCase):
43-
def test_proxy_tracer(self):
44-
original_provider = trace._TRACER_PROVIDER
43+
def setUp(self) -> None:
44+
super().setUp()
45+
trace._reset_globals()
46+
47+
def tearDown(self) -> None:
48+
super().tearDown()
49+
trace._reset_globals()
4550

51+
def test_proxy_tracer(self):
4652
provider = trace.get_tracer_provider()
4753
# proxy provider
4854
self.assertIsInstance(provider, trace.ProxyTracerProvider)
@@ -60,6 +66,9 @@ def test_proxy_tracer(self):
6066
# set a real provider
6167
trace.set_tracer_provider(TestProvider())
6268

69+
# get_tracer_provider() now returns the real provider
70+
self.assertIsInstance(trace.get_tracer_provider(), TestProvider)
71+
6372
# tracer provider now returns real instance
6473
self.assertIsInstance(trace.get_tracer_provider(), TestProvider)
6574

@@ -71,5 +80,3 @@ def test_proxy_tracer(self):
7180
# creates real spans
7281
with tracer.start_span("") as span:
7382
self.assertIsInstance(span, TestSpan)
74-
75-
trace._TRACER_PROVIDER = original_provider
+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright The OpenTelemetry Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from opentelemetry.test.concurrency_test import ConcurrencyTestBase, MockFunc
16+
from opentelemetry.util._once import Once
17+
18+
19+
class TestOnce(ConcurrencyTestBase):
20+
def test_once_single_thread(self):
21+
once_func = MockFunc()
22+
once = Once()
23+
24+
self.assertEqual(once_func.call_count, 0)
25+
26+
# first call should run
27+
called = once.do_once(once_func)
28+
self.assertTrue(called)
29+
self.assertEqual(once_func.call_count, 1)
30+
31+
# subsequent calls do nothing
32+
called = once.do_once(once_func)
33+
self.assertFalse(called)
34+
self.assertEqual(once_func.call_count, 1)
35+
36+
def test_once_many_threads(self):
37+
once_func = MockFunc()
38+
once = Once()
39+
40+
def run_concurrently() -> bool:
41+
return once.do_once(once_func)
42+
43+
results = self.run_with_many_threads(run_concurrently, num_threads=100)
44+
45+
self.assertEqual(once_func.call_count, 1)
46+
47+
# check that only one of the threads got True
48+
self.assertEqual(results.count(True), 1)

0 commit comments

Comments
 (0)