Skip to content

Commit 389d707

Browse files
committed
modify kafka context getter helper methods to work on dict and list
1 parent 24b7234 commit 389d707

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

instrumentation/opentelemetry-instrumentation-confluent-kafka/src/opentelemetry/instrumentation/confluent_kafka/utils.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,26 @@ class KafkaContextGetter(textmap.Getter):
4141
def get(self, carrier: textmap.CarrierT, key: str) -> Optional[List[str]]:
4242
if carrier is None:
4343
return None
44-
for item_key, value in carrier:
44+
45+
carrier_items = carrier
46+
if isinstance(carrier, dict):
47+
carrier_items = carrier.items()
48+
49+
for item_key, value in carrier_items:
4550
if item_key == key:
4651
if value is not None:
4752
return [value.decode()]
53+
4854
return None
4955

5056
def keys(self, carrier: textmap.CarrierT) -> List[str]:
5157
if carrier is None:
5258
return []
53-
return [key for (key, value) in carrier]
59+
60+
carrier_items = carrier
61+
if isinstance(carrier, dict):
62+
carrier_items = carrier.items()
63+
return [key for (key, value) in carrier_items]
5464

5565

5666
class KafkaContextSetter(textmap.Setter):

instrumentation/opentelemetry-instrumentation-confluent-kafka/tests/test_instrumentation.py

+15
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626
from opentelemetry.instrumentation.confluent_kafka.utils import (
2727
KafkaContextSetter,
28+
KafkaContextGetter,
2829
)
2930

3031

@@ -89,3 +90,17 @@ def test_context_setter(self) -> None:
8990
carrier_list = [("key1", "val1")]
9091
context_setter.set(carrier_list, "key2", "val2")
9192
self.assertTrue(("key2", "val2".encode()) in carrier_list)
93+
94+
def test_context_getter(self) -> None:
95+
context_setter = KafkaContextSetter()
96+
context_getter = KafkaContextGetter()
97+
98+
carrier_dict = {}
99+
context_setter.set(carrier_dict, "key1", "val1")
100+
self.assertEqual(context_getter.get(carrier_dict, "key1"), ["val1"])
101+
self.assertEqual(["key1"], context_getter.keys(carrier_dict))
102+
103+
carrier_list = []
104+
context_setter.set(carrier_list, "key1", "val1")
105+
self.assertEqual(context_getter.get(carrier_list, "key1"), ["val1"])
106+
self.assertEqual(["key1"], context_getter.keys(carrier_list))

0 commit comments

Comments
 (0)