@@ -67,6 +67,7 @@ def add(x, y):
67
67
from billiard .einfo import ExceptionInfo
68
68
from celery import signals # pylint: disable=no-name-in-module
69
69
70
+ from opentelemetry import context as context_api
70
71
from opentelemetry import trace
71
72
from opentelemetry .instrumentation .celery import utils
72
73
from opentelemetry .instrumentation .celery .package import _instruments
@@ -169,6 +170,7 @@ def _trace_prerun(self, *args, **kwargs):
169
170
self .update_task_duration_time (task_id )
170
171
request = task .request
171
172
tracectx = extract (request , getter = celery_getter ) or None
173
+ token = context_api .attach (tracectx ) if tracectx is not None else None
172
174
173
175
logger .debug ("prerun signal start task_id=%s" , task_id )
174
176
@@ -179,7 +181,7 @@ def _trace_prerun(self, *args, **kwargs):
179
181
180
182
activation = trace .use_span (span , end_on_exit = True )
181
183
activation .__enter__ () # pylint: disable=E1101
182
- utils .attach_span (task , task_id , ( span , activation ) )
184
+ utils .attach_context (task , task_id , span , activation , token )
183
185
184
186
def _trace_postrun (self , * args , ** kwargs ):
185
187
task = utils .retrieve_task (kwargs )
@@ -191,11 +193,14 @@ def _trace_postrun(self, *args, **kwargs):
191
193
logger .debug ("postrun signal task_id=%s" , task_id )
192
194
193
195
# retrieve and finish the Span
194
- span , activation = utils .retrieve_span (task , task_id )
195
- if span is None :
196
+ ctx = utils .retrieve_context (task , task_id )
197
+
198
+ if ctx is None :
196
199
logger .warning ("no existing span found for task_id=%s" , task_id )
197
200
return
198
201
202
+ span , activation , token = ctx
203
+
199
204
# request context tags
200
205
if span .is_recording ():
201
206
span .set_attribute (_TASK_TAG_KEY , _TASK_RUN )
@@ -204,10 +209,11 @@ def _trace_postrun(self, *args, **kwargs):
204
209
span .set_attribute (_TASK_NAME_KEY , task .name )
205
210
206
211
activation .__exit__ (None , None , None )
207
- utils .detach_span (task , task_id )
212
+ utils .detach_context (task , task_id )
208
213
self .update_task_duration_time (task_id )
209
214
labels = {"task" : task .name , "worker" : task .request .hostname }
210
215
self ._record_histograms (task_id , labels )
216
+ context_api .detach (token )
211
217
212
218
def _trace_before_publish (self , * args , ** kwargs ):
213
219
task = utils .retrieve_task_from_sender (kwargs )
@@ -238,7 +244,9 @@ def _trace_before_publish(self, *args, **kwargs):
238
244
activation = trace .use_span (span , end_on_exit = True )
239
245
activation .__enter__ () # pylint: disable=E1101
240
246
241
- utils .attach_span (task , task_id , (span , activation ), is_publish = True )
247
+ utils .attach_context (
248
+ task , task_id , span , activation , None , is_publish = True
249
+ )
242
250
243
251
headers = kwargs .get ("headers" )
244
252
if headers :
@@ -253,13 +261,16 @@ def _trace_after_publish(*args, **kwargs):
253
261
return
254
262
255
263
# retrieve and finish the Span
256
- _ , activation = utils .retrieve_span (task , task_id , is_publish = True )
257
- if activation is None :
264
+ ctx = utils .retrieve_context (task , task_id , is_publish = True )
265
+
266
+ if ctx is None :
258
267
logger .warning ("no existing span found for task_id=%s" , task_id )
259
268
return
260
269
270
+ _ , activation , _ = ctx
271
+
261
272
activation .__exit__ (None , None , None ) # pylint: disable=E1101
262
- utils .detach_span (task , task_id , is_publish = True )
273
+ utils .detach_context (task , task_id , is_publish = True )
263
274
264
275
@staticmethod
265
276
def _trace_failure (* args , ** kwargs ):
@@ -269,9 +280,14 @@ def _trace_failure(*args, **kwargs):
269
280
if task is None or task_id is None :
270
281
return
271
282
272
- # retrieve and pass exception info to activation
273
- span , _ = utils .retrieve_span (task , task_id )
274
- if span is None or not span .is_recording ():
283
+ ctx = utils .retrieve_context (task , task_id )
284
+
285
+ if ctx is None :
286
+ return
287
+
288
+ span , _ , _ = ctx
289
+
290
+ if not span .is_recording ():
275
291
return
276
292
277
293
status_kwargs = {"status_code" : StatusCode .ERROR }
@@ -311,8 +327,14 @@ def _trace_retry(*args, **kwargs):
311
327
if task is None or task_id is None or reason is None :
312
328
return
313
329
314
- span , _ = utils .retrieve_span (task , task_id )
315
- if span is None or not span .is_recording ():
330
+ ctx = utils .retrieve_context (task , task_id )
331
+
332
+ if ctx is None :
333
+ return
334
+
335
+ span , _ , _ = ctx
336
+
337
+ if not span .is_recording ():
316
338
return
317
339
318
340
# Add retry reason metadata to span
0 commit comments