Skip to content

Commit 06c561e

Browse files
committed
Allow to override global tracer_provider after tracers creation
1 parent 00578e3 commit 06c561e

File tree

3 files changed

+263
-25
lines changed

3 files changed

+263
-25
lines changed

opentelemetry-api/src/opentelemetry/trace/__init__.py

+81-5
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676

7777
import abc
7878
import enum
79+
import functools
7980
import typing
8081
from contextlib import contextmanager
8182
from logging import getLogger
@@ -409,9 +410,80 @@ def use_span(
409410
yield
410411

411412

413+
class ProxyTracer(Tracer):
414+
"""Proxies all calls to current TracerProvider
415+
"""
416+
417+
def __init__(
418+
self, get_current_tracer: typing.Callable[[], "Tracer"],
419+
) -> None:
420+
self._get_current_tracer = get_current_tracer
421+
422+
def start_span(
423+
self,
424+
name: str,
425+
context: typing.Optional[Context] = None,
426+
kind: SpanKind = SpanKind.INTERNAL,
427+
attributes: types.Attributes = None,
428+
links: typing.Sequence[Link] = (),
429+
start_time: typing.Optional[int] = None,
430+
record_exception: bool = True,
431+
set_status_on_exception: bool = True,
432+
) -> "Span":
433+
return self._get_current_tracer().start_span(
434+
name=name,
435+
context=context,
436+
kind=kind,
437+
attributes=attributes,
438+
links=links,
439+
start_time=start_time,
440+
record_exception=record_exception,
441+
set_status_on_exception=set_status_on_exception,
442+
)
443+
444+
@contextmanager # type: ignore
445+
def start_as_current_span(
446+
self,
447+
name: str,
448+
context: typing.Optional[Context] = None,
449+
kind: SpanKind = SpanKind.INTERNAL,
450+
attributes: types.Attributes = None,
451+
links: typing.Sequence[Link] = (),
452+
start_time: typing.Optional[int] = None,
453+
record_exception: bool = True,
454+
set_status_on_exception: bool = True,
455+
) -> typing.Iterator["Span"]:
456+
with self._get_current_tracer().start_as_current_span(
457+
name=name,
458+
context=context,
459+
kind=kind,
460+
attributes=attributes,
461+
links=links,
462+
start_time=start_time,
463+
record_exception=record_exception,
464+
set_status_on_exception=set_status_on_exception,
465+
) as span:
466+
yield span
467+
468+
@contextmanager # type: ignore
469+
def use_span(
470+
self, span: "Span", end_on_exit: bool = False,
471+
) -> typing.Iterator[None]:
472+
with self._get_current_tracer().use_span(
473+
span=span, end_on_exit=end_on_exit,
474+
) as context:
475+
yield context
476+
477+
412478
_TRACER_PROVIDER = None
413479

414480

481+
@functools.lru_cache # type: ignore
482+
def _get_current_tracer(*args: typing.Any, **kwargs: typing.Any,) -> "Tracer":
483+
tracer_provider = get_tracer_provider()
484+
return tracer_provider.get_tracer(*args, **kwargs) # type: ignore
485+
486+
415487
def get_tracer(
416488
instrumenting_module_name: str,
417489
instrumenting_library_version: str = "",
@@ -425,7 +497,14 @@ def get_tracer(
425497
If tracer_provider is ommited the current configured one is used.
426498
"""
427499
if tracer_provider is None:
428-
tracer_provider = get_tracer_provider()
500+
return ProxyTracer(
501+
functools.partial(
502+
_get_current_tracer,
503+
instrumenting_module_name=instrumenting_module_name,
504+
instrumenting_library_version=instrumenting_library_version,
505+
)
506+
)
507+
429508
return tracer_provider.get_tracer(
430509
instrumenting_module_name, instrumenting_library_version
431510
)
@@ -439,11 +518,8 @@ def set_tracer_provider(tracer_provider: TracerProvider) -> None:
439518
"""
440519
global _TRACER_PROVIDER # pylint: disable=global-statement
441520

442-
if _TRACER_PROVIDER is not None:
443-
logger.warning("Overriding of current TracerProvider is not allowed")
444-
return
445-
446521
_TRACER_PROVIDER = tracer_provider
522+
_get_current_tracer.cache_clear() # pylint: disable=no-member
447523

448524

449525
def get_tracer_provider() -> TracerProvider:

opentelemetry-api/tests/trace/test_globals.py

+174-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import unittest
2-
from unittest.mock import patch
2+
from unittest.mock import MagicMock, patch
33

44
from opentelemetry import context, trace
55

@@ -13,13 +13,77 @@ def tearDown(self) -> None:
1313
self._patcher.stop()
1414

1515
def test_get_tracer(self):
16-
"""trace.get_tracer should proxy to the global tracer provider."""
17-
trace.get_tracer("foo", "var")
18-
self._mock_tracer_provider.get_tracer.assert_called_with("foo", "var")
16+
"""trace.get_tracer should create a proxy to the global tracer provider."""
17+
tracer = trace.get_tracer("foo", "var")
18+
self._mock_tracer_provider.get_tracer.assert_not_called()
19+
self.assertIsInstance(tracer, trace.ProxyTracer)
20+
21+
tracer.start_span("one")
22+
tracer.start_span("two")
23+
self._mock_tracer_provider.get_tracer.assert_called_once_with(
24+
instrumenting_module_name="foo",
25+
instrumenting_library_version="var",
26+
)
27+
1928
mock_provider = unittest.mock.Mock()
2029
trace.get_tracer("foo", "var", mock_provider)
2130
mock_provider.get_tracer.assert_called_with("foo", "var")
2231

32+
def test_set_tracer_provider(self):
33+
"""trace.get_tracer should update global tracer provider."""
34+
self.assertIs(
35+
trace._TRACER_PROVIDER, # pylint: disable=protected-access
36+
self._mock_tracer_provider,
37+
)
38+
39+
tracer_provider1 = trace.DefaultTracerProvider()
40+
trace.set_tracer_provider(tracer_provider1)
41+
self.assertIs(
42+
trace._TRACER_PROVIDER, # pylint: disable=protected-access
43+
tracer_provider1,
44+
)
45+
46+
tracer_provider2 = trace.DefaultTracerProvider()
47+
trace.set_tracer_provider(tracer_provider2)
48+
self.assertIs(
49+
trace._TRACER_PROVIDER, # pylint: disable=protected-access
50+
tracer_provider2,
51+
)
52+
53+
@patch("opentelemetry.trace._load_trace_provider")
54+
def test_get_tracer_provider(self, load_trace_provider_mock):
55+
"""trace.get_tracer should get or create a global tracer provider."""
56+
load_trace_provider_mock.assert_not_called()
57+
58+
tracer_provider = trace.get_tracer_provider()
59+
self.assertIs(
60+
trace._TRACER_PROVIDER, # pylint: disable=protected-access
61+
tracer_provider,
62+
)
63+
load_trace_provider_mock.assert_not_called()
64+
65+
trace._TRACER_PROVIDER = None # pylint: disable=protected-access
66+
tracer_provider1 = trace.get_tracer_provider()
67+
self.assertIsNotNone(
68+
trace._TRACER_PROVIDER # pylint: disable=protected-access
69+
)
70+
self.assertIs(
71+
trace._TRACER_PROVIDER, # pylint: disable=protected-access
72+
tracer_provider1,
73+
)
74+
load_trace_provider_mock.assert_called_once_with("tracer_provider")
75+
76+
tracer_provider2 = trace.get_tracer_provider()
77+
self.assertIsNotNone(
78+
trace._TRACER_PROVIDER # pylint: disable=protected-access
79+
)
80+
self.assertIs(
81+
trace._TRACER_PROVIDER, # pylint: disable=protected-access
82+
tracer_provider2,
83+
)
84+
self.assertIs(tracer_provider1, tracer_provider2)
85+
load_trace_provider_mock.assert_called_once_with("tracer_provider")
86+
2387

2488
class TestTracer(unittest.TestCase):
2589
def setUp(self):
@@ -38,3 +102,109 @@ def test_get_current_span(self):
38102
finally:
39103
context.detach(token)
40104
self.assertEqual(trace.get_current_span(), trace.INVALID_SPAN)
105+
106+
107+
class TestProxyTracer(unittest.TestCase):
108+
def setUp(self):
109+
self._patcher = patch("opentelemetry.trace._TRACER_PROVIDER")
110+
self._patcher.start()
111+
112+
self.proxy_tracer = trace.get_tracer("foo", "var")
113+
self.inner_tracer = MagicMock(wraps=trace.DefaultTracer())
114+
115+
tracer_provider = MagicMock(wraps=trace.DefaultTracerProvider())
116+
tracer_provider.get_tracer.return_value = self.inner_tracer
117+
trace.set_tracer_provider(tracer_provider)
118+
119+
def tearDown(self):
120+
self._patcher.stop()
121+
122+
def test_start_span(self):
123+
"""ProxyTracer should call `start_span` on a real `Tracer`
124+
"""
125+
self.inner_tracer.start_span.assert_not_called()
126+
127+
span = self.proxy_tracer.start_span("span1")
128+
self.assertIs(span, trace.INVALID_SPAN)
129+
self.inner_tracer.start_span.assert_called_once_with(
130+
name="span1",
131+
context=None,
132+
kind=trace.SpanKind.INTERNAL,
133+
attributes=None,
134+
links=(),
135+
start_time=None,
136+
record_exception=True,
137+
set_status_on_exception=True,
138+
)
139+
140+
def test_start_as_current_span(self):
141+
"""ProxyTracer should call `start_as_current_span` on a real `Tracer`
142+
"""
143+
self.inner_tracer.start_as_current_span.assert_not_called()
144+
145+
with self.proxy_tracer.start_as_current_span("span1") as span:
146+
self.assertIs(span, trace.INVALID_SPAN)
147+
self.inner_tracer.start_as_current_span.assert_called_once_with(
148+
name="span1",
149+
context=None,
150+
kind=trace.SpanKind.INTERNAL,
151+
attributes=None,
152+
links=(),
153+
start_time=None,
154+
record_exception=True,
155+
set_status_on_exception=True,
156+
)
157+
158+
self.inner_tracer.start_as_current_span.reset_mock()
159+
160+
@self.proxy_tracer.start_as_current_span("span1")
161+
def func(arg, kwarg):
162+
self.assertEqual(arg, "argval")
163+
self.assertEqual(kwarg, "kwargval")
164+
return "retval"
165+
166+
self.inner_tracer.start_as_current_span.assert_not_called()
167+
168+
result = func("argval", kwarg="kwargval")
169+
self.assertEqual(result, "retval")
170+
171+
self.inner_tracer.start_as_current_span.assert_called_once_with(
172+
name="span1",
173+
context=None,
174+
kind=trace.SpanKind.INTERNAL,
175+
attributes=None,
176+
links=(),
177+
start_time=None,
178+
record_exception=True,
179+
set_status_on_exception=True,
180+
)
181+
182+
def test_use_span(self):
183+
"""ProxyTracer should call `use_span` on a real `Tracer`
184+
"""
185+
self.inner_tracer.use_span.assert_not_called()
186+
187+
span = self.proxy_tracer.start_span("span1")
188+
189+
with self.proxy_tracer.use_span(span) as current_context:
190+
self.assertIsNone(current_context)
191+
self.inner_tracer.use_span.assert_called_once_with(
192+
span=span, end_on_exit=False,
193+
)
194+
195+
self.inner_tracer.use_span.reset_mock()
196+
197+
@self.proxy_tracer.use_span(span)
198+
def func(arg, kwarg):
199+
self.assertEqual(arg, "argval")
200+
self.assertEqual(kwarg, "kwargval")
201+
return "retval"
202+
203+
self.inner_tracer.use_span.assert_not_called()
204+
205+
result = func("argval", kwarg="kwargval")
206+
self.assertEqual(result, "retval")
207+
208+
self.inner_tracer.use_span.assert_called_once_with(
209+
span=span, end_on_exit=False,
210+
)
+8-16
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,17 @@
11
# type:ignore
22
import unittest
3-
from logging import WARNING
43

54
from opentelemetry import trace
65
from opentelemetry.sdk.trace import TracerProvider # type:ignore
76

87

98
class TestGlobals(unittest.TestCase):
10-
def test_tracer_provider_override_warning(self):
11-
"""trace.set_tracer_provider should throw a warning when overridden"""
12-
trace.set_tracer_provider(TracerProvider())
13-
tracer_provider = trace.get_tracer_provider()
14-
with self.assertLogs(level=WARNING) as test:
15-
trace.set_tracer_provider(TracerProvider())
16-
self.assertEqual(
17-
test.output,
18-
[
19-
(
20-
"WARNING:opentelemetry.trace:Overriding of current "
21-
"TracerProvider is not allowed"
22-
)
23-
],
24-
)
9+
def test_tracer_provider_override(self):
10+
"""trace.set_tracer_provider should override global tracer"""
11+
tracer_provider = TracerProvider()
12+
trace.set_tracer_provider(tracer_provider)
2513
self.assertIs(tracer_provider, trace.get_tracer_provider())
14+
15+
new_tracer_provider = TracerProvider()
16+
trace.set_tracer_provider(new_tracer_provider)
17+
self.assertIs(new_tracer_provider, trace.get_tracer_provider())

0 commit comments

Comments
 (0)