Skip to content

Commit ec0056b

Browse files
committed
Fix race in set_tracer_provider
1 parent c9b18c6 commit ec0056b

File tree

7 files changed

+275
-34
lines changed

7 files changed

+275
-34
lines changed

opentelemetry-api/src/opentelemetry/trace/__init__.py

+31-19
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@
108108
)
109109
from opentelemetry.trace.status import Status, StatusCode
110110
from opentelemetry.util import types
111+
from opentelemetry.util._once import Once
111112
from opentelemetry.util._providers import _load_provider
112113

113114
logger = getLogger(__name__)
@@ -452,8 +453,19 @@ def start_as_current_span(
452453
yield INVALID_SPAN
453454

454455

455-
_TRACER_PROVIDER = None
456-
_PROXY_TRACER_PROVIDER = None
456+
_TRACER_PROVIDER_SET_ONCE = Once()
457+
_TRACER_PROVIDER: Optional[TracerProvider] = None
458+
_PROXY_TRACER_PROVIDER = ProxyTracerProvider()
459+
460+
461+
def _reset_globals() -> None:
462+
"""WARNING: only use this for tests."""
463+
global _TRACER_PROVIDER_SET_ONCE # pylint: disable=global-statement
464+
global _TRACER_PROVIDER # pylint: disable=global-statement
465+
global _PROXY_TRACER_PROVIDER # pylint: disable=global-statement
466+
_TRACER_PROVIDER_SET_ONCE = Once()
467+
_TRACER_PROVIDER = None
468+
_PROXY_TRACER_PROVIDER = ProxyTracerProvider()
457469

458470

459471
def get_tracer(
@@ -476,40 +488,40 @@ def get_tracer(
476488
)
477489

478490

491+
def _set_tracer_provider(tracer_provider: TracerProvider, log: bool) -> None:
492+
def set_tp() -> None:
493+
global _TRACER_PROVIDER # pylint: disable=global-statement
494+
_TRACER_PROVIDER = tracer_provider
495+
496+
did_set = _TRACER_PROVIDER_SET_ONCE.do_once(set_tp)
497+
498+
if not did_set:
499+
logger.warning("Overriding of current TracerProvider is not allowed")
500+
501+
479502
def set_tracer_provider(tracer_provider: TracerProvider) -> None:
480503
"""Sets the current global :class:`~.TracerProvider` object.
481504
482505
This can only be done once, a warning will be logged if any furter attempt
483506
is made.
484507
"""
485-
global _TRACER_PROVIDER # pylint: disable=global-statement
486-
487-
if _TRACER_PROVIDER is not None:
488-
logger.warning("Overriding of current TracerProvider is not allowed")
489-
return
490-
491-
_TRACER_PROVIDER = tracer_provider
508+
_set_tracer_provider(tracer_provider, log=True)
492509

493510

494511
def get_tracer_provider() -> TracerProvider:
495512
"""Gets the current global :class:`~.TracerProvider` object."""
496-
# pylint: disable=global-statement
497-
global _TRACER_PROVIDER
498-
global _PROXY_TRACER_PROVIDER
499-
500513
if _TRACER_PROVIDER is None:
501514
# if a global tracer provider has not been set either via code or env
502515
# vars, return a proxy tracer provider
503516
if OTEL_PYTHON_TRACER_PROVIDER not in os.environ:
504-
if not _PROXY_TRACER_PROVIDER:
505-
_PROXY_TRACER_PROVIDER = ProxyTracerProvider()
506517
return _PROXY_TRACER_PROVIDER
507518

508-
_TRACER_PROVIDER = cast( # type: ignore
509-
"TracerProvider",
510-
_load_provider(OTEL_PYTHON_TRACER_PROVIDER, "tracer_provider"),
519+
tracer_provider: TracerProvider = _load_provider(
520+
OTEL_PYTHON_TRACER_PROVIDER, "tracer_provider"
511521
)
512-
return _TRACER_PROVIDER
522+
_set_tracer_provider(tracer_provider, log=False)
523+
# _TRACER_PROVIDER will have been set by one thread
524+
return cast("TracerProvider", _TRACER_PROVIDER)
513525

514526

515527
@contextmanager # type: ignore
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright The OpenTelemetry Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from threading import Lock
16+
from typing import Callable
17+
18+
19+
class Once:
20+
"""Execute a function exactly once and block all callers until the function returns
21+
22+
Same as golang's `sync.Once <https://pkg.go.dev/sync#Once>`_
23+
"""
24+
25+
def __init__(self) -> None:
26+
self._lock = Lock()
27+
self._done = False
28+
29+
def do_once(self, func: Callable[[], None]) -> bool:
30+
"""Execute ``func`` if it hasn't been executed or return.
31+
32+
Will block until ``func`` has been called by one thread.
33+
34+
Returns:
35+
Whether or not ``func`` was executed in this call
36+
"""
37+
38+
# fast path, try to avoid locking
39+
if self._done:
40+
return False
41+
42+
with self._lock:
43+
if not self._done:
44+
func()
45+
self._done = True
46+
return True
47+
return False

opentelemetry-api/tests/trace/test_globals.py

+60-10
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import sys
12
import unittest
2-
from unittest.mock import patch
3+
from unittest.mock import Mock, patch
34

45
from opentelemetry import context, trace
6+
from opentelemetry.test.concurrency_test import ConcurrencyTestBase, MockFunc
57
from opentelemetry.trace.status import Status, StatusCode
68

79

@@ -27,23 +29,71 @@ def record_exception(
2729

2830
class TestGlobals(unittest.TestCase):
2931
def setUp(self):
30-
self._patcher = patch("opentelemetry.trace._TRACER_PROVIDER")
31-
self._mock_tracer_provider = self._patcher.start()
32+
trace._reset_globals()
3233

33-
def tearDown(self) -> None:
34-
self._patcher.stop()
34+
def tearDown(self):
35+
trace._reset_globals()
3536

36-
def test_get_tracer(self):
37+
@patch("opentelemetry.trace._TRACER_PROVIDER")
38+
def test_get_tracer(self, mock_tracer_provider): # type: ignore
3739
"""trace.get_tracer should proxy to the global tracer provider."""
3840
trace.get_tracer("foo", "var")
39-
self._mock_tracer_provider.get_tracer.assert_called_with(
40-
"foo", "var", None
41-
)
42-
mock_provider = unittest.mock.Mock()
41+
mock_tracer_provider.get_tracer.assert_called_with("foo", "var", None)
42+
mock_provider = Mock()
4343
trace.get_tracer("foo", "var", mock_provider)
4444
mock_provider.get_tracer.assert_called_with("foo", "var", None)
4545

4646

47+
class TestGlobalsConcurrency(ConcurrencyTestBase):
48+
def setUp(self):
49+
super().setUp()
50+
trace._reset_globals()
51+
52+
def tearDown(self):
53+
super().tearDown()
54+
trace._reset_globals()
55+
56+
@patch("opentelemetry.trace.logger")
57+
def test_set_tracer_provider_many_threads(self, mock_logger) -> None: # type: ignore
58+
mock_logger.warning = MockFunc()
59+
60+
def do_concurrently() -> Mock:
61+
# first get a proxy tracer
62+
proxy_tracer = trace.ProxyTracerProvider().get_tracer("foo")
63+
64+
# try to set the global tracer provider
65+
mock_tracer_provider = Mock()
66+
mock_tracer_provider.get_tracer = MockFunc()
67+
trace.set_tracer_provider(mock_tracer_provider)
68+
69+
# start a span through the proxy which will call through to the mock provider
70+
proxy_tracer.start_span("foo")
71+
72+
return mock_tracer_provider
73+
74+
num_threads = 100
75+
mock_tracer_providers = self.run_with_many_threads(
76+
do_concurrently,
77+
num_threads=num_threads,
78+
)
79+
80+
# despite setting tracer provider many times, only one of the
81+
# mock_tracer_providers should have stuck and been called from
82+
# proxy_tracer.start_span()
83+
mocks_with_any_call = [
84+
mock
85+
for mock in mock_tracer_providers
86+
if mock.get_tracer.call_count > 0
87+
]
88+
89+
self.assertEqual(len(mocks_with_any_call), 1)
90+
the_mock = mocks_with_any_call[0]
91+
self.assertEqual(the_mock.get_tracer.call_count, num_threads)
92+
93+
# should have warned everytime except for the successful set
94+
self.assertEqual(mock_logger.warning.call_count, num_threads - 1)
95+
96+
4797
class TestTracer(unittest.TestCase):
4898
def setUp(self):
4999
# pylint: disable=protected-access

opentelemetry-api/tests/trace/test_proxy.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,15 @@ class TestSpan(NonRecordingSpan):
4040

4141

4242
class TestProxy(unittest.TestCase):
43-
def test_proxy_tracer(self):
44-
original_provider = trace._TRACER_PROVIDER
43+
def setUp(self) -> None:
44+
super().setUp()
45+
trace._reset_globals()
46+
47+
def tearDown(self) -> None:
48+
super().tearDown()
49+
trace._reset_globals()
4550

51+
def test_proxy_tracer(self):
4652
provider = trace.get_tracer_provider()
4753
# proxy provider
4854
self.assertIsInstance(provider, trace.ProxyTracerProvider)
@@ -60,6 +66,9 @@ def test_proxy_tracer(self):
6066
# set a real provider
6167
trace.set_tracer_provider(TestProvider())
6268

69+
# get tracer provider immediately returns the real provider
70+
self.assertIsInstance(trace.get_tracer_provider(), TestProvider)
71+
6372
# tracer provider now returns real instance
6473
self.assertIsInstance(trace.get_tracer_provider(), TestProvider)
6574

@@ -71,5 +80,3 @@ def test_proxy_tracer(self):
7180
# creates real spans
7281
with tracer.start_span("") as span:
7382
self.assertIsInstance(span, TestSpan)
74-
75-
trace._TRACER_PROVIDER = original_provider
+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright The OpenTelemetry Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import threading
16+
17+
from opentelemetry.test.concurrency_test import ConcurrencyTestBase, MockFunc
18+
from opentelemetry.util._once import Once
19+
20+
21+
class TestOnce(ConcurrencyTestBase):
22+
def test_once_single_thread(self):
23+
once_func = MockFunc()
24+
once = Once()
25+
26+
self.assertEqual(once_func.call_count, 0)
27+
28+
# first call should run
29+
called = once.do_once(once_func)
30+
self.assertTrue(called)
31+
self.assertEqual(once_func.call_count, 1)
32+
33+
# subsequent calls do nothing
34+
called = once.do_once(once_func)
35+
self.assertFalse(called)
36+
self.assertEqual(once_func.call_count, 1)
37+
38+
def test_once_many_threads(self):
39+
once_func = MockFunc()
40+
once = Once()
41+
42+
def run_concurrently() -> bool:
43+
return once.do_once(once_func)
44+
45+
results = self.run_with_many_threads(run_concurrently, num_threads=100)
46+
47+
self.assertEqual(once_func.call_count, 1)
48+
49+
# check that only one of the threads got True
50+
self.assertEqual(results.count(True), 1)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright The OpenTelemetry Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import sys
16+
import threading
17+
import unittest
18+
from functools import partial
19+
from typing import Callable, List, Optional, TypeVar
20+
from unittest.mock import Mock
21+
22+
ReturnT = TypeVar("ReturnT")
23+
24+
25+
# Can't use Mock directly because its call count is not thread safe
26+
class MockFunc:
27+
def __init__(self) -> None:
28+
self.lock = threading.Lock()
29+
self.call_count = 0
30+
self.mock = Mock()
31+
32+
def __call__(self, *args, **kwargs):
33+
with self.lock:
34+
self.call_count += 1
35+
return self.mock
36+
37+
38+
class ConcurrencyTestBase(unittest.TestCase):
39+
orig_switch_interval = sys.getswitchinterval()
40+
41+
@classmethod
42+
def setUpClass(cls) -> None:
43+
super().setUpClass()
44+
# switch threads more often to increase chance of contention
45+
sys.setswitchinterval(1e-12)
46+
47+
@classmethod
48+
def tearDownClass(cls) -> None:
49+
super().tearDownClass()
50+
sys.setswitchinterval(cls.orig_switch_interval)
51+
52+
@staticmethod
53+
def run_with_many_threads(
54+
func_to_test: Callable[[], ReturnT],
55+
num_threads: int = 100,
56+
) -> List[ReturnT]:
57+
barrier = threading.Barrier(num_threads)
58+
results: List[Optional[ReturnT]] = [None] * num_threads
59+
60+
def thread_start(idx: int) -> None:
61+
nonlocal results
62+
# Get all threads here before releasing them to create contention
63+
barrier.wait()
64+
results[idx] = func_to_test()
65+
66+
threads = [
67+
threading.Thread(target=partial(thread_start, i))
68+
for i in range(num_threads)
69+
]
70+
for thread in threads:
71+
thread.start()
72+
for thread in threads:
73+
thread.join()
74+
75+
return results # type: ignore

tox.ini

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ setenv =
8585
; i.e: CONTRIB_REPO_SHA=dde62cebffe519c35875af6d06fae053b3be65ec tox -e <env to test>
8686
CONTRIB_REPO_SHA={env:CONTRIB_REPO_SHA:"main"}
8787
CONTRIB_REPO="git+https://github.com/open-telemetry/opentelemetry-python-contrib.git@{env:CONTRIB_REPO_SHA}"
88-
mypy: MYPYPATH={toxinidir}/opentelemetry-api/src/
88+
mypy: MYPYPATH={toxinidir}/opentelemetry-api/src/:{toxinidir}/tests/util/src/
8989

9090
changedir =
9191
api: opentelemetry-api/tests

0 commit comments

Comments
 (0)