Skip to content

Commit d89f865

Browse files
authored
Sync with Make setters and getters optional (#372)
1 parent 36e7ad0 commit d89f865

File tree

27 files changed

+113
-144
lines changed

27 files changed

+113
-144
lines changed

.github/workflows/test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ on:
66
- 'release/*'
77
pull_request:
88
env:
9-
CORE_REPO_SHA: d3694fc520f8542b232fd1065133286f4591dcec
9+
CORE_REPO_SHA: 94d9c1eb77723123a779ddd18d2b12d4bec445fb
1010

1111
jobs:
1212
build:

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

77
## [Unreleased](https://github.com/open-telemetry/opentelemetry-python-contrib/compare/v0.18b0...HEAD)
8+
- Make getters and setters optional
9+
([#372](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/372))
810
- Updated instrumentations to use `opentelemetry.trace.use_span` instead of `Tracer.use_span()`
911
([#364](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/364))
1012
- `opentelemetry-propagator-ot-trace` Do not throw an exception when headers are not present

docs/nitpick-exceptions.ini

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
[default]
22
class_references=
33
; TODO: Understand why sphinx is not able to find this local class
4+
opentelemetry.propagators.textmap.CarrierT
5+
opentelemetry.propagators.textmap.Setter
6+
opentelemetry.propagators.textmap.Getter
47
opentelemetry.propagators.textmap.TextMapPropagator
58
; - AwsXRayFormat
6-
opentelemetry.propagators.textmap.DictGetter
9+
opentelemetry.propagators.textmap.DefaultGetter
710
; API
811
opentelemetry.propagators.textmap.Getter
912
; - DatadogFormat
1013
; - AWSXRayFormat
1114
opentelemetry.sdk.trace.id_generator.IdGenerator
1215
; - AwsXRayIdGenerator
13-
TextMapPropagatorT
16+
TextMapPropagator
17+
CarrierT
18+
Setter
19+
Getter
1420
; - AwsXRayFormat.extract
1521

1622
anys=

exporter/opentelemetry-exporter-datadog/src/opentelemetry/exporter/datadog/propagator.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
from opentelemetry.context import Context
1919
from opentelemetry.exporter.datadog import constants
2020
from opentelemetry.propagators.textmap import (
21+
CarrierT,
2122
Getter,
2223
Setter,
2324
TextMapPropagator,
24-
TextMapPropagatorT,
25+
default_getter,
26+
default_setter,
2527
)
2628
from opentelemetry.trace import get_current_span, set_span_in_context
2729

@@ -36,9 +38,9 @@ class DatadogFormat(TextMapPropagator):
3638

3739
def extract(
3840
self,
39-
getter: Getter[TextMapPropagatorT],
40-
carrier: TextMapPropagatorT,
41+
carrier: CarrierT,
4142
context: typing.Optional[Context] = None,
43+
getter: Getter = default_getter,
4244
) -> Context:
4345
trace_id = extract_first_element(
4446
getter.get(carrier, self.TRACE_ID_KEY)
@@ -81,28 +83,28 @@ def extract(
8183

8284
def inject(
8385
self,
84-
set_in_carrier: Setter[TextMapPropagatorT],
85-
carrier: TextMapPropagatorT,
86+
carrier: CarrierT,
8687
context: typing.Optional[Context] = None,
88+
setter: Setter = default_setter,
8789
) -> None:
8890
span = get_current_span(context)
8991
span_context = span.get_span_context()
9092
if span_context == trace.INVALID_SPAN_CONTEXT:
9193
return
9294
sampled = (trace.TraceFlags.SAMPLED & span.context.trace_flags) != 0
93-
set_in_carrier(
95+
setter.set(
9496
carrier, self.TRACE_ID_KEY, format_trace_id(span.context.trace_id),
9597
)
96-
set_in_carrier(
98+
setter.set(
9799
carrier, self.PARENT_ID_KEY, format_span_id(span.context.span_id)
98100
)
99-
set_in_carrier(
101+
setter.set(
100102
carrier,
101103
self.SAMPLING_PRIORITY_KEY,
102104
str(constants.AUTO_KEEP if sampled else constants.AUTO_REJECT),
103105
)
104106
if constants.DD_ORIGIN in span.context.trace_state:
105-
set_in_carrier(
107+
setter.set(
106108
carrier,
107109
self.ORIGIN_KEY,
108110
span.context.trace_state[constants.DD_ORIGIN],
@@ -134,8 +136,8 @@ def format_span_id(span_id: int) -> str:
134136

135137

136138
def extract_first_element(
137-
items: typing.Iterable[TextMapPropagatorT],
138-
) -> typing.Optional[TextMapPropagatorT]:
139+
items: typing.Iterable[CarrierT],
140+
) -> typing.Optional[CarrierT]:
139141
if items is None:
140142
return None
141143
return next(iter(items), None)

exporter/opentelemetry-exporter-datadog/tests/test_datadog_format.py

+7-13
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,12 @@
1717

1818
from opentelemetry import trace as trace_api
1919
from opentelemetry.exporter.datadog import constants, propagator
20-
from opentelemetry.propagators.textmap import DictGetter
2120
from opentelemetry.sdk import trace
2221
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
2322
from opentelemetry.trace import get_current_span, set_span_in_context
2423

2524
FORMAT = propagator.DatadogFormat()
2625

27-
carrier_getter = DictGetter()
28-
2926

3027
class TestDatadogFormat(unittest.TestCase):
3128
@classmethod
@@ -45,7 +42,6 @@ def test_malformed_headers(self):
4542
malformed_parent_id_key = FORMAT.PARENT_ID_KEY + "-x"
4643
context = get_current_span(
4744
FORMAT.extract(
48-
carrier_getter,
4945
{
5046
malformed_trace_id_key: self.serialized_trace_id,
5147
malformed_parent_id_key: self.serialized_parent_id,
@@ -63,7 +59,7 @@ def test_missing_trace_id(self):
6359
FORMAT.PARENT_ID_KEY: self.serialized_parent_id,
6460
}
6561

66-
ctx = FORMAT.extract(carrier_getter, carrier)
62+
ctx = FORMAT.extract(carrier)
6763
span_context = get_current_span(ctx).get_span_context()
6864
self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID)
6965

@@ -73,15 +69,14 @@ def test_missing_parent_id(self):
7369
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
7470
}
7571

76-
ctx = FORMAT.extract(carrier_getter, carrier)
72+
ctx = FORMAT.extract(carrier)
7773
span_context = get_current_span(ctx).get_span_context()
7874
self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID)
7975

8076
def test_context_propagation(self):
8177
"""Test the propagation of Datadog headers."""
8278
parent_span_context = get_current_span(
8379
FORMAT.extract(
84-
carrier_getter,
8580
{
8681
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
8782
FORMAT.PARENT_ID_KEY: self.serialized_parent_id,
@@ -118,7 +113,7 @@ def test_context_propagation(self):
118113

119114
child_carrier = {}
120115
child_context = set_span_in_context(child)
121-
FORMAT.inject(dict.__setitem__, child_carrier, context=child_context)
116+
FORMAT.inject(child_carrier, context=child_context)
122117

123118
self.assertEqual(
124119
child_carrier[FORMAT.TRACE_ID_KEY], self.serialized_trace_id
@@ -138,7 +133,6 @@ def test_sampling_priority_auto_reject(self):
138133
"""Test sampling priority rejected."""
139134
parent_span_context = get_current_span(
140135
FORMAT.extract(
141-
carrier_getter,
142136
{
143137
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
144138
FORMAT.PARENT_ID_KEY: self.serialized_parent_id,
@@ -165,7 +159,7 @@ def test_sampling_priority_auto_reject(self):
165159

166160
child_carrier = {}
167161
child_context = set_span_in_context(child)
168-
FORMAT.inject(dict.__setitem__, child_carrier, context=child_context)
162+
FORMAT.inject(child_carrier, context=child_context)
169163

170164
self.assertEqual(
171165
child_carrier[FORMAT.SAMPLING_PRIORITY_KEY],
@@ -178,7 +172,7 @@ def test_fields(self, mock_get_current_span):
178172

179173
tracer = trace.TracerProvider().get_tracer("sdk_tracer_provider")
180174

181-
mock_set_in_carrier = Mock()
175+
mock_setter = Mock()
182176

183177
mock_get_current_span.configure_mock(
184178
**{
@@ -195,11 +189,11 @@ def test_fields(self, mock_get_current_span):
195189

196190
with tracer.start_as_current_span("parent"):
197191
with tracer.start_as_current_span("child"):
198-
FORMAT.inject(mock_set_in_carrier, {})
192+
FORMAT.inject({}, setter=mock_setter)
199193

200194
inject_fields = set()
201195

202-
for call in mock_set_in_carrier.mock_calls:
196+
for call in mock_setter.mock_calls:
203197
inject_fields.add(call[1][1])
204198

205199
self.assertEqual(FORMAT.fields, inject_fields)

instrumentation/opentelemetry-instrumentation-aiohttp-client/src/opentelemetry/instrumentation/aiohttp_client/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ async def on_request_start(
181181
trace.set_span_in_context(trace_config_ctx.span)
182182
)
183183

184-
inject(type(params.headers).__setitem__, params.headers)
184+
inject(params.headers)
185185

186186
async def on_request_end(
187187
unused_session: aiohttp.ClientSession,

instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@
2929
from opentelemetry.instrumentation.asgi.version import __version__ # noqa
3030
from opentelemetry.instrumentation.utils import http_status_to_status_code
3131
from opentelemetry.propagate import extract
32-
from opentelemetry.propagators.textmap import DictGetter
32+
from opentelemetry.propagators.textmap import Getter
3333
from opentelemetry.trace.status import Status, StatusCode
3434

3535

36-
class CarrierGetter(DictGetter):
36+
class ASGIGetter(Getter):
3737
def get(
3838
self, carrier: dict, key: str
3939
) -> typing.Optional[typing.List[str]]:
@@ -62,8 +62,11 @@ def get(
6262
return None
6363
return decoded
6464

65+
def keys(self, carrier: dict) -> typing.List[str]:
66+
return list(carrier.keys())
6567

66-
carrier_getter = CarrierGetter()
68+
69+
asgi_getter = ASGIGetter()
6770

6871

6972
def collect_request_attributes(scope):
@@ -88,10 +91,10 @@ def collect_request_attributes(scope):
8891
if http_method:
8992
result["http.method"] = http_method
9093

91-
http_host_value_list = carrier_getter.get(scope, "host")
94+
http_host_value_list = asgi_getter.get(scope, "host")
9295
if http_host_value_list:
9396
result["http.server_name"] = ",".join(http_host_value_list)
94-
http_user_agent = carrier_getter.get(scope, "user-agent")
97+
http_user_agent = asgi_getter.get(scope, "user-agent")
9598
if http_user_agent:
9699
result["http.user_agent"] = http_user_agent[0]
97100

@@ -186,7 +189,7 @@ async def __call__(self, scope, receive, send):
186189
if self.excluded_urls and self.excluded_urls.url_disabled(url):
187190
return await self.app(scope, receive, send)
188191

189-
token = context.attach(extract(carrier_getter, scope))
192+
token = context.attach(extract(scope, getter=asgi_getter))
190193
span_name, additional_attributes = self.span_details_callback(scope)
191194

192195
try:

instrumentation/opentelemetry-instrumentation-asgi/tests/test_getter.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,18 @@
1414

1515
from unittest import TestCase
1616

17-
from opentelemetry.instrumentation.asgi import CarrierGetter
17+
from opentelemetry.instrumentation.asgi import ASGIGetter
1818

1919

20-
class TestCarrierGetter(TestCase):
20+
class TestASGIGetter(TestCase):
2121
def test_get_none(self):
22-
getter = CarrierGetter()
22+
getter = ASGIGetter()
2323
carrier = {}
2424
val = getter.get(carrier, "test")
2525
self.assertIsNone(val)
2626

2727
def test_get_(self):
28-
getter = CarrierGetter()
28+
getter = ASGIGetter()
2929
carrier = {"headers": [(b"test-key", b"val")]}
3030
expected_val = ["val"]
3131
self.assertEqual(
@@ -45,6 +45,6 @@ def test_get_(self):
4545
)
4646

4747
def test_keys(self):
48-
getter = CarrierGetter()
48+
getter = ASGIGetter()
4949
keys = getter.keys({})
5050
self.assertEqual(keys, [])

instrumentation/opentelemetry-instrumentation-botocore/src/opentelemetry/instrumentation/botocore/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
def _patched_endpoint_prepare_request(wrapped, instance, args, kwargs):
6767
request = args[0]
6868
headers = request.headers
69-
inject(type(headers).__setitem__, headers)
69+
inject(headers)
7070
return wrapped(*args, **kwargs)
7171

7272

instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def add(x, y):
6161
from opentelemetry.instrumentation.celery.version import __version__
6262
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
6363
from opentelemetry.propagate import extract, inject
64-
from opentelemetry.propagators.textmap import DictGetter
64+
from opentelemetry.propagators.textmap import Getter
6565
from opentelemetry.trace.status import Status, StatusCode
6666

6767
logger = logging.getLogger(__name__)
@@ -78,7 +78,7 @@ def add(x, y):
7878
_MESSAGE_ID_ATTRIBUTE_NAME = "messaging.message_id"
7979

8080

81-
class CarrierGetter(DictGetter):
81+
class CeleryGetter(Getter):
8282
def get(self, carrier, key):
8383
value = getattr(carrier, key, None)
8484
if value is None:
@@ -91,7 +91,7 @@ def keys(self, carrier):
9191
return []
9292

9393

94-
carrier_getter = CarrierGetter()
94+
celery_getter = CeleryGetter()
9595

9696

9797
class CeleryInstrumentor(BaseInstrumentor):
@@ -128,7 +128,7 @@ def _trace_prerun(self, *args, **kwargs):
128128
return
129129

130130
request = task.request
131-
tracectx = extract(carrier_getter, request) or None
131+
tracectx = extract(request, getter=celery_getter) or None
132132

133133
logger.debug("prerun signal start task_id=%s", task_id)
134134

@@ -193,7 +193,7 @@ def _trace_before_publish(self, *args, **kwargs):
193193

194194
headers = kwargs.get("headers")
195195
if headers:
196-
inject(type(headers).__setitem__, headers)
196+
inject(headers)
197197

198198
@staticmethod
199199
def _trace_after_publish(*args, **kwargs):

instrumentation/opentelemetry-instrumentation-celery/tests/test_getter.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -14,31 +14,31 @@
1414

1515
from unittest import TestCase, mock
1616

17-
from opentelemetry.instrumentation.celery import CarrierGetter
17+
from opentelemetry.instrumentation.celery import CeleryGetter
1818

1919

20-
class TestCarrierGetter(TestCase):
20+
class TestCeleryGetter(TestCase):
2121
def test_get_none(self):
22-
getter = CarrierGetter()
22+
getter = CeleryGetter()
2323
carrier = {}
2424
val = getter.get(carrier, "test")
2525
self.assertIsNone(val)
2626

2727
def test_get_str(self):
2828
mock_obj = mock.Mock()
29-
getter = CarrierGetter()
29+
getter = CeleryGetter()
3030
mock_obj.test = "val"
3131
val = getter.get(mock_obj, "test")
3232
self.assertEqual(val, ("val",))
3333

3434
def test_get_iter(self):
3535
mock_obj = mock.Mock()
36-
getter = CarrierGetter()
36+
getter = CeleryGetter()
3737
mock_obj.test = ["val"]
3838
val = getter.get(mock_obj, "test")
3939
self.assertEqual(val, ["val"])
4040

4141
def test_keys(self):
42-
getter = CarrierGetter()
42+
getter = CeleryGetter()
4343
keys = getter.keys({})
4444
self.assertEqual(keys, [])

0 commit comments

Comments
 (0)