Skip to content

Commit 3eb27ca

Browse files
authored
Return none for Getter if key does not exist (#233)
1 parent 3b48a38 commit 3eb27ca

File tree

9 files changed

+114
-14
lines changed

9 files changed

+114
-14
lines changed

Diff for: .github/workflows/test.yml

+1-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@ on:
66
- 'release/*'
77
pull_request:
88
env:
9-
CORE_REPO_SHA: f69e12fba8d0afd587dd21adbedfe751153aa73c
10-
9+
CORE_REPO_SHA: master
1110

1211
jobs:
1312
build:

Diff for: instrumentation/opentelemetry-instrumentation-asgi/CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
## Unreleased
44

5+
- Return `None` for `CarrierGetter` if key not found
6+
([#1374](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/233))
7+
58
## Version 0.12b0
69

710
Released 2020-08-14

Diff for: instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434

3535

3636
class CarrierGetter(DictGetter):
37-
def get(self, carrier: dict, key: str) -> typing.List[str]:
37+
def get(
38+
self, carrier: dict, key: str
39+
) -> typing.Optional[typing.List[str]]:
3840
"""Getter implementation to retrieve a HTTP header value from the ASGI
3941
scope.
4042
@@ -43,14 +45,17 @@ def get(self, carrier: dict, key: str) -> typing.List[str]:
4345
key: header name in scope
4446
Returns:
4547
A list with a single string with the header value if it exists,
46-
else an empty list.
48+
else None.
4749
"""
4850
headers = carrier.get("headers")
49-
return [
51+
decoded = [
5052
_value.decode("utf8")
5153
for (_key, _value) in headers
5254
if _key.decode("utf8") == key
5355
]
56+
if not decoded:
57+
return None
58+
return decoded
5459

5560

5661
carrier_getter = CarrierGetter()
@@ -82,11 +87,12 @@ def collect_request_attributes(scope):
8287
http_method = scope.get("method")
8388
if http_method:
8489
result["http.method"] = http_method
85-
http_host_value = ",".join(carrier_getter.get(scope, "host"))
86-
if http_host_value:
87-
result["http.server_name"] = http_host_value
90+
91+
http_host_value_list = carrier_getter.get(scope, "host")
92+
if http_host_value_list:
93+
result["http.server_name"] = ",".join(http_host_value_list)
8894
http_user_agent = carrier_getter.get(scope, "user-agent")
89-
if len(http_user_agent) > 0:
95+
if http_user_agent:
9096
result["http.user_agent"] = http_user_agent[0]
9197

9298
if "client" in scope and scope["client"] is not None:

Diff for: instrumentation/opentelemetry-instrumentation-asgi/tests/test_asgi_middleware.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def test_basic_asgi_call(self):
164164
outputs = self.get_all_output()
165165
self.validate_outputs(outputs)
166166

167-
def test_wsgi_not_recording(self):
167+
def test_asgi_not_recording(self):
168168
mock_tracer = mock.Mock()
169169
mock_span = mock.Mock()
170170
mock_span.is_recording.return_value = False
@@ -312,8 +312,12 @@ def setUp(self):
312312

313313
def test_request_attributes(self):
314314
self.scope["query_string"] = b"foo=bar"
315+
headers = []
316+
headers.append(("host".encode("utf8"), "test".encode("utf8")))
317+
self.scope["headers"] = headers
315318

316319
attrs = otel_asgi.collect_request_attributes(self.scope)
320+
317321
self.assertDictEqual(
318322
attrs,
319323
{
@@ -324,6 +328,7 @@ def test_request_attributes(self):
324328
"http.url": "http://127.0.0.1/?foo=bar",
325329
"host.port": 80,
326330
"http.scheme": "http",
331+
"http.server_name": "test",
327332
"http.flavor": "1.0",
328333
"net.peer.ip": "127.0.0.1",
329334
"net.peer.port": 32767,

Diff for: instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ def add(x, y):
8080

8181
class CarrierGetter(DictGetter):
8282
def get(self, carrier, key):
83-
value = getattr(carrier, key, [])
83+
value = getattr(carrier, key, None)
84+
if value is None:
85+
return None
8486
if isinstance(value, str) or not isinstance(value, Iterable):
8587
value = (value,)
8688
return value
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright The OpenTelemetry Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest import TestCase, mock
16+
17+
from opentelemetry.instrumentation.celery import CarrierGetter
18+
19+
20+
class TestCarrierGetter(TestCase):
21+
def test_get_none(self):
22+
getter = CarrierGetter()
23+
carrier = {}
24+
val = getter.get(carrier, "test")
25+
self.assertIsNone(val)
26+
27+
def test_get_str(self):
28+
mock_obj = mock.Mock()
29+
getter = CarrierGetter()
30+
mock_obj.test = "val"
31+
val = getter.get(mock_obj, "test")
32+
self.assertEqual(val, ("val",))
33+
34+
def test_get_iter(self):
35+
mock_obj = mock.Mock()
36+
getter = CarrierGetter()
37+
mock_obj.test = ["val"]
38+
val = getter.get(mock_obj, "test")
39+
self.assertEqual(val, ["val"])
40+
41+
def test_keys(self):
42+
getter = CarrierGetter()
43+
keys = getter.keys({})
44+
self.assertEqual(keys, [])

Diff for: instrumentation/opentelemetry-instrumentation-wsgi/CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
## Unreleased
44

5+
- Return `None` for `CarrierGetter` if key not found
6+
([#1374](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/233))
7+
58
## Version 0.13b0
69

710
Released 2020-09-17

Diff for: instrumentation/opentelemetry-instrumentation-wsgi/src/opentelemetry/instrumentation/wsgi/__init__.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ def hello():
6868

6969

7070
class CarrierGetter(DictGetter):
71-
def get(self, carrier: dict, key: str) -> typing.List[str]:
71+
def get(
72+
self, carrier: dict, key: str
73+
) -> typing.Optional[typing.List[str]]:
7274
"""Getter implementation to retrieve a HTTP header value from the
7375
PEP3333-conforming WSGI environ
7476
@@ -77,13 +79,13 @@ def get(self, carrier: dict, key: str) -> typing.List[str]:
7779
key: header name in environ object
7880
Returns:
7981
A list with a single string with the header value if it exists,
80-
else an empty list.
82+
else None.
8183
"""
8284
environ_key = "HTTP_" + key.upper().replace("-", "_")
8385
value = carrier.get(environ_key)
8486
if value is not None:
8587
return [value]
86-
return []
88+
return None
8789

8890
def keys(self, carrier):
8991
return []
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright The OpenTelemetry Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest import TestCase, mock
16+
17+
from opentelemetry.instrumentation.wsgi import CarrierGetter
18+
19+
20+
class TestCarrierGetter(TestCase):
21+
def test_get_none(self):
22+
getter = CarrierGetter()
23+
carrier = {}
24+
val = getter.get(carrier, "test")
25+
self.assertIsNone(val)
26+
27+
def test_get_(self):
28+
getter = CarrierGetter()
29+
carrier = {"HTTP_TEST_KEY": "val"}
30+
val = getter.get(carrier, "test-key")
31+
self.assertEqual(val, ["val"])
32+
33+
def test_keys(self):
34+
getter = CarrierGetter()
35+
keys = getter.keys({})
36+
self.assertEqual(keys, [])

0 commit comments

Comments
 (0)