11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
- from unittest import TestCase
15
14
16
- from aiokafka import AIOKafkaConsumer , AIOKafkaProducer
15
+ import uuid
16
+ from typing import List , Tuple
17
+ from unittest import IsolatedAsyncioTestCase , mock
18
+
19
+ from aiokafka import (
20
+ AIOKafkaConsumer ,
21
+ AIOKafkaProducer ,
22
+ ConsumerRecord ,
23
+ TopicPartition ,
24
+ )
17
25
from wrapt import BoundFunctionWrapper
18
26
19
27
from opentelemetry .instrumentation .aiokafka import AIOKafkaInstrumentor
28
+ from opentelemetry .sdk .trace import Span
29
+ from opentelemetry .semconv ._incubating .attributes import messaging_attributes
30
+ from opentelemetry .semconv .attributes import server_attributes
31
+ from opentelemetry .test .test_base import TestBase
32
+ from opentelemetry .trace import SpanKind , format_trace_id
33
+
20
34
35
+ class TestAIOKafka (TestBase , IsolatedAsyncioTestCase ):
36
+ @staticmethod
37
+ def consumer_record_factory (
38
+ number : int , headers : Tuple [Tuple [str , bytes ], ...]
39
+ ) -> ConsumerRecord :
40
+ return ConsumerRecord (
41
+ f"topic_{ number } " ,
42
+ number ,
43
+ number ,
44
+ number ,
45
+ number ,
46
+ f"key_{ number } " .encode (),
47
+ f"value_{ number } " .encode (),
48
+ None ,
49
+ number ,
50
+ number ,
51
+ headers = headers ,
52
+ )
21
53
22
- class TestAIOKafka (TestCase ):
23
54
def test_instrument_api (self ) -> None :
24
55
instrumentation = AIOKafkaInstrumentor ()
25
56
@@ -38,3 +69,175 @@ def test_instrument_api(self) -> None:
38
69
self .assertFalse (
39
70
isinstance (AIOKafkaConsumer .__anext__ , BoundFunctionWrapper )
40
71
)
72
+
73
+ async def test_anext (self ) -> None :
74
+ AIOKafkaInstrumentor ().uninstrument ()
75
+ AIOKafkaInstrumentor ().instrument (tracer_provider = self .tracer_provider )
76
+
77
+ client_id = str (uuid .uuid4 ())
78
+ group_id = str (uuid .uuid4 ())
79
+ consumer = AIOKafkaConsumer (client_id = client_id , group_id = group_id )
80
+
81
+ expected_spans = [
82
+ {
83
+ "name" : "topic_1 receive" ,
84
+ "kind" : SpanKind .CONSUMER ,
85
+ "attributes" : {
86
+ messaging_attributes .MESSAGING_SYSTEM : messaging_attributes .MessagingSystemValues .KAFKA .value ,
87
+ server_attributes .SERVER_ADDRESS : '"localhost"' ,
88
+ messaging_attributes .MESSAGING_CLIENT_ID : client_id ,
89
+ messaging_attributes .MESSAGING_DESTINATION_NAME : "topic_1" ,
90
+ messaging_attributes .MESSAGING_DESTINATION_PARTITION_ID : "1" ,
91
+ messaging_attributes .MESSAGING_KAFKA_MESSAGE_KEY : "key_1" ,
92
+ messaging_attributes .MESSAGING_CONSUMER_GROUP_NAME : group_id ,
93
+ messaging_attributes .MESSAGING_OPERATION_NAME : "receive" ,
94
+ messaging_attributes .MESSAGING_OPERATION_TYPE : messaging_attributes .MessagingOperationTypeValues .RECEIVE .value ,
95
+ messaging_attributes .MESSAGING_KAFKA_MESSAGE_OFFSET : 1 ,
96
+ messaging_attributes .MESSAGING_MESSAGE_ID : "topic_1.1.1" ,
97
+ },
98
+ },
99
+ {
100
+ "name" : "topic_2 receive" ,
101
+ "kind" : SpanKind .CONSUMER ,
102
+ "attributes" : {
103
+ messaging_attributes .MESSAGING_SYSTEM : messaging_attributes .MessagingSystemValues .KAFKA .value ,
104
+ server_attributes .SERVER_ADDRESS : '"localhost"' ,
105
+ messaging_attributes .MESSAGING_CLIENT_ID : client_id ,
106
+ messaging_attributes .MESSAGING_DESTINATION_NAME : "topic_2" ,
107
+ messaging_attributes .MESSAGING_DESTINATION_PARTITION_ID : "2" ,
108
+ messaging_attributes .MESSAGING_KAFKA_MESSAGE_KEY : "key_2" ,
109
+ messaging_attributes .MESSAGING_CONSUMER_GROUP_NAME : group_id ,
110
+ messaging_attributes .MESSAGING_OPERATION_NAME : "receive" ,
111
+ messaging_attributes .MESSAGING_OPERATION_TYPE : messaging_attributes .MessagingOperationTypeValues .RECEIVE .value ,
112
+ messaging_attributes .MESSAGING_KAFKA_MESSAGE_OFFSET : 2 ,
113
+ messaging_attributes .MESSAGING_MESSAGE_ID : "topic_2.2.2" ,
114
+ },
115
+ },
116
+ ]
117
+ self .memory_exporter .clear ()
118
+
119
+ getone_mock = mock .AsyncMock ()
120
+ consumer .getone = getone_mock
121
+
122
+ getone_mock .side_effect = [
123
+ self .consumer_record_factory (
124
+ 1 ,
125
+ headers = (
126
+ (
127
+ "traceparent" ,
128
+ b"00-03afa25236b8cd948fa853d67038ac79-405ff022e8247c46-01" ,
129
+ ),
130
+ ),
131
+ ),
132
+ self .consumer_record_factory (2 , headers = ()),
133
+ ]
134
+
135
+ await consumer .__anext__ ()
136
+ getone_mock .assert_awaited_with ()
137
+
138
+ first_span = self .memory_exporter .get_finished_spans ()[0 ]
139
+ self .assertEqual (
140
+ format_trace_id (first_span .get_span_context ().trace_id ),
141
+ "03afa25236b8cd948fa853d67038ac79" ,
142
+ )
143
+
144
+ await consumer .__anext__ ()
145
+ getone_mock .assert_awaited_with ()
146
+
147
+ span_list = self .memory_exporter .get_finished_spans ()
148
+ self ._compare_spans (span_list , expected_spans )
149
+
150
+ async def test_anext_consumer_hook (self ) -> None :
151
+ async_consume_hook_mock = mock .AsyncMock ()
152
+
153
+ AIOKafkaInstrumentor ().uninstrument ()
154
+ AIOKafkaInstrumentor ().instrument (
155
+ tracer_provider = self .tracer_provider ,
156
+ async_consume_hook = async_consume_hook_mock ,
157
+ )
158
+
159
+ consumer = AIOKafkaConsumer ()
160
+
161
+ getone_mock = mock .AsyncMock ()
162
+ consumer .getone = getone_mock
163
+
164
+ getone_mock .side_effect = [self .consumer_record_factory (1 , headers = ())]
165
+
166
+ await consumer .__anext__ ()
167
+
168
+ async_consume_hook_mock .assert_awaited_once ()
169
+
170
+ async def test_send (self ) -> None :
171
+ AIOKafkaInstrumentor ().uninstrument ()
172
+ AIOKafkaInstrumentor ().instrument (tracer_provider = self .tracer_provider )
173
+
174
+ producer = AIOKafkaProducer (api_version = "1.0" )
175
+
176
+ add_message_mock = mock .AsyncMock ()
177
+ producer .client ._wait_on_metadata = mock .AsyncMock ()
178
+ producer .client .bootstrap = mock .AsyncMock ()
179
+ producer ._message_accumulator .add_message = add_message_mock
180
+ producer ._sender .start = mock .AsyncMock ()
181
+ producer ._partition = mock .Mock (return_value = 1 )
182
+
183
+ await producer .start ()
184
+
185
+ tracer = self .tracer_provider .get_tracer (__name__ )
186
+ with tracer .start_as_current_span ("test_span" ) as span :
187
+ await producer .send ("topic_1" , b"value_1" )
188
+
189
+ add_message_mock .assert_awaited_with (
190
+ TopicPartition (topic = "topic_1" , partition = 1 ),
191
+ None ,
192
+ b"value_1" ,
193
+ 40.0 ,
194
+ timestamp_ms = None ,
195
+ headers = [("traceparent" , mock .ANY )],
196
+ )
197
+ add_message_mock .call_args_list [0 ].kwargs ["headers" ][0 ][1 ].startswith (
198
+ f"00-{ format_trace_id (span .get_span_context ().trace_id )} -" .encode ()
199
+ )
200
+
201
+ await producer .send ("topic_2" , b"value_2" )
202
+ add_message_mock .assert_awaited_with (
203
+ TopicPartition (topic = "topic_2" , partition = 1 ),
204
+ None ,
205
+ b"value_2" ,
206
+ 40.0 ,
207
+ timestamp_ms = None ,
208
+ headers = [("traceparent" , mock .ANY )],
209
+ )
210
+
211
+ async def test_send_produce_hook (self ) -> None :
212
+ async_produce_hook_mock = mock .AsyncMock ()
213
+
214
+ AIOKafkaInstrumentor ().uninstrument ()
215
+ AIOKafkaInstrumentor ().instrument (
216
+ tracer_provider = self .tracer_provider ,
217
+ async_produce_hook = async_produce_hook_mock ,
218
+ )
219
+
220
+ producer = AIOKafkaProducer (api_version = "1.0" )
221
+
222
+ producer .client ._wait_on_metadata = mock .AsyncMock ()
223
+ producer .client .bootstrap = mock .AsyncMock ()
224
+ producer ._message_accumulator .add_message = mock .AsyncMock ()
225
+ producer ._sender .start = mock .AsyncMock ()
226
+ producer ._partition = mock .Mock (return_value = 1 )
227
+
228
+ await producer .start ()
229
+
230
+ await producer .send ("topic_1" , b"value_1" )
231
+
232
+ async_produce_hook_mock .assert_awaited_once ()
233
+
234
+ def _compare_spans (
235
+ self , spans : List [Span ], expected_spans : List [dict ]
236
+ ) -> None :
237
+ self .assertEqual (len (spans ), len (expected_spans ))
238
+ for span , expected_span in zip (spans , expected_spans ):
239
+ self .assertEqual (expected_span ["name" ], span .name )
240
+ self .assertEqual (expected_span ["kind" ], span .kind )
241
+ self .assertEqual (
242
+ expected_span ["attributes" ], dict (span .attributes )
243
+ )
0 commit comments