Skip to content

Commit bd16884

Browse files
committed
adding Iterbale to DefaultDict of TextMap
1 parent f3afa1f commit bd16884

File tree

2 files changed

+19
-35
lines changed

2 files changed

+19
-35
lines changed

opentelemetry-api/src/opentelemetry/trace/propagation/textmap.py

+14-29
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from opentelemetry.context.context import Context
1919

2020
TextMapPropagatorT = typing.TypeVar("TextMapPropagatorT")
21-
CarrierValT = typing.Union[typing.List[str], str]
21+
CarrierValT = typing.TypeVar("CarrierValT")
2222

2323
Setter = typing.Callable[[TextMapPropagatorT, str, str], None]
2424

@@ -34,22 +34,21 @@ def get(self, carrier: TextMapPropagatorT, key: str) -> typing.List[str]:
3434
or more values from the carrier. In the case that
3535
the value does not exist, returns an empty list.
3636
37-
Args: carrier: and object which contains values that are used to
38-
construct a Context. This object must be paired with an appropriate
39-
getter which understands how to extract a value from it.
40-
key: key of a field in carrier. Returns: first value of the
41-
propagation key or an empty list if the key doesn't exist.
37+
Args:
38+
carrier: An object which contains values that are used to
39+
construct a Context.
40+
key: key of a field in carrier.
41+
Returns: first value of the propagation key or an empty list if the
42+
key doesn't exist.
4243
"""
4344
raise NotImplementedError()
4445

4546
def keys(self, carrier: TextMapPropagatorT) -> typing.List[str]:
4647
"""Function that can retrieve all the keys in a carrier object.
4748
4849
Args:
49-
carrier: and object which contains values that are
50-
used to construct a Context. This object
51-
must be paired with an appropriate getter
52-
which understands how to extract a value from it.
50+
carrier: An object which contains values that are
51+
used to construct a Context.
5352
Returns:
5453
list of keys from the carrier.
5554
"""
@@ -60,31 +59,17 @@ class DictGetter(Getter[typing.Dict[str, CarrierValT]]):
6059
def get(
6160
self, carrier: typing.Dict[str, CarrierValT], key: str
6261
) -> typing.List[str]:
63-
val = carrier.get(key, None)
64-
if not val:
62+
value = carrier.get(key, None)
63+
if not value:
6564
return []
66-
return val if isinstance(val, typing.List) else [val]
65+
if not isinstance(value, str) and isinstance(value, typing.Iterable):
66+
return list(value)
67+
return [value]
6768

6869
def keys(self, carrier: typing.Dict[str, CarrierValT]) -> typing.List[str]:
6970
return list(carrier.keys())
7071

7172

72-
class CustomGetter(Getter[TextMapPropagatorT]):
73-
def __init__(
74-
self,
75-
get: typing.Callable[[TextMapPropagatorT, str], typing.List[str]],
76-
keys: typing.Callable[[TextMapPropagatorT], typing.List[str]],
77-
):
78-
self._get = get
79-
self._keys = keys
80-
81-
def get(self, carrier: TextMapPropagatorT, key: str) -> typing.List[str]:
82-
return self._get(carrier, key)
83-
84-
def keys(self, carrier: TextMapPropagatorT) -> typing.List[str]:
85-
return self._keys(carrier)
86-
87-
8873
class TextMapPropagator(abc.ABC):
8974
"""This class provides an interface that enables extracting and injecting
9075
context into headers of HTTP requests. HTTP frameworks and clients

opentelemetry-sdk/tests/trace/propagation/test_b3_format.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import opentelemetry.sdk.trace.propagation.b3_format as b3_format
2020
import opentelemetry.trace as trace_api
2121
from opentelemetry.context import get_current
22-
from opentelemetry.trace.propagation.textmap import CustomGetter, DictGetter
22+
from opentelemetry.trace.propagation.textmap import DictGetter
2323

2424
FORMAT = b3_format.B3Format()
2525

@@ -320,13 +320,12 @@ def test_inject_empty_context():
320320
def test_default_span():
321321
"""Make sure propagator does not crash when working with DefaultSpan"""
322322

323-
def default_span_getter(carrier, key):
324-
return carrier.get(key, None)
323+
class CarrierGetter(DictGetter):
324+
def get(self, carrier, key):
325+
return carrier.get(key, None)
325326

326327
def setter(carrier, key, value):
327328
carrier[key] = value
328329

329-
ctx = FORMAT.extract(
330-
CustomGetter(default_span_getter, DictGetter().keys), {}
331-
)
330+
ctx = FORMAT.extract(CarrierGetter(), {})
332331
FORMAT.inject(setter, {}, ctx)

0 commit comments

Comments
 (0)