Skip to content

Commit 72bc5d9

Browse files
committed
feat: support ThreadPool and update document
1 parent 4b1ac55 commit 72bc5d9

File tree

4 files changed

+123
-39
lines changed

4 files changed

+123
-39
lines changed

instrumentation/opentelemetry-instrumentation-threading/README.rst

+3-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ OpenTelemetry threading Instrumentation
77
:target: https://pypi.org/project/opentelemetry-instrumentation-threading/
88

99
This library provides instrumentation for the `threading` module to ensure that
10-
the OpenTelemetry context is propagated across threads.
10+
the OpenTelemetry context is propagated across threads. It is important to note
11+
that this instrumentation does not produce any telemetry data on its own. It
12+
merely ensures that the context is correctly propagated when threads are used.
1113

1214
Installation
1315
------------

instrumentation/opentelemetry-instrumentation-threading/pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ build-backend = "hatchling.build"
55
[project]
66
name = "opentelemetry-instrumentation-threading"
77
dynamic = ["version"]
8-
description = "Threading tracing for OpenTelemetry"
8+
description = "Thread context propagation support for OpenTelemetry"
99
readme = "README.rst"
1010
license = "Apache-2.0"
1111
requires-python = ">=3.8"

instrumentation/opentelemetry-instrumentation-threading/src/opentelemetry/instrumentation/threading/__init__.py

+73-15
Original file line numberDiff line numberDiff line change
@@ -24,72 +24,103 @@
2424
ThreadingInstrumentor().instrument()
2525
2626
This library provides instrumentation for the `threading` module to ensure that
27-
the OpenTelemetry context is propagated across threads.
27+
the OpenTelemetry context is propagated across threads. It is important to note
28+
that this instrumentation does not produce any telemetry data on its own. It
29+
merely ensures that the context is correctly propagated when threads are used.
2830
29-
When instrumented, new threads created using `threading.Thread` or `threading.Timer`
30-
will have the current OpenTelemetry context attached, and this context will be
31-
re-activated in the thread's run method.
31+
32+
When instrumented, new threads created using threading.Thread, threading.Timer,
33+
or within futures.ThreadPoolExecutor will have the current OpenTelemetry
34+
context attached, and this context will be re-activated in the thread's
35+
run method or the executor's worker thread."
3236
"""
3337

3438
import threading
39+
from concurrent import futures
3540
from typing import Collection
3641

3742
from wrapt import wrap_function_wrapper
3843

39-
from opentelemetry import context, trace
44+
from opentelemetry import context
4045
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
4146
from opentelemetry.instrumentation.threading.package import _instruments
4247
from opentelemetry.instrumentation.utils import unwrap
4348

4449

4550
class ThreadingInstrumentor(BaseInstrumentor):
51+
__WRAPPER_START_METHOD = "start"
52+
__WRAPPER_RUN_METHOD = "run"
53+
__WRAPPER_SUBMIT_METHOD = "submit"
54+
__WRAPPER_KWARGS = "kwargs"
55+
__WRAPPER_CONTEXT = "_otel_context"
56+
4657
def instrumentation_dependencies(self) -> Collection[str]:
4758
return _instruments
4859

4960
def _instrument(self, **kwargs):
5061
self._instrument_thread()
5162
self._instrument_timer()
63+
self._instrument_thread_pool()
5264

5365
def _uninstrument(self, **kwargs):
5466
self._uninstrument_thread()
5567
self._uninstrument_timer()
68+
self._uninstrument_thread_pool()
5669

5770
@staticmethod
5871
def _instrument_thread():
5972
wrap_function_wrapper(
6073
threading.Thread,
61-
"start",
74+
ThreadingInstrumentor.__WRAPPER_START_METHOD,
6275
ThreadingInstrumentor.__wrap_threading_start,
6376
)
6477
wrap_function_wrapper(
65-
threading.Thread, "run", ThreadingInstrumentor.__wrap_threading_run
78+
threading.Thread,
79+
ThreadingInstrumentor.__WRAPPER_RUN_METHOD,
80+
ThreadingInstrumentor.__wrap_threading_run,
6681
)
6782

6883
@staticmethod
6984
def _instrument_timer():
7085
wrap_function_wrapper(
7186
threading.Timer,
72-
"start",
87+
ThreadingInstrumentor.__WRAPPER_START_METHOD,
7388
ThreadingInstrumentor.__wrap_threading_start,
7489
)
7590
wrap_function_wrapper(
76-
threading.Timer, "run", ThreadingInstrumentor.__wrap_threading_run
91+
threading.Timer,
92+
ThreadingInstrumentor.__WRAPPER_RUN_METHOD,
93+
ThreadingInstrumentor.__wrap_threading_run,
94+
)
95+
96+
@staticmethod
97+
def _instrument_thread_pool():
98+
wrap_function_wrapper(
99+
futures.ThreadPoolExecutor,
100+
ThreadingInstrumentor.__WRAPPER_SUBMIT_METHOD,
101+
ThreadingInstrumentor.__wrap_thread_pool_submit,
77102
)
78103

79104
@staticmethod
80105
def _uninstrument_thread():
81-
unwrap(threading.Thread, "start")
82-
unwrap(threading.Thread, "run")
106+
unwrap(threading.Thread, ThreadingInstrumentor.__WRAPPER_START_METHOD)
107+
unwrap(threading.Thread, ThreadingInstrumentor.__WRAPPER_RUN_METHOD)
83108

84109
@staticmethod
85110
def _uninstrument_timer():
86-
unwrap(threading.Timer, "start")
87-
unwrap(threading.Timer, "run")
111+
unwrap(threading.Timer, ThreadingInstrumentor.__WRAPPER_START_METHOD)
112+
unwrap(threading.Timer, ThreadingInstrumentor.__WRAPPER_RUN_METHOD)
113+
114+
@staticmethod
115+
def _uninstrument_thread_pool():
116+
unwrap(
117+
futures.ThreadPoolExecutor,
118+
ThreadingInstrumentor.__WRAPPER_SUBMIT_METHOD,
119+
)
88120

89121
@staticmethod
90122
def __wrap_threading_start(call_wrapped, instance, args, kwargs):
91-
span = trace.get_current_span()
92-
instance._otel_context = trace.set_span_in_context(span)
123+
instance._otel_context = context.get_current()
93124
return call_wrapped(*args, **kwargs)
94125

95126
@staticmethod
@@ -100,3 +131,30 @@ def __wrap_threading_run(call_wrapped, instance, args, kwargs):
100131
return call_wrapped(*args, **kwargs)
101132
finally:
102133
context.detach(token)
134+
135+
@staticmethod
136+
def __wrap_thread_pool_submit(call_wrapped, instance, args, kwargs):
137+
# obtain the original function and wrapped kwargs
138+
original_func = args[0]
139+
wrapped_kwargs = {
140+
ThreadingInstrumentor.__WRAPPER_KWARGS: kwargs,
141+
ThreadingInstrumentor.__WRAPPER_CONTEXT: context.get_current(),
142+
}
143+
144+
def wrapped_func(*func_args, **func_kwargs):
145+
original_kwargs = func_kwargs.pop(
146+
ThreadingInstrumentor.__WRAPPER_KWARGS
147+
)
148+
otel_context = func_kwargs.pop(
149+
ThreadingInstrumentor.__WRAPPER_CONTEXT
150+
)
151+
token = None
152+
try:
153+
token = context.attach(otel_context)
154+
return original_func(*func_args, **original_kwargs)
155+
finally:
156+
context.detach(token)
157+
158+
# replace the original function with the wrapped function
159+
new_args = (wrapped_func,) + args[1:]
160+
return call_wrapped(*new_args, **wrapped_kwargs)

instrumentation/opentelemetry-instrumentation-threading/tests/test_threading.py

+46-22
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,27 @@
1313
# limitations under the License.
1414

1515
import threading
16-
from concurrent import futures
16+
from concurrent.futures import ThreadPoolExecutor, as_completed
17+
from dataclasses import dataclass
18+
from typing import List
1719

1820
from opentelemetry import trace
1921
from opentelemetry.instrumentation.threading import ThreadingInstrumentor
2022
from opentelemetry.test.test_base import TestBase
2123

2224

25+
@dataclass
26+
class MockContext:
27+
span_context: trace.SpanContext = None
28+
trace_id: int = None
29+
span_id: int = None
30+
31+
2332
class TestThreading(TestBase):
2433
def setUp(self):
2534
super().setUp()
2635
self._tracer = self.tracer_provider.get_tracer(__name__)
27-
self.global_context = None
28-
self.global_trace_id = None
29-
self.global_span_id = None
36+
self._mock_contexts: List[MockContext] = []
3037
ThreadingInstrumentor().instrument()
3138

3239
def tearDown(self):
@@ -54,31 +61,48 @@ def run_threading_test(self, thread: threading.Thread):
5461
thread.join()
5562

5663
# check result
57-
self.assertEqual(self.global_context, expected_context)
58-
self.assertEqual(self.global_trace_id, expected_trace_id)
59-
self.assertEqual(self.global_span_id, expected_span_id)
64+
self.assertEqual(len(self._mock_contexts), 1)
6065

61-
def test_trace_context_propagation_in_thread_pool(self):
62-
with self.get_root_span() as span:
63-
span_context = span.get_span_context()
64-
expected_context = span_context
65-
expected_trace_id = span_context.trace_id
66-
expected_span_id = span_context.span_id
66+
current_mock_context = self._mock_contexts[0]
67+
self.assertEqual(
68+
current_mock_context.span_context, expected_context
69+
)
70+
self.assertEqual(current_mock_context.trace_id, expected_trace_id)
71+
self.assertEqual(current_mock_context.span_id, expected_span_id)
6772

68-
with futures.ThreadPoolExecutor(max_workers=1) as executor:
73+
def test_trace_context_propagation_in_thread_pool(self):
74+
max_workers = 10
75+
executor = ThreadPoolExecutor(max_workers=max_workers)
76+
77+
expected_contexts: List[trace.SpanContext] = []
78+
futures_list = []
79+
for num in range(max_workers):
80+
with self._tracer.start_as_current_span(f"trace_{num}") as span:
81+
span_context = span.get_span_context()
82+
expected_contexts.append(span_context)
6983
future = executor.submit(self.fake_func)
70-
future.result()
84+
futures_list.append(future)
85+
86+
for future in as_completed(futures_list):
87+
future.result()
7188

72-
# check result
73-
self.assertEqual(self.global_context, expected_context)
74-
self.assertEqual(self.global_trace_id, expected_trace_id)
75-
self.assertEqual(self.global_span_id, expected_span_id)
89+
# check result
90+
self.assertEqual(len(self._mock_contexts), max_workers)
91+
self.assertEqual(len(self._mock_contexts), len(expected_contexts))
92+
for index, mock_context in enumerate(self._mock_contexts):
93+
span_context = expected_contexts[index]
94+
self.assertEqual(mock_context.span_context, span_context)
95+
self.assertEqual(mock_context.trace_id, span_context.trace_id)
96+
self.assertEqual(mock_context.span_id, span_context.span_id)
7697

7798
def fake_func(self):
7899
span_context = trace.get_current_span().get_span_context()
79-
self.global_context = span_context
80-
self.global_trace_id = span_context.trace_id
81-
self.global_span_id = span_context.span_id
100+
mock_context = MockContext(
101+
span_context=span_context,
102+
trace_id=span_context.trace_id,
103+
span_id=span_context.span_id,
104+
)
105+
self._mock_contexts.append(mock_context)
82106

83107
def print_square(self, num):
84108
with self._tracer.start_as_current_span("square"):

0 commit comments

Comments
 (0)