Skip to content

Add support for elastic-transport sniffing #1771

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions elasticsearch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@

from ._async.client import AsyncElasticsearch
from ._sync.client import Elasticsearch
from .exceptions import ElasticsearchDeprecationWarning # noqa: F401
from .exceptions import (
ApiError,
AuthenticationException,
AuthorizationException,
ConflictError,
ConnectionError,
ConnectionTimeout,
ElasticsearchDeprecationWarning,
ElasticsearchException,
ElasticsearchWarning,
NotFoundError,
Expand Down Expand Up @@ -73,5 +73,4 @@
"AuthorizationException",
"UnsupportedProductError",
"ElasticsearchWarning",
"ElasticsearchDeprecationWarning",
]
110 changes: 107 additions & 3 deletions elasticsearch/_async/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,19 @@

import logging
import warnings
from typing import Optional
from typing import Any, Callable, Dict, Optional, Union

from elastic_transport import AsyncTransport, TransportError
from elastic_transport import AsyncTransport, NodeConfig, TransportError
from elastic_transport.client_utils import DEFAULT

from ...exceptions import NotFoundError
from ...serializer import DEFAULT_SERIALIZERS
from ._base import BaseClient, resolve_auth_headers
from ._base import (
BaseClient,
create_sniff_callback,
default_sniff_callback,
resolve_auth_headers,
)
from .async_search import AsyncSearchClient
from .autoscaling import AutoscalingClient
from .cat import CatClient
Expand Down Expand Up @@ -148,9 +153,21 @@ def __init__(
sniff_on_node_failure=DEFAULT,
sniff_timeout=DEFAULT,
min_delay_between_sniffing=DEFAULT,
sniffed_node_callback: Optional[
Callable[[Dict[str, Any], NodeConfig], Optional[NodeConfig]]
] = None,
meta_header=DEFAULT,
# Deprecated
timeout=DEFAULT,
randomize_hosts=DEFAULT,
host_info_callback: Optional[
Callable[
[Dict[str, Any], Dict[str, Union[str, int]]],
Optional[Dict[str, Union[str, int]]],
]
] = None,
sniffer_timeout=DEFAULT,
sniff_on_connection_fail=DEFAULT,
# Internal use only
_transport: Optional[AsyncTransport] = None,
) -> None:
Expand All @@ -170,6 +187,92 @@ def __init__(
)
request_timeout = timeout

if randomize_hosts is not DEFAULT:
if randomize_nodes_in_pool is not DEFAULT:
raise ValueError(
"Can't specify both 'randomize_hosts' and 'randomize_nodes_in_pool', "
"instead only specify 'randomize_nodes_in_pool'"
)
warnings.warn(
"The 'randomize_hosts' parameter is deprecated in favor of 'randomize_nodes_in_pool'",
category=DeprecationWarning,
stacklevel=2,
)
randomize_nodes_in_pool = randomize_hosts

if sniffer_timeout is not DEFAULT:
if min_delay_between_sniffing is not DEFAULT:
raise ValueError(
"Can't specify both 'sniffer_timeout' and 'min_delay_between_sniffing', "
"instead only specify 'min_delay_between_sniffing'"
)
warnings.warn(
"The 'sniffer_timeout' parameter is deprecated in favor of 'min_delay_between_sniffing'",
category=DeprecationWarning,
stacklevel=2,
)
min_delay_between_sniffing = sniffer_timeout

if sniff_on_connection_fail is not DEFAULT:
if sniff_on_node_failure is not DEFAULT:
raise ValueError(
"Can't specify both 'sniff_on_connection_fail' and 'sniff_on_node_failure', "
"instead only specify 'sniff_on_node_failure'"
)
warnings.warn(
"The 'sniff_on_connection_fail' parameter is deprecated in favor of 'sniff_on_node_failure'",
category=DeprecationWarning,
stacklevel=2,
)
sniff_on_node_failure = sniff_on_connection_fail

# Setting min_delay_between_sniffing=True implies sniff_before_requests=True
if min_delay_between_sniffing is not DEFAULT:
sniff_before_requests = True

sniffing_options = (
sniff_timeout,
sniff_on_start,
sniff_before_requests,
sniff_on_node_failure,
sniffed_node_callback,
min_delay_between_sniffing,
sniffed_node_callback,
)
if cloud_id is not None and any(
x is not DEFAULT and x is not None for x in sniffing_options
):
raise ValueError(
"Sniffing should not be enabled when connecting to Elastic Cloud"
)

sniff_callback = None
if host_info_callback is not None:
if sniffed_node_callback is not None:
raise ValueError(
"Can't specify both 'host_info_callback' and 'sniffed_node_callback', "
"instead only specify 'sniffed_node_callback'"
)
warnings.warn(
"The 'host_info_callback' parameter is deprecated in favor of 'sniffed_node_callback'",
category=DeprecationWarning,
stacklevel=2,
)

sniff_callback = create_sniff_callback(
host_info_callback=host_info_callback
)
elif sniffed_node_callback is not None:
sniff_callback = create_sniff_callback(
sniffed_node_callback=sniffed_node_callback
)
elif (
sniff_on_start is True
or sniff_before_requests is True
or sniff_on_node_failure is True
):
sniff_callback = default_sniff_callback

if _transport is None:
node_configs = client_node_configs(
hosts,
Expand Down Expand Up @@ -222,6 +325,7 @@ def __init__(
_transport = transport_class(
node_configs,
client_meta_service=CLIENT_META_SERVICE,
sniff_callback=sniff_callback,
**transport_kwargs,
)

Expand Down
116 changes: 113 additions & 3 deletions elasticsearch/_async/client/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,33 @@

import re
import warnings
from typing import Any, Collection, Iterable, Mapping, Optional, Tuple, TypeVar, Union
from typing import (
Any,
Callable,
Collection,
Dict,
Iterable,
List,
Mapping,
Optional,
Tuple,
TypeVar,
Union,
)

from elastic_transport import AsyncTransport, HttpHeaders
from elastic_transport import AsyncTransport, HttpHeaders, NodeConfig, SniffOptions
from elastic_transport.client_utils import DEFAULT, DefaultType, resolve_default

from ...compat import urlencode, warn_stacklevel
from ...exceptions import (
HTTP_EXCEPTIONS,
ApiError,
ConnectionError,
ElasticsearchWarning,
SerializationError,
UnsupportedProductError,
)
from .utils import _base64_auth_header
from .utils import _TYPE_ASYNC_SNIFF_CALLBACK, _base64_auth_header

SelfType = TypeVar("SelfType", bound="BaseClient")
SelfNamespacedType = TypeVar("SelfNamespacedType", bound="NamespacedClient")
Expand Down Expand Up @@ -83,6 +97,102 @@ def resolve_auth_headers(
return headers


def create_sniff_callback(
host_info_callback: Optional[
Callable[[Dict[str, Any], Dict[str, Any]], Optional[Dict[str, Any]]]
] = None,
sniffed_node_callback: Optional[
Callable[[Dict[str, Any], NodeConfig], Optional[NodeConfig]]
] = None,
) -> _TYPE_ASYNC_SNIFF_CALLBACK:
assert (host_info_callback is None) != (sniffed_node_callback is None)

# Wrap the deprecated 'host_info_callback' into 'sniffed_node_callback'
if host_info_callback is not None:

def _sniffed_node_callback(
node_info: Dict[str, Any], node_config: NodeConfig
) -> Optional[NodeConfig]:
assert host_info_callback is not None
if (
host_info_callback( # type ignore[misc]
node_info, {"host": node_config.host, "port": node_config.port}
)
is None
):
return None
return node_config

sniffed_node_callback = _sniffed_node_callback

async def sniff_callback(
transport: AsyncTransport, sniff_options: SniffOptions
) -> List[NodeConfig]:
for _ in transport.node_pool.all():
try:
meta, node_infos = await transport.perform_request(
"GET",
"/_nodes/_all/http",
headers={"accept": "application/json"},
request_timeout=(
sniff_options.sniff_timeout
if not sniff_options.is_initial_sniff
else None
),
)
except (SerializationError, ConnectionError):
continue

if not 200 <= meta.status <= 299:
continue

node_configs = []
for node_info in node_infos.get("nodes", {}).values():
address = node_info.get("http", {}).get("publish_address")
if not address or ":" not in address:
continue

if "/" in address:
# Support 7.x host/ip:port behavior where http.publish_host has been set.
fqdn, ipaddress = address.split("/", 1)
host = fqdn
_, port_str = ipaddress.rsplit(":", 1)
port = int(port_str)
else:
host, port_str = address.rsplit(":", 1)
port = int(port_str)

assert sniffed_node_callback is not None
sniffed_node = sniffed_node_callback(
node_info, meta.node.replace(host=host, port=port)
)
if sniffed_node is None:
continue

# Use the node which was able to make the request as a base.
node_configs.append(sniffed_node)

if node_configs:
return node_configs

return []

return sniff_callback


def _default_sniffed_node_callback(
node_info: Dict[str, Any], node_config: NodeConfig
) -> Optional[NodeConfig]:
if node_info.get("roles", []) == ["master"]:
return None
return node_config


default_sniff_callback = create_sniff_callback(
sniffed_node_callback=_default_sniffed_node_callback
)


class BaseClient:
def __init__(self, _transport: AsyncTransport) -> None:
self._transport = _transport
Expand Down
2 changes: 2 additions & 0 deletions elasticsearch/_async/client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

from ..._sync.client.utils import (
_TYPE_ASYNC_SNIFF_CALLBACK,
_TYPE_HOSTS,
CLIENT_META_SERVICE,
SKIP_IN_PATH,
Expand All @@ -30,6 +31,7 @@

__all__ = [
"CLIENT_META_SERVICE",
"_TYPE_ASYNC_SNIFF_CALLBACK",
"_deprecated_options",
"_TYPE_HOSTS",
"SKIP_IN_PATH",
Expand Down
Loading