1
1
import unittest
2
- from unittest .mock import patch
2
+ from unittest .mock import Mock , patch
3
3
4
4
from opentelemetry import context , trace
5
+ from opentelemetry .test .concurrency_test import ConcurrencyTestBase , MockFunc
5
6
from opentelemetry .trace .status import Status , StatusCode
6
7
7
8
@@ -27,23 +28,74 @@ def record_exception(
27
28
28
29
class TestGlobals (unittest .TestCase ):
29
30
def setUp (self ):
30
- self . _patcher = patch ( "opentelemetry.trace._TRACER_PROVIDER" )
31
- self . _mock_tracer_provider = self . _patcher . start ()
31
+ super (). setUp ( )
32
+ trace . _reset_globals () # pylint: disable=protected-access
32
33
33
- def tearDown (self ) -> None :
34
- self ._patcher .stop ()
34
+ def tearDown (self ):
35
+ super ().tearDown ()
36
+ trace ._reset_globals () # pylint: disable=protected-access
35
37
36
- def test_get_tracer (self ):
38
+ @staticmethod
39
+ @patch ("opentelemetry.trace._TRACER_PROVIDER" )
40
+ def test_get_tracer (mock_tracer_provider ): # type: ignore
37
41
"""trace.get_tracer should proxy to the global tracer provider."""
38
42
trace .get_tracer ("foo" , "var" )
39
- self ._mock_tracer_provider .get_tracer .assert_called_with (
40
- "foo" , "var" , None
41
- )
42
- mock_provider = unittest .mock .Mock ()
43
+ mock_tracer_provider .get_tracer .assert_called_with ("foo" , "var" , None )
44
+ mock_provider = Mock ()
43
45
trace .get_tracer ("foo" , "var" , mock_provider )
44
46
mock_provider .get_tracer .assert_called_with ("foo" , "var" , None )
45
47
46
48
49
+ class TestGlobalsConcurrency (ConcurrencyTestBase ):
50
+ def setUp (self ):
51
+ super ().setUp ()
52
+ trace ._reset_globals () # pylint: disable=protected-access
53
+
54
+ def tearDown (self ):
55
+ super ().tearDown ()
56
+ trace ._reset_globals () # pylint: disable=protected-access
57
+
58
+ @patch ("opentelemetry.trace.logger" )
59
+ def test_set_tracer_provider_many_threads (self , mock_logger ) -> None : # type: ignore
60
+ mock_logger .warning = MockFunc ()
61
+
62
+ def do_concurrently () -> Mock :
63
+ # first get a proxy tracer
64
+ proxy_tracer = trace .ProxyTracerProvider ().get_tracer ("foo" )
65
+
66
+ # try to set the global tracer provider
67
+ mock_tracer_provider = Mock (get_tracer = MockFunc ())
68
+ trace .set_tracer_provider (mock_tracer_provider )
69
+
70
+ # start a span through the proxy which will call through to the mock provider
71
+ proxy_tracer .start_span ("foo" )
72
+
73
+ return mock_tracer_provider
74
+
75
+ num_threads = 100
76
+ mock_tracer_providers = self .run_with_many_threads (
77
+ do_concurrently ,
78
+ num_threads = num_threads ,
79
+ )
80
+
81
+ # despite trying to set tracer provider many times, only one of the
82
+ # mock_tracer_providers should have stuck and been called from
83
+ # proxy_tracer.start_span()
84
+ mock_tps_with_any_call = [
85
+ mock
86
+ for mock in mock_tracer_providers
87
+ if mock .get_tracer .call_count > 0
88
+ ]
89
+
90
+ self .assertEqual (len (mock_tps_with_any_call ), 1 )
91
+ self .assertEqual (
92
+ mock_tps_with_any_call [0 ].get_tracer .call_count , num_threads
93
+ )
94
+
95
+ # should have warned everytime except for the successful set
96
+ self .assertEqual (mock_logger .warning .call_count , num_threads - 1 )
97
+
98
+
47
99
class TestTracer (unittest .TestCase ):
48
100
def setUp (self ):
49
101
# pylint: disable=protected-access
0 commit comments