Skip to content

Commit c3e9f75

Browse files
fix(opentelemetry-instrumentation-celery): attach incoming context on… (#2385)
* fix(opentelemetry-instrumentation-celery): attach incoming context on _trace_prerun * docs(CHANGELOG): add entry for fix #2385 * fix(opentelemetry-instrumentation-celery): detach context after task is run * test(opentelemetry-instrumentation-celery): add context utils tests * fix(opentelemetry-instrumentation-celery): remove duplicated signal registration * refactor(opentelemetry-instrumentation-celery): fix lint issues * refactor(opentelemetry-instrumentation-celery): fix types and tests for python 3.8 * refactor(opentelemetry-instrumentation-celery): fix lint issues * refactor(opentelemetry-instrumentation-celery): fix lint issues * fix(opentelemetry-instrumentation-celery): attach context only if it is not None * refactor(opentelemetry-instrumentation-celery): fix lint issues
1 parent 4ea9e5a commit c3e9f75

File tree

6 files changed

+114
-49
lines changed

6 files changed

+114
-49
lines changed

Diff for: CHANGELOG.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2323
([#2756](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2756))
2424
- `opentelemetry-instrumentation-aws-lambda` Fixing w3c baggage support
2525
([#2589](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2589))
26+
- `opentelemetry-instrumentation-celery` propagates baggage
27+
([#2385](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2385))
2628

2729
## Version 1.26.0/0.47b0 (2024-07-23)
2830

@@ -119,10 +121,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
119121
([#2610](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2610))
120122
- `opentelemetry-instrumentation-asgi` Bugfix: Middleware did not set status code attribute on duration metrics for non-recording spans.
121123
([#2627](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2627))
122-
<<<<<<< HEAD
123-
- `opentelemetry-instrumentation-mysql` Add support for `mysql-connector-python` v9 ([#2751](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2751))
124-
=======
125-
>>>>>>> 5a623233 (Changelog update)
124+
- `opentelemetry-instrumentation-mysql` Add support for `mysql-connector-python` v9
125+
([#2751](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2751))
126126

127127
## Version 1.25.0/0.46b0 (2024-05-31)
128128

Diff for: instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py

+35-13
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def add(x, y):
6767
from billiard.einfo import ExceptionInfo
6868
from celery import signals # pylint: disable=no-name-in-module
6969

70+
from opentelemetry import context as context_api
7071
from opentelemetry import trace
7172
from opentelemetry.instrumentation.celery import utils
7273
from opentelemetry.instrumentation.celery.package import _instruments
@@ -169,6 +170,7 @@ def _trace_prerun(self, *args, **kwargs):
169170
self.update_task_duration_time(task_id)
170171
request = task.request
171172
tracectx = extract(request, getter=celery_getter) or None
173+
token = context_api.attach(tracectx) if tracectx is not None else None
172174

173175
logger.debug("prerun signal start task_id=%s", task_id)
174176

@@ -179,7 +181,7 @@ def _trace_prerun(self, *args, **kwargs):
179181

180182
activation = trace.use_span(span, end_on_exit=True)
181183
activation.__enter__() # pylint: disable=E1101
182-
utils.attach_span(task, task_id, (span, activation))
184+
utils.attach_context(task, task_id, span, activation, token)
183185

184186
def _trace_postrun(self, *args, **kwargs):
185187
task = utils.retrieve_task(kwargs)
@@ -191,11 +193,14 @@ def _trace_postrun(self, *args, **kwargs):
191193
logger.debug("postrun signal task_id=%s", task_id)
192194

193195
# retrieve and finish the Span
194-
span, activation = utils.retrieve_span(task, task_id)
195-
if span is None:
196+
ctx = utils.retrieve_context(task, task_id)
197+
198+
if ctx is None:
196199
logger.warning("no existing span found for task_id=%s", task_id)
197200
return
198201

202+
span, activation, token = ctx
203+
199204
# request context tags
200205
if span.is_recording():
201206
span.set_attribute(_TASK_TAG_KEY, _TASK_RUN)
@@ -204,10 +209,11 @@ def _trace_postrun(self, *args, **kwargs):
204209
span.set_attribute(_TASK_NAME_KEY, task.name)
205210

206211
activation.__exit__(None, None, None)
207-
utils.detach_span(task, task_id)
212+
utils.detach_context(task, task_id)
208213
self.update_task_duration_time(task_id)
209214
labels = {"task": task.name, "worker": task.request.hostname}
210215
self._record_histograms(task_id, labels)
216+
context_api.detach(token)
211217

212218
def _trace_before_publish(self, *args, **kwargs):
213219
task = utils.retrieve_task_from_sender(kwargs)
@@ -238,7 +244,9 @@ def _trace_before_publish(self, *args, **kwargs):
238244
activation = trace.use_span(span, end_on_exit=True)
239245
activation.__enter__() # pylint: disable=E1101
240246

241-
utils.attach_span(task, task_id, (span, activation), is_publish=True)
247+
utils.attach_context(
248+
task, task_id, span, activation, None, is_publish=True
249+
)
242250

243251
headers = kwargs.get("headers")
244252
if headers:
@@ -253,13 +261,16 @@ def _trace_after_publish(*args, **kwargs):
253261
return
254262

255263
# retrieve and finish the Span
256-
_, activation = utils.retrieve_span(task, task_id, is_publish=True)
257-
if activation is None:
264+
ctx = utils.retrieve_context(task, task_id, is_publish=True)
265+
266+
if ctx is None:
258267
logger.warning("no existing span found for task_id=%s", task_id)
259268
return
260269

270+
_, activation, _ = ctx
271+
261272
activation.__exit__(None, None, None) # pylint: disable=E1101
262-
utils.detach_span(task, task_id, is_publish=True)
273+
utils.detach_context(task, task_id, is_publish=True)
263274

264275
@staticmethod
265276
def _trace_failure(*args, **kwargs):
@@ -269,9 +280,14 @@ def _trace_failure(*args, **kwargs):
269280
if task is None or task_id is None:
270281
return
271282

272-
# retrieve and pass exception info to activation
273-
span, _ = utils.retrieve_span(task, task_id)
274-
if span is None or not span.is_recording():
283+
ctx = utils.retrieve_context(task, task_id)
284+
285+
if ctx is None:
286+
return
287+
288+
span, _, _ = ctx
289+
290+
if not span.is_recording():
275291
return
276292

277293
status_kwargs = {"status_code": StatusCode.ERROR}
@@ -311,8 +327,14 @@ def _trace_retry(*args, **kwargs):
311327
if task is None or task_id is None or reason is None:
312328
return
313329

314-
span, _ = utils.retrieve_span(task, task_id)
315-
if span is None or not span.is_recording():
330+
ctx = utils.retrieve_context(task, task_id)
331+
332+
if ctx is None:
333+
return
334+
335+
span, _, _ = ctx
336+
337+
if not span.is_recording():
316338
return
317339

318340
# Add retry reason metadata to span

Diff for: instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/utils.py

+38-21
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313
# limitations under the License.
1414

1515
import logging
16+
from typing import ContextManager, Optional, Tuple
1617

1718
from celery import registry # pylint: disable=no-name-in-module
19+
from celery.app.task import Task
1820

1921
from opentelemetry.semconv.trace import SpanAttributes
22+
from opentelemetry.trace import Span
2023

2124
logger = logging.getLogger(__name__)
2225

@@ -81,10 +84,12 @@ def set_attributes_from_context(span, context):
8184
elif key == "delivery_info":
8285
# Get also destination from this
8386
routing_key = value.get("routing_key")
87+
8488
if routing_key is not None:
8589
span.set_attribute(
8690
SpanAttributes.MESSAGING_DESTINATION, routing_key
8791
)
92+
8893
value = str(value)
8994

9095
elif key == "id":
@@ -114,11 +119,18 @@ def set_attributes_from_context(span, context):
114119
span.set_attribute(attribute_name, value)
115120

116121

117-
def attach_span(task, task_id, span, is_publish=False):
118-
"""Helper to propagate a `Span` for the given `Task` instance. This
119-
function uses a `dict` that stores the Span using the
120-
`(task_id, is_publish)` as a key. This is useful when information must be
121-
propagated from one Celery signal to another.
122+
def attach_context(
123+
task: Optional[Task],
124+
task_id: str,
125+
span: Span,
126+
activation: ContextManager[Span],
127+
token: Optional[object],
128+
is_publish: bool = False,
129+
) -> None:
130+
"""Helper to propagate a `Span`, `ContextManager` and context token
131+
for the given `Task` instance. This function uses a `dict` that stores
132+
the Span using the `(task_id, is_publish)` as a key. This is useful
133+
when information must be propagated from one Celery signal to another.
122134
123135
We use (task_id, is_publish) for the key to ensure that publishing a
124136
task from within another task does not cause any conflicts.
@@ -134,36 +146,41 @@ def attach_span(task, task_id, span, is_publish=False):
134146
"""
135147
if task is None:
136148
return
137-
span_dict = getattr(task, CTX_KEY, None)
138-
if span_dict is None:
139-
span_dict = {}
140-
setattr(task, CTX_KEY, span_dict)
141149

142-
span_dict[(task_id, is_publish)] = span
150+
ctx_dict = getattr(task, CTX_KEY, None)
151+
152+
if ctx_dict is None:
153+
ctx_dict = {}
154+
setattr(task, CTX_KEY, ctx_dict)
155+
156+
ctx_dict[(task_id, is_publish)] = (span, activation, token)
143157

144158

145-
def detach_span(task, task_id, is_publish=False):
146-
"""Helper to remove a `Span` in a Celery task when it's propagated.
147-
This function handles tasks where the `Span` is not attached.
159+
def detach_context(task, task_id, is_publish=False) -> None:
160+
"""Helper to remove `Span`, `ContextManager` and context token in a
161+
Celery task when it's propagated.
162+
This function handles tasks where no values are attached to the `Task`.
148163
"""
149164
span_dict = getattr(task, CTX_KEY, None)
150165
if span_dict is None:
151166
return
152167

153-
# See note in `attach_span` for key info
154-
span_dict.pop((task_id, is_publish), (None, None))
168+
# See note in `attach_context` for key info
169+
span_dict.pop((task_id, is_publish), None)
155170

156171

157-
def retrieve_span(task, task_id, is_publish=False):
158-
"""Helper to retrieve an active `Span` stored in a `Task`
159-
instance
172+
def retrieve_context(
173+
task, task_id, is_publish=False
174+
) -> Optional[Tuple[Span, ContextManager[Span], Optional[object]]]:
175+
"""Helper to retrieve an active `Span`, `ContextManager` and context token
176+
stored in a `Task` instance
160177
"""
161178
span_dict = getattr(task, CTX_KEY, None)
162179
if span_dict is None:
163-
return (None, None)
180+
return None
164181

165-
# See note in `attach_span` for key info
166-
return span_dict.get((task_id, is_publish), (None, None))
182+
# See note in `attach_context` for key info
183+
return span_dict.get((task_id, is_publish), None)
167184

168185

169186
def retrieve_task(kwargs):

Diff for: instrumentation/opentelemetry-instrumentation-celery/tests/celery_test_tasks.py

+7
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from celery import Celery
1616

17+
from opentelemetry import baggage
18+
1719

1820
class Config:
1921
result_backend = "rpc"
@@ -36,3 +38,8 @@ def task_add(num_a, num_b):
3638
@app.task
3739
def task_raises():
3840
raise CustomError("The task failed!")
41+
42+
43+
@app.task
44+
def task_returns_baggage():
45+
return dict(baggage.get_all())

Diff for: instrumentation/opentelemetry-instrumentation-celery/tests/test_tasks.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
import threading
1616
import time
1717

18+
from opentelemetry import baggage, context
1819
from opentelemetry.instrumentation.celery import CeleryInstrumentor
1920
from opentelemetry.semconv.trace import SpanAttributes
2021
from opentelemetry.test.test_base import TestBase
2122
from opentelemetry.trace import SpanKind, StatusCode
2223

23-
from .celery_test_tasks import app, task_add, task_raises
24+
from .celery_test_tasks import app, task_add, task_raises, task_returns_baggage
2425

2526

2627
class TestCeleryInstrumentation(TestBase):
@@ -168,6 +169,22 @@ def test_uninstrument(self):
168169
spans = self.memory_exporter.get_finished_spans()
169170
self.assertEqual(len(spans), 0)
170171

172+
def test_baggage(self):
173+
CeleryInstrumentor().instrument()
174+
175+
ctx = baggage.set_baggage("key", "value")
176+
context.attach(ctx)
177+
178+
task = task_returns_baggage.delay()
179+
180+
timeout = time.time() + 60 * 1 # 1 minutes from now
181+
while not task.ready():
182+
if time.time() > timeout:
183+
break
184+
time.sleep(0.05)
185+
186+
self.assertEqual(task.result, {"key": "value"})
187+
171188

172189
class TestCelerySignatureTask(TestBase):
173190
def setUp(self):

Diff for: instrumentation/opentelemetry-instrumentation-celery/tests/test_utils.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,10 @@ def fn_task():
167167
# propagate and retrieve a Span
168168
task_id = "7c6731af-9533-40c3-83a9-25b58f0d837f"
169169
span = trace._Span("name", mock.Mock(spec=trace_api.SpanContext))
170-
utils.attach_span(fn_task, task_id, span)
171-
span_after = utils.retrieve_span(fn_task, task_id)
170+
utils.attach_context(fn_task, task_id, span, mock.Mock(), "")
171+
ctx = utils.retrieve_context(fn_task, task_id)
172+
self.assertIsNotNone(ctx)
173+
span_after, _, _ = ctx
172174
self.assertIs(span, span_after)
173175

174176
def test_span_delete(self):
@@ -180,17 +182,19 @@ def fn_task():
180182
# propagate a Span
181183
task_id = "7c6731af-9533-40c3-83a9-25b58f0d837f"
182184
span = trace._Span("name", mock.Mock(spec=trace_api.SpanContext))
183-
utils.attach_span(fn_task, task_id, span)
185+
utils.attach_context(fn_task, task_id, span, mock.Mock(), "")
184186
# delete the Span
185-
utils.detach_span(fn_task, task_id)
186-
self.assertEqual(utils.retrieve_span(fn_task, task_id), (None, None))
187+
utils.detach_context(fn_task, task_id)
188+
self.assertEqual(utils.retrieve_context(fn_task, task_id), None)
187189

188190
def test_optional_task_span_attach(self):
189191
task_id = "7c6731af-9533-40c3-83a9-25b58f0d837f"
190192
span = trace._Span("name", mock.Mock(spec=trace_api.SpanContext))
191193

192194
# assert this is is a no-aop
193-
self.assertIsNone(utils.attach_span(None, task_id, span))
195+
self.assertIsNone(
196+
utils.attach_context(None, task_id, span, mock.Mock(), "")
197+
)
194198

195199
def test_span_delete_empty(self):
196200
# ensure detach_span doesn't raise an exception if span is not present
@@ -201,10 +205,8 @@ def fn_task():
201205
# delete the Span
202206
task_id = "7c6731af-9533-40c3-83a9-25b58f0d837f"
203207
try:
204-
utils.detach_span(fn_task, task_id)
205-
self.assertEqual(
206-
utils.retrieve_span(fn_task, task_id), (None, None)
207-
)
208+
utils.detach_context(fn_task, task_id)
209+
self.assertEqual(utils.retrieve_context(fn_task, task_id), None)
208210
except Exception as ex: # pylint: disable=broad-except
209211
self.fail(f"Exception was raised: {ex}")
210212

0 commit comments

Comments
 (0)