Skip to content

Commit 58a5620

Browse files
committed
refactor: remove redundant class and simplify capture OTel context usage
1 parent 016101e commit 58a5620

File tree

2 files changed

+32
-47
lines changed

2 files changed

+32
-47
lines changed

instrumentation/opentelemetry-instrumentation-threading/src/opentelemetry/instrumentation/threading/__init__.py

+3-12
Original file line numberDiff line numberDiff line change
@@ -136,25 +136,16 @@ def __wrap_threading_run(call_wrapped, instance, args, kwargs):
136136
def __wrap_thread_pool_submit(call_wrapped, instance, args, kwargs):
137137
# obtain the original function and wrapped kwargs
138138
original_func = args[0]
139-
wrapped_kwargs = {
140-
ThreadingInstrumentor.__WRAPPER_KWARGS: kwargs,
141-
ThreadingInstrumentor.__WRAPPER_CONTEXT: context.get_current(),
142-
}
139+
otel_context = context.get_current()
143140

144141
def wrapped_func(*func_args, **func_kwargs):
145-
original_kwargs = func_kwargs.pop(
146-
ThreadingInstrumentor.__WRAPPER_KWARGS
147-
)
148-
otel_context = func_kwargs.pop(
149-
ThreadingInstrumentor.__WRAPPER_CONTEXT
150-
)
151142
token = None
152143
try:
153144
token = context.attach(otel_context)
154-
return original_func(*func_args, **original_kwargs)
145+
return original_func(*func_args, **func_kwargs)
155146
finally:
156147
context.detach(token)
157148

158149
# replace the original function with the wrapped function
159150
new_args = (wrapped_func,) + args[1:]
160-
return call_wrapped(*new_args, **wrapped_kwargs)
151+
return call_wrapped(*new_args, **kwargs)

instrumentation/opentelemetry-instrumentation-threading/tests/test_threading.py

+29-35
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,18 @@
1414

1515
import threading
1616
from concurrent.futures import ThreadPoolExecutor, as_completed
17-
from dataclasses import dataclass
1817
from typing import List
1918

2019
from opentelemetry import trace
2120
from opentelemetry.instrumentation.threading import ThreadingInstrumentor
2221
from opentelemetry.test.test_base import TestBase
2322

2423

25-
@dataclass
26-
class MockContext:
27-
span_context: trace.SpanContext = None
28-
trace_id: int = None
29-
span_id: int = None
30-
31-
3224
class TestThreading(TestBase):
3325
def setUp(self):
3426
super().setUp()
3527
self._tracer = self.tracer_provider.get_tracer(__name__)
36-
self._mock_contexts: List[MockContext] = []
28+
self._mock_span_contexts: List[trace.SpanContext] = []
3729
ThreadingInstrumentor().instrument()
3830

3931
def tearDown(self):
@@ -53,56 +45,58 @@ def test_trace_context_propagation_in_timer(self):
5345

5446
def run_threading_test(self, thread: threading.Thread):
5547
with self.get_root_span() as span:
56-
span_context = span.get_span_context()
57-
expected_context = span_context
58-
expected_trace_id = span_context.trace_id
59-
expected_span_id = span_context.span_id
48+
expected_span_context = span.get_span_context()
6049
thread.start()
6150
thread.join()
6251

6352
# check result
64-
self.assertEqual(len(self._mock_contexts), 1)
65-
66-
current_mock_context = self._mock_contexts[0]
67-
self.assertEqual(
68-
current_mock_context.span_context, expected_context
53+
self.assertEqual(len(self._mock_span_contexts), 1)
54+
self.assert_span_context_equality(
55+
self._mock_span_contexts[0], expected_span_context
6956
)
70-
self.assertEqual(current_mock_context.trace_id, expected_trace_id)
71-
self.assertEqual(current_mock_context.span_id, expected_span_id)
7257

7358
def test_trace_context_propagation_in_thread_pool(self):
7459
max_workers = 10
7560
executor = ThreadPoolExecutor(max_workers=max_workers)
7661

77-
expected_contexts: List[trace.SpanContext] = []
62+
expected_span_contexts: List[trace.SpanContext] = []
7863
futures_list = []
7964
for num in range(max_workers):
8065
with self._tracer.start_as_current_span(f"trace_{num}") as span:
81-
span_context = span.get_span_context()
82-
expected_contexts.append(span_context)
66+
expected_span_context = span.get_span_context()
67+
expected_span_contexts.append(expected_span_context)
8368
future = executor.submit(self.fake_func)
8469
futures_list.append(future)
8570

8671
for future in as_completed(futures_list):
8772
future.result()
8873

8974
# check result
90-
self.assertEqual(len(self._mock_contexts), max_workers)
91-
self.assertEqual(len(self._mock_contexts), len(expected_contexts))
92-
for index, mock_context in enumerate(self._mock_contexts):
93-
span_context = expected_contexts[index]
94-
self.assertEqual(mock_context.span_context, span_context)
95-
self.assertEqual(mock_context.trace_id, span_context.trace_id)
96-
self.assertEqual(mock_context.span_id, span_context.span_id)
75+
self.assertEqual(len(self._mock_span_contexts), max_workers)
76+
self.assertEqual(
77+
len(self._mock_span_contexts), len(expected_span_contexts)
78+
)
79+
for index, mock_span_context in enumerate(self._mock_span_contexts):
80+
self.assert_span_context_equality(
81+
mock_span_context, expected_span_contexts[index]
82+
)
9783

9884
def fake_func(self):
9985
span_context = trace.get_current_span().get_span_context()
100-
mock_context = MockContext(
101-
span_context=span_context,
102-
trace_id=span_context.trace_id,
103-
span_id=span_context.span_id,
86+
self._mock_span_contexts.append(span_context)
87+
88+
def assert_span_context_equality(
89+
self,
90+
result_span_context: trace.SpanContext,
91+
expected_span_context: trace.SpanContext,
92+
):
93+
self.assertEqual(result_span_context, expected_span_context)
94+
self.assertEqual(
95+
result_span_context.trace_id, expected_span_context.trace_id
96+
)
97+
self.assertEqual(
98+
result_span_context.span_id, expected_span_context.span_id
10499
)
105-
self._mock_contexts.append(mock_context)
106100

107101
def print_square(self, num):
108102
with self._tracer.start_as_current_span("square"):

0 commit comments

Comments
 (0)