@@ -85,7 +85,7 @@ def get_default_span_name():
85
85
return span_name
86
86
87
87
88
- def _rewrapped_app (wsgi_app ):
88
+ def _rewrapped_app (wsgi_app , response_hook = None ):
89
89
def _wrapped_app (wrapped_app_environ , start_response ):
90
90
# We want to measure the time for route matching, etc.
91
91
# In theory, we could start the span here and use
@@ -114,21 +114,21 @@ def _start_response(status, response_headers, *args, **kwargs):
114
114
"missing at _start_response(%s)" ,
115
115
status ,
116
116
)
117
-
117
+ if response_hook is not None :
118
+ response_hook (span , status , response_headers )
118
119
return start_response (status , response_headers , * args , ** kwargs )
119
120
120
121
return wsgi_app (wrapped_app_environ , _start_response )
121
122
122
123
return _wrapped_app
123
124
124
125
125
- def _wrapped_before_request (name_callback , tracer ):
126
+ def _wrapped_before_request (request_hook = None , tracer = None ):
126
127
def _before_request ():
127
128
if _excluded_urls .url_disabled (flask .request .url ):
128
129
return
129
-
130
130
flask_request_environ = flask .request .environ
131
- span_name = name_callback ()
131
+ span_name = get_default_span_name ()
132
132
token = context .attach (
133
133
extract (flask_request_environ , getter = otel_wsgi .wsgi_getter )
134
134
)
@@ -138,6 +138,9 @@ def _before_request():
138
138
kind = trace .SpanKind .SERVER ,
139
139
start_time = flask_request_environ .get (_ENVIRON_STARTTIME_KEY ),
140
140
)
141
+ if request_hook :
142
+ request_hook (span , flask_request_environ )
143
+
141
144
if span .is_recording ():
142
145
attributes = otel_wsgi .collect_request_attributes (
143
146
flask_request_environ
@@ -183,21 +186,25 @@ def _teardown_request(exc):
183
186
184
187
class _InstrumentedFlask (flask .Flask ):
185
188
186
- name_callback = get_default_span_name
187
189
_tracer_provider = None
190
+ _request_hook = None
191
+ _response_hook = None
188
192
189
193
def __init__ (self , * args , ** kwargs ):
190
194
super ().__init__ (* args , ** kwargs )
191
195
192
196
self ._original_wsgi_ = self .wsgi_app
193
- self .wsgi_app = _rewrapped_app (self .wsgi_app )
197
+
198
+ self .wsgi_app = _rewrapped_app (
199
+ self .wsgi_app , _InstrumentedFlask ._response_hook
200
+ )
194
201
195
202
tracer = trace .get_tracer (
196
203
__name__ , __version__ , _InstrumentedFlask ._tracer_provider
197
204
)
198
205
199
206
_before_request = _wrapped_before_request (
200
- _InstrumentedFlask .name_callback , tracer ,
207
+ _InstrumentedFlask ._request_hook , tracer ,
201
208
)
202
209
self ._before_request = _before_request
203
210
self .before_request (_before_request )
@@ -216,26 +223,30 @@ def instrumentation_dependencies(self) -> Collection[str]:
216
223
217
224
def _instrument (self , ** kwargs ):
218
225
self ._original_flask = flask .Flask
219
- name_callback = kwargs .get ("name_callback" )
226
+ request_hook = kwargs .get ("request_hook" )
227
+ response_hook = kwargs .get ("response_hook" )
228
+ if callable (request_hook ):
229
+ _InstrumentedFlask ._request_hook = request_hook
230
+ if callable (response_hook ):
231
+ _InstrumentedFlask ._response_hook = response_hook
232
+ flask .Flask = _InstrumentedFlask
220
233
tracer_provider = kwargs .get ("tracer_provider" )
221
- if callable (name_callback ):
222
- _InstrumentedFlask .name_callback = name_callback
223
234
_InstrumentedFlask ._tracer_provider = tracer_provider
224
235
flask .Flask = _InstrumentedFlask
225
236
226
237
def instrument_app (
227
- self , app , name_callback = get_default_span_name , tracer_provider = None
238
+ self , app , request_hook = None , response_hook = None , tracer_provider = None
228
239
): # pylint: disable=no-self-use
229
240
if not hasattr (app , "_is_instrumented" ):
230
241
app ._is_instrumented = False
231
242
232
243
if not app ._is_instrumented :
233
244
app ._original_wsgi_app = app .wsgi_app
234
- app .wsgi_app = _rewrapped_app (app .wsgi_app )
245
+ app .wsgi_app = _rewrapped_app (app .wsgi_app , response_hook )
235
246
236
247
tracer = trace .get_tracer (__name__ , __version__ , tracer_provider )
237
248
238
- _before_request = _wrapped_before_request (name_callback , tracer )
249
+ _before_request = _wrapped_before_request (request_hook , tracer )
239
250
app ._before_request = _before_request
240
251
app .before_request (_before_request )
241
252
app .teardown_request (_teardown_request )
0 commit comments