diff --git a/docs/getting_started/tests/test_flask.py b/docs/getting_started/tests/test_flask.py index 321098ce97e..6ba296a5c03 100644 --- a/docs/getting_started/tests/test_flask.py +++ b/docs/getting_started/tests/test_flask.py @@ -38,6 +38,6 @@ def test_flask(self): server.terminate() output = str(server.stdout.read()) - self.assertIn('"name": "HTTP get"', output) + self.assertIn('"name": "HTTP GET"', output) self.assertIn('"name": "example-request"', output) self.assertIn('"name": "hello"', output) diff --git a/instrumentation/opentelemetry-instrumentation-requests/CHANGELOG.md b/instrumentation/opentelemetry-instrumentation-requests/CHANGELOG.md index 3f18f6101bc..eeb2e837252 100644 --- a/instrumentation/opentelemetry-instrumentation-requests/CHANGELOG.md +++ b/instrumentation/opentelemetry-instrumentation-requests/CHANGELOG.md @@ -2,6 +2,9 @@ ## Unreleased +- Add support for instrumenting prepared requests + ([#1040](https://github.com/open-telemetry/opentelemetry-python/pull/1040)) + ## Version 0.12b0 Released 2020-08-14 diff --git a/instrumentation/opentelemetry-instrumentation-requests/src/opentelemetry/instrumentation/requests/__init__.py b/instrumentation/opentelemetry-instrumentation-requests/src/opentelemetry/instrumentation/requests/__init__.py index e2c54b7f1b0..16e8952fea4 100644 --- a/instrumentation/opentelemetry-instrumentation-requests/src/opentelemetry/instrumentation/requests/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-requests/src/opentelemetry/instrumentation/requests/__init__.py @@ -29,26 +29,17 @@ opentelemetry.instrumentation.requests.RequestsInstrumentor().instrument() response = requests.get(url="https://www.example.org/") -Limitations ------------ - -Note that calls that do not use the higher-level APIs but use -:code:`requests.sessions.Session.send` (or an alias thereof) directly, are -currently not traced. If you find any other way to trigger an untraced HTTP -request, please report it via a GitHub issue with :code:`[requests: untraced -API]` in the title. - API --- """ import functools import types -from urllib.parse import urlparse from requests import Timeout, URLRequired from requests.exceptions import InvalidSchema, InvalidURL, MissingSchema from requests.sessions import Session +from requests.structures import CaseInsensitiveDict from opentelemetry import context, propagators from opentelemetry.instrumentation.instrumentor import BaseInstrumentor @@ -57,6 +48,10 @@ from opentelemetry.trace import SpanKind, get_tracer from opentelemetry.trace.status import Status, StatusCanonicalCode +# A key to a context variable to avoid creating duplicate spans when instrumenting +# both, Session.request and Session.send, since Session.request calls into Session.send +_SUPPRESS_REQUESTS_INSTRUMENTATION_KEY = "suppress_requests_instrumentation" + # pylint: disable=unused-argument def _instrument(tracer_provider=None, span_callback=None): @@ -71,15 +66,54 @@ def _instrument(tracer_provider=None, span_callback=None): # before v1.0.0, Dec 17, 2012, see # https://github.com/psf/requests/commit/4e5c4a6ab7bb0195dececdd19bb8505b872fe120) - wrapped = Session.request + wrapped_request = Session.request + wrapped_send = Session.send - @functools.wraps(wrapped) + @functools.wraps(wrapped_request) def instrumented_request(self, method, url, *args, **kwargs): - if context.get_value("suppress_instrumentation"): - return wrapped(self, method, url, *args, **kwargs) + def get_or_create_headers(): + headers = kwargs.get("headers") + if headers is None: + headers = {} + kwargs["headers"] = headers + + return headers + + def call_wrapped(): + return wrapped_request(self, method, url, *args, **kwargs) + + return _instrumented_requests_call( + method, url, call_wrapped, get_or_create_headers + ) + + @functools.wraps(wrapped_send) + def instrumented_send(self, request, **kwargs): + def get_or_create_headers(): + request.headers = ( + request.headers + if request.headers is not None + else CaseInsensitiveDict() + ) + return request.headers + + def call_wrapped(): + return wrapped_send(self, request, **kwargs) + + return _instrumented_requests_call( + request.method, request.url, call_wrapped, get_or_create_headers + ) + + def _instrumented_requests_call( + method: str, url: str, call_wrapped, get_or_create_headers + ): + if context.get_value("suppress_instrumentation") or context.get_value( + _SUPPRESS_REQUESTS_INSTRUMENTATION_KEY + ): + return call_wrapped() # See # https://github.com/open-telemetry/opentelemetry-specification/blob/master/specification/trace/semantic_conventions/http.md#http-client + method = method.upper() span_name = "HTTP {}".format(method) exception = None @@ -91,17 +125,19 @@ def instrumented_request(self, method, url, *args, **kwargs): span.set_attribute("http.method", method.upper()) span.set_attribute("http.url", url) - headers = kwargs.get("headers", {}) or {} + headers = get_or_create_headers() propagators.inject(type(headers).__setitem__, headers) - kwargs["headers"] = headers + token = context.attach( + context.set_value(_SUPPRESS_REQUESTS_INSTRUMENTATION_KEY, True) + ) try: - result = wrapped( - self, method, url, *args, **kwargs - ) # *** PROCEED + result = call_wrapped() # *** PROCEED except Exception as exc: # pylint: disable=W0703 exception = exc result = getattr(exc, "response", None) + finally: + context.detach(token) if exception is not None: span.set_status( @@ -124,24 +160,34 @@ def instrumented_request(self, method, url, *args, **kwargs): return result - instrumented_request.opentelemetry_ext_requests_applied = True - + instrumented_request.opentelemetry_instrumentation_requests_applied = True Session.request = instrumented_request - # TODO: We should also instrument requests.sessions.Session.send - # but to avoid doubled spans, we would need some context-local - # state (i.e., only create a Span if the current context's URL is - # different, then push the current URL, pop it afterwards) + instrumented_send.opentelemetry_instrumentation_requests_applied = True + Session.send = instrumented_send def _uninstrument(): - # pylint: disable=global-statement """Disables instrumentation of :code:`requests` through this module. Note that this only works if no other module also patches requests.""" - if getattr(Session.request, "opentelemetry_ext_requests_applied", False): - original = Session.request.__wrapped__ # pylint:disable=no-member - Session.request = original + _uninstrument_from(Session) + + +def _uninstrument_from(instr_root, restore_as_bound_func=False): + for instr_func_name in ("request", "send"): + instr_func = getattr(instr_root, instr_func_name) + if not getattr( + instr_func, + "opentelemetry_instrumentation_requests_applied", + False, + ): + continue + + original = instr_func.__wrapped__ # pylint:disable=no-member + if restore_as_bound_func: + original = types.MethodType(original, instr_root) + setattr(instr_root, instr_func_name, original) def _exception_to_canonical_code(exc: Exception) -> StatusCanonicalCode: @@ -179,8 +225,4 @@ def _uninstrument(self, **kwargs): @staticmethod def uninstrument_session(session): """Disables instrumentation on the session object.""" - if getattr( - session.request, "opentelemetry_ext_requests_applied", False - ): - original = session.request.__wrapped__ # pylint:disable=no-member - session.request = types.MethodType(original, session) + _uninstrument_from(session, restore_as_bound_func=True) diff --git a/instrumentation/opentelemetry-instrumentation-requests/tests/test_requests_integration.py b/instrumentation/opentelemetry-instrumentation-requests/tests/test_requests_integration.py index da09118e5bc..0e0492f47e3 100644 --- a/instrumentation/opentelemetry-instrumentation-requests/tests/test_requests_integration.py +++ b/instrumentation/opentelemetry-instrumentation-requests/tests/test_requests_integration.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc from unittest import mock import httpretty @@ -26,32 +27,47 @@ from opentelemetry.trace.status import StatusCanonicalCode -class TestRequestsIntegration(TestBase): +class RequestsIntegrationTestBase(abc.ABC): + # pylint: disable=no-member + URL = "http://httpbin.org/status/200" + # pylint: disable=invalid-name def setUp(self): super().setUp() RequestsInstrumentor().instrument() httpretty.enable() - httpretty.register_uri( - httpretty.GET, self.URL, body="Hello!", - ) + httpretty.register_uri(httpretty.GET, self.URL, body="Hello!") + # pylint: disable=invalid-name def tearDown(self): super().tearDown() RequestsInstrumentor().uninstrument() httpretty.disable() + def assert_span(self, exporter=None, num_spans=1): + if exporter is None: + exporter = self.memory_exporter + span_list = exporter.get_finished_spans() + self.assertEqual(num_spans, len(span_list)) + if num_spans == 0: + return None + if num_spans == 1: + return span_list[0] + return span_list + + @staticmethod + @abc.abstractmethod + def perform_request(url: str, session: requests.Session = None): + pass + def test_basic(self): - result = requests.get(self.URL) + result = self.perform_request(self.URL) self.assertEqual(result.text, "Hello!") - - span_list = self.memory_exporter.get_finished_spans() - self.assertEqual(len(span_list), 1) - span = span_list[0] + span = self.assert_span() self.assertIs(span.kind, trace.SpanKind.CLIENT) - self.assertEqual(span.name, "HTTP get") + self.assertEqual(span.name, "HTTP GET") self.assertEqual( span.attributes, @@ -77,12 +93,10 @@ def test_not_foundbasic(self): httpretty.register_uri( httpretty.GET, url_404, status=404, ) - result = requests.get(url_404) + result = self.perform_request(url_404) self.assertEqual(result.status_code, 404) - span_list = self.memory_exporter.get_finished_spans() - self.assertEqual(len(span_list), 1) - span = span_list[0] + span = self.assert_span() self.assertEqual(span.attributes.get("http.status_code"), 404) self.assertEqual(span.attributes.get("http.status_text"), "Not Found") @@ -92,31 +106,11 @@ def test_not_foundbasic(self): trace.status.StatusCanonicalCode.NOT_FOUND, ) - def test_invalid_url(self): - url = "http://[::1/nope" - - with self.assertRaises(ValueError): - requests.post(url) - - span_list = self.memory_exporter.get_finished_spans() - self.assertEqual(len(span_list), 1) - span = span_list[0] - - self.assertEqual(span.name, "HTTP post") - self.assertEqual( - span.attributes, - {"component": "http", "http.method": "POST", "http.url": url}, - ) - self.assertEqual( - span.status.canonical_code, StatusCanonicalCode.INVALID_ARGUMENT - ) - def test_uninstrument(self): RequestsInstrumentor().uninstrument() - result = requests.get(self.URL) + result = self.perform_request(self.URL) self.assertEqual(result.text, "Hello!") - span_list = self.memory_exporter.get_finished_spans() - self.assertEqual(len(span_list), 0) + self.assert_span(num_spans=0) # instrument again to avoid annoying warning message RequestsInstrumentor().instrument() @@ -124,49 +118,43 @@ def test_uninstrument_session(self): session1 = requests.Session() RequestsInstrumentor().uninstrument_session(session1) - result = session1.get(self.URL) + result = self.perform_request(self.URL, session1) self.assertEqual(result.text, "Hello!") - span_list = self.memory_exporter.get_finished_spans() - self.assertEqual(len(span_list), 0) + self.assert_span(num_spans=0) # Test that other sessions as well as global requests is still # instrumented session2 = requests.Session() - result = session2.get(self.URL) + result = self.perform_request(self.URL, session2) self.assertEqual(result.text, "Hello!") - span_list = self.memory_exporter.get_finished_spans() - self.assertEqual(len(span_list), 1) + self.assert_span() self.memory_exporter.clear() - result = requests.get(self.URL) + result = self.perform_request(self.URL) self.assertEqual(result.text, "Hello!") - span_list = self.memory_exporter.get_finished_spans() - self.assertEqual(len(span_list), 1) + self.assert_span() def test_suppress_instrumentation(self): token = context.attach( context.set_value("suppress_instrumentation", True) ) try: - result = requests.get(self.URL) + result = self.perform_request(self.URL) self.assertEqual(result.text, "Hello!") finally: context.detach(token) - span_list = self.memory_exporter.get_finished_spans() - self.assertEqual(len(span_list), 0) + self.assert_span(num_spans=0) def test_distributed_context(self): previous_propagator = propagators.get_global_httptextformat() try: propagators.set_global_httptextformat(MockHTTPTextFormat()) - result = requests.get(self.URL) + result = self.perform_request(self.URL) self.assertEqual(result.text, "Hello!") - span_list = self.memory_exporter.get_finished_spans() - self.assertEqual(len(span_list), 1) - span = span_list[0] + span = self.assert_span() headers = dict(httpretty.last_request().headers) self.assertIn(MockHTTPTextFormat.TRACE_ID_KEY, headers) @@ -195,13 +183,10 @@ def span_callback(span, result: requests.Response): tracer_provider=self.tracer_provider, span_callback=span_callback, ) - result = requests.get(self.URL) + result = self.perform_request(self.URL) self.assertEqual(result.text, "Hello!") - span_list = self.memory_exporter.get_finished_spans() - self.assertEqual(len(span_list), 1) - span = span_list[0] - + span = self.assert_span() self.assertEqual( span.attributes, { @@ -221,28 +206,21 @@ def test_custom_tracer_provider(self): RequestsInstrumentor().uninstrument() RequestsInstrumentor().instrument(tracer_provider=tracer_provider) - result = requests.get(self.URL) + result = self.perform_request(self.URL) self.assertEqual(result.text, "Hello!") - span_list = exporter.get_finished_spans() - self.assertEqual(len(span_list), 1) - span = span_list[0] - + span = self.assert_span(exporter=exporter) self.assertIs(span.resource, resource) - def test_if_headers_equals_none(self): - result = requests.get(self.URL, headers=None) - self.assertEqual(result.text, "Hello!") - - @mock.patch("requests.Session.send", side_effect=requests.RequestException) + @mock.patch( + "requests.adapters.HTTPAdapter.send", + side_effect=requests.RequestException, + ) def test_requests_exception_without_response(self, *_, **__): - with self.assertRaises(requests.RequestException): - requests.get(self.URL) + self.perform_request(self.URL) - span_list = self.memory_exporter.get_finished_spans() - self.assertEqual(len(span_list), 1) - span = span_list[0] + span = self.assert_span() self.assertEqual( span.attributes, {"component": "http", "http.method": "GET", "http.url": self.URL}, @@ -256,17 +234,14 @@ def test_requests_exception_without_response(self, *_, **__): mocked_response.reason = "Internal Server Error" @mock.patch( - "requests.Session.send", + "requests.adapters.HTTPAdapter.send", side_effect=requests.RequestException(response=mocked_response), ) def test_requests_exception_with_response(self, *_, **__): - with self.assertRaises(requests.RequestException): - requests.get(self.URL) + self.perform_request(self.URL) - span_list = self.memory_exporter.get_finished_spans() - self.assertEqual(len(span_list), 1) - span = span_list[0] + span = self.assert_span() self.assertEqual( span.attributes, { @@ -281,27 +256,66 @@ def test_requests_exception_with_response(self, *_, **__): span.status.canonical_code, StatusCanonicalCode.INTERNAL ) - @mock.patch("requests.Session.send", side_effect=Exception) + @mock.patch("requests.adapters.HTTPAdapter.send", side_effect=Exception) def test_requests_basic_exception(self, *_, **__): - with self.assertRaises(Exception): - requests.get(self.URL) + self.perform_request(self.URL) - span_list = self.memory_exporter.get_finished_spans() - self.assertEqual(len(span_list), 1) + span = self.assert_span() self.assertEqual( - span_list[0].status.canonical_code, StatusCanonicalCode.UNKNOWN + span.status.canonical_code, StatusCanonicalCode.UNKNOWN ) - @mock.patch("requests.Session.send", side_effect=requests.Timeout) + @mock.patch( + "requests.adapters.HTTPAdapter.send", side_effect=requests.Timeout + ) def test_requests_timeout_exception(self, *_, **__): - with self.assertRaises(Exception): - requests.get(self.URL) + self.perform_request(self.URL) - span_list = self.memory_exporter.get_finished_spans() - self.assertEqual(len(span_list), 1) + span = self.assert_span() self.assertEqual( - span_list[0].status.canonical_code, - StatusCanonicalCode.DEADLINE_EXCEEDED, + span.status.canonical_code, StatusCanonicalCode.DEADLINE_EXCEEDED ) + + +class TestRequestsIntegration(RequestsIntegrationTestBase, TestBase): + @staticmethod + def perform_request(url: str, session: requests.Session = None): + if session is None: + return requests.get(url) + return session.get(url) + + def test_invalid_url(self): + url = "http://[::1/nope" + + with self.assertRaises(ValueError): + requests.post(url) + + span = self.assert_span() + + self.assertEqual(span.name, "HTTP POST") + self.assertEqual( + span.attributes, + {"component": "http", "http.method": "POST", "http.url": url}, + ) + self.assertEqual( + span.status.canonical_code, StatusCanonicalCode.INVALID_ARGUMENT + ) + + def test_if_headers_equals_none(self): + result = requests.get(self.URL, headers=None) + self.assertEqual(result.text, "Hello!") + self.assert_span() + + +class TestRequestsIntegrationPreparedRequest( + RequestsIntegrationTestBase, TestBase +): + @staticmethod + def perform_request(url: str, session: requests.Session = None): + if session is None: + session = requests.Session() + request = requests.Request("GET", url) + prepared_request = session.prepare_request(request) + return session.send(prepared_request)