diff --git a/elasticsearch/__init__.py b/elasticsearch/__init__.py index 701fd95e5..c40077c08 100644 --- a/elasticsearch/__init__.py +++ b/elasticsearch/__init__.py @@ -32,6 +32,7 @@ from ._async.client import AsyncElasticsearch from ._sync.client import Elasticsearch +from .exceptions import ElasticsearchDeprecationWarning # noqa: F401 from .exceptions import ( ApiError, AuthenticationException, @@ -39,7 +40,6 @@ ConflictError, ConnectionError, ConnectionTimeout, - ElasticsearchDeprecationWarning, ElasticsearchException, ElasticsearchWarning, NotFoundError, @@ -73,5 +73,4 @@ "AuthorizationException", "UnsupportedProductError", "ElasticsearchWarning", - "ElasticsearchDeprecationWarning", ] diff --git a/elasticsearch/_async/client/__init__.py b/elasticsearch/_async/client/__init__.py index 3cc0becc1..3e520f8c4 100644 --- a/elasticsearch/_async/client/__init__.py +++ b/elasticsearch/_async/client/__init__.py @@ -18,14 +18,19 @@ import logging import warnings -from typing import Optional +from typing import Any, Callable, Dict, Optional, Union -from elastic_transport import AsyncTransport, TransportError +from elastic_transport import AsyncTransport, NodeConfig, TransportError from elastic_transport.client_utils import DEFAULT from ...exceptions import NotFoundError from ...serializer import DEFAULT_SERIALIZERS -from ._base import BaseClient, resolve_auth_headers +from ._base import ( + BaseClient, + create_sniff_callback, + default_sniff_callback, + resolve_auth_headers, +) from .async_search import AsyncSearchClient from .autoscaling import AutoscalingClient from .cat import CatClient @@ -148,9 +153,21 @@ def __init__( sniff_on_node_failure=DEFAULT, sniff_timeout=DEFAULT, min_delay_between_sniffing=DEFAULT, + sniffed_node_callback: Optional[ + Callable[[Dict[str, Any], NodeConfig], Optional[NodeConfig]] + ] = None, meta_header=DEFAULT, # Deprecated timeout=DEFAULT, + randomize_hosts=DEFAULT, + host_info_callback: Optional[ + Callable[ + [Dict[str, Any], Dict[str, Union[str, int]]], + Optional[Dict[str, Union[str, int]]], + ] + ] = None, + sniffer_timeout=DEFAULT, + sniff_on_connection_fail=DEFAULT, # Internal use only _transport: Optional[AsyncTransport] = None, ) -> None: @@ -170,6 +187,92 @@ def __init__( ) request_timeout = timeout + if randomize_hosts is not DEFAULT: + if randomize_nodes_in_pool is not DEFAULT: + raise ValueError( + "Can't specify both 'randomize_hosts' and 'randomize_nodes_in_pool', " + "instead only specify 'randomize_nodes_in_pool'" + ) + warnings.warn( + "The 'randomize_hosts' parameter is deprecated in favor of 'randomize_nodes_in_pool'", + category=DeprecationWarning, + stacklevel=2, + ) + randomize_nodes_in_pool = randomize_hosts + + if sniffer_timeout is not DEFAULT: + if min_delay_between_sniffing is not DEFAULT: + raise ValueError( + "Can't specify both 'sniffer_timeout' and 'min_delay_between_sniffing', " + "instead only specify 'min_delay_between_sniffing'" + ) + warnings.warn( + "The 'sniffer_timeout' parameter is deprecated in favor of 'min_delay_between_sniffing'", + category=DeprecationWarning, + stacklevel=2, + ) + min_delay_between_sniffing = sniffer_timeout + + if sniff_on_connection_fail is not DEFAULT: + if sniff_on_node_failure is not DEFAULT: + raise ValueError( + "Can't specify both 'sniff_on_connection_fail' and 'sniff_on_node_failure', " + "instead only specify 'sniff_on_node_failure'" + ) + warnings.warn( + "The 'sniff_on_connection_fail' parameter is deprecated in favor of 'sniff_on_node_failure'", + category=DeprecationWarning, + stacklevel=2, + ) + sniff_on_node_failure = sniff_on_connection_fail + + # Setting min_delay_between_sniffing=True implies sniff_before_requests=True + if min_delay_between_sniffing is not DEFAULT: + sniff_before_requests = True + + sniffing_options = ( + sniff_timeout, + sniff_on_start, + sniff_before_requests, + sniff_on_node_failure, + sniffed_node_callback, + min_delay_between_sniffing, + sniffed_node_callback, + ) + if cloud_id is not None and any( + x is not DEFAULT and x is not None for x in sniffing_options + ): + raise ValueError( + "Sniffing should not be enabled when connecting to Elastic Cloud" + ) + + sniff_callback = None + if host_info_callback is not None: + if sniffed_node_callback is not None: + raise ValueError( + "Can't specify both 'host_info_callback' and 'sniffed_node_callback', " + "instead only specify 'sniffed_node_callback'" + ) + warnings.warn( + "The 'host_info_callback' parameter is deprecated in favor of 'sniffed_node_callback'", + category=DeprecationWarning, + stacklevel=2, + ) + + sniff_callback = create_sniff_callback( + host_info_callback=host_info_callback + ) + elif sniffed_node_callback is not None: + sniff_callback = create_sniff_callback( + sniffed_node_callback=sniffed_node_callback + ) + elif ( + sniff_on_start is True + or sniff_before_requests is True + or sniff_on_node_failure is True + ): + sniff_callback = default_sniff_callback + if _transport is None: node_configs = client_node_configs( hosts, @@ -222,6 +325,7 @@ def __init__( _transport = transport_class( node_configs, client_meta_service=CLIENT_META_SERVICE, + sniff_callback=sniff_callback, **transport_kwargs, ) diff --git a/elasticsearch/_async/client/_base.py b/elasticsearch/_async/client/_base.py index 312bec622..015775caf 100644 --- a/elasticsearch/_async/client/_base.py +++ b/elasticsearch/_async/client/_base.py @@ -17,19 +17,33 @@ import re import warnings -from typing import Any, Collection, Iterable, Mapping, Optional, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + Collection, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, + TypeVar, + Union, +) -from elastic_transport import AsyncTransport, HttpHeaders +from elastic_transport import AsyncTransport, HttpHeaders, NodeConfig, SniffOptions from elastic_transport.client_utils import DEFAULT, DefaultType, resolve_default from ...compat import urlencode, warn_stacklevel from ...exceptions import ( HTTP_EXCEPTIONS, ApiError, + ConnectionError, ElasticsearchWarning, + SerializationError, UnsupportedProductError, ) -from .utils import _base64_auth_header +from .utils import _TYPE_ASYNC_SNIFF_CALLBACK, _base64_auth_header SelfType = TypeVar("SelfType", bound="BaseClient") SelfNamespacedType = TypeVar("SelfNamespacedType", bound="NamespacedClient") @@ -83,6 +97,102 @@ def resolve_auth_headers( return headers +def create_sniff_callback( + host_info_callback: Optional[ + Callable[[Dict[str, Any], Dict[str, Any]], Optional[Dict[str, Any]]] + ] = None, + sniffed_node_callback: Optional[ + Callable[[Dict[str, Any], NodeConfig], Optional[NodeConfig]] + ] = None, +) -> _TYPE_ASYNC_SNIFF_CALLBACK: + assert (host_info_callback is None) != (sniffed_node_callback is None) + + # Wrap the deprecated 'host_info_callback' into 'sniffed_node_callback' + if host_info_callback is not None: + + def _sniffed_node_callback( + node_info: Dict[str, Any], node_config: NodeConfig + ) -> Optional[NodeConfig]: + assert host_info_callback is not None + if ( + host_info_callback( # type ignore[misc] + node_info, {"host": node_config.host, "port": node_config.port} + ) + is None + ): + return None + return node_config + + sniffed_node_callback = _sniffed_node_callback + + async def sniff_callback( + transport: AsyncTransport, sniff_options: SniffOptions + ) -> List[NodeConfig]: + for _ in transport.node_pool.all(): + try: + meta, node_infos = await transport.perform_request( + "GET", + "/_nodes/_all/http", + headers={"accept": "application/json"}, + request_timeout=( + sniff_options.sniff_timeout + if not sniff_options.is_initial_sniff + else None + ), + ) + except (SerializationError, ConnectionError): + continue + + if not 200 <= meta.status <= 299: + continue + + node_configs = [] + for node_info in node_infos.get("nodes", {}).values(): + address = node_info.get("http", {}).get("publish_address") + if not address or ":" not in address: + continue + + if "/" in address: + # Support 7.x host/ip:port behavior where http.publish_host has been set. + fqdn, ipaddress = address.split("/", 1) + host = fqdn + _, port_str = ipaddress.rsplit(":", 1) + port = int(port_str) + else: + host, port_str = address.rsplit(":", 1) + port = int(port_str) + + assert sniffed_node_callback is not None + sniffed_node = sniffed_node_callback( + node_info, meta.node.replace(host=host, port=port) + ) + if sniffed_node is None: + continue + + # Use the node which was able to make the request as a base. + node_configs.append(sniffed_node) + + if node_configs: + return node_configs + + return [] + + return sniff_callback + + +def _default_sniffed_node_callback( + node_info: Dict[str, Any], node_config: NodeConfig +) -> Optional[NodeConfig]: + if node_info.get("roles", []) == ["master"]: + return None + return node_config + + +default_sniff_callback = create_sniff_callback( + sniffed_node_callback=_default_sniffed_node_callback +) + + class BaseClient: def __init__(self, _transport: AsyncTransport) -> None: self._transport = _transport diff --git a/elasticsearch/_async/client/utils.py b/elasticsearch/_async/client/utils.py index 0067f149b..d62f1f8ef 100644 --- a/elasticsearch/_async/client/utils.py +++ b/elasticsearch/_async/client/utils.py @@ -16,6 +16,7 @@ # under the License. from ..._sync.client.utils import ( + _TYPE_ASYNC_SNIFF_CALLBACK, _TYPE_HOSTS, CLIENT_META_SERVICE, SKIP_IN_PATH, @@ -30,6 +31,7 @@ __all__ = [ "CLIENT_META_SERVICE", + "_TYPE_ASYNC_SNIFF_CALLBACK", "_deprecated_options", "_TYPE_HOSTS", "SKIP_IN_PATH", diff --git a/elasticsearch/_sync/client/__init__.py b/elasticsearch/_sync/client/__init__.py index dfa5cfd73..1450ef082 100644 --- a/elasticsearch/_sync/client/__init__.py +++ b/elasticsearch/_sync/client/__init__.py @@ -18,14 +18,19 @@ import logging import warnings -from typing import Optional +from typing import Any, Callable, Dict, Optional, Union -from elastic_transport import Transport, TransportError +from elastic_transport import NodeConfig, Transport, TransportError from elastic_transport.client_utils import DEFAULT from ...exceptions import NotFoundError from ...serializer import DEFAULT_SERIALIZERS -from ._base import BaseClient, resolve_auth_headers +from ._base import ( + BaseClient, + create_sniff_callback, + default_sniff_callback, + resolve_auth_headers, +) from .async_search import AsyncSearchClient from .autoscaling import AutoscalingClient from .cat import CatClient @@ -148,9 +153,21 @@ def __init__( sniff_on_node_failure=DEFAULT, sniff_timeout=DEFAULT, min_delay_between_sniffing=DEFAULT, + sniffed_node_callback: Optional[ + Callable[[Dict[str, Any], NodeConfig], Optional[NodeConfig]] + ] = None, meta_header=DEFAULT, # Deprecated timeout=DEFAULT, + randomize_hosts=DEFAULT, + host_info_callback: Optional[ + Callable[ + [Dict[str, Any], Dict[str, Union[str, int]]], + Optional[Dict[str, Union[str, int]]], + ] + ] = None, + sniffer_timeout=DEFAULT, + sniff_on_connection_fail=DEFAULT, # Internal use only _transport: Optional[Transport] = None, ) -> None: @@ -170,6 +187,92 @@ def __init__( ) request_timeout = timeout + if randomize_hosts is not DEFAULT: + if randomize_nodes_in_pool is not DEFAULT: + raise ValueError( + "Can't specify both 'randomize_hosts' and 'randomize_nodes_in_pool', " + "instead only specify 'randomize_nodes_in_pool'" + ) + warnings.warn( + "The 'randomize_hosts' parameter is deprecated in favor of 'randomize_nodes_in_pool'", + category=DeprecationWarning, + stacklevel=2, + ) + randomize_nodes_in_pool = randomize_hosts + + if sniffer_timeout is not DEFAULT: + if min_delay_between_sniffing is not DEFAULT: + raise ValueError( + "Can't specify both 'sniffer_timeout' and 'min_delay_between_sniffing', " + "instead only specify 'min_delay_between_sniffing'" + ) + warnings.warn( + "The 'sniffer_timeout' parameter is deprecated in favor of 'min_delay_between_sniffing'", + category=DeprecationWarning, + stacklevel=2, + ) + min_delay_between_sniffing = sniffer_timeout + + if sniff_on_connection_fail is not DEFAULT: + if sniff_on_node_failure is not DEFAULT: + raise ValueError( + "Can't specify both 'sniff_on_connection_fail' and 'sniff_on_node_failure', " + "instead only specify 'sniff_on_node_failure'" + ) + warnings.warn( + "The 'sniff_on_connection_fail' parameter is deprecated in favor of 'sniff_on_node_failure'", + category=DeprecationWarning, + stacklevel=2, + ) + sniff_on_node_failure = sniff_on_connection_fail + + # Setting min_delay_between_sniffing=True implies sniff_before_requests=True + if min_delay_between_sniffing is not DEFAULT: + sniff_before_requests = True + + sniffing_options = ( + sniff_timeout, + sniff_on_start, + sniff_before_requests, + sniff_on_node_failure, + sniffed_node_callback, + min_delay_between_sniffing, + sniffed_node_callback, + ) + if cloud_id is not None and any( + x is not DEFAULT and x is not None for x in sniffing_options + ): + raise ValueError( + "Sniffing should not be enabled when connecting to Elastic Cloud" + ) + + sniff_callback = None + if host_info_callback is not None: + if sniffed_node_callback is not None: + raise ValueError( + "Can't specify both 'host_info_callback' and 'sniffed_node_callback', " + "instead only specify 'sniffed_node_callback'" + ) + warnings.warn( + "The 'host_info_callback' parameter is deprecated in favor of 'sniffed_node_callback'", + category=DeprecationWarning, + stacklevel=2, + ) + + sniff_callback = create_sniff_callback( + host_info_callback=host_info_callback + ) + elif sniffed_node_callback is not None: + sniff_callback = create_sniff_callback( + sniffed_node_callback=sniffed_node_callback + ) + elif ( + sniff_on_start is True + or sniff_before_requests is True + or sniff_on_node_failure is True + ): + sniff_callback = default_sniff_callback + if _transport is None: node_configs = client_node_configs( hosts, @@ -222,6 +325,7 @@ def __init__( _transport = transport_class( node_configs, client_meta_service=CLIENT_META_SERVICE, + sniff_callback=sniff_callback, **transport_kwargs, ) diff --git a/elasticsearch/_sync/client/_base.py b/elasticsearch/_sync/client/_base.py index afbbb6d00..ebaa82bdc 100644 --- a/elasticsearch/_sync/client/_base.py +++ b/elasticsearch/_sync/client/_base.py @@ -17,19 +17,33 @@ import re import warnings -from typing import Any, Collection, Mapping, Optional, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + Collection, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, + TypeVar, + Union, +) -from elastic_transport import HttpHeaders, Transport +from elastic_transport import HttpHeaders, NodeConfig, SniffOptions, Transport from elastic_transport.client_utils import DEFAULT, DefaultType, resolve_default from ...compat import urlencode, warn_stacklevel from ...exceptions import ( HTTP_EXCEPTIONS, ApiError, + ConnectionError, ElasticsearchWarning, + SerializationError, UnsupportedProductError, ) -from .utils import _base64_auth_header +from .utils import _TYPE_SYNC_SNIFF_CALLBACK, _base64_auth_header SelfType = TypeVar("SelfType", bound="BaseClient") SelfNamespacedType = TypeVar("SelfNamespacedType", bound="NamespacedClient") @@ -83,6 +97,102 @@ def resolve_auth_headers( return headers +def create_sniff_callback( + host_info_callback: Optional[ + Callable[[Dict[str, Any], Dict[str, Any]], Optional[Dict[str, Any]]] + ] = None, + sniffed_node_callback: Optional[ + Callable[[Dict[str, Any], NodeConfig], Optional[NodeConfig]] + ] = None, +) -> _TYPE_SYNC_SNIFF_CALLBACK: + assert (host_info_callback is None) != (sniffed_node_callback is None) + + # Wrap the deprecated 'host_info_callback' into 'sniffed_node_callback' + if host_info_callback is not None: + + def _sniffed_node_callback( + node_info: Dict[str, Any], node_config: NodeConfig + ) -> Optional[NodeConfig]: + assert host_info_callback is not None + if ( + host_info_callback( # type ignore[misc] + node_info, {"host": node_config.host, "port": node_config.port} + ) + is None + ): + return None + return node_config + + sniffed_node_callback = _sniffed_node_callback + + def sniff_callback( + transport: Transport, sniff_options: SniffOptions + ) -> List[NodeConfig]: + for _ in transport.node_pool.all(): + try: + meta, node_infos = transport.perform_request( + "GET", + "/_nodes/_all/http", + headers={"accept": "application/json"}, + request_timeout=( + sniff_options.sniff_timeout + if not sniff_options.is_initial_sniff + else None + ), + ) + except (SerializationError, ConnectionError): + continue + + if not 200 <= meta.status <= 299: + continue + + node_configs = [] + for node_info in node_infos.get("nodes", {}).values(): + address = node_info.get("http", {}).get("publish_address") + if not address or ":" not in address: + continue + + if "/" in address: + # Support 7.x host/ip:port behavior where http.publish_host has been set. + fqdn, ipaddress = address.split("/", 1) + host = fqdn + _, port_str = ipaddress.rsplit(":", 1) + port = int(port_str) + else: + host, port_str = address.rsplit(":", 1) + port = int(port_str) + + assert sniffed_node_callback is not None + sniffed_node = sniffed_node_callback( + node_info, meta.node.replace(host=host, port=port) + ) + if sniffed_node is None: + continue + + # Use the node which was able to make the request as a base. + node_configs.append(sniffed_node) + + if node_configs: + return node_configs + + return [] + + return sniff_callback + + +def _default_sniffed_node_callback( + node_info: Dict[str, Any], node_config: NodeConfig +) -> Optional[NodeConfig]: + if node_info.get("roles", []) == ["master"]: + return None + return node_config + + +default_sniff_callback = create_sniff_callback( + sniffed_node_callback=_default_sniffed_node_callback +) + + class BaseClient: def __init__(self, _transport: Transport) -> None: self._transport = _transport @@ -165,9 +275,10 @@ def _perform_request( # 'Warning' headers should be reraised as 'ElasticsearchWarning' warning_header = (meta.headers.get("warning") or "").strip() if warning_header: - for warning_message in _WARNING_RE.findall(warning_header) or ( + warning_messages: Iterable[str] = _WARNING_RE.findall(warning_header) or ( warning_header, - ): + ) + for warning_message in warning_messages: warnings.warn( warning_message, category=ElasticsearchWarning, diff --git a/elasticsearch/_sync/client/utils.py b/elasticsearch/_sync/client/utils.py index 4a00adb30..7dcb97546 100644 --- a/elasticsearch/_sync/client/utils.py +++ b/elasticsearch/_sync/client/utils.py @@ -23,6 +23,7 @@ from typing import ( TYPE_CHECKING, Any, + Awaitable, Callable, Collection, List, @@ -34,7 +35,7 @@ Union, ) -from elastic_transport import NodeConfig +from elastic_transport import AsyncTransport, NodeConfig, SniffOptions, Transport from elastic_transport.client_utils import ( DEFAULT, client_meta_version, @@ -57,6 +58,11 @@ _TYPE_HOSTS = Union[str, List[Union[str, Mapping[str, Union[str, int]], NodeConfig]]] +_TYPE_ASYNC_SNIFF_CALLBACK = Callable[ + [AsyncTransport, SniffOptions], Awaitable[List[NodeConfig]] +] +_TYPE_SYNC_SNIFF_CALLBACK = Callable[[Transport, SniffOptions], List[NodeConfig]] + def client_node_configs( hosts: _TYPE_HOSTS, cloud_id: str, **kwargs: Any diff --git a/elasticsearch/transport.py b/elasticsearch/transport.py index a8ddd6915..5d84866aa 100644 --- a/elasticsearch/transport.py +++ b/elasticsearch/transport.py @@ -15,9 +15,18 @@ # specific language governing permissions and limitations # under the License. +import warnings from typing import Any, Dict, Optional, Union -from elastic_transport import Transport # noqa: F401 +from elastic_transport import AsyncTransport, Transport # noqa: F401 + +# This file exists for backwards compatibility. +warnings.warn( + "Importing from the 'elasticsearch.transport' module is deprecated. " + "Instead import from 'elastic_transport'", + category=DeprecationWarning, + stacklevel=2, +) def get_host_info( @@ -34,6 +43,14 @@ def get_host_info( :arg node_info: node information from `/_cluster/nodes` :arg host: connection information (host, port) extracted from the node info """ + + warnings.warn( + "The 'get_host_info' function is deprecated. Instead " + "use the 'sniff_node_callback' parameter on the client", + category=DeprecationWarning, + stacklevel=2, + ) + # ignore master only nodes if node_info.get("roles", []) == ["master"]: return None diff --git a/noxfile.py b/noxfile.py index 3f93a32b2..719d811b1 100644 --- a/noxfile.py +++ b/noxfile.py @@ -57,8 +57,9 @@ def test(session): @nox.session() def format(session): - session.install("black", "isort", "flynt") + session.install("black", "isort", "flynt", "unasync") + session.run("python", "utils/run-unasync.py") session.run("isort", "--profile=black", *SOURCE_FILES) session.run("flynt", *SOURCE_FILES) session.run("black", "--target-version=py36", *SOURCE_FILES) diff --git a/test_elasticsearch/test_async/test_transport.py b/test_elasticsearch/test_async/test_transport.py index c551e5843..39b810840 100644 --- a/test_elasticsearch/test_async/test_transport.py +++ b/test_elasticsearch/test_async/test_transport.py @@ -19,17 +19,17 @@ from __future__ import unicode_literals import asyncio -import json import re import warnings +from typing import Any, Dict, Optional, Union import pytest from elastic_transport import ApiResponseMeta, BaseAsyncNode, HttpHeaders, NodeConfig from elastic_transport.client_utils import DEFAULT -from mock import patch from elasticsearch import AsyncElasticsearch from elasticsearch.exceptions import ( + ApiError, ConnectionError, ElasticsearchWarning, TransportError, @@ -39,9 +39,6 @@ pytestmark = pytest.mark.asyncio -sniffing_xfail = pytest.mark.xfail(strict=True) - - class DummyNode(BaseAsyncNode): def __init__(self, config: NodeConfig): self.resp_status = config._extras.pop("status", 200) @@ -126,6 +123,45 @@ async def close(self): } }""" +CLUSTER_NODES_MASTER_ONLY = """{ + "_nodes" : { + "total" : 2, + "successful" : 2, + "failed" : 0 + }, + "cluster_name" : "elasticsearch", + "nodes" : { + "SRZpKFZdQguhhvifmN6UVA" : { + "name" : "SRZpKFZa", + "transport_address" : "127.0.0.1:9300", + "host" : "127.0.0.1", + "ip" : "127.0.0.1", + "version" : "5.0.0", + "build_hash" : "253032b", + "roles" : ["master"], + "http" : { + "bound_address" : [ "[fe80::1]:9200", "[::1]:9200", "127.0.0.1:9200" ], + "publish_address" : "somehost.tld/1.1.1.1:123", + "max_content_length_in_bytes" : 104857600 + } + }, + "SRZpKFZdQguhhvifmN6UVB" : { + "name" : "SRZpKFZb", + "transport_address" : "127.0.0.1:9300", + "host" : "127.0.0.1", + "ip" : "127.0.0.1", + "version" : "5.0.0", + "build_hash" : "253032b", + "roles" : [ "master", "data", "ingest" ], + "http" : { + "bound_address" : [ "[fe80::1]:9200", "[::1]:9200", "127.0.0.1:9200" ], + "publish_address" : "somehost.tld/1.1.1.1:124", + "max_content_length_in_bytes" : 104857600 + } + } + } +}""" + class TestTransport: async def test_request_timeout_extracted_from_params_and_passed(self): @@ -303,253 +339,217 @@ async def test_resurrected_connection_will_be_marked_as_live_on_success(self): assert len(client.transport.node_pool.alive_nodes) == 1 assert len(client.transport.node_pool.dead_consecutive_failures) == 1 - @sniffing_xfail - async def test_sniff_will_use_seed_connections(self): - t = AsyncTransport( # noqa: F821 - [{"data": CLUSTER_NODES}], connection_class=DummyNode - ) - await t._async_call() - t.set_connections([{"data": "invalid"}]) - - await t.sniff_hosts() - assert 1 == len(t.connection_pool.connections) - assert "http://1.1.1.1:123" == t.get_connection().host - - @sniffing_xfail - async def test_sniff_on_start_fetches_and_uses_nodes_list(self): - t = AsyncTransport( # noqa: F821 - [{"data": CLUSTER_NODES}], - connection_class=DummyNode, + @pytest.mark.parametrize( + ["nodes_info_response", "node_host"], + [(CLUSTER_NODES, "1.1.1.1"), (CLUSTER_NODES_7x_PUBLISH_HOST, "somehost.tld")], + ) + async def test_sniff_will_use_seed_connections( + self, nodes_info_response, node_host + ): + client = AsyncElasticsearch( + [ + NodeConfig( + "http", "localhost", 9200, _extras={"data": nodes_info_response} + ) + ], + node_class=DummyNode, sniff_on_start=True, ) - await t._async_call() - await t.sniffing_task # Need to wait for the sniffing task to complete - assert 1 == len(t.connection_pool.connections) - assert "http://1.1.1.1:123" == t.get_connection().host + # Async sniffing happens in the background. + await client.transport._async_call() + assert client.transport._sniffing_task is not None + await client.transport._sniffing_task + + node_configs = [node.config for node in client.transport.node_pool.all()] + assert len(node_configs) == 2 + assert NodeConfig("http", node_host, 123) in node_configs - @sniffing_xfail async def test_sniff_on_start_ignores_sniff_timeout(self): - t = AsyncTransport( # noqa: F821 - [{"data": CLUSTER_NODES}], - connection_class=DummyNode, + client = AsyncElasticsearch( + [NodeConfig("http", "localhost", 9200, _extras={"data": CLUSTER_NODES})], + node_class=DummyNode, sniff_on_start=True, sniff_timeout=12, + meta_header=False, ) - await t._async_call() - await t.sniffing_task # Need to wait for the sniffing task to complete - assert (("GET", "/_nodes/_all/http"), {"timeout": None}) == t.seed_connections[ - 0 - ].calls[0] + # Async sniffing happens in the background. + await client.transport._async_call() + assert client.transport._sniffing_task is not None + await client.transport._sniffing_task + + node_config = client.transport.node_pool.seed_nodes[0] + calls = client.transport.node_pool.all_nodes[node_config].calls + + assert len(calls) == 1 + assert calls[0] == ( + ("GET", "/_nodes/_all/http"), + { + "body": None, + "headers": {"accept": "application/json"}, + "request_timeout": None, # <-- Should be None instead of 12 + }, + ) - @sniffing_xfail async def test_sniff_uses_sniff_timeout(self): - t = AsyncTransport( # noqa: F821 - [{"data": CLUSTER_NODES}], - connection_class=DummyNode, - sniff_timeout=42, - ) - await t._async_call() - await t.sniff_hosts() - - assert (("GET", "/_nodes/_all/http"), {"timeout": 42}) == t.seed_connections[ - 0 - ].calls[0] - - @sniffing_xfail - async def test_sniff_reuses_connection_instances_if_possible(self): - t = AsyncTransport( # noqa: F821 - [{"data": CLUSTER_NODES}, {"host": "1.1.1.1", "port": 123}], - connection_class=DummyNode, - randomize_hosts=False, - ) - await t._async_call() - connection = t.connection_pool.connections[1] - connection.delay = 3.0 # Add this delay to make the sniffing deterministic. - - await t.sniff_hosts() - assert 1 == len(t.connection_pool.connections) - assert connection is t.get_connection() - - @sniffing_xfail - async def test_sniff_on_fail_triggers_sniffing_on_fail(self): - t = AsyncTransport( # noqa: F821 - [{"exception": ConnectionError("abandon ship")}, {"data": CLUSTER_NODES}], - connection_class=DummyNode, - sniff_on_connection_fail=True, - max_retries=0, - randomize_hosts=False, + client = AsyncElasticsearch( + [NodeConfig("http", "localhost", 9200, _extras={"data": CLUSTER_NODES})], + node_class=DummyNode, + sniff_before_requests=True, + sniff_timeout=12, + meta_header=False, ) - await t._async_call() + await client.info() - connection_error = False - try: - await t.perform_request("GET", "/") - except ConnectionError: - connection_error = True - - await t.sniffing_task # Need to wait for the sniffing task to complete - - assert connection_error - assert 1 == len(t.connection_pool.connections) - assert "http://1.1.1.1:123" == t.get_connection().host - - @sniffing_xfail - @patch("elasticsearch._async.transport.AsyncTransport.sniff_hosts") - async def test_sniff_on_fail_failing_does_not_prevent_retires(self, sniff_hosts): - sniff_hosts.side_effect = [TransportError("sniff failed")] - t = AsyncTransport( # noqa: F821 - [{"exception": ConnectionError("abandon ship")}, {"data": CLUSTER_NODES}], - connection_class=DummyNode, - sniff_on_connection_fail=True, - max_retries=3, - randomize_hosts=False, + # Async sniffing happens in the background. + assert client.transport._sniffing_task is not None + await client.transport._sniffing_task + + node_config = client.transport.node_pool.seed_nodes[0] + calls = client.transport.node_pool.all_nodes[node_config].calls + + assert len(calls) == 2 + assert calls[0] == ( + ("GET", "/"), + { + "body": None, + "headers": {"content-type": "application/json"}, + "request_timeout": DEFAULT, + }, ) - await t._async_init() - - conn_err, conn_data = t.connection_pool.connections - response = await t.perform_request("GET", "/") - assert json.loads(CLUSTER_NODES) == response - assert 1 == sniff_hosts.call_count - assert 1 == len(conn_err.calls) - assert 1 == len(conn_data.calls) - - @sniffing_xfail - async def test_sniff_after_n_seconds(self, event_loop): - t = AsyncTransport( # noqa: F821 - [{"data": CLUSTER_NODES}], - connection_class=DummyNode, - sniffer_timeout=5, + assert calls[1] == ( + ("GET", "/_nodes/_all/http"), + { + "body": None, + "headers": {"accept": "application/json"}, + "request_timeout": 12, + }, ) - await t._async_call() - for _ in range(4): - await t.perform_request("GET", "/") - assert 1 == len(t.connection_pool.connections) - assert isinstance(t.get_connection(), DummyNode) - t.last_sniff = event_loop.time() - 5.1 - - await t.perform_request("GET", "/") - await t.sniffing_task # Need to wait for the sniffing task to complete - - assert 1 == len(t.connection_pool.connections) - assert "http://1.1.1.1:123" == t.get_connection().host - assert event_loop.time() - 1 < t.last_sniff < event_loop.time() + 0.01 - - @sniffing_xfail - async def test_sniff_7x_publish_host(self): - # Test the response shaped when a 7.x node has publish_host set - # and the returend data is shaped in the fqdn/ip:port format. - t = AsyncTransport( # noqa: F821 - [{"data": CLUSTER_NODES_7x_PUBLISH_HOST}], - connection_class=DummyNode, - sniff_timeout=42, - ) - await t._async_call() - await t.sniff_hosts() - # Ensure we parsed out the fqdn and port from the fqdn/ip:port string. - assert t.connection_pool.connection_opts[0][1] == { - "host": "somehost.tld", - "port": 123, - } - - @sniffing_xfail - @patch("elasticsearch._async.transport.AsyncTransport.sniff_hosts") - async def test_sniffing_disabled_on_cloud_instances(self, sniff_hosts): - t = AsyncTransport( # noqa: F821 - [{}], + async def test_sniff_on_start_awaits_before_request(self): + client = AsyncElasticsearch( + [NodeConfig("http", "localhost", 9200, _extras={"data": CLUSTER_NODES})], + node_class=DummyNode, sniff_on_start=True, - sniff_on_connection_fail=True, - connection_class=DummyNode, - cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==", + sniff_timeout=12, + meta_header=False, ) - await t._async_call() - assert not t.sniff_on_connection_fail - assert sniff_hosts.call_args is None # Assert not called. - await t.perform_request("GET", "/", body={}) - assert 1 == len(t.get_connection().calls) - assert ("GET", "/", None, b"{}") == t.get_connection().calls[0][0] + await client.info() - @sniffing_xfail - async def test_transport_close_closes_all_pool_connections(self): - t = AsyncTransport([{}], connection_class=DummyNode) # noqa: F821 - await t._async_call() + node_config = client.transport.node_pool.seed_nodes[0] + calls = client.transport.node_pool.all_nodes[node_config].calls - assert not any([conn.closed for conn in t.connection_pool.connections]) - await t.close() - assert all([conn.closed for conn in t.connection_pool.connections]) + assert len(calls) == 2 + # The sniff request happens first. + assert calls[0][0] == ("GET", "/_nodes/_all/http") + assert calls[1][0] == ("GET", "/") - t = AsyncTransport([{}, {}], connection_class=DummyNode) # noqa: F821 - await t._async_call() + async def test_sniff_reuses_node_instances(self): + client = AsyncElasticsearch( + [NodeConfig("http", "1.1.1.1", 123, _extras={"data": CLUSTER_NODES})], + node_class=DummyNode, + sniff_on_start=True, + ) - assert not any([conn.closed for conn in t.connection_pool.connections]) - await t.close() - assert all([conn.closed for conn in t.connection_pool.connections]) + assert len(client.transport.node_pool.all_nodes) == 1 + await client.info() + assert len(client.transport.node_pool.all_nodes) == 1 - @sniffing_xfail - async def test_sniff_on_start_no_viable_hosts(self, event_loop): - t = AsyncTransport( # noqa: F821 + @pytest.mark.parametrize( + "bad_node_extras", + [{"exception": ConnectionError("Abandon ship!")}, {"status": 500}], + ) + async def test_sniff_on_node_failure_triggers(self, bad_node_extras): + client = AsyncElasticsearch( [ - {"data": ""}, - {"data": ""}, - {"data": ""}, + NodeConfig("http", "localhost", 9200, _extras=bad_node_extras), + NodeConfig("http", "localhost", 9201, _extras={"data": CLUSTER_NODES}), ], - connection_class=DummyNode, - sniff_on_start=True, + node_class=DummyNode, + sniff_on_node_failure=True, + randomize_nodes_in_pool=False, + max_retries=0, ) - # If our initial sniffing attempt comes back - # empty then we raise an error. - with pytest.raises(TransportError) as e: - await t._async_call() - assert str(e.value) == "TransportError(N/A, 'Unable to sniff hosts.')" + request_failed_in_error = False + try: + await client.info() + except (ConnectionError, ApiError): + request_failed_in_error = True - @sniffing_xfail - async def test_sniff_on_start_waits_for_sniff_to_complete(self, event_loop): - t = AsyncTransport( # noqa: F821 - [ - {"delay": 1, "data": ""}, - {"delay": 1, "data": ""}, - {"delay": 1, "data": CLUSTER_NODES}, - ], - connection_class=DummyNode, - sniff_on_start=True, + assert client.transport._sniffing_task is not None + await client.transport._sniffing_task + + assert request_failed_in_error + assert len(client.transport.node_pool.all_nodes) == 3 + + async def test_sniff_after_n_seconds(self, event_loop): + client = AsyncElasticsearch( # noqa: F821 + [NodeConfig("http", "localhost", 9200, _extras={"data": CLUSTER_NODES})], + node_class=DummyNode, + min_delay_between_sniffing=5, ) + client.transport._last_sniffed_at = event_loop.time() - # Start the timer right before the first task - # and have a bunch of tasks come in immediately. - tasks = [] - start_time = event_loop.time() - for _ in range(5): - tasks.append(event_loop.create_task(t._async_call())) - await asyncio.sleep(0) # Yield to the loop + await client.info() + + for _ in range(4): + await client.info() + await asyncio.sleep(0) - assert t.sniffing_task is not None + assert 1 == len(client.transport.node_pool.all_nodes) - # Tasks streaming in later. - for _ in range(5): - tasks.append(event_loop.create_task(t._async_call())) - await asyncio.sleep(0.1) + client.transport._last_sniffed_at = event_loop.time() - 5.1 - # Now that all the API calls have come in we wait for - # them all to resolve before - await asyncio.gather(*tasks) - end_time = event_loop.time() - duration = end_time - start_time + await client.info() + await client.transport._sniffing_task # Need to wait for the sniffing task to complete + + assert 2 == len(client.transport.node_pool.all_nodes) + assert "http://1.1.1.1:123" in ( + node.base_url for node in client.transport.node_pool.all() + ) + assert ( + event_loop.time() - 1 + < client.transport._last_sniffed_at + < event_loop.time() + 0.01 + ) - # All the tasks blocked on the sniff of each node - # and then resolved immediately after. - assert 1 <= duration < 2 + @pytest.mark.parametrize( + "kwargs", + [ + {"sniff_on_start": True}, + {"sniff_on_connection_fail": True}, + {"sniff_on_node_failure": True}, + {"sniff_before_requests": True}, + {"sniffer_timeout": 1}, + {"sniff_timeout": 1}, + ], + ) + async def test_sniffing_disabled_on_elastic_cloud(self, kwargs): + with pytest.raises(ValueError) as e: + AsyncElasticsearch( + cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==", + **kwargs, + ) + + assert ( + str(e.value) + == "Sniffing should not be enabled when connecting to Elastic Cloud" + ) - @sniffing_xfail async def test_sniff_on_start_close_unlocks_async_calls(self, event_loop): - t = AsyncTransport( # noqa: F821 + client = AsyncElasticsearch( # noqa: F821 [ - {"delay": 10, "data": CLUSTER_NODES}, + NodeConfig( + "http", + "localhost", + 9200, + _extras={"delay": 10, "data": CLUSTER_NODES}, + ), ], - connection_class=DummyNode, + node_class=DummyNode, sniff_on_start=True, ) @@ -557,11 +557,11 @@ async def test_sniff_on_start_close_unlocks_async_calls(self, event_loop): tasks = [] start_time = event_loop.time() for _ in range(3): - tasks.append(event_loop.create_task(t._async_call())) + tasks.append(event_loop.create_task(client.info())) await asyncio.sleep(0) # Close the transport while the sniffing task is active! :( - await t.close() + await client.transport.close() # Now we start waiting on all those _async_calls() await asyncio.gather(*tasks) @@ -571,6 +571,89 @@ async def test_sniff_on_start_close_unlocks_async_calls(self, event_loop): # A lot quicker than 10 seconds defined in 'delay' assert duration < 1 + async def test_sniffing_master_only_filtered_by_default(self): + client = AsyncElasticsearch( # noqa: F821 + [ + NodeConfig( + "http", + "localhost", + 9200, + _extras={"data": CLUSTER_NODES_MASTER_ONLY}, + ) + ], + node_class=DummyNode, + sniff_on_start=True, + ) + await client.transport._async_call() + + assert len(client.transport.node_pool.all_nodes) == 2 + + async def test_sniff_node_callback(self): + def sniffed_node_callback( + node_info: Dict[str, Any], node_config: NodeConfig + ) -> Optional[NodeConfig]: + return ( + node_config + if node_info["http"]["publish_address"].endswith(":124") + else None + ) + + client = AsyncElasticsearch( # noqa: F821 + [ + NodeConfig( + "http", + "localhost", + 9200, + _extras={"data": CLUSTER_NODES_MASTER_ONLY}, + ) + ], + node_class=DummyNode, + sniff_on_start=True, + sniffed_node_callback=sniffed_node_callback, + ) + await client.transport._async_call() + + assert len(client.transport.node_pool.all_nodes) == 2 + + ports = {node.config.port for node in client.transport.node_pool.all()} + assert ports == {9200, 124} + + async def test_sniffing_deprecated_host_info_callback(self): + def host_info_callback( + node_info: Dict[str, Any], host: Dict[str, Union[int, str]] + ) -> Dict[str, Any]: + return ( + host if node_info["http"]["publish_address"].endswith(":124") else None + ) + + with warnings.catch_warnings(record=True) as w: + client = AsyncElasticsearch( # noqa: F821 + [ + NodeConfig( + "http", + "localhost", + 9200, + _extras={"data": CLUSTER_NODES_MASTER_ONLY}, + ) + ], + node_class=DummyNode, + sniff_on_start=True, + host_info_callback=host_info_callback, + ) + await client.transport._async_call() + + assert len(w) == 1 + assert w[0].category == DeprecationWarning + assert ( + str(w[0].message) + == "The 'host_info_callback' parameter is deprecated in favor of 'sniffed_node_callback'" + ) + + assert len(client.transport.node_pool.all_nodes) == 2 + + ports = {node.config.port for node in client.transport.node_pool.all()} + assert ports == {9200, 124} + @pytest.mark.parametrize("headers", [{}, {"X-elastic-product": "BAD HEADER"}]) async def test_unsupported_product_error(headers): diff --git a/test_elasticsearch/test_client/test_deprecated_options.py b/test_elasticsearch/test_client/test_deprecated_options.py new file mode 100644 index 000000000..d7037521d --- /dev/null +++ b/test_elasticsearch/test_client/test_deprecated_options.py @@ -0,0 +1,52 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import warnings + +from elasticsearch import Elasticsearch + + +def test_sniff_on_connection_fail(): + with warnings.catch_warnings(record=True) as w: + client = Elasticsearch("http://localhost:9200", sniff_on_connection_fail=True) + assert client.transport._sniff_on_node_failure is True + assert len(w) == 1 + assert w[0].category == DeprecationWarning + assert str(w[0].message) == ( + "The 'sniff_on_connection_fail' parameter is deprecated in favor of 'sniff_on_node_failure'" + ) + + +def test_sniffer_timeout(): + with warnings.catch_warnings(record=True) as w: + client = Elasticsearch("http://localhost:9200", sniffer_timeout=1) + assert client.transport._min_delay_between_sniffing == 1 + assert len(w) == 1 + assert w[0].category == DeprecationWarning + assert str(w[0].message) == ( + "The 'sniffer_timeout' parameter is deprecated in favor of 'min_delay_between_sniffing'" + ) + + +def test_randomize_hosts(): + with warnings.catch_warnings(record=True) as w: + Elasticsearch("http://localhost:9200", randomize_hosts=True) + assert len(w) == 1 + assert w[0].category == DeprecationWarning + assert str(w[0].message) == ( + "The 'randomize_hosts' parameter is deprecated in favor of 'randomize_nodes_in_pool'" + ) diff --git a/test_elasticsearch/test_transport.py b/test_elasticsearch/test_transport.py index d9d59c795..c8fc1dbc0 100644 --- a/test_elasticsearch/test_transport.py +++ b/test_elasticsearch/test_transport.py @@ -18,15 +18,14 @@ from __future__ import unicode_literals -import json import re import time import warnings +from typing import Any, Dict, Optional, Union import pytest from elastic_transport import ApiResponseMeta, BaseNode, HttpHeaders, NodeConfig from elastic_transport.client_utils import DEFAULT -from mock import patch from elasticsearch import Elasticsearch from elasticsearch.exceptions import ( @@ -37,8 +36,6 @@ ) from elasticsearch.transport import get_host_info -sniffing_xfail = pytest.mark.xfail(strict=True) - class DummyNode(BaseNode): def __init__(self, config: NodeConfig): @@ -118,6 +115,45 @@ def perform_request(self, *args, **kwargs): } }""" +CLUSTER_NODES_MASTER_ONLY = """{ + "_nodes" : { + "total" : 2, + "successful" : 2, + "failed" : 0 + }, + "cluster_name" : "elasticsearch", + "nodes" : { + "SRZpKFZdQguhhvifmN6UVA" : { + "name" : "SRZpKFZa", + "transport_address" : "127.0.0.1:9300", + "host" : "127.0.0.1", + "ip" : "127.0.0.1", + "version" : "5.0.0", + "build_hash" : "253032b", + "roles" : ["master"], + "http" : { + "bound_address" : [ "[fe80::1]:9200", "[::1]:9200", "127.0.0.1:9200" ], + "publish_address" : "somehost.tld/1.1.1.1:123", + "max_content_length_in_bytes" : 104857600 + } + }, + "SRZpKFZdQguhhvifmN6UVB" : { + "name" : "SRZpKFZb", + "transport_address" : "127.0.0.1:9300", + "host" : "127.0.0.1", + "ip" : "127.0.0.1", + "version" : "5.0.0", + "build_hash" : "253032b", + "roles" : [ "master", "data", "ingest" ], + "http" : { + "bound_address" : [ "[fe80::1]:9200", "[::1]:9200", "127.0.0.1:9200" ], + "publish_address" : "somehost.tld/1.1.1.1:124", + "max_content_length_in_bytes" : 104857600 + } + } + } +}""" + class TestHostsInfoCallback: def test_master_only_nodes_are_ignored(self): @@ -313,145 +349,215 @@ def test_resurrected_connection_will_be_marked_as_live_on_success(self): assert len(client.transport.node_pool.alive_nodes) == 1 assert len(client.transport.node_pool.dead_consecutive_failures) == 1 - @sniffing_xfail - def test_sniff_will_use_seed_connections(self): - t = Transport( # noqa: F821 - [{"data": CLUSTER_NODES}], connection_class=DummyNode - ) - t.set_connections([{"data": "invalid"}]) - - t.sniff_hosts() - assert 1 == len(t.connection_pool.connections) - assert "http://1.1.1.1:123" == t.get_connection().host - - @sniffing_xfail - def test_sniff_on_start_fetches_and_uses_nodes_list(self): - t = Transport( # noqa: F821 - [{"data": CLUSTER_NODES}], - connection_class=DummyNode, + @pytest.mark.parametrize( + ["nodes_info_response", "node_host"], + [(CLUSTER_NODES, "1.1.1.1"), (CLUSTER_NODES_7x_PUBLISH_HOST, "somehost.tld")], + ) + def test_sniff_will_use_seed_connections(self, nodes_info_response, node_host): + client = Elasticsearch( + [ + NodeConfig( + "http", "localhost", 9200, _extras={"data": nodes_info_response} + ) + ], + node_class=DummyNode, sniff_on_start=True, ) - assert 1 == len(t.connection_pool.connections) - assert "http://1.1.1.1:123" == t.get_connection().host - @sniffing_xfail + node_configs = [node.config for node in client.transport.node_pool.all()] + assert len(node_configs) == 2 + assert NodeConfig("http", node_host, 123) in node_configs + def test_sniff_on_start_ignores_sniff_timeout(self): - t = Transport( # noqa: F821 - [{"data": CLUSTER_NODES}], - connection_class=DummyNode, + client = Elasticsearch( + [NodeConfig("http", "localhost", 9200, _extras={"data": CLUSTER_NODES})], + node_class=DummyNode, sniff_on_start=True, sniff_timeout=12, + meta_header=False, + ) + + node_config = client.transport.node_pool.seed_nodes[0] + calls = client.transport.node_pool.all_nodes[node_config].calls + + assert len(calls) == 1 + assert calls[0] == ( + ("GET", "/_nodes/_all/http"), + { + "body": None, + "headers": {"accept": "application/json"}, + "request_timeout": None, # <-- Should be None instead of 12 + }, ) - assert (("GET", "/_nodes/_all/http"), {"timeout": None}) == t.seed_connections[ - 0 - ].calls[0] - @sniffing_xfail def test_sniff_uses_sniff_timeout(self): - t = Transport( # noqa: F821 - [{"data": CLUSTER_NODES}], - connection_class=DummyNode, - sniff_timeout=42, + client = Elasticsearch( + [NodeConfig("http", "localhost", 9200, _extras={"data": CLUSTER_NODES})], + node_class=DummyNode, + sniff_before_requests=True, + sniff_timeout=12, + meta_header=False, ) - t.sniff_hosts() - assert (("GET", "/_nodes/_all/http"), {"timeout": 42}) == t.seed_connections[ - 0 - ].calls[0] - - @sniffing_xfail - def test_sniff_reuses_connection_instances_if_possible(self): - t = Transport( # noqa: F821 - [{"data": CLUSTER_NODES}, {"host": "1.1.1.1", "port": 123}], - connection_class=DummyNode, - randomize_hosts=False, + client.info() + + node_config = client.transport.node_pool.seed_nodes[0] + calls = client.transport.node_pool.all_nodes[node_config].calls + + assert len(calls) == 2 + assert calls[0] == ( + ("GET", "/_nodes/_all/http"), + { + "body": None, + "headers": {"accept": "application/json"}, + "request_timeout": 12, + }, ) - connection = t.connection_pool.connections[1] - - t.sniff_hosts() - assert 1 == len(t.connection_pool.connections) - assert connection is t.get_connection() - - @sniffing_xfail - def test_sniff_on_fail_triggers_sniffing_on_fail(self): - t = Transport( # noqa: F821 - [{"exception": ConnectionError("abandon ship")}, {"data": CLUSTER_NODES}], - connection_class=DummyNode, - sniff_on_connection_fail=True, - max_retries=0, - randomize_hosts=False, + assert calls[1] == ( + ("GET", "/"), + { + "body": None, + "headers": {"content-type": "application/json"}, + "request_timeout": DEFAULT, + }, ) - with pytest.raises(ConnectionError): - t.perform_request("GET", "/") - assert 1 == len(t.connection_pool.connections) - assert "http://1.1.1.1:123" == t.get_connection().host - - @sniffing_xfail - @patch("elasticsearch.transport.Transport.sniff_hosts") - def test_sniff_on_fail_failing_does_not_prevent_retires(self, sniff_hosts): - sniff_hosts.side_effect = [TransportError("sniff failed")] - t = Transport( # noqa: F821 - [{"exception": ConnectionError("abandon ship")}, {"data": CLUSTER_NODES}], - connection_class=DummyNode, - sniff_on_connection_fail=True, - max_retries=3, - randomize_hosts=False, + def test_sniff_reuses_node_instances(self): + client = Elasticsearch( + [NodeConfig("http", "1.1.1.1", 123, _extras={"data": CLUSTER_NODES})], + node_class=DummyNode, + sniff_on_start=True, ) - conn_err, conn_data = t.connection_pool.connections - response = t.perform_request("GET", "/") - assert json.loads(CLUSTER_NODES) == response - assert 1 == sniff_hosts.call_count - assert 1 == len(conn_err.calls) - assert 1 == len(conn_data.calls) + assert len(client.transport.node_pool.all_nodes) == 1 + client.info() + assert len(client.transport.node_pool.all_nodes) == 1 - @sniffing_xfail def test_sniff_after_n_seconds(self): - t = Transport( # noqa: F821 - [{"data": CLUSTER_NODES}], - connection_class=DummyNode, - sniffer_timeout=5, + client = Elasticsearch( # noqa: F821 + [NodeConfig("http", "localhost", 9200, _extras={"data": CLUSTER_NODES})], + node_class=DummyNode, + min_delay_between_sniffing=5, ) + client.transport._last_sniffed_at = time.time() + client.info() for _ in range(4): - t.perform_request("GET", "/") - assert 1 == len(t.connection_pool.connections) - assert isinstance(t.get_connection(), DummyNode) - t.last_sniff = time.time() - 5.1 - - t.perform_request("GET", "/") - assert 1 == len(t.connection_pool.connections) - assert "http://1.1.1.1:123" == t.get_connection().host - assert time.time() - 1 < t.last_sniff < time.time() + 0.01 - - @sniffing_xfail - def test_sniff_7x_publish_host(self): - # Test the response shaped when a 7.x node has publish_host set - # and the returend data is shaped in the fqdn/ip:port format. - t = Transport( # noqa: F821 - [{"data": CLUSTER_NODES_7x_PUBLISH_HOST}], - connection_class=DummyNode, - sniff_timeout=42, + client.info() + + assert 1 == len(client.transport.node_pool.all_nodes) + + client.transport._last_sniffed_at = time.time() - 5.1 + + client.info() + + assert 2 == len(client.transport.node_pool.all_nodes) + assert "http://1.1.1.1:123" in ( + node.base_url for node in client.transport.node_pool.all() ) - t.sniff_hosts() - # Ensure we parsed out the fqdn and port from the fqdn/ip:port string. - assert t.connection_pool.connection_opts[0][1] == { - "host": "somehost.tld", - "port": 123, - } - - @sniffing_xfail - @patch("elasticsearch.transport.Transport.sniff_hosts") - def test_sniffing_disabled_on_cloud_instances(self, sniff_hosts): - t = Transport( # noqa: F821 - [{}], + assert time.time() - 1 < client.transport._last_sniffed_at < time.time() + 0.01 + + @pytest.mark.parametrize( + "kwargs", + [ + {"sniff_on_start": True}, + {"sniff_on_connection_fail": True}, + {"sniff_on_node_failure": True}, + {"sniff_before_requests": True}, + {"sniffer_timeout": 1}, + {"sniff_timeout": 1}, + ], + ) + def test_sniffing_disabled_on_elastic_cloud(self, kwargs): + with pytest.raises(ValueError) as e: + Elasticsearch( + cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==", + **kwargs, + ) + + assert ( + str(e.value) + == "Sniffing should not be enabled when connecting to Elastic Cloud" + ) + + def test_sniffing_master_only_filtered_by_default(self): + client = Elasticsearch( # noqa: F821 + [ + NodeConfig( + "http", + "localhost", + 9200, + _extras={"data": CLUSTER_NODES_MASTER_ONLY}, + ) + ], + node_class=DummyNode, + sniff_on_start=True, + ) + + assert len(client.transport.node_pool.all_nodes) == 2 + + def test_sniff_node_callback(self): + def sniffed_node_callback( + node_info: Dict[str, Any], node_config: NodeConfig + ) -> Optional[NodeConfig]: + return ( + node_config + if node_info["http"]["publish_address"].endswith(":124") + else None + ) + + client = Elasticsearch( # noqa: F821 + [ + NodeConfig( + "http", + "localhost", + 9200, + _extras={"data": CLUSTER_NODES_MASTER_ONLY}, + ) + ], + node_class=DummyNode, sniff_on_start=True, - sniff_on_connection_fail=True, - cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==", + sniffed_node_callback=sniffed_node_callback, ) - assert not t.sniff_on_connection_fail - assert sniff_hosts.call_args is None # Assert not called. + assert len(client.transport.node_pool.all_nodes) == 2 + + ports = {node.config.port for node in client.transport.node_pool.all()} + assert ports == {9200, 124} + + def test_sniffing_deprecated_host_info_callback(self): + def host_info_callback( + node_info: Dict[str, Any], host: Dict[str, Union[int, str]] + ) -> Dict[str, Any]: + return ( + host if node_info["http"]["publish_address"].endswith(":124") else None + ) + + with warnings.catch_warnings(record=True) as w: + client = Elasticsearch( # noqa: F821 + [ + NodeConfig( + "http", + "localhost", + 9200, + _extras={"data": CLUSTER_NODES_MASTER_ONLY}, + ) + ], + node_class=DummyNode, + sniff_on_start=True, + host_info_callback=host_info_callback, + ) + + assert len(w) == 1 + assert w[0].category == DeprecationWarning + assert ( + str(w[0].message) + == "The 'host_info_callback' parameter is deprecated in favor of 'sniffed_node_callback'" + ) + + assert len(client.transport.node_pool.all_nodes) == 2 + + ports = {node.config.port for node in client.transport.node_pool.all()} + assert ports == {9200, 124} @pytest.mark.parametrize("headers", [{}, {"X-elastic-product": "BAD HEADER"}]) diff --git a/utils/generate-api.py b/utils/generate-api.py index 0275a5101..380d99f29 100644 --- a/utils/generate-api.py +++ b/utils/generate-api.py @@ -31,7 +31,6 @@ from pathlib import Path import black -import unasync import urllib3 from click.testing import CliRunner from jinja2 import Environment, FileSystemLoader, TemplateNotFound @@ -392,36 +391,7 @@ def dump_modules(modules): for mod in modules.values(): mod.dump() - # Unasync all the generated async code - additional_replacements = { - # We want to rewrite to 'Transport' instead of 'SyncTransport', etc - "AsyncTransport": "Transport", - "AsyncElasticsearch": "Elasticsearch", - # We don't want to rewrite this class - "AsyncSearchClient": "AsyncSearchClient", - } - rules = [ - unasync.Rule( - fromdir="/elasticsearch/_async/client/", - todir="/elasticsearch/_sync/client/", - additional_replacements=additional_replacements, - ), - ] - - filepaths = [] - for root, _, filenames in os.walk(CODE_ROOT / "elasticsearch/_async"): - for filename in filenames: - if ( - filename.rpartition(".")[-1] - in ( - "py", - "pyi", - ) - and not filename.startswith("utils.py") - ): - filepaths.append(os.path.join(root, filename)) - - unasync.unasync_files(filepaths, rules) + os.system("python utils/run-unasync.py") blacken(CODE_ROOT / "elasticsearch") diff --git a/utils/run-unasync.py b/utils/run-unasync.py new file mode 100644 index 000000000..bc0c380f2 --- /dev/null +++ b/utils/run-unasync.py @@ -0,0 +1,62 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +from pathlib import Path + +import unasync + + +def main(): + # Unasync all the generated async code + additional_replacements = { + # We want to rewrite to 'Transport' instead of 'SyncTransport', etc + "AsyncTransport": "Transport", + "AsyncElasticsearch": "Elasticsearch", + # We don't want to rewrite this class + "AsyncSearchClient": "AsyncSearchClient", + # Handling typing.Awaitable[...] isn't done yet by unasync. + "_TYPE_ASYNC_SNIFF_CALLBACK": "_TYPE_SYNC_SNIFF_CALLBACK", + } + rules = [ + unasync.Rule( + fromdir="/elasticsearch/_async/client/", + todir="/elasticsearch/_sync/client/", + additional_replacements=additional_replacements, + ), + ] + + filepaths = [] + for root, _, filenames in os.walk( + Path(__file__).absolute().parent.parent / "elasticsearch/_async" + ): + for filename in filenames: + if ( + filename.rpartition(".")[-1] + in ( + "py", + "pyi", + ) + and not filename.startswith("utils.py") + ): + filepaths.append(os.path.join(root, filename)) + + unasync.unasync_files(filepaths, rules) + + +if __name__ == "__main__": + main()