From 6e71f2e6265aa35c71d4656c485b7944a605b094 Mon Sep 17 00:00:00 2001 From: Seth Michael Larson Date: Tue, 29 Jun 2021 09:53:04 -0500 Subject: [PATCH] Tolerate RecursionError not being defined in Python<3.5 --- elasticsearch/_async/http_aiohttp.py | 4 +- elasticsearch/compat.py | 14 +++++++ elasticsearch/compat.pyi | 3 +- elasticsearch/connection/http_requests.py | 4 +- elasticsearch/connection/http_urllib3.py | 4 +- .../test_async/test_connection.py | 21 ++++++++++ test_elasticsearch/test_connection.py | 42 +++++++++++++++++++ 7 files changed, 85 insertions(+), 7 deletions(-) diff --git a/elasticsearch/_async/http_aiohttp.py b/elasticsearch/_async/http_aiohttp.py index 2057cdd6a..20223472e 100644 --- a/elasticsearch/_async/http_aiohttp.py +++ b/elasticsearch/_async/http_aiohttp.py @@ -22,7 +22,7 @@ import urllib3 # type: ignore -from ..compat import urlencode +from ..compat import reraise_exceptions, urlencode from ..connection.base import Connection from ..exceptions import ( ConnectionError, @@ -304,7 +304,7 @@ async def perform_request( duration = self.loop.time() - start # We want to reraise a cancellation or recursion error. - except (asyncio.CancelledError, RecursionError): + except reraise_exceptions: raise except Exception as e: self.log_request_fail( diff --git a/elasticsearch/compat.py b/elasticsearch/compat.py index e98478523..99425ce6c 100644 --- a/elasticsearch/compat.py +++ b/elasticsearch/compat.py @@ -58,8 +58,22 @@ def to_bytes(x, encoding="ascii"): from collections import Mapping +try: + reraise_exceptions = (RecursionError,) +except NameError: + reraise_exceptions = () + +try: + import asyncio + + reraise_exceptions += (asyncio.CancelledError,) +except (ImportError, AttributeError): + pass + + __all__ = [ "string_types", + "reraise_exceptions", "quote_plus", "quote", "urlencode", diff --git a/elasticsearch/compat.pyi b/elasticsearch/compat.pyi index 50ae33cc2..249ff0415 100644 --- a/elasticsearch/compat.pyi +++ b/elasticsearch/compat.pyi @@ -16,13 +16,14 @@ # under the License. import sys -from typing import Callable, Tuple, Union +from typing import Callable, Tuple, Type, Union PY2: bool string_types: Tuple[type, ...] to_str: Callable[[Union[str, bytes]], str] to_bytes: Callable[[Union[str, bytes]], bytes] +reraise_exceptions: Tuple[Type[Exception], ...] if sys.version_info[0] == 2: from itertools import imap as map diff --git a/elasticsearch/connection/http_requests.py b/elasticsearch/connection/http_requests.py index 84526d8d5..e45cbde81 100644 --- a/elasticsearch/connection/http_requests.py +++ b/elasticsearch/connection/http_requests.py @@ -18,7 +18,7 @@ import time import warnings -from ..compat import string_types, urlencode +from ..compat import reraise_exceptions, string_types, urlencode from ..exceptions import ( ConnectionError, ConnectionTimeout, @@ -166,7 +166,7 @@ def perform_request( response = self.session.send(prepared_request, **send_kwargs) duration = time.time() - start raw_data = response.content.decode("utf-8", "surrogatepass") - except RecursionError: + except reraise_exceptions: raise except Exception as e: self.log_request_fail( diff --git a/elasticsearch/connection/http_urllib3.py b/elasticsearch/connection/http_urllib3.py index e343616db..cac8659bf 100644 --- a/elasticsearch/connection/http_urllib3.py +++ b/elasticsearch/connection/http_urllib3.py @@ -24,7 +24,7 @@ from urllib3.exceptions import SSLError as UrllibSSLError # type: ignore from urllib3.util.retry import Retry # type: ignore -from ..compat import urlencode +from ..compat import reraise_exceptions, urlencode from ..exceptions import ( ConnectionError, ConnectionTimeout, @@ -253,7 +253,7 @@ def perform_request( ) duration = time.time() - start raw_data = response.data.decode("utf-8", "surrogatepass") - except RecursionError: + except reraise_exceptions: raise except Exception as e: self.log_request_fail( diff --git a/test_elasticsearch/test_async/test_connection.py b/test_elasticsearch/test_async/test_connection.py index 6b53adfed..d0646e1c7 100644 --- a/test_elasticsearch/test_async/test_connection.py +++ b/test_elasticsearch/test_async/test_connection.py @@ -29,6 +29,8 @@ from multidict import CIMultiDict from elasticsearch import AIOHttpConnection, __versionstr__ +from elasticsearch.compat import reraise_exceptions +from elasticsearch.exceptions import ConnectionError pytestmark = pytest.mark.asyncio @@ -318,6 +320,20 @@ async def test_surrogatepass_into_bytes(self): status, headers, data = await con.perform_request("GET", "/") assert u"你好\uda6a" == data + @pytest.mark.parametrize("exception_cls", reraise_exceptions) + async def test_recursion_error_reraised(self, exception_cls): + conn = AIOHttpConnection() + + def request_raise(*_, **__): + raise exception_cls("Wasn't modified!") + + await conn._create_aiohttp_session() + conn.session.request = request_raise + + with pytest.raises(exception_cls) as e: + await conn.perform_request("GET", "/") + assert str(e.value) == "Wasn't modified!" + class TestConnectionHttpbin: """Tests the HTTP connection implementations against a live server E2E""" @@ -389,3 +405,8 @@ async def test_aiohttp_connection(self): "Header2": "value2", "User-Agent": user_agent, } + + async def test_aiohttp_connection_error(self): + conn = AIOHttpConnection("not.a.host.name") + with pytest.raises(ConnectionError): + await conn.perform_request("GET", "/") diff --git a/test_elasticsearch/test_connection.py b/test_elasticsearch/test_connection.py index 878c661e6..0baa4f456 100644 --- a/test_elasticsearch/test_connection.py +++ b/test_elasticsearch/test_connection.py @@ -32,6 +32,7 @@ from urllib3._collections import HTTPHeaderDict from elasticsearch import __versionstr__ +from elasticsearch.compat import reraise_exceptions from elasticsearch.connection import ( Connection, RequestsHttpConnection, @@ -39,6 +40,7 @@ ) from elasticsearch.exceptions import ( ConflictError, + ConnectionError, NotFoundError, RequestError, TransportError, @@ -466,6 +468,21 @@ def test_surrogatepass_into_bytes(self): status, headers, data = con.perform_request("GET", "/") self.assertEqual(u"你好\uda6a", data) + @pytest.mark.skipif( + not reraise_exceptions, reason="RecursionError isn't defined in Python <3.5" + ) + def test_recursion_error_reraised(self): + conn = Urllib3HttpConnection() + + def urlopen_raise(*_, **__): + raise RecursionError("Wasn't modified!") + + conn.pool.urlopen = urlopen_raise + + with pytest.raises(RecursionError) as e: + conn.perform_request("GET", "/") + assert str(e.value) == "Wasn't modified!" + class TestRequestsConnection(TestCase): def _get_mock_connection( @@ -868,6 +885,21 @@ def test_surrogatepass_into_bytes(self): status, headers, data = con.perform_request("GET", "/") self.assertEqual(u"你好\uda6a", data) + @pytest.mark.skipif( + not reraise_exceptions, reason="RecursionError isn't defined in Python <3.5" + ) + def test_recursion_error_reraised(self): + conn = RequestsHttpConnection() + + def send_raise(*_, **__): + raise RecursionError("Wasn't modified!") + + conn.session.send = send_raise + + with pytest.raises(RecursionError) as e: + conn.perform_request("GET", "/") + assert str(e.value) == "Wasn't modified!" + class TestConnectionHttpbin: """Tests the HTTP connection implementations against a live server E2E""" @@ -942,6 +974,11 @@ def test_urllib3_connection(self): "User-Agent": user_agent, } + def test_urllib3_connection_error(self): + conn = Urllib3HttpConnection("not.a.host.name") + with pytest.raises(ConnectionError): + conn.perform_request("GET", "/") + def test_requests_connection(self): # Defaults conn = RequestsHttpConnection("httpbin.org", port=443, use_ssl=True) @@ -1003,3 +1040,8 @@ def test_requests_connection(self): "Header2": "value2", "User-Agent": user_agent, } + + def test_requests_connection_error(self): + conn = RequestsHttpConnection("not.a.host.name") + with pytest.raises(ConnectionError): + conn.perform_request("GET", "/")