13
13
# limitations under the License.
14
14
15
15
import uuid
16
- from typing import List , Tuple
16
+ from typing import List , Sequence , Tuple
17
17
from unittest import IsolatedAsyncioTestCase , mock
18
18
19
19
from aiokafka import (
24
24
)
25
25
from wrapt import BoundFunctionWrapper
26
26
27
+ from opentelemetry import baggage , context
27
28
from opentelemetry .instrumentation .aiokafka import AIOKafkaInstrumentor
28
- from opentelemetry .sdk .trace import Span
29
+ from opentelemetry .sdk .trace import ReadableSpan
29
30
from opentelemetry .semconv ._incubating .attributes import messaging_attributes
30
31
from opentelemetry .semconv .attributes import server_attributes
31
32
from opentelemetry .test .test_base import TestBase
32
- from opentelemetry .trace import SpanKind , format_trace_id
33
+ from opentelemetry .trace import SpanKind , format_trace_id , set_span_in_context
33
34
34
35
35
36
class TestAIOKafka (TestBase , IsolatedAsyncioTestCase ):
@@ -51,6 +52,19 @@ def consumer_record_factory(
51
52
headers = headers ,
52
53
)
53
54
55
+ @staticmethod
56
+ def producer_factory () -> AIOKafkaProducer :
57
+ producer = AIOKafkaProducer (api_version = "1.0" )
58
+
59
+ add_message_mock = mock .AsyncMock ()
60
+ producer .client ._wait_on_metadata = mock .AsyncMock ()
61
+ producer .client .bootstrap = mock .AsyncMock ()
62
+ producer ._message_accumulator .add_message = add_message_mock
63
+ producer ._sender .start = mock .AsyncMock ()
64
+ producer ._partition = mock .Mock (return_value = 1 )
65
+
66
+ return producer
67
+
54
68
def test_instrument_api (self ) -> None :
55
69
instrumentation = AIOKafkaInstrumentor ()
56
70
@@ -147,7 +161,46 @@ async def test_anext(self) -> None:
147
161
span_list = self .memory_exporter .get_finished_spans ()
148
162
self ._compare_spans (span_list , expected_spans )
149
163
150
- async def test_anext_consumer_hook (self ) -> None :
164
+ async def test_anext_baggage (self ) -> None :
165
+ received_baggage = None
166
+
167
+ async def async_consume_hook (span , * _ ) -> None :
168
+ nonlocal received_baggage
169
+ received_baggage = baggage .get_all (set_span_in_context (span ))
170
+
171
+ AIOKafkaInstrumentor ().uninstrument ()
172
+ AIOKafkaInstrumentor ().instrument (
173
+ tracer_provider = self .tracer_provider ,
174
+ async_consume_hook = async_consume_hook ,
175
+ )
176
+
177
+ consumer = AIOKafkaConsumer ()
178
+
179
+ self .memory_exporter .clear ()
180
+
181
+ getone_mock = mock .AsyncMock ()
182
+ consumer .getone = getone_mock
183
+
184
+ getone_mock .side_effect = [
185
+ self .consumer_record_factory (
186
+ 1 ,
187
+ headers = (
188
+ (
189
+ "traceparent" ,
190
+ b"00-03afa25236b8cd948fa853d67038ac79-405ff022e8247c46-01" ,
191
+ ),
192
+ ("baggage" , b"foo=bar" ),
193
+ ),
194
+ ),
195
+ self .consumer_record_factory (2 , headers = ()),
196
+ ]
197
+
198
+ await consumer .__anext__ ()
199
+ getone_mock .assert_awaited_with ()
200
+
201
+ self .assertEqual (received_baggage , {"foo" : "bar" })
202
+
203
+ async def test_anext_consume_hook (self ) -> None :
151
204
async_consume_hook_mock = mock .AsyncMock ()
152
205
153
206
AIOKafkaInstrumentor ().uninstrument ()
@@ -171,14 +224,10 @@ async def test_send(self) -> None:
171
224
AIOKafkaInstrumentor ().uninstrument ()
172
225
AIOKafkaInstrumentor ().instrument (tracer_provider = self .tracer_provider )
173
226
174
- producer = AIOKafkaProducer (api_version = "1.0" )
175
-
176
- add_message_mock = mock .AsyncMock ()
177
- producer .client ._wait_on_metadata = mock .AsyncMock ()
178
- producer .client .bootstrap = mock .AsyncMock ()
179
- producer ._message_accumulator .add_message = add_message_mock
180
- producer ._sender .start = mock .AsyncMock ()
181
- producer ._partition = mock .Mock (return_value = 1 )
227
+ producer = self .producer_factory ()
228
+ add_message_mock : mock .AsyncMock = (
229
+ producer ._message_accumulator .add_message
230
+ )
182
231
183
232
await producer .start ()
184
233
@@ -208,6 +257,33 @@ async def test_send(self) -> None:
208
257
headers = [("traceparent" , mock .ANY )],
209
258
)
210
259
260
+ async def test_send_baggage (self ) -> None :
261
+ AIOKafkaInstrumentor ().uninstrument ()
262
+ AIOKafkaInstrumentor ().instrument (tracer_provider = self .tracer_provider )
263
+
264
+ producer = self .producer_factory ()
265
+ add_message_mock : mock .AsyncMock = (
266
+ producer ._message_accumulator .add_message
267
+ )
268
+
269
+ await producer .start ()
270
+
271
+ tracer = self .tracer_provider .get_tracer (__name__ )
272
+ ctx = baggage .set_baggage ("foo" , "bar" )
273
+ context .attach (ctx )
274
+
275
+ with tracer .start_as_current_span ("test_span" , context = ctx ):
276
+ await producer .send ("topic_1" , b"value_1" )
277
+
278
+ add_message_mock .assert_awaited_with (
279
+ TopicPartition (topic = "topic_1" , partition = 1 ),
280
+ None ,
281
+ b"value_1" ,
282
+ 40.0 ,
283
+ timestamp_ms = None ,
284
+ headers = [("traceparent" , mock .ANY ), ("baggage" , b"foo=bar" )],
285
+ )
286
+
211
287
async def test_send_produce_hook (self ) -> None :
212
288
async_produce_hook_mock = mock .AsyncMock ()
213
289
@@ -217,13 +293,7 @@ async def test_send_produce_hook(self) -> None:
217
293
async_produce_hook = async_produce_hook_mock ,
218
294
)
219
295
220
- producer = AIOKafkaProducer (api_version = "1.0" )
221
-
222
- producer .client ._wait_on_metadata = mock .AsyncMock ()
223
- producer .client .bootstrap = mock .AsyncMock ()
224
- producer ._message_accumulator .add_message = mock .AsyncMock ()
225
- producer ._sender .start = mock .AsyncMock ()
226
- producer ._partition = mock .Mock (return_value = 1 )
296
+ producer = self .producer_factory ()
227
297
228
298
await producer .start ()
229
299
@@ -232,7 +302,7 @@ async def test_send_produce_hook(self) -> None:
232
302
async_produce_hook_mock .assert_awaited_once ()
233
303
234
304
def _compare_spans (
235
- self , spans : List [ Span ], expected_spans : List [dict ]
305
+ self , spans : Sequence [ ReadableSpan ], expected_spans : List [dict ]
236
306
) -> None :
237
307
self .assertEqual (len (spans ), len (expected_spans ))
238
308
for span , expected_span in zip (spans , expected_spans ):
0 commit comments