Skip to content

Commit 444161d

Browse files
lzchenAlex Boten
authored and
Alex Boten
committed
Return none for Getter if key does not exist (open-telemetry#1449)
1 parent 7500f73 commit 444161d

File tree

5 files changed

+58
-9
lines changed

5 files changed

+58
-9
lines changed

Diff for: opentelemetry-api/src/opentelemetry/baggage/propagation/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def _format_baggage(baggage_entries: typing.Mapping[str, object]) -> str:
101101

102102

103103
def _extract_first_element(
104-
items: typing.Iterable[textmap.TextMapPropagatorT],
104+
items: typing.Optional[typing.Iterable[textmap.TextMapPropagatorT]],
105105
) -> typing.Optional[textmap.TextMapPropagatorT]:
106106
if items is None:
107107
return None

Diff for: opentelemetry-api/src/opentelemetry/trace/propagation/textmap.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ class Getter(typing.Generic[TextMapPropagatorT]):
2929
3030
"""
3131

32-
def get(self, carrier: TextMapPropagatorT, key: str) -> typing.List[str]:
32+
def get(
33+
self, carrier: TextMapPropagatorT, key: str
34+
) -> typing.Optional[typing.List[str]]:
3335
"""Function that can retrieve zero
3436
or more values from the carrier. In the case that
3537
the value does not exist, returns an empty list.
@@ -38,8 +40,8 @@ def get(self, carrier: TextMapPropagatorT, key: str) -> typing.List[str]:
3840
carrier: An object which contains values that are used to
3941
construct a Context.
4042
key: key of a field in carrier.
41-
Returns: first value of the propagation key or an empty list if the
42-
key doesn't exist.
43+
Returns: first value of the propagation key or None if the key doesn't
44+
exist.
4345
"""
4446
raise NotImplementedError()
4547

@@ -58,8 +60,10 @@ def keys(self, carrier: TextMapPropagatorT) -> typing.List[str]:
5860
class DictGetter(Getter[typing.Dict[str, CarrierValT]]):
5961
def get(
6062
self, carrier: typing.Dict[str, CarrierValT], key: str
61-
) -> typing.List[str]:
62-
val = carrier.get(key, [])
63+
) -> typing.Optional[typing.List[str]]:
64+
val = carrier.get(key, None)
65+
if val is None:
66+
return None
6367
if isinstance(val, typing.Iterable) and not isinstance(val, str):
6468
return list(val)
6569
return [val]

Diff for: opentelemetry-api/src/opentelemetry/trace/propagation/tracecontext.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,10 @@ def extract(
9191
return trace.set_span_in_context(trace.INVALID_SPAN, context)
9292

9393
tracestate_headers = getter.get(carrier, self._TRACESTATE_HEADER_NAME)
94-
tracestate = _parse_tracestate(tracestate_headers)
94+
if tracestate_headers is None:
95+
tracestate = None
96+
else:
97+
tracestate = _parse_tracestate(tracestate_headers)
9598

9699
span_context = trace.SpanContext(
97100
trace_id=int(trace_id, 16),

Diff for: opentelemetry-api/src/opentelemetry/trace/span.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,8 @@ def __new__(
175175
trace_id: int,
176176
span_id: int,
177177
is_remote: bool,
178-
trace_flags: "TraceFlags" = DEFAULT_TRACE_OPTIONS,
179-
trace_state: "TraceState" = DEFAULT_TRACE_STATE,
178+
trace_flags: typing.Optional["TraceFlags"] = DEFAULT_TRACE_OPTIONS,
179+
trace_state: typing.Optional["TraceState"] = DEFAULT_TRACE_STATE,
180180
) -> "SpanContext":
181181
if trace_flags is None:
182182
trace_flags = DEFAULT_TRACE_OPTIONS
+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright The OpenTelemetry Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
from opentelemetry.trace.propagation.textmap import DictGetter
18+
19+
20+
class TestDictGetter(unittest.TestCase):
21+
def test_get_none(self):
22+
getter = DictGetter()
23+
carrier = {}
24+
val = getter.get(carrier, "test")
25+
self.assertIsNone(val)
26+
27+
def test_get_str(self):
28+
getter = DictGetter()
29+
carrier = {"test": "val"}
30+
val = getter.get(carrier, "test")
31+
self.assertEqual(val, ["val"])
32+
33+
def test_get_iter(self):
34+
getter = DictGetter()
35+
carrier = {"test": ["val"]}
36+
val = getter.get(carrier, "test")
37+
self.assertEqual(val, ["val"])
38+
39+
def test_keys(self):
40+
getter = DictGetter()
41+
keys = getter.keys({"test": "val"})
42+
self.assertEqual(keys, ["test"])

0 commit comments

Comments
 (0)