Skip to content

Commit 62d6737

Browse files
committed
httpx: fix handling of async hooks
Don't default to sync hooks for async hooks, then check that they are actually async. Fixes open-telemetry#2734
1 parent bc4d2c5 commit 62d6737

File tree

2 files changed

+48
-6
lines changed

2 files changed

+48
-6
lines changed

Diff for: instrumentation/opentelemetry-instrumentation-httpx/src/opentelemetry/instrumentation/httpx/__init__.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ async def async_response_hook(span, request, response):
192192
"""
193193
import logging
194194
import typing
195+
from asyncio import iscoroutinefunction
195196
from types import TracebackType
196197

197198
import httpx
@@ -731,15 +732,15 @@ def _instrument(self, **kwargs):
731732
self._original_async_client = httpx.AsyncClient
732733
request_hook = kwargs.get("request_hook")
733734
response_hook = kwargs.get("response_hook")
734-
async_request_hook = kwargs.get("async_request_hook", request_hook)
735-
async_response_hook = kwargs.get("async_response_hook", response_hook)
735+
async_request_hook = kwargs.get("async_request_hook")
736+
async_response_hook = kwargs.get("async_response_hook")
736737
if callable(request_hook):
737738
_InstrumentedClient._request_hook = request_hook
738-
if callable(async_request_hook):
739+
if callable(async_request_hook) and iscoroutinefunction(async_request_hook):
739740
_InstrumentedAsyncClient._request_hook = async_request_hook
740741
if callable(response_hook):
741742
_InstrumentedClient._response_hook = response_hook
742-
if callable(async_response_hook):
743+
if callable(async_response_hook) and iscoroutinefunction(async_response_hook):
743744
_InstrumentedAsyncClient._response_hook = async_response_hook
744745
tracer_provider = kwargs.get("tracer_provider")
745746
_InstrumentedClient._tracer_provider = tracer_provider

Diff for: instrumentation/opentelemetry-instrumentation-httpx/tests/test_httpx_integration.py

+43-2
Original file line numberDiff line numberDiff line change
@@ -780,9 +780,13 @@ def test_custom_tracer_provider(self):
780780
HTTPXClientInstrumentor().uninstrument()
781781

782782
def test_response_hook(self):
783+
response_hook_key = "async_response_hook" if asyncio.iscoroutinefunction(self.response_hook) else "response_hook"
784+
response_hook_kwargs = {
785+
response_hook_key: self.response_hook
786+
}
783787
HTTPXClientInstrumentor().instrument(
784788
tracer_provider=self.tracer_provider,
785-
response_hook=self.response_hook,
789+
**response_hook_kwargs
786790
)
787791
client = self.create_client()
788792
result = self.perform_request(self.URL, client=client)
@@ -823,9 +827,13 @@ def test_response_hook_sync_async_kwargs(self):
823827
HTTPXClientInstrumentor().uninstrument()
824828

825829
def test_request_hook(self):
830+
request_hook_key = "async_request_hook" if asyncio.iscoroutinefunction(self.request_hook) else "request_hook"
831+
request_hook_kwargs = {
832+
request_hook_key: self.request_hook
833+
}
826834
HTTPXClientInstrumentor().instrument(
827835
tracer_provider=self.tracer_provider,
828-
request_hook=self.request_hook,
836+
**request_hook_kwargs,
829837
)
830838
client = self.create_client()
831839
result = self.perform_request(self.URL, client=client)
@@ -1214,3 +1222,36 @@ def test_basic_multiple(self):
12141222
self.perform_request(self.URL, client=self.client)
12151223
self.perform_request(self.URL, client=self.client2)
12161224
self.assert_span(num_spans=2)
1225+
1226+
def test_async_response_hook_does_nothing_if_not_coroutine(self):
1227+
HTTPXClientInstrumentor().instrument(
1228+
tracer_provider=self.tracer_provider,
1229+
async_response_hook=_response_hook,
1230+
)
1231+
client = self.create_client()
1232+
result = self.perform_request(self.URL, client=client)
1233+
1234+
self.assertEqual(result.text, "Hello!")
1235+
span = self.assert_span()
1236+
self.assertEqual(
1237+
dict(span.attributes),
1238+
{
1239+
SpanAttributes.HTTP_METHOD: "GET",
1240+
SpanAttributes.HTTP_URL: self.URL,
1241+
SpanAttributes.HTTP_STATUS_CODE: 200,
1242+
},
1243+
)
1244+
HTTPXClientInstrumentor().uninstrument()
1245+
1246+
def test_async_request_hook_does_nothing_if_not_coroutine(self):
1247+
HTTPXClientInstrumentor().instrument(
1248+
tracer_provider=self.tracer_provider,
1249+
async_request_hook=_request_hook,
1250+
)
1251+
client = self.create_client()
1252+
result = self.perform_request(self.URL, client=client)
1253+
1254+
self.assertEqual(result.text, "Hello!")
1255+
span = self.assert_span()
1256+
self.assertEqual(span.name, "GET")
1257+
HTTPXClientInstrumentor().uninstrument()

0 commit comments

Comments
 (0)