54
54
from typing import Callable , Collection , Iterable , Optional
55
55
from urllib .parse import urlparse
56
56
57
- from requests .models import Response
57
+ from requests .models import PreparedRequest , Response
58
58
from requests .sessions import Session
59
59
from requests .structures import CaseInsensitiveDict
60
60
85
85
86
86
_excluded_urls_from_env = get_excluded_urls ("REQUESTS" )
87
87
88
+ _RequestHookT = Optional [Callable [[Span , PreparedRequest ], None ]]
89
+ _ResponseHookT = Optional [Callable [[Span , PreparedRequest ], None ]]
90
+
88
91
89
92
# pylint: disable=unused-argument
90
93
# pylint: disable=R0915
91
94
def _instrument (
92
95
tracer : Tracer ,
93
96
duration_histogram : Histogram ,
94
- span_callback : Optional [ Callable [[ Span , Response ], str ]] = None ,
95
- name_callback : Optional [ Callable [[ str , str ], str ]] = None ,
97
+ request_hook : _RequestHookT = None ,
98
+ response_hook : _ResponseHookT = None ,
96
99
excluded_urls : Iterable [str ] = None ,
97
100
):
98
101
"""Enables tracing of all requests calls that go through
@@ -106,29 +109,9 @@ def _instrument(
106
109
# before v1.0.0, Dec 17, 2012, see
107
110
# https://github.com/psf/requests/commit/4e5c4a6ab7bb0195dececdd19bb8505b872fe120)
108
111
109
- wrapped_request = Session .request
110
112
wrapped_send = Session .send
111
113
112
- @functools .wraps (wrapped_request )
113
- def instrumented_request (self , method , url , * args , ** kwargs ):
114
- if excluded_urls and excluded_urls .url_disabled (url ):
115
- return wrapped_request (self , method , url , * args , ** kwargs )
116
-
117
- def get_or_create_headers ():
118
- headers = kwargs .get ("headers" )
119
- if headers is None :
120
- headers = {}
121
- kwargs ["headers" ] = headers
122
-
123
- return headers
124
-
125
- def call_wrapped ():
126
- return wrapped_request (self , method , url , * args , ** kwargs )
127
-
128
- return _instrumented_requests_call (
129
- method , url , call_wrapped , get_or_create_headers
130
- )
131
-
114
+ # pylint: disable-msg=too-many-locals,too-many-branches
132
115
@functools .wraps (wrapped_send )
133
116
def instrumented_send (self , request , ** kwargs ):
134
117
if excluded_urls and excluded_urls .url_disabled (request .url ):
@@ -142,32 +125,17 @@ def get_or_create_headers():
142
125
)
143
126
return request .headers
144
127
145
- def call_wrapped ():
146
- return wrapped_send (self , request , ** kwargs )
147
-
148
- return _instrumented_requests_call (
149
- request .method , request .url , call_wrapped , get_or_create_headers
150
- )
151
-
152
- # pylint: disable-msg=too-many-locals,too-many-branches
153
- def _instrumented_requests_call (
154
- method : str , url : str , call_wrapped , get_or_create_headers
155
- ):
156
128
if context .get_value (
157
129
_SUPPRESS_INSTRUMENTATION_KEY
158
130
) or context .get_value (_SUPPRESS_HTTP_INSTRUMENTATION_KEY ):
159
- return call_wrapped ( )
131
+ return wrapped_send ( self , request , ** kwargs )
160
132
161
133
# See
162
134
# https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/semantic_conventions/http.md#http-client
163
- method = method .upper ()
164
- span_name = ""
165
- if name_callback is not None :
166
- span_name = name_callback (method , url )
167
- if not span_name or not isinstance (span_name , str ):
168
- span_name = get_default_span_name (method )
135
+ method = request .method .upper ()
136
+ span_name = get_default_span_name (method )
169
137
170
- url = remove_url_credentials (url )
138
+ url = remove_url_credentials (request . url )
171
139
172
140
span_attributes = {
173
141
SpanAttributes .HTTP_METHOD : method ,
@@ -195,6 +163,8 @@ def _instrumented_requests_call(
195
163
span_name , kind = SpanKind .CLIENT , attributes = span_attributes
196
164
) as span , set_ip_on_next_http_connection (span ):
197
165
exception = None
166
+ if callable (request_hook ):
167
+ request_hook (span , request )
198
168
199
169
headers = get_or_create_headers ()
200
170
inject (headers )
@@ -206,7 +176,7 @@ def _instrumented_requests_call(
206
176
start_time = default_timer ()
207
177
208
178
try :
209
- result = call_wrapped ( ) # *** PROCEED
179
+ result = wrapped_send ( self , request , ** kwargs ) # *** PROCEED
210
180
except Exception as exc : # pylint: disable=W0703
211
181
exception = exc
212
182
result = getattr (exc , "response" , None )
@@ -236,8 +206,8 @@ def _instrumented_requests_call(
236
206
"1.1" if version == 11 else "1.0"
237
207
)
238
208
239
- if span_callback is not None :
240
- span_callback (span , result )
209
+ if callable ( response_hook ) :
210
+ response_hook (span , request , result )
241
211
242
212
duration_histogram .record (elapsed_time , attributes = metric_labels )
243
213
@@ -246,9 +216,6 @@ def _instrumented_requests_call(
246
216
247
217
return result
248
218
249
- instrumented_request .opentelemetry_instrumentation_requests_applied = True
250
- Session .request = instrumented_request
251
-
252
219
instrumented_send .opentelemetry_instrumentation_requests_applied = True
253
220
Session .send = instrumented_send
254
221
@@ -295,10 +262,8 @@ def _instrument(self, **kwargs):
295
262
Args:
296
263
**kwargs: Optional arguments
297
264
``tracer_provider``: a TracerProvider, defaults to global
298
- ``span_callback``: An optional callback invoked before returning the http response. Invoked with Span and requests.Response
299
- ``name_callback``: Callback which calculates a generic span name for an
300
- outgoing HTTP request based on the method and url.
301
- Optional: Defaults to get_default_span_name.
265
+ ``request_hook``: An optional callback that is invoked right after a span is created.
266
+ ``response_hook``: An optional callback which is invoked right before the span is finished processing a response.
302
267
``excluded_urls``: A string containing a comma-delimited
303
268
list of regexes used to exclude URLs from tracking
304
269
"""
@@ -319,8 +284,8 @@ def _instrument(self, **kwargs):
319
284
_instrument (
320
285
tracer ,
321
286
duration_histogram ,
322
- span_callback = kwargs .get ("span_callback " ),
323
- name_callback = kwargs .get ("name_callback " ),
287
+ request_hook = kwargs .get ("request_hook " ),
288
+ response_hook = kwargs .get ("response_hook " ),
324
289
excluded_urls = _excluded_urls_from_env
325
290
if excluded_urls is None
326
291
else parse_excluded_urls (excluded_urls ),
0 commit comments