Skip to content

Commit 0d1d010

Browse files
Better context management
1 parent 1c0eea0 commit 0d1d010

File tree

3 files changed

+84
-22
lines changed

3 files changed

+84
-22
lines changed

instrumentation/opentelemetry-instrumentation-kafka-python/src/opentelemetry/instrumentation/kafka/__init__.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
API
4242
___
4343
"""
44+
import atexit
4445
from typing import Collection
4546

4647
import kafka
@@ -50,6 +51,7 @@
5051
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
5152
from opentelemetry.instrumentation.kafka.package import _instruments
5253
from opentelemetry.instrumentation.kafka.utils import (
54+
KafkaInstrumentorContextManager,
5355
_wrap_next,
5456
_wrap_send,
5557
dummy_callback,
@@ -83,11 +85,16 @@ def _instrument(self, **kwargs):
8385
__name__, __version__, tracer_provider=tracer_provider
8486
)
8587

88+
context_manager = KafkaInstrumentorContextManager()
89+
atexit.register(context_manager.close)
90+
8691
wrap_function_wrapper(
8792
kafka.KafkaProducer, "send", _wrap_send(tracer, produce_hook)
8893
)
8994
wrap_function_wrapper(
90-
kafka.KafkaConsumer, "__next__", _wrap_next(tracer, consume_hook)
95+
kafka.KafkaConsumer,
96+
"__next__",
97+
_wrap_next(tracer, context_manager, consume_hook),
9198
)
9299

93100
def _uninstrument(self, **kwargs):

instrumentation/opentelemetry-instrumentation-kafka-python/src/opentelemetry/instrumentation/kafka/utils.py

+57-10
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
from logging import getLogger
33
from typing import Callable, Dict, List, Optional
44

5+
from kafka import KafkaConsumer
6+
57
from opentelemetry import trace
6-
from opentelemetry.context import attach
8+
from opentelemetry.context import attach, detach
9+
from opentelemetry.context.context import Context
710
from opentelemetry.propagate import extract, inject
811
from opentelemetry.propagators import textmap
912
from opentelemetry.semconv.trace import SpanAttributes
@@ -13,6 +16,46 @@
1316
_LOG = getLogger(__name__)
1417

1518

19+
class KafkaInstrumentorContextManager:
20+
def __init__(self):
21+
self.spans = dict()
22+
self.tokens = dict()
23+
24+
def set_consumer_context(
25+
self, consumer: KafkaConsumer, context: Context, span: Span
26+
):
27+
self.set_span(consumer, span)
28+
self.attach_context(consumer, context)
29+
30+
def set_span(self, consumer: KafkaConsumer, span: Span):
31+
self.close_span(consumer)
32+
self.spans[consumer] = span
33+
34+
def close_span(self, consumer: KafkaConsumer):
35+
if consumer in self.spans:
36+
self.spans.get(consumer).close()
37+
del self.spans[consumer]
38+
39+
def attach_context(self, consumer: KafkaConsumer, context: Context):
40+
self.detach_context(consumer)
41+
self.tokens[consumer] = attach(context)
42+
43+
def detach_context(self, consumer: KafkaConsumer):
44+
if consumer in self.tokens:
45+
detach(self.tokens.get(consumer))
46+
del self.tokens[consumer]
47+
48+
def close(self, kafka_consumer: KafkaConsumer = None):
49+
if kafka_consumer:
50+
self.close_span(kafka_consumer)
51+
self.detach_context(kafka_consumer)
52+
else:
53+
for consumer in self.spans:
54+
self.close_span(consumer)
55+
for consumer in self.tokens:
56+
self.detach_context(consumer)
57+
58+
1659
class KafkaPropertiesExtractor:
1760
@staticmethod
1861
def extract_bootstrap_servers(instance):
@@ -167,26 +210,30 @@ def _traced_send(func, instance, args, kwargs):
167210

168211

169212
def _start_consume_span_with_extracted_context(
170-
tracer: Tracer, headers: List, topic: str
213+
tracer: Tracer,
214+
context_manager: KafkaInstrumentorContextManager,
215+
instance: KafkaConsumer,
216+
headers: List,
217+
topic: str,
171218
) -> Span:
172219
extracted_context = extract(headers, getter=_kafka_getter)
173220
span_name = _get_span_name("receive", topic)
174221
span = tracer.start_span(
175222
span_name, context=extracted_context, kind=trace.SpanKind.CONSUMER
176223
)
177224
new_context = set_span_in_context(span, extracted_context)
178-
attach(new_context)
225+
context_manager.set_consumer_context(instance, new_context, span)
179226
return span
180227

181228

182-
def _wrap_next(tracer: Tracer, consume_hook: HookT) -> Callable:
229+
def _wrap_next(
230+
tracer: Tracer,
231+
context_manager: KafkaInstrumentorContextManager,
232+
consume_hook: HookT,
233+
) -> Callable:
183234
def _traced_next(func, instance, args, kwargs):
184235
# End the current span if exists before processing the next record
185-
current_span = trace.get_current_span()
186-
if current_span.is_recording() and current_span.name.startswith(
187-
"receive"
188-
):
189-
current_span.end()
236+
context_manager.close(instance)
190237

191238
record = func(*args, **kwargs)
192239

@@ -198,7 +245,7 @@ def _traced_next(func, instance, args, kwargs):
198245
)
199246
partition = record.partition
200247
span = _start_consume_span_with_extracted_context(
201-
tracer, headers, topic
248+
tracer, context_manager, instance, headers, topic
202249
)
203250
with trace.use_span(span):
204251
_enrich_span(span, bootstrap_servers, topic, partition)

instrumentation/opentelemetry-instrumentation-kafka-python/tests/test_utils.py

+19-11
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ def test_wrap_send(
7676
)
7777
self.assertEqual(retval, original_send_callback.return_value)
7878

79-
@mock.patch("opentelemetry.trace.get_current_span")
8079
@mock.patch("opentelemetry.trace.use_span")
8180
@mock.patch(
8281
"opentelemetry.instrumentation.kafka.utils._start_consume_span_with_extracted_context"
@@ -91,31 +90,34 @@ def test_wrap_next(
9190
enrich_span: mock.MagicMock,
9291
start_consume_span_with_extracted_context: mock.MagicMock,
9392
use_span: mock.MagicMock,
94-
get_current_span: mock.MagicMock,
9593
) -> None:
9694
tracer = mock.MagicMock()
9795
consume_hook = mock.MagicMock()
9896
original_next_callback = mock.MagicMock()
9997
kafka_consumer = mock.MagicMock()
98+
context_manager = mock.MagicMock()
10099

101-
wrapped_next = _wrap_next(tracer, consume_hook)
100+
wrapped_next = _wrap_next(tracer, context_manager, consume_hook)
102101
record = wrapped_next(
103102
original_next_callback, kafka_consumer, self.args, self.kwargs
104103
)
105104

106105
extract_bootstrap_servers.assert_called_once_with(kafka_consumer)
107106
bootstrap_servers = extract_bootstrap_servers.return_value
108-
get_current_span.assert_called_once()
109-
current_span = get_current_span.return_value
110-
current_span.end.assert_called_once()
107+
108+
context_manager.close.assert_called_once_with(kafka_consumer)
111109

112110
original_next_callback.assert_called_once_with(
113111
*self.args, **self.kwargs
114112
)
115113
self.assertEqual(record, original_next_callback.return_value)
116114

117115
start_consume_span_with_extracted_context.assert_called_once_with(
118-
tracer, record.headers, record.topic
116+
tracer,
117+
context_manager,
118+
kafka_consumer,
119+
record.headers,
120+
record.topic,
119121
)
120122
span = start_consume_span_with_extracted_context.return_value
121123
use_span.assert_called_once_with(span)
@@ -124,20 +126,24 @@ def test_wrap_next(
124126
)
125127
consume_hook.assert_called_once_with(span, self.args, self.kwargs)
126128

127-
@mock.patch("opentelemetry.context.attach")
128129
@mock.patch("opentelemetry.trace.set_span_in_context")
129130
@mock.patch("opentelemetry.propagate.extract")
130131
def test_start_consume_span_with_extracted_context(
131132
self,
132133
extract: mock.MagicMock,
133134
set_span_in_context: mock.MagicMock,
134-
attach: mock.MagicMock,
135135
):
136136
tracer = mock.MagicMock()
137+
context_manager = mock.MagicMock()
138+
kafka_consumer = mock.MagicMock()
137139
expected_span_name = _get_span_name("receive", self.topic_name)
138140

139141
_start_consume_span_with_extracted_context(
140-
tracer, self.headers, self.topic_name
142+
tracer,
143+
context_manager,
144+
kafka_consumer,
145+
self.headers,
146+
self.topic_name,
141147
)
142148

143149
extract.assert_called_once_with(self.headers, _kafka_getter)
@@ -148,4 +154,6 @@ def test_start_consume_span_with_extracted_context(
148154
span = tracer.start_span.return_value
149155
set_span_in_context.assert_called_once_with(span, context)
150156
new_context = set_span_in_context.return_value
151-
attach.assert_called_once_with(new_context)
157+
context_manager.set_consumer_context.assert_called_once_with(
158+
kafka_consumer, new_context, span
159+
)

0 commit comments

Comments
 (0)