17
17
18
18
from opentelemetry import trace as trace_api
19
19
from opentelemetry .exporter .datadog import constants , propagator
20
- from opentelemetry .propagators .textmap import DictGetter
21
20
from opentelemetry .sdk import trace
22
21
from opentelemetry .sdk .trace .id_generator import RandomIdGenerator
23
22
from opentelemetry .trace import get_current_span , set_span_in_context
24
23
25
24
FORMAT = propagator .DatadogFormat ()
26
25
27
- carrier_getter = DictGetter ()
28
-
29
26
30
27
class TestDatadogFormat (unittest .TestCase ):
31
28
@classmethod
@@ -45,7 +42,6 @@ def test_malformed_headers(self):
45
42
malformed_parent_id_key = FORMAT .PARENT_ID_KEY + "-x"
46
43
context = get_current_span (
47
44
FORMAT .extract (
48
- carrier_getter ,
49
45
{
50
46
malformed_trace_id_key : self .serialized_trace_id ,
51
47
malformed_parent_id_key : self .serialized_parent_id ,
@@ -63,7 +59,7 @@ def test_missing_trace_id(self):
63
59
FORMAT .PARENT_ID_KEY : self .serialized_parent_id ,
64
60
}
65
61
66
- ctx = FORMAT .extract (carrier_getter , carrier )
62
+ ctx = FORMAT .extract (carrier )
67
63
span_context = get_current_span (ctx ).get_span_context ()
68
64
self .assertEqual (span_context .trace_id , trace_api .INVALID_TRACE_ID )
69
65
@@ -73,15 +69,14 @@ def test_missing_parent_id(self):
73
69
FORMAT .TRACE_ID_KEY : self .serialized_trace_id ,
74
70
}
75
71
76
- ctx = FORMAT .extract (carrier_getter , carrier )
72
+ ctx = FORMAT .extract (carrier )
77
73
span_context = get_current_span (ctx ).get_span_context ()
78
74
self .assertEqual (span_context .span_id , trace_api .INVALID_SPAN_ID )
79
75
80
76
def test_context_propagation (self ):
81
77
"""Test the propagation of Datadog headers."""
82
78
parent_span_context = get_current_span (
83
79
FORMAT .extract (
84
- carrier_getter ,
85
80
{
86
81
FORMAT .TRACE_ID_KEY : self .serialized_trace_id ,
87
82
FORMAT .PARENT_ID_KEY : self .serialized_parent_id ,
@@ -118,7 +113,7 @@ def test_context_propagation(self):
118
113
119
114
child_carrier = {}
120
115
child_context = set_span_in_context (child )
121
- FORMAT .inject (dict . __setitem__ , child_carrier , context = child_context )
116
+ FORMAT .inject (child_carrier , context = child_context )
122
117
123
118
self .assertEqual (
124
119
child_carrier [FORMAT .TRACE_ID_KEY ], self .serialized_trace_id
@@ -138,7 +133,6 @@ def test_sampling_priority_auto_reject(self):
138
133
"""Test sampling priority rejected."""
139
134
parent_span_context = get_current_span (
140
135
FORMAT .extract (
141
- carrier_getter ,
142
136
{
143
137
FORMAT .TRACE_ID_KEY : self .serialized_trace_id ,
144
138
FORMAT .PARENT_ID_KEY : self .serialized_parent_id ,
@@ -165,7 +159,7 @@ def test_sampling_priority_auto_reject(self):
165
159
166
160
child_carrier = {}
167
161
child_context = set_span_in_context (child )
168
- FORMAT .inject (dict . __setitem__ , child_carrier , context = child_context )
162
+ FORMAT .inject (child_carrier , context = child_context )
169
163
170
164
self .assertEqual (
171
165
child_carrier [FORMAT .SAMPLING_PRIORITY_KEY ],
@@ -178,8 +172,6 @@ def test_fields(self, mock_get_current_span):
178
172
179
173
tracer = trace .TracerProvider ().get_tracer ("sdk_tracer_provider" )
180
174
181
- mock_set_in_carrier = Mock ()
182
-
183
175
mock_get_current_span .configure_mock (
184
176
** {
185
177
"return_value" : Mock (
@@ -193,13 +185,10 @@ def test_fields(self, mock_get_current_span):
193
185
}
194
186
)
195
187
188
+ carrier = {}
189
+
196
190
with tracer .start_as_current_span ("parent" ):
197
191
with tracer .start_as_current_span ("child" ):
198
- FORMAT .inject (mock_set_in_carrier , {})
199
-
200
- inject_fields = set ()
201
-
202
- for call in mock_set_in_carrier .mock_calls :
203
- inject_fields .add (call [1 ][1 ])
192
+ FORMAT .inject (carrier )
204
193
205
- self .assertEqual (FORMAT .fields , inject_fields )
194
+ self .assertEqual (FORMAT .fields , carrier . keys () )
0 commit comments