14
14
15
15
import threading
16
16
from concurrent .futures import ThreadPoolExecutor , as_completed
17
- from dataclasses import dataclass
18
17
from typing import List
19
18
20
19
from opentelemetry import trace
21
20
from opentelemetry .instrumentation .threading import ThreadingInstrumentor
22
21
from opentelemetry .test .test_base import TestBase
23
22
24
23
25
- @dataclass
26
- class MockContext :
27
- span_context : trace .SpanContext = None
28
- trace_id : int = None
29
- span_id : int = None
30
-
31
-
32
24
class TestThreading (TestBase ):
33
25
def setUp (self ):
34
26
super ().setUp ()
35
27
self ._tracer = self .tracer_provider .get_tracer (__name__ )
36
- self ._mock_contexts : List [MockContext ] = []
28
+ self ._mock_span_contexts : List [trace . SpanContext ] = []
37
29
ThreadingInstrumentor ().instrument ()
38
30
39
31
def tearDown (self ):
@@ -53,56 +45,58 @@ def test_trace_context_propagation_in_timer(self):
53
45
54
46
def run_threading_test (self , thread : threading .Thread ):
55
47
with self .get_root_span () as span :
56
- span_context = span .get_span_context ()
57
- expected_context = span_context
58
- expected_trace_id = span_context .trace_id
59
- expected_span_id = span_context .span_id
48
+ expected_span_context = span .get_span_context ()
60
49
thread .start ()
61
50
thread .join ()
62
51
63
52
# check result
64
- self .assertEqual (len (self ._mock_contexts ), 1 )
65
-
66
- current_mock_context = self ._mock_contexts [0 ]
67
- self .assertEqual (
68
- current_mock_context .span_context , expected_context
53
+ self .assertEqual (len (self ._mock_span_contexts ), 1 )
54
+ self .assert_span_context_equality (
55
+ self ._mock_span_contexts [0 ], expected_span_context
69
56
)
70
- self .assertEqual (current_mock_context .trace_id , expected_trace_id )
71
- self .assertEqual (current_mock_context .span_id , expected_span_id )
72
57
73
58
def test_trace_context_propagation_in_thread_pool (self ):
74
59
max_workers = 10
75
60
executor = ThreadPoolExecutor (max_workers = max_workers )
76
61
77
- expected_contexts : List [trace .SpanContext ] = []
62
+ expected_span_contexts : List [trace .SpanContext ] = []
78
63
futures_list = []
79
64
for num in range (max_workers ):
80
65
with self ._tracer .start_as_current_span (f"trace_{ num } " ) as span :
81
- span_context = span .get_span_context ()
82
- expected_contexts .append (span_context )
66
+ expected_span_context = span .get_span_context ()
67
+ expected_span_contexts .append (expected_span_context )
83
68
future = executor .submit (self .fake_func )
84
69
futures_list .append (future )
85
70
86
71
for future in as_completed (futures_list ):
87
72
future .result ()
88
73
89
74
# check result
90
- self .assertEqual (len (self ._mock_contexts ), max_workers )
91
- self .assertEqual (len (self ._mock_contexts ), len (expected_contexts ))
92
- for index , mock_context in enumerate (self ._mock_contexts ):
93
- span_context = expected_contexts [index ]
94
- self .assertEqual (mock_context .span_context , span_context )
95
- self .assertEqual (mock_context .trace_id , span_context .trace_id )
96
- self .assertEqual (mock_context .span_id , span_context .span_id )
75
+ self .assertEqual (len (self ._mock_span_contexts ), max_workers )
76
+ self .assertEqual (
77
+ len (self ._mock_span_contexts ), len (expected_span_contexts )
78
+ )
79
+ for index , mock_span_context in enumerate (self ._mock_span_contexts ):
80
+ self .assert_span_context_equality (
81
+ mock_span_context , expected_span_contexts [index ]
82
+ )
97
83
98
84
def fake_func (self ):
99
85
span_context = trace .get_current_span ().get_span_context ()
100
- mock_context = MockContext (
101
- span_context = span_context ,
102
- trace_id = span_context .trace_id ,
103
- span_id = span_context .span_id ,
86
+ self ._mock_span_contexts .append (span_context )
87
+
88
+ def assert_span_context_equality (
89
+ self ,
90
+ result_span_context : trace .SpanContext ,
91
+ expected_span_context : trace .SpanContext ,
92
+ ):
93
+ self .assertEqual (result_span_context , expected_span_context )
94
+ self .assertEqual (
95
+ result_span_context .trace_id , expected_span_context .trace_id
96
+ )
97
+ self .assertEqual (
98
+ result_span_context .span_id , expected_span_context .span_id
104
99
)
105
- self ._mock_contexts .append (mock_context )
106
100
107
101
def print_square (self , num ):
108
102
with self ._tracer .start_as_current_span ("square" ):
0 commit comments