@@ -780,9 +780,15 @@ def test_custom_tracer_provider(self):
780
780
HTTPXClientInstrumentor ().uninstrument ()
781
781
782
782
def test_response_hook (self ):
783
+ response_hook_key = (
784
+ "async_response_hook"
785
+ if asyncio .iscoroutinefunction (self .response_hook )
786
+ else "response_hook"
787
+ )
788
+ response_hook_kwargs = {response_hook_key : self .response_hook }
783
789
HTTPXClientInstrumentor ().instrument (
784
790
tracer_provider = self .tracer_provider ,
785
- response_hook = self . response_hook ,
791
+ ** response_hook_kwargs ,
786
792
)
787
793
client = self .create_client ()
788
794
result = self .perform_request (self .URL , client = client )
@@ -823,9 +829,15 @@ def test_response_hook_sync_async_kwargs(self):
823
829
HTTPXClientInstrumentor ().uninstrument ()
824
830
825
831
def test_request_hook (self ):
832
+ request_hook_key = (
833
+ "async_request_hook"
834
+ if asyncio .iscoroutinefunction (self .request_hook )
835
+ else "request_hook"
836
+ )
837
+ request_hook_kwargs = {request_hook_key : self .request_hook }
826
838
HTTPXClientInstrumentor ().instrument (
827
839
tracer_provider = self .tracer_provider ,
828
- request_hook = self . request_hook ,
840
+ ** request_hook_kwargs ,
829
841
)
830
842
client = self .create_client ()
831
843
result = self .perform_request (self .URL , client = client )
@@ -1214,3 +1226,36 @@ def test_basic_multiple(self):
1214
1226
self .perform_request (self .URL , client = self .client )
1215
1227
self .perform_request (self .URL , client = self .client2 )
1216
1228
self .assert_span (num_spans = 2 )
1229
+
1230
+ def test_async_response_hook_does_nothing_if_not_coroutine (self ):
1231
+ HTTPXClientInstrumentor ().instrument (
1232
+ tracer_provider = self .tracer_provider ,
1233
+ async_response_hook = _response_hook ,
1234
+ )
1235
+ client = self .create_client ()
1236
+ result = self .perform_request (self .URL , client = client )
1237
+
1238
+ self .assertEqual (result .text , "Hello!" )
1239
+ span = self .assert_span ()
1240
+ self .assertEqual (
1241
+ dict (span .attributes ),
1242
+ {
1243
+ SpanAttributes .HTTP_METHOD : "GET" ,
1244
+ SpanAttributes .HTTP_URL : self .URL ,
1245
+ SpanAttributes .HTTP_STATUS_CODE : 200 ,
1246
+ },
1247
+ )
1248
+ HTTPXClientInstrumentor ().uninstrument ()
1249
+
1250
+ def test_async_request_hook_does_nothing_if_not_coroutine (self ):
1251
+ HTTPXClientInstrumentor ().instrument (
1252
+ tracer_provider = self .tracer_provider ,
1253
+ async_request_hook = _request_hook ,
1254
+ )
1255
+ client = self .create_client ()
1256
+ result = self .perform_request (self .URL , client = client )
1257
+
1258
+ self .assertEqual (result .text , "Hello!" )
1259
+ span = self .assert_span ()
1260
+ self .assertEqual (span .name , "GET" )
1261
+ HTTPXClientInstrumentor ().uninstrument ()
0 commit comments