@@ -233,6 +233,38 @@ def _prepare_headers(headers: typing.Optional[Headers]) -> httpx.Headers:
233
233
return httpx .Headers (headers )
234
234
235
235
236
+ def _extract_parameters (args , kwargs ):
237
+ if isinstance (args [0 ], httpx .Request ):
238
+ # In httpx >= 0.20.0, handle_request receives a Request object
239
+ request : httpx .Request = args [0 ]
240
+ method = request .method .encode ()
241
+ url = request .url
242
+ headers = request .headers
243
+ stream = request .stream
244
+ extensions = request .extensions
245
+ else :
246
+ # In httpx < 0.20.0, handle_request receives the parameters separately
247
+ method = args [0 ]
248
+ url = args [1 ]
249
+ headers = kwargs .get ("headers" , args [2 ] if len (args ) > 2 else None )
250
+ stream = kwargs .get ("stream" , args [3 ] if len (args ) > 3 else None )
251
+ extensions = kwargs .get (
252
+ "extensions" , args [4 ] if len (args ) > 4 else None
253
+ )
254
+
255
+ return method , url , headers , stream , extensions
256
+
257
+
258
+ def _inject_propagation_headers (headers , args , kwargs ):
259
+ _headers = _prepare_headers (headers )
260
+ inject (_headers )
261
+ if isinstance (args [0 ], httpx .Request ):
262
+ request : httpx .Request = args [0 ]
263
+ request .headers = _headers
264
+ else :
265
+ kwargs ["headers" ] = _headers .raw
266
+
267
+
236
268
class SyncOpenTelemetryTransport (httpx .BaseTransport ):
237
269
"""Sync transport class that will trace all requests made with a client.
238
270
@@ -263,60 +295,53 @@ def __init__(
263
295
264
296
def handle_request (
265
297
self ,
266
- method : bytes ,
267
- url : URL ,
268
- headers : typing .Optional [ Headers ] = None ,
269
- stream : typing .Optional [ httpx .SyncByteStream ] = None ,
270
- extensions : typing . Optional [ dict ] = None ,
271
- ) -> typing . Tuple [ int , "Headers" , httpx . SyncByteStream , dict ]:
298
+ * args ,
299
+ ** kwargs ,
300
+ ) -> typing .Union [
301
+ typing .Tuple [ int , "Headers" , httpx .SyncByteStream , dict ] ,
302
+ httpx . Response ,
303
+ ]:
272
304
"""Add request info to span."""
273
305
if context .get_value ("suppress_instrumentation" ):
274
- return self ._transport .handle_request (
275
- method ,
276
- url ,
277
- headers = headers ,
278
- stream = stream ,
279
- extensions = extensions ,
280
- )
306
+ return self ._transport .handle_request (* args , ** kwargs )
281
307
308
+ method , url , headers , stream , extensions = _extract_parameters (
309
+ args , kwargs
310
+ )
282
311
span_attributes = _prepare_attributes (method , url )
283
- _headers = _prepare_headers (headers )
312
+
313
+ request_info = RequestInfo (method , url , headers , stream , extensions )
284
314
span_name = _get_default_span_name (
285
315
span_attributes [SpanAttributes .HTTP_METHOD ]
286
316
)
287
- request = RequestInfo (method , url , headers , stream , extensions )
288
317
289
318
with self ._tracer .start_as_current_span (
290
319
span_name , kind = SpanKind .CLIENT , attributes = span_attributes
291
320
) as span :
292
321
if self ._request_hook is not None :
293
- self ._request_hook (span , request )
294
-
295
- inject (_headers )
296
-
297
- (
298
- status_code ,
299
- headers ,
300
- stream ,
301
- extensions ,
302
- ) = self ._transport .handle_request (
303
- method ,
304
- url ,
305
- headers = _headers .raw ,
306
- stream = stream ,
307
- extensions = extensions ,
308
- )
322
+ self ._request_hook (span , request_info )
323
+
324
+ _inject_propagation_headers (headers , args , kwargs )
325
+ response = self ._transport .handle_request (* args , ** kwargs )
326
+ if isinstance (response , httpx .Response ):
327
+ response : httpx .Response = response
328
+ status_code = response .status_code
329
+ headers = response .headers
330
+ stream = response .stream
331
+ extensions = response .extensions
332
+ else :
333
+ status_code , headers , stream , extensions = response
309
334
310
335
_apply_status_code (span , status_code )
311
336
312
337
if self ._response_hook is not None :
313
338
self ._response_hook (
314
339
span ,
315
- request ,
340
+ request_info ,
316
341
ResponseInfo (status_code , headers , stream , extensions ),
317
342
)
318
343
319
- return status_code , headers , stream , extensions
344
+ return response
320
345
321
346
322
347
class AsyncOpenTelemetryTransport (httpx .AsyncBaseTransport ):
@@ -348,61 +373,55 @@ def __init__(
348
373
self ._response_hook = response_hook
349
374
350
375
async def handle_async_request (
351
- self ,
352
- method : bytes ,
353
- url : URL ,
354
- headers : typing .Optional [Headers ] = None ,
355
- stream : typing .Optional [httpx .AsyncByteStream ] = None ,
356
- extensions : typing .Optional [dict ] = None ,
357
- ) -> typing .Tuple [int , "Headers" , httpx .AsyncByteStream , dict ]:
376
+ self , * args , ** kwargs
377
+ ) -> typing .Union [
378
+ typing .Tuple [int , "Headers" , httpx .AsyncByteStream , dict ],
379
+ httpx .Response ,
380
+ ]:
358
381
"""Add request info to span."""
359
382
if context .get_value ("suppress_instrumentation" ):
360
- return await self ._transport .handle_async_request (
361
- method ,
362
- url ,
363
- headers = headers ,
364
- stream = stream ,
365
- extensions = extensions ,
366
- )
383
+ return await self ._transport .handle_async_request (* args , ** kwargs )
367
384
385
+ method , url , headers , stream , extensions = _extract_parameters (
386
+ args , kwargs
387
+ )
368
388
span_attributes = _prepare_attributes (method , url )
369
- _headers = _prepare_headers ( headers )
389
+
370
390
span_name = _get_default_span_name (
371
391
span_attributes [SpanAttributes .HTTP_METHOD ]
372
392
)
373
- request = RequestInfo (method , url , headers , stream , extensions )
393
+ request_info = RequestInfo (method , url , headers , stream , extensions )
374
394
375
395
with self ._tracer .start_as_current_span (
376
396
span_name , kind = SpanKind .CLIENT , attributes = span_attributes
377
397
) as span :
378
398
if self ._request_hook is not None :
379
- await self ._request_hook (span , request )
380
-
381
- inject (_headers )
382
-
383
- (
384
- status_code ,
385
- headers ,
386
- stream ,
387
- extensions ,
388
- ) = await self ._transport .handle_async_request (
389
- method ,
390
- url ,
391
- headers = _headers .raw ,
392
- stream = stream ,
393
- extensions = extensions ,
399
+ await self ._request_hook (span , request_info )
400
+
401
+ _inject_propagation_headers (headers , args , kwargs )
402
+
403
+ response = await self ._transport .handle_async_request (
404
+ * args , ** kwargs
394
405
)
406
+ if isinstance (response , httpx .Response ):
407
+ response : httpx .Response = response
408
+ status_code = response .status_code
409
+ headers = response .headers
410
+ stream = response .stream
411
+ extensions = response .extensions
412
+ else :
413
+ status_code , headers , stream , extensions = response
395
414
396
415
_apply_status_code (span , status_code )
397
416
398
417
if self ._response_hook is not None :
399
418
await self ._response_hook (
400
419
span ,
401
- request ,
420
+ request_info ,
402
421
ResponseInfo (status_code , headers , stream , extensions ),
403
422
)
404
423
405
- return status_code , headers , stream , extensions
424
+ return response
406
425
407
426
408
427
class _InstrumentedClient (httpx .Client ):
0 commit comments