13
13
# limitations under the License.
14
14
15
15
import threading
16
- from concurrent .futures import ThreadPoolExecutor , as_completed
16
+ from concurrent .futures import ThreadPoolExecutor
17
17
from typing import List
18
18
19
19
from opentelemetry import trace
@@ -51,11 +51,13 @@ def run_threading_test(self, thread: threading.Thread):
51
51
52
52
# check result
53
53
self .assertEqual (len (self ._mock_span_contexts ), 1 )
54
- self .assert_span_context_equality (
54
+ self .assertEqual (
55
55
self ._mock_span_contexts [0 ], expected_span_context
56
56
)
57
57
58
- def test_trace_context_propagation_in_thread_pool (self ):
58
+ def test_trace_context_propagation_in_thread_pool_with_multiple_workers (
59
+ self ,
60
+ ):
59
61
max_workers = 10
60
62
executor = ThreadPoolExecutor (max_workers = max_workers )
61
63
@@ -65,38 +67,65 @@ def test_trace_context_propagation_in_thread_pool(self):
65
67
with self ._tracer .start_as_current_span (f"trace_{ num } " ) as span :
66
68
expected_span_context = span .get_span_context ()
67
69
expected_span_contexts .append (expected_span_context )
68
- future = executor .submit (self .fake_func )
70
+ future = executor .submit (
71
+ self .get_current_span_context_for_test
72
+ )
69
73
futures_list .append (future )
70
74
71
- for future in as_completed (futures_list ):
72
- future .result ()
75
+ result_span_contexts = [future .result () for future in futures_list ]
73
76
74
77
# check result
75
- self .assertEqual (len (self . _mock_span_contexts ), max_workers )
78
+ self .assertEqual (len (result_span_contexts ), max_workers )
76
79
self .assertEqual (
77
- len (self . _mock_span_contexts ), len (expected_span_contexts )
80
+ len (result_span_contexts ), len (expected_span_contexts )
78
81
)
79
- for index , mock_span_context in enumerate (self . _mock_span_contexts ):
80
- self .assert_span_context_equality (
81
- mock_span_context , expected_span_contexts [index ]
82
+ for index , result_span_context in enumerate (result_span_contexts ):
83
+ self .assertEqual (
84
+ result_span_context , expected_span_contexts [index ]
82
85
)
83
86
84
- def fake_func (self ):
85
- span_context = trace .get_current_span ().get_span_context ()
87
+ def test_trace_context_propagation_in_thread_pool_with_single_worker (self ):
88
+ max_workers = 1
89
+ with ThreadPoolExecutor (max_workers = max_workers ) as executor :
90
+ # test propagation of the same trace context across multiple tasks
91
+ with self ._tracer .start_as_current_span (f"task" ) as task_span :
92
+ expected_task_context = task_span .get_span_context ()
93
+ future1 = executor .submit (
94
+ self .get_current_span_context_for_test
95
+ )
96
+ future2 = executor .submit (
97
+ self .get_current_span_context_for_test
98
+ )
99
+
100
+ # check result
101
+ self .assertEqual (future1 .result (), expected_task_context )
102
+ self .assertEqual (future2 .result (), expected_task_context )
103
+
104
+ # test propagation of different trace contexts across tasks in sequence
105
+ with self ._tracer .start_as_current_span (f"task1" ) as task1_span :
106
+ expected_task1_context = task1_span .get_span_context ()
107
+ future1 = executor .submit (
108
+ self .get_current_span_context_for_test
109
+ )
110
+
111
+ # check result
112
+ self .assertEqual (future1 .result (), expected_task1_context )
113
+
114
+ with self ._tracer .start_as_current_span (f"task2" ) as task2_span :
115
+ expected_task2_context = task2_span .get_span_context ()
116
+ future2 = executor .submit (
117
+ self .get_current_span_context_for_test
118
+ )
119
+
120
+ # check result
121
+ self .assertEqual (future2 .result (), expected_task2_context )
122
+
123
+ def fake_func (self ) -> trace .SpanContext :
124
+ span_context = self .get_current_span_context_for_test ()
86
125
self ._mock_span_contexts .append (span_context )
87
126
88
- def assert_span_context_equality (
89
- self ,
90
- result_span_context : trace .SpanContext ,
91
- expected_span_context : trace .SpanContext ,
92
- ):
93
- self .assertEqual (result_span_context , expected_span_context )
94
- self .assertEqual (
95
- result_span_context .trace_id , expected_span_context .trace_id
96
- )
97
- self .assertEqual (
98
- result_span_context .span_id , expected_span_context .span_id
99
- )
127
+ def get_current_span_context_for_test (self ) -> trace .SpanContext :
128
+ return trace .get_current_span ().get_span_context ()
100
129
101
130
def print_square (self , num ):
102
131
with self ._tracer .start_as_current_span ("square" ):
0 commit comments